Skip to content

Commit

Permalink
Sync to upstream/release/612 (#1162)
Browse files Browse the repository at this point in the history
New solver

* Fix bugs where bidirectional type inference would fail to take effect
at the proper stage.
* Improve inference of mutually recursive functions
* Fix crashes

---------

Co-authored-by: Aaron Weiss <[email protected]>
Co-authored-by: Andy Friesen <[email protected]>
Co-authored-by: Vyacheslav Egorov <[email protected]>
  • Loading branch information
3 people authored Feb 9, 2024
1 parent 67ce75e commit d6c2472
Show file tree
Hide file tree
Showing 13 changed files with 796 additions and 334 deletions.
1 change: 1 addition & 0 deletions Analysis/include/Luau/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ struct FunctionCheckConstraint
TypePackId argsPack;

class AstExprCall* callSite = nullptr;
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes;
};

// prim FreeType ExpectedType PrimitiveType
Expand Down
186 changes: 131 additions & 55 deletions Analysis/src/ConstraintGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,53 @@ void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const Con
f(cg->constraints[i]);
}

struct HasFreeType : TypeOnceVisitor
{
bool result = false;

HasFreeType()
{
}

bool visit(TypeId ty) override
{
if (result || ty->persistent)
return false;
return true;
}

bool visit(TypePackId tp) override
{
if (result)
return false;
return true;
}

bool visit(TypeId ty, const ClassType&) override
{
return false;
}

bool visit(TypeId ty, const FreeType&) override
{
result = true;
return false;
}

bool visit(TypePackId ty, const FreeTypePack&) override
{
result = true;
return false;
}
};

bool hasFreeType(TypeId ty)
{
HasFreeType hft{};
hft.traverse(ty);
return hft.result;
}

} // namespace

ConstraintGenerator::ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver,
Expand Down Expand Up @@ -229,6 +276,8 @@ std::optional<TypeId> ConstraintGenerator::lookup(const ScopePtr& scope, DefId d
{
if (auto found = scope->lookup(def))
return *found;
else if (phi->operands.size() == 1)
return lookup(scope, phi->operands[0], prototype);
else if (!prototype)
return std::nullopt;

Expand Down Expand Up @@ -837,6 +886,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocalFuncti
FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location);
sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location};

bool sigFullyDefined = !hasFreeType(sig.signature);
if (sigFullyDefined)
asMutable(functionType)->ty.emplace<BoundType>(sig.signature);

DefId def = dfg->getDef(function->name);
scope->lvalueTypes[def] = functionType;
scope->rvalueRefinements[def] = functionType;
Expand All @@ -847,25 +900,32 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocalFuncti
checkFunctionBody(sig.bodyScope, function->func);
Checkpoint end = checkpoint(this);

NotNull<Scope> constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()};
std::unique_ptr<Constraint> c =
std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature});
if (!sigFullyDefined)
{
NotNull<Scope> constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()};
std::unique_ptr<Constraint> c =
std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature});

Constraint* previous = nullptr;
forEachConstraint(start, end, this, [&c, &previous](const ConstraintPtr& constraint) {
c->dependencies.push_back(NotNull{constraint.get()});
Constraint* previous = nullptr;
forEachConstraint(start, end, this,
[&c, &previous](const ConstraintPtr& constraint)
{
c->dependencies.push_back(NotNull{constraint.get()});

if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});

previous = constraint.get();
}
});
previous = constraint.get();
}
});

addConstraint(scope, std::move(c));
module->astTypes[function->func] = functionType;
addConstraint(scope, std::move(c));
module->astTypes[function->func] = functionType;
}
else
module->astTypes[function->func] = sig.signature;

return ControlFlow::None;
}
Expand All @@ -879,12 +939,19 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f

Checkpoint start = checkpoint(this);
FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location);
bool sigFullyDefined = !hasFreeType(sig.signature);

if (sigFullyDefined)
asMutable(generalizedType)->ty.emplace<BoundType>(sig.signature);

DenseHashSet<Constraint*> excludeList{nullptr};

DefId def = dfg->getDef(function->name);
std::optional<TypeId> existingFunctionTy = lookup(scope, def);

if (sigFullyDefined && existingFunctionTy && get<BlockedType>(*existingFunctionTy))
asMutable(*existingFunctionTy)->ty.emplace<BoundType>(sig.signature);

