Skip to content

Commit

Permalink
refactor and fix validation of conditional return type
Browse files Browse the repository at this point in the history
  • Loading branch information
gabritto committed Oct 4, 2024
1 parent 7e1fe01 commit bbaf88f
Showing 1 changed file with 84 additions and 66 deletions.
150 changes: 84 additions & 66 deletions src/compiler/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2367,6 +2367,8 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
[".jsx", ".jsx"],
[".json", ".json"],
];

var narrowableReturnTypeCache = new Map<TypeId, boolean>;
/* eslint-enable no-var */

initializeTypeChecker();
Expand Down Expand Up @@ -19266,6 +19268,7 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
forConstraint: boolean,
aliasSymbol?: Symbol,
aliasTypeArguments?: readonly Type[],
forNarrowing?: boolean,
): Type {
let result;
let extraTypes: Type[] | undefined;
Expand All @@ -19288,7 +19291,9 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
if (checkType === wildcardType || extendsType === wildcardType) {
return wildcardType;
}
const effectiveCheckType = isNarrowingSubstitutionType(checkType) ? (checkType as SubstitutionType).constraint : checkType;
const effectiveCheckType = forNarrowing && isNarrowingSubstitutionType(checkType)
? (checkType as SubstitutionType).constraint
: checkType;
const checkTypeNode = skipTypeParentheses(root.node.checkType);
const extendsTypeNode = skipTypeParentheses(root.node.extendsType);
// When the check and extends types are simple tuple types of the same arity, we defer resolution of the
Expand Down Expand Up @@ -20498,19 +20503,26 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
if (!result) {
const newMapper = createTypeMapper(root.outerTypeParameters, typeArguments);
const checkType = root.checkType;
let distributionType = root.isDistributive ? getReducedType(getMappedType(checkType, newMapper)) : undefined;
let narrowingBaseType: Type | undefined;
let mappedCheckType = root.isDistributive ? getReducedType(getMappedType(checkType, newMapper)) : undefined;
if (mappedCheckType && isNarrowingSubstitutionType(mappedCheckType)) {
narrowingBaseType = (mappedCheckType as SubstitutionType).baseType;
mappedCheckType = getReducedType((mappedCheckType as SubstitutionType).constraint);
const forNarrowing = distributionType && isNarrowingSubstitutionType(distributionType) && isNarrowableConditionalTypeWorker(type);
if (forNarrowing) {
narrowingBaseType = (distributionType as SubstitutionType).baseType;
distributionType = getReducedType((distributionType as SubstitutionType).constraint);
}
const distributionType = root.isDistributive ? mappedCheckType : undefined;
// Distributive conditional types are distributed over union types. For example, when the
// distributive conditional type T extends U ? X : Y is instantiated with A | B for T, the
// result is (A extends U ? X : Y) | (B extends U ? X : Y).
if (distributionType && checkType !== distributionType && distributionType.flags & (TypeFlags.Union | TypeFlags.Never)) {
const mapperCallback = narrowingBaseType ?
(t: Type) => getConditionalType(root, prependTypeMapping(checkType, getSubstitutionType(narrowingBaseType, t, /*isNarrowed*/ true), newMapper), forConstraint) :
(t: Type) => getConditionalType(
root,
prependTypeMapping(checkType, getSubstitutionType(narrowingBaseType, t, /*isNarrowed*/ true), newMapper),
forConstraint,
/*aliasSymbol*/ undefined,
/*aliasTypeArguments*/ undefined,
forNarrowing,
) :
(t: Type) => getConditionalType(root, prependTypeMapping(checkType, t, newMapper), forConstraint);
if (narrowingBaseType) {
result = mapType(distributionType, mapperCallback, /*noReductions*/ undefined, /*toIntersection*/ true);
Expand All @@ -20520,7 +20532,7 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
}
}
else {
result = getConditionalType(root, newMapper, forConstraint, aliasSymbol, aliasTypeArguments);
result = getConditionalType(root, newMapper, forConstraint, aliasSymbol, aliasTypeArguments, forNarrowing);
}
root.instantiations!.set(id, result);
}
Expand Down Expand Up @@ -45763,10 +45775,12 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
const allTypeParameters = appendTypeParameters(getOuterTypeParameters(container, /*includeThisTypes*/ false), getEffectiveTypeParameterDeclarations(container as DeclarationWithTypeParameters));
const narrowableTypeParameters = allTypeParameters && getNarrowableTypeParameters(allTypeParameters);

// >> TODO: another optimization would be to check if any of the narrowable type parameters
// match the types in the return type that can be narrowed
if (
!narrowableTypeParameters ||
!narrowableTypeParameters.length ||
!isNarrowableReturnType(narrowableTypeParameters.map(trio => trio[0]), unwrappedReturnType)
!isNarrowableReturnType(unwrappedReturnType)
) {
checkTypeAssignableToAndOptionallyElaborate(unwrappedExprType, unwrappedReturnType, errorNode, expr);
return;
Expand Down Expand Up @@ -45926,66 +45940,70 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {

// A narrowable indexed access type is one that has the shape `A[T]`,
// where `T` is a narrowable type parameter.
function isNarrowableReturnType(returnType: IndexedAccessType | ConditionalType): boolean {
return isConditionalType(returnType)
? isNarrowableConditionalType(returnType)
: !!(returnType.indexType.flags & TypeFlags.TypeParameter);
}

function isNarrowableConditionalType(type: ConditionalType): boolean {
let result = narrowableReturnTypeCache.get(type.id);
if (result === undefined) {
result = isNarrowableConditionalTypeWorker(type);
narrowableReturnTypeCache.set(type.id, result);
}
return result;
}

// A narrowable conditional type is one that has the following shape:
// `T extends A ? TrueBranch<T> : FalseBranch<T>`, such that:
// (0) The conditional type's check type is a narrowable type parameter;
// (1) `A` is a type belonging to the constraint of the type parameter,
// or a union of types belonging to the constraint of the type parameter;
// (2) There are no `infer` type parameters in the conditional type;
// (3) `TrueBranch<T>` and `FalseBranch<T>` must be valid, recursively;
// In particular, the false-most branch of the conditional type must be `never`.
function isNarrowableReturnType(
typeParameters: TypeParameter[],
returnType: IndexedAccessType | ConditionalType,
): boolean {
return !isConditionalType(returnType)
&& typeParameters.includes(returnType.indexType)
|| isNarrowableConditionalType(returnType, /*branch*/ undefined);
// `branch` can be `true` if `type` is the true type of a conditional, `false` if it's the false type of a conditional,
// and `undefined` if neither.
function isNarrowableConditionalType(type: Type, branch: boolean | undefined): boolean {
if (!isConditionalType(type)) {
// This is type `R` in `T extends A ? R : ...`
if (branch === true) {
return true;
}
// This is type `never` in `T extends A ? R : never`
if (branch === false) {
return type === neverType;
}
return false;
}
// (0)
if (!(type.checkType.flags & TypeFlags.TypeParameter)) {
return false;
}
const typeParameter = typeParameters.find(tp => tp === type.checkType);
if (!typeParameter) {
return false;
}
const constraintType = getConstraintOfTypeParameter(typeParameter) as UnionType;
// (0)
if (!type.root.isDistributive) {
return false;
}
// (2)
if (type.root.inferTypeParameters?.length) {
return false;
}
// (1)
if (
!everyType(type.extendsType, extendsType =>
some(
constraintType.types,
constraintType => isTypeIdenticalTo(constraintType, extendsType),
))
) {
return false;
}
// `T extends A ? TrueBranch<T> : FalseBranch<T>`, in other words:
// (0) The conditional type is distributive;
// (1) The conditional type has no `infer` type parameters;
// (2) The conditional type's check type is a narrowable type parameter (i.e. a type parameter with a union constraint);
// (3) The extends type `A` is a type or a union of types belonging to the union constraint of the type parameter;
// (4) `TrueBranch<T>` and `FalseBranch<T>` must be valid, recursively.
// In particular, the false-most branch of the conditional type must be `never`.
function isNarrowableConditionalTypeWorker(type: ConditionalType): boolean {
// (0)
if (!type.root.isDistributive) {
return false;
}
// (1)
if (type.root.inferTypeParameters) {
return false;
}

// (2)
if (!(type.checkType.flags & TypeFlags.TypeParameter)) {
return false;
}

return isNarrowableConditionalType(getTrueTypeFromConditionalType(type), /*branch*/ true) &&
isNarrowableConditionalType(getFalseTypeFromConditionalType(type), /*branch*/ false);
// (2)
const constraintType = getConstraintOfTypeParameter(type.checkType as TypeParameter);
if (!constraintType || !(constraintType.flags & TypeFlags.Union)) {
return false;
}
// (3)
if (
!everyType(type.extendsType, extendsType =>
some(
(constraintType as UnionType).types,
constraintType => isTypeIdenticalTo(constraintType, extendsType),
))
) {
return false;
}

// (4)
const trueType = getTrueTypeFromConditionalType(type);
const falseType = getFalseTypeFromConditionalType(type);
const isValidTrueType = isConditionalType(trueType)
? isNarrowableConditionalTypeWorker(trueType)
: true;
const isValidFalseType = isConditionalType(falseType)
? isNarrowableConditionalTypeWorker(falseType)
: falseType === neverType;
return isValidTrueType && isValidFalseType;
}

function isConditionalType(type: Type): type is ConditionalType {
Expand Down

0 comments on commit bbaf88f

Please sign in to comment.