if (AstExprLocal* localName = function->name->as<AstExprLocal>())
{
if (existingFunctionTy)
Expand All @@ -906,7 +973,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f
if (!existingFunctionTy)
ice->ice("prepopulateGlobalScope did not populate a global name", globalName->location);

generalizedType = *existingFunctionTy;
if (!sigFullyDefined)
generalizedType = *existingFunctionTy;

sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location};
sig.bodyScope->lvalueTypes[def] = sig.signature;
Expand Down Expand Up @@ -943,25 +1011,30 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f
checkFunctionBody(sig.bodyScope, function->func);
Checkpoint end = checkpoint(this);

NotNull<Scope> constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()};
std::unique_ptr<Constraint> c =
std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature});
if (!sigFullyDefined)
{
NotNull<Scope> constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()};
std::unique_ptr<Constraint> c =
std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature});

Constraint* previous = nullptr;
forEachConstraint(start, end, this, [&c, &excludeList, &previous](const ConstraintPtr& constraint) {
if (!excludeList.contains(constraint.get()))
c->dependencies.push_back(NotNull{constraint.get()});
Constraint* previous = nullptr;
forEachConstraint(start, end, this,
[&c, &excludeList, &previous](const ConstraintPtr& constraint)
{
if (!excludeList.contains(constraint.get()))
c->dependencies.push_back(NotNull{constraint.get()});

if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});

previous = constraint.get();
}
});
previous = constraint.get();
}
});

addConstraint(scope, std::move(c));
addConstraint(scope, std::move(c));
}

return ControlFlow::None;
}
Expand Down Expand Up @@ -1626,24 +1699,6 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
TypePackId argPack = addTypePack(std::move(args), argTail);
FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self);

NotNull<Constraint> fcc = addConstraint(scope, call->func->location,
FunctionCallConstraint{
fnType,
argPack,
rets,
call,
std::move(discriminantTypes),
&module->astOverloadResolvedTypes,
});

NotNull<Constraint> foo = addConstraint(scope, call->func->location,
FunctionCheckConstraint{
fnType,
argPack,
call
}
);

/*
* To make bidirectional type checking work, we need to solve these constraints in a particular order:
*
Expand All @@ -1653,14 +1708,35 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
* 4. Solve the call
*/

forEachConstraint(funcBeginCheckpoint, funcEndCheckpoint, this, [foo](const ConstraintPtr& constraint) {
foo->dependencies.emplace_back(constraint.get());
NotNull<Constraint> checkConstraint = addConstraint(scope, call->func->location,
FunctionCheckConstraint{
fnType,
argPack,
call,
NotNull{&module->astExpectedTypes}
}
);

forEachConstraint(funcBeginCheckpoint, funcEndCheckpoint, this, [checkConstraint](const ConstraintPtr& constraint) {
checkConstraint->dependencies.emplace_back(constraint.get());
});

forEachConstraint(argBeginCheckpoint, argEndCheckpoint, this, [foo, fcc](const ConstraintPtr& constraint) {
constraint->dependencies.emplace_back(foo);
NotNull<Constraint> callConstraint = addConstraint(scope, call->func->location,
FunctionCallConstraint{
fnType,
argPack,
rets,
call,
std::move(discriminantTypes),
&module->astOverloadResolvedTypes,
});

callConstraint->dependencies.push_back(checkConstraint);

forEachConstraint(argBeginCheckpoint, argEndCheckpoint, this, [checkConstraint, callConstraint](const ConstraintPtr& constraint) {
constraint->dependencies.emplace_back(checkConstraint);

fcc->dependencies.emplace_back(constraint.get());
callConstraint->dependencies.emplace_back(constraint.get());
});

return InferencePack{rets, {refinementArena.variadic(returnRefinements)}};
Expand Down Expand Up @@ -1884,7 +1960,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* fun
checkFunctionBody(sig.bodyScope, func);
Checkpoint endCheckpoint = checkpoint(this);

if (generalize)
if (generalize && hasFreeType(sig.signature))
{
TypeId generalizedTy = arena->addType(BlockedType{});
NotNull<Constraint> gc = addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature});
Expand Down
54 changes: 29 additions & 25 deletions Analysis/src/ConstraintSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
return block(fn, constraint);

if (isBlocked(argsPack))
return block(argsPack, constraint);
return true;

// We know the type of the function and the arguments it expects to receive.
// We also know the TypeIds of the actual arguments that will be passed.
Expand All @@ -1152,38 +1152,42 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
// types.

// FIXME: Bidirectional type checking of overloaded functions is not yet supported.
if (auto ftv = get<FunctionType>(fn))
const FunctionType* ftv = get<FunctionType>(fn);
if (!ftv)
return true;

const std::vector<TypeId> expectedArgs = flatten(ftv->argTypes).first;
const std::vector<TypeId> argPackHead = flatten(argsPack).first;

for (size_t i = 0; i < c.callSite->args.size && i < expectedArgs.size() && i < argPackHead.size(); ++i)
{
const std::vector<TypeId> expectedArgs = flatten(ftv->argTypes).first;
const std::vector<TypeId> argPackHead = flatten(argsPack).first;
const TypeId expectedArgTy = follow(expectedArgs[i]);
const TypeId actualArgTy = follow(argPackHead[i]);
const AstExpr* expr = c.callSite->args.data[i];

for (size_t i = 0; i < c.callSite->args.size && i < expectedArgs.size() && i < argPackHead.size(); ++i)
{
const TypeId expectedArgTy = follow(expectedArgs[i]);
const TypeId actualArgTy = follow(argPackHead[i]);
(*c.astExpectedTypes)[expr] = expectedArgTy;

const FunctionType* expectedLambdaTy = get<FunctionType>(expectedArgTy);
const FunctionType* lambdaTy = get<FunctionType>(actualArgTy);
const AstExprFunction* lambdaExpr = c.callSite->args.data[i]->as<AstExprFunction>();
const FunctionType* expectedLambdaTy = get<FunctionType>(expectedArgTy);
const FunctionType* lambdaTy = get<FunctionType>(actualArgTy);
const AstExprFunction* lambdaExpr = expr->as<AstExprFunction>();

if (expectedLambdaTy && lambdaTy && lambdaExpr)
{
const std::vector<TypeId> expectedLambdaArgTys = flatten(expectedLambdaTy->argTypes).first;
const std::vector<TypeId> lambdaArgTys = flatten(lambdaTy->argTypes).first;
if (expectedLambdaTy && lambdaTy && lambdaExpr)
{
const std::vector<TypeId> expectedLambdaArgTys = flatten(expectedLambdaTy->argTypes).first;
const std::vector<TypeId> lambdaArgTys = flatten(lambdaTy->argTypes).first;

for (size_t j = 0; j < expectedLambdaArgTys.size() && j < lambdaArgTys.size() && j < lambdaExpr->args.size; ++j)
for (size_t j = 0; j < expectedLambdaArgTys.size() && j < lambdaArgTys.size() && j < lambdaExpr->args.size; ++j)
{
if (!lambdaExpr->args.data[j]->annotation && get<FreeType>(follow(lambdaArgTys[j])))
{
if (!lambdaExpr->args.data[j]->annotation && get<FreeType>(follow(lambdaArgTys[j])))
{
asMutable(lambdaArgTys[j])->ty.emplace<BoundType>(expectedLambdaArgTys[j]);
}
asMutable(lambdaArgTys[j])->ty.emplace<BoundType>(expectedLambdaArgTys[j]);
}
}
else
{
Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}};
u2.unify(actualArgTy, expectedArgTy);
}
}
else if (expr->is<AstExprConstantBool>() || expr->is<AstExprConstantString>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantNil>() || expr->is<AstExprTable>())
{
Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}};
u2.unify(actualArgTy, expectedArgTy);
}
}

Expand Down
4 changes: 4 additions & 0 deletions Analysis/src/Subtyping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
return SubtypingResult{false}.withSubComponent(TypePath::PackField::Tail);
}
}
else if (get<ErrorTypePack>(*subTail))
return SubtypingResult{true}.withSubComponent(TypePath::PackField::Tail);
else
unexpected(*subTail);
}
Expand Down Expand Up @@ -643,6 +645,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
return SubtypingResult{false}.withSuperComponent(TypePath::PackField::Tail);
}
}
else if (get<ErrorTypePack>(*superTail))
return SubtypingResult{true}.withSuperComponent(TypePath::PackField::Tail);
else
unexpected(*superTail);
}
Expand Down
Loading

0 comments on commit d6c2472

Please sign in to comment.