From c6ac06e6569da483c8e217af6c728e618a014575 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 1 Sep 2022 16:00:14 -0700 Subject: [PATCH 01/66] Sync to upstream/release/543 --- Analysis/include/Luau/Anyification.h | 3 +- Analysis/include/Luau/BuiltinDefinitions.h | 1 + Analysis/include/Luau/Constraint.h | 31 +- .../include/Luau/ConstraintGraphBuilder.h | 11 +- Analysis/include/Luau/ConstraintSolver.h | 39 +- .../include/Luau/ConstraintSolverLogger.h | 3 +- Analysis/include/Luau/Frontend.h | 2 +- Analysis/include/Luau/Scope.h | 3 + Analysis/include/Luau/ToString.h | 1 - Analysis/include/Luau/TypeArena.h | 3 +- Analysis/include/Luau/TypeInfer.h | 3 +- Analysis/include/Luau/TypePack.h | 14 +- Analysis/include/Luau/TypeUtils.h | 10 +- Analysis/include/Luau/TypeVar.h | 12 +- Analysis/include/Luau/Unifier.h | 3 +- Analysis/include/Luau/VisitTypeVar.h | 25 +- Analysis/src/Anyification.cpp | 11 +- Analysis/src/Autocomplete.cpp | 12 +- Analysis/src/BuiltinDefinitions.cpp | 57 +++ Analysis/src/Clone.cpp | 36 +- Analysis/src/Constraint.cpp | 7 +- Analysis/src/ConstraintGraphBuilder.cpp | 225 +++++++---- Analysis/src/ConstraintSolver.cpp | 382 ++++++++++++++++-- Analysis/src/ConstraintSolverLogger.cpp | 1 + Analysis/src/Frontend.cpp | 16 +- Analysis/src/Instantiation.cpp | 1 + Analysis/src/JsonEmitter.cpp | 6 +- Analysis/src/Linter.cpp | 30 +- Analysis/src/Module.cpp | 3 +- Analysis/src/Substitution.cpp | 4 +- Analysis/src/ToString.cpp | 36 +- Analysis/src/TypeArena.cpp | 13 +- Analysis/src/TypeAttach.cpp | 5 + Analysis/src/TypeChecker2.cpp | 224 +++++++++- Analysis/src/TypeInfer.cpp | 61 +-- Analysis/src/TypePack.cpp | 7 + Analysis/src/TypeUtils.cpp | 34 +- Analysis/src/TypeVar.cpp | 21 +- Analysis/src/TypedAllocator.cpp | 2 + Analysis/src/Unifiable.cpp | 3 +- Analysis/src/Unifier.cpp | 15 +- Ast/include/Luau/Ast.h | 10 + Ast/include/Luau/Parser.h | 6 + Ast/src/Ast.cpp | 13 + Ast/src/Lexer.cpp | 17 +- Ast/src/Parser.cpp | 87 ++-- Common/include/Luau/Bytecode.h | 2 +- Common/include/Luau/ExperimentalFlags.h | 3 +- Compiler/src/BytecodeBuilder.cpp | 5 - Compiler/src/Compiler.cpp | 371 +++++------------ VM/src/lbuiltins.cpp | 6 +- VM/src/lstrlib.cpp | 5 - VM/src/lvmexecute.cpp | 9 + tests/AssemblyBuilderX64.test.cpp | 3 +- tests/Compiler.test.cpp | 27 +- tests/Conformance.test.cpp | 20 +- tests/ConstraintGraphBuilder.test.cpp | 23 +- tests/ConstraintSolver.test.cpp | 11 +- tests/Fixture.cpp | 5 +- tests/Fixture.h | 5 +- tests/JsonEmitter.test.cpp | 2 +- tests/Linter.test.cpp | 14 +- tests/Parser.test.cpp | 45 ++- tests/TypeInfer.anyerror.test.cpp | 13 + tests/TypeInfer.builtins.test.cpp | 12 - tests/TypeInfer.definitions.test.cpp | 26 ++ tests/TypeInfer.functions.test.cpp | 1 + tests/TypeInfer.loops.test.cpp | 41 +- tests/TypeInfer.modules.test.cpp | 29 ++ tests/TypeInfer.primitives.test.cpp | 12 +- tests/TypeInfer.provisional.test.cpp | 20 +- tests/TypeInfer.tables.test.cpp | 2 - tests/conformance/basic.lua | 5 + tests/conformance/errors.lua | 14 +- tools/faillist.txt | 72 +--- tools/test_dcr.py | 16 +- 76 files changed, 1531 insertions(+), 797 deletions(-) diff --git a/Analysis/include/Luau/Anyification.h b/Analysis/include/Luau/Anyification.h index ee8d66898..9dd7d8e00 100644 --- a/Analysis/include/Luau/Anyification.h +++ b/Analysis/include/Luau/Anyification.h @@ -19,6 +19,7 @@ using ScopePtr = std::shared_ptr; // A substitution which replaces free types by any struct Anyification : Substitution { + Anyification(TypeArena* arena, NotNull scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack); Anyification(TypeArena* arena, const ScopePtr& scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack); NotNull scope; InternalErrorReporter* iceHandler; @@ -35,4 +36,4 @@ struct Anyification : Substitution bool ignoreChildren(TypePackId ty) override; }; -} \ No newline at end of file +} // namespace Luau \ No newline at end of file diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 07d897b2f..28a4368e7 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -34,6 +34,7 @@ TypeId makeFunction( // Polymorphic std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); void attachMagicFunction(TypeId ty, MagicFunction fn); +void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName); diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index e9f04e79f..d5cfcf3f2 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -56,6 +56,10 @@ struct UnaryConstraint TypeId resultType; }; +// let L : leftType +// let R : rightType +// in +// L op R : resultType struct BinaryConstraint { AstExprBinary::Op op; @@ -64,6 +68,14 @@ struct BinaryConstraint TypeId resultType; }; +// iteratee is iterable +// iterators is the iteration types. +struct IterableConstraint +{ + TypePackId iterator; + TypePackId variables; +}; + // name(namedType) = name struct NameConstraint { @@ -78,20 +90,31 @@ struct TypeAliasExpansionConstraint TypeId target; }; -using ConstraintV = Variant; using ConstraintPtr = std::unique_ptr; +struct FunctionCallConstraint +{ + std::vector> innerConstraints; + TypeId fn; + TypePackId result; + class AstExprCall* astFragment; +}; + +using ConstraintV = Variant; + struct Constraint { - Constraint(ConstraintV&& c, NotNull scope); + Constraint(NotNull scope, const Location& location, ConstraintV&& c); Constraint(const Constraint&) = delete; Constraint& operator=(const Constraint&) = delete; + NotNull scope; + Location location; ConstraintV c; + std::vector> dependencies; - NotNull scope; }; inline Constraint& asMutable(const Constraint& c) diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 0e41e1eea..1cba0d33d 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -9,6 +9,7 @@ #include "Luau/Ast.h" #include "Luau/Constraint.h" #include "Luau/Module.h" +#include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" #include "Luau/Symbol.h" #include "Luau/TypeVar.h" @@ -51,12 +52,15 @@ struct ConstraintGraphBuilder // It is pretty uncommon for constraint generation to itself produce errors, but it can happen. std::vector errors; + // Needed to resolve modules to make 'require' import types properly. + NotNull moduleResolver; // Occasionally constraint generation needs to produce an ICE. const NotNull ice; ScopePtr globalScope; - ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull ice, const ScopePtr& globalScope); + ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, + NotNull ice, const ScopePtr& globalScope); /** * Fabricates a new free type belonging to a given scope. @@ -82,7 +86,7 @@ struct ConstraintGraphBuilder * @param scope the scope to add the constraint to. * @param cv the constraint variant to add. */ - void addConstraint(const ScopePtr& scope, ConstraintV cv); + void addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); /** * Adds a constraint to a given scope. @@ -104,6 +108,7 @@ struct ConstraintGraphBuilder void visit(const ScopePtr& scope, AstStatBlock* block); void visit(const ScopePtr& scope, AstStatLocal* local); void visit(const ScopePtr& scope, AstStatFor* for_); + void visit(const ScopePtr& scope, AstStatForIn* forIn); void visit(const ScopePtr& scope, AstStatWhile* while_); void visit(const ScopePtr& scope, AstStatRepeat* repeat); void visit(const ScopePtr& scope, AstStatLocalFunction* function); @@ -117,8 +122,6 @@ struct ConstraintGraphBuilder void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); - TypePackId checkExprList(const ScopePtr& scope, const AstArray& exprs); - TypePackId checkPack(const ScopePtr& scope, AstArray exprs); TypePackId checkPack(const ScopePtr& scope, AstExpr* expr); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index a270ec997..002aa9475 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -17,6 +17,8 @@ namespace Luau // never dereference this pointer. using BlockedConstraintId = const void*; +struct ModuleResolver; + struct InstantiationSignature { TypeFun fn; @@ -42,6 +44,7 @@ struct ConstraintSolver // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; + ModuleName currentModuleName; // Constraints that the solver has generated, rather than sourcing from the // scope tree. @@ -63,9 +66,13 @@ struct ConstraintSolver // Recorded errors that take place within the solver. ErrorVec errors; + NotNull moduleResolver; + std::vector requireCycles; + ConstraintSolverLogger logger; - explicit ConstraintSolver(TypeArena* arena, NotNull rootScope); + explicit ConstraintSolver(TypeArena* arena, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, + std::vector requireCycles); /** * Attempts to dispatch all pending constraints and reach a type solution @@ -86,8 +93,17 @@ struct ConstraintSolver bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); bool tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force); bool tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const IterableConstraint& c, NotNull constraint, bool force); bool tryDispatch(const NameConstraint& c, NotNull constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint); + bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); + + // for a, ... in some_table do + bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); + + // for a, ... in next_function, t, ... do + bool tryDispatchIterableFunction( + TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); void block(NotNull target, NotNull constraint); /** @@ -108,6 +124,11 @@ struct ConstraintSolver */ bool isBlocked(TypeId ty); + /** + * @returns true if the TypePackId is in a blocked state. + */ + bool isBlocked(TypePackId tp); + /** * Returns whether the constraint is blocked on anything. * @param constraint the constraint to check. @@ -133,10 +154,22 @@ struct ConstraintSolver /** Pushes a new solver constraint to the solver. * @param cv the body of the constraint. **/ - void pushConstraint(ConstraintV cv, NotNull scope); + void pushConstraint(NotNull scope, const Location& location, ConstraintV cv); + + /** + * Attempts to resolve a module from its module information. Returns the + * module-level return type of the module, or the error type if one cannot + * be found. Reports errors to the solver if the module cannot be found or + * the require is illegal. + * @param module the module information to look up. + * @param location the location where the require is taking place; used for + * error locations. + **/ + TypeId resolveModule(const ModuleInfo& module, const Location& location); void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeError e); + private: /** * Marks a constraint as being blocked on a type or type pack. The constraint @@ -154,6 +187,8 @@ struct ConstraintSolver * @param progressed the type or type pack pointer that has progressed. **/ void unblock_(BlockedConstraintId progressed); + + ToStringOptions opts; }; void dump(NotNull rootScope, struct ToStringOptions& opts); diff --git a/Analysis/include/Luau/ConstraintSolverLogger.h b/Analysis/include/Luau/ConstraintSolverLogger.h index 55170c42b..65aa9a7e6 100644 --- a/Analysis/include/Luau/ConstraintSolverLogger.h +++ b/Analysis/include/Luau/ConstraintSolverLogger.h @@ -16,7 +16,8 @@ struct ConstraintSolverLogger { std::string compileOutput(); void captureBoundarySnapshot(const Scope* rootScope, std::vector>& unsolvedConstraints); - void prepareStepSnapshot(const Scope* rootScope, NotNull current, std::vector>& unsolvedConstraints, bool force); + void prepareStepSnapshot( + const Scope* rootScope, NotNull current, std::vector>& unsolvedConstraints, bool force); void commitPreparedStepSnapshot(); private: diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index f8da32732..556126892 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -157,7 +157,7 @@ struct Frontend ScopePtr getGlobalScope(); private: - ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope); + ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector requireCycles); std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index dc1233351..b7569d8eb 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -40,6 +40,9 @@ struct Scope std::optional varargPack; // All constraints belonging to this scope. std::vector constraints; + // Constraints belonging to this scope that are queued manually by other + // constraints. + std::vector unqueuedConstraints; TypeLevel level; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index eabbc2beb..61e07e9fa 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -92,7 +92,6 @@ inline std::string toString(const Constraint& c) return toString(c, ToStringOptions{}); } - std::string toString(const TypeVar& tv, ToStringOptions& opts); std::string toString(const TypePackVar& tp, ToStringOptions& opts); diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index be36f19c7..decc8c590 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -29,9 +29,10 @@ struct TypeArena TypeId addTV(TypeVar&& tv); TypeId freshType(TypeLevel level); + TypeId freshType(Scope* scope); TypePackId addTypePack(std::initializer_list types); - TypePackId addTypePack(std::vector types); + TypePackId addTypePack(std::vector types, std::optional tail = {}); TypePackId addTypePack(TypePack pack); TypePackId addTypePack(TypePackVar pack); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index e253eddf8..0b427946b 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -166,7 +166,8 @@ struct TypeChecker */ bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options); - bool unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); + bool unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, + CountMismatch::Context ctx = CountMismatch::Context::Arg); /** Attempt to unify the types. * If this fails, and the subTy type can be instantiated, do so and try unification again. diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index b17003b1a..8269230b4 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -15,6 +15,7 @@ struct TypeArena; struct TypePack; struct VariadicTypePack; +struct BlockedTypePack; struct TypePackVar; @@ -24,7 +25,7 @@ using TypePackId = const TypePackVar*; using FreeTypePack = Unifiable::Free; using BoundTypePack = Unifiable::Bound; using GenericTypePack = Unifiable::Generic; -using TypePackVariant = Unifiable::Variant; +using TypePackVariant = Unifiable::Variant; /* A TypePack is a rope-like string of TypeIds. We use this structure to encode * notions like packs of unknown length and packs of any length, as well as more @@ -43,6 +44,17 @@ struct VariadicTypePack bool hidden = false; // if true, we don't display this when toString()ing a pack with this variadic as its tail. }; +/** + * Analogous to a BlockedTypeVar. + */ +struct BlockedTypePack +{ + BlockedTypePack(); + size_t index; + + static size_t nextIndex; +}; + struct TypePackVar { explicit TypePackVar(const TypePackVariant& ty); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 0aff5a7df..6c611fb2c 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -11,12 +11,16 @@ namespace Luau { +struct TxnLog; + using ScopePtr = std::shared_ptr; std::optional findMetatableEntry(ErrorVec& errors, TypeId type, const std::string& entry, Location location); std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, const std::string& name, Location location); -std::optional getIndexTypeFromType( - const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, TypeId type, const std::string& prop, const Location& location, bool addErrors, - InternalErrorReporter& handle); +std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, TypeId type, const std::string& prop, + const Location& location, bool addErrors, InternalErrorReporter& handle); + +// Returns the minimum and maximum number of types the argument list can accept. +std::pair> getParameterExtents(const TxnLog* log, TypePackId tp); } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 35b37394a..e67b36014 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -1,11 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Ast.h" #include "Luau/DenseHash.h" #include "Luau/Predicate.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" #include "Luau/Common.h" +#include "Luau/NotNull.h" #include #include @@ -262,6 +264,8 @@ struct WithPredicate using MagicFunction = std::function>( struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; +using DcrMagicFunction = std::function, TypePackId, const class AstExprCall*)>; + struct FunctionTypeVar { // Global monomorphic function @@ -287,7 +291,8 @@ struct FunctionTypeVar std::vector> argNames; TypePackId retTypes; std::optional definition; - MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. + MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. + DcrMagicFunction dcrMagicFunction = nullptr; // can be nullptr bool hasSelf; Tags tags; bool hasNoGenerics = false; @@ -462,8 +467,9 @@ struct TypeFun */ struct PendingExpansionTypeVar { - PendingExpansionTypeVar(TypeFun fn, std::vector typeArguments, std::vector packArguments); - TypeFun fn; + PendingExpansionTypeVar(std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments); + std::optional prefix; + AstName name; std::vector typeArguments; std::vector packArguments; size_t index; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 312b05849..c7eb51a65 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -59,7 +59,8 @@ struct Unifier UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, NotNull scope, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(TypeArena* types, Mode mode, NotNull scope, const Location& location, Variance variance, UnifierSharedState& sharedState, + TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 9d7fa9fef..315e5992f 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -188,6 +188,10 @@ struct GenericTypeVarVisitor { return visit(tp); } + virtual bool visit(TypePackId tp, const BlockedTypePack& btp) + { + return visit(tp); + } void traverse(TypeId ty) { @@ -314,24 +318,6 @@ struct GenericTypeVarVisitor { if (visit(ty, *petv)) { - traverse(petv->fn.type); - - for (const GenericTypeDefinition& p : petv->fn.typeParams) - { - traverse(p.ty); - - if (p.defaultValue) - traverse(*p.defaultValue); - } - - for (const GenericTypePackDefinition& p : petv->fn.typePackParams) - { - traverse(p.tp); - - if (p.defaultValue) - traverse(*p.defaultValue); - } - for (TypeId a : petv->typeArguments) traverse(a); @@ -388,6 +374,9 @@ struct GenericTypeVarVisitor if (res) traverse(pack->ty); } + else if (auto btp = get(tp)) + visit(tp, *btp); + else LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!"); diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp index b6e58009c..abcaba020 100644 --- a/Analysis/src/Anyification.cpp +++ b/Analysis/src/Anyification.cpp @@ -11,15 +11,20 @@ LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) namespace Luau { -Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) +Anyification::Anyification(TypeArena* arena, NotNull scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) : Substitution(TxnLog::empty(), arena) - , scope(NotNull{scope.get()}) + , scope(scope) , iceHandler(iceHandler) , anyType(anyType) , anyTypePack(anyTypePack) { } +Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) + : Anyification(arena, NotNull{scope.get()}, iceHandler, anyType, anyTypePack) +{ +} + bool Anyification::isDirty(TypeId ty) { if (ty->persistent) @@ -93,4 +98,4 @@ bool Anyification::ignoreChildren(TypePackId ty) return ty->persistent; } -} +} // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 5c484899a..378a1cb7d 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -1215,8 +1215,8 @@ static bool autocompleteIfElseExpression( } } -static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, const TypeChecker& typeChecker, TypeArena* typeArena, - const std::vector& ancestry, Position position, AutocompleteEntryMap& result) +static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, const TypeChecker& typeChecker, + TypeArena* typeArena, const std::vector& ancestry, Position position, AutocompleteEntryMap& result) { LUAU_ASSERT(!ancestry.empty()); @@ -1422,8 +1422,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; if (!FFlag::LuauSelfCallAutocompleteFix3 && isString(ty)) - return { - autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), ancestry, AutocompleteContext::Property}; + return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), ancestry, + AutocompleteContext::Property}; else return {autocompleteProps(*module, typeArena, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; } @@ -1522,8 +1522,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { - return { - {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, + ancestry, AutocompleteContext::Keyword}; } else if (AstStatIf* statIf = parent->as(); statIf && node->is()) { diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 826179b39..e011eaa55 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -5,6 +5,7 @@ #include "Luau/Symbol.h" #include "Luau/Common.h" #include "Luau/ToString.h" +#include "Luau/ConstraintSolver.h" #include @@ -32,6 +33,8 @@ static std::optional> magicFunctionPack( static std::optional> magicFunctionRequire( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionRequire(NotNull solver, TypePackId result, const AstExprCall* expr); + TypeId makeUnion(TypeArena& arena, std::vector&& types) { return arena.addType(UnionTypeVar{std::move(types)}); @@ -105,6 +108,14 @@ void attachMagicFunction(TypeId ty, MagicFunction fn) LUAU_ASSERT(!"Got a non functional type"); } +void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn) +{ + if (auto ftv = getMutable(ty)) + ftv->dcrMagicFunction = fn; + else + LUAU_ASSERT(!"Got a non functional type"); +} + Property makeProperty(TypeId ty, std::optional documentationSymbol) { return { @@ -263,6 +274,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) } attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); + attachDcrMagicFunction(getGlobalBinding(typeChecker, "require"), dcrMagicFunctionRequire); } static std::optional> magicFunctionSelect( @@ -509,4 +521,49 @@ static std::optional> magicFunctionRequire( return std::nullopt; } +static bool checkRequirePathDcr(NotNull solver, AstExpr* expr) +{ + // require(foo.parent.bar) will technically work, but it depends on legacy goop that + // Luau does not and could not support without a bunch of work. It's deprecated anyway, so + // we'll warn here if we see it. + bool good = true; + AstExprIndexName* indexExpr = expr->as(); + + while (indexExpr) + { + if (indexExpr->index == "parent") + { + solver->reportError(DeprecatedApiUsed{"parent", "Parent"}, indexExpr->indexLocation); + good = false; + } + + indexExpr = indexExpr->expr->as(); + } + + return good; +} + +static bool dcrMagicFunctionRequire(NotNull solver, TypePackId result, const AstExprCall* expr) +{ + if (expr->args.size != 1) + { + solver->reportError(GenericError{"require takes 1 argument"}, expr->location); + return false; + } + + if (!checkRequirePathDcr(solver, expr->args.data[0])) + return false; + + if (auto moduleInfo = solver->moduleResolver->resolveModuleInfo(solver->currentModuleName, *expr)) + { + TypeId moduleType = solver->resolveModule(*moduleInfo, expr->location); + TypePackId moduleResult = solver->arena->addTypePack({moduleType}); + asMutable(result)->ty.emplace(moduleResult); + + return true; + } + + return false; +} + } // namespace Luau diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 2e04b527f..7048d201b 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -102,6 +102,11 @@ struct TypePackCloner defaultClone(t); } + void operator()(const BlockedTypePack& t) + { + defaultClone(t); + } + // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. void operator()(const Unifiable::Bound& t) @@ -170,7 +175,7 @@ void TypeCloner::operator()(const BlockedTypeVar& t) void TypeCloner::operator()(const PendingExpansionTypeVar& t) { - TypeId res = dest.addType(PendingExpansionTypeVar{t.fn, t.typeArguments, t.packArguments}); + TypeId res = dest.addType(PendingExpansionTypeVar{t.prefix, t.name, t.typeArguments, t.packArguments}); PendingExpansionTypeVar* petv = getMutable(res); LUAU_ASSERT(petv); @@ -184,32 +189,6 @@ void TypeCloner::operator()(const PendingExpansionTypeVar& t) for (TypePackId arg : t.packArguments) packArguments.push_back(clone(arg, dest, cloneState)); - TypeFun fn; - fn.type = clone(t.fn.type, dest, cloneState); - - for (const GenericTypeDefinition& param : t.fn.typeParams) - { - TypeId ty = clone(param.ty, dest, cloneState); - std::optional defaultValue = param.defaultValue; - - if (defaultValue) - defaultValue = clone(*defaultValue, dest, cloneState); - - fn.typeParams.push_back(GenericTypeDefinition{ty, defaultValue}); - } - - for (const GenericTypePackDefinition& param : t.fn.typePackParams) - { - TypePackId tp = clone(param.tp, dest, cloneState); - std::optional defaultValue = param.defaultValue; - - if (defaultValue) - defaultValue = clone(*defaultValue, dest, cloneState); - - fn.typePackParams.push_back(GenericTypePackDefinition{tp, defaultValue}); - } - - petv->fn = std::move(fn); petv->typeArguments = std::move(typeArguments); petv->packArguments = std::move(packArguments); } @@ -461,6 +440,7 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl clone.generics = ftv->generics; clone.genericPacks = ftv->genericPacks; clone.magicFunction = ftv->magicFunction; + clone.dcrMagicFunction = ftv->dcrMagicFunction; clone.tags = ftv->tags; clone.argNames = ftv->argNames; result = dest.addType(std::move(clone)); @@ -502,7 +482,7 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl } else if (const PendingExpansionTypeVar* petv = get(ty)) { - PendingExpansionTypeVar clone{petv->fn, petv->typeArguments, petv->packArguments}; + PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; result = dest.addType(std::move(clone)); } else if (const ClassTypeVar* ctv = get(ty); FFlag::LuauClonePublicInterfaceLess && ctv && alwaysClone) diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index d272c0279..3a6417dc1 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -5,9 +5,10 @@ namespace Luau { -Constraint::Constraint(ConstraintV&& c, NotNull scope) - : c(std::move(c)) - , scope(scope) +Constraint::Constraint(NotNull scope, const Location& location, ConstraintV&& c) + : scope(scope) + , location(location) + , c(std::move(c)) { } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 8f9947405..9fabc528d 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -4,6 +4,7 @@ #include "Luau/Ast.h" #include "Luau/Common.h" #include "Luau/Constraint.h" +#include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" #include "Luau/ToString.h" @@ -16,13 +17,31 @@ namespace Luau const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp -ConstraintGraphBuilder::ConstraintGraphBuilder( - const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull ice, const ScopePtr& globalScope) +static std::optional matchRequire(const AstExprCall& call) +{ + const char* require = "require"; + + if (call.args.size != 1) + return std::nullopt; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != require) + return std::nullopt; + + if (call.args.size != 1) + return std::nullopt; + + return call.args.data[0]; +} + +ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, + NotNull moduleResolver, NotNull ice, const ScopePtr& globalScope) : moduleName(moduleName) , module(module) , singletonTypes(getSingletonTypes()) , arena(arena) , rootScope(nullptr) + , moduleResolver(moduleResolver) , ice(ice) , globalScope(globalScope) { @@ -54,9 +73,9 @@ ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& paren return scope; } -void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, ConstraintV cv) +void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) { - scope->constraints.emplace_back(new Constraint{std::move(cv), NotNull{scope.get()}}); + scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}); } void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) @@ -77,13 +96,6 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) prepopulateGlobalScope(scope, block); - // TODO: We should share the global scope. - rootScope->privateTypeBindings["nil"] = TypeFun{singletonTypes.nilType}; - rootScope->privateTypeBindings["number"] = TypeFun{singletonTypes.numberType}; - rootScope->privateTypeBindings["string"] = TypeFun{singletonTypes.stringType}; - rootScope->privateTypeBindings["boolean"] = TypeFun{singletonTypes.booleanType}; - rootScope->privateTypeBindings["thread"] = TypeFun{singletonTypes.threadType}; - visitBlockWithoutChildScope(scope, block); } @@ -158,6 +170,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) visit(scope, s); else if (auto s = stat->as()) visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); else if (auto s = stat->as()) visit(scope, s); else if (auto s = stat->as()) @@ -201,7 +215,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) { location = local->annotation->location; TypeId annotation = resolveType(scope, local->annotation, /* topLevel */ true); - addConstraint(scope, SubtypeConstraint{ty, annotation}); + addConstraint(scope, location, SubtypeConstraint{ty, annotation}); } varTypes.push_back(ty); @@ -225,14 +239,38 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) { std::vector tailValues{varTypes.begin() + i, varTypes.end()}; TypePackId tailPack = arena->addTypePack(std::move(tailValues)); - addConstraint(scope, PackSubtypeConstraint{exprPack, tailPack}); + addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); } } else { TypeId exprType = check(scope, value); if (i < varTypes.size()) - addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}); + addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); + } + } + + if (local->values.size > 0) + { + // To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'. + for (size_t i = 0; i < local->values.size && i < local->vars.size; ++i) + { + const AstExprCall* call = local->values.data[i]->as(); + if (!call) + continue; + + if (auto maybeRequire = matchRequire(*call)) + { + AstExpr* require = *maybeRequire; + + if (auto moduleInfo = moduleResolver->resolveModuleInfo(moduleName, *require)) + { + const Name name{local->vars.data[i]->name.value}; + + if (ModulePtr module = moduleResolver->getModule(moduleInfo->name)) + scope->importedTypeBindings[name] = module->getModuleScope()->exportedTypeBindings; + } + } } } } @@ -244,7 +282,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) return; TypeId t = check(scope, expr); - addConstraint(scope, SubtypeConstraint{t, singletonTypes.numberType}); + addConstraint(scope, expr->location, SubtypeConstraint{t, singletonTypes.numberType}); }; checkNumber(for_->from); @@ -257,6 +295,29 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) visit(forScope, for_->body); } +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) +{ + ScopePtr loopScope = childScope(forIn, scope); + + TypePackId iterator = checkPack(scope, forIn->values); + + std::vector variableTypes; + variableTypes.reserve(forIn->vars.size); + for (AstLocal* var : forIn->vars) + { + TypeId ty = freshType(loopScope); + loopScope->bindings[var] = Binding{ty, var->location}; + variableTypes.push_back(ty); + } + + // It is always ok to provide too few variables, so we give this pack a free tail. + TypePackId variablePack = arena->addTypePack(std::move(variableTypes), arena->addTypePack(FreeTypePack{loopScope.get()})); + + addConstraint(loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack}); + + visit(loopScope, forIn->body); +} + void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_) { check(scope, while_->condition); @@ -284,6 +345,9 @@ void addConstraints(Constraint* constraint, NotNull scope) for (const auto& c : scope->constraints) constraint->dependencies.push_back(NotNull{c.get()}); + for (const auto& c : scope->unqueuedConstraints) + constraint->dependencies.push_back(NotNull{c.get()}); + for (NotNull childScope : scope->children) addConstraints(constraint, childScope); } @@ -308,7 +372,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* checkFunctionBody(sig.bodyScope, function->func); NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; - std::unique_ptr c = std::make_unique(GeneralizationConstraint{functionType, sig.signature}, constraintScope); + std::unique_ptr c = + std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); addConstraints(c.get(), NotNull(sig.bodyScope.get())); addConstraint(scope, std::move(c)); @@ -366,7 +431,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct prop.type = functionType; prop.location = function->name->location; - addConstraint(scope, SubtypeConstraint{containingTableType, prospectiveTableType}); + addConstraint(scope, indexName->location, SubtypeConstraint{containingTableType, prospectiveTableType}); } else if (AstExprError* err = function->name->as()) { @@ -378,7 +443,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct checkFunctionBody(sig.bodyScope, function->func); NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; - std::unique_ptr c = std::make_unique(GeneralizationConstraint{functionType, sig.signature}, constraintScope); + std::unique_ptr c = + std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); addConstraints(c.get(), NotNull(sig.bodyScope.get())); addConstraint(scope, std::move(c)); @@ -387,7 +453,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) { TypePackId exprTypes = checkPack(scope, ret->list); - addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); + addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) @@ -399,10 +465,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { - TypePackId varPackId = checkExprList(scope, assign->vars); + TypePackId varPackId = checkPack(scope, assign->vars); TypePackId valuePack = checkPack(scope, assign->values); - addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}); + addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) @@ -435,8 +501,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias) { - // TODO: Exported type aliases - auto bindingIt = scope->privateTypeBindings.find(alias->name.value); ScopePtr* defnIt = astTypeAliasDefiningScopes.find(alias); // These will be undefined if the alias was a duplicate definition, in which @@ -449,6 +513,12 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia ScopePtr resolvingScope = *defnIt; TypeId ty = resolveType(resolvingScope, alias->type, /* topLevel */ true); + if (alias->exported) + { + Name typeName(alias->name.value); + scope->exportedTypeBindings[typeName] = TypeFun{ty}; + } + LUAU_ASSERT(get(bindingIt->second.type)); // Rather than using a subtype constraint, we instead directly bind @@ -457,7 +527,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia // bind the free alias type to an unrelated type, causing havoc. asMutable(bindingIt->second.type)->ty.emplace(ty); - addConstraint(scope, NameConstraint{ty, alias->name.value}); + addConstraint(scope, alias->location, NameConstraint{ty, alias->name.value}); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) @@ -615,44 +685,22 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs) { - if (exprs.size == 0) - return arena->addTypePack({}); - - std::vector types; - TypePackId last = nullptr; - - for (size_t i = 0; i < exprs.size; ++i) - { - if (i < exprs.size - 1) - types.push_back(check(scope, exprs.data[i])); - else - last = checkPack(scope, exprs.data[i]); - } - - LUAU_ASSERT(last != nullptr); - - return arena->addTypePack(TypePack{std::move(types), last}); -} - -TypePackId ConstraintGraphBuilder::checkExprList(const ScopePtr& scope, const AstArray& exprs) -{ - TypePackId result = arena->addTypePack({}); - TypePack* resultPack = getMutable(result); - LUAU_ASSERT(resultPack); + std::vector head; + std::optional tail; for (size_t i = 0; i < exprs.size; ++i) { AstExpr* expr = exprs.data[i]; if (i < exprs.size - 1) - resultPack->head.push_back(check(scope, expr)); + head.push_back(check(scope, expr)); else - resultPack->tail = checkPack(scope, expr); + tail = checkPack(scope, expr); } - if (resultPack->head.empty() && resultPack->tail) - return *resultPack->tail; + if (head.empty() && tail) + return *tail; else - return result; + return arena->addTypePack(TypePack{std::move(head), tail}); } TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr) @@ -683,13 +731,26 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp astOriginalCallTypes[call->func] = fnType; TypeId instantiatedType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); - - TypePackId rets = freshTypePack(scope); + TypePackId rets = arena->addTypePack(BlockedTypePack{}); FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); TypeId inferredFnType = arena->addType(ftv); - addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); + scope->unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); + NotNull ic(scope->unqueuedConstraints.back().get()); + + scope->unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); + NotNull sc(scope->unqueuedConstraints.back().get()); + + addConstraint(scope, call->func->location, + FunctionCallConstraint{ + {ic, sc}, + fnType, + rets, + call, + }); + result = rets; } else if (AstExprVarargs* varargs = expr->as()) @@ -805,7 +866,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* in TypeId expectedTableType = arena->addType(std::move(ttv)); - addConstraint(scope, SubtypeConstraint{obj, expectedTableType}); + addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); return result; } @@ -820,7 +881,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* in TableIndexer indexer{indexType, result}; TypeId tableType = arena->addType(TableTypeVar{TableTypeVar::Props{}, TableIndexer{indexType, result}, TypeLevel{}, TableState::Free}); - addConstraint(scope, SubtypeConstraint{obj, tableType}); + addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); return result; } @@ -834,7 +895,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) case AstExprUnary::Minus: { TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, UnaryConstraint{AstExprUnary::Minus, operandType, resultType}); + addConstraint(scope, unary->location, UnaryConstraint{AstExprUnary::Minus, operandType, resultType}); return resultType; } default: @@ -853,19 +914,19 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binar { case AstExprBinary::Or: { - addConstraint(scope, SubtypeConstraint{leftType, rightType}); + addConstraint(scope, binary->location, SubtypeConstraint{leftType, rightType}); return leftType; } case AstExprBinary::Add: { TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, BinaryConstraint{AstExprBinary::Add, leftType, rightType, resultType}); + addConstraint(scope, binary->location, BinaryConstraint{AstExprBinary::Add, leftType, rightType, resultType}); return resultType; } case AstExprBinary::Sub: { TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, BinaryConstraint{AstExprBinary::Sub, leftType, rightType, resultType}); + addConstraint(scope, binary->location, BinaryConstraint{AstExprBinary::Sub, leftType, rightType, resultType}); return resultType; } default: @@ -886,8 +947,8 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifEls if (ifElse->hasElse) { TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, SubtypeConstraint{thenType, resultType}); - addConstraint(scope, SubtypeConstraint{elseType, resultType}); + addConstraint(scope, ifElse->trueExpr->location, SubtypeConstraint{thenType, resultType}); + addConstraint(scope, ifElse->falseExpr->location, SubtypeConstraint{elseType, resultType}); return resultType; } @@ -906,7 +967,7 @@ TypeId ConstraintGraphBuilder::checkExprTable(const ScopePtr& scope, AstExprTabl TableTypeVar* ttv = getMutable(ty); LUAU_ASSERT(ttv); - auto createIndexer = [this, scope, ttv](TypeId currentIndexType, TypeId currentResultType) { + auto createIndexer = [this, scope, ttv](const Location& location, TypeId currentIndexType, TypeId currentResultType) { if (!ttv->indexer) { TypeId indexType = this->freshType(scope); @@ -914,8 +975,8 @@ TypeId ConstraintGraphBuilder::checkExprTable(const ScopePtr& scope, AstExprTabl ttv->indexer = TableIndexer{indexType, resultType}; } - addConstraint(scope, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}); - addConstraint(scope, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}); + addConstraint(scope, location, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}); + addConstraint(scope, location, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}); }; for (const AstExprTable::Item& item : expr->items) @@ -937,13 +998,15 @@ TypeId ConstraintGraphBuilder::checkExprTable(const ScopePtr& scope, AstExprTabl } else { - createIndexer(keyTy, itemTy); + createIndexer(item.key->location, keyTy, itemTy); } } else { TypeId numberType = singletonTypes.numberType; - createIndexer(numberType, itemTy); + // FIXME? The location isn't quite right here. Not sure what is + // right. + createIndexer(item.value->location, numberType, itemTy); } } @@ -1008,7 +1071,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS if (fn->returnAnnotation) { TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation); - addConstraint(signatureScope, PackSubtypeConstraint{returnType, annotatedRetType}); + addConstraint(signatureScope, getLocation(*fn->returnAnnotation), PackSubtypeConstraint{returnType, annotatedRetType}); } std::vector argTypes; @@ -1022,7 +1085,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS if (local->annotation) { TypeId argAnnotation = resolveType(signatureScope, local->annotation, /* topLevel */ true); - addConstraint(signatureScope, SubtypeConstraint{t, argAnnotation}); + addConstraint(signatureScope, local->annotation->location, SubtypeConstraint{t, argAnnotation}); } } @@ -1056,7 +1119,7 @@ void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFun if (nullptr != getFallthrough(fn->body)) { TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever - addConstraint(scope, PackSubtypeConstraint{scope->returnType, empty}); + addConstraint(scope, fn->location, PackSubtypeConstraint{scope->returnType, empty}); } } @@ -1066,16 +1129,13 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b if (auto ref = ty->as()) { - // TODO: Support imported types w/ require tracing. - LUAU_ASSERT(!ref->prefix); - std::optional alias = scope->lookupType(ref->name.value); - if (alias.has_value()) + if (alias.has_value() || ref->prefix.has_value()) { // If the alias is not generic, we don't need to set up a blocked // type and an instantiation constraint. - if (alias->typeParams.empty() && alias->typePackParams.empty()) + if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty()) { result = alias->type; } @@ -1104,11 +1164,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b } } - result = arena->addType(PendingExpansionTypeVar{*alias, parameters, packParameters}); + result = arena->addType(PendingExpansionTypeVar{ref->prefix, ref->name, parameters, packParameters}); if (topLevel) { - addConstraint(scope, TypeAliasExpansionConstraint{ /* target */ result }); + addConstraint(scope, ty->location, TypeAliasExpansionConstraint{/* target */ result}); } } } @@ -1141,8 +1201,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b }; } - // TODO: Remove TypeLevel{} here, we don't need it. - result = arena->addType(TableTypeVar{props, indexer, TypeLevel{}, TableState::Sealed}); + result = arena->addType(TableTypeVar{props, indexer, scope->level, TableState::Sealed}); } else if (auto fn = ty->as()) { @@ -1363,7 +1422,7 @@ TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location locat TypePack onePack{{typeResult}, freshTypePack(scope)}; TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); - addConstraint(scope, PackSubtypeConstraint{tp, oneTypePack}); + addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); return typeResult; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index b2b1d4725..1088d9824 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1,9 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" #include "Luau/ConstraintSolver.h" #include "Luau/Instantiation.h" #include "Luau/Location.h" +#include "Luau/ModuleResolver.h" #include "Luau/Quantify.h" #include "Luau/ToString.h" #include "Luau/Unifier.h" @@ -240,11 +242,17 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } } -ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull rootScope) +ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, + std::vector requireCycles) : arena(arena) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) + , currentModuleName(std::move(moduleName)) + , moduleResolver(moduleResolver) + , requireCycles(requireCycles) { + opts.exhaustive = true; + for (NotNull c : constraints) { unsolvedConstraints.push_back(c); @@ -261,9 +269,6 @@ void ConstraintSolver::run() if (done()) return; - ToStringOptions opts; - opts.exhaustive = true; - if (FFlag::DebugLuauLogSolver) { printf("Starting solver\n"); @@ -371,10 +376,14 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*uc, constraint, force); else if (auto bc = get(*constraint)) success = tryDispatch(*bc, constraint, force); + else if (auto ic = get(*constraint)) + success = tryDispatch(*ic, constraint, force); else if (auto nc = get(*constraint)) success = tryDispatch(*nc, constraint); else if (auto taec = get(*constraint)) success = tryDispatch(*taec, constraint); + else if (auto fcc = get(*constraint)) + success = tryDispatch(*fcc, constraint); else LUAU_ASSERT(0); @@ -400,6 +409,11 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { + if (isBlocked(c.subPack)) + return block(c.subPack, constraint); + else if (isBlocked(c.superPack)) + return block(c.superPack, constraint); + unify(c.subPack, c.superPack, constraint->scope); return true; } @@ -512,6 +526,82 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force) +{ + /* + * for .. in loops can play out in a bunch of different ways depending on + * the shape of iteratee. + * + * iteratee might be: + * * (nextFn) + * * (nextFn, table) + * * (nextFn, table, firstIndex) + * * table with a metatable and __index + * * table with a metatable and __call but no __index (if the metatable has + * both, __index takes precedence) + * * table with an indexer but no __index or __call (or no metatable) + * + * To dispatch this constraint, we need first to know enough about iteratee + * to figure out which of the above shapes we are actually working with. + * + * If `force` is true and we still do not know, we must flag a warning. Type + * families are the fix for this. + * + * Since we need to know all of this stuff about the types of the iteratee, + * we have no choice but for ConstraintSolver to also be the thing that + * applies constraints to the types of the iterators. + */ + + auto block_ = [&](auto&& t) { + if (force) + { + // If we haven't figured out the type of the iteratee by now, + // there's nothing we can do. + return true; + } + + block(t, constraint); + return false; + }; + + auto [iteratorTypes, iteratorTail] = flatten(c.iterator); + if (iteratorTail) + return block_(*iteratorTail); + + if (0 == iteratorTypes.size()) + { + Anyification anyify{ + arena, constraint->scope, &iceReporter, getSingletonTypes().errorRecoveryType(), getSingletonTypes().errorRecoveryTypePack()}; + std::optional anyified = anyify.substitute(c.variables); + LUAU_ASSERT(anyified); + unify(*anyified, c.variables, constraint->scope); + + return true; + } + + TypeId nextTy = follow(iteratorTypes[0]); + if (get(nextTy)) + return block_(nextTy); + + if (get(nextTy)) + { + TypeId tableTy = getSingletonTypes().nilType; + if (iteratorTypes.size() >= 2) + tableTy = iteratorTypes[1]; + + TypeId firstIndexTy = getSingletonTypes().nilType; + if (iteratorTypes.size() >= 3) + firstIndexTy = iteratorTypes[2]; + + return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force); + } + + else + return tryDispatchIterableTable(iteratorTypes[0], c, constraint, force); + + return true; +} + bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull constraint) { if (isBlocked(c.namedType)) @@ -519,7 +609,7 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNullpersistent) + if (target->persistent || target->owningArena != arena) return true; if (TableTypeVar* ttv = getMutable(target)) @@ -536,19 +626,27 @@ struct InfiniteTypeFinder : TypeVarOnceVisitor { ConstraintSolver* solver; const InstantiationSignature& signature; + NotNull scope; bool foundInfiniteType = false; - explicit InfiniteTypeFinder(ConstraintSolver* solver, const InstantiationSignature& signature) + explicit InfiniteTypeFinder(ConstraintSolver* solver, const InstantiationSignature& signature, NotNull scope) : solver(solver) , signature(signature) + , scope(scope) { } bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override { - auto [typeArguments, packArguments] = saturateArguments(petv.fn, petv.typeArguments, petv.packArguments, solver->arena); + std::optional tf = + (petv.prefix) ? scope->lookupImportedType(petv.prefix->value, petv.name.value) : scope->lookupType(petv.name.value); + + if (!tf.has_value()) + return true; + + auto [typeArguments, packArguments] = saturateArguments(*tf, petv.typeArguments, petv.packArguments, solver->arena); - if (follow(petv.fn.type) == follow(signature.fn.type) && (signature.arguments != typeArguments || signature.packArguments != packArguments)) + if (follow(tf->type) == follow(signature.fn.type) && (signature.arguments != typeArguments || signature.packArguments != packArguments)) { foundInfiniteType = true; return false; @@ -563,17 +661,19 @@ struct InstantiationQueuer : TypeVarOnceVisitor ConstraintSolver* solver; const InstantiationSignature& signature; NotNull scope; + Location location; - explicit InstantiationQueuer(ConstraintSolver* solver, const InstantiationSignature& signature, NotNull scope) + explicit InstantiationQueuer(NotNull scope, const Location& location, ConstraintSolver* solver, const InstantiationSignature& signature) : solver(solver) , signature(signature) , scope(scope) + , location(location) { } bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override { - solver->pushConstraint(TypeAliasExpansionConstraint{ty}, scope); + solver->pushConstraint(scope, location, TypeAliasExpansionConstraint{ty}); return false; } }; @@ -592,23 +692,32 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul unblock(c.target); }; + std::optional tf = (petv->prefix) ? constraint->scope->lookupImportedType(petv->prefix->value, petv->name.value) + : constraint->scope->lookupType(petv->name.value); + + if (!tf.has_value()) + { + reportError(UnknownSymbol{petv->name.value, UnknownSymbol::Context::Type}, constraint->location); + bindResult(getSingletonTypes().errorRecoveryType()); + return true; + } + // If there are no parameters to the type function we can just use the type // directly. - if (petv->fn.typeParams.empty() && petv->fn.typePackParams.empty()) + if (tf->typeParams.empty() && tf->typePackParams.empty()) { - bindResult(petv->fn.type); + bindResult(tf->type); return true; } - auto [typeArguments, packArguments] = saturateArguments(petv->fn, petv->typeArguments, petv->packArguments, arena); + auto [typeArguments, packArguments] = saturateArguments(*tf, petv->typeArguments, petv->packArguments, arena); - bool sameTypes = - std::equal(typeArguments.begin(), typeArguments.end(), petv->fn.typeParams.begin(), petv->fn.typeParams.end(), [](auto&& itp, auto&& p) { - return itp == p.ty; - }); + bool sameTypes = std::equal(typeArguments.begin(), typeArguments.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& p) { + return itp == p.ty; + }); - bool samePacks = std::equal( - packArguments.begin(), packArguments.end(), petv->fn.typePackParams.begin(), petv->fn.typePackParams.end(), [](auto&& itp, auto&& p) { + bool samePacks = + std::equal(packArguments.begin(), packArguments.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itp, auto&& p) { return itp == p.tp; }); @@ -617,12 +726,12 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // to the TypeFun's type. if (sameTypes && samePacks) { - bindResult(petv->fn.type); + bindResult(tf->type); return true; } InstantiationSignature signature{ - petv->fn, + *tf, typeArguments, packArguments, }; @@ -642,8 +751,8 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // https://github.com/Roblox/luau/pull/68 for the RFC responsible for this. // This is a little nicer than using a recursion limit because we can catch // the infinite expansion before actually trying to expand it. - InfiniteTypeFinder itf{this, signature}; - itf.traverse(petv->fn.type); + InfiniteTypeFinder itf{this, signature, constraint->scope}; + itf.traverse(tf->type); if (itf.foundInfiniteType) { @@ -655,15 +764,15 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul ApplyTypeFunction applyTypeFunction{arena}; for (size_t i = 0; i < typeArguments.size(); ++i) { - applyTypeFunction.typeArguments[petv->fn.typeParams[i].ty] = typeArguments[i]; + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeArguments[i]; } for (size_t i = 0; i < packArguments.size(); ++i) { - applyTypeFunction.typePackArguments[petv->fn.typePackParams[i].tp] = packArguments[i]; + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = packArguments[i]; } - std::optional maybeInstantiated = applyTypeFunction.substitute(petv->fn.type); + std::optional maybeInstantiated = applyTypeFunction.substitute(tf->type); // Note that ApplyTypeFunction::encounteredForwardedType is never set in // DCR, because we do not use free types for forward-declared generic // aliases. @@ -683,7 +792,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // Type function application will happily give us the exact same type if // there are e.g. generic saturatedTypeArguments that go unused. - bool needsClone = follow(petv->fn.type) == target; + bool needsClone = follow(tf->type) == target; // Only tables have the properties we're trying to set. TableTypeVar* ttv = getMutableTableType(target); @@ -722,7 +831,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // The application is not recursive, so we need to queue up application of // any child type function instantiations within the result in order for it // to be complete. - InstantiationQueuer queuer{this, signature, constraint->scope}; + InstantiationQueuer queuer{constraint->scope, constraint->location, this, signature}; queuer.traverse(target); instantiatedAliases[signature] = target; @@ -730,6 +839,152 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } +bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull constraint) +{ + TypeId fn = follow(c.fn); + TypePackId result = follow(c.result); + + if (isBlocked(c.fn)) + { + return block(c.fn, constraint); + } + + const FunctionTypeVar* ftv = get(fn); + bool usedMagic = false; + + if (ftv && ftv->dcrMagicFunction != nullptr) + { + usedMagic = ftv->dcrMagicFunction(NotNull(this), result, c.astFragment); + } + + if (!usedMagic) + { + for (const auto& inner : c.innerConstraints) + { + unsolvedConstraints.push_back(inner); + } + + asMutable(c.result)->ty.emplace(constraint->scope); + } + + unblock(c.result); + + return true; +} + +bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) +{ + auto block_ = [&](auto&& t) { + if (force) + { + // TODO: I believe it is the case that, if we are asked to force + // this constraint, then we can do nothing but fail. I'd like to + // find a code sample that gets here. + LUAU_ASSERT(0); + } + else + block(t, constraint); + return false; + }; + + // We may have to block here if we don't know what the iteratee type is, + // if it's a free table, if we don't know it has a metatable, and so on. + iteratorTy = follow(iteratorTy); + if (get(iteratorTy)) + return block_(iteratorTy); + + auto anyify = [&](auto ty) { + Anyification anyify{arena, constraint->scope, &iceReporter, getSingletonTypes().anyType, getSingletonTypes().anyTypePack}; + std::optional anyified = anyify.substitute(ty); + if (!anyified) + reportError(CodeTooComplex{}, constraint->location); + else + unify(*anyified, ty, constraint->scope); + }; + + auto errorify = [&](auto ty) { + Anyification anyify{ + arena, constraint->scope, &iceReporter, getSingletonTypes().errorRecoveryType(), getSingletonTypes().errorRecoveryTypePack()}; + std::optional errorified = anyify.substitute(ty); + if (!errorified) + reportError(CodeTooComplex{}, constraint->location); + else + unify(*errorified, ty, constraint->scope); + }; + + if (get(iteratorTy)) + { + anyify(c.variables); + return true; + } + + if (get(iteratorTy)) + { + errorify(c.variables); + return true; + } + + // Irksome: I don't think we have any way to guarantee that this table + // type never has a metatable. + + if (auto iteratorTable = get(iteratorTy)) + { + if (iteratorTable->state == TableState::Free) + return block_(iteratorTy); + + if (iteratorTable->indexer) + { + TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); + unify(c.variables, expectedVariablePack, constraint->scope); + } + else + errorify(c.variables); + } + else if (auto iteratorMetatable = get(iteratorTy)) + { + TypeId metaTy = follow(iteratorMetatable->metatable); + if (get(metaTy)) + return block_(metaTy); + + LUAU_ASSERT(0); + } + else + errorify(c.variables); + + return true; +} + +bool ConstraintSolver::tryDispatchIterableFunction( + TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force) +{ + // We need to know whether or not this type is nil or not. + // If we don't know, block and reschedule ourselves. + firstIndexTy = follow(firstIndexTy); + if (get(firstIndexTy)) + { + if (force) + LUAU_ASSERT(0); + else + block(firstIndexTy, constraint); + return false; + } + + const TypeId firstIndex = isNil(firstIndexTy) ? arena->freshType(constraint->scope) // FIXME: Surely this should be a union (free | nil) + : firstIndexTy; + + // nextTy : (tableTy, indexTy?) -> (indexTy, valueTailTy...) + const TypePackId nextArgPack = arena->addTypePack({tableTy, arena->addType(UnionTypeVar{{firstIndex, getSingletonTypes().nilType}})}); + const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); + const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); + + const TypeId expectedNextTy = arena->addType(FunctionTypeVar{nextArgPack, nextRetPack}); + unify(nextTy, expectedNextTy, constraint->scope); + + pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); + + return true; +} + void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); @@ -741,14 +996,14 @@ void ConstraintSolver::block_(BlockedConstraintId target, NotNull target, NotNull constraint) { if (FFlag::DebugLuauLogSolver) - printf("block Constraint %s on\t%s\n", toString(*target).c_str(), toString(*constraint).c_str()); + printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str()); block_(target, constraint); } bool ConstraintSolver::block(TypeId target, NotNull constraint) { if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target).c_str(), toString(*constraint).c_str()); + printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); block_(target, constraint); return false; } @@ -756,7 +1011,7 @@ bool ConstraintSolver::block(TypeId target, NotNull constraint bool ConstraintSolver::block(TypePackId target, NotNull constraint) { if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target).c_str(), toString(*constraint).c_str()); + printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); block_(target, constraint); return false; } @@ -772,7 +1027,7 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) { auto& count = blockedConstraints[unblockedConstraint]; if (FFlag::DebugLuauLogSolver) - printf("Unblocking count=%d\t%s\n", int(count), toString(*unblockedConstraint).c_str()); + printf("Unblocking count=%d\t%s\n", int(count), toString(*unblockedConstraint, opts).c_str()); // This assertion being hit indicates that `blocked` and // `blockedConstraints` desynchronized at some point. This is problematic @@ -817,6 +1072,11 @@ bool ConstraintSolver::isBlocked(TypeId ty) return nullptr != get(follow(ty)) || nullptr != get(follow(ty)); } +bool ConstraintSolver::isBlocked(TypePackId tp) +{ + return nullptr != get(follow(tp)); +} + bool ConstraintSolver::isBlocked(NotNull constraint) { auto blockedIt = blockedConstraints.find(constraint); @@ -830,6 +1090,13 @@ void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull sc u.tryUnify(subType, superType); + if (!u.errors.empty()) + { + TypeId errorType = getSingletonTypes().errorRecoveryType(); + u.tryUnify(subType, errorType); + u.tryUnify(superType, errorType); + } + const auto [changedTypes, changedPacks] = u.log.getChanges(); u.log.commit(); @@ -853,22 +1120,69 @@ void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) +void ConstraintSolver::pushConstraint(NotNull scope, const Location& location, ConstraintV cv) { - std::unique_ptr c = std::make_unique(std::move(cv), scope); + std::unique_ptr c = std::make_unique(scope, location, std::move(cv)); NotNull borrow = NotNull(c.get()); solverConstraints.push_back(std::move(c)); unsolvedConstraints.push_back(borrow); } +TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& location) +{ + if (info.name.empty()) + { + reportError(UnknownRequire{}, location); + return getSingletonTypes().errorRecoveryType(); + } + + std::string humanReadableName = moduleResolver->getHumanReadableModuleName(info.name); + + for (const auto& [location, path] : requireCycles) + { + if (!path.empty() && path.front() == humanReadableName) + return getSingletonTypes().anyType; + } + + ModulePtr module = moduleResolver->getModule(info.name); + if (!module) + { + if (!moduleResolver->moduleExists(info.name) && !info.optional) + reportError(UnknownRequire{humanReadableName}, location); + + return getSingletonTypes().errorRecoveryType(); + } + + if (module->type != SourceCode::Type::Module) + { + reportError(IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}, location); + return getSingletonTypes().errorRecoveryType(); + } + + TypePackId modulePack = module->getModuleScope()->returnType; + if (get(modulePack)) + return getSingletonTypes().errorRecoveryType(); + + std::optional moduleType = first(modulePack); + if (!moduleType) + { + reportError(IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}, location); + return getSingletonTypes().errorRecoveryType(); + } + + return *moduleType; +} + void ConstraintSolver::reportError(TypeErrorData&& data, const Location& location) { errors.emplace_back(location, std::move(data)); + errors.back().moduleName = currentModuleName; } void ConstraintSolver::reportError(TypeError e) { errors.emplace_back(std::move(e)); + errors.back().moduleName = currentModuleName; } } // namespace Luau diff --git a/Analysis/src/ConstraintSolverLogger.cpp b/Analysis/src/ConstraintSolverLogger.cpp index 097ceeecc..5ba405216 100644 --- a/Analysis/src/ConstraintSolverLogger.cpp +++ b/Analysis/src/ConstraintSolverLogger.cpp @@ -3,6 +3,7 @@ #include "Luau/ConstraintSolverLogger.h" #include "Luau/JsonEmitter.h" +#include "Luau/ToString.h" LUAU_FASTFLAG(LuauFixNameMaps); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index c8c5d4b59..d8839f2f3 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -14,6 +14,7 @@ #include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" +#include "Luau/BuiltinDefinitions.h" #include #include @@ -99,7 +100,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c module.root = parseResult.root; module.mode = Mode::Definition; - ModulePtr checkedModule = check(module, Mode::Definition, globalScope); + ModulePtr checkedModule = check(module, Mode::Definition, globalScope, {}); if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, checkedModule}; @@ -526,7 +527,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional requireCycles) { ModulePtr result = std::make_shared(); - ConstraintGraphBuilder cgb{sourceModule.name, result, &result->internalTypes, NotNull(&iceHandler), getGlobalScope()}; + ConstraintGraphBuilder cgb{sourceModule.name, result, &result->internalTypes, NotNull(&moduleResolver), NotNull(&iceHandler), getGlobalScope()}; cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope)}; + ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope), sourceModule.name, NotNull(&moduleResolver), requireCycles}; cs.run(); for (TypeError& e : cs.errors) @@ -852,11 +853,12 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); result->astResolvedTypes = std::move(cgb.astResolvedTypes); result->astResolvedTypePacks = std::move(cgb.astResolvedTypePacks); - - result->clonePublicInterface(iceHandler); + result->type = sourceModule.type; Luau::check(sourceModule, result.get()); + result->clonePublicInterface(iceHandler); + return result; } diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index e98ab1858..2d1d62f31 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -46,6 +46,7 @@ TypeId Instantiation::clean(TypeId ty) FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.magicFunction = ftv->magicFunction; + clone.dcrMagicFunction = ftv->dcrMagicFunction; clone.tags = ftv->tags; clone.argNames = ftv->argNames; TypeId result = addType(std::move(clone)); diff --git a/Analysis/src/JsonEmitter.cpp b/Analysis/src/JsonEmitter.cpp index e99619baa..9c8a7af9d 100644 --- a/Analysis/src/JsonEmitter.cpp +++ b/Analysis/src/JsonEmitter.cpp @@ -11,7 +11,8 @@ namespace Luau::Json static constexpr int CHUNK_SIZE = 1024; ObjectEmitter::ObjectEmitter(NotNull emitter) - : emitter(emitter), finished(false) + : emitter(emitter) + , finished(false) { comma = emitter->pushComma(); emitter->writeRaw('{'); @@ -33,7 +34,8 @@ void ObjectEmitter::finish() } ArrayEmitter::ArrayEmitter(NotNull emitter) - : emitter(emitter), finished(false) + : emitter(emitter) + , finished(false) { comma = emitter->pushComma(); emitter->writeRaw('['); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 426ff9d6a..669739a04 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -216,7 +216,8 @@ static bool similar(AstExpr* lhs, AstExpr* rhs) return false; for (size_t i = 0; i < le->strings.size; ++i) - if (le->strings.data[i].size != re->strings.data[i].size || memcmp(le->strings.data[i].data, re->strings.data[i].data, le->strings.data[i].size) != 0) + if (le->strings.data[i].size != re->strings.data[i].size || + memcmp(le->strings.data[i].data, re->strings.data[i].data, le->strings.data[i].size) != 0) return false; for (size_t i = 0; i < le->expressions.size; ++i) @@ -2675,13 +2676,18 @@ class LintComparisonPrecedence : AstVisitor private: LintContext* context; - bool isComparison(AstExprBinary::Op op) + static bool isEquality(AstExprBinary::Op op) + { + return op == AstExprBinary::CompareNe || op == AstExprBinary::CompareEq; + } + + static bool isComparison(AstExprBinary::Op op) { return op == AstExprBinary::CompareNe || op == AstExprBinary::CompareEq || op == AstExprBinary::CompareLt || op == AstExprBinary::CompareLe || op == AstExprBinary::CompareGt || op == AstExprBinary::CompareGe; } - bool isNot(AstExpr* node) + static bool isNot(AstExpr* node) { AstExprUnary* expr = node->as(); @@ -2698,22 +2704,26 @@ class LintComparisonPrecedence : AstVisitor { std::string op = toString(node->op); - if (node->op == AstExprBinary::CompareEq || node->op == AstExprBinary::CompareNe) + if (isEquality(node->op)) emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or wrap one of the expressions in parentheses to silence", - op.c_str(), op.c_str(), node->op == AstExprBinary::CompareEq ? "~=" : "=="); + "not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence", op.c_str(), op.c_str(), + node->op == AstExprBinary::CompareEq ? "~=" : "=="); else emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "not X %s Y is equivalent to (not X) %s Y; wrap one of the expressions in parentheses to silence", op.c_str(), op.c_str()); + "not X %s Y is equivalent to (not X) %s Y; add parentheses to silence", op.c_str(), op.c_str()); } else if (AstExprBinary* left = node->left->as(); left && isComparison(left->op)) { std::string lop = toString(left->op); std::string rop = toString(node->op); - emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "X %s Y %s Z is equivalent to (X %s Y) %s Z; wrap one of the expressions in parentheses to silence", lop.c_str(), rop.c_str(), - lop.c_str(), rop.c_str()); + if (isEquality(left->op) || isEquality(node->op)) + emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, + "X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str()); + else + emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, + "X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str(), + lop.c_str(), rop.c_str()); } return true; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index de796c7a4..4c9e95378 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -219,8 +219,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) TypePackId returnType = moduleScope->returnType; std::optional varargPack = FFlag::DebugLuauDeferredConstraintResolution ? std::nullopt : moduleScope->varargPack; - std::unordered_map* exportedTypeBindings = - FFlag::DebugLuauDeferredConstraintResolution ? nullptr : &moduleScope->exportedTypeBindings; + std::unordered_map* exportedTypeBindings = &moduleScope->exportedTypeBindings; TxnLog log; ClonePublicInterface clonePublicInterface{&log, this}; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 0beeb58ca..fa12f306b 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -510,7 +510,7 @@ void Substitution::foundDirty(TypeId ty) ty = log->follow(ty); if (FFlag::LuauSubstitutionReentrant && newTypes.contains(ty)) - return; + return; if (isDirty(ty)) newTypes[ty] = follow(clean(ty)); @@ -523,7 +523,7 @@ void Substitution::foundDirty(TypePackId tp) tp = log->follow(tp); if (FFlag::LuauSubstitutionReentrant && newPacks.contains(tp)) - return; + return; if (isDirty(tp)) newPacks[tp] = follow(clean(tp)); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index ace44cda6..13cd7490e 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1034,6 +1034,13 @@ struct TypePackStringifier { stringify(btv.boundTo); } + + void operator()(TypePackId, const BlockedTypePack& btp) + { + state.emit("*blocked-tp-"); + state.emit(btp.index); + state.emit("*"); + } }; void TypeVarStringifier::stringify(TypePackId tp) @@ -1095,9 +1102,8 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) ToStringResult result; - StringifierState state = FFlag::LuauFixNameMaps - ? StringifierState{opts, result, opts.nameMap} - : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state = + FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; std::set cycles; std::set cycleTPs; @@ -1204,9 +1210,8 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) * 4. Print out the root of the type using the same algorithm as step 3. */ ToStringResult result; - StringifierState state = FFlag::LuauFixNameMaps - ? StringifierState{opts, result, opts.nameMap} - : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state = + FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; std::set cycles; std::set cycleTPs; @@ -1292,9 +1297,8 @@ std::string toString(const TypePackVar& tp, ToStringOptions& opts) std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts) { ToStringResult result; - StringifierState state = FFlag::LuauFixNameMaps - ? StringifierState{opts, result, opts.nameMap} - : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state = + FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; TypeVarStringifier tvs{state}; state.emit(funcName); @@ -1427,8 +1431,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) using T = std::decay_t; // TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps - auto tos = [](auto&& a, ToStringOptions& opts) - { + auto tos = [](auto&& a, ToStringOptions& opts) { if (FFlag::LuauFixNameMaps) return toString(a, opts); else @@ -1478,6 +1481,13 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) return resultStr + " ~ Binary<" + toString(c.op) + ", " + leftStr + ", " + rightStr + ">"; } + else if constexpr (std::is_same_v) + { + std::string iteratorStr = tos(c.iterator, opts); + std::string variableStr = tos(c.variables, opts); + + return variableStr + " ~ Iterate<" + iteratorStr + ">"; + } else if constexpr (std::is_same_v) { std::string namedStr = tos(c.namedType, opts); @@ -1488,6 +1498,10 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) std::string targetStr = tos(c.target, opts); return "expand " + targetStr; } + else if constexpr (std::is_same_v) + { + return "call " + tos(c.fn, opts) + " with { result = " + tos(c.result, opts) + " }"; + } else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index 0c89d130c..c7980ab0b 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -31,6 +31,15 @@ TypeId TypeArena::freshType(TypeLevel level) return allocated; } +TypeId TypeArena::freshType(Scope* scope) +{ + TypeId allocated = typeVars.allocate(FreeTypeVar{scope}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + TypePackId TypeArena::addTypePack(std::initializer_list types) { TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); @@ -40,9 +49,9 @@ TypePackId TypeArena::addTypePack(std::initializer_list types) return allocated; } -TypePackId TypeArena::addTypePack(std::vector types) +TypePackId TypeArena::addTypePack(std::vector types, std::optional tail) { - TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); + TypePackId allocated = typePacks.allocate(TypePack{std::move(types), tail}); asMutable(allocated)->owningArena = this; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f21a4fa9c..84494083b 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -373,6 +373,11 @@ class TypePackRehydrationVisitor return Luau::visit(*this, btp.boundTo->ty); } + AstTypePack* operator()(const BlockedTypePack& btp) const + { + return allocator->alloc(Location(), AstName("*blocked*")); + } + AstTypePack* operator()(const TypePack& tp) const { AstArray head; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index e5813cd2a..480bdf403 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -166,7 +166,8 @@ struct TypeChecker2 auto pusher = pushStack(stat); if (0) - {} + { + } else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) @@ -239,11 +240,9 @@ struct TypeChecker2 visit(repeatStatement->condition); } - void visit(AstStatBreak*) - {} + void visit(AstStatBreak*) {} - void visit(AstStatContinue*) - {} + void visit(AstStatContinue*) {} void visit(AstStatReturn* ret) { @@ -339,6 +338,50 @@ struct TypeChecker2 visit(forStatement->body); } + // "Render" a type pack out to an array of a given length. Expands + // variadics and various other things to get there. + static std::vector flatten(TypeArena& arena, TypePackId pack, size_t length) + { + std::vector result; + + auto it = begin(pack); + auto endIt = end(pack); + + while (it != endIt) + { + result.push_back(*it); + + if (result.size() >= length) + return result; + + ++it; + } + + if (!it.tail()) + return result; + + TypePackId tail = *it.tail(); + if (get(tail)) + LUAU_ASSERT(0); + else if (auto vtp = get(tail)) + { + while (result.size() < length) + result.push_back(vtp->ty); + } + else if (get(tail) || get(tail)) + { + while (result.size() < length) + result.push_back(arena.addType(FreeTypeVar{nullptr})); + } + else if (auto etp = get(tail)) + { + while (result.size() < length) + result.push_back(getSingletonTypes().errorRecoveryType()); + } + + return result; + } + void visit(AstStatForIn* forInStatement) { for (AstLocal* local : forInStatement->vars) @@ -351,6 +394,128 @@ struct TypeChecker2 visit(expr); visit(forInStatement->body); + + // Rule out crazy stuff. Maybe possible if the file is not syntactically valid. + if (!forInStatement->vars.size || !forInStatement->values.size) + return; + + NotNull scope = stack.back(); + TypeArena tempArena; + + std::vector variableTypes; + for (AstLocal* var : forInStatement->vars) + { + std::optional ty = scope->lookup(var); + LUAU_ASSERT(ty); + variableTypes.emplace_back(*ty); + } + + // ugh. There's nothing in the AST to hang a whole type pack on for the + // set of iteratees, so we have to piece it back together by hand. + std::vector valueTypes; + for (size_t i = 0; i < forInStatement->values.size - 1; ++i) + valueTypes.emplace_back(lookupType(forInStatement->values.data[i])); + TypePackId iteratorTail = lookupPack(forInStatement->values.data[forInStatement->values.size - 1]); + TypePackId iteratorPack = tempArena.addTypePack(valueTypes, iteratorTail); + + // ... and then expand it out to 3 values (if possible) + const std::vector iteratorTypes = flatten(tempArena, iteratorPack, 3); + if (iteratorTypes.empty()) + { + reportError(GenericError{"for..in loops require at least one value to iterate over. Got zero"}, getLocation(forInStatement->values)); + return; + } + TypeId iteratorTy = follow(iteratorTypes[0]); + + /* + * If the first iterator argument is a function + * * There must be 1 to 3 iterator arguments. Name them (nextTy, + * arrayTy, startIndexTy) + * * The return type of nextTy() must correspond to the variables' + * types and counts. HOWEVER the first iterator will never be nil. + * * The first return value of nextTy must be compatible with + * startIndexTy. + * * The first argument to nextTy() must be compatible with arrayTy if + * present. nil if not. + * * The second argument to nextTy() must be compatible with + * startIndexTy if it is present. Else, it must be compatible with + * nil. + * * nextTy() must be callable with only 2 arguments. + */ + if (const FunctionTypeVar* nextFn = get(iteratorTy)) + { + if (iteratorTypes.size() < 1 || iteratorTypes.size() > 3) + reportError(GenericError{"for..in loops must be passed (next, [table[, state]])"}, getLocation(forInStatement->values)); + + // It is okay if there aren't enough iterators, but the iteratee must provide enough. + std::vector expectedVariableTypes = flatten(tempArena, nextFn->retTypes, variableTypes.size()); + if (expectedVariableTypes.size() < variableTypes.size()) + reportError(GenericError{"next() does not return enough values"}, forInStatement->vars.data[0]->location); + + for (size_t i = 0; i < std::min(expectedVariableTypes.size(), variableTypes.size()); ++i) + reportErrors(tryUnify(scope, forInStatement->vars.data[i]->location, variableTypes[i], expectedVariableTypes[i])); + + // nextFn is going to be invoked with (arrayTy, startIndexTy) + + // It will be passed two arguments on every iteration save the + // first. + + // It may be invoked with 0 or 1 argument on the first iteration. + // This depends on the types in iterateePack and therefore + // iteratorTypes. + + // If iteratorTypes is too short to be a valid call to nextFn, we have to report a count mismatch error. + // If 2 is too short to be a valid call to nextFn, we have to report a count mismatch error. + // If 2 is too long to be a valid call to nextFn, we have to report a count mismatch error. + auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), nextFn->argTypes); + + if (minCount > 2) + reportError(CountMismatch{2, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + if (maxCount && *maxCount < 2) + reportError(CountMismatch{2, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + + const std::vector flattenedArgTypes = flatten(tempArena, nextFn->argTypes, 2); + const auto [argTypes, argsTail] = Luau::flatten(nextFn->argTypes); + + size_t firstIterationArgCount = iteratorTypes.empty() ? 0 : iteratorTypes.size() - 1; + size_t actualArgCount = expectedVariableTypes.size(); + + if (firstIterationArgCount < minCount) + reportError(CountMismatch{2, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + else if (actualArgCount < minCount) + reportError(CountMismatch{2, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + + if (iteratorTypes.size() >= 2 && flattenedArgTypes.size() > 0) + { + size_t valueIndex = forInStatement->values.size > 1 ? 1 : 0; + reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iteratorTypes[1], flattenedArgTypes[0])); + } + + if (iteratorTypes.size() == 3 && flattenedArgTypes.size() > 1) + { + size_t valueIndex = forInStatement->values.size > 2 ? 2 : 0; + reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iteratorTypes[2], flattenedArgTypes[1])); + } + } + else if (const TableTypeVar* ttv = get(iteratorTy)) + { + if ((forInStatement->vars.size == 1 || forInStatement->vars.size == 2) && ttv->indexer) + { + reportErrors(tryUnify(scope, forInStatement->vars.data[0]->location, variableTypes[0], ttv->indexer->indexType)); + if (variableTypes.size() == 2) + reportErrors(tryUnify(scope, forInStatement->vars.data[1]->location, variableTypes[1], ttv->indexer->indexResultType)); + } + else + reportError(GenericError{"Cannot iterate over a table without indexer"}, forInStatement->values.data[0]->location); + } + else if (get(iteratorTy) || get(iteratorTy)) + { + // nothing + } + else + { + reportError(CannotCallNonFunction{iteratorTy}, forInStatement->values.data[0]->location); + } } void visit(AstStatAssign* assign) @@ -456,7 +621,8 @@ struct TypeChecker2 auto StackPusher = pushStack(expr); if (0) - {} + { + } else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) @@ -561,9 +727,21 @@ struct TypeChecker2 TypePackId expectedRetType = lookupPack(call); TypeId functionType = lookupType(call->func); - TypeId instantiatedFunctionType = instantiation.substitute(functionType).value_or(nullptr); LUAU_ASSERT(functionType); + if (get(functionType) || get(functionType)) + return; + + // TODO: Lots of other types are callable: intersections of functions + // and things with the __call metamethod. + if (!get(functionType)) + { + reportError(CannotCallNonFunction{functionType}, call->func->location); + return; + } + + TypeId instantiatedFunctionType = follow(instantiation.substitute(functionType).value_or(nullptr)); + TypePack args; for (AstExpr* arg : call->args) { @@ -575,12 +753,11 @@ struct TypeChecker2 TypePackId argsTp = arena.addTypePack(args); FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); + if (!isSubtype(expectedType, instantiatedFunctionType, stack.back(), ice)) { - unfreeze(module->interfaceTypes); CloneState cloneState; - expectedType = clone(expectedType, module->interfaceTypes, cloneState); - freeze(module->interfaceTypes); + expectedType = clone(expectedType, module->internalTypes, cloneState); reportError(TypeMismatch{expectedType, functionType}, call->location); } } @@ -592,7 +769,8 @@ struct TypeChecker2 // leftType must have a property called indexName->index - std::optional ty = getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); + std::optional ty = + getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); if (ty) { if (!isSubtype(resultType, *ty, stack.back(), ice)) @@ -972,18 +1150,34 @@ struct TypeChecker2 } } - void reportError(TypeErrorData&& data, const Location& location) + template + ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy) + { + UnifierSharedState sharedState{&ice}; + Unifier u{&module->internalTypes, Mode::Strict, scope, location, Covariant, sharedState}; + u.anyIsTop = true; + u.tryUnify(subTy, superTy); + + return std::move(u.errors); + } + + void reportError(TypeErrorData data, const Location& location) { module->errors.emplace_back(location, sourceModule->name, std::move(data)); } void reportError(TypeError e) { - module->errors.emplace_back(std::move(e)); + reportError(std::move(e.data), e.location); + } + + void reportErrors(ErrorVec errors) + { + for (TypeError e : errors) + reportError(std::move(e)); } - std::optional getIndexTypeFromType( - const ScopePtr& scope, TypeId type, const std::string& prop, const Location& location, bool addErrors) + std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const std::string& prop, const Location& location, bool addErrors) { return Luau::getIndexTypeFromType(scope, module->errors, &module->internalTypes, type, prop, location, addErrors, ice); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 77168053b..2bda2804c 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,7 +32,6 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) -LUAU_FASTFLAGVARIABLE(LuauExpectedTableUnionIndexerType, false) LUAU_FASTFLAGVARIABLE(LuauInplaceDemoteSkipAllBound, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) @@ -45,6 +44,7 @@ LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) +LUAU_FASTFLAGVARIABLE(LuauUnionOfTypesFollow, false) namespace Luau { @@ -1659,6 +1659,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar TypeId propTy = resolveType(scope, *prop.ty); bool assignToMetatable = isMetamethod(propName); + Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; // Function types always take 'self', but this isn't reflected in the // parsed annotation. Add it here. @@ -1674,16 +1675,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } } - if (ctv->props.count(propName) == 0) + if (assignTo.count(propName) == 0) { - if (assignToMetatable) - metatable->props[propName] = {propTy}; - else - ctv->props[propName] = {propTy}; + assignTo[propName] = {propTy}; } else { - TypeId currentTy = assignToMetatable ? metatable->props[propName].type : ctv->props[propName].type; + TypeId currentTy = assignTo[propName].type; // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. @@ -1693,19 +1691,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar options.push_back(propTy); TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); - if (assignToMetatable) - metatable->props[propName] = {newItv}; - else - ctv->props[propName] = {newItv}; + assignTo[propName] = {newItv}; } else if (get(currentTy)) { TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); - if (assignToMetatable) - metatable->props[propName] = {intersection}; - else - ctv->props[propName] = {intersection}; + assignTo[propName] = {intersection}; } else { @@ -2351,7 +2343,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp { if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) expectedResultTypes.push_back(prop->second.type); - else if (FFlag::LuauExpectedTableUnionIndexerType && ttv->indexer && maybeString(ttv->indexer->indexType)) + else if (ttv->indexer && maybeString(ttv->indexer->indexType)) expectedResultTypes.push_back(ttv->indexer->indexResultType); } } @@ -2506,6 +2498,12 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op) TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes) { + if (FFlag::LuauUnionOfTypesFollow) + { + a = follow(a); + b = follow(b); + } + if (unifyFreeTypes && (get(a) || get(b))) { if (unify(b, a, scope, location)) @@ -3667,33 +3665,6 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope } } -// Returns the minimum number of arguments the argument list can accept. -static size_t getMinParameterCount(TxnLog* log, TypePackId tp) -{ - size_t minCount = 0; - size_t optionalCount = 0; - - auto it = begin(tp, log); - auto endIter = end(tp); - - while (it != endIter) - { - TypeId ty = *it; - if (isOptional(ty)) - ++optionalCount; - else - { - minCount += optionalCount; - optionalCount = 0; - minCount++; - } - - ++it; - } - - return minCount; -} - void TypeChecker::checkArgumentList( const ScopePtr& scope, Unifier& state, TypePackId argPack, TypePackId paramPack, const std::vector& argLocations) { @@ -3713,7 +3684,7 @@ void TypeChecker::checkArgumentList( if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - size_t minParams = getMinParameterCount(&state.log, paramPack); + size_t minParams = getParameterExtents(&state.log, paramPack).first; state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); }; @@ -3812,7 +3783,7 @@ void TypeChecker::checkArgumentList( } // ok else { - size_t minParams = getMinParameterCount(&state.log, paramPack); + size_t minParams = getParameterExtents(&state.log, paramPack).first; std::optional tail = flatten(paramPack, state.log).second; bool isVariadic = tail && Luau::isVariadic(*tail); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index d4544483a..2fa9413a8 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -8,6 +8,13 @@ namespace Luau { +BlockedTypePack::BlockedTypePack() + : index(++nextIndex) +{ +} + +size_t BlockedTypePack::nextIndex = 0; + TypePackVar::TypePackVar(const TypePackVariant& tp) : ty(tp) { diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 60bca0a30..56fcceccc 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -84,9 +84,8 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t return std::nullopt; } -std::optional getIndexTypeFromType( - const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, TypeId type, const std::string& prop, const Location& location, bool addErrors, - InternalErrorReporter& handle) +std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, TypeId type, const std::string& prop, + const Location& location, bool addErrors, InternalErrorReporter& handle) { type = follow(type); @@ -190,4 +189,33 @@ std::optional getIndexTypeFromType( return std::nullopt; } +std::pair> getParameterExtents(const TxnLog* log, TypePackId tp) +{ + size_t minCount = 0; + size_t optionalCount = 0; + + auto it = begin(tp, log); + auto endIter = end(tp); + + while (it != endIter) + { + TypeId ty = *it; + if (isOptional(ty)) + ++optionalCount; + else + { + minCount += optionalCount; + optionalCount = 0; + minCount++; + } + + ++it; + } + + if (it.tail()) + return {minCount, std::nullopt}; + else + return {minCount, minCount + optionalCount}; +} + } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 8974f8c70..4abee0f6d 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -24,9 +24,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauUnknownAndNeverType) -LUAU_FASTFLAGVARIABLE(LuauDeduceGmatchReturnTypes, false) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) -LUAU_FASTFLAGVARIABLE(LuauDeduceFindMatchReturnTypes, false) LUAU_FASTFLAGVARIABLE(LuauStringFormatArgumentErrorFix, false) namespace Luau @@ -446,8 +444,10 @@ BlockedTypeVar::BlockedTypeVar() int BlockedTypeVar::nextIndex = 0; -PendingExpansionTypeVar::PendingExpansionTypeVar(TypeFun fn, std::vector typeArguments, std::vector packArguments) - : fn(fn) +PendingExpansionTypeVar::PendingExpansionTypeVar( + std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) + : prefix(prefix) + , name(name) , typeArguments(typeArguments) , packArguments(packArguments) , index(++nextIndex) @@ -787,8 +787,8 @@ TypeId SingletonTypes::makeStringMetatable() makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); attachMagicFunction(gmatchFunc, magicFunctionGmatch); - const TypeId matchFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), - arena->addTypePack(TypePackVar{VariadicTypePack{FFlag::LuauDeduceFindMatchReturnTypes ? stringType : optionalString}})}); + const TypeId matchFunc = arena->addType( + FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); attachMagicFunction(matchFunc, magicFunctionMatch); const TypeId findFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), @@ -1221,9 +1221,6 @@ static std::vector parsePatternString(TypeChecker& typechecker, const ch static std::optional> magicFunctionGmatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - if (!FFlag::LuauDeduceGmatchReturnTypes) - return std::nullopt; - auto [paramPack, _predicates] = withPredicate; const auto& [params, tail] = flatten(paramPack); @@ -1256,9 +1253,6 @@ static std::optional> magicFunctionGmatch( static std::optional> magicFunctionMatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - if (!FFlag::LuauDeduceFindMatchReturnTypes) - return std::nullopt; - auto [paramPack, _predicates] = withPredicate; const auto& [params, tail] = flatten(paramPack); @@ -1295,9 +1289,6 @@ static std::optional> magicFunctionMatch( static std::optional> magicFunctionFind( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - if (!FFlag::LuauDeduceFindMatchReturnTypes) - return std::nullopt; - auto [paramPack, _predicates] = withPredicate; const auto& [params, tail] = flatten(paramPack); diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index 9ce8c3dc0..c95c8eae6 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -36,7 +36,9 @@ void* pagedAllocate(size_t size) { // By default we use operator new/delete instead of malloc/free so that they can be overridden externally if (!FFlag::DebugLuauFreezeArena) + { return ::operator new(size, std::nothrow); + } // On Windows, VirtualAlloc results in 64K granularity allocations; we allocate in chunks of ~32K so aligned_malloc is a little more efficient // On Linux, we must use mmap because using regular heap results in mprotect() fragmenting the page table and us bumping into 64K mmap limit. diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index 63d8647db..fa76e8204 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -13,7 +13,8 @@ Free::Free(TypeLevel level) } Free::Free(Scope* scope) - : scope(scope) + : index(++nextIndex) + , scope(scope) { } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index b5f58c835..b135cd0c8 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -317,7 +317,8 @@ static std::optional> getTableMat return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, NotNull scope, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) +Unifier::Unifier(TypeArena* types, Mode mode, NotNull scope, const Location& location, Variance variance, UnifierSharedState& sharedState, + TxnLog* parentLog) : types(types) , mode(mode) , scope(scope) @@ -492,9 +493,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (log.get(subTy)) tryUnifyWithConstrainedSubTypeVar(subTy, superTy); - else if (const UnionTypeVar* uv = log.getMutable(subTy)) + else if (const UnionTypeVar* subUnion = log.getMutable(subTy)) { - tryUnifyUnionWithType(subTy, uv, superTy); + tryUnifyUnionWithType(subTy, subUnion, superTy); } else if (const UnionTypeVar* uv = log.getMutable(superTy)) { @@ -555,14 +556,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool log.popSeen(superTy, subTy); } -void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId superTy) +void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, TypeId superTy) { // A | B <: T if A <: T and B <: T bool failed = false; std::optional unificationTooComplex; std::optional firstFailedOption; - for (TypeId type : uv->options) + for (TypeId type : subUnion->options) { Unifier innerState = makeChildUnifier(); innerState.tryUnify_(type, superTy); @@ -608,9 +609,9 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId } }; - if (auto utv = log.getMutable(superTy)) + if (auto superUnion = log.getMutable(superTy)) { - for (TypeId ty : utv) + for (TypeId ty : superUnion) tryBind(ty); } else diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 612283fb1..070511632 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -1276,6 +1276,16 @@ class AstTypePackGeneric : public AstTypePack }; AstName getIdentifier(AstExpr*); +Location getLocation(const AstTypeList& typeList); + +template // AstNode, AstExpr, AstLocal, etc +Location getLocation(AstArray array) +{ + if (0 == array.size) + return {}; + + return Location{array.data[0]->location.begin, array.data[array.size - 1]->location.end}; +} #undef LUAU_RTTI diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 956fcf648..848d71179 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -304,6 +304,12 @@ class Parser AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); + // `parseErrorLocation` is associated with the parser error + // `astErrorLocation` is associated with the AstTypeError created + // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely + // define the location (possibly of zero size) where a type annotation is expected. + AstTypeError* reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) + LUAU_PRINTF_ATTR(4, 5); AstExpr* reportFunctionArgsError(AstExpr* func, bool self); void reportAmbiguousCallError(); diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 8291a5b11..cbed8bae1 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -952,4 +952,17 @@ AstName getIdentifier(AstExpr* node) return AstName(); } +Location getLocation(const AstTypeList& typeList) +{ + Location result; + if (typeList.types.size) + { + result = Location{typeList.types.data[0]->location, typeList.types.data[typeList.types.size - 1]->location}; + } + if (typeList.tailType) + result.end = typeList.tailType->location.end; + + return result; +} + } // namespace Luau diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index b4db8bdf4..d93f2ccb6 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -91,18 +91,8 @@ Lexeme::Lexeme(const Location& location, Type type, const char* data, size_t siz , length(unsigned(size)) , data(data) { - LUAU_ASSERT( - type == RawString - || type == QuotedString - || type == InterpStringBegin - || type == InterpStringMid - || type == InterpStringEnd - || type == InterpStringSimple - || type == BrokenInterpDoubleBrace - || type == Number - || type == Comment - || type == BlockComment - ); + LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || + type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment); } Lexeme::Lexeme(const Location& location, Type type, const char* name) @@ -644,7 +634,8 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT if (peekch(1) == '{') { - Lexeme brokenDoubleBrace = Lexeme(Location(start, position()), Lexeme::BrokenInterpDoubleBrace, &buffer[startOffset], offset - startOffset); + Lexeme brokenDoubleBrace = + Lexeme(Location(start, position()), Lexeme::BrokenInterpDoubleBrace, &buffer[startOffset], offset - startOffset); consume(); consume(); return brokenDoubleBrace; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index b6de27da5..0914054f9 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -24,6 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauLintParseIntegerIssues, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) +LUAU_FASTFLAGVARIABLE(LuauTypeAnnotationLocationChange, false) bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false; @@ -1564,44 +1565,43 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); - Location begin = lexer.current().location; + Location start = lexer.current().location; if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); - return {allocator.alloc(begin, std::nullopt, nameNil), {}}; + return {allocator.alloc(start, std::nullopt, nameNil), {}}; } else if (lexer.current().type == Lexeme::ReservedTrue) { nextLexeme(); - return {allocator.alloc(begin, true)}; + return {allocator.alloc(start, true)}; } else if (lexer.current().type == Lexeme::ReservedFalse) { nextLexeme(); - return {allocator.alloc(begin, false)}; + return {allocator.alloc(start, false)}; } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { if (std::optional> value = parseCharArray()) { AstArray svalue = *value; - return {allocator.alloc(begin, svalue)}; + return {allocator.alloc(start, svalue)}; } else - return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; + return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { parseInterpString(); - return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "Interpolated string literals cannot be used as types")}; + return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Interpolated string literals cannot be used as types")}; } else if (lexer.current().type == Lexeme::BrokenString) { - Location location = lexer.current().location; nextLexeme(); - return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, "Malformed string")}; + return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Malformed string")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1632,7 +1632,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) expectMatchAndConsume(')', typeofBegin); - return {allocator.alloc(Location(begin, end), expr), {}}; + return {allocator.alloc(Location(start, end), expr), {}}; } bool hasParameters = false; @@ -1646,7 +1646,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) Location end = lexer.previousLocation(); - return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; + return {allocator.alloc(Location(start, end), prefix, name.name, hasParameters, parameters), {}}; } else if (lexer.current().type == '{') { @@ -1658,23 +1658,35 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) } else if (lexer.current().type == Lexeme::ReservedFunction) { - Location location = lexer.current().location; - nextLexeme(); - return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, + return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " "...any'"), {}}; } else { - Location location = lexer.current().location; + if (FFlag::LuauTypeAnnotationLocationChange) + { + // For a missing type annotation, capture 'space' between last token and the next one + Location astErrorlocation(lexer.previousLocation().end, start.begin); + // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). + // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. + Location parseErrorLocation(lexer.previousLocation().end, start.end); + return { + reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), + {}}; + } + else + { + Location location = lexer.current().location; - // For a missing type annotation, capture 'space' between last token and the next one - location = Location(lexer.previousLocation().end, lexer.current().location.begin); + // For a missing type annotation, capture 'space' between last token and the next one + location = Location(lexer.previousLocation().end, lexer.current().location.begin); - return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; + return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; + } } } @@ -2245,7 +2257,8 @@ AstExpr* Parser::parseSimpleExpr() { return parseNumber(); } - else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || (FFlag::LuauInterpolatedStringBaseSupport && lexer.current().type == Lexeme::InterpStringSimple)) + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || + (FFlag::LuauInterpolatedStringBaseSupport && lexer.current().type == Lexeme::InterpStringSimple)) { return parseString(); } @@ -2653,7 +2666,8 @@ AstArray Parser::parseTypeParams() std::optional> Parser::parseCharArray() { - LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::InterpStringSimple); + LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || + lexer.current().type == Lexeme::InterpStringSimple); scratchData.assign(lexer.current().data, lexer.current().length); @@ -2691,14 +2705,11 @@ AstExpr* Parser::parseInterpString() Location startLocation = lexer.current().location; - do { + do + { Lexeme currentLexeme = lexer.current(); - LUAU_ASSERT( - currentLexeme.type == Lexeme::InterpStringBegin - || currentLexeme.type == Lexeme::InterpStringMid - || currentLexeme.type == Lexeme::InterpStringEnd - || currentLexeme.type == Lexeme::InterpStringSimple - ); + LUAU_ASSERT(currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || + currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple); Location location = currentLexeme.location; @@ -2973,8 +2984,7 @@ bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begi { // If the token matches on a different line and a different column, it suggests misleading indentation // This can be used to pinpoint the problem location for a possible future *actual* mismatch - if (lexer.current().location.begin.line != begin.position.line && - lexer.current().location.begin.column != begin.position.column && + if (lexer.current().location.begin.line != begin.position.line && lexer.current().location.begin.column != begin.position.column && endMismatchSuspect.position.line < begin.position.line) // Only replace the previous suspect with more recent suspects { endMismatchSuspect = begin; @@ -3108,6 +3118,13 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) { + if (FFlag::LuauTypeAnnotationLocationChange) + { + // Missing type annotations should be using `reportMissingTypeAnnotationError` when LuauTypeAnnotationLocationChange is enabled + // Note: `isMissing` can be removed once FFlag::LuauTypeAnnotationLocationChange is removed since it will always be true. + LUAU_ASSERT(!isMissing); + } + va_list args; va_start(args, format); report(location, format, args); @@ -3116,6 +3133,18 @@ AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const return allocator.alloc(location, types, isMissing, unsigned(parseErrors.size() - 1)); } +AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) +{ + LUAU_ASSERT(FFlag::LuauTypeAnnotationLocationChange); + + va_list args; + va_start(args, format); + report(parseErrorLocation, format, args); + va_end(args); + + return allocator.alloc(astErrorLocation, AstArray{}, true, unsigned(parseErrors.size() - 1)); +} + void Parser::nextLexeme() { Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type; diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index d6660f05c..decde93fa 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -414,7 +414,7 @@ enum LuauBytecodeTag // Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled LBC_VERSION_MIN = 2, LBC_VERSION_MAX = 3, - LBC_VERSION_TARGET = 2, + LBC_VERSION_TARGET = 3, // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 809c78dac..ce47cd9ad 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,7 +13,8 @@ inline bool isFlagExperimental(const char* flag) static const char* kList[] = { "LuauLowerBoundsCalculation", "LuauInterpolatedStringBaseSupport", - nullptr, // makes sure we always have at least one entry + // makes sure we always have at least one entry + nullptr, }; for (const char* item : kList) diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 46ab26488..713d08cdb 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileBytecodeV3, false) - namespace Luau { @@ -1079,9 +1077,6 @@ std::string BytecodeBuilder::getError(const std::string& message) uint8_t BytecodeBuilder::getVersion() { - if (FFlag::LuauCompileBytecodeV3) - return 3; - // This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags return LBC_VERSION_TARGET; } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 4429e4ccf..d44daf0cf 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -25,14 +25,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileXEQ, false) - LUAU_FASTFLAG(LuauInterpolatedStringBaseSupport) -LUAU_FASTFLAGVARIABLE(LuauCompileOptimalAssignment, false) - -LUAU_FASTFLAGVARIABLE(LuauCompileExtractK, false) - namespace Luau { @@ -406,47 +400,29 @@ struct Compiler } } - void compileExprFastcallN(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs, int bfid, int bfK = -1) + void compileExprFastcallN( + AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs, int bfid, int bfK = -1) { LUAU_ASSERT(!expr->self); LUAU_ASSERT(expr->args.size >= 1); LUAU_ASSERT(expr->args.size <= 2 || (bfid == LBF_BIT32_EXTRACTK && expr->args.size == 3)); LUAU_ASSERT(bfid == LBF_BIT32_EXTRACTK ? bfK >= 0 : bfK < 0); - LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2; - - if (FFlag::LuauCompileExtractK) - { - opc = expr->args.size == 1 ? LOP_FASTCALL1 : (bfK >= 0 || isConstant(expr->args.data[1])) ? LOP_FASTCALL2K : LOP_FASTCALL2; - } + LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : (bfK >= 0 || isConstant(expr->args.data[1])) ? LOP_FASTCALL2K : LOP_FASTCALL2; uint32_t args[3] = {}; for (size_t i = 0; i < expr->args.size; ++i) { - if (FFlag::LuauCompileExtractK) + if (i > 0 && opc == LOP_FASTCALL2K) { - if (i > 0 && opc == LOP_FASTCALL2K) - { - int32_t cid = getConstantIndex(expr->args.data[i]); - if (cid < 0) - CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + int32_t cid = getConstantIndex(expr->args.data[i]); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - args[i] = cid; - continue; // TODO: remove this and change if below to else if - } - } - else if (i > 0) - { - if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0) - { - opc = LOP_FASTCALL2K; - args[i] = cid; - break; - } + args[i] = cid; } - - if (int reg = getExprLocalReg(expr->args.data[i]); reg >= 0) + else if (int reg = getExprLocalReg(expr->args.data[i]); reg >= 0) { args[i] = uint8_t(reg); } @@ -468,24 +444,10 @@ struct Compiler // these FASTCALL variants. for (size_t i = 0; i < expr->args.size; ++i) { - if (FFlag::LuauCompileExtractK) - { - if (i > 0 && opc == LOP_FASTCALL2K) - emitLoadK(uint8_t(regs + 1 + i), args[i]); - else if (args[i] != regs + 1 + i) - bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); - } - else - { - if (i > 0 && opc == LOP_FASTCALL2K) - { - emitLoadK(uint8_t(regs + 1 + i), args[i]); - break; - } - - if (args[i] != regs + 1 + i) - bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); - } + if (i > 0 && opc == LOP_FASTCALL2K) + emitLoadK(uint8_t(regs + 1 + i), args[i]); + else if (args[i] != regs + 1 + i) + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); } // note, these instructions are normally not executed and are used as a fallback for FASTCALL @@ -758,7 +720,7 @@ struct Compiler } // Optimization: for bit32.extract with constant in-range f/w we compile using FASTCALL2K and a special builtin - if (FFlag::LuauCompileExtractK && bfid == LBF_BIT32_EXTRACT && expr->args.size == 3 && isConstant(expr->args.data[1]) && isConstant(expr->args.data[2])) + if (bfid == LBF_BIT32_EXTRACT && expr->args.size == 3 && isConstant(expr->args.data[1]) && isConstant(expr->args.data[2])) { Constant fc = getConstant(expr->args.data[1]); Constant wc = getConstant(expr->args.data[2]); @@ -1080,102 +1042,64 @@ struct Compiler std::swap(left, right); } - if (FFlag::LuauCompileXEQ) - { - uint8_t rl = compileExprAuto(left, rs); - - if (isEq && operandIsConstant) - { - const Constant* cv = constants.find(right); - LUAU_ASSERT(cv && cv->type != Constant::Type_Unknown); + uint8_t rl = compileExprAuto(left, rs); - LuauOpcode opc = LOP_NOP; - int32_t cid = -1; - uint32_t flip = (expr->op == AstExprBinary::CompareEq) == not_ ? 0x80000000 : 0; - - switch (cv->type) - { - case Constant::Type_Nil: - opc = LOP_JUMPXEQKNIL; - cid = 0; - break; - - case Constant::Type_Boolean: - opc = LOP_JUMPXEQKB; - cid = cv->valueBoolean; - break; - - case Constant::Type_Number: - opc = LOP_JUMPXEQKN; - cid = getConstantIndex(right); - break; + if (isEq && operandIsConstant) + { + const Constant* cv = constants.find(right); + LUAU_ASSERT(cv && cv->type != Constant::Type_Unknown); - case Constant::Type_String: - opc = LOP_JUMPXEQKS; - cid = getConstantIndex(right); - break; + LuauOpcode opc = LOP_NOP; + int32_t cid = -1; + uint32_t flip = (expr->op == AstExprBinary::CompareEq) == not_ ? 0x80000000 : 0; - default: - LUAU_ASSERT(!"Unexpected constant type"); - } + switch (cv->type) + { + case Constant::Type_Nil: + opc = LOP_JUMPXEQKNIL; + cid = 0; + break; - if (cid < 0) - CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + case Constant::Type_Boolean: + opc = LOP_JUMPXEQKB; + cid = cv->valueBoolean; + break; - size_t jumpLabel = bytecode.emitLabel(); + case Constant::Type_Number: + opc = LOP_JUMPXEQKN; + cid = getConstantIndex(right); + break; - bytecode.emitAD(opc, rl, 0); - bytecode.emitAux(cid | flip); + case Constant::Type_String: + opc = LOP_JUMPXEQKS; + cid = getConstantIndex(right); + break; - return jumpLabel; + default: + LUAU_ASSERT(!"Unexpected constant type"); } - else - { - LuauOpcode opc = getJumpOpCompare(expr->op, not_); - uint8_t rr = compileExprAuto(right, rs); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - size_t jumpLabel = bytecode.emitLabel(); + size_t jumpLabel = bytecode.emitLabel(); - if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe) - { - bytecode.emitAD(opc, rr, 0); - bytecode.emitAux(rl); - } - else - { - bytecode.emitAD(opc, rl, 0); - bytecode.emitAux(rr); - } + bytecode.emitAD(opc, rl, 0); + bytecode.emitAux(cid | flip); - return jumpLabel; - } + return jumpLabel; } else { LuauOpcode opc = getJumpOpCompare(expr->op, not_); - uint8_t rl = compileExprAuto(left, rs); - int32_t rr = -1; - - if (isEq && operandIsConstant) - { - if (opc == LOP_JUMPIFEQ) - opc = LOP_JUMPIFEQK; - else if (opc == LOP_JUMPIFNOTEQ) - opc = LOP_JUMPIFNOTEQK; - - rr = getConstantIndex(right); - LUAU_ASSERT(rr >= 0); - } - else - rr = compileExprAuto(right, rs); + uint8_t rr = compileExprAuto(right, rs); size_t jumpLabel = bytecode.emitLabel(); if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe) { - bytecode.emitAD(opc, uint8_t(rr), 0); + bytecode.emitAD(opc, rr, 0); bytecode.emitAux(rl); } else @@ -2979,62 +2903,6 @@ struct Compiler loops.pop_back(); } - void resolveAssignConflicts(AstStat* stat, std::vector& vars) - { - LUAU_ASSERT(!FFlag::LuauCompileOptimalAssignment); - - // regsUsed[i] is true if we have assigned the register during earlier assignments - // regsRemap[i] is set to the register where the original (pre-assignment) copy was made - // note: regsRemap is uninitialized intentionally to speed small assignments up; regsRemap[i] is valid iff regsUsed[i] - std::bitset<256> regsUsed; - uint8_t regsRemap[256]; - - for (size_t i = 0; i < vars.size(); ++i) - { - LValue& li = vars[i]; - - if (li.kind == LValue::Kind_Local) - { - if (!regsUsed[li.reg]) - { - regsUsed[li.reg] = true; - regsRemap[li.reg] = li.reg; - } - } - else if (li.kind == LValue::Kind_IndexName || li.kind == LValue::Kind_IndexNumber || li.kind == LValue::Kind_IndexExpr) - { - // we're looking for assignments before this one that invalidate any of the registers involved - if (regsUsed[li.reg]) - { - // the register may have been evacuated previously, but if it wasn't - move it now - if (regsRemap[li.reg] == li.reg) - { - uint8_t reg = allocReg(stat, 1); - bytecode.emitABC(LOP_MOVE, reg, li.reg, 0); - - regsRemap[li.reg] = reg; - } - - li.reg = regsRemap[li.reg]; - } - - if (li.kind == LValue::Kind_IndexExpr && regsUsed[li.index]) - { - // the register may have been evacuated previously, but if it wasn't - move it now - if (regsRemap[li.index] == li.index) - { - uint8_t reg = allocReg(stat, 1); - bytecode.emitABC(LOP_MOVE, reg, li.index, 0); - - regsRemap[li.index] = reg; - } - - li.index = regsRemap[li.index]; - } - } - } - } - struct Assignment { LValue lvalue; @@ -3146,110 +3014,81 @@ struct Compiler return; } - if (FFlag::LuauCompileOptimalAssignment) + // compute all l-values: note that this doesn't assign anything yet but it allocates registers and computes complex expressions on the + // left hand side - for example, in "a[expr] = foo" expr will get evaluated here + std::vector vars(stat->vars.size); + + for (size_t i = 0; i < stat->vars.size; ++i) + vars[i].lvalue = compileLValue(stat->vars.data[i], rs); + + // perform conflict resolution: if any expression refers to a local that is assigned before evaluating it, we assign to a temporary + // register after this, vars[i].conflictReg is set for locals that need to be assigned in the second pass + resolveAssignConflicts(stat, vars, stat->values); + + // compute rhs into (mostly) fresh registers + // note that when the lhs assigment is a local, we evaluate directly into that register + // this is possible because resolveAssignConflicts renamed conflicting locals into temporaries + // after this, vars[i].valueReg is set to a register with the value for *all* vars, but some have already been assigned + for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i) { - // compute all l-values: note that this doesn't assign anything yet but it allocates registers and computes complex expressions on the - // left hand side - for example, in "a[expr] = foo" expr will get evaluated here - std::vector vars(stat->vars.size); + AstExpr* value = stat->values.data[i]; - for (size_t i = 0; i < stat->vars.size; ++i) - vars[i].lvalue = compileLValue(stat->vars.data[i], rs); + if (i + 1 == stat->values.size && stat->vars.size > stat->values.size) + { + // allocate a consecutive range of regs for all remaining vars and compute everything into temps + // note, this also handles trailing nils + uint8_t rest = uint8_t(stat->vars.size - stat->values.size + 1); + uint8_t temp = allocReg(stat, rest); - // perform conflict resolution: if any expression refers to a local that is assigned before evaluating it, we assign to a temporary - // register after this, vars[i].conflictReg is set for locals that need to be assigned in the second pass - resolveAssignConflicts(stat, vars, stat->values); + compileExprTempN(value, temp, rest, /* targetTop= */ true); - // compute rhs into (mostly) fresh registers - // note that when the lhs assigment is a local, we evaluate directly into that register - // this is possible because resolveAssignConflicts renamed conflicting locals into temporaries - // after this, vars[i].valueReg is set to a register with the value for *all* vars, but some have already been assigned - for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i) + for (size_t j = i; j < stat->vars.size; ++j) + vars[j].valueReg = uint8_t(temp + (j - i)); + } + else { - AstExpr* value = stat->values.data[i]; + Assignment& var = vars[i]; - if (i + 1 == stat->values.size && stat->vars.size > stat->values.size) + // if target is a local, use compileExpr directly to target + if (var.lvalue.kind == LValue::Kind_Local) { - // allocate a consecutive range of regs for all remaining vars and compute everything into temps - // note, this also handles trailing nils - uint8_t rest = uint8_t(stat->vars.size - stat->values.size + 1); - uint8_t temp = allocReg(stat, rest); - - compileExprTempN(value, temp, rest, /* targetTop= */ true); + var.valueReg = (var.conflictReg == kInvalidReg) ? var.lvalue.reg : var.conflictReg; - for (size_t j = i; j < stat->vars.size; ++j) - vars[j].valueReg = uint8_t(temp + (j - i)); + compileExpr(stat->values.data[i], var.valueReg); } else { - Assignment& var = vars[i]; - - // if target is a local, use compileExpr directly to target - if (var.lvalue.kind == LValue::Kind_Local) - { - var.valueReg = (var.conflictReg == kInvalidReg) ? var.lvalue.reg : var.conflictReg; - - compileExpr(stat->values.data[i], var.valueReg); - } - else - { - var.valueReg = compileExprAuto(stat->values.data[i], rs); - } + var.valueReg = compileExprAuto(stat->values.data[i], rs); } } + } - // compute expressions with side effects for lulz - for (size_t i = stat->vars.size; i < stat->values.size; ++i) - { - RegScope rsi(this); - compileExprAuto(stat->values.data[i], rsi); - } - - // almost done... let's assign everything left to right, noting that locals were either written-to directly, or will be written-to in a - // separate pass to avoid conflicts - for (const Assignment& var : vars) - { - LUAU_ASSERT(var.valueReg != kInvalidReg); + // compute expressions with side effects for lulz + for (size_t i = stat->vars.size; i < stat->values.size; ++i) + { + RegScope rsi(this); + compileExprAuto(stat->values.data[i], rsi); + } - if (var.lvalue.kind != LValue::Kind_Local) - { - setDebugLine(var.lvalue.location); - compileAssign(var.lvalue, var.valueReg); - } - } + // almost done... let's assign everything left to right, noting that locals were either written-to directly, or will be written-to in a + // separate pass to avoid conflicts + for (const Assignment& var : vars) + { + LUAU_ASSERT(var.valueReg != kInvalidReg); - // all regular local writes are done by the prior loops by computing result directly into target, so this just handles conflicts OR - // local copies from temporary registers in multret context, since in that case we have to allocate consecutive temporaries - for (const Assignment& var : vars) + if (var.lvalue.kind != LValue::Kind_Local) { - if (var.lvalue.kind == LValue::Kind_Local && var.valueReg != var.lvalue.reg) - bytecode.emitABC(LOP_MOVE, var.lvalue.reg, var.valueReg, 0); + setDebugLine(var.lvalue.location); + compileAssign(var.lvalue, var.valueReg); } } - else - { - // compute all l-values: note that this doesn't assign anything yet but it allocates registers and computes complex expressions on the - // left hand side for example, in "a[expr] = foo" expr will get evaluated here - std::vector vars(stat->vars.size); - - for (size_t i = 0; i < stat->vars.size; ++i) - vars[i] = compileLValue(stat->vars.data[i], rs); - // perform conflict resolution: if any lvalue refers to a local reg that will be reassigned before that, we save the local variable in a - // temporary reg - resolveAssignConflicts(stat, vars); - - // compute values into temporaries - uint8_t regs = allocReg(stat, unsigned(stat->vars.size)); - - compileExprListTemp(stat->values, regs, uint8_t(stat->vars.size), /* targetTop= */ true); - - // assign variables that have associated values; note that if we have fewer values than variables, we'll assign nil because - // compileExprListTemp will generate nils - for (size_t i = 0; i < stat->vars.size; ++i) - { - setDebugLine(stat->vars.data[i]); - compileAssign(vars[i], uint8_t(regs + i)); - } + // all regular local writes are done by the prior loops by computing result directly into target, so this just handles conflicts OR + // local copies from temporary registers in multret context, since in that case we have to allocate consecutive temporaries + for (const Assignment& var : vars) + { + if (var.lvalue.kind == LValue::Kind_Local && var.valueReg != var.lvalue.reg) + bytecode.emitABC(LOP_MOVE, var.lvalue.reg, var.valueReg, 0); } } diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 422c82b13..c16e5aa70 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -15,8 +15,6 @@ #include #endif -LUAU_FASTFLAGVARIABLE(LuauFasterBit32NoWidth, false) - // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -602,7 +600,7 @@ static int luauF_btest(lua_State* L, StkId res, TValue* arg0, int nresults, StkI static int luauF_extract(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (nparams >= (3 - FFlag::LuauFasterBit32NoWidth) && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) { double a1 = nvalue(arg0); double a2 = nvalue(args); @@ -693,7 +691,7 @@ static int luauF_lshift(lua_State* L, StkId res, TValue* arg0, int nresults, Stk static int luauF_replace(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (nparams >= (4 - FFlag::LuauFasterBit32NoWidth) && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) { double a1 = nvalue(arg0); double a2 = nvalue(args); diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index b3ea1094f..192ea0b5c 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauTostringFormatSpecifier, false); - // macro to `unsign' a character #define uchar(c) ((unsigned char)(c)) @@ -1036,9 +1034,6 @@ static int str_format(lua_State* L) } case '*': { - if (!FFlag::LuauTostringFormatSpecifier) - luaL_error(L, "invalid option '%%*' to 'format'"); - if (formatItemSize != 1) luaL_error(L, "'%%*' does not take a form"); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 7306b055b..aa1da8aee 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -3031,7 +3031,16 @@ static void luau_execute(lua_State* L) TValue* kv = VM_KV(aux & 0xffffff); LUAU_ASSERT(ttisnumber(kv)); +#if defined(__aarch64__) + // On several ARM chips (Apple M1/M2, Neoverse N1), comparing the result of a floating-point comparison is expensive, and a branch + // is much cheaper; on some 32-bit ARM chips (Cortex A53) the performance is about the same so we prefer less branchy variant there + if (aux >> 31) + pc += !(ttisnumber(ra) && nvalue(ra) == nvalue(kv)) ? LUAU_INSN_D(insn) : 1; + else + pc += (ttisnumber(ra) && nvalue(ra) == nvalue(kv)) ? LUAU_INSN_D(insn) : 1; +#else pc += int(ttisnumber(ra) && nvalue(ra) == nvalue(kv)) != (aux >> 31) ? LUAU_INSN_D(insn) : 1; +#endif LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); VM_NEXT(); } diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 3f75ebac3..08f241ed7 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -4,7 +4,6 @@ #include "doctest.h" -#include #include using namespace Luau::CodeGen; @@ -22,7 +21,7 @@ std::string bytecodeAsArray(const std::vector& bytecode) class AssemblyBuilderX64Fixture { public: - void check(std::function f, std::vector code, std::vector data = {}) + void check(void (*f)(AssemblyBuilderX64& build), std::vector code, std::vector data = {}) { AssemblyBuilderX64 build(/* logText= */ false); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 0a3c6507c..d8520a6e0 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -1241,8 +1241,7 @@ TEST_CASE("InterpStringZeroCost") { ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; - CHECK_EQ( - "\n" + compileFunction0(R"(local _ = `hello, {"world"}!`)"), + CHECK_EQ("\n" + compileFunction0(R"(local _ = `hello, {"world"}!`)"), R"( LOADK R1 K0 LOADK R3 K1 @@ -1250,16 +1249,14 @@ NAMECALL R1 R1 K2 CALL R1 2 1 MOVE R0 R1 RETURN R0 0 -)" - ); +)"); } TEST_CASE("InterpStringRegisterCleanup") { ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; - CHECK_EQ( - "\n" + compileFunction0(R"( + CHECK_EQ("\n" + compileFunction0(R"( local a, b, c = nil, "um", "uh oh" a = `foo{"bar"}` print(a) @@ -1278,8 +1275,7 @@ GETIMPORT R3 6 MOVE R4 R0 CALL R3 1 0 RETURN R0 0 -)" - ); +)"); } TEST_CASE("ConstantFoldArith") @@ -2488,8 +2484,6 @@ end TEST_CASE("DebugLineInfoRepeatUntil") { - ScopedFastFlag sff("LuauCompileXEQ", true); - CHECK_EQ("\n" + compileFunction0Coverage(R"( local f = 0 repeat @@ -2834,8 +2828,6 @@ RETURN R0 0 TEST_CASE("AssignmentConflict") { - ScopedFastFlag sff("LuauCompileOptimalAssignment", true); - // assignments are left to right CHECK_EQ("\n" + compileFunction0("local a, b a, b = 1, 2"), R"( LOADNIL R0 @@ -3610,8 +3602,6 @@ RETURN R0 1 TEST_CASE("ConstantJumpCompare") { - ScopedFastFlag sff("LuauCompileXEQ", true); - CHECK_EQ("\n" + compileFunction0(R"( local obj = ... local b = obj == 1 @@ -6210,8 +6200,6 @@ L4: RETURN R0 -1 TEST_CASE("BuiltinFoldingMultret") { - ScopedFastFlag sff("LuauCompileXEQ", true); - CHECK_EQ("\n" + compileFunction(R"( local NoLanes: Lanes = --[[ ]] 0b0000000000000000000000000000000 local OffscreenLane: Lane = --[[ ]] 0b1000000000000000000000000000000 @@ -6350,8 +6338,6 @@ RETURN R2 1 TEST_CASE("MultipleAssignments") { - ScopedFastFlag sff("LuauCompileOptimalAssignment", true); - // order of assignments is left to right CHECK_EQ("\n" + compileFunction0(R"( local a, b @@ -6574,15 +6560,14 @@ RETURN R0 0 TEST_CASE("BuiltinExtractK") { - ScopedFastFlag sff("LuauCompileExtractK", true); - // below, K0 refers to a packed f+w constant for bit32.extractk builtin // K1 and K2 refer to 1 and 3 and are only used during fallback path CHECK_EQ("\n" + compileFunction0(R"( local v = ... return bit32.extract(v, 1, 3) -)"), R"( +)"), + R"( GETVARARGS R0 1 FASTCALL2K 59 R0 K0 L0 MOVE R2 R0 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index f6f5b41f5..a7ffb493d 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -289,15 +289,12 @@ TEST_CASE("Clear") TEST_CASE("Strings") { - ScopedFastFlag sff{"LuauTostringFormatSpecifier", true}; - runConformance("strings.lua"); } TEST_CASE("StringInterp") { ScopedFastFlag sffInterpStrings{"LuauInterpolatedStringBaseSupport", true}; - ScopedFastFlag sffTostringFormat{"LuauTostringFormatSpecifier", true}; runConformance("stringinterp.lua"); } @@ -725,13 +722,16 @@ TEST_CASE("NewUserdataOverflow") StateRef globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); - lua_pushcfunction(L, [](lua_State* L1) { - // The following userdata request might cause an overflow. - lua_newuserdatadtor(L1, SIZE_MAX, [](void* d){}); - // The overflow might segfault in the following call. - lua_getmetatable(L1, -1); - return 0; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L1) { + // The following userdata request might cause an overflow. + lua_newuserdatadtor(L1, SIZE_MAX, [](void* d) {}); + // The overflow might segfault in the following call. + lua_getmetatable(L1, -1); + return 0; + }, + nullptr); CHECK(lua_pcall(L, 0, 0, 0) == LUA_ERRRUN); CHECK(strcmp(lua_tostring(L, -1), "memory allocation error: block too big") == 0); diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp index 00c3309ca..bbe294290 100644 --- a/tests/ConstraintGraphBuilder.test.cpp +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -57,13 +57,12 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "nil_primitive") auto constraints = collectConstraints(NotNull(cgb.rootScope)); ToStringOptions opts; - REQUIRE(5 <= constraints.size()); + REQUIRE(4 <= constraints.size()); CHECK("*blocked-1* ~ gen () -> (a...)" == toString(*constraints[0], opts)); - CHECK("*blocked-2* ~ inst *blocked-1*" == toString(*constraints[1], opts)); - CHECK("() -> (b...) <: *blocked-2*" == toString(*constraints[2], opts)); - CHECK("b... <: c" == toString(*constraints[3], opts)); - CHECK("nil <: a..." == toString(*constraints[4], opts)); + CHECK("call *blocked-1* with { result = *blocked-tp-1* }" == toString(*constraints[1], opts)); + CHECK("*blocked-tp-1* <: b" == toString(*constraints[2], opts)); + CHECK("nil <: a..." == toString(*constraints[3], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") @@ -76,13 +75,12 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") cgb.visit(block); auto constraints = collectConstraints(NotNull(cgb.rootScope)); - REQUIRE(4 == constraints.size()); + REQUIRE(3 == constraints.size()); ToStringOptions opts; CHECK("string <: a" == toString(*constraints[0], opts)); - CHECK("*blocked-1* ~ inst a" == toString(*constraints[1], opts)); - CHECK("(string) -> (b...) <: *blocked-1*" == toString(*constraints[2], opts)); - CHECK("b... <: c" == toString(*constraints[3], opts)); + CHECK("call a with { result = *blocked-tp-1* }" == toString(*constraints[1], opts)); + CHECK("*blocked-tp-1* <: b" == toString(*constraints[2], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") @@ -114,13 +112,12 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") cgb.visit(block); auto constraints = collectConstraints(NotNull(cgb.rootScope)); - REQUIRE(4 == constraints.size()); + REQUIRE(3 == constraints.size()); ToStringOptions opts; CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); - CHECK("*blocked-2* ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); - CHECK("(a) -> (c...) <: *blocked-2*" == toString(*constraints[2], opts)); - CHECK("c... <: b..." == toString(*constraints[3], opts)); + CHECK("call (a) -> (b...) with { result = *blocked-tp-1* }" == toString(*constraints[1], opts)); + CHECK("*blocked-tp-1* <: b..." == toString(*constraints[2], opts)); } TEST_SUITE_END(); diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index e33f6570e..2c4897330 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -28,7 +28,8 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") cgb.visit(block); NotNull rootScope = NotNull(cgb.rootScope); - ConstraintSolver cs{&arena, rootScope}; + NullModuleResolver resolver; + ConstraintSolver cs{&arena, rootScope, "MainModule", NotNull(&resolver), {}}; cs.run(); @@ -48,7 +49,8 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") cgb.visit(block); NotNull rootScope = NotNull(cgb.rootScope); - ConstraintSolver cs{&arena, rootScope}; + NullModuleResolver resolver; + ConstraintSolver cs{&arena, rootScope, "MainModule", NotNull(&resolver), {}}; cs.run(); @@ -57,7 +59,6 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") CHECK("(a) -> a" == toString(idType)); } -#if 1 TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") { AstStatBlock* block = parse(R"( @@ -77,7 +78,8 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") ToStringOptions opts; - ConstraintSolver cs{&arena, rootScope}; + NullModuleResolver resolver; + ConstraintSolver cs{&arena, rootScope, "MainModule", NotNull(&resolver), {}}; cs.run(); @@ -85,6 +87,5 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") CHECK("(a) -> number" == toString(idType, opts)); } -#endif TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 4051f8512..476b7a2a5 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -2,6 +2,8 @@ #include "Fixture.h" #include "Luau/AstQuery.h" +#include "Luau/ModuleResolver.h" +#include "Luau/NotNull.h" #include "Luau/Parser.h" #include "Luau/TypeVar.h" #include "Luau/TypeAttach.h" @@ -444,10 +446,11 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() : Fixture() , mainModule(new Module) - , cgb(mainModuleName, mainModule, &arena, NotNull(&ice), frontend.getGlobalScope()) + , cgb(mainModuleName, mainModule, &arena, NotNull(&moduleResolver), NotNull(&ice), frontend.getGlobalScope()) , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { BlockedTypeVar::nextIndex = 0; + BlockedTypePack::nextIndex = 0; } ModuleName fromString(std::string_view name) diff --git a/tests/Fixture.h b/tests/Fixture.h index e82ebf000..8923b2085 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -132,6 +132,7 @@ struct Fixture TestFileResolver fileResolver; TestConfigResolver configResolver; + NullModuleResolver moduleResolver; std::unique_ptr sourceModule; Frontend frontend; InternalErrorReporter ice; @@ -213,7 +214,7 @@ Nth nth(int nth = 1) struct FindNthOccurenceOf : public AstVisitor { Nth requestedNth; - size_t currentOccurrence = 0; + int currentOccurrence = 0; AstNode* theNode = nullptr; FindNthOccurenceOf(Nth nth); @@ -244,7 +245,7 @@ struct FindNthOccurenceOf : public AstVisitor * 2. Luau::query(Luau::query(block)) * 3. Luau::query(block, {nth(2)}) */ -template +template T* query(AstNode* node, const std::vector& nths = {nth(N)}) { static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); diff --git a/tests/JsonEmitter.test.cpp b/tests/JsonEmitter.test.cpp index ebe832093..ff9a59552 100644 --- a/tests/JsonEmitter.test.cpp +++ b/tests/JsonEmitter.test.cpp @@ -94,7 +94,7 @@ TEST_CASE("write_optional") write(emitter, std::optional{true}); emitter.writeComma(); write(emitter, std::nullopt); - + CHECK(emitter.str() == "true,null"); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index a7d09e8f9..b560c89e3 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1181,7 +1181,7 @@ s:match("[]") nons:match("[]") )~"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE_EQ(result.warnings.size(), 2); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: expected ] at the end of the string to close a set"); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: expected ] at the end of the string to close a set"); @@ -1746,6 +1746,7 @@ local _ = not a == b local _ = not a ~= b local _ = not a <= b local _ = a <= b == 0 +local _ = a <= b <= 0 local _ = not a == not b -- weird but ok @@ -1760,11 +1761,12 @@ local _ = (a <= b) == 0 local _ = a <= (b == 0) )"); - REQUIRE_EQ(result.warnings.size(), 4); - CHECK_EQ(result.warnings[0].text, "not X == Y is equivalent to (not X) == Y; consider using X ~= Y, or wrap one of the expressions in parentheses to silence"); - CHECK_EQ(result.warnings[1].text, "not X ~= Y is equivalent to (not X) ~= Y; consider using X == Y, or wrap one of the expressions in parentheses to silence"); - CHECK_EQ(result.warnings[2].text, "not X <= Y is equivalent to (not X) <= Y; wrap one of the expressions in parentheses to silence"); - CHECK_EQ(result.warnings[3].text, "X <= Y == Z is equivalent to (X <= Y) == Z; wrap one of the expressions in parentheses to silence"); + REQUIRE_EQ(result.warnings.size(), 5); + CHECK_EQ(result.warnings[0].text, "not X == Y is equivalent to (not X) == Y; consider using X ~= Y, or add parentheses to silence"); + CHECK_EQ(result.warnings[1].text, "not X ~= Y is equivalent to (not X) ~= Y; consider using X == Y, or add parentheses to silence"); + CHECK_EQ(result.warnings[2].text, "not X <= Y is equivalent to (not X) <= Y; add parentheses to silence"); + CHECK_EQ(result.warnings[3].text, "X <= Y == Z is equivalent to (X <= Y) == Z; add parentheses to silence"); + CHECK_EQ(result.warnings[4].text, "X <= Y <= Z is equivalent to (X <= Y) <= Z; did you mean X <= Y and Y <= Z?"); } TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 2dd477080..44a6b4acf 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -943,8 +943,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace") { ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; - auto columnOfEndBraceError = [this](const char* code) - { + auto columnOfEndBraceError = [this](const char* code) { try { parse(code); @@ -1737,6 +1736,48 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_type_annotation") matchParseError("local a : 2 = 2", "Expected type, got '2'"); } +TEST_CASE_FIXTURE(Fixture, "parse_error_missing_type_annotation") +{ + ScopedFastFlag LuauTypeAnnotationLocationChange{"LuauTypeAnnotationLocationChange", true}; + + { + ParseResult result = tryParse("local x:"); + CHECK(result.errors.size() == 1); + Position begin = result.errors[0].getLocation().begin; + Position end = result.errors[0].getLocation().end; + CHECK(begin.line == end.line); + int width = end.column - begin.column; + CHECK(width == 0); + CHECK(result.errors[0].getMessage() == "Expected type, got "); + } + + { + ParseResult result = tryParse(R"( +local x:=42 + )"); + CHECK(result.errors.size() == 1); + Position begin = result.errors[0].getLocation().begin; + Position end = result.errors[0].getLocation().end; + CHECK(begin.line == end.line); + int width = end.column - begin.column; + CHECK(width == 1); // Length of `=` + CHECK(result.errors[0].getMessage() == "Expected type, got '='"); + } + + { + ParseResult result = tryParse(R"( +function func():end + )"); + CHECK(result.errors.size() == 1); + Position begin = result.errors[0].getLocation().begin; + Position end = result.errors[0].getLocation().end; + CHECK(begin.line == end.line); + int width = end.column - begin.column; + CHECK(width == 3); // Length of `end` + CHECK(result.errors[0].getMessage() == "Expected type, got 'end'"); + } +} + TEST_CASE_FIXTURE(Fixture, "parse_declarations") { AstStatBlock* stat = parseEx(R"( diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index f47661040..8c6f2e4fa 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -346,4 +346,17 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_types_regression_test") +{ + ScopedFastFlag LuauUnionOfTypesFollow{"LuauUnionOfTypesFollow", true}; + + CheckResult result = check(R"( +--!strict +local stat +stat = stat and tonumber(stat) or stat + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 12fb4aa2c..07a04363f 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1061,7 +1061,6 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") { - ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; CheckResult result = check(R"END( local a, b, c = string.gmatch("This is a string", "(.()(%a+))")() )END"); @@ -1075,7 +1074,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types2") { - ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; CheckResult result = check(R"END( local a, b, c = ("This is a string"):gmatch("(.()(%a+))")() )END"); @@ -1089,7 +1087,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types2") TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") { - ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; CheckResult result = check(R"END( local a, b, c, d = string.gmatch("T(his)() is a string", ".")() )END"); @@ -1107,7 +1104,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens") { - ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; CheckResult result = check(R"END( local a, b, c, d = string.gmatch("T(his) is a string", "((.)%b()())")() )END"); @@ -1127,7 +1123,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_ignored") { - ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; CheckResult result = check(R"END( local a, b, c = string.gmatch("T(his)() is a string", "(T[()])()")() )END"); @@ -1146,7 +1141,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_igno TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket") { - ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; CheckResult result = check(R"END( local a, b = string.gmatch("[[[", "()([[])")() )END"); @@ -1196,7 +1190,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_invalid_pattern_fallbac TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types") { - ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; CheckResult result = check(R"END( local a, b, c = string.match("This is a string", "(.()(%a+))") )END"); @@ -1210,7 +1203,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types") TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") { - ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; CheckResult result = check(R"END( local a, b, c = string.match("This is a string", "(.()(%a+))", "this should be a number") )END"); @@ -1229,7 +1221,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") { - ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; CheckResult result = check(R"END( local d, e, a, b, c = string.find("This is a string", "(.()(%a+))") )END"); @@ -1245,7 +1236,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types2") { - ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; CheckResult result = check(R"END( local d, e, a, b, c = string.find("This is a string", "(.()(%a+))", "this should be a number") )END"); @@ -1266,7 +1256,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types2") TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") { - ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; CheckResult result = check(R"END( local d, e, a, b, c = string.find("This is a string", "(.()(%a+))", 1, "this should be a bool") )END"); @@ -1287,7 +1276,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") { - ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; CheckResult result = check(R"END( local d, e, a, b = string.find("This is a string", "(.()(%a+))", 1, true) )END"); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index fd6fb83f1..9fe0c6aaf 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -336,4 +336,30 @@ local s : Cls = GetCls() LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "class_definition_overload_metamethods") +{ + loadDefinition(R"( + declare class Vector3 + end + + declare class CFrame + function __mul(self, other: CFrame): CFrame + function __mul(self, other: Vector3): Vector3 + end + + declare function newVector3(): Vector3 + declare function newCFrame(): CFrame + )"); + + CheckResult result = check(R"( + local base = newCFrame() + local shouldBeCFrame = base * newCFrame() + local shouldBeVector = base * newVector3() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("shouldBeCFrame")), "CFrame"); + CHECK_EQ(toString(requireType("shouldBeVector")), "Vector3"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 5cc759d6b..3a6f44911 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -133,6 +133,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") TableTypeVar* ttv = getMutable(*r); REQUIRE(ttv); + REQUIRE(ttv->props.count("f")); TypeId k = ttv->props["f"].type; REQUIRE(k); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 9b10092cf..85249ecdc 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -14,6 +14,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSpecialTypesAsterisked) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) TEST_SUITE_BEGIN("TypeInferLoops"); @@ -109,6 +110,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_just_one_iterator_is_ok") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators_dcr") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function no_iter() end + for key in no_iter() do end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_a_custom_iterator_should_type_check") { CheckResult result = check(R"( @@ -141,7 +154,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") end )"); - CHECK_EQ(2, result.errors.size()); + LUAU_REQUIRE_ERROR_COUNT(2, result); TypeId p = requireType("p"); if (FFlag::LuauSpecialTypesAsterisked) @@ -232,6 +245,30 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_iterator_requiring_args CHECK_EQ(0, acm->actual); } +TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_incompatible_args_to_iterator") +{ + CheckResult result = check(R"( + function my_iter(state: string, index: number) + return state, index + end + + local my_state = {} + local first_index = "first" + + -- Type errors here. my_state and first_index cannot be passed to my_iter + for a, b in my_iter, my_state, first_index do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(get(result.errors[1])); + CHECK(Location{{9, 29}, {9, 37}} == result.errors[0].location); + + CHECK(get(result.errors[1])); + CHECK(Location{{9, 39}, {9, 50}} == result.errors[1].location); +} + TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") { CheckResult result = check(R"( @@ -503,7 +540,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") end )"); - LUAU_REQUIRE_ERROR_COUNT(0, result); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*typeChecker.numberType, *requireType("key")); } diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index b5f2296c6..ede84f4a5 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -16,6 +16,35 @@ LUAU_FASTFLAG(LuauSpecialTypesAsterisked) TEST_SUITE_BEGIN("TypeInferModules"); +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_require_basic") +{ + fileResolver.source["game/A"] = R"( + --!strict + return { + a = 1, + } + )"; + + fileResolver.source["game/B"] = R"( + --!strict + local A = require(game.A) + + local b = A.a + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ModulePtr b = frontend.moduleResolver.modules["game/B"]; + REQUIRE(b != nullptr); + std::optional bType = requireType(b, "b"); + REQUIRE(bType); + CHECK(toString(*bType) == "number"); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "require") { fileResolver.source["game/A"] = R"( diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index a06fd7491..5a6fb0e49 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -11,7 +11,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauDeduceFindMatchReturnTypes) LUAU_FASTFLAG(LuauSpecialTypesAsterisked) using namespace Luau; @@ -61,8 +60,8 @@ TEST_CASE_FIXTURE(Fixture, "string_method") CheckResult result = check(R"( local p = ("tacos"):len() )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*requireType("p"), *typeChecker.numberType); } @@ -73,8 +72,8 @@ TEST_CASE_FIXTURE(Fixture, "string_function_indirect") local l = s.lower local p = l(s) )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*requireType("p"), *typeChecker.stringType); } @@ -84,12 +83,9 @@ TEST_CASE_FIXTURE(Fixture, "string_function_other") local s:string local p = s:match("foo") )"); - CHECK_EQ(0, result.errors.size()); - if (FFlag::LuauDeduceFindMatchReturnTypes) - CHECK_EQ(toString(requireType("p")), "string"); - else - CHECK_EQ(toString(requireType("p")), "string?"); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("p")), "string"); } TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index c8fc7f2da..472d0ed55 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -62,7 +62,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") local a:boolean,b:number,c:string=xpcall(function(): (number,string)return 1,'foo'end,function(): (string,number)return'foo',1 end) )"; - CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + CHECK("boolean" == toString(requireType("a"))); + CHECK("number" == toString(requireType("b"))); + CHECK("string" == toString(requireType("c"))); + + CHECK(expected == decorateWithTypes(code)); } // We had a bug where if you have two type packs that looks like: @@ -609,4 +615,16 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") CHECK("b?" == toString(option2, opts)); // This should not hold. } +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", false}; + + CheckResult result = check(R"( + function no_iter() end + for key in no_iter() do end -- This should not be ok + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 95b85c488..97c3da4f6 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2994,8 +2994,6 @@ TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra_2") TEST_CASE_FIXTURE(Fixture, "expected_indexer_from_table_union") { - ScopedFastFlag luauExpectedTableUnionIndexerType{"LuauExpectedTableUnionIndexerType", true}; - LUAU_REQUIRE_NO_ERRORS(check(R"(local a: {[string]: {number | string}} = {a = {2, 's'}})")); LUAU_REQUIRE_NO_ERRORS(check(R"(local a: {[string]: {number | string}}? = {a = {2, 's'}})")); LUAU_REQUIRE_NO_ERRORS(check(R"(local a: {[string]: {[string]: {string?}}?} = {["a"] = {["b"] = {"a", "b"}}})")); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index d92747513..9c19da59a 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -111,6 +111,11 @@ assert((function() local a = nil a = a and 2 return a end)() == nil) assert((function() local a = 1 a = a or 2 return a end)() == 1) assert((function() local a = nil a = a or 2 return a end)() == 2) +assert((function() local a a = 1 local b = 2 b = a and b return b end)() == 2) +assert((function() local a a = nil local b = 2 b = a and b return b end)() == nil) +assert((function() local a a = 1 local b = 2 b = a or b return b end)() == 1) +assert((function() local a a = nil local b = 2 b = a or b return b end)() == 2) + -- binary arithmetics coerces strings to numbers (sadly) assert(1 + "2" == 3) assert(2 * "0xa" == 20) diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index b69d437bb..529e9b0ca 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -369,21 +369,31 @@ assert(not a and b:match('[^ ]+') == "short:1:") local a,b = loadstring("nope", "=" .. string.rep("thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilities", 10)) assert(not a and b:match('[^ ]+') == "thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbuffe:1:") --- arith errors function ecall(fn, ...) local ok, err = pcall(fn, ...) assert(not ok) - return err:sub(err:find(": ") + 2, #err) + return err:sub((err:find(": ") or -1) + 2, #err) end +-- arith errors assert(ecall(function() return nil + 5 end) == "attempt to perform arithmetic (add) on nil and number") assert(ecall(function() return "a" + "b" end) == "attempt to perform arithmetic (add) on string") assert(ecall(function() return 1 > nil end) == "attempt to compare nil < number") -- note reversed order (by design) assert(ecall(function() return "a" <= 5 end) == "attempt to compare string <= number") +-- table errors assert(ecall(function() local t = {} t[nil] = 2 end) == "table index is nil") assert(ecall(function() local t = {} t[0/0] = 2 end) == "table index is NaN") +assert(ecall(function() local t = {} rawset(t, nil, 2) end) == "table index is nil") +assert(ecall(function() local t = {} rawset(t, 0/0, 2) end) == "table index is NaN") + +assert(ecall(function() local t = {} t[nil] = nil end) == "table index is nil") +assert(ecall(function() local t = {} t[0/0] = nil end) == "table index is NaN") + +assert(ecall(function() local t = {} rawset(t, nil, nil) end) == "table index is nil") +assert(ecall(function() local t = {} rawset(t, 0/0, nil) end) == "table index is NaN") + -- for loop type errors assert(ecall(function() for i='a',2 do end end) == "invalid 'for' initial value (number expected, got string)") assert(ecall(function() for i=1,'a' do end end) == "invalid 'for' limit (number expected, got string)") diff --git a/tools/faillist.txt b/tools/faillist.txt index 54e7ac059..3de9db769 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,10 +1,8 @@ AnnotationTests.builtin_types_are_not_exported -AnnotationTests.cannot_use_nonexported_type -AnnotationTests.cloned_interface_maintains_pointers_between_definitions AnnotationTests.duplicate_type_param_name AnnotationTests.for_loop_counter_annotation_is_checked AnnotationTests.generic_aliases_are_cloned_properly -AnnotationTests.interface_types_belong_to_interface_arena +AnnotationTests.instantiation_clone_has_to_follow AnnotationTests.luau_ice_triggers_an_ice AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag_handler @@ -26,6 +24,7 @@ AutocompleteTest.autocomplete_first_function_arg_expected_type AutocompleteTest.autocomplete_for_in_middle_keywords AutocompleteTest.autocomplete_for_middle_keywords AutocompleteTest.autocomplete_if_middle_keywords +AutocompleteTest.autocomplete_interpolated_string AutocompleteTest.autocomplete_on_string_singletons AutocompleteTest.autocomplete_oop_implicit_self AutocompleteTest.autocomplete_repeat_middle_keyword @@ -56,14 +55,12 @@ AutocompleteTest.global_functions_are_not_scoped_lexically AutocompleteTest.globals_are_order_independent AutocompleteTest.if_then_else_elseif_completions AutocompleteTest.keyword_methods -AutocompleteTest.keyword_types AutocompleteTest.library_non_self_calls_are_fine AutocompleteTest.library_self_calls_are_invalid AutocompleteTest.local_function AutocompleteTest.local_function_params AutocompleteTest.local_functions_fall_out_of_scope AutocompleteTest.method_call_inside_function_body -AutocompleteTest.module_type_members AutocompleteTest.nested_member_completions AutocompleteTest.nested_recursive_function AutocompleteTest.no_function_name_suggestions @@ -78,7 +75,6 @@ AutocompleteTest.return_types AutocompleteTest.sometimes_the_metatable_is_an_error AutocompleteTest.source_module_preservation_and_invalidation AutocompleteTest.statement_between_two_statements -AutocompleteTest.stop_at_first_stat_when_recommending_keywords AutocompleteTest.string_prim_non_self_calls_are_avoided AutocompleteTest.string_prim_self_calls_are_fine AutocompleteTest.suggest_external_module_type @@ -155,6 +151,7 @@ BuiltinTests.string_format_tostring_specifier BuiltinTests.string_format_tostring_specifier_type_constraint BuiltinTests.string_format_use_correct_argument BuiltinTests.string_format_use_correct_argument2 +BuiltinTests.string_format_use_correct_argument3 BuiltinTests.string_lib_self_noself BuiltinTests.table_concat_returns_string BuiltinTests.table_dot_remove_optionally_returns_generic @@ -168,31 +165,21 @@ BuiltinTests.tonumber_returns_optional_number_type BuiltinTests.tonumber_returns_optional_number_type2 DefinitionTests.declaring_generic_functions DefinitionTests.definition_file_classes -DefinitionTests.definition_file_loading -DefinitionTests.definitions_documentation_symbols -DefinitionTests.documentation_symbols_dont_attach_to_persistent_types -DefinitionTests.single_class_type_identity_in_global_types FrontendTest.ast_node_at_position FrontendTest.automatically_check_dependent_scripts -FrontendTest.check_without_builtin_next FrontendTest.dont_reparse_clean_file_when_linting FrontendTest.environments FrontendTest.imported_table_modification_2 FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded -FrontendTest.no_use_after_free_with_type_fun_instantiation FrontendTest.nocheck_cycle_used_by_checked -FrontendTest.nocheck_modules_are_typed FrontendTest.produce_errors_for_unchanged_file_with_a_syntax_error FrontendTest.recheck_if_dependent_script_is_dirty FrontendTest.reexport_cyclic_type -FrontendTest.reexport_type_alias -FrontendTest.report_require_to_nonexistent_file FrontendTest.report_syntax_error_in_required_file FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics2 GenericsTests.better_mismatch_error_messages -GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.calling_self_generic_methods GenericsTests.check_generic_typepack_function GenericsTests.check_mutual_generic_functions @@ -208,7 +195,6 @@ GenericsTests.factories_of_generics GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_factories -GenericsTests.generic_functions_dont_cache_type_parameters GenericsTests.generic_functions_in_types GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_table_method @@ -245,7 +231,6 @@ IntersectionTypes.no_stack_overflow_from_flattenintersection IntersectionTypes.overload_is_not_a_function IntersectionTypes.select_correct_union_fn IntersectionTypes.should_still_pick_an_overload_whose_arguments_are_unions -IntersectionTypes.table_intersection_write IntersectionTypes.table_intersection_write_sealed IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect @@ -255,7 +240,6 @@ Linter.TableOperations ModuleTests.clone_self_property ModuleTests.deepClone_cyclic_table ModuleTests.do_not_clone_reexports -ModuleTests.do_not_clone_types_of_reexported_values NonstrictModeTests.delay_function_does_not_require_its_argument_to_return_anything NonstrictModeTests.for_in_iterator_variables_are_any NonstrictModeTests.function_parameters_are_any @@ -319,6 +303,7 @@ ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weird_fail_to_unify_type_pack ProvisionalTests.weirditer_should_not_loop_forever ProvisionalTests.while_body_are_also_refined +ProvisionalTests.xpcall_returns_what_f_returns RefinementTest.and_constraint RefinementTest.and_or_peephole_refinement RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string @@ -358,6 +343,7 @@ RefinementTest.not_and_constraint RefinementTest.not_t_or_some_prop_of_t RefinementTest.or_predicate_with_truthy_predicates RefinementTest.parenthesized_expressions_are_followed_through +RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table RefinementTest.refine_the_correct_types_opposite_of_when_a_is_not_number_or_string RefinementTest.refine_unknowns RefinementTest.string_not_equal_to_string_or_nil @@ -394,7 +380,6 @@ TableTests.augment_table TableTests.builtin_table_names TableTests.call_method TableTests.cannot_augment_sealed_table -TableTests.cannot_call_tables TableTests.cannot_change_type_of_unsealed_table_prop TableTests.casting_sealed_tables_with_props_into_table_with_indexer TableTests.casting_tables_with_props_into_table_with_indexer3 @@ -409,7 +394,6 @@ TableTests.defining_a_self_method_for_a_builtin_sealed_table_must_fail TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index -TableTests.dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back TableTests.dont_leak_free_table_props TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys @@ -448,6 +432,7 @@ TableTests.length_operator_intersection TableTests.length_operator_non_table_union TableTests.length_operator_union TableTests.length_operator_union_errors +TableTests.less_exponential_blowup_please TableTests.meta_add TableTests.meta_add_both_ways TableTests.meta_add_inferred @@ -470,7 +455,6 @@ TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table TableTests.quantifying_a_bound_var_works TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table -TableTests.recursive_metatable_type_call TableTests.result_is_always_any_if_lhs_is_any TableTests.result_is_bool_for_equality_operators_if_lhs_is_any TableTests.right_table_missing_key @@ -480,7 +464,6 @@ TableTests.scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables -TableTests.table_function_check_use_after_free TableTests.table_indexing_error_location TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict TableTests.table_insert_should_cope_with_optional_properties_in_strict @@ -525,13 +508,10 @@ TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails TryUnifyTests.variadics_should_use_reversed_properly TypeAliases.cli_38393_recursive_intersection_oom -TypeAliases.corecursive_types_generic TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any -TypeAliases.general_require_multi_assign TypeAliases.generic_param_remap TypeAliases.mismatched_generic_pack_type_param TypeAliases.mismatched_generic_type_param -TypeAliases.mutually_recursive_types_errors TypeAliases.mutually_recursive_types_restriction_not_ok_1 TypeAliases.mutually_recursive_types_restriction_not_ok_2 TypeAliases.mutually_recursive_types_swapsies_not_ok @@ -543,7 +523,6 @@ TypeAliases.type_alias_fwd_declaration_is_precise TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_of_an_imported_recursive_generic_type -TypeAliases.type_alias_of_an_imported_recursive_type TypeInfer.checking_should_not_ice TypeInfer.cyclic_follow TypeInfer.do_not_bind_a_free_table_to_a_union_containing_that_table @@ -559,20 +538,21 @@ TypeInfer.tc_if_else_expressions_expected_type_1 TypeInfer.tc_if_else_expressions_expected_type_2 TypeInfer.tc_if_else_expressions_expected_type_3 TypeInfer.tc_if_else_expressions_type_union +TypeInfer.tc_interpolated_string_basic +TypeInfer.tc_interpolated_string_constant_type +TypeInfer.tc_interpolated_string_with_invalid_expression TypeInfer.type_infer_recursion_limit_no_ice -TypeInfer.warn_on_lowercase_parent_property TypeInferAnyError.assign_prop_to_table_by_calling_any_yields_any TypeInferAnyError.can_get_length_of_any -TypeInferAnyError.for_in_loop_iterator_is_any TypeInferAnyError.for_in_loop_iterator_is_any2 -TypeInferAnyError.for_in_loop_iterator_is_error -TypeInferAnyError.for_in_loop_iterator_is_error2 TypeInferAnyError.for_in_loop_iterator_returns_any -TypeInferAnyError.for_in_loop_iterator_returns_any2 TypeInferAnyError.length_of_error_type_does_not_produce_an_error TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any +TypeInferAnyError.union_of_types_regression_test TypeInferClasses.call_base_method TypeInferClasses.call_instance_method +TypeInferClasses.can_assign_to_prop_of_base_class_using_string +TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.classes_can_have_overloaded_operators TypeInferClasses.classes_without_overloaded_operators_cannot_be_added @@ -582,10 +562,13 @@ TypeInferClasses.higher_order_function_return_type_is_not_contravariant TypeInferClasses.higher_order_function_return_values_are_covariant TypeInferClasses.optional_class_field_access_error TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties +TypeInferClasses.table_indexers_are_invariant +TypeInferClasses.table_properties_are_invariant TypeInferClasses.warn_when_prop_almost_matches TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class TypeInferFunctions.another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments TypeInferFunctions.another_recursive_local_function +TypeInferFunctions.call_o_with_another_argument_after_foo_was_quantified TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument TypeInferFunctions.complicated_return_types_require_an_explicit_annotation @@ -634,43 +617,23 @@ TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_return_values TypeInferFunctions.vararg_function_is_quantified TypeInferFunctions.vararg_functions_should_allow_calls_of_any_types_and_size -TypeInferLoops.for_in_loop TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values -TypeInferLoops.for_in_loop_error_on_iterator_requiring_args_but_none_given -TypeInferLoops.for_in_loop_on_error -TypeInferLoops.for_in_loop_on_non_function -TypeInferLoops.for_in_loop_should_fail_with_non_function_iterator -TypeInferLoops.for_in_loop_where_iteratee_is_free TypeInferLoops.for_in_loop_with_custom_iterator TypeInferLoops.for_in_loop_with_next -TypeInferLoops.for_in_with_a_custom_iterator_should_type_check -TypeInferLoops.for_in_with_an_iterator_of_type_any TypeInferLoops.for_in_with_just_one_iterator_is_ok -TypeInferLoops.fuzz_fail_missing_instantitation_follow -TypeInferLoops.ipairs_produces_integral_indices -TypeInferLoops.loop_iter_basic TypeInferLoops.loop_iter_iter_metamethod TypeInferLoops.loop_iter_no_indexer_nonstrict -TypeInferLoops.loop_iter_no_indexer_strict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.loop_typecheck_crash_on_empty_optional -TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free -TypeInferModules.do_not_modify_imported_types -TypeInferModules.do_not_modify_imported_types_2 -TypeInferModules.do_not_modify_imported_types_3 -TypeInferModules.general_require_call_expression +TypeInferModules.custom_require_global TypeInferModules.general_require_type_mismatch TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated -TypeInferModules.require TypeInferModules.require_a_variadic_function -TypeInferModules.require_failed_module -TypeInferModules.require_module_that_does_not_export TypeInferModules.require_types TypeInferModules.type_error_of_unknown_qualified_type -TypeInferModules.warn_if_you_try_to_require_a_non_modulescript TypeInferOOP.CheckMethodsOfSealed TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 @@ -725,7 +688,6 @@ TypeInferOperators.typecheck_unary_minus_error TypeInferOperators.unary_not_is_boolean TypeInferOperators.unknown_type_in_comparison TypeInferOperators.UnknownGlobalCompoundAssign -TypeInferPrimitives.cannot_call_primitives TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.string_function_other TypeInferPrimitives.string_index @@ -768,7 +730,6 @@ TypePackTests.type_alias_type_packs_nested TypePackTests.type_pack_hidden_free_tail_infinite_growth TypePackTests.type_pack_type_parameters TypePackTests.varargs_inference_through_multiple_scopes -TypePackTests.variadic_argument_tail TypePackTests.variadic_pack_syntax TypePackTests.variadic_packs TypeSingletons.bool_singleton_subtype @@ -793,7 +754,6 @@ TypeSingletons.string_singletons_escape_chars TypeSingletons.string_singletons_mismatch TypeSingletons.table_insert_with_a_singleton_argument TypeSingletons.table_properties_type_error_escapes -TypeSingletons.tagged_unions_immutable_tag TypeSingletons.tagged_unions_using_singletons TypeSingletons.taking_the_length_of_string_singleton TypeSingletons.taking_the_length_of_union_of_string_singleton diff --git a/tools/test_dcr.py b/tools/test_dcr.py index da33706c3..db932253b 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -84,12 +84,16 @@ def main(): failList = loadFailList() + commandLine = [ + args.path, + "--reporters=xml", + "--fflags=true,DebugLuauDeferredConstraintResolution=true", + ] + + print('>', ' '.join(commandLine), file=sys.stderr) + p = sp.Popen( - [ - args.path, - "--reporters=xml", - "--fflags=true,DebugLuauDeferredConstraintResolution=true", - ], + commandLine, stdout=sp.PIPE, ) @@ -122,7 +126,7 @@ def main(): with open(FAIL_LIST_PATH, "w", newline="\n") as f: for name in newFailList: print(name, file=f) - print("Updated faillist.txt") + print("Updated faillist.txt", file=sys.stderr) if handler.numSkippedTests > 0: print( From dec4b67b5a3335123623e80328ef39e41800dfe7 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 8 Sep 2022 14:44:50 -0700 Subject: [PATCH 02/66] Sync to upstream/release/544 --- Analysis/include/Luau/Anyification.h | 7 +- Analysis/include/Luau/BuiltinDefinitions.h | 12 +- .../include/Luau/ConstraintGraphBuilder.h | 7 +- Analysis/include/Luau/ConstraintSolver.h | 11 +- .../include/Luau/ConstraintSolverLogger.h | 29 -- Analysis/include/Luau/DcrLogger.h | 131 ++++++ Analysis/include/Luau/Documentation.h | 1 + Analysis/include/Luau/Frontend.h | 3 + Analysis/include/Luau/JsonEmitter.h | 12 + Analysis/include/Luau/Module.h | 2 +- Analysis/include/Luau/Normalize.h | 22 +- Analysis/include/Luau/TypeArena.h | 2 + Analysis/include/Luau/TypeChecker2.h | 6 +- Analysis/include/Luau/TypeInfer.h | 3 +- Analysis/include/Luau/TypeUtils.h | 10 +- Analysis/include/Luau/TypeVar.h | 38 +- Analysis/include/Luau/Unifier.h | 13 +- Analysis/src/Anyification.cpp | 11 +- Analysis/src/Autocomplete.cpp | 108 +++-- Analysis/src/BuiltinDefinitions.cpp | 146 ++++++- Analysis/src/ConstraintGraphBuilder.cpp | 83 ++-- Analysis/src/ConstraintSolver.cpp | 95 +++-- Analysis/src/ConstraintSolverLogger.cpp | 150 ------- Analysis/src/DcrLogger.cpp | 395 ++++++++++++++++++ Analysis/src/EmbeddedBuiltinDefinitions.cpp | 6 +- Analysis/src/Frontend.cpp | 35 +- Analysis/src/Linter.cpp | 3 +- Analysis/src/Module.cpp | 22 +- Analysis/src/Normalize.cpp | 46 +- Analysis/src/TypeArena.cpp | 9 + Analysis/src/TypeChecker2.cpp | 71 ++-- Analysis/src/TypeInfer.cpp | 144 ++++--- Analysis/src/TypeUtils.cpp | 38 +- Analysis/src/TypeVar.cpp | 63 ++- Analysis/src/Unifier.cpp | 39 +- CodeGen/include/Luau/CodeAllocator.h | 50 +++ CodeGen/include/Luau/OperandX64.h | 1 + CodeGen/include/Luau/RegisterX64.h | 1 + CodeGen/src/CodeAllocator.cpp | 188 +++++++++ Common/include/Luau/Bytecode.h | 27 +- Compiler/src/BytecodeBuilder.cpp | 31 -- Compiler/src/Compiler.cpp | 8 - Makefile | 6 +- Sources.cmake | 7 +- VM/include/lua.h | 10 + VM/src/lapi.cpp | 17 + VM/src/lcorolib.cpp | 34 +- VM/src/lgc.cpp | 77 +++- VM/src/lvmexecute.cpp | 14 +- tests/Autocomplete.test.cpp | 6 +- tests/CodeAllocator.test.cpp | 160 +++++++ tests/Compiler.test.cpp | 104 ++++- tests/Conformance.test.cpp | 3 +- tests/ConstraintSolver.test.cpp | 12 +- tests/CostModel.test.cpp | 19 + tests/Fixture.cpp | 11 +- tests/Fixture.h | 3 + tests/Frontend.test.cpp | 4 +- tests/LValue.test.cpp | 55 +-- tests/Linter.test.cpp | 10 +- tests/Module.test.cpp | 12 +- tests/Normalize.test.cpp | 28 +- tests/NotNull.test.cpp | 1 + tests/TypeInfer.aliases.test.cpp | 10 +- tests/TypeInfer.annotations.test.cpp | 6 +- tests/TypeInfer.classes.test.cpp | 10 +- tests/TypeInfer.definitions.test.cpp | 20 +- tests/TypeInfer.functions.test.cpp | 4 +- tests/TypeInfer.primitives.test.cpp | 14 + tests/TypeInfer.provisional.test.cpp | 12 +- tests/TypeInfer.refinements.test.cpp | 10 +- tests/TypeInfer.tryUnify.test.cpp | 2 +- tests/TypeInfer.typePacks.cpp | 4 +- tests/TypeVar.test.cpp | 2 +- tests/conformance/types.lua | 3 - tools/faillist.txt | 11 +- 76 files changed, 1986 insertions(+), 794 deletions(-) delete mode 100644 Analysis/include/Luau/ConstraintSolverLogger.h create mode 100644 Analysis/include/Luau/DcrLogger.h delete mode 100644 Analysis/src/ConstraintSolverLogger.cpp create mode 100644 Analysis/src/DcrLogger.cpp create mode 100644 CodeGen/include/Luau/CodeAllocator.h create mode 100644 CodeGen/src/CodeAllocator.cpp create mode 100644 tests/CodeAllocator.test.cpp diff --git a/Analysis/include/Luau/Anyification.h b/Analysis/include/Luau/Anyification.h index 9dd7d8e00..a6f3e2a90 100644 --- a/Analysis/include/Luau/Anyification.h +++ b/Analysis/include/Luau/Anyification.h @@ -19,9 +19,12 @@ using ScopePtr = std::shared_ptr; // A substitution which replaces free types by any struct Anyification : Substitution { - Anyification(TypeArena* arena, NotNull scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack); - Anyification(TypeArena* arena, const ScopePtr& scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack); + Anyification(TypeArena* arena, NotNull scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, TypeId anyType, + TypePackId anyTypePack); + Anyification(TypeArena* arena, const ScopePtr& scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, TypeId anyType, + TypePackId anyTypePack); NotNull scope; + NotNull singletonTypes; InternalErrorReporter* iceHandler; TypeId anyType; diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 28a4368e7..0292dff78 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Frontend.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -8,6 +9,7 @@ namespace Luau { void registerBuiltinTypes(TypeChecker& typeChecker); +void registerBuiltinTypes(Frontend& frontend); TypeId makeUnion(TypeArena& arena, std::vector&& types); TypeId makeIntersection(TypeArena& arena, std::vector&& types); @@ -15,6 +17,7 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types); /** Build an optional 't' */ TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t); +TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t); /** Small utility function for building up type definitions from C++. */ @@ -41,12 +44,17 @@ void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::strin std::string getBuiltinDefinitionSource(); -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName); void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding); +void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName); void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding); -std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std::string& name); +void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding); +void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding); +std::optional tryGetGlobalBinding(Frontend& frontend, const std::string& name); Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name); +TypeId getGlobalBinding(Frontend& frontend, const std::string& name); TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 1cba0d33d..1567e0ada 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -21,6 +21,8 @@ namespace Luau struct Scope; using ScopePtr = std::shared_ptr; +struct DcrLogger; + struct ConstraintGraphBuilder { // A list of all the scopes in the module. This vector holds ownership of the @@ -30,7 +32,7 @@ struct ConstraintGraphBuilder ModuleName moduleName; ModulePtr module; - SingletonTypes& singletonTypes; + NotNull singletonTypes; const NotNull arena; // The root scope of the module we're generating constraints for. // This is null when the CGB is initially constructed. @@ -58,9 +60,10 @@ struct ConstraintGraphBuilder const NotNull ice; ScopePtr globalScope; + DcrLogger* logger; ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, - NotNull ice, const ScopePtr& globalScope); + NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger); /** * Fabricates a new free type belonging to a given scope. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 002aa9475..059e97cb3 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -5,14 +5,16 @@ #include "Luau/Error.h" #include "Luau/Variant.h" #include "Luau/Constraint.h" -#include "Luau/ConstraintSolverLogger.h" #include "Luau/TypeVar.h" +#include "Luau/ToString.h" #include namespace Luau { +struct DcrLogger; + // TypeId, TypePackId, or Constraint*. It is impossible to know which, but we // never dereference this pointer. using BlockedConstraintId = const void*; @@ -40,6 +42,7 @@ struct HashInstantiationSignature struct ConstraintSolver { TypeArena* arena; + NotNull singletonTypes; InternalErrorReporter iceReporter; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; @@ -69,10 +72,10 @@ struct ConstraintSolver NotNull moduleResolver; std::vector requireCycles; - ConstraintSolverLogger logger; + DcrLogger* logger; - explicit ConstraintSolver(TypeArena* arena, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, - std::vector requireCycles); + explicit ConstraintSolver(TypeArena* arena, NotNull singletonTypes, NotNull rootScope, ModuleName moduleName, + NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); /** * Attempts to dispatch all pending constraints and reach a type solution diff --git a/Analysis/include/Luau/ConstraintSolverLogger.h b/Analysis/include/Luau/ConstraintSolverLogger.h deleted file mode 100644 index 65aa9a7e6..000000000 --- a/Analysis/include/Luau/ConstraintSolverLogger.h +++ /dev/null @@ -1,29 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - -#include "Luau/Constraint.h" -#include "Luau/NotNull.h" -#include "Luau/Scope.h" -#include "Luau/ToString.h" - -#include -#include -#include - -namespace Luau -{ - -struct ConstraintSolverLogger -{ - std::string compileOutput(); - void captureBoundarySnapshot(const Scope* rootScope, std::vector>& unsolvedConstraints); - void prepareStepSnapshot( - const Scope* rootScope, NotNull current, std::vector>& unsolvedConstraints, bool force); - void commitPreparedStepSnapshot(); - -private: - std::vector snapshots; - std::optional preparedSnapshot; - ToStringOptions opts; -}; - -} // namespace Luau diff --git a/Analysis/include/Luau/DcrLogger.h b/Analysis/include/Luau/DcrLogger.h new file mode 100644 index 000000000..bd8672e32 --- /dev/null +++ b/Analysis/include/Luau/DcrLogger.h @@ -0,0 +1,131 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Constraint.h" +#include "Luau/NotNull.h" +#include "Luau/Scope.h" +#include "Luau/ToString.h" +#include "Luau/Error.h" +#include "Luau/Variant.h" + +#include +#include +#include + +namespace Luau +{ + +struct ErrorSnapshot +{ + std::string message; + Location location; +}; + +struct BindingSnapshot +{ + std::string typeId; + std::string typeString; + Location location; +}; + +struct TypeBindingSnapshot +{ + std::string typeId; + std::string typeString; +}; + +struct ConstraintGenerationLog +{ + std::string source; + std::unordered_map constraintLocations; + std::vector errors; +}; + +struct ScopeSnapshot +{ + std::unordered_map bindings; + std::unordered_map typeBindings; + std::unordered_map typePackBindings; + std::vector children; +}; + +enum class ConstraintBlockKind +{ + TypeId, + TypePackId, + ConstraintId, +}; + +struct ConstraintBlock +{ + ConstraintBlockKind kind; + std::string stringification; +}; + +struct ConstraintSnapshot +{ + std::string stringification; + std::vector blocks; +}; + +struct BoundarySnapshot +{ + std::unordered_map constraints; + ScopeSnapshot rootScope; +}; + +struct StepSnapshot +{ + std::string currentConstraint; + bool forced; + std::unordered_map unsolvedConstraints; + ScopeSnapshot rootScope; +}; + +struct TypeSolveLog +{ + BoundarySnapshot initialState; + std::vector stepStates; + BoundarySnapshot finalState; +}; + +struct TypeCheckLog +{ + std::vector errors; +}; + +using ConstraintBlockTarget = Variant>; + +struct DcrLogger +{ + std::string compileOutput(); + + void captureSource(std::string source); + void captureGenerationError(const TypeError& error); + void captureConstraintLocation(NotNull constraint, Location location); + + void pushBlock(NotNull constraint, TypeId block); + void pushBlock(NotNull constraint, TypePackId block); + void pushBlock(NotNull constraint, NotNull block); + void popBlock(TypeId block); + void popBlock(TypePackId block); + void popBlock(NotNull block); + + void captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints); + StepSnapshot prepareStepSnapshot(const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints); + void commitStepSnapshot(StepSnapshot snapshot); + void captureFinalSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints); + + void captureTypeCheckError(const TypeError& error); +private: + ConstraintGenerationLog generationLog; + std::unordered_map, std::vector> constraintBlocks; + TypeSolveLog solveLog; + TypeCheckLog checkLog; + + ToStringOptions opts; + + std::vector snapshotBlocks(NotNull constraint); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Documentation.h b/Analysis/include/Luau/Documentation.h index 7a2b56ffb..67a9feb19 100644 --- a/Analysis/include/Luau/Documentation.h +++ b/Analysis/include/Luau/Documentation.h @@ -1,3 +1,4 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/DenseHash.h" diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 556126892..04c598de1 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -174,6 +174,9 @@ struct Frontend ScopePtr globalScope; public: + SingletonTypes singletonTypes_; + const NotNull singletonTypes; + FileResolver* fileResolver; FrontendModuleResolver moduleResolver; FrontendModuleResolver moduleResolverForAutocomplete; diff --git a/Analysis/include/Luau/JsonEmitter.h b/Analysis/include/Luau/JsonEmitter.h index 0bf3327a8..d8dc96e43 100644 --- a/Analysis/include/Luau/JsonEmitter.h +++ b/Analysis/include/Luau/JsonEmitter.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "Luau/NotNull.h" @@ -232,4 +233,15 @@ void write(JsonEmitter& emitter, const std::optional& v) emitter.writeRaw("null"); } +template +void write(JsonEmitter& emitter, const std::unordered_map& map) +{ + ObjectEmitter o = emitter.writeObject(); + + for (const auto& [k, v] : map) + o.writePair(k, v); + + o.finish(); +} + } // namespace Luau::Json diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index bec51b81a..d22aad12c 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -90,7 +90,7 @@ struct Module // Once a module has been typechecked, we clone its public interface into a separate arena. // This helps us to force TypeVar ownership into a DAG rather than a DCG. - void clonePublicInterface(InternalErrorReporter& ice); + void clonePublicInterface(NotNull singletonTypes, InternalErrorReporter& ice); }; } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 78b241e4f..48dbe2bea 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once #include "Luau/Module.h" #include "Luau/NotNull.h" @@ -12,17 +13,20 @@ namespace Luau struct InternalErrorReporter; struct Module; struct Scope; +struct SingletonTypes; using ModulePtr = std::shared_ptr; -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, InternalErrorReporter& ice); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, InternalErrorReporter& ice); - -std::pair normalize(TypeId ty, NotNull scope, TypeArena& arena, InternalErrorReporter& ice); -std::pair normalize(TypeId ty, NotNull module, InternalErrorReporter& ice); -std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); -std::pair normalize(TypePackId ty, NotNull scope, TypeArena& arena, InternalErrorReporter& ice); -std::pair normalize(TypePackId ty, NotNull module, InternalErrorReporter& ice); -std::pair normalize(TypePackId ty, const ModulePtr& module, InternalErrorReporter& ice); +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); + +std::pair normalize( + TypeId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice); +std::pair normalize(TypeId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice); +std::pair normalize(TypeId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice); +std::pair normalize( + TypePackId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice); } // namespace Luau diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index decc8c590..1e029aeb8 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -31,6 +31,8 @@ struct TypeArena TypeId freshType(TypeLevel level); TypeId freshType(Scope* scope); + TypePackId freshTypePack(Scope* scope); + TypePackId addTypePack(std::initializer_list types); TypePackId addTypePack(std::vector types, std::optional tail = {}); TypePackId addTypePack(TypePack pack); diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index a6c7a3e3a..a9cd6ec8c 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -4,10 +4,14 @@ #include "Luau/Ast.h" #include "Luau/Module.h" +#include "Luau/NotNull.h" namespace Luau { -void check(const SourceModule& sourceModule, Module* module); +struct DcrLogger; +struct SingletonTypes; + +void check(NotNull singletonTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 0b427946b..b0b3f3ac5 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -58,7 +58,7 @@ class TimeLimitError : public std::exception // within a program are borrowed pointers into this set. struct TypeChecker { - explicit TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler); + explicit TypeChecker(ModuleResolver* resolver, NotNull singletonTypes, InternalErrorReporter* iceHandler); TypeChecker(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete; @@ -353,6 +353,7 @@ struct TypeChecker ModuleName currentModuleName; std::function prepareModuleScope; + NotNull singletonTypes; InternalErrorReporter* iceHandler; UnifierSharedState unifierState; diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 6c611fb2c..6890f881a 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -15,10 +15,12 @@ struct TxnLog; using ScopePtr = std::shared_ptr; -std::optional findMetatableEntry(ErrorVec& errors, TypeId type, const std::string& entry, Location location); -std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, const std::string& name, Location location); -std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, TypeId type, const std::string& prop, - const Location& location, bool addErrors, InternalErrorReporter& handle); +std::optional findMetatableEntry( + NotNull singletonTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); +std::optional findTablePropertyRespectingMeta( + NotNull singletonTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); +std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, NotNull singletonTypes, + TypeId type, const std::string& prop, const Location& location, bool addErrors, InternalErrorReporter& handle); // Returns the minimum and maximum number of types the argument list can accept. std::pair> getParameterExtents(const TxnLog* log, TypePackId tp); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index e67b36014..2847d0b16 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -586,7 +586,7 @@ bool isOverloadedFunction(TypeId ty); // True when string is a subtype of ty bool maybeString(TypeId ty); -std::optional getMetatable(TypeId type); +std::optional getMetatable(TypeId type, NotNull singletonTypes); TableTypeVar* getMutableTableType(TypeId type); const TableTypeVar* getTableType(TypeId type); @@ -614,21 +614,6 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount); struct SingletonTypes { - const TypeId nilType; - const TypeId numberType; - const TypeId stringType; - const TypeId booleanType; - const TypeId threadType; - const TypeId trueType; - const TypeId falseType; - const TypeId anyType; - const TypeId unknownType; - const TypeId neverType; - - const TypePackId anyTypePack; - const TypePackId neverTypePack; - const TypePackId uninhabitableTypePack; - SingletonTypes(); ~SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; @@ -644,9 +629,28 @@ struct SingletonTypes bool debugFreezeArena = false; TypeId makeStringMetatable(); + +public: + const TypeId nilType; + const TypeId numberType; + const TypeId stringType; + const TypeId booleanType; + const TypeId threadType; + const TypeId trueType; + const TypeId falseType; + const TypeId anyType; + const TypeId unknownType; + const TypeId neverType; + const TypeId errorType; + + const TypePackId anyTypePack; + const TypePackId neverTypePack; + const TypePackId uninhabitableTypePack; + const TypePackId errorTypePack; }; -SingletonTypes& getSingletonTypes(); +// Clip with FFlagLuauNoMoreGlobalSingletonTypes +SingletonTypes& DEPRECATED_getSingletonTypes(); void persist(TypeId ty); void persist(TypePackId tp); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index c7eb51a65..4d46869dd 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -3,10 +3,11 @@ #include "Luau/Error.h" #include "Luau/Location.h" +#include "Luau/ParseOptions.h" #include "Luau/Scope.h" +#include "Luau/Substitution.h" #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" -#include "Luau/TypeInfer.h" #include "Luau/UnifierSharedState.h" #include @@ -23,11 +24,14 @@ enum Variance // A substitution which replaces singleton types by their wider types struct Widen : Substitution { - Widen(TypeArena* arena) + Widen(TypeArena* arena, NotNull singletonTypes) : Substitution(TxnLog::empty(), arena) + , singletonTypes(singletonTypes) { } + NotNull singletonTypes; + bool isDirty(TypeId ty) override; bool isDirty(TypePackId ty) override; TypeId clean(TypeId ty) override; @@ -47,6 +51,7 @@ struct UnifierOptions struct Unifier { TypeArena* const types; + NotNull singletonTypes; Mode mode; NotNull scope; // const Scope maybe @@ -59,8 +64,8 @@ struct Unifier UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, NotNull scope, const Location& location, Variance variance, UnifierSharedState& sharedState, - TxnLog* parentLog = nullptr); + Unifier(TypeArena* types, NotNull singletonTypes, Mode mode, NotNull scope, const Location& location, Variance variance, + UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp index abcaba020..cc9796eec 100644 --- a/Analysis/src/Anyification.cpp +++ b/Analysis/src/Anyification.cpp @@ -11,17 +11,20 @@ LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) namespace Luau { -Anyification::Anyification(TypeArena* arena, NotNull scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) +Anyification::Anyification(TypeArena* arena, NotNull scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, + TypeId anyType, TypePackId anyTypePack) : Substitution(TxnLog::empty(), arena) , scope(scope) + , singletonTypes(singletonTypes) , iceHandler(iceHandler) , anyType(anyType) , anyTypePack(anyTypePack) { } -Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) - : Anyification(arena, NotNull{scope.get()}, iceHandler, anyType, anyTypePack) +Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, + TypeId anyType, TypePackId anyTypePack) + : Anyification(arena, NotNull{scope.get()}, singletonTypes, iceHandler, anyType, anyTypePack) { } @@ -71,7 +74,7 @@ TypeId Anyification::clean(TypeId ty) for (TypeId& ty : copy) ty = replace(ty); TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)}); - auto [t, ok] = normalize(res, scope, *arena, *iceHandler); + auto [t, ok] = normalize(res, scope, *arena, singletonTypes, *iceHandler); if (!ok) normalizationTooComplex = true; return t; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 378a1cb7d..2fc145d32 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,8 +14,6 @@ LUAU_FASTFLAG(LuauSelfCallAutocompleteFix3) -LUAU_FASTFLAGVARIABLE(LuauAutocompleteFixGlobalOrder, false) - static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -137,27 +135,28 @@ static std::optional findExpectedTypeAt(const Module& module, AstNode* n return *it; } -static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena) +static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena, NotNull singletonTypes) { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); - Unifier unifier(typeArena, Mode::Strict, scope, Location(), Variance::Covariant, unifierState); + Unifier unifier(typeArena, singletonTypes, Mode::Strict, scope, Location(), Variance::Covariant, unifierState); return unifier.canUnify(subTy, superTy).empty(); } -static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, Position position, TypeId ty) +static TypeCorrectKind checkTypeCorrectKind( + const Module& module, TypeArena* typeArena, NotNull singletonTypes, AstNode* node, Position position, TypeId ty) { ty = follow(ty); NotNull moduleScope{module.getModuleScope().get()}; - auto canUnify = [&typeArena, moduleScope](TypeId subTy, TypeId superTy) { + auto canUnify = [&typeArena, singletonTypes, moduleScope](TypeId subTy, TypeId superTy) { LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3); InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); - Unifier unifier(typeArena, Mode::Strict, moduleScope, Location(), Variance::Covariant, unifierState); + Unifier unifier(typeArena, singletonTypes, Mode::Strict, moduleScope, Location(), Variance::Covariant, unifierState); unifier.tryUnify(subTy, superTy); bool ok = unifier.errors.empty(); @@ -171,11 +170,11 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); - auto checkFunctionType = [typeArena, moduleScope, &canUnify, &expectedType](const FunctionTypeVar* ftv) { + auto checkFunctionType = [typeArena, singletonTypes, moduleScope, &canUnify, &expectedType](const FunctionTypeVar* ftv) { if (FFlag::LuauSelfCallAutocompleteFix3) { if (std::optional firstRetTy = first(ftv->retTypes)) - return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena); + return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, singletonTypes); return false; } @@ -214,7 +213,8 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } if (FFlag::LuauSelfCallAutocompleteFix3) - return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena, singletonTypes) ? TypeCorrectKind::Correct + : TypeCorrectKind::None; else return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } @@ -226,8 +226,8 @@ enum class PropIndexType Key, }; -static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId rootTy, TypeId ty, PropIndexType indexType, - const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, +static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull singletonTypes, TypeId rootTy, TypeId ty, + PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) { if (FFlag::LuauSelfCallAutocompleteFix3) @@ -272,7 +272,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId return colonIndex; } }; - auto isWrongIndexer = [typeArena, &module, rootTy, indexType](Luau::TypeId type) { + auto isWrongIndexer = [typeArena, singletonTypes, &module, rootTy, indexType](Luau::TypeId type) { LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix3); if (indexType == PropIndexType::Key) @@ -280,7 +280,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId bool calledWithSelf = indexType == PropIndexType::Colon; - auto isCompatibleCall = [typeArena, &module, rootTy, calledWithSelf](const FunctionTypeVar* ftv) { + auto isCompatibleCall = [typeArena, singletonTypes, &module, rootTy, calledWithSelf](const FunctionTypeVar* ftv) { // Strong match with definition is a success if (calledWithSelf == ftv->hasSelf) return true; @@ -293,7 +293,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible if (std::optional firstArgTy = first(ftv->argTypes)) { - if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena)) + if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, singletonTypes)) return calledWithSelf; } @@ -327,8 +327,9 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId if (result.count(name) == 0 && name != kParseNameError) { Luau::TypeId type = Luau::follow(prop.type); - TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct - : checkTypeCorrectKind(module, typeArena, nodes.back(), {{}, {}}, type); + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key + ? TypeCorrectKind::Correct + : checkTypeCorrectKind(module, typeArena, singletonTypes, nodes.back(), {{}, {}}, type); ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); @@ -355,13 +356,13 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId TypeId followed = follow(indexIt->second.type); if (get(followed) || get(followed)) { - autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, followed, indexType, nodes, result, seen); } else if (auto indexFunction = get(followed)) { std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) - autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } } }; @@ -371,13 +372,13 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId containingClass = containingClass.value_or(cls); fillProps(cls->props); if (cls->parent) - autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); + autocompleteProps(module, typeArena, singletonTypes, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); } else if (auto tbl = get(ty)) fillProps(tbl->props); else if (auto mt = get(ty)) { - autocompleteProps(module, typeArena, rootTy, mt->table, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, mt->table, indexType, nodes, result, seen); if (FFlag::LuauSelfCallAutocompleteFix3) { @@ -395,12 +396,12 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId { TypeId followed = follow(indexIt->second.type); if (get(followed) || get(followed)) - autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, followed, indexType, nodes, result, seen); else if (auto indexFunction = get(followed)) { std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) - autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } } } @@ -413,7 +414,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen = seen; - autocompleteProps(module, typeArena, rootTy, ty, indexType, nodes, inner, innerSeen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, ty, indexType, nodes, inner, innerSeen); for (auto& pair : inner) result.insert(pair); @@ -436,7 +437,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId if (iter == endIter) return; - autocompleteProps(module, typeArena, rootTy, *iter, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, *iter, indexType, nodes, result, seen); ++iter; @@ -454,7 +455,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId continue; } - autocompleteProps(module, typeArena, rootTy, *iter, indexType, nodes, inner, innerSeen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, *iter, indexType, nodes, inner, innerSeen); std::unordered_set toRemove; @@ -481,7 +482,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } else if (FFlag::LuauSelfCallAutocompleteFix3 && get(get(ty))) { - autocompleteProps(module, typeArena, rootTy, getSingletonTypes().stringType, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, singletonTypes->stringType, indexType, nodes, result, seen); } } @@ -506,18 +507,18 @@ static void autocompleteKeywords( } } -static void autocompleteProps( - const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result) +static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull singletonTypes, TypeId ty, PropIndexType indexType, + const std::vector& nodes, AutocompleteEntryMap& result) { std::unordered_set seen; - autocompleteProps(module, typeArena, ty, ty, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, ty, ty, indexType, nodes, result, seen); } -AutocompleteEntryMap autocompleteProps( - const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes) +AutocompleteEntryMap autocompleteProps(const Module& module, TypeArena* typeArena, NotNull singletonTypes, TypeId ty, + PropIndexType indexType, const std::vector& nodes) { AutocompleteEntryMap result; - autocompleteProps(module, typeArena, ty, indexType, nodes, result); + autocompleteProps(module, typeArena, singletonTypes, ty, indexType, nodes, result); return result; } @@ -1079,19 +1080,11 @@ T* extractStat(const std::vector& ancestry) static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& binding, Position pos) { - if (FFlag::LuauAutocompleteFixGlobalOrder) - { - if (symbol.local) - return binding.location.end < pos; + if (symbol.local) + return binding.location.end < pos; - // Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it - return binding.location == Location() || !binding.location.containsClosed(pos); - } - else - { - // Default Location used for global bindings, which are always legal. - return binding.location == Location() || binding.location.end < pos; - } + // Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it + return binding.location == Location() || !binding.location.containsClosed(pos); } static AutocompleteEntryMap autocompleteStatement( @@ -1220,12 +1213,14 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu { LUAU_ASSERT(!ancestry.empty()); + NotNull singletonTypes = typeChecker.singletonTypes; + AstNode* node = ancestry.rbegin()[0]; if (node->is()) { if (auto it = module.astTypes.find(node->asExpr())) - autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result); + autocompleteProps(module, typeArena, singletonTypes, *it, PropIndexType::Point, ancestry, result); } else if (autocompleteIfElseExpression(node, ancestry, position, result)) return AutocompleteContext::Keyword; @@ -1249,7 +1244,7 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu std::string n = toString(name); if (!result.count(n)) { - TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, position, binding.typeId); + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, binding.typeId); result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; @@ -1259,9 +1254,9 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu scope = scope->parent; } - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); - TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType); - TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType); + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, typeChecker.nilType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->falseType); TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; @@ -1396,6 +1391,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (isWithinComment(sourceModule, position)) return {}; + NotNull singletonTypes = typeChecker.singletonTypes; + std::vector ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position); LUAU_ASSERT(!ancestry.empty()); AstNode* node = ancestry.back(); @@ -1422,10 +1419,11 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; if (!FFlag::LuauSelfCallAutocompleteFix3 && isString(ty)) - return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), ancestry, - AutocompleteContext::Property}; + return {autocompleteProps( + *module, typeArena, singletonTypes, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), + ancestry, AutocompleteContext::Property}; else - return {autocompleteProps(*module, typeArena, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; + return {autocompleteProps(*module, typeArena, singletonTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; } else if (auto typeReference = node->as()) { @@ -1548,7 +1546,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, ancestry); + auto result = autocompleteProps(*module, typeArena, singletonTypes, *it, PropIndexType::Key, ancestry); // Remove keys that are already completed for (const auto& item : exprTable->items) @@ -1590,7 +1588,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) { if (auto it = module->astTypes.find(idxExpr->expr)) - autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, ancestry, result); + autocompleteProps(*module, typeArena, singletonTypes, follow(*it), PropIndexType::Point, ancestry, result); } else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) { diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index e011eaa55..8f4863d00 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -6,6 +6,7 @@ #include "Luau/Common.h" #include "Luau/ToString.h" #include "Luau/ConstraintSolver.h" +#include "Luau/TypeInfer.h" #include @@ -45,6 +46,11 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types) return arena.addType(IntersectionTypeVar{std::move(types)}); } +TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t) +{ + return makeUnion(arena, {frontend.typeChecker.nilType, t}); +} + TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t) { return makeUnion(arena, {typeChecker.nilType, t}); @@ -128,32 +134,48 @@ Property makeProperty(TypeId ty, std::optional documentationSymbol) }; } +void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName) +{ + addGlobalBinding(frontend, frontend.getGlobalScope(), name, ty, packageName); +} + +void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); + void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName) { addGlobalBinding(typeChecker, typeChecker.globalScope, name, ty, packageName); } +void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding) +{ + addGlobalBinding(frontend, frontend.getGlobalScope(), name, binding); +} + void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding) { addGlobalBinding(typeChecker, typeChecker.globalScope, name, binding); } +void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) +{ + std::string documentationSymbol = packageName + "/global/" + name; + addGlobalBinding(frontend, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); +} + void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) { std::string documentationSymbol = packageName + "/global/" + name; addGlobalBinding(typeChecker, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); } -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding) +void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding) { - scope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = binding; + addGlobalBinding(frontend.typeChecker, scope, name, binding); } -TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name) +void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding) { - auto t = tryGetGlobalBinding(typeChecker, name); - LUAU_ASSERT(t.has_value()); - return t->typeId; + scope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = binding; } std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std::string& name) @@ -166,6 +188,23 @@ std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std:: return std::nullopt; } +TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name) +{ + auto t = tryGetGlobalBinding(typeChecker, name); + LUAU_ASSERT(t.has_value()); + return t->typeId; +} + +TypeId getGlobalBinding(Frontend& frontend, const std::string& name) +{ + return getGlobalBinding(frontend.typeChecker, name); +} + +std::optional tryGetGlobalBinding(Frontend& frontend, const std::string& name) +{ + return tryGetGlobalBinding(frontend.typeChecker, name); +} + Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name) { AstName astName = typeChecker.globalNames.names->get(name.c_str()); @@ -195,6 +234,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId nilType = typeChecker.nilType; TypeArena& arena = typeChecker.globalTypes; + NotNull singletonTypes = typeChecker.singletonTypes; LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); @@ -203,7 +243,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); - std::optional stringMetatableTy = getMetatable(getSingletonTypes().stringType); + std::optional stringMetatableTy = getMetatable(singletonTypes->stringType, singletonTypes); LUAU_ASSERT(stringMetatableTy); const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); LUAU_ASSERT(stringMetatableTable); @@ -277,6 +317,98 @@ void registerBuiltinTypes(TypeChecker& typeChecker) attachDcrMagicFunction(getGlobalBinding(typeChecker, "require"), dcrMagicFunctionRequire); } +void registerBuiltinTypes(Frontend& frontend) +{ + LUAU_ASSERT(!frontend.globalTypes.typeVars.isFrozen()); + LUAU_ASSERT(!frontend.globalTypes.typePacks.isFrozen()); + + TypeId nilType = frontend.typeChecker.nilType; + + TypeArena& arena = frontend.globalTypes; + NotNull singletonTypes = frontend.singletonTypes; + + LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau"); + LUAU_ASSERT(loadResult.success); + + TypeId genericK = arena.addType(GenericTypeVar{"K"}); + TypeId genericV = arena.addType(GenericTypeVar{"V"}); + TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), frontend.getGlobalScope()->level, TableState::Generic}); + + std::optional stringMetatableTy = getMetatable(singletonTypes->stringType, singletonTypes); + LUAU_ASSERT(stringMetatableTy); + const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); + LUAU_ASSERT(stringMetatableTable); + + auto it = stringMetatableTable->props.find("__index"); + LUAU_ASSERT(it != stringMetatableTable->props.end()); + + addGlobalBinding(frontend, "string", it->second.type, "@luau"); + + // next(t: Table, i: K?) -> (K, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); + addGlobalBinding(frontend, "next", + arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding(frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + + TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); + + TableTypeVar tab{TableState::Generic, frontend.getGlobalScope()->level}; + TypeId tabTy = arena.addType(tab); + + TypeId tableMetaMT = arena.addType(MetatableTypeVar{tabTy, genericMT}); + + addGlobalBinding(frontend, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); + + // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } + addGlobalBinding(frontend, "setmetatable", + arena.addType( + FunctionTypeVar{ + {genericMT}, + {}, + arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}), + arena.addTypePack(TypePack{{tableMetaMT}}) + } + ), "@luau" + ); + // clang-format on + + for (const auto& pair : frontend.getGlobalScope()->bindings) + { + persist(pair.second.typeId); + + if (TableTypeVar* ttv = getMutable(pair.second.typeId)) + { + if (!ttv->name) + ttv->name = toString(pair.first); + } + } + + attachMagicFunction(getGlobalBinding(frontend, "assert"), magicFunctionAssert); + attachMagicFunction(getGlobalBinding(frontend, "setmetatable"), magicFunctionSetMetaTable); + attachMagicFunction(getGlobalBinding(frontend, "select"), magicFunctionSelect); + + if (TableTypeVar* ttv = getMutable(getGlobalBinding(frontend, "table"))) + { + // tabTy is a generic table type which we can't express via declaration syntax yet + ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); + ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + + attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); + } + + attachMagicFunction(getGlobalBinding(frontend, "require"), magicFunctionRequire); + attachDcrMagicFunction(getGlobalBinding(frontend, "require"), dcrMagicFunctionRequire); +} + + static std::optional> magicFunctionSelect( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 9fabc528d..e9c61e412 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -7,8 +7,10 @@ #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" #include "Luau/ToString.h" +#include "Luau/DcrLogger.h" LUAU_FASTINT(LuauCheckRecursionLimit); +LUAU_FASTFLAG(DebugLuauLogSolverToJson); #include "Luau/Scope.h" @@ -35,17 +37,20 @@ static std::optional matchRequire(const AstExprCall& call) } ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, - NotNull moduleResolver, NotNull ice, const ScopePtr& globalScope) + NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger) : moduleName(moduleName) , module(module) - , singletonTypes(getSingletonTypes()) + , singletonTypes(singletonTypes) , arena(arena) , rootScope(nullptr) , moduleResolver(moduleResolver) , ice(ice) , globalScope(globalScope) + , logger(logger) { - LUAU_ASSERT(arena); + if (FFlag::DebugLuauLogSolverToJson) + LUAU_ASSERT(logger); + LUAU_ASSERT(module); } @@ -66,6 +71,7 @@ ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& paren scopes.emplace_back(node->location, scope); scope->returnType = parent->returnType; + scope->varargPack = parent->varargPack; parent->children.push_back(NotNull{scope.get()}); module->astScopes[node] = scope.get(); @@ -282,7 +288,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) return; TypeId t = check(scope, expr); - addConstraint(scope, expr->location, SubtypeConstraint{t, singletonTypes.numberType}); + addConstraint(scope, expr->location, SubtypeConstraint{t, singletonTypes->numberType}); }; checkNumber(for_->from); @@ -290,7 +296,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) checkNumber(for_->step); ScopePtr forScope = childScope(for_, scope); - forScope->bindings[for_->var] = Binding{singletonTypes.numberType, for_->var->location}; + forScope->bindings[for_->var] = Binding{singletonTypes->numberType, for_->var->location}; visit(forScope, for_->body); } @@ -435,7 +441,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct } else if (AstExprError* err = function->name->as()) { - functionType = singletonTypes.errorRecoveryType(); + functionType = singletonTypes->errorRecoveryType(); } LUAU_ASSERT(functionType != nullptr); @@ -657,12 +663,18 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction std::vector genericTys; genericTys.reserve(generics.size()); for (auto& [name, generic] : generics) + { genericTys.push_back(generic.ty); + scope->privateTypeBindings[name] = TypeFun{generic.ty}; + } std::vector genericTps; genericTps.reserve(genericPacks.size()); for (auto& [name, generic] : genericPacks) + { genericTps.push_back(generic.tp); + scope->privateTypePackBindings[name] = generic.tp; + } ScopePtr funScope = scope; if (!generics.empty() || !genericPacks.empty()) @@ -710,7 +722,7 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(expr->location); - return singletonTypes.errorRecoveryTypePack(); + return singletonTypes->errorRecoveryTypePack(); } TypePackId result = nullptr; @@ -758,7 +770,7 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp if (scope->varargPack) result = *scope->varargPack; else - result = singletonTypes.errorRecoveryTypePack(); + result = singletonTypes->errorRecoveryTypePack(); } else { @@ -778,7 +790,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr) if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(expr->location); - return singletonTypes.errorRecoveryType(); + return singletonTypes->errorRecoveryType(); } TypeId result = nullptr; @@ -786,20 +798,20 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr) if (auto group = expr->as()) result = check(scope, group->expr); else if (expr->is()) - result = singletonTypes.stringType; + result = singletonTypes->stringType; else if (expr->is()) - result = singletonTypes.numberType; + result = singletonTypes->numberType; else if (expr->is()) - result = singletonTypes.booleanType; + result = singletonTypes->booleanType; else if (expr->is()) - result = singletonTypes.nilType; + result = singletonTypes->nilType; else if (auto a = expr->as()) { std::optional ty = scope->lookup(a->local); if (ty) result = *ty; else - result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? + result = singletonTypes->errorRecoveryType(); // FIXME? Record an error at this point? } else if (auto g = expr->as()) { @@ -812,7 +824,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr) * global that is not already in-scope is definitely an unknown symbol. */ reportError(g->location, UnknownSymbol{g->name.value}); - result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? + result = singletonTypes->errorRecoveryType(); // FIXME? Record an error at this point? } } else if (expr->is()) @@ -842,7 +854,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr) else if (auto err = expr->as()) { // Open question: Should we traverse into this? - result = singletonTypes.errorRecoveryType(); + result = singletonTypes->errorRecoveryType(); } else { @@ -903,7 +915,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) } LUAU_UNREACHABLE(); - return singletonTypes.errorRecoveryType(); + return singletonTypes->errorRecoveryType(); } TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary) @@ -1003,7 +1015,7 @@ TypeId ConstraintGraphBuilder::checkExprTable(const ScopePtr& scope, AstExprTabl } else { - TypeId numberType = singletonTypes.numberType; + TypeId numberType = singletonTypes->numberType; // FIXME? The location isn't quite right here. Not sure what is // right. createIndexer(item.value->location, numberType, itemTy); @@ -1068,6 +1080,23 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS signatureScope = bodyScope; } + std::optional varargPack; + + if (fn->vararg) + { + if (fn->varargAnnotation) + { + TypePackId annotationType = resolveTypePack(signatureScope, fn->varargAnnotation); + varargPack = annotationType; + } + else + { + varargPack = arena->freshTypePack(signatureScope.get()); + } + + signatureScope->varargPack = varargPack; + } + if (fn->returnAnnotation) { TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation); @@ -1092,7 +1121,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // TODO: Vararg annotation. // TODO: Preserve argument names in the function's type. - FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; + FunctionTypeVar actualFunction{arena->addTypePack(argTypes, varargPack), returnType}; actualFunction.hasNoGenerics = !hasGenerics; actualFunction.generics = std::move(genericTypes); actualFunction.genericPacks = std::move(genericTypePacks); @@ -1175,7 +1204,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b else { reportError(ty->location, UnknownSymbol{ref->name.value, UnknownSymbol::Context::Type}); - result = singletonTypes.errorRecoveryType(); + result = singletonTypes->errorRecoveryType(); } } else if (auto tab = ty->as()) @@ -1308,12 +1337,12 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b } else if (ty->is()) { - result = singletonTypes.errorRecoveryType(); + result = singletonTypes->errorRecoveryType(); } else { LUAU_ASSERT(0); - result = singletonTypes.errorRecoveryType(); + result = singletonTypes->errorRecoveryType(); } astResolvedTypes[ty] = result; @@ -1341,13 +1370,13 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTyp else { reportError(tp->location, UnknownSymbol{gen->genericName.value, UnknownSymbol::Context::Type}); - result = singletonTypes.errorRecoveryTypePack(); + result = singletonTypes->errorRecoveryTypePack(); } } else { LUAU_ASSERT(0); - result = singletonTypes.errorRecoveryTypePack(); + result = singletonTypes->errorRecoveryTypePack(); } astResolvedTypePacks[tp] = result; @@ -1430,11 +1459,17 @@ TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location locat void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) { errors.push_back(TypeError{location, moduleName, std::move(err)}); + + if (FFlag::DebugLuauLogSolverToJson) + logger->captureGenerationError(errors.back()); } void ConstraintGraphBuilder::reportCodeTooComplex(Location location) { errors.push_back(TypeError{location, moduleName, CodeTooComplex{}}); + + if (FFlag::DebugLuauLogSolverToJson) + logger->captureGenerationError(errors.back()); } struct GlobalPrepopulator : AstVisitor diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 1088d9824..f964a855a 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -9,6 +9,7 @@ #include "Luau/Quantify.h" #include "Luau/ToString.h" #include "Luau/Unifier.h" +#include "Luau/DcrLogger.h" #include "Luau/VisitTypeVar.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); @@ -50,8 +51,8 @@ static void dumpConstraints(NotNull scope, ToStringOptions& opts) dumpConstraints(child, opts); } -static std::pair, std::vector> saturateArguments( - const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments, TypeArena* arena) +static std::pair, std::vector> saturateArguments(TypeArena* arena, NotNull singletonTypes, + const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments) { std::vector saturatedTypeArguments; std::vector extraTypes; @@ -131,7 +132,7 @@ static std::pair, std::vector> saturateArguments if (!defaultTy) break; - TypeId instantiatedDefault = atf.substitute(defaultTy).value_or(getSingletonTypes().errorRecoveryType()); + TypeId instantiatedDefault = atf.substitute(defaultTy).value_or(singletonTypes->errorRecoveryType()); atf.typeArguments[fn.typeParams[i].ty] = instantiatedDefault; saturatedTypeArguments.push_back(instantiatedDefault); } @@ -149,7 +150,7 @@ static std::pair, std::vector> saturateArguments if (!defaultTp) break; - TypePackId instantiatedDefault = atf.substitute(defaultTp).value_or(getSingletonTypes().errorRecoveryTypePack()); + TypePackId instantiatedDefault = atf.substitute(defaultTp).value_or(singletonTypes->errorRecoveryTypePack()); atf.typePackArguments[fn.typePackParams[i].tp] = instantiatedDefault; saturatedPackArguments.push_back(instantiatedDefault); } @@ -167,12 +168,12 @@ static std::pair, std::vector> saturateArguments // even if they're missing, so we use the error type as a filler. for (size_t i = saturatedTypeArguments.size(); i < typesRequired; ++i) { - saturatedTypeArguments.push_back(getSingletonTypes().errorRecoveryType()); + saturatedTypeArguments.push_back(singletonTypes->errorRecoveryType()); } for (size_t i = saturatedPackArguments.size(); i < packsRequired; ++i) { - saturatedPackArguments.push_back(getSingletonTypes().errorRecoveryTypePack()); + saturatedPackArguments.push_back(singletonTypes->errorRecoveryTypePack()); } // At this point, these two conditions should be true. If they aren't we @@ -242,14 +243,16 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } } -ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, - std::vector requireCycles) +ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull singletonTypes, NotNull rootScope, ModuleName moduleName, + NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) : arena(arena) + , singletonTypes(singletonTypes) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) , moduleResolver(moduleResolver) , requireCycles(requireCycles) + , logger(logger) { opts.exhaustive = true; @@ -262,6 +265,9 @@ ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull rootScope, M block(dep, c); } } + + if (FFlag::DebugLuauLogSolverToJson) + LUAU_ASSERT(logger); } void ConstraintSolver::run() @@ -277,7 +283,7 @@ void ConstraintSolver::run() if (FFlag::DebugLuauLogSolverToJson) { - logger.captureBoundarySnapshot(rootScope, unsolvedConstraints); + logger->captureInitialSolverState(rootScope, unsolvedConstraints); } auto runSolverPass = [&](bool force) { @@ -294,10 +300,11 @@ void ConstraintSolver::run() } std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; + StepSnapshot snapshot; if (FFlag::DebugLuauLogSolverToJson) { - logger.prepareStepSnapshot(rootScope, c, unsolvedConstraints, force); + snapshot = logger->prepareStepSnapshot(rootScope, c, force, unsolvedConstraints); } bool success = tryDispatch(c, force); @@ -311,7 +318,7 @@ void ConstraintSolver::run() if (FFlag::DebugLuauLogSolverToJson) { - logger.commitPreparedStepSnapshot(); + logger->commitStepSnapshot(snapshot); } if (FFlag::DebugLuauLogSolver) @@ -347,8 +354,7 @@ void ConstraintSolver::run() if (FFlag::DebugLuauLogSolverToJson) { - logger.captureBoundarySnapshot(rootScope, unsolvedConstraints); - printf("Logger output:\n%s\n", logger.compileOutput().c_str()); + logger->captureFinalSolverState(rootScope, unsolvedConstraints); } } @@ -516,7 +522,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(getSingletonTypes().errorRecoveryType()); + asMutable(resultType)->ty.emplace(singletonTypes->errorRecoveryType()); // reportError(constraint->location, CannotInferBinaryOperation{c.op, std::nullopt, CannotInferBinaryOperation::Operation}); return true; } @@ -571,7 +577,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope, &iceReporter, getSingletonTypes().errorRecoveryType(), getSingletonTypes().errorRecoveryTypePack()}; + arena, constraint->scope, singletonTypes, &iceReporter, singletonTypes->errorRecoveryType(), singletonTypes->errorRecoveryTypePack()}; std::optional anyified = anyify.substitute(c.variables); LUAU_ASSERT(anyified); unify(*anyified, c.variables, constraint->scope); @@ -585,11 +591,11 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull(nextTy)) { - TypeId tableTy = getSingletonTypes().nilType; + TypeId tableTy = singletonTypes->nilType; if (iteratorTypes.size() >= 2) tableTy = iteratorTypes[1]; - TypeId firstIndexTy = getSingletonTypes().nilType; + TypeId firstIndexTy = singletonTypes->nilType; if (iteratorTypes.size() >= 3) firstIndexTy = iteratorTypes[2]; @@ -644,7 +650,7 @@ struct InfiniteTypeFinder : TypeVarOnceVisitor if (!tf.has_value()) return true; - auto [typeArguments, packArguments] = saturateArguments(*tf, petv.typeArguments, petv.packArguments, solver->arena); + auto [typeArguments, packArguments] = saturateArguments(solver->arena, solver->singletonTypes, *tf, petv.typeArguments, petv.packArguments); if (follow(tf->type) == follow(signature.fn.type) && (signature.arguments != typeArguments || signature.packArguments != packArguments)) { @@ -698,7 +704,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (!tf.has_value()) { reportError(UnknownSymbol{petv->name.value, UnknownSymbol::Context::Type}, constraint->location); - bindResult(getSingletonTypes().errorRecoveryType()); + bindResult(singletonTypes->errorRecoveryType()); return true; } @@ -710,7 +716,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } - auto [typeArguments, packArguments] = saturateArguments(*tf, petv->typeArguments, petv->packArguments, arena); + auto [typeArguments, packArguments] = saturateArguments(arena, singletonTypes, *tf, petv->typeArguments, petv->packArguments); bool sameTypes = std::equal(typeArguments.begin(), typeArguments.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& p) { return itp == p.ty; @@ -757,7 +763,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (itf.foundInfiniteType) { // TODO (CLI-56761): Report an error. - bindResult(getSingletonTypes().errorRecoveryType()); + bindResult(singletonTypes->errorRecoveryType()); return true; } @@ -780,7 +786,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (!maybeInstantiated.has_value()) { // TODO (CLI-56761): Report an error. - bindResult(getSingletonTypes().errorRecoveryType()); + bindResult(singletonTypes->errorRecoveryType()); return true; } @@ -894,7 +900,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl return block_(iteratorTy); auto anyify = [&](auto ty) { - Anyification anyify{arena, constraint->scope, &iceReporter, getSingletonTypes().anyType, getSingletonTypes().anyTypePack}; + Anyification anyify{arena, constraint->scope, singletonTypes, &iceReporter, singletonTypes->anyType, singletonTypes->anyTypePack}; std::optional anyified = anyify.substitute(ty); if (!anyified) reportError(CodeTooComplex{}, constraint->location); @@ -904,7 +910,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl auto errorify = [&](auto ty) { Anyification anyify{ - arena, constraint->scope, &iceReporter, getSingletonTypes().errorRecoveryType(), getSingletonTypes().errorRecoveryTypePack()}; + arena, constraint->scope, singletonTypes, &iceReporter, singletonTypes->errorRecoveryType(), singletonTypes->errorRecoveryTypePack()}; std::optional errorified = anyify.substitute(ty); if (!errorified) reportError(CodeTooComplex{}, constraint->location); @@ -973,7 +979,7 @@ bool ConstraintSolver::tryDispatchIterableFunction( : firstIndexTy; // nextTy : (tableTy, indexTy?) -> (indexTy, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({tableTy, arena->addType(UnionTypeVar{{firstIndex, getSingletonTypes().nilType}})}); + const TypePackId nextArgPack = arena->addTypePack({tableTy, arena->addType(UnionTypeVar{{firstIndex, singletonTypes->nilType}})}); const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); @@ -995,23 +1001,35 @@ void ConstraintSolver::block_(BlockedConstraintId target, NotNull target, NotNull constraint) { + if (FFlag::DebugLuauLogSolverToJson) + logger->pushBlock(constraint, target); + if (FFlag::DebugLuauLogSolver) printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str()); + block_(target, constraint); } bool ConstraintSolver::block(TypeId target, NotNull constraint) { + if (FFlag::DebugLuauLogSolverToJson) + logger->pushBlock(constraint, target); + if (FFlag::DebugLuauLogSolver) printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + block_(target, constraint); return false; } bool ConstraintSolver::block(TypePackId target, NotNull constraint) { + if (FFlag::DebugLuauLogSolverToJson) + logger->pushBlock(constraint, target); + if (FFlag::DebugLuauLogSolver) printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + block_(target, constraint); return false; } @@ -1042,16 +1060,25 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) void ConstraintSolver::unblock(NotNull progressed) { + if (FFlag::DebugLuauLogSolverToJson) + logger->popBlock(progressed); + return unblock_(progressed); } void ConstraintSolver::unblock(TypeId progressed) { + if (FFlag::DebugLuauLogSolverToJson) + logger->popBlock(progressed); + return unblock_(progressed); } void ConstraintSolver::unblock(TypePackId progressed) { + if (FFlag::DebugLuauLogSolverToJson) + logger->popBlock(progressed); + return unblock_(progressed); } @@ -1086,13 +1113,13 @@ bool ConstraintSolver::isBlocked(NotNull constraint) void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, scope, Location{}, Covariant, sharedState}; + Unifier u{arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; u.tryUnify(subType, superType); if (!u.errors.empty()) { - TypeId errorType = getSingletonTypes().errorRecoveryType(); + TypeId errorType = singletonTypes->errorRecoveryType(); u.tryUnify(subType, errorType); u.tryUnify(superType, errorType); } @@ -1108,7 +1135,7 @@ void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull sc void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, scope, Location{}, Covariant, sharedState}; + Unifier u{arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; u.tryUnify(subPack, superPack); @@ -1133,7 +1160,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l if (info.name.empty()) { reportError(UnknownRequire{}, location); - return getSingletonTypes().errorRecoveryType(); + return singletonTypes->errorRecoveryType(); } std::string humanReadableName = moduleResolver->getHumanReadableModuleName(info.name); @@ -1141,7 +1168,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l for (const auto& [location, path] : requireCycles) { if (!path.empty() && path.front() == humanReadableName) - return getSingletonTypes().anyType; + return singletonTypes->anyType; } ModulePtr module = moduleResolver->getModule(info.name); @@ -1150,24 +1177,24 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l if (!moduleResolver->moduleExists(info.name) && !info.optional) reportError(UnknownRequire{humanReadableName}, location); - return getSingletonTypes().errorRecoveryType(); + return singletonTypes->errorRecoveryType(); } if (module->type != SourceCode::Type::Module) { reportError(IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}, location); - return getSingletonTypes().errorRecoveryType(); + return singletonTypes->errorRecoveryType(); } TypePackId modulePack = module->getModuleScope()->returnType; if (get(modulePack)) - return getSingletonTypes().errorRecoveryType(); + return singletonTypes->errorRecoveryType(); std::optional moduleType = first(modulePack); if (!moduleType) { reportError(IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}, location); - return getSingletonTypes().errorRecoveryType(); + return singletonTypes->errorRecoveryType(); } return *moduleType; diff --git a/Analysis/src/ConstraintSolverLogger.cpp b/Analysis/src/ConstraintSolverLogger.cpp deleted file mode 100644 index 5ba405216..000000000 --- a/Analysis/src/ConstraintSolverLogger.cpp +++ /dev/null @@ -1,150 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - -#include "Luau/ConstraintSolverLogger.h" - -#include "Luau/JsonEmitter.h" -#include "Luau/ToString.h" - -LUAU_FASTFLAG(LuauFixNameMaps); - -namespace Luau -{ - -static void dumpScopeAndChildren(const Scope* scope, Json::JsonEmitter& emitter, ToStringOptions& opts) -{ - emitter.writeRaw("{"); - Json::write(emitter, "bindings"); - emitter.writeRaw(":"); - - Json::ObjectEmitter o = emitter.writeObject(); - - for (const auto& [name, binding] : scope->bindings) - { - if (FFlag::LuauFixNameMaps) - o.writePair(name.c_str(), toString(binding.typeId, opts)); - else - { - ToStringResult result = toStringDetailed(binding.typeId, opts); - opts.DEPRECATED_nameMap = std::move(result.DEPRECATED_nameMap); - o.writePair(name.c_str(), result.name); - } - } - - o.finish(); - emitter.writeRaw(","); - Json::write(emitter, "children"); - emitter.writeRaw(":"); - - Json::ArrayEmitter a = emitter.writeArray(); - for (const Scope* child : scope->children) - { - emitter.writeComma(); - dumpScopeAndChildren(child, emitter, opts); - } - - a.finish(); - emitter.writeRaw("}"); -} - -static std::string dumpConstraintsToDot(std::vector>& constraints, ToStringOptions& opts) -{ - std::string result = "digraph Constraints {\n"; - result += "rankdir=LR\n"; - - std::unordered_set> contained; - for (NotNull c : constraints) - { - contained.insert(c); - } - - for (NotNull c : constraints) - { - std::string shape; - if (get(*c)) - shape = "box"; - else if (get(*c)) - shape = "box3d"; - else - shape = "oval"; - - std::string id = std::to_string(reinterpret_cast(c.get())); - result += id; - result += " [label=\""; - result += toString(*c, opts); - result += "\" shape=" + shape + "];\n"; - - for (NotNull dep : c->dependencies) - { - if (contained.count(dep) == 0) - continue; - - result += std::to_string(reinterpret_cast(dep.get())); - result += " -> "; - result += id; - result += ";\n"; - } - } - - result += "}"; - - return result; -} - -std::string ConstraintSolverLogger::compileOutput() -{ - Json::JsonEmitter emitter; - emitter.writeRaw("["); - for (const std::string& snapshot : snapshots) - { - emitter.writeComma(); - emitter.writeRaw(snapshot); - } - - emitter.writeRaw("]"); - return emitter.str(); -} - -void ConstraintSolverLogger::captureBoundarySnapshot(const Scope* rootScope, std::vector>& unsolvedConstraints) -{ - Json::JsonEmitter emitter; - Json::ObjectEmitter o = emitter.writeObject(); - o.writePair("type", "boundary"); - o.writePair("constraintGraph", dumpConstraintsToDot(unsolvedConstraints, opts)); - emitter.writeComma(); - Json::write(emitter, "rootScope"); - emitter.writeRaw(":"); - dumpScopeAndChildren(rootScope, emitter, opts); - o.finish(); - - snapshots.push_back(emitter.str()); -} - -void ConstraintSolverLogger::prepareStepSnapshot( - const Scope* rootScope, NotNull current, std::vector>& unsolvedConstraints, bool force) -{ - Json::JsonEmitter emitter; - Json::ObjectEmitter o = emitter.writeObject(); - o.writePair("type", "step"); - o.writePair("constraintGraph", dumpConstraintsToDot(unsolvedConstraints, opts)); - o.writePair("currentId", std::to_string(reinterpret_cast(current.get()))); - o.writePair("current", toString(*current, opts)); - o.writePair("force", force); - emitter.writeComma(); - Json::write(emitter, "rootScope"); - emitter.writeRaw(":"); - dumpScopeAndChildren(rootScope, emitter, opts); - o.finish(); - - preparedSnapshot = emitter.str(); -} - -void ConstraintSolverLogger::commitPreparedStepSnapshot() -{ - if (preparedSnapshot) - { - snapshots.push_back(std::move(*preparedSnapshot)); - preparedSnapshot = std::nullopt; - } -} - -} // namespace Luau diff --git a/Analysis/src/DcrLogger.cpp b/Analysis/src/DcrLogger.cpp new file mode 100644 index 000000000..a2eb96e5c --- /dev/null +++ b/Analysis/src/DcrLogger.cpp @@ -0,0 +1,395 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/DcrLogger.h" + +#include + +#include "Luau/JsonEmitter.h" + +namespace Luau +{ + +namespace Json +{ + +void write(JsonEmitter& emitter, const Location& location) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("beginLine", location.begin.line); + o.writePair("beginColumn", location.begin.column); + o.writePair("endLine", location.end.line); + o.writePair("endColumn", location.end.column); + o.finish(); +} + +void write(JsonEmitter& emitter, const ErrorSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("message", snapshot.message); + o.writePair("location", snapshot.location); + o.finish(); +} + +void write(JsonEmitter& emitter, const BindingSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("typeId", snapshot.typeId); + o.writePair("typeString", snapshot.typeString); + o.writePair("location", snapshot.location); + o.finish(); +} + +void write(JsonEmitter& emitter, const TypeBindingSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("typeId", snapshot.typeId); + o.writePair("typeString", snapshot.typeString); + o.finish(); +} + +void write(JsonEmitter& emitter, const ConstraintGenerationLog& log) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("source", log.source); + + emitter.writeComma(); + write(emitter, "constraintLocations"); + emitter.writeRaw(":"); + + ObjectEmitter locationEmitter = emitter.writeObject(); + + for (const auto& [id, location] : log.constraintLocations) + { + locationEmitter.writePair(id, location); + } + + locationEmitter.finish(); + o.writePair("errors", log.errors); + o.finish(); +} + +void write(JsonEmitter& emitter, const ScopeSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("bindings", snapshot.bindings); + o.writePair("typeBindings", snapshot.typeBindings); + o.writePair("typePackBindings", snapshot.typePackBindings); + o.writePair("children", snapshot.children); + o.finish(); +} + +void write(JsonEmitter& emitter, const ConstraintBlockKind& kind) +{ + switch (kind) + { + case ConstraintBlockKind::TypeId: + return write(emitter, "type"); + case ConstraintBlockKind::TypePackId: + return write(emitter, "typePack"); + case ConstraintBlockKind::ConstraintId: + return write(emitter, "constraint"); + default: + LUAU_ASSERT(0); + } +} + +void write(JsonEmitter& emitter, const ConstraintBlock& block) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("kind", block.kind); + o.writePair("stringification", block.stringification); + o.finish(); +} + +void write(JsonEmitter& emitter, const ConstraintSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("stringification", snapshot.stringification); + o.writePair("blocks", snapshot.blocks); + o.finish(); +} + +void write(JsonEmitter& emitter, const BoundarySnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("rootScope", snapshot.rootScope); + o.writePair("constraints", snapshot.constraints); + o.finish(); +} + +void write(JsonEmitter& emitter, const StepSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("currentConstraint", snapshot.currentConstraint); + o.writePair("forced", snapshot.forced); + o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints); + o.writePair("rootScope", snapshot.rootScope); + o.finish(); +} + +void write(JsonEmitter& emitter, const TypeSolveLog& log) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("initialState", log.initialState); + o.writePair("stepStates", log.stepStates); + o.writePair("finalState", log.finalState); + o.finish(); +} + +void write(JsonEmitter& emitter, const TypeCheckLog& log) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("errors", log.errors); + o.finish(); +} + +} // namespace Json + +static std::string toPointerId(NotNull ptr) +{ + return std::to_string(reinterpret_cast(ptr.get())); +} + +static ScopeSnapshot snapshotScope(const Scope* scope, ToStringOptions& opts) +{ + std::unordered_map bindings; + std::unordered_map typeBindings; + std::unordered_map typePackBindings; + std::vector children; + + for (const auto& [name, binding] : scope->bindings) + { + std::string id = std::to_string(reinterpret_cast(binding.typeId)); + ToStringResult result = toStringDetailed(binding.typeId, opts); + + bindings[name.c_str()] = BindingSnapshot{ + id, + result.name, + binding.location, + }; + } + + for (const auto& [name, tf] : scope->exportedTypeBindings) + { + std::string id = std::to_string(reinterpret_cast(tf.type)); + + typeBindings[name] = TypeBindingSnapshot{ + id, + toString(tf.type, opts), + }; + } + + for (const auto& [name, tf] : scope->privateTypeBindings) + { + std::string id = std::to_string(reinterpret_cast(tf.type)); + + typeBindings[name] = TypeBindingSnapshot{ + id, + toString(tf.type, opts), + }; + } + + for (const auto& [name, tp] : scope->privateTypePackBindings) + { + std::string id = std::to_string(reinterpret_cast(tp)); + + typePackBindings[name] = TypeBindingSnapshot{ + id, + toString(tp, opts), + }; + } + + for (const auto& child : scope->children) + { + children.push_back(snapshotScope(child.get(), opts)); + } + + return ScopeSnapshot{ + bindings, + typeBindings, + typePackBindings, + children, + }; +} + +std::string DcrLogger::compileOutput() +{ + Json::JsonEmitter emitter; + Json::ObjectEmitter o = emitter.writeObject(); + o.writePair("generation", generationLog); + o.writePair("solve", solveLog); + o.writePair("check", checkLog); + o.finish(); + + return emitter.str(); +} + +void DcrLogger::captureSource(std::string source) +{ + generationLog.source = std::move(source); +} + +void DcrLogger::captureGenerationError(const TypeError& error) +{ + std::string stringifiedError = toString(error); + generationLog.errors.push_back(ErrorSnapshot { + /* message */ stringifiedError, + /* location */ error.location, + }); +} + +void DcrLogger::captureConstraintLocation(NotNull constraint, Location location) +{ + std::string id = toPointerId(constraint); + generationLog.constraintLocations[id] = location; +} + +void DcrLogger::pushBlock(NotNull constraint, TypeId block) +{ + constraintBlocks[constraint].push_back(block); +} + +void DcrLogger::pushBlock(NotNull constraint, TypePackId block) +{ + constraintBlocks[constraint].push_back(block); +} + +void DcrLogger::pushBlock(NotNull constraint, NotNull block) +{ + constraintBlocks[constraint].push_back(block); +} + +void DcrLogger::popBlock(TypeId block) +{ + for (auto& [_, list] : constraintBlocks) + { + list.erase(std::remove(list.begin(), list.end(), block), list.end()); + } +} + +void DcrLogger::popBlock(TypePackId block) +{ + for (auto& [_, list] : constraintBlocks) + { + list.erase(std::remove(list.begin(), list.end(), block), list.end()); + } +} + +void DcrLogger::popBlock(NotNull block) +{ + for (auto& [_, list] : constraintBlocks) + { + list.erase(std::remove(list.begin(), list.end(), block), list.end()); + } +} + +void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) +{ + solveLog.initialState.rootScope = snapshotScope(rootScope, opts); + solveLog.initialState.constraints.clear(); + + for (NotNull c : unsolvedConstraints) + { + std::string id = toPointerId(c); + solveLog.initialState.constraints[id] = { + toString(*c.get(), opts), + snapshotBlocks(c), + }; + } +} + +StepSnapshot DcrLogger::prepareStepSnapshot(const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints) +{ + ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts); + std::string currentId = toPointerId(current); + std::unordered_map constraints; + + for (NotNull c : unsolvedConstraints) + { + std::string id = toPointerId(c); + constraints[id] = { + toString(*c.get(), opts), + snapshotBlocks(c), + }; + } + + return StepSnapshot{ + currentId, + force, + constraints, + scopeSnapshot, + }; +} + +void DcrLogger::commitStepSnapshot(StepSnapshot snapshot) +{ + solveLog.stepStates.push_back(std::move(snapshot)); +} + +void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) +{ + solveLog.finalState.rootScope = snapshotScope(rootScope, opts); + solveLog.finalState.constraints.clear(); + + for (NotNull c : unsolvedConstraints) + { + std::string id = toPointerId(c); + solveLog.finalState.constraints[id] = { + toString(*c.get(), opts), + snapshotBlocks(c), + }; + } +} + +void DcrLogger::captureTypeCheckError(const TypeError& error) +{ + std::string stringifiedError = toString(error); + checkLog.errors.push_back(ErrorSnapshot { + /* message */ stringifiedError, + /* location */ error.location, + }); +} + +std::vector DcrLogger::snapshotBlocks(NotNull c) +{ + auto it = constraintBlocks.find(c); + if (it == constraintBlocks.end()) + { + return {}; + } + + std::vector snapshot; + + for (const ConstraintBlockTarget& target : it->second) + { + if (const TypeId* ty = get_if(&target)) + { + snapshot.push_back({ + ConstraintBlockKind::TypeId, + toString(*ty, opts), + }); + } + else if (const TypePackId* tp = get_if(&target)) + { + snapshot.push_back({ + ConstraintBlockKind::TypePackId, + toString(*tp, opts), + }); + } + else if (const NotNull* c = get_if>(&target)) + { + snapshot.push_back({ + ConstraintBlockKind::ConstraintId, + toString(*(c->get()), opts), + }); + } + else + { + LUAU_ASSERT(0); + } + } + + return snapshot; +} + +} // namespace Luau diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 456635314..0f04ace08 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -187,13 +187,9 @@ declare utf8: { char: (...number) -> string, charpattern: string, codes: (string) -> ((string, number) -> (number, number), string, number), - -- FIXME - codepoint: (string, number?, number?) -> (number, ...number), + codepoint: (string, number?, number?) -> ...number, len: (string, number?, number?) -> (number?, number?), offset: (string, number?, number?) -> number, - nfdnormalize: (string) -> string, - nfcnormalize: (string) -> string, - graphemes: (string, number?, number?) -> (() -> (number, number)), } -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index d8839f2f3..01e82baad 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -6,6 +6,7 @@ #include "Luau/Config.h" #include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintSolver.h" +#include "Luau/DcrLogger.h" #include "Luau/FileResolver.h" #include "Luau/Parser.h" #include "Luau/Scope.h" @@ -23,10 +24,12 @@ LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) +LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAG(DebugLuauLogSolverToJson); namespace Luau { @@ -389,11 +392,12 @@ double getTimestamp() } // namespace Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options) - : fileResolver(fileResolver) + : singletonTypes(NotNull{FFlag::LuauNoMoreGlobalSingletonTypes ? &singletonTypes_ : &DEPRECATED_getSingletonTypes()}) + , fileResolver(fileResolver) , moduleResolver(this) , moduleResolverForAutocomplete(this) - , typeChecker(&moduleResolver, &iceHandler) - , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, &iceHandler) + , typeChecker(&moduleResolver, singletonTypes, &iceHandler) + , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, singletonTypes, &iceHandler) , configResolver(configResolver) , options(options) { @@ -837,11 +841,22 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco { ModulePtr result = std::make_shared(); - ConstraintGraphBuilder cgb{sourceModule.name, result, &result->internalTypes, NotNull(&moduleResolver), NotNull(&iceHandler), getGlobalScope()}; + std::unique_ptr logger; + if (FFlag::DebugLuauLogSolverToJson) + { + logger = std::make_unique(); + std::optional source = fileResolver->readSource(sourceModule.name); + if (source) + { + logger->captureSource(source->source); + } + } + + ConstraintGraphBuilder cgb{sourceModule.name, result, &result->internalTypes, NotNull(&moduleResolver), singletonTypes, NotNull(&iceHandler), getGlobalScope(), logger.get()}; cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope), sourceModule.name, NotNull(&moduleResolver), requireCycles}; + ConstraintSolver cs{&result->internalTypes, singletonTypes, NotNull(cgb.rootScope), sourceModule.name, NotNull(&moduleResolver), requireCycles, logger.get()}; cs.run(); for (TypeError& e : cs.errors) @@ -855,9 +870,15 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco result->astResolvedTypePacks = std::move(cgb.astResolvedTypePacks); result->type = sourceModule.type; - Luau::check(sourceModule, result.get()); + Luau::check(singletonTypes, logger.get(), sourceModule, result.get()); + + if (FFlag::DebugLuauLogSolverToJson) + { + std::string output = logger->compileOutput(); + printf("%s\n", output.c_str()); + } - result->clonePublicInterface(iceHandler); + result->clonePublicInterface(singletonTypes, iceHandler); return result; } diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 669739a04..7f67a7db1 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,7 +14,6 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) -LUAU_FASTFLAGVARIABLE(LuauLintComparisonPrecedence, false) LUAU_FASTFLAGVARIABLE(LuauLintFixDeprecationMessage, false) namespace Luau @@ -2954,7 +2953,7 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_IntegerParsing)) LintIntegerParsing::process(context); - if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence) && FFlag::LuauLintComparisonPrecedence) + if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence)) LintComparisonPrecedence::process(context); std::sort(context.result.begin(), context.result.end(), WarningComparator()); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 4c9e95378..b9deac769 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -92,10 +92,12 @@ struct ForceNormal : TypeVarOnceVisitor struct ClonePublicInterface : Substitution { + NotNull singletonTypes; NotNull module; - ClonePublicInterface(const TxnLog* log, Module* module) + ClonePublicInterface(const TxnLog* log, NotNull singletonTypes, Module* module) : Substitution(log, &module->interfaceTypes) + , singletonTypes(singletonTypes) , module(module) { LUAU_ASSERT(module); @@ -147,7 +149,7 @@ struct ClonePublicInterface : Substitution else { module->errors.push_back(TypeError{module->scopes[0].first, UnificationTooComplex{}}); - return getSingletonTypes().errorRecoveryType(); + return singletonTypes->errorRecoveryType(); } } @@ -163,7 +165,7 @@ struct ClonePublicInterface : Substitution else { module->errors.push_back(TypeError{module->scopes[0].first, UnificationTooComplex{}}); - return getSingletonTypes().errorRecoveryTypePack(); + return singletonTypes->errorRecoveryTypePack(); } } @@ -208,7 +210,7 @@ Module::~Module() unfreeze(internalTypes); } -void Module::clonePublicInterface(InternalErrorReporter& ice) +void Module::clonePublicInterface(NotNull singletonTypes, InternalErrorReporter& ice) { LUAU_ASSERT(interfaceTypes.typeVars.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); @@ -222,7 +224,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) std::unordered_map* exportedTypeBindings = &moduleScope->exportedTypeBindings; TxnLog log; - ClonePublicInterface clonePublicInterface{&log, this}; + ClonePublicInterface clonePublicInterface{&log, singletonTypes, this}; if (FFlag::LuauClonePublicInterfaceLess) returnType = clonePublicInterface.cloneTypePack(returnType); @@ -243,12 +245,12 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) if (FFlag::LuauLowerBoundsCalculation) { - normalize(returnType, NotNull{this}, ice); + normalize(returnType, NotNull{this}, singletonTypes, ice); if (FFlag::LuauForceExportSurfacesToBeNormal) forceNormal.traverse(returnType); if (varargPack) { - normalize(*varargPack, NotNull{this}, ice); + normalize(*varargPack, NotNull{this}, singletonTypes, ice); if (FFlag::LuauForceExportSurfacesToBeNormal) forceNormal.traverse(*varargPack); } @@ -264,7 +266,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) tf = clone(tf, interfaceTypes, cloneState); if (FFlag::LuauLowerBoundsCalculation) { - normalize(tf.type, NotNull{this}, ice); + normalize(tf.type, NotNull{this}, singletonTypes, ice); // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables // won't be marked normal. If the types aren't normal by now, they never will be. @@ -275,7 +277,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) if (param.defaultValue) { - normalize(*param.defaultValue, NotNull{this}, ice); + normalize(*param.defaultValue, NotNull{this}, singletonTypes, ice); forceNormal.traverse(*param.defaultValue); } } @@ -301,7 +303,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) ty = clone(ty, interfaceTypes, cloneState); if (FFlag::LuauLowerBoundsCalculation) { - normalize(ty, NotNull{this}, ice); + normalize(ty, NotNull{this}, singletonTypes, ice); if (FFlag::LuauForceExportSurfacesToBeNormal) forceNormal.traverse(ty); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 94adaf5c7..c3f0bb9d6 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -54,11 +54,11 @@ struct Replacer } // anonymous namespace -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, InternalErrorReporter& ice) +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; - Unifier u{&arena, Mode::Strict, scope, Location{}, Covariant, sharedState}; + Unifier u{&arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; u.anyIsTop = true; u.tryUnify(subTy, superTy); @@ -66,11 +66,11 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, InternalError return ok; } -bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, InternalErrorReporter& ice) +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; - Unifier u{&arena, Mode::Strict, scope, Location{}, Covariant, sharedState}; + Unifier u{&arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; u.anyIsTop = true; u.tryUnify(subPack, superPack); @@ -133,15 +133,17 @@ struct Normalize final : TypeVarVisitor { using TypeVarVisitor::Set; - Normalize(TypeArena& arena, NotNull scope, InternalErrorReporter& ice) + Normalize(TypeArena& arena, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) : arena(arena) , scope(scope) + , singletonTypes(singletonTypes) , ice(ice) { } TypeArena& arena; NotNull scope; + NotNull singletonTypes; InternalErrorReporter& ice; int iterationLimit = 0; @@ -499,9 +501,9 @@ struct Normalize final : TypeVarVisitor for (TypeId& part : result) { - if (isSubtype(ty, part, scope, ice)) + if (isSubtype(ty, part, scope, singletonTypes, ice)) return; // no need to do anything - else if (isSubtype(part, ty, scope, ice)) + else if (isSubtype(part, ty, scope, singletonTypes, ice)) { part = ty; // replace the less general type by the more general one return; @@ -553,12 +555,12 @@ struct Normalize final : TypeVarVisitor bool merged = false; for (TypeId& part : result->parts) { - if (isSubtype(part, ty, scope, ice)) + if (isSubtype(part, ty, scope, singletonTypes, ice)) { merged = true; break; // no need to do anything } - else if (isSubtype(ty, part, scope, ice)) + else if (isSubtype(ty, part, scope, singletonTypes, ice)) { merged = true; part = ty; // replace the less general type by the more general one @@ -691,13 +693,14 @@ struct Normalize final : TypeVarVisitor /** * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) */ -std::pair normalize(TypeId ty, NotNull scope, TypeArena& arena, InternalErrorReporter& ice) +std::pair normalize( + TypeId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice) { CloneState state; if (FFlag::DebugLuauCopyBeforeNormalizing) (void)clone(ty, arena, state); - Normalize n{arena, scope, ice}; + Normalize n{arena, scope, singletonTypes, ice}; n.traverse(ty); return {ty, !n.limitExceeded}; @@ -707,39 +710,40 @@ std::pair normalize(TypeId ty, NotNull scope, TypeArena& ar // reclaim memory used by wantonly allocated intermediate types here. // The main wrinkle here is that we don't want clone() to copy a type if the source and dest // arena are the same. -std::pair normalize(TypeId ty, NotNull module, InternalErrorReporter& ice) +std::pair normalize(TypeId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice) { - return normalize(ty, NotNull{module->getModuleScope().get()}, module->internalTypes, ice); + return normalize(ty, NotNull{module->getModuleScope().get()}, module->internalTypes, singletonTypes, ice); } -std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice) +std::pair normalize(TypeId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice) { - return normalize(ty, NotNull{module.get()}, ice); + return normalize(ty, NotNull{module.get()}, singletonTypes, ice); } /** * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) */ -std::pair normalize(TypePackId tp, NotNull scope, TypeArena& arena, InternalErrorReporter& ice) +std::pair normalize( + TypePackId tp, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice) { CloneState state; if (FFlag::DebugLuauCopyBeforeNormalizing) (void)clone(tp, arena, state); - Normalize n{arena, scope, ice}; + Normalize n{arena, scope, singletonTypes, ice}; n.traverse(tp); return {tp, !n.limitExceeded}; } -std::pair normalize(TypePackId tp, NotNull module, InternalErrorReporter& ice) +std::pair normalize(TypePackId tp, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice) { - return normalize(tp, NotNull{module->getModuleScope().get()}, module->internalTypes, ice); + return normalize(tp, NotNull{module->getModuleScope().get()}, module->internalTypes, singletonTypes, ice); } -std::pair normalize(TypePackId tp, const ModulePtr& module, InternalErrorReporter& ice) +std::pair normalize(TypePackId tp, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice) { - return normalize(tp, NotNull{module.get()}, ice); + return normalize(tp, NotNull{module.get()}, singletonTypes, ice); } } // namespace Luau diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index c7980ab0b..abf31aee2 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -40,6 +40,15 @@ TypeId TypeArena::freshType(Scope* scope) return allocated; } +TypePackId TypeArena::freshTypePack(Scope* scope) +{ + TypePackId allocated = typePacks.allocate(FreeTypePack{scope}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + TypePackId TypeArena::addTypePack(std::initializer_list types) { TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 480bdf403..88363b43e 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1,8 +1,6 @@ - +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeChecker2.h" -#include - #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Clone.h" @@ -13,6 +11,12 @@ #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" +#include "Luau/ToString.h" +#include "Luau/DcrLogger.h" + +#include + +LUAU_FASTFLAG(DebugLuauLogSolverToJson); namespace Luau { @@ -54,18 +58,22 @@ struct StackPusher struct TypeChecker2 { + NotNull singletonTypes; + DcrLogger* logger; + InternalErrorReporter ice; // FIXME accept a pointer from Frontend const SourceModule* sourceModule; Module* module; - InternalErrorReporter ice; // FIXME accept a pointer from Frontend - SingletonTypes& singletonTypes; std::vector> stack; - TypeChecker2(const SourceModule* sourceModule, Module* module) - : sourceModule(sourceModule) + TypeChecker2(NotNull singletonTypes, DcrLogger* logger, const SourceModule* sourceModule, Module* module) + : singletonTypes(singletonTypes) + , logger(logger) + , sourceModule(sourceModule) , module(module) - , singletonTypes(getSingletonTypes()) { + if (FFlag::DebugLuauLogSolverToJson) + LUAU_ASSERT(logger); } std::optional pushStack(AstNode* node) @@ -85,7 +93,7 @@ struct TypeChecker2 if (tp) return follow(*tp); else - return singletonTypes.anyTypePack; + return singletonTypes->anyTypePack; } TypeId lookupType(AstExpr* expr) @@ -101,7 +109,7 @@ struct TypeChecker2 if (tp) return flattenPack(*tp); - return singletonTypes.anyType; + return singletonTypes->anyType; } TypeId lookupAnnotation(AstType* annotation) @@ -253,7 +261,7 @@ struct TypeChecker2 TypePackId actualRetType = reconstructPack(ret->list, arena); UnifierSharedState sharedState{&ice}; - Unifier u{&arena, Mode::Strict, stack.back(), ret->location, Covariant, sharedState}; + Unifier u{&arena, singletonTypes, Mode::Strict, stack.back(), ret->location, Covariant, sharedState}; u.anyIsTop = true; u.tryUnify(actualRetType, expectedRetType); @@ -299,7 +307,7 @@ struct TypeChecker2 if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - if (!isSubtype(*it, varType, stack.back(), ice)) + if (!isSubtype(*it, varType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{varType, *it}, value->location); } @@ -317,7 +325,7 @@ struct TypeChecker2 if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - if (!isSubtype(varType, valueType, stack.back(), ice)) + if (!isSubtype(varType, valueType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{varType, valueType}, value->location); } @@ -340,7 +348,7 @@ struct TypeChecker2 // "Render" a type pack out to an array of a given length. Expands // variadics and various other things to get there. - static std::vector flatten(TypeArena& arena, TypePackId pack, size_t length) + std::vector flatten(TypeArena& arena, TypePackId pack, size_t length) { std::vector result; @@ -376,7 +384,7 @@ struct TypeChecker2 else if (auto etp = get(tail)) { while (result.size() < length) - result.push_back(getSingletonTypes().errorRecoveryType()); + result.push_back(singletonTypes->errorRecoveryType()); } return result; @@ -532,7 +540,7 @@ struct TypeChecker2 visit(rhs); TypeId rhsType = lookupType(rhs); - if (!isSubtype(rhsType, lhsType, stack.back(), ice)) + if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } @@ -681,9 +689,9 @@ struct TypeChecker2 void visit(AstExprConstantNumber* number) { TypeId actualType = lookupType(number); - TypeId numberType = getSingletonTypes().numberType; + TypeId numberType = singletonTypes->numberType; - if (!isSubtype(numberType, actualType, stack.back(), ice)) + if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -692,9 +700,9 @@ struct TypeChecker2 void visit(AstExprConstantString* string) { TypeId actualType = lookupType(string); - TypeId stringType = getSingletonTypes().stringType; + TypeId stringType = singletonTypes->stringType; - if (!isSubtype(stringType, actualType, stack.back(), ice)) + if (!isSubtype(stringType, actualType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -754,7 +762,7 @@ struct TypeChecker2 FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(expectedType, instantiatedFunctionType, stack.back(), ice)) + if (!isSubtype(expectedType, instantiatedFunctionType, stack.back(), singletonTypes, ice)) { CloneState cloneState; expectedType = clone(expectedType, module->internalTypes, cloneState); @@ -773,7 +781,7 @@ struct TypeChecker2 getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); if (ty) { - if (!isSubtype(resultType, *ty, stack.back(), ice)) + if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{resultType, *ty}, indexName->location); } @@ -806,7 +814,7 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), ice)) + if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); } @@ -851,10 +859,10 @@ struct TypeChecker2 TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back(), ice)) + if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice)) return; - if (isSubtype(computedType, annotationType, stack.back(), ice)) + if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice)) return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); @@ -908,7 +916,7 @@ struct TypeChecker2 return result; } else if (get(pack)) - return singletonTypes.errorRecoveryType(); + return singletonTypes->errorRecoveryType(); else ice.ice("flattenPack got a weird pack!"); } @@ -1154,7 +1162,7 @@ struct TypeChecker2 ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy) { UnifierSharedState sharedState{&ice}; - Unifier u{&module->internalTypes, Mode::Strict, scope, location, Covariant, sharedState}; + Unifier u{&module->internalTypes, singletonTypes, Mode::Strict, scope, location, Covariant, sharedState}; u.anyIsTop = true; u.tryUnify(subTy, superTy); @@ -1164,6 +1172,9 @@ struct TypeChecker2 void reportError(TypeErrorData data, const Location& location) { module->errors.emplace_back(location, sourceModule->name, std::move(data)); + + if (FFlag::DebugLuauLogSolverToJson) + logger->captureTypeCheckError(module->errors.back()); } void reportError(TypeError e) @@ -1179,13 +1190,13 @@ struct TypeChecker2 std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const std::string& prop, const Location& location, bool addErrors) { - return Luau::getIndexTypeFromType(scope, module->errors, &module->internalTypes, type, prop, location, addErrors, ice); + return Luau::getIndexTypeFromType(scope, module->errors, &module->internalTypes, singletonTypes, type, prop, location, addErrors, ice); } }; -void check(const SourceModule& sourceModule, Module* module) +void check(NotNull singletonTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module) { - TypeChecker2 typeChecker{&sourceModule, module}; + TypeChecker2 typeChecker{singletonTypes, logger, &sourceModule, module}; typeChecker.visit(sourceModule.root); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 2bda2804c..c14081969 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -248,21 +248,22 @@ size_t HashBoolNamePair::operator()(const std::pair& pair) const return std::hash()(pair.first) ^ std::hash()(pair.second); } -TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler) +TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull singletonTypes, InternalErrorReporter* iceHandler) : resolver(resolver) + , singletonTypes(singletonTypes) , iceHandler(iceHandler) , unifierState(iceHandler) - , nilType(getSingletonTypes().nilType) - , numberType(getSingletonTypes().numberType) - , stringType(getSingletonTypes().stringType) - , booleanType(getSingletonTypes().booleanType) - , threadType(getSingletonTypes().threadType) - , anyType(getSingletonTypes().anyType) - , unknownType(getSingletonTypes().unknownType) - , neverType(getSingletonTypes().neverType) - , anyTypePack(getSingletonTypes().anyTypePack) - , neverTypePack(getSingletonTypes().neverTypePack) - , uninhabitableTypePack(getSingletonTypes().uninhabitableTypePack) + , nilType(singletonTypes->nilType) + , numberType(singletonTypes->numberType) + , stringType(singletonTypes->stringType) + , booleanType(singletonTypes->booleanType) + , threadType(singletonTypes->threadType) + , anyType(singletonTypes->anyType) + , unknownType(singletonTypes->unknownType) + , neverType(singletonTypes->neverType) + , anyTypePack(singletonTypes->anyTypePack) + , neverTypePack(singletonTypes->neverTypePack) + , uninhabitableTypePack(singletonTypes->uninhabitableTypePack) , duplicateTypeAliases{{false, {}}} { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -357,7 +358,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo prepareErrorsForDisplay(currentModule->errors); - currentModule->clonePublicInterface(*iceHandler); + currentModule->clonePublicInterface(singletonTypes, *iceHandler); // Clear unifier cache since it's keyed off internal types that get deallocated // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. @@ -1606,7 +1607,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (FFlag::LuauLowerBoundsCalculation) { - auto [t, ok] = normalize(bindingType, currentModule, *iceHandler); + auto [t, ok] = normalize(bindingType, currentModule, singletonTypes, *iceHandler); bindingType = t; if (!ok) reportError(typealias.location, NormalizationTooComplex{}); @@ -1923,7 +1924,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors) { ErrorVec errors; - auto result = Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); + auto result = Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); if (addErrors) reportErrors(errors); return result; @@ -1932,7 +1933,7 @@ std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsTyp std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors) { ErrorVec errors; - auto result = Luau::findMetatableEntry(errors, type, entry, location); + auto result = Luau::findMetatableEntry(singletonTypes, errors, type, entry, location); if (addErrors) reportErrors(errors); return result; @@ -2034,8 +2035,8 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( if (FFlag::LuauLowerBoundsCalculation) { - auto [t, ok] = normalize(addType(UnionTypeVar{std::move(goodOptions)}), currentModule, - *iceHandler); // FIXME Inefficient. We craft a UnionTypeVar and immediately throw it away. + // FIXME Inefficient. We craft a UnionTypeVar and immediately throw it away. + auto [t, ok] = normalize(addType(UnionTypeVar{std::move(goodOptions)}), currentModule, singletonTypes, *iceHandler); if (!ok) reportError(location, NormalizationTooComplex{}); @@ -2642,8 +2643,8 @@ TypeId TypeChecker::checkRelationalOperation( std::string metamethodName = opToMetaTableEntry(expr.op); - std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); - std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); + std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType), singletonTypes); + std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType), singletonTypes); if (leftMetatable != rightMetatable) { @@ -2654,7 +2655,7 @@ TypeId TypeChecker::checkRelationalOperation( { for (TypeId leftOption : utv) { - if (getMetatable(follow(leftOption)) == rightMetatable) + if (getMetatable(follow(leftOption), singletonTypes) == rightMetatable) { matches = true; break; @@ -2668,7 +2669,7 @@ TypeId TypeChecker::checkRelationalOperation( { for (TypeId rightOption : utv) { - if (getMetatable(follow(rightOption)) == leftMetatable) + if (getMetatable(follow(rightOption), singletonTypes) == leftMetatable) { matches = true; break; @@ -4113,7 +4114,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc std::vector adjustedArgTypes; auto it = begin(argPack); auto endIt = end(argPack); - Widen widen{¤tModule->internalTypes}; + Widen widen{¤tModule->internalTypes, singletonTypes}; for (; it != endIt; ++it) { adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widen(*it)}})); @@ -4649,7 +4650,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (FFlag::LuauLowerBoundsCalculation) { - auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + auto [t, ok] = Luau::normalize(ty, currentModule, singletonTypes, *iceHandler); if (!ok) reportError(location, NormalizationTooComplex{}); return t; @@ -4664,7 +4665,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (FFlag::LuauLowerBoundsCalculation && ftv) { - auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + auto [t, ok] = Luau::normalize(ty, currentModule, singletonTypes, *iceHandler); if (!ok) reportError(location, NormalizationTooComplex{}); return t; @@ -4701,13 +4702,13 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { if (FFlag::LuauLowerBoundsCalculation) { - auto [t, ok] = normalize(ty, currentModule, *iceHandler); + auto [t, ok] = normalize(ty, currentModule, singletonTypes, *iceHandler); if (!ok) reportError(location, NormalizationTooComplex{}); ty = t; } - Anyification anyification{¤tModule->internalTypes, scope, iceHandler, anyType, anyTypePack}; + Anyification anyification{¤tModule->internalTypes, scope, singletonTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (anyification.normalizationTooComplex) reportError(location, NormalizationTooComplex{}); @@ -4724,13 +4725,13 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo { if (FFlag::LuauLowerBoundsCalculation) { - auto [t, ok] = normalize(ty, currentModule, *iceHandler); + auto [t, ok] = normalize(ty, currentModule, singletonTypes, *iceHandler); if (!ok) reportError(location, NormalizationTooComplex{}); ty = t; } - Anyification anyification{¤tModule->internalTypes, scope, iceHandler, anyType, anyTypePack}; + Anyification anyification{¤tModule->internalTypes, scope, singletonTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4868,7 +4869,8 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant, unifierState}; + return Unifier{ + ¤tModule->internalTypes, singletonTypes, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant, unifierState}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -4883,7 +4885,7 @@ TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::singletonType(bool value) { - return value ? getSingletonTypes().trueType : getSingletonTypes().falseType; + return value ? singletonTypes->trueType : singletonTypes->falseType; } TypeId TypeChecker::singletonType(std::string value) @@ -4894,22 +4896,22 @@ TypeId TypeChecker::singletonType(std::string value) TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) { - return getSingletonTypes().errorRecoveryType(); + return singletonTypes->errorRecoveryType(); } TypeId TypeChecker::errorRecoveryType(TypeId guess) { - return getSingletonTypes().errorRecoveryType(guess); + return singletonTypes->errorRecoveryType(guess); } TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) { - return getSingletonTypes().errorRecoveryTypePack(); + return singletonTypes->errorRecoveryTypePack(); } TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) { - return getSingletonTypes().errorRecoveryTypePack(guess); + return singletonTypes->errorRecoveryTypePack(guess); } TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) @@ -5836,48 +5838,52 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r return; } - using ConditionFunc = bool(TypeId); - using SenseToTypeIdPredicate = std::function; - auto mkFilter = [](ConditionFunc f, std::optional other = std::nullopt) -> SenseToTypeIdPredicate { - return [f, other](bool sense) -> TypeIdPredicate { - return [f, other, sense](TypeId ty) -> std::optional { - if (FFlag::LuauUnknownAndNeverType && sense && get(ty)) - return other.value_or(ty); + auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional mapsTo = std::nullopt) { + TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional { + if (FFlag::LuauUnknownAndNeverType && sense && get(ty)) + return mapsTo.value_or(ty); - if (f(ty) == sense) - return ty; + if (f(ty) == sense) + return ty; - if (isUndecidable(ty)) - return other.value_or(ty); + if (isUndecidable(ty)) + return mapsTo.value_or(ty); - return std::nullopt; - }; + return std::nullopt; }; - }; - - // Note: "vector" never happens here at this point, so we don't have to write something for it. - // clang-format off - static const std::unordered_map primitives{ - // Trivial primitives. - {"nil", mkFilter(isNil, nilType)}, // This can still happen when sense is false! - {"string", mkFilter(isString, stringType)}, - {"number", mkFilter(isNumber, numberType)}, - {"boolean", mkFilter(isBoolean, booleanType)}, - {"thread", mkFilter(isThread, threadType)}, - - // Non-trivial primitives. - {"table", mkFilter([](TypeId ty) -> bool { return isTableIntersection(ty) || get(ty) || get(ty); })}, - {"function", mkFilter([](TypeId ty) -> bool { return isOverloadedFunction(ty) || get(ty); })}, - // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. - {"userdata", mkFilter([](TypeId ty) -> bool { return get(ty); })}, + refineLValue(lvalue, refis, scope, predicate); }; - // clang-format on - if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) + // Note: "vector" never happens here at this point, so we don't have to write something for it. + if (typeguardP.kind == "nil") + return refine(isNil, nilType); // This can still happen when sense is false! + else if (typeguardP.kind == "string") + return refine(isString, stringType); + else if (typeguardP.kind == "number") + return refine(isNumber, numberType); + else if (typeguardP.kind == "boolean") + return refine(isBoolean, booleanType); + else if (typeguardP.kind == "thread") + return refine(isThread, threadType); + else if (typeguardP.kind == "table") + { + return refine([](TypeId ty) -> bool { + return isTableIntersection(ty) || get(ty) || get(ty); + }); + } + else if (typeguardP.kind == "function") { - refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); - return; + return refine([](TypeId ty) -> bool { + return isOverloadedFunction(ty) || get(ty); + }); + } + else if (typeguardP.kind == "userdata") + { + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. + return refine([](TypeId ty) -> bool { + return get(ty); + }); } if (!typeguardP.isTypeof) diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 56fcceccc..a96820d67 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -9,18 +9,19 @@ namespace Luau { -std::optional findMetatableEntry(ErrorVec& errors, TypeId type, const std::string& entry, Location location) +std::optional findMetatableEntry( + NotNull singletonTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) { type = follow(type); - std::optional metatable = getMetatable(type); + std::optional metatable = getMetatable(type, singletonTypes); if (!metatable) return std::nullopt; TypeId unwrapped = follow(*metatable); if (get(unwrapped)) - return getSingletonTypes().anyType; + return singletonTypes->anyType; const TableTypeVar* mtt = getTableType(unwrapped); if (!mtt) @@ -36,7 +37,8 @@ std::optional findMetatableEntry(ErrorVec& errors, TypeId type, const st return std::nullopt; } -std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, const std::string& name, Location location) +std::optional findTablePropertyRespectingMeta( + NotNull singletonTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) { if (get(ty)) return ty; @@ -48,7 +50,7 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t return it->second.type; } - std::optional mtIndex = findMetatableEntry(errors, ty, "__index", location); + std::optional mtIndex = findMetatableEntry(singletonTypes, errors, ty, "__index", location); int count = 0; while (mtIndex) { @@ -69,23 +71,23 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t { std::optional r = first(follow(itf->retTypes)); if (!r) - return getSingletonTypes().nilType; + return singletonTypes->nilType; else return *r; } else if (get(index)) - return getSingletonTypes().anyType; + return singletonTypes->anyType; else errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); - mtIndex = findMetatableEntry(errors, *mtIndex, "__index", location); + mtIndex = findMetatableEntry(singletonTypes, errors, *mtIndex, "__index", location); } return std::nullopt; } -std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, TypeId type, const std::string& prop, - const Location& location, bool addErrors, InternalErrorReporter& handle) +std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, NotNull singletonTypes, + TypeId type, const std::string& prop, const Location& location, bool addErrors, InternalErrorReporter& handle) { type = follow(type); @@ -97,14 +99,14 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro if (isString(type)) { - std::optional mtIndex = Luau::findMetatableEntry(errors, getSingletonTypes().stringType, "__index", location); + std::optional mtIndex = Luau::findMetatableEntry(singletonTypes, errors, singletonTypes->stringType, "__index", location); LUAU_ASSERT(mtIndex); type = *mtIndex; } if (getTableType(type)) { - return findTablePropertyRespectingMeta(errors, type, prop, location); + return findTablePropertyRespectingMeta(singletonTypes, errors, type, prop, location); } else if (const ClassTypeVar* cls = get(type)) { @@ -125,7 +127,8 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro if (get(follow(t))) return t; - if (std::optional ty = getIndexTypeFromType(scope, errors, arena, t, prop, location, /* addErrors= */ false, handle)) + if (std::optional ty = + getIndexTypeFromType(scope, errors, arena, singletonTypes, t, prop, location, /* addErrors= */ false, handle)) goodOptions.push_back(*ty); else badOptions.push_back(t); @@ -144,17 +147,17 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro } if (goodOptions.empty()) - return getSingletonTypes().neverType; + return singletonTypes->neverType; if (goodOptions.size() == 1) return goodOptions[0]; // TODO: inefficient. TypeId result = arena->addType(UnionTypeVar{std::move(goodOptions)}); - auto [ty, ok] = normalize(result, NotNull{scope.get()}, *arena, handle); + auto [ty, ok] = normalize(result, NotNull{scope.get()}, *arena, singletonTypes, handle); if (!ok && addErrors) errors.push_back(TypeError{location, NormalizationTooComplex{}}); - return ok ? ty : getSingletonTypes().anyType; + return ok ? ty : singletonTypes->anyType; } else if (const IntersectionTypeVar* itv = get(type)) { @@ -165,7 +168,8 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro // TODO: we should probably limit recursion here? // RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - if (std::optional ty = getIndexTypeFromType(scope, errors, arena, t, prop, location, /* addErrors= */ false, handle)) + if (std::optional ty = + getIndexTypeFromType(scope, errors, arena, singletonTypes, t, prop, location, /* addErrors= */ false, handle)) parts.push_back(*ty); } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4abee0f6d..4f6603fb9 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -26,6 +26,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAGVARIABLE(LuauStringFormatArgumentErrorFix, false) +LUAU_FASTFLAGVARIABLE(LuauNoMoreGlobalSingletonTypes, false) namespace Luau { @@ -239,7 +240,7 @@ bool isOverloadedFunction(TypeId ty) return std::all_of(parts.begin(), parts.end(), isFunction); } -std::optional getMetatable(TypeId type) +std::optional getMetatable(TypeId type, NotNull singletonTypes) { type = follow(type); @@ -249,7 +250,7 @@ std::optional getMetatable(TypeId type) return classType->metatable; else if (isString(type)) { - auto ptv = get(getSingletonTypes().stringType); + auto ptv = get(singletonTypes->stringType); LUAU_ASSERT(ptv && ptv->metatable); return ptv->metatable; } @@ -707,44 +708,30 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); -static TypeVar nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true}; -static TypeVar numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true}; -static TypeVar stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}; -static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}; -static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}; -static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}; -static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; -static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; -static TypeVar unknownType_{UnknownTypeVar{}, /*persistent*/ true}; -static TypeVar neverType_{NeverTypeVar{}, /*persistent*/ true}; -static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; - -static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, /*persistent*/ true}; -static TypePackVar errorTypePack_{Unifiable::Error{}, /*persistent*/ true}; -static TypePackVar neverTypePack_{VariadicTypePack{&neverType_}, /*persistent*/ true}; -static TypePackVar uninhabitableTypePack_{TypePack{{&neverType_}, &neverTypePack_}, /*persistent*/ true}; - SingletonTypes::SingletonTypes() - : nilType(&nilType_) - , numberType(&numberType_) - , stringType(&stringType_) - , booleanType(&booleanType_) - , threadType(&threadType_) - , trueType(&trueType_) - , falseType(&falseType_) - , anyType(&anyType_) - , unknownType(&unknownType_) - , neverType(&neverType_) - , anyTypePack(&anyTypePack_) - , neverTypePack(&neverTypePack_) - , uninhabitableTypePack(&uninhabitableTypePack_) - , arena(new TypeArena) + : arena(new TypeArena) + , debugFreezeArena(FFlag::DebugLuauFreezeArena) + , nilType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true})) + , numberType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true})) + , stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true})) + , booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true})) + , threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true})) + , trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true})) + , falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true})) + , anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true})) + , unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true})) + , neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true})) + , errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true})) + , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) + , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) + , uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack)) + , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) { TypeId stringMetatable = makeStringMetatable(); - stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, stringMetatable}; + asMutable(stringType)->ty = PrimitiveTypeVar{PrimitiveTypeVar::String, stringMetatable}; persist(stringMetatable); + persist(uninhabitableTypePack); - debugFreezeArena = FFlag::DebugLuauFreezeArena; freeze(*arena); } @@ -834,12 +821,12 @@ TypeId SingletonTypes::makeStringMetatable() TypeId SingletonTypes::errorRecoveryType() { - return &errorType_; + return errorType; } TypePackId SingletonTypes::errorRecoveryTypePack() { - return &errorTypePack_; + return errorTypePack; } TypeId SingletonTypes::errorRecoveryType(TypeId guess) @@ -852,7 +839,7 @@ TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) return guess; } -SingletonTypes& getSingletonTypes() +SingletonTypes& DEPRECATED_getSingletonTypes() { static SingletonTypes singletonTypes; return singletonTypes; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index b135cd0c8..fd6784321 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -257,12 +257,12 @@ TypeId Widen::clean(TypeId ty) LUAU_ASSERT(stv); if (get(stv)) - return getSingletonTypes().stringType; + return singletonTypes->stringType; else { // If this assert trips, it's likely we now have number singletons. LUAU_ASSERT(get(stv)); - return getSingletonTypes().booleanType; + return singletonTypes->booleanType; } } @@ -317,9 +317,10 @@ static std::optional> getTableMat return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, NotNull scope, const Location& location, Variance variance, UnifierSharedState& sharedState, - TxnLog* parentLog) +Unifier::Unifier(TypeArena* types, NotNull singletonTypes, Mode mode, NotNull scope, const Location& location, + Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) + , singletonTypes(singletonTypes) , mode(mode) , scope(scope) , log(parentLog) @@ -409,7 +410,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { promoteTypeLevels(log, types, superFree->level, subTy); - Widen widen{types}; + Widen widen{types, singletonTypes}; log.replace(superTy, BoundTypeVar(widen(subTy))); } @@ -1018,7 +1019,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal { if (!occursCheck(superTp, subTp)) { - Widen widen{types}; + Widen widen{types, singletonTypes}; log.replace(superTp, Unifiable::Bound(widen(subTp))); } } @@ -1162,13 +1163,13 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); + tryUnify_(*superIter, singletonTypes->errorRecoveryType()); superIter.advance(); } while (subIter.good()) { - tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); + tryUnify_(*subIter, singletonTypes->errorRecoveryType()); subIter.advance(); } @@ -1613,7 +1614,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) // Given t1 where t1 = { lower: (t1) -> (a, b...) } // It should be the case that `string <: t1` iff `(subtype's metatable).__index <: t1` - if (auto metatable = getMetatable(subTy)) + if (auto metatable = getMetatable(subTy, singletonTypes)) { auto mttv = log.get(*metatable); if (!mttv) @@ -1658,10 +1659,10 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see TableTypeVar* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) prop.type = deeplyOptional(prop.type, seen); - return types->addType(UnionTypeVar{{getSingletonTypes().nilType, result}}); + return types->addType(UnionTypeVar{{singletonTypes->nilType, result}}); } else - return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); + return types->addType(UnionTypeVar{{singletonTypes->nilType, ty}}); } void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) @@ -1951,7 +1952,7 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) anyTp = types->addTypePack(TypePackVar{VariadicTypePack{anyTy}}); else { - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes->anyType}}); anyTp = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); } @@ -1960,15 +1961,15 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, - FFlag::LuauUnknownAndNeverType ? anyTy : getSingletonTypes().anyType, anyTp); + Luau::tryUnifyWithAny( + queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, FFlag::LuauUnknownAndNeverType ? anyTy : singletonTypes->anyType, anyTp); } void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) { LUAU_ASSERT(get(anyTp)); - const TypeId anyTy = getSingletonTypes().errorRecoveryType(); + const TypeId anyTy = singletonTypes->errorRecoveryType(); std::vector queue; @@ -1982,7 +1983,7 @@ void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) { - return Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); + return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); } void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy) @@ -2193,7 +2194,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { reportError(TypeError{location, OccursCheckFailed{}}); - log.replace(needle, *getSingletonTypes().errorRecoveryType()); + log.replace(needle, *singletonTypes->errorRecoveryType()); return true; } @@ -2250,7 +2251,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (needle == haystack) { reportError(TypeError{location, OccursCheckFailed{}}); - log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); + log.replace(needle, *singletonTypes->errorRecoveryTypePack()); return true; } @@ -2269,7 +2270,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - Unifier u = Unifier{types, mode, scope, location, variance, sharedState, &log}; + Unifier u = Unifier{types, singletonTypes, mode, scope, location, variance, sharedState, &log}; u.anyIsTop = anyIsTop; return u; } diff --git a/CodeGen/include/Luau/CodeAllocator.h b/CodeGen/include/Luau/CodeAllocator.h new file mode 100644 index 000000000..c80b5c389 --- /dev/null +++ b/CodeGen/include/Luau/CodeAllocator.h @@ -0,0 +1,50 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +struct CodeAllocator +{ + CodeAllocator(size_t blockSize, size_t maxTotalSize); + ~CodeAllocator(); + + // Places data and code into the executable page area + // To allow allocation while previously allocated code is already running, allocation has page granularity + // It's important to group functions together so that page alignment won't result in a lot of wasted space + bool allocate(uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart); + + // Provided to callbacks + void* context = nullptr; + + // Called when new block is created to create and setup the unwinding information for all the code in the block + // Some platforms require this data to be placed inside the block itself, so we also return 'unwindDataSizeInBlock' + void* (*createBlockUnwindInfo)(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) = nullptr; + + // Called to destroy unwinding information returned by 'createBlockUnwindInfo' + void (*destroyBlockUnwindInfo)(void* context, void* unwindData) = nullptr; + + static const size_t kMaxUnwindDataSize = 128; + + bool allocateNewBlock(size_t& unwindInfoSize); + + // Current block we use for allocations + uint8_t* blockPos = nullptr; + uint8_t* blockEnd = nullptr; + + // All allocated blocks + std::vector blocks; + std::vector unwindInfos; + + size_t blockSize = 0; + size_t maxTotalSize = 0; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/OperandX64.h b/CodeGen/include/Luau/OperandX64.h index 146beafb9..432b5874a 100644 --- a/CodeGen/include/Luau/OperandX64.h +++ b/CodeGen/include/Luau/OperandX64.h @@ -1,3 +1,4 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Common.h" diff --git a/CodeGen/include/Luau/RegisterX64.h b/CodeGen/include/Luau/RegisterX64.h index ae89f600a..3b6e1a483 100644 --- a/CodeGen/include/Luau/RegisterX64.h +++ b/CodeGen/include/Luau/RegisterX64.h @@ -1,3 +1,4 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Common.h" diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp new file mode 100644 index 000000000..f74320640 --- /dev/null +++ b/CodeGen/src/CodeAllocator.cpp @@ -0,0 +1,188 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/CodeAllocator.h" + +#include "Luau/Common.h" + +#include + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include + +const size_t kPageSize = 4096; +#else +#include +#include + +const size_t kPageSize = sysconf(_SC_PAGESIZE); +#endif + +static size_t alignToPageSize(size_t size) +{ + return (size + kPageSize - 1) & ~(kPageSize - 1); +} + +#if defined(_WIN32) +static uint8_t* allocatePages(size_t size) +{ + return (uint8_t*)VirtualAlloc(nullptr, alignToPageSize(size), MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); +} + +static void freePages(uint8_t* mem, size_t size) +{ + if (VirtualFree(mem, 0, MEM_RELEASE) == 0) + LUAU_ASSERT(!"failed to deallocate block memory"); +} + +static void makePagesExecutable(uint8_t* mem, size_t size) +{ + LUAU_ASSERT((uintptr_t(mem) & (kPageSize - 1)) == 0); + LUAU_ASSERT(size == alignToPageSize(size)); + + DWORD oldProtect; + if (VirtualProtect(mem, size, PAGE_EXECUTE_READ, &oldProtect) == 0) + LUAU_ASSERT(!"failed to change page protection"); +} + +static void flushInstructionCache(uint8_t* mem, size_t size) +{ + if (FlushInstructionCache(GetCurrentProcess(), mem, size) == 0) + LUAU_ASSERT(!"failed to flush instruction cache"); +} +#else +static uint8_t* allocatePages(size_t size) +{ + return (uint8_t*)mmap(nullptr, alignToPageSize(size), PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); +} + +static void freePages(uint8_t* mem, size_t size) +{ + if (munmap(mem, alignToPageSize(size)) != 0) + LUAU_ASSERT(!"failed to deallocate block memory"); +} + +static void makePagesExecutable(uint8_t* mem, size_t size) +{ + LUAU_ASSERT((uintptr_t(mem) & (kPageSize - 1)) == 0); + LUAU_ASSERT(size == alignToPageSize(size)); + + if (mprotect(mem, size, PROT_READ | PROT_EXEC) != 0) + LUAU_ASSERT(!"failed to change page protection"); +} + +static void flushInstructionCache(uint8_t* mem, size_t size) +{ + __builtin___clear_cache((char*)mem, (char*)mem + size); +} +#endif + +namespace Luau +{ +namespace CodeGen +{ + +CodeAllocator::CodeAllocator(size_t blockSize, size_t maxTotalSize) + : blockSize(blockSize) + , maxTotalSize(maxTotalSize) +{ + LUAU_ASSERT(blockSize > kMaxUnwindDataSize); + LUAU_ASSERT(maxTotalSize >= blockSize); +} + +CodeAllocator::~CodeAllocator() +{ + if (destroyBlockUnwindInfo) + { + for (void* unwindInfo : unwindInfos) + destroyBlockUnwindInfo(context, unwindInfo); + } + + for (uint8_t* block : blocks) + freePages(block, blockSize); +} + +bool CodeAllocator::allocate( + uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart) +{ + // 'Round up' to preserve 16 byte alignment + size_t alignedDataSize = (dataSize + 15) & ~15; + + size_t totalSize = alignedDataSize + codeSize; + + // Function has to fit into a single block with unwinding information + if (totalSize > blockSize - kMaxUnwindDataSize) + return false; + + size_t unwindInfoSize = 0; + + // We might need a new block + if (totalSize > size_t(blockEnd - blockPos)) + { + if (!allocateNewBlock(unwindInfoSize)) + return false; + + LUAU_ASSERT(totalSize <= size_t(blockEnd - blockPos)); + } + + LUAU_ASSERT((uintptr_t(blockPos) & (kPageSize - 1)) == 0); // Allocation starts on page boundary + + size_t dataOffset = unwindInfoSize + alignedDataSize - dataSize; + size_t codeOffset = unwindInfoSize + alignedDataSize; + + if (dataSize) + memcpy(blockPos + dataOffset, data, dataSize); + if (codeSize) + memcpy(blockPos + codeOffset, code, codeSize); + + size_t pageSize = alignToPageSize(unwindInfoSize + totalSize); + + makePagesExecutable(blockPos, pageSize); + flushInstructionCache(blockPos + codeOffset, codeSize); + + result = blockPos + unwindInfoSize; + resultSize = totalSize; + resultCodeStart = blockPos + codeOffset; + + blockPos += pageSize; + LUAU_ASSERT((uintptr_t(blockPos) & (kPageSize - 1)) == 0); // Allocation ends on page boundary + + return true; +} + +bool CodeAllocator::allocateNewBlock(size_t& unwindInfoSize) +{ + // Stop allocating once we reach a global limit + if ((blocks.size() + 1) * blockSize > maxTotalSize) + return false; + + uint8_t* block = allocatePages(blockSize); + + if (!block) + return false; + + blockPos = block; + blockEnd = block + blockSize; + + blocks.push_back(block); + + if (createBlockUnwindInfo) + { + void* unwindInfo = createBlockUnwindInfo(context, block, blockSize, unwindInfoSize); + + // 'Round up' to preserve 16 byte alignment of the following data and code + unwindInfoSize = (unwindInfoSize + 15) & ~15; + + LUAU_ASSERT(unwindInfoSize <= kMaxUnwindDataSize); + + if (!unwindInfo) + return false; + + unwindInfos.push_back(unwindInfo); + } + + return true; +} + +} // namespace CodeGen +} // namespace Luau diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index decde93fa..8b6ccddff 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -289,17 +289,19 @@ enum LuauOpcode // the first variable is then copied into index; generator/state are immutable, index isn't visible to user code LOP_FORGLOOP, - // FORGPREP_INEXT/FORGLOOP_INEXT: FORGLOOP with 2 output variables (no AUX encoding), assuming generator is luaB_inext - // FORGPREP_INEXT prepares the index variable and jumps to FORGLOOP_INEXT - // FORGLOOP_INEXT has identical encoding and semantics to FORGLOOP (except for AUX encoding) + // FORGPREP_INEXT: prepare FORGLOOP with 2 output variables (no AUX encoding), assuming generator is luaB_inext, and jump to FORGLOOP + // A: target register (see FORGLOOP for register layout) LOP_FORGPREP_INEXT, - LOP_FORGLOOP_INEXT, - // FORGPREP_NEXT/FORGLOOP_NEXT: FORGLOOP with 2 output variables (no AUX encoding), assuming generator is luaB_next - // FORGPREP_NEXT prepares the index variable and jumps to FORGLOOP_NEXT - // FORGLOOP_NEXT has identical encoding and semantics to FORGLOOP (except for AUX encoding) + // removed in v3 + LOP_DEP_FORGLOOP_INEXT, + + // FORGPREP_NEXT: prepare FORGLOOP with 2 output variables (no AUX encoding), assuming generator is luaB_next, and jump to FORGLOOP + // A: target register (see FORGLOOP for register layout) LOP_FORGPREP_NEXT, - LOP_FORGLOOP_NEXT, + + // removed in v3 + LOP_DEP_FORGLOOP_NEXT, // GETVARARGS: copy variables into the target register from vararg storage for current function // A: target register @@ -343,12 +345,9 @@ enum LuauOpcode // B: source register (for VAL/REF) or upvalue index (for UPVAL/UPREF) LOP_CAPTURE, - // JUMPIFEQK, JUMPIFNOTEQK: jumps to target offset if the comparison with constant is true (or false, for NOT variants) - // A: source register 1 - // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") - // AUX: constant table index - LOP_JUMPIFEQK, - LOP_JUMPIFNOTEQK, + // removed in v3 + LOP_DEP_JUMPIFEQK, + LOP_DEP_JUMPIFNOTEQK, // FASTCALL1: perform a fast call of a built-in function using 1 register argument // A: builtin function id (see LuauBuiltinFunction) diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 713d08cdb..2848447f6 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -73,8 +73,6 @@ static int getOpLength(LuauOpcode op) case LOP_SETLIST: case LOP_FORGLOOP: case LOP_LOADKX: - case LOP_JUMPIFEQK: - case LOP_JUMPIFNOTEQK: case LOP_FASTCALL2: case LOP_FASTCALL2K: case LOP_JUMPXEQKNIL: @@ -106,12 +104,8 @@ inline bool isJumpD(LuauOpcode op) case LOP_FORGPREP: case LOP_FORGLOOP: case LOP_FORGPREP_INEXT: - case LOP_FORGLOOP_INEXT: case LOP_FORGPREP_NEXT: - case LOP_FORGLOOP_NEXT: case LOP_JUMPBACK: - case LOP_JUMPIFEQK: - case LOP_JUMPIFNOTEQK: case LOP_JUMPXEQKNIL: case LOP_JUMPXEQKB: case LOP_JUMPXEQKN: @@ -1247,13 +1241,6 @@ void BytecodeBuilder::validate() const VJUMP(LUAU_INSN_D(insn)); break; - case LOP_JUMPIFEQK: - case LOP_JUMPIFNOTEQK: - VREG(LUAU_INSN_A(insn)); - VCONSTANY(insns[i + 1]); - VJUMP(LUAU_INSN_D(insn)); - break; - case LOP_JUMPXEQKNIL: case LOP_JUMPXEQKB: VREG(LUAU_INSN_A(insn)); @@ -1360,9 +1347,7 @@ void BytecodeBuilder::validate() const break; case LOP_FORGPREP_INEXT: - case LOP_FORGLOOP_INEXT: case LOP_FORGPREP_NEXT: - case LOP_FORGLOOP_NEXT: VREG(LUAU_INSN_A(insn) + 4); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, A+4 are loop variables VJUMP(LUAU_INSN_D(insn)); break; @@ -1728,18 +1713,10 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, formatAppend(result, "FORGPREP_INEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; - case LOP_FORGLOOP_INEXT: - formatAppend(result, "FORGLOOP_INEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel); - break; - case LOP_FORGPREP_NEXT: formatAppend(result, "FORGPREP_NEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; - case LOP_FORGLOOP_NEXT: - formatAppend(result, "FORGLOOP_NEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel); - break; - case LOP_GETVARARGS: formatAppend(result, "GETVARARGS R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn) - 1); break; @@ -1797,14 +1774,6 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, LUAU_INSN_A(insn) == LCT_UPVAL ? 'U' : 'R', LUAU_INSN_B(insn)); break; - case LOP_JUMPIFEQK: - formatAppend(result, "JUMPIFEQK R%d K%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); - break; - - case LOP_JUMPIFNOTEQK: - formatAppend(result, "JUMPIFNOTEQK R%d K%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); - break; - case LOP_JUMPXEQKNIL: formatAppend(result, "JUMPXEQKNIL R%d L%d%s\n", LUAU_INSN_A(insn), targetLabel, *code >> 31 ? " NOT" : ""); code++; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index d44daf0cf..ce2c5a920 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -3457,14 +3457,6 @@ struct Compiler return uint8_t(top); } - void reserveReg(AstNode* node, unsigned int count) - { - if (regTop + count > kMaxRegisterCount) - CompileError::raise(node->location, "Out of registers when trying to allocate %d registers: exceeded limit %d", count, kMaxRegisterCount); - - stackSize = std::max(stackSize, regTop + count); - } - void setDebugLine(AstNode* node) { if (options.debugLevel >= 1) diff --git a/Makefile b/Makefile index 0db7b2818..bd72cf881 100644 --- a/Makefile +++ b/Makefile @@ -142,12 +142,16 @@ coverage: $(TESTS_TARGET) llvm-cov export -format lcov --instr-profile default.profdata build/coverage/luau-tests >coverage.info format: - find . -name '*.h' -or -name '*.cpp' | xargs clang-format-11 -i + git ls-files '*.h' '*.cpp' | xargs clang-format-11 -i luau-size: luau nm --print-size --demangle luau | grep ' t void luau_execute' | awk -F ' ' '{sum += strtonum("0x" $$2)} END {print sum " interpreter" }' nm --print-size --demangle luau | grep ' t luauF_' | awk -F ' ' '{sum += strtonum("0x" $$2)} END {print sum " builtins" }' +check-source: + git ls-files '*.h' '*.cpp' | xargs -I+ sh -c 'grep -L LICENSE +' + git ls-files '*.h' ':!:extern' | xargs -I+ sh -c 'grep -L "#pragma once" +' + # executable target aliases luau: $(REPL_CLI_TARGET) ln -fs $^ $@ diff --git a/Sources.cmake b/Sources.cmake index 50770f921..ff0b5a6ed 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -56,12 +56,14 @@ target_sources(Luau.Compiler PRIVATE # Luau.CodeGen Sources target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/AssemblyBuilderX64.h + CodeGen/include/Luau/CodeAllocator.h CodeGen/include/Luau/Condition.h CodeGen/include/Luau/Label.h CodeGen/include/Luau/OperandX64.h CodeGen/include/Luau/RegisterX64.h CodeGen/src/AssemblyBuilderX64.cpp + CodeGen/src/CodeAllocator.cpp ) # Luau.Analysis Sources @@ -77,7 +79,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Constraint.h Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintSolver.h - Analysis/include/Luau/ConstraintSolverLogger.h + Analysis/include/Luau/DcrLogger.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h @@ -127,7 +129,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Constraint.cpp Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintSolver.cpp - Analysis/src/ConstraintSolverLogger.cpp + Analysis/src/DcrLogger.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp @@ -266,6 +268,7 @@ if(TARGET Luau.UnitTest) tests/AstVisitor.test.cpp tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp + tests/CodeAllocator.test.cpp tests/Compiler.test.cpp tests/Config.test.cpp tests/ConstraintGraphBuilder.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index 1f315a086..5ce40ae2a 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -35,6 +35,15 @@ enum lua_Status LUA_BREAK, // yielded for a debug breakpoint }; +enum lua_CoStatus +{ + LUA_CORUN = 0, // running + LUA_COSUS, // suspended + LUA_CONOR, // 'normal' (it resumed another coroutine) + LUA_COFIN, // finished + LUA_COERR, // finished with error +}; + typedef struct lua_State lua_State; typedef int (*lua_CFunction)(lua_State* L); @@ -224,6 +233,7 @@ LUA_API int lua_status(lua_State* L); LUA_API int lua_isyieldable(lua_State* L); LUA_API void* lua_getthreaddata(lua_State* L); LUA_API void lua_setthreaddata(lua_State* L, void* data); +LUA_API int lua_costatus(lua_State* L, lua_State* co); /* ** garbage-collection function and options diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 4396e5d1f..6a9c46dae 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1008,6 +1008,23 @@ int lua_status(lua_State* L) return L->status; } +int lua_costatus(lua_State* L, lua_State* co) +{ + if (co == L) + return LUA_CORUN; + if (co->status == LUA_YIELD) + return LUA_COSUS; + if (co->status == LUA_BREAK) + return LUA_CONOR; + if (co->status != 0) // some error occurred + return LUA_COERR; + if (co->ci != co->base_ci) // does it have frames? + return LUA_CONOR; + if (co->top == co->base) + return LUA_COFIN; + return LUA_COSUS; // initial state +} + void* lua_getthreaddata(lua_State* L) { return L->userdata; diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 7b967e343..3d39a2de5 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,38 +5,16 @@ #include "lstate.h" #include "lvm.h" -#define CO_RUN 0 // running -#define CO_SUS 1 // suspended -#define CO_NOR 2 // 'normal' (it resumed another coroutine) -#define CO_DEAD 3 - #define CO_STATUS_ERROR -1 #define CO_STATUS_BREAK -2 -static const char* const statnames[] = {"running", "suspended", "normal", "dead"}; - -static int auxstatus(lua_State* L, lua_State* co) -{ - if (co == L) - return CO_RUN; - if (co->status == LUA_YIELD) - return CO_SUS; - if (co->status == LUA_BREAK) - return CO_NOR; - if (co->status != 0) // some error occurred - return CO_DEAD; - if (co->ci != co->base_ci) // does it have frames? - return CO_NOR; - if (co->top == co->base) - return CO_DEAD; - return CO_SUS; // initial state -} +static const char* const statnames[] = {"running", "suspended", "normal", "dead", "dead"}; // dead appears twice for LUA_COERR and LUA_COFIN static int costatus(lua_State* L) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); - lua_pushstring(L, statnames[auxstatus(L, co)]); + lua_pushstring(L, statnames[lua_costatus(L, co)]); return 1; } @@ -45,8 +23,8 @@ static int auxresume(lua_State* L, lua_State* co, int narg) // error handling for edge cases if (co->status != LUA_YIELD) { - int status = auxstatus(L, co); - if (status != CO_SUS) + int status = lua_costatus(L, co); + if (status != LUA_COSUS) { lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]); return CO_STATUS_ERROR; @@ -236,8 +214,8 @@ static int coclose(lua_State* L) lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); - int status = auxstatus(L, co); - if (status != CO_DEAD && status != CO_SUS) + int status = lua_costatus(L, co); + if (status != LUA_COFIN && status != LUA_COERR && status != LUA_COSUS) luaL_error(L, "cannot close %s coroutine", statnames[status]); if (co->status == LUA_OK || co->status == LUA_YIELD) diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index b95d6dee4..fb610e130 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -123,6 +123,7 @@ LUAU_FASTFLAGVARIABLE(LuauSimplerUpval, false) LUAU_FASTFLAGVARIABLE(LuauNoSleepBit, false) LUAU_FASTFLAGVARIABLE(LuauEagerShrink, false) +LUAU_FASTFLAGVARIABLE(LuauFasterSweep, false) #define GC_SWEEPPAGESTEPCOST 16 @@ -848,6 +849,7 @@ static size_t atomic(lua_State* L) static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) { + LUAU_ASSERT(!FFlag::LuauFasterSweep); global_State* g = L->global; int deadmask = otherwhite(g); @@ -890,22 +892,62 @@ static int sweepgcopage(lua_State* L, lua_Page* page) int blockSize; luaM_getpagewalkinfo(page, &start, &end, &busyBlocks, &blockSize); - for (char* pos = start; pos != end; pos += blockSize) + LUAU_ASSERT(busyBlocks > 0); + + if (FFlag::LuauFasterSweep) { - GCObject* gco = (GCObject*)pos; + LUAU_ASSERT(FFlag::LuauNoSleepBit && FFlag::LuauEagerShrink); + + global_State* g = L->global; + + int deadmask = otherwhite(g); + LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); // make sure we never sweep fixed objects - // skip memory blocks that are already freed - if (gco->gch.tt == LUA_TNIL) - continue; + int newwhite = luaC_white(g); - // when true is returned it means that the element was deleted - if (sweepgco(L, page, gco)) + for (char* pos = start; pos != end; pos += blockSize) { - LUAU_ASSERT(busyBlocks > 0); + GCObject* gco = (GCObject*)pos; + + // skip memory blocks that are already freed + if (gco->gch.tt == LUA_TNIL) + continue; + + // is the object alive? + if ((gco->gch.marked ^ WHITEBITS) & deadmask) + { + LUAU_ASSERT(!isdead(g, gco)); + // make it white (for next cycle) + gco->gch.marked = cast_byte((gco->gch.marked & maskmarks) | newwhite); + } + else + { + LUAU_ASSERT(isdead(g, gco)); + freeobj(L, gco, page); - // if the last block was removed, page would be removed as well - if (--busyBlocks == 0) - return int(pos - start) / blockSize + 1; + // if the last block was removed, page would be removed as well + if (--busyBlocks == 0) + return int(pos - start) / blockSize + 1; + } + } + } + else + { + for (char* pos = start; pos != end; pos += blockSize) + { + GCObject* gco = (GCObject*)pos; + + // skip memory blocks that are already freed + if (gco->gch.tt == LUA_TNIL) + continue; + + // when true is returned it means that the element was deleted + if (sweepgco(L, page, gco)) + { + // if the last block was removed, page would be removed as well + if (--busyBlocks == 0) + return int(pos - start) / blockSize + 1; + } } } @@ -993,10 +1035,19 @@ static size_t gcstep(lua_State* L, size_t limit) // nothing more to sweep? if (g->sweepgcopage == NULL) { - // don't forget to visit main thread - sweepgco(L, NULL, obj2gco(g->mainthread)); + // don't forget to visit main thread, it's the only object not allocated in GCO pages + if (FFlag::LuauFasterSweep) + { + LUAU_ASSERT(!isdead(g, obj2gco(g->mainthread))); + makewhite(g, obj2gco(g->mainthread)); // make it white (for next cycle) + } + else + { + sweepgco(L, NULL, obj2gco(g->mainthread)); + } shrinkbuffers(L); + g->gcstate = GCSpause; // end collection } break; diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index aa1da8aee..e34889160 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -107,10 +107,10 @@ LUAU_FASTFLAG(LuauNoSleepBit) VM_DISPATCH_OP(LOP_POWK), VM_DISPATCH_OP(LOP_AND), VM_DISPATCH_OP(LOP_OR), VM_DISPATCH_OP(LOP_ANDK), VM_DISPATCH_OP(LOP_ORK), \ VM_DISPATCH_OP(LOP_CONCAT), VM_DISPATCH_OP(LOP_NOT), VM_DISPATCH_OP(LOP_MINUS), VM_DISPATCH_OP(LOP_LENGTH), VM_DISPATCH_OP(LOP_NEWTABLE), \ VM_DISPATCH_OP(LOP_DUPTABLE), VM_DISPATCH_OP(LOP_SETLIST), VM_DISPATCH_OP(LOP_FORNPREP), VM_DISPATCH_OP(LOP_FORNLOOP), \ - VM_DISPATCH_OP(LOP_FORGLOOP), VM_DISPATCH_OP(LOP_FORGPREP_INEXT), VM_DISPATCH_OP(LOP_FORGLOOP_INEXT), VM_DISPATCH_OP(LOP_FORGPREP_NEXT), \ - VM_DISPATCH_OP(LOP_FORGLOOP_NEXT), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ + VM_DISPATCH_OP(LOP_FORGLOOP), VM_DISPATCH_OP(LOP_FORGPREP_INEXT), VM_DISPATCH_OP(LOP_DEP_FORGLOOP_INEXT), VM_DISPATCH_OP(LOP_FORGPREP_NEXT), \ + VM_DISPATCH_OP(LOP_DEP_FORGLOOP_NEXT), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ - VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_JUMPIFEQK), VM_DISPATCH_OP(LOP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ + VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_DEP_JUMPIFEQK), VM_DISPATCH_OP(LOP_DEP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), VM_DISPATCH_OP(LOP_JUMPXEQKNIL), \ VM_DISPATCH_OP(LOP_JUMPXEQKB), VM_DISPATCH_OP(LOP_JUMPXEQKN), VM_DISPATCH_OP(LOP_JUMPXEQKS), @@ -2401,7 +2401,7 @@ static void luau_execute(lua_State* L) VM_NEXT(); } - VM_CASE(LOP_FORGLOOP_INEXT) + VM_CASE(LOP_DEP_FORGLOOP_INEXT) { VM_INTERRUPT(); Instruction insn = *pc++; @@ -2473,7 +2473,7 @@ static void luau_execute(lua_State* L) VM_NEXT(); } - VM_CASE(LOP_FORGLOOP_NEXT) + VM_CASE(LOP_DEP_FORGLOOP_NEXT) { VM_INTERRUPT(); Instruction insn = *pc++; @@ -2748,7 +2748,7 @@ static void luau_execute(lua_State* L) LUAU_UNREACHABLE(); } - VM_CASE(LOP_JUMPIFEQK) + VM_CASE(LOP_DEP_JUMPIFEQK) { Instruction insn = *pc++; uint32_t aux = *pc; @@ -2793,7 +2793,7 @@ static void luau_execute(lua_State* L) } } - VM_CASE(LOP_JUMPIFNOTEQK) + VM_CASE(LOP_DEP_JUMPIFNOTEQK) { Instruction insn = *pc++; uint32_t aux = *pc; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 988cbe807..a64d372fc 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -107,8 +107,8 @@ struct ACFixture : ACFixtureImpl ACFixture() : ACFixtureImpl() { - addGlobalBinding(frontend.typeChecker, "table", Binding{typeChecker.anyType}); - addGlobalBinding(frontend.typeChecker, "math", Binding{typeChecker.anyType}); + addGlobalBinding(frontend, "table", Binding{typeChecker.anyType}); + addGlobalBinding(frontend, "math", Binding{typeChecker.anyType}); addGlobalBinding(frontend.typeCheckerForAutocomplete, "table", Binding{typeChecker.anyType}); addGlobalBinding(frontend.typeCheckerForAutocomplete, "math", Binding{typeChecker.anyType}); } @@ -3200,8 +3200,6 @@ a.@1 TEST_CASE_FIXTURE(ACFixture, "globals_are_order_independent") { - ScopedFastFlag sff("LuauAutocompleteFixGlobalOrder", true); - check(R"( local myLocal = 4 function abc0() diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp new file mode 100644 index 000000000..005bc9598 --- /dev/null +++ b/tests/CodeAllocator.test.cpp @@ -0,0 +1,160 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/CodeAllocator.h" + +#include "doctest.h" + +#include + +using namespace Luau::CodeGen; + +TEST_SUITE_BEGIN("CodeAllocation"); + +TEST_CASE("CodeAllocation") +{ + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + uint8_t* nativeData = nullptr; + size_t sizeNativeData = 0; + uint8_t* nativeEntry = nullptr; + + std::vector code; + code.resize(128); + + REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + CHECK(nativeData != nullptr); + CHECK(sizeNativeData == 128); + CHECK(nativeEntry != nullptr); + CHECK(nativeEntry == nativeData); + + std::vector data; + data.resize(8); + + REQUIRE(allocator.allocate(data.data(), data.size(), code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + CHECK(nativeData != nullptr); + CHECK(sizeNativeData == 16 + 128); + CHECK(nativeEntry != nullptr); + CHECK(nativeEntry == nativeData + 16); +} + +TEST_CASE("CodeAllocationFailure") +{ + size_t blockSize = 4096; + size_t maxTotalSize = 8192; + CodeAllocator allocator(blockSize, maxTotalSize); + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + + std::vector code; + code.resize(6000); + + REQUIRE(!allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + + code.resize(3000); + REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(!allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); +} + +TEST_CASE("CodeAllocationWithUnwindCallbacks") +{ + struct Info + { + std::vector unwind; + uint8_t* block = nullptr; + bool destroyCalled = false; + }; + Info info; + info.unwind.resize(8); + + { + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + uint8_t* nativeData = nullptr; + size_t sizeNativeData = 0; + uint8_t* nativeEntry = nullptr; + + std::vector code; + code.resize(128); + + std::vector data; + data.resize(8); + + allocator.context = &info; + allocator.createBlockUnwindInfo = [](void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) -> void* { + Info& info = *(Info*)context; + + CHECK(info.unwind.size() == 8); + memcpy(block, info.unwind.data(), info.unwind.size()); + unwindDataSizeInBlock = 8; + + info.block = block; + + return new int(7); + }; + allocator.destroyBlockUnwindInfo = [](void* context, void* unwindData) { + Info& info = *(Info*)context; + + info.destroyCalled = true; + + CHECK(*(int*)unwindData == 7); + delete (int*)unwindData; + }; + + REQUIRE(allocator.allocate(data.data(), data.size(), code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + CHECK(nativeData != nullptr); + CHECK(sizeNativeData == 16 + 128); + CHECK(nativeEntry != nullptr); + CHECK(nativeEntry == nativeData + 16); + CHECK(nativeData == info.block + 16); + } + + CHECK(info.destroyCalled); +} + +#if defined(__x86_64__) || defined(_M_X64) +TEST_CASE("GeneratedCodeExecution") +{ +#if defined(_WIN32) + // Windows x64 ABI + constexpr RegisterX64 rArg1 = rcx; + constexpr RegisterX64 rArg2 = rdx; +#else + // System V AMD64 ABI + constexpr RegisterX64 rArg1 = rdi; + constexpr RegisterX64 rArg2 = rsi; +#endif + + AssemblyBuilderX64 build(/* logText= */ false); + + build.mov(rax, rArg1); + build.add(rax, rArg2); + build.imul(rax, rax, 7); + build.ret(); + + build.finalize(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, int64_t); + FunctionType* f = (FunctionType*)nativeEntry; + int64_t result = f(10, 20); + CHECK(result == 210); +} +#endif + +TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index d8520a6e0..e6222b029 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -6100,6 +6100,7 @@ return math.round(7.6), bit32.extract(-1, 31), bit32.replace(100, 1, 0), + math.log(100, 10), (type("fin")) )", 0, 2), @@ -6153,8 +6154,9 @@ LOADN R45 1 LOADN R46 8 LOADN R47 1 LOADN R48 101 -LOADK R49 K3 -RETURN R0 50 +LOADN R49 2 +LOADK R50 K3 +RETURN R0 51 )"); } @@ -6166,7 +6168,12 @@ return math.max(1, true), string.byte("abc", 42), bit32.rshift(10, 42), - bit32.extract(1, 2, "3") + bit32.extract(1, 2, "3"), + bit32.bor(1, true), + bit32.band(1, true), + bit32.bxor(1, true), + bit32.btest(1, true), + math.min(1, true) )", 0, 2), R"( @@ -6193,9 +6200,94 @@ LOADN R6 2 LOADK R7 K14 FASTCALL 34 L4 GETIMPORT R4 16 -CALL R4 3 -1 -L4: RETURN R0 -1 -)"); +CALL R4 3 1 +L4: LOADN R6 1 +FASTCALL2K 31 R6 K3 L5 +LOADK R7 K3 +GETIMPORT R5 18 +CALL R5 2 1 +L5: LOADN R7 1 +FASTCALL2K 29 R7 K3 L6 +LOADK R8 K3 +GETIMPORT R6 20 +CALL R6 2 1 +L6: LOADN R8 1 +FASTCALL2K 32 R8 K3 L7 +LOADK R9 K3 +GETIMPORT R7 22 +CALL R7 2 1 +L7: LOADN R9 1 +FASTCALL2K 33 R9 K3 L8 +LOADK R10 K3 +GETIMPORT R8 24 +CALL R8 2 1 +L8: LOADN R10 1 +FASTCALL2K 19 R10 K3 L9 +LOADK R11 K3 +GETIMPORT R9 26 +CALL R9 2 -1 +L9: RETURN R0 -1 +)"); +} + +TEST_CASE("BuiltinFoldingProhibitedCoverage") +{ + const char* builtins[] = { + "math.abs", + "math.acos", + "math.asin", + "math.atan2", + "math.atan", + "math.ceil", + "math.cosh", + "math.cos", + "math.deg", + "math.exp", + "math.floor", + "math.fmod", + "math.ldexp", + "math.log10", + "math.log", + "math.max", + "math.min", + "math.pow", + "math.rad", + "math.sinh", + "math.sin", + "math.sqrt", + "math.tanh", + "math.tan", + "bit32.arshift", + "bit32.band", + "bit32.bnot", + "bit32.bor", + "bit32.bxor", + "bit32.btest", + "bit32.extract", + "bit32.lrotate", + "bit32.lshift", + "bit32.replace", + "bit32.rrotate", + "bit32.rshift", + "type", + "string.byte", + "string.len", + "typeof", + "math.clamp", + "math.sign", + "math.round", + }; + + for (const char* func : builtins) + { + std::string source = "return "; + source += func; + source += "()"; + + std::string bc = compileFunction(source.c_str(), 0, 2); + + CHECK(bc.find("FASTCALL") != std::string::npos); + } } TEST_CASE("BuiltinFoldingMultret") diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index a7ffb493d..c6bdb4dbe 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -496,7 +496,8 @@ TEST_CASE("Types") runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; - Luau::TypeChecker env(&moduleResolver, &iceHandler); + Luau::SingletonTypes singletonTypes; + Luau::TypeChecker env(&moduleResolver, Luau::NotNull{&singletonTypes}, &iceHandler); Luau::registerBuiltinTypes(env); Luau::freeze(env.globalTypes); diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index 2c4897330..fba578230 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -26,10 +26,10 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") )"); cgb.visit(block); - NotNull rootScope = NotNull(cgb.rootScope); + NotNull rootScope{cgb.rootScope}; NullModuleResolver resolver; - ConstraintSolver cs{&arena, rootScope, "MainModule", NotNull(&resolver), {}}; + ConstraintSolver cs{&arena, singletonTypes, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; cs.run(); @@ -47,10 +47,10 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") )"); cgb.visit(block); - NotNull rootScope = NotNull(cgb.rootScope); + NotNull rootScope{cgb.rootScope}; NullModuleResolver resolver; - ConstraintSolver cs{&arena, rootScope, "MainModule", NotNull(&resolver), {}}; + ConstraintSolver cs{&arena, singletonTypes, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; cs.run(); @@ -74,12 +74,12 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") )"); cgb.visit(block); - NotNull rootScope = NotNull(cgb.rootScope); + NotNull rootScope{cgb.rootScope}; ToStringOptions opts; NullModuleResolver resolver; - ConstraintSolver cs{&arena, rootScope, "MainModule", NotNull(&resolver), {}}; + ConstraintSolver cs{&arena, singletonTypes, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; cs.run(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp index eacc718b3..d82d5d835 100644 --- a/tests/CostModel.test.cpp +++ b/tests/CostModel.test.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "ScopedFlags.h" + #include "doctest.h" using namespace Luau; @@ -223,4 +225,21 @@ end CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); } +TEST_CASE("InterpString") +{ + ScopedFastFlag sff("LuauInterpolatedStringBaseSupport", true); + + uint64_t model = modelFunction(R"( +function test(a) + return `hello, {a}!` +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 476b7a2a5..6c4594f4b 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -91,6 +91,7 @@ Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true}) , typeChecker(frontend.typeChecker) + , singletonTypes(frontend.singletonTypes) { configResolver.defaultConfig.mode = Mode::Strict; configResolver.defaultConfig.enabledLint.warningMask = ~0ull; @@ -367,9 +368,9 @@ void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) void Fixture::registerTestTypes() { - addGlobalBinding(typeChecker, "game", typeChecker.anyType, "@luau"); - addGlobalBinding(typeChecker, "workspace", typeChecker.anyType, "@luau"); - addGlobalBinding(typeChecker, "script", typeChecker.anyType, "@luau"); + addGlobalBinding(frontend, "game", typeChecker.anyType, "@luau"); + addGlobalBinding(frontend, "workspace", typeChecker.anyType, "@luau"); + addGlobalBinding(frontend, "script", typeChecker.anyType, "@luau"); } void Fixture::dumpErrors(const CheckResult& cr) @@ -434,7 +435,7 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::unfreeze(frontend.typeChecker.globalTypes); Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); - registerBuiltinTypes(frontend.typeChecker); + registerBuiltinTypes(frontend); if (prepareAutocomplete) registerBuiltinTypes(frontend.typeCheckerForAutocomplete); registerTestTypes(); @@ -446,7 +447,7 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() : Fixture() , mainModule(new Module) - , cgb(mainModuleName, mainModule, &arena, NotNull(&moduleResolver), NotNull(&ice), frontend.getGlobalScope()) + , cgb(mainModuleName, mainModule, &arena, NotNull(&moduleResolver), singletonTypes, NotNull(&ice), frontend.getGlobalScope(), &logger) , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { BlockedTypeVar::nextIndex = 0; diff --git a/tests/Fixture.h b/tests/Fixture.h index 8923b2085..03101bbf3 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -12,6 +12,7 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" +#include "Luau/DcrLogger.h" #include "IostreamOptional.h" #include "ScopedFlags.h" @@ -137,6 +138,7 @@ struct Fixture Frontend frontend; InternalErrorReporter ice; TypeChecker& typeChecker; + NotNull singletonTypes; std::string decorateWithTypes(const std::string& code); @@ -165,6 +167,7 @@ struct ConstraintGraphBuilderFixture : Fixture TypeArena arena; ModulePtr mainModule; ConstraintGraphBuilder cgb; + DcrLogger logger; ScopedFastFlag forceTheFlag; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 3c27ad54e..a8a9e044b 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -81,8 +81,8 @@ struct FrontendFixture : BuiltinsFixture { FrontendFixture() { - addGlobalBinding(typeChecker, "game", frontend.typeChecker.anyType, "@test"); - addGlobalBinding(typeChecker, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "game", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); } }; diff --git a/tests/LValue.test.cpp b/tests/LValue.test.cpp index 606f6de39..0bb91ecef 100644 --- a/tests/LValue.test.cpp +++ b/tests/LValue.test.cpp @@ -34,23 +34,28 @@ static LValue mkSymbol(const std::string& s) return Symbol{AstName{s.data()}}; } +struct LValueFixture +{ + SingletonTypes singletonTypes; +}; + TEST_SUITE_BEGIN("LValue"); -TEST_CASE("Luau_merge_hashmap_order") +TEST_CASE_FIXTURE(LValueFixture, "Luau_merge_hashmap_order") { std::string a = "a"; std::string b = "b"; std::string c = "c"; RefinementMap m{{ - {mkSymbol(b), getSingletonTypes().stringType}, - {mkSymbol(c), getSingletonTypes().numberType}, + {mkSymbol(b), singletonTypes.stringType}, + {mkSymbol(c), singletonTypes.numberType}, }}; RefinementMap other{{ - {mkSymbol(a), getSingletonTypes().stringType}, - {mkSymbol(b), getSingletonTypes().stringType}, - {mkSymbol(c), getSingletonTypes().booleanType}, + {mkSymbol(a), singletonTypes.stringType}, + {mkSymbol(b), singletonTypes.stringType}, + {mkSymbol(c), singletonTypes.booleanType}, }}; TypeArena arena; @@ -66,21 +71,21 @@ TEST_CASE("Luau_merge_hashmap_order") CHECK_EQ("boolean | number", toString(m[mkSymbol(c)])); } -TEST_CASE("Luau_merge_hashmap_order2") +TEST_CASE_FIXTURE(LValueFixture, "Luau_merge_hashmap_order2") { std::string a = "a"; std::string b = "b"; std::string c = "c"; RefinementMap m{{ - {mkSymbol(a), getSingletonTypes().stringType}, - {mkSymbol(b), getSingletonTypes().stringType}, - {mkSymbol(c), getSingletonTypes().numberType}, + {mkSymbol(a), singletonTypes.stringType}, + {mkSymbol(b), singletonTypes.stringType}, + {mkSymbol(c), singletonTypes.numberType}, }}; RefinementMap other{{ - {mkSymbol(b), getSingletonTypes().stringType}, - {mkSymbol(c), getSingletonTypes().booleanType}, + {mkSymbol(b), singletonTypes.stringType}, + {mkSymbol(c), singletonTypes.booleanType}, }}; TypeArena arena; @@ -96,7 +101,7 @@ TEST_CASE("Luau_merge_hashmap_order2") CHECK_EQ("boolean | number", toString(m[mkSymbol(c)])); } -TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start") +TEST_CASE_FIXTURE(LValueFixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start") { std::string a = "a"; std::string b = "b"; @@ -105,15 +110,15 @@ TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start") std::string e = "e"; RefinementMap m{{ - {mkSymbol(a), getSingletonTypes().stringType}, - {mkSymbol(b), getSingletonTypes().numberType}, - {mkSymbol(c), getSingletonTypes().booleanType}, + {mkSymbol(a), singletonTypes.stringType}, + {mkSymbol(b), singletonTypes.numberType}, + {mkSymbol(c), singletonTypes.booleanType}, }}; RefinementMap other{{ - {mkSymbol(c), getSingletonTypes().stringType}, - {mkSymbol(d), getSingletonTypes().numberType}, - {mkSymbol(e), getSingletonTypes().booleanType}, + {mkSymbol(c), singletonTypes.stringType}, + {mkSymbol(d), singletonTypes.numberType}, + {mkSymbol(e), singletonTypes.booleanType}, }}; TypeArena arena; @@ -133,7 +138,7 @@ TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start") CHECK_EQ("boolean", toString(m[mkSymbol(e)])); } -TEST_CASE("hashing_lvalue_global_prop_access") +TEST_CASE_FIXTURE(LValueFixture, "hashing_lvalue_global_prop_access") { std::string t1 = "t"; std::string x1 = "x"; @@ -154,13 +159,13 @@ TEST_CASE("hashing_lvalue_global_prop_access") CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); RefinementMap m; - m[t_x1] = getSingletonTypes().stringType; - m[t_x2] = getSingletonTypes().numberType; + m[t_x1] = singletonTypes.stringType; + m[t_x2] = singletonTypes.numberType; CHECK_EQ(1, m.size()); } -TEST_CASE("hashing_lvalue_local_prop_access") +TEST_CASE_FIXTURE(LValueFixture, "hashing_lvalue_local_prop_access") { std::string t1 = "t"; std::string x1 = "x"; @@ -183,8 +188,8 @@ TEST_CASE("hashing_lvalue_local_prop_access") CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); RefinementMap m; - m[t_x1] = getSingletonTypes().stringType; - m[t_x2] = getSingletonTypes().numberType; + m[t_x1] = singletonTypes.stringType; + m[t_x2] = singletonTypes.numberType; CHECK_EQ(2, m.size()); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index b560c89e3..8c7d762ed 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -35,7 +35,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobal") TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") { // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(typeChecker, "Wait", Binding{typeChecker.anyType, {}, true, "wait", "@test/global/Wait"}); + addGlobalBinding(frontend, "Wait", Binding{typeChecker.anyType, {}, true, "wait", "@test/global/Wait"}); LintResult result = lintTyped("Wait(5)"); @@ -49,7 +49,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobalNoReplacement") // Normally this would be defined externally, so hack it in for testing const char* deprecationReplacementString = ""; - addGlobalBinding(typeChecker, "Version", Binding{typeChecker.anyType, {}, true, deprecationReplacementString}); + addGlobalBinding(frontend, "Version", Binding{typeChecker.anyType, {}, true, deprecationReplacementString}); LintResult result = lintTyped("Version()"); @@ -380,7 +380,7 @@ return bar() TEST_CASE_FIXTURE(Fixture, "ImportUnused") { // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(typeChecker, "game", typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "game", typeChecker.anyType, "@test"); LintResult result = lint(R"( local Roact = require(game.Packages.Roact) @@ -1464,7 +1464,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") getMutable(colorType)->props = {{"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; - addGlobalBinding(typeChecker, "Color3", Binding{colorType, {}}); + addGlobalBinding(frontend, "Color3", Binding{colorType, {}}); freeze(typeChecker.globalTypes); @@ -1737,8 +1737,6 @@ local _ = 0x0xffffffffffffffffffffffffffffffffff TEST_CASE_FIXTURE(Fixture, "ComparisonPrecedence") { - ScopedFastFlag sff("LuauLintComparisonPrecedence", true); - LintResult result = lint(R"( local a, b = ... diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 5ec375c11..58d389947 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -10,6 +10,7 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauLowerBoundsCalculation); TEST_SUITE_BEGIN("ModuleTests"); @@ -134,7 +135,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") REQUIRE(signType != nullptr); CHECK(!isInArena(signType, module->interfaceTypes)); - CHECK(isInArena(signType, typeChecker.globalTypes)); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK(isInArena(signType, frontend.globalTypes)); + else + CHECK(isInArena(signType, typeChecker.globalTypes)); } TEST_CASE_FIXTURE(Fixture, "deepClone_union") @@ -230,7 +234,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") { TypeArena src; - TypeId constrained = src.addType(ConstrainedTypeVar{TypeLevel{}, {getSingletonTypes().numberType, getSingletonTypes().stringType}}); + TypeId constrained = src.addType(ConstrainedTypeVar{TypeLevel{}, {singletonTypes->numberType, singletonTypes->stringType}}); TypeArena dest; CloneState cloneState; @@ -240,8 +244,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") const ConstrainedTypeVar* ctv = get(cloned); REQUIRE_EQ(2, ctv->parts.size()); - CHECK_EQ(getSingletonTypes().numberType, ctv->parts[0]); - CHECK_EQ(getSingletonTypes().stringType, ctv->parts[1]); + CHECK_EQ(singletonTypes->numberType, ctv->parts[0]); + CHECK_EQ(singletonTypes->stringType, ctv->parts[1]); } TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index b017d8ddc..42e9a9336 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -15,13 +15,13 @@ struct NormalizeFixture : Fixture bool isSubtype(TypeId a, TypeId b) { - return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, ice); + return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); } }; -void createSomeClasses(TypeChecker& typeChecker) +void createSomeClasses(Frontend& frontend) { - auto& arena = typeChecker.globalTypes; + auto& arena = frontend.globalTypes; unfreeze(arena); @@ -32,23 +32,23 @@ void createSomeClasses(TypeChecker& typeChecker) parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; - addGlobalBinding(typeChecker, "Parent", {parentType}); - typeChecker.globalScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; + addGlobalBinding(frontend, "Parent", {parentType}); + frontend.getGlobalScope()->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); ClassTypeVar* childClass = getMutable(childType); childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; - addGlobalBinding(typeChecker, "Child", {childType}); - typeChecker.globalScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; + addGlobalBinding(frontend, "Child", {childType}); + frontend.getGlobalScope()->exportedTypeBindings["Child"] = TypeFun{{}, childType}; TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); - addGlobalBinding(typeChecker, "Unrelated", {unrelatedType}); - typeChecker.globalScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; + addGlobalBinding(frontend, "Unrelated", {unrelatedType}); + frontend.getGlobalScope()->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; - for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + for (const auto& [name, ty] : frontend.getGlobalScope()->exportedTypeBindings) persist(ty.type); freeze(arena); @@ -508,7 +508,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_table") TEST_CASE_FIXTURE(NormalizeFixture, "classes") { - createSomeClasses(typeChecker); + createSomeClasses(frontend); check(""); // Ensure that we have a main Module. @@ -596,7 +596,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_sub )"); ModulePtr tempModule{new Module}; - tempModule->scopes.emplace_back(Location(), std::make_shared(getSingletonTypes().anyTypePack)); + tempModule->scopes.emplace_back(Location(), std::make_shared(singletonTypes->anyTypePack)); // HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze // the arena that the type lives in. @@ -604,7 +604,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_sub unfreeze(mainModule->internalTypes); TypeId tType = requireType("t"); - normalize(tType, tempModule, *typeChecker.iceHandler); + normalize(tType, tempModule, singletonTypes, *typeChecker.iceHandler); CHECK_EQ("{| x: number? |}", toString(tType, {true})); } @@ -1085,7 +1085,7 @@ TEST_CASE_FIXTURE(Fixture, "bound_typevars_should_only_be_marked_normal_if_their TEST_CASE_FIXTURE(BuiltinsFixture, "skip_force_normal_on_external_types") { - createSomeClasses(typeChecker); + createSomeClasses(frontend); CheckResult result = check(R"( export type t0 = { a: Child } diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp index e77ba78ac..dfa06aa1b 100644 --- a/tests/NotNull.test.cpp +++ b/tests/NotNull.test.cpp @@ -1,3 +1,4 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/NotNull.h" #include "doctest.h" diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index e487fd48f..dd91467d5 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -76,8 +76,8 @@ TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") Location{{1, 21}, {1, 26}}, getMainSourceModule()->name, TypeMismatch{ - getSingletonTypes().numberType, - getSingletonTypes().stringType, + singletonTypes->numberType, + singletonTypes->stringType, }, }); } @@ -87,8 +87,8 @@ TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") Location{{1, 8}, {1, 26}}, getMainSourceModule()->name, TypeMismatch{ - getSingletonTypes().numberType, - getSingletonTypes().stringType, + singletonTypes->numberType, + singletonTypes->stringType, }, }); } @@ -501,7 +501,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") CheckResult result = check("type t10 = typeof(table)"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); + TypeId ty = getGlobalBinding(frontend, "table"); CHECK_EQ(toString(ty), "table"); const TableTypeVar* ttv = get(ty); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 8a86ee5fd..5d18b335d 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -557,7 +557,7 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") { - addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -583,7 +583,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") { - addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -609,7 +609,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_are_not_exported") { - addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 6f4191e3f..98883dfa7 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -32,7 +32,7 @@ struct ClassFixture : BuiltinsFixture {"New", {makeFunction(arena, nullopt, {}, {baseClassInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - addGlobalBinding(typeChecker, "BaseClass", baseClassType, "@test"); + addGlobalBinding(frontend, "BaseClass", baseClassType, "@test"); TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); @@ -45,7 +45,7 @@ struct ClassFixture : BuiltinsFixture {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; - addGlobalBinding(typeChecker, "ChildClass", childClassType, "@test"); + addGlobalBinding(frontend, "ChildClass", childClassType, "@test"); TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); @@ -58,7 +58,7 @@ struct ClassFixture : BuiltinsFixture {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; - addGlobalBinding(typeChecker, "GrandChild", childClassType, "@test"); + addGlobalBinding(frontend, "GrandChild", childClassType, "@test"); TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); @@ -71,7 +71,7 @@ struct ClassFixture : BuiltinsFixture {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; - addGlobalBinding(typeChecker, "AnotherChild", childClassType, "@test"); + addGlobalBinding(frontend, "AnotherChild", childClassType, "@test"); TypeId vector2MetaType = arena.addType(TableTypeVar{}); @@ -89,7 +89,7 @@ struct ClassFixture : BuiltinsFixture {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; - addGlobalBinding(typeChecker, "Vector2", vector2Type, "@test"); + addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) persist(tf.type); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 9fe0c6aaf..26280c134 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -19,13 +19,13 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_simple") declare foo2: typeof(foo) )"); - TypeId globalFooTy = getGlobalBinding(frontend.typeChecker, "foo"); + TypeId globalFooTy = getGlobalBinding(frontend, "foo"); CHECK_EQ(toString(globalFooTy), "number"); - TypeId globalBarTy = getGlobalBinding(frontend.typeChecker, "bar"); + TypeId globalBarTy = getGlobalBinding(frontend, "bar"); CHECK_EQ(toString(globalBarTy), "(number) -> string"); - TypeId globalFoo2Ty = getGlobalBinding(frontend.typeChecker, "foo2"); + TypeId globalFoo2Ty = getGlobalBinding(frontend, "foo2"); CHECK_EQ(toString(globalFoo2Ty), "number"); CheckResult result = check(R"( @@ -48,20 +48,20 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") declare function var(...: any): string )"); - TypeId globalFooTy = getGlobalBinding(frontend.typeChecker, "foo"); + TypeId globalFooTy = getGlobalBinding(frontend, "foo"); CHECK_EQ(toString(globalFooTy), "number"); - std::optional globalAsdfTy = frontend.typeChecker.globalScope->lookupType("Asdf"); + std::optional globalAsdfTy = frontend.getGlobalScope()->lookupType("Asdf"); REQUIRE(bool(globalAsdfTy)); CHECK_EQ(toString(globalAsdfTy->type), "number | string"); - TypeId globalBarTy = getGlobalBinding(frontend.typeChecker, "bar"); + TypeId globalBarTy = getGlobalBinding(frontend, "bar"); CHECK_EQ(toString(globalBarTy), "(number) -> string"); - TypeId globalFoo2Ty = getGlobalBinding(frontend.typeChecker, "foo2"); + TypeId globalFoo2Ty = getGlobalBinding(frontend, "foo2"); CHECK_EQ(toString(globalFoo2Ty), "number"); - TypeId globalVarTy = getGlobalBinding(frontend.typeChecker, "var"); + TypeId globalVarTy = getGlobalBinding(frontend, "var"); CHECK_EQ(toString(globalVarTy), "(...any) -> string"); @@ -85,7 +85,7 @@ TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_sc freeze(typeChecker.globalTypes); REQUIRE(!parseFailResult.success); - std::optional fooTy = tryGetGlobalBinding(typeChecker, "foo"); + std::optional fooTy = tryGetGlobalBinding(frontend, "foo"); CHECK(!fooTy.has_value()); LoadDefinitionFileResult checkFailResult = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( @@ -95,7 +95,7 @@ TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_sc "@test"); REQUIRE(!checkFailResult.success); - std::optional barTy = tryGetGlobalBinding(typeChecker, "bar"); + std::optional barTy = tryGetGlobalBinding(frontend, "bar"); CHECK(!barTy.has_value()); } diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 3a6f44911..35e67ec55 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -127,6 +127,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") return T )"); + LUAU_REQUIRE_NO_ERRORS(result); + auto r = first(getMainModule()->getModuleScope()->returnType); REQUIRE(r); @@ -136,8 +138,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") REQUIRE(ttv->props.count("f")); TypeId k = ttv->props["f"].type; REQUIRE(k); - - LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_count") diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 5a6fb0e49..a31c9c503 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -102,4 +102,18 @@ end CHECK_EQ(toString(result.errors[1]), "Type 'number' could not be converted into 'string'"); } +TEST_CASE("singleton_types") +{ + BuiltinsFixture a; + + { + BuiltinsFixture b; + } + + // Check that Frontend 'a' environment wasn't modified by 'b' + CheckResult result = a.check("local s: string = 'hello' local t = s:lower()"); + + CHECK(result.errors.empty()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 472d0ed55..3482b75cf 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -353,6 +353,9 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", false}, + // I'm not sure why this is broken without DCR, but it seems to be fixed + // when DCR is enabled. + {"DebugLuauDeferredConstraintResolution", false}, }; CheckResult result = check(R"( @@ -367,6 +370,9 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", false}, + // I'm not sure why this is broken without DCR, but it seems to be fixed + // when DCR is enabled. + {"DebugLuauDeferredConstraintResolution", false}, }; CheckResult result = check(R"( @@ -588,9 +594,9 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") }; TypeArena arena; - TypeId nilType = getSingletonTypes().nilType; + TypeId nilType = singletonTypes->nilType; - std::unique_ptr scope = std::make_unique(getSingletonTypes().anyTypePack); + std::unique_ptr scope = std::make_unique(singletonTypes->anyTypePack); TypeId free1 = arena.addType(FreeTypePack{scope.get()}); TypeId option1 = arena.addType(UnionTypeVar{{nilType, free1}}); @@ -600,7 +606,7 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") InternalErrorReporter iceHandler; UnifierSharedState sharedState{&iceHandler}; - Unifier u{&arena, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant, sharedState}; + Unifier u{&arena, singletonTypes, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant, sharedState}; u.tryUnify(option1, option2); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index e61e6e45d..d2bff9c3c 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -50,7 +50,7 @@ struct RefinementClassFixture : Fixture {"Y", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}}, }; - normalize(vec3, scope, arena, *typeChecker.iceHandler); + normalize(vec3, scope, arena, singletonTypes, *typeChecker.iceHandler); TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); @@ -58,21 +58,21 @@ struct RefinementClassFixture : Fixture TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - normalize(isA, scope, arena, *typeChecker.iceHandler); + normalize(isA, scope, arena, singletonTypes, *typeChecker.iceHandler); getMutable(inst)->props = { {"Name", Property{typeChecker.stringType}}, {"IsA", Property{isA}}, }; - normalize(inst, scope, arena, *typeChecker.iceHandler); + normalize(inst, scope, arena, singletonTypes, *typeChecker.iceHandler); TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); - normalize(folder, scope, arena, *typeChecker.iceHandler); + normalize(folder, scope, arena, singletonTypes, *typeChecker.iceHandler); TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); getMutable(part)->props = { {"Position", Property{vec3}}, }; - normalize(part, scope, arena, *typeChecker.iceHandler); + normalize(part, scope, arena, singletonTypes, *typeChecker.iceHandler); typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 02fdfd733..7fa0fac0f 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -18,7 +18,7 @@ struct TryUnifyFixture : Fixture InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; - Unifier state{&arena, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant, unifierState}; + Unifier state{&arena, singletonTypes, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant, unifierState}; }; TEST_SUITE_BEGIN("TryUnifyTests"); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 7aefa00d8..eaa8b0539 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -194,7 +194,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_packs") TypePackId listOfStrings = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.stringType}}); // clang-format off - addGlobalBinding(typeChecker, "foo", + addGlobalBinding(frontend, "foo", arena.addType( FunctionTypeVar{ listOfNumbers, @@ -203,7 +203,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_packs") ), "@test" ); - addGlobalBinding(typeChecker, "bar", + addGlobalBinding(frontend, "bar", arena.addType( FunctionTypeVar{ arena.addTypePack({{typeChecker.numberType}, listOfStrings}), diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 32be82157..b81c80ce4 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -273,7 +273,7 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId root = &ttvTweenResult; typeChecker.currentModule = std::make_shared(); - typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(getSingletonTypes().anyTypePack)); + typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(singletonTypes->anyTypePack)); TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); diff --git a/tests/conformance/types.lua b/tests/conformance/types.lua index cdddceefd..3539b34d5 100644 --- a/tests/conformance/types.lua +++ b/tests/conformance/types.lua @@ -10,9 +10,6 @@ local ignore = -- what follows is a set of mismatches that hopefully eventually will go down to 0 "_G.require", -- need to move to Roblox type defs - "_G.utf8.nfcnormalize", -- need to move to Roblox type defs - "_G.utf8.nfdnormalize", -- need to move to Roblox type defs - "_G.utf8.graphemes", -- need to move to Roblox type defs } function verify(real, rtti, path) diff --git a/tools/faillist.txt b/tools/faillist.txt index 3de9db769..ef995aa62 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -163,6 +163,7 @@ BuiltinTests.table_pack_reduce BuiltinTests.table_pack_variadic BuiltinTests.tonumber_returns_optional_number_type BuiltinTests.tonumber_returns_optional_number_type2 +DefinitionTests.class_definition_overload_metamethods DefinitionTests.declaring_generic_functions DefinitionTests.definition_file_classes FrontendTest.ast_node_at_position @@ -199,7 +200,6 @@ GenericsTests.generic_functions_in_types GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_table_method GenericsTests.generic_type_pack_parentheses -GenericsTests.generic_type_pack_syntax GenericsTests.generic_type_pack_unification1 GenericsTests.generic_type_pack_unification2 GenericsTests.generic_type_pack_unification3 @@ -300,10 +300,8 @@ ProvisionalTests.operator_eq_completely_incompatible ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.typeguard_inference_incomplete -ProvisionalTests.weird_fail_to_unify_type_pack ProvisionalTests.weirditer_should_not_loop_forever ProvisionalTests.while_body_are_also_refined -ProvisionalTests.xpcall_returns_what_f_returns RefinementTest.and_constraint RefinementTest.and_or_peephole_refinement RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string @@ -494,13 +492,10 @@ ToString.function_type_with_argument_names_generic ToString.no_parentheses_around_cyclic_function_type_in_union ToString.toStringDetailed2 ToString.toStringErrorPack -ToString.toStringNamedFunction_generic_pack ToString.toStringNamedFunction_hide_type_params ToString.toStringNamedFunction_id ToString.toStringNamedFunction_map -ToString.toStringNamedFunction_overrides_param_names ToString.toStringNamedFunction_variadics -TranspilerTests.type_lists_should_be_emitted_correctly TranspilerTests.types_should_not_be_considered_cyclic_if_they_are_not_recursive TryUnifyTests.cli_41095_concat_log_in_sealed_table_unification TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType @@ -616,7 +611,6 @@ TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_return_values TypeInferFunctions.vararg_function_is_quantified -TypeInferFunctions.vararg_functions_should_allow_calls_of_any_types_and_size TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values TypeInferLoops.for_in_loop_with_custom_iterator TypeInferLoops.for_in_loop_with_next @@ -641,7 +635,6 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted -TypeInferOOP.nonstrict_self_mismatch_tail TypeInferOperators.and_adds_boolean TypeInferOperators.and_adds_boolean_no_superfluous_union TypeInferOperators.and_binexps_dont_unify @@ -689,6 +682,7 @@ TypeInferOperators.unary_not_is_boolean TypeInferOperators.unknown_type_in_comparison TypeInferOperators.UnknownGlobalCompoundAssign TypeInferPrimitives.CheckMethodsOfNumber +TypeInferPrimitives.singleton_types TypeInferPrimitives.string_function_other TypeInferPrimitives.string_index TypeInferPrimitives.string_length @@ -730,7 +724,6 @@ TypePackTests.type_alias_type_packs_nested TypePackTests.type_pack_hidden_free_tail_infinite_growth TypePackTests.type_pack_type_parameters TypePackTests.varargs_inference_through_multiple_scopes -TypePackTests.variadic_pack_syntax TypePackTests.variadic_packs TypeSingletons.bool_singleton_subtype TypeSingletons.bool_singletons From 6c708975d39a7a3d3172d1e0b734d8f78a894f0d Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 8 Sep 2022 14:58:20 -0700 Subject: [PATCH 03/66] Patch the test for now to work with 16K pages This exposes a bug in CodeAllocator that needs to be fixed properly --- tests/CodeAllocator.test.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 005bc9598..41f3a9001 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -41,8 +41,8 @@ TEST_CASE("CodeAllocation") TEST_CASE("CodeAllocationFailure") { - size_t blockSize = 4096; - size_t maxTotalSize = 8192; + size_t blockSize = 16384; + size_t maxTotalSize = 32768; CodeAllocator allocator(blockSize, maxTotalSize); uint8_t* nativeData; @@ -50,11 +50,13 @@ TEST_CASE("CodeAllocationFailure") uint8_t* nativeEntry; std::vector code; - code.resize(6000); + code.resize(18000); + // allocation has to fit in a block REQUIRE(!allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); - code.resize(3000); + // each allocation exhausts a block, so third allocation fails + code.resize(10000); REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); REQUIRE(!allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); From dd710f67caa10c12fd60835254eb9e5bae29c2bf Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 15 Sep 2022 15:13:58 -0700 Subject: [PATCH 04/66] Sync to upstream/release/545 --- Analysis/include/Luau/ConstraintSolver.h | 3 + Analysis/include/Luau/Error.h | 2 + Analysis/include/Luau/ToString.h | 2 + Analysis/include/Luau/TypeInfer.h | 4 +- Analysis/include/Luau/TypePack.h | 3 + Analysis/include/Luau/TypeUtils.h | 2 +- Analysis/src/ConstraintGraphBuilder.cpp | 36 ++- Analysis/src/ConstraintSolver.cpp | 77 +++++-- Analysis/src/Error.cpp | 21 +- Analysis/src/ToString.cpp | 16 ++ Analysis/src/TypeChecker2.cpp | 20 +- Analysis/src/TypeInfer.cpp | 69 ++++-- Analysis/src/TypePack.cpp | 9 +- Analysis/src/TypeUtils.cpp | 6 +- Analysis/src/TypeVar.cpp | 4 +- Analysis/src/Unifier.cpp | 14 +- CodeGen/include/Luau/CodeAllocator.h | 2 +- CodeGen/include/Luau/CodeBlockUnwind.h | 17 ++ CodeGen/include/Luau/UnwindBuilder.h | 35 +++ CodeGen/include/Luau/UnwindBuilderDwarf2.h | 40 ++++ CodeGen/include/Luau/UnwindBuilderWin.h | 51 +++++ CodeGen/src/CodeAllocator.cpp | 25 +- CodeGen/src/CodeBlockUnwind.cpp | 123 ++++++++++ CodeGen/src/UnwindBuilderDwarf2.cpp | 253 +++++++++++++++++++++ CodeGen/src/UnwindBuilderWin.cpp | 120 ++++++++++ Common/include/Luau/Bytecode.h | 2 +- Sources.cmake | 11 + VM/include/lua.h | 1 + VM/src/lapi.cpp | 17 +- VM/src/ldo.cpp | 16 +- VM/src/lfunc.cpp | 82 +------ VM/src/lfunc.h | 1 - VM/src/lgc.cpp | 215 +++++++---------- VM/src/lgc.h | 31 +-- VM/src/lobject.h | 2 - VM/src/lstate.cpp | 11 +- VM/src/lstate.h | 4 +- VM/src/lstring.cpp | 59 +---- VM/src/lvmexecute.cpp | 240 +------------------ tests/AstQuery.test.cpp | 3 +- tests/AstQueryDsl.cpp | 45 ++++ tests/AstQueryDsl.h | 83 +++++++ tests/CodeAllocator.test.cpp | 182 ++++++++++++++- tests/Conformance.test.cpp | 6 +- tests/ConstraintGraphBuilder.test.cpp | 5 +- tests/ConstraintGraphBuilderFixture.cpp | 17 ++ tests/ConstraintGraphBuilderFixture.h | 27 +++ tests/ConstraintSolver.test.cpp | 8 +- tests/Fixture.cpp | 47 ---- tests/Fixture.h | 83 ------- tests/TypeInfer.builtins.test.cpp | 24 +- tests/TypeInfer.functions.test.cpp | 52 +++++ tests/TypeInfer.generics.test.cpp | 8 +- tests/TypeInfer.provisional.test.cpp | 2 + tests/TypeInfer.singletons.test.cpp | 9 + tests/TypeInfer.tables.test.cpp | 8 +- tests/TypeInfer.test.cpp | 17 ++ tests/TypeInfer.typePacks.cpp | 36 +++ tools/faillist.txt | 7 +- 59 files changed, 1517 insertions(+), 798 deletions(-) create mode 100644 CodeGen/include/Luau/CodeBlockUnwind.h create mode 100644 CodeGen/include/Luau/UnwindBuilder.h create mode 100644 CodeGen/include/Luau/UnwindBuilderDwarf2.h create mode 100644 CodeGen/include/Luau/UnwindBuilderWin.h create mode 100644 CodeGen/src/CodeBlockUnwind.cpp create mode 100644 CodeGen/src/UnwindBuilderDwarf2.cpp create mode 100644 CodeGen/src/UnwindBuilderWin.cpp create mode 100644 tests/AstQueryDsl.cpp create mode 100644 tests/AstQueryDsl.h create mode 100644 tests/ConstraintGraphBuilderFixture.cpp create mode 100644 tests/ConstraintGraphBuilderFixture.h diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 059e97cb3..fe6a025b2 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -191,6 +191,9 @@ struct ConstraintSolver **/ void unblock_(BlockedConstraintId progressed); + TypeId errorRecoveryType() const; + TypePackId errorRecoveryTypePack() const; + ToStringOptions opts; }; diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 4c81d33d2..eab6a21d3 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -95,9 +95,11 @@ struct CountMismatch Return, }; size_t expected; + std::optional maximum; size_t actual; Context context = Arg; bool isVariadic = false; + std::string function; bool operator==(const CountMismatch& rhs) const; }; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 61e07e9fa..dd2d709bc 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -92,6 +92,8 @@ inline std::string toString(const Constraint& c) return toString(c, ToStringOptions{}); } +std::string toString(const LValue& lvalue); + std::string toString(const TypeVar& tv, ToStringOptions& opts); std::string toString(const TypePackVar& tp, ToStringOptions& opts); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index b0b3f3ac5..bbb9bd6d1 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -127,8 +127,8 @@ struct TypeChecker std::optional originalNameLoc, std::optional selfType, std::optional expectedType); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); - void checkArgumentList( - const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); + void checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId paramPack, TypePackId argPack, + const std::vector& argLocations); WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 8269230b4..296880942 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -185,6 +185,9 @@ std::pair, std::optional> flatten(TypePackId tp, bool isVariadic(TypePackId tp); bool isVariadic(TypePackId tp, const TxnLog& log); +// Returns true if the TypePack is Generic or Variadic. Does not walk TypePacks!! +bool isVariadicTail(TypePackId tp, const TxnLog& log, bool includeHiddenVariadics = false); + bool containsNever(TypePackId tp); } // namespace Luau diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 6890f881a..efc1d8814 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -23,6 +23,6 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro TypeId type, const std::string& prop, const Location& location, bool addErrors, InternalErrorReporter& handle); // Returns the minimum and maximum number of types the argument list can accept. -std::pair> getParameterExtents(const TxnLog* log, TypePackId tp); +std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false); } // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index e9c61e412..6a65ab925 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -729,6 +729,11 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp if (AstExprCall* call = expr->as()) { + TypeId fnType = check(scope, call->func); + + const size_t constraintIndex = scope->constraints.size(); + const size_t scopeIndex = scopes.size(); + std::vector args; for (AstExpr* arg : call->args) @@ -738,7 +743,8 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp // TODO self - TypeId fnType = check(scope, call->func); + const size_t constraintEndIndex = scope->constraints.size(); + const size_t scopeEndIndex = scopes.size(); astOriginalCallTypes[call->func] = fnType; @@ -753,7 +759,23 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp scope->unqueuedConstraints.push_back( std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); - NotNull sc(scope->unqueuedConstraints.back().get()); + NotNull sc(scope->unqueuedConstraints.back().get()); + + // We force constraints produced by checking function arguments to wait + // until after we have resolved the constraint on the function itself. + // This ensures, for instance, that we start inferring the contents of + // lambdas under the assumption that their arguments and return types + // will be compatible with the enclosing function call. + for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) + scope->constraints[ci]->dependencies.push_back(sc); + + for (size_t si = scopeIndex; si < scopeEndIndex; ++si) + { + for (auto& c : scopes[si].second->constraints) + { + c->dependencies.push_back(sc); + } + } addConstraint(scope, call->func->location, FunctionCallConstraint{ @@ -1080,7 +1102,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS signatureScope = bodyScope; } - std::optional varargPack; + TypePackId varargPack = nullptr; if (fn->vararg) { @@ -1096,6 +1118,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS signatureScope->varargPack = varargPack; } + else + { + varargPack = arena->addTypePack(VariadicTypePack{singletonTypes->anyType, /*hidden*/ true}); + // We do not add to signatureScope->varargPack because ... is not valid + // in functions without an explicit ellipsis. + } + + LUAU_ASSERT(nullptr != varargPack); if (fn->returnAnnotation) { diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index f964a855a..6fb57b15e 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -240,6 +240,12 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); printf("\t%d\t\t%s\n", blockCount, toString(*dep, opts).c_str()); } + + if (auto fcc = get(*c)) + { + for (NotNull inner : fcc->innerConstraints) + printf("\t\t\t%s\n", toString(*inner, opts).c_str()); + } } } @@ -522,7 +528,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(singletonTypes->errorRecoveryType()); + asMutable(resultType)->ty.emplace(errorRecoveryType()); // reportError(constraint->location, CannotInferBinaryOperation{c.op, std::nullopt, CannotInferBinaryOperation::Operation}); return true; } @@ -574,10 +580,24 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope, singletonTypes, &iceReporter, singletonTypes->errorRecoveryType(), singletonTypes->errorRecoveryTypePack()}; + Anyification anyify{arena, constraint->scope, singletonTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; std::optional anyified = anyify.substitute(c.variables); LUAU_ASSERT(anyified); unify(*anyified, c.variables, constraint->scope); @@ -704,7 +724,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (!tf.has_value()) { reportError(UnknownSymbol{petv->name.value, UnknownSymbol::Context::Type}, constraint->location); - bindResult(singletonTypes->errorRecoveryType()); + bindResult(errorRecoveryType()); return true; } @@ -763,7 +783,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (itf.foundInfiniteType) { // TODO (CLI-56761): Report an error. - bindResult(singletonTypes->errorRecoveryType()); + bindResult(errorRecoveryType()); return true; } @@ -786,7 +806,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (!maybeInstantiated.has_value()) { // TODO (CLI-56761): Report an error. - bindResult(singletonTypes->errorRecoveryType()); + bindResult(errorRecoveryType()); return true; } @@ -863,13 +883,21 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulldcrMagicFunction(NotNull(this), result, c.astFragment); } - if (!usedMagic) + if (usedMagic) { - for (const auto& inner : c.innerConstraints) - { - unsolvedConstraints.push_back(inner); - } - + // There are constraints that are blocked on these constraints. If we + // are never going to even examine them, then we should not block + // anything else on them. + // + // TODO CLI-58842 +#if 0 + for (auto& c: c.innerConstraints) + unblock(c); +#endif + } + else + { + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); asMutable(c.result)->ty.emplace(constraint->scope); } @@ -909,8 +937,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl }; auto errorify = [&](auto ty) { - Anyification anyify{ - arena, constraint->scope, singletonTypes, &iceReporter, singletonTypes->errorRecoveryType(), singletonTypes->errorRecoveryTypePack()}; + Anyification anyify{arena, constraint->scope, singletonTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; std::optional errorified = anyify.substitute(ty); if (!errorified) reportError(CodeTooComplex{}, constraint->location); @@ -1119,7 +1146,7 @@ void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull sc if (!u.errors.empty()) { - TypeId errorType = singletonTypes->errorRecoveryType(); + TypeId errorType = errorRecoveryType(); u.tryUnify(subType, errorType); u.tryUnify(superType, errorType); } @@ -1160,7 +1187,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l if (info.name.empty()) { reportError(UnknownRequire{}, location); - return singletonTypes->errorRecoveryType(); + return errorRecoveryType(); } std::string humanReadableName = moduleResolver->getHumanReadableModuleName(info.name); @@ -1177,24 +1204,24 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l if (!moduleResolver->moduleExists(info.name) && !info.optional) reportError(UnknownRequire{humanReadableName}, location); - return singletonTypes->errorRecoveryType(); + return errorRecoveryType(); } if (module->type != SourceCode::Type::Module) { reportError(IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}, location); - return singletonTypes->errorRecoveryType(); + return errorRecoveryType(); } TypePackId modulePack = module->getModuleScope()->returnType; if (get(modulePack)) - return singletonTypes->errorRecoveryType(); + return errorRecoveryType(); std::optional moduleType = first(modulePack); if (!moduleType) { reportError(IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}, location); - return singletonTypes->errorRecoveryType(); + return errorRecoveryType(); } return *moduleType; @@ -1212,4 +1239,14 @@ void ConstraintSolver::reportError(TypeError e) errors.back().moduleName = currentModuleName; } +TypeId ConstraintSolver::errorRecoveryType() const +{ + return singletonTypes->errorRecoveryType(); +} + +TypePackId ConstraintSolver::errorRecoveryTypePack() const +{ + return singletonTypes->errorRecoveryTypePack(); +} + } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 93cb65b90..9ecdc82b3 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -10,7 +10,8 @@ LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false) LUAU_FASTFLAGVARIABLE(LuauUseInternalCompilerErrorException, false) -static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) +static std::string wrongNumberOfArgsString( + size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { std::string s = "expects "; @@ -19,11 +20,14 @@ static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCo s += std::to_string(expectedCount) + " "; + if (maximumCount && expectedCount != *maximumCount) + s += "to " + std::to_string(*maximumCount) + " "; + if (argPrefix) s += std::string(argPrefix) + " "; s += "argument"; - if (expectedCount != 1) + if ((maximumCount ? *maximumCount : expectedCount) != 1) s += "s"; s += ", but "; @@ -185,7 +189,12 @@ struct ErrorConverter return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual, /*argPrefix*/ nullptr, e.isVariadic); + if (!e.function.empty()) + return "Argument count mismatch. Function '" + e.function + "' " + + wrongNumberOfArgsString(e.expected, e.maximum, e.actual, /*argPrefix*/ nullptr, e.isVariadic); + else + return "Argument count mismatch. Function " + + wrongNumberOfArgsString(e.expected, e.maximum, e.actual, /*argPrefix*/ nullptr, e.isVariadic); } LUAU_ASSERT(!"Unknown context"); @@ -247,10 +256,10 @@ struct ErrorConverter if (e.typeFun.typeParams.size() != e.actualParameters) return "Generic type '" + name + "' " + - wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); + wrongNumberOfArgsString(e.typeFun.typeParams.size(), std::nullopt, e.actualParameters, "type", !e.typeFun.typePackParams.empty()); return "Generic type '" + name + "' " + - wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); + wrongNumberOfArgsString(e.typeFun.typePackParams.size(), std::nullopt, e.actualPackParameters, "type pack", /*isVariadic*/ false); } std::string operator()(const Luau::SyntaxError& e) const @@ -547,7 +556,7 @@ bool DuplicateTypeDefinition::operator==(const DuplicateTypeDefinition& rhs) con bool CountMismatch::operator==(const CountMismatch& rhs) const { - return expected == rhs.expected && actual == rhs.actual && context == rhs.context; + return expected == rhs.expected && maximum == rhs.maximum && actual == rhs.actual && context == rhs.context && function == rhs.function; } bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 13cd7490e..711d461fb 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1519,4 +1519,20 @@ std::string dump(const Constraint& c) return s; } +std::string toString(const LValue& lvalue) +{ + std::string s; + for (const LValue* current = &lvalue; current; current = baseof(*current)) + { + if (auto field = get(*current)) + s = "." + field->key + s; + else if (auto symbol = get(*current)) + s = toString(*symbol) + s; + else + LUAU_ASSERT(!"Unknown LValue"); + } + + return s; +} + } // namespace Luau diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 88363b43e..76b27acdc 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -408,7 +408,7 @@ struct TypeChecker2 return; NotNull scope = stack.back(); - TypeArena tempArena; + TypeArena& arena = module->internalTypes; std::vector variableTypes; for (AstLocal* var : forInStatement->vars) @@ -424,10 +424,10 @@ struct TypeChecker2 for (size_t i = 0; i < forInStatement->values.size - 1; ++i) valueTypes.emplace_back(lookupType(forInStatement->values.data[i])); TypePackId iteratorTail = lookupPack(forInStatement->values.data[forInStatement->values.size - 1]); - TypePackId iteratorPack = tempArena.addTypePack(valueTypes, iteratorTail); + TypePackId iteratorPack = arena.addTypePack(valueTypes, iteratorTail); // ... and then expand it out to 3 values (if possible) - const std::vector iteratorTypes = flatten(tempArena, iteratorPack, 3); + const std::vector iteratorTypes = flatten(arena, iteratorPack, 3); if (iteratorTypes.empty()) { reportError(GenericError{"for..in loops require at least one value to iterate over. Got zero"}, getLocation(forInStatement->values)); @@ -456,7 +456,7 @@ struct TypeChecker2 reportError(GenericError{"for..in loops must be passed (next, [table[, state]])"}, getLocation(forInStatement->values)); // It is okay if there aren't enough iterators, but the iteratee must provide enough. - std::vector expectedVariableTypes = flatten(tempArena, nextFn->retTypes, variableTypes.size()); + std::vector expectedVariableTypes = flatten(arena, nextFn->retTypes, variableTypes.size()); if (expectedVariableTypes.size() < variableTypes.size()) reportError(GenericError{"next() does not return enough values"}, forInStatement->vars.data[0]->location); @@ -475,23 +475,23 @@ struct TypeChecker2 // If iteratorTypes is too short to be a valid call to nextFn, we have to report a count mismatch error. // If 2 is too short to be a valid call to nextFn, we have to report a count mismatch error. // If 2 is too long to be a valid call to nextFn, we have to report a count mismatch error. - auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), nextFn->argTypes); + auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), nextFn->argTypes, /*includeHiddenVariadics*/ true); if (minCount > 2) - reportError(CountMismatch{2, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + reportError(CountMismatch{2, std::nullopt, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); if (maxCount && *maxCount < 2) - reportError(CountMismatch{2, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - const std::vector flattenedArgTypes = flatten(tempArena, nextFn->argTypes, 2); + const std::vector flattenedArgTypes = flatten(arena, nextFn->argTypes, 2); const auto [argTypes, argsTail] = Luau::flatten(nextFn->argTypes); size_t firstIterationArgCount = iteratorTypes.empty() ? 0 : iteratorTypes.size() - 1; size_t actualArgCount = expectedVariableTypes.size(); if (firstIterationArgCount < minCount) - reportError(CountMismatch{2, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); else if (actualArgCount < minCount) - reportError(CountMismatch{2, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + reportError(CountMismatch{2, std::nullopt, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); if (iteratorTypes.size() >= 2 && flattenedArgTypes.size() > 0) { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index c14081969..b6f20ac81 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,13 +32,15 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTFLAGVARIABLE(LuauFunctionArgMismatchDetails, false) LUAU_FASTFLAGVARIABLE(LuauInplaceDemoteSkipAllBound, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix3, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false); +LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) +LUAU_FASTFLAGVARIABLE(LuauCallUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) @@ -1346,7 +1348,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) } Unifier state = mkUnifier(loopScope, firstValue->location); - checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); + checkArgumentList(loopScope, *firstValue, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); state.log.commit(); @@ -3666,8 +3668,8 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope } } -void TypeChecker::checkArgumentList( - const ScopePtr& scope, Unifier& state, TypePackId argPack, TypePackId paramPack, const std::vector& argLocations) +void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId argPack, TypePackId paramPack, + const std::vector& argLocations) { /* Important terminology refresher: * A function requires parameters. @@ -3679,14 +3681,27 @@ void TypeChecker::checkArgumentList( size_t paramIndex = 0; - auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack]() { + auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]() { // For this case, we want the error span to cover every errant extra parameter Location location = state.location; if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - size_t minParams = getParameterExtents(&state.log, paramPack).first; - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + if (FFlag::LuauFunctionArgMismatchDetails) + { + std::string namePath; + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); + + auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); + state.reportError(TypeError{location, + CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}}); + } + else + { + size_t minParams = getParameterExtents(&state.log, paramPack).first; + state.reportError(TypeError{location, CountMismatch{minParams, std::nullopt, std::distance(begin(argPack), end(argPack))}}); + } }; while (true) @@ -3698,11 +3713,8 @@ void TypeChecker::checkArgumentList( std::optional argTail = argIter.tail(); std::optional paramTail = paramIter.tail(); - // If we hit the end of both type packs simultaneously, then there are definitely no further type - // errors to report. All we need to do is tie up any free tails. - // - // If one side has a free tail and the other has none at all, we create an empty pack and bind the - // free tail to that. + // If we hit the end of both type packs simultaneously, we have to unify them. + // But if one side has a free tail and the other has none at all, we create an empty pack and bind the free tail to that. if (argTail) { @@ -3713,6 +3725,10 @@ void TypeChecker::checkArgumentList( else state.log.replace(*argTail, TypePackVar(TypePack{{}})); } + else if (FFlag::LuauCallUnifyPackTails && paramTail) + { + state.tryUnify(*argTail, *paramTail); + } } else if (paramTail) { @@ -3784,12 +3800,25 @@ void TypeChecker::checkArgumentList( } // ok else { - size_t minParams = getParameterExtents(&state.log, paramPack).first; + auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); std::optional tail = flatten(paramPack, state.log).second; bool isVariadic = tail && Luau::isVariadic(*tail); - state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}}); + if (FFlag::LuauFunctionArgMismatchDetails) + { + std::string namePath; + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); + + state.reportError(TypeError{ + state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + } + else + { + state.reportError( + TypeError{state.location, CountMismatch{minParams, std::nullopt, paramIndex, CountMismatch::Context::Arg, isVariadic}}); + } return; } ++paramIter; @@ -4185,13 +4214,13 @@ std::optional> TypeChecker::checkCallOverload(const Sc Unifier state = mkUnifier(scope, expr.location); // Unify return types - checkArgumentList(scope, state, retPack, ftv->retTypes, /*argLocations*/ {}); + checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { return {}; } - checkArgumentList(scope, state, argPack, ftv->argTypes, *argLocations); + checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, *argLocations); if (!state.errors.empty()) { @@ -4245,7 +4274,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal TypePackId editedArgPack = addTypePack(TypePack{editedParamList}); Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); if (editedState.errors.empty()) { @@ -4276,7 +4305,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); if (editedState.errors.empty()) { @@ -4345,8 +4374,8 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast // Unify return types if (const FunctionTypeVar* ftv = get(overload)) { - checkArgumentList(scope, state, retPack, ftv->retTypes, {}); - checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); + checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, {}); + checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, argLocations); } if (state.errors.empty()) diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 2fa9413a8..0fa4df605 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -357,10 +357,15 @@ bool isVariadic(TypePackId tp, const TxnLog& log) if (!tail) return false; - if (log.get(*tail)) + return isVariadicTail(*tail, log); +} + +bool isVariadicTail(TypePackId tp, const TxnLog& log, bool includeHiddenVariadics) +{ + if (log.get(tp)) return true; - if (auto vtp = log.get(*tail); vtp && !vtp->hidden) + if (auto vtp = log.get(tp); vtp && (includeHiddenVariadics || !vtp->hidden)) return true; return false; diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index a96820d67..6ea04ea9b 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -6,6 +6,8 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" +LUAU_FASTFLAG(LuauFunctionArgMismatchDetails) + namespace Luau { @@ -193,7 +195,7 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro return std::nullopt; } -std::pair> getParameterExtents(const TxnLog* log, TypePackId tp) +std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics) { size_t minCount = 0; size_t optionalCount = 0; @@ -216,7 +218,7 @@ std::pair> getParameterExtents(const TxnLog* log, ++it; } - if (it.tail()) + if (it.tail() && (!FFlag::LuauFunctionArgMismatchDetails || isVariadicTail(*it.tail(), *log, includeHiddenVariadics))) return {minCount, std::nullopt}; else return {minCount, minCount + optionalCount}; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4f6603fb9..3a820ea6e 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -1133,14 +1133,14 @@ std::optional> magicFunctionFormat( size_t numExpectedParams = expected.size() + 1; // + 1 for the format string if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, numActualParams}}); + typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); } else { size_t actualParamSize = params.size() - paramOffset; if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) - typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); + typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), std::nullopt, actualParamSize}}); } return WithPredicate{arena.addTypePack({typechecker.stringType})}; } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index fd6784321..505e9e437 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) +LUAU_FASTFLAG(LuauCallUnifyPackTails) namespace Luau { @@ -1159,7 +1160,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); - reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}}); while (superIter.good()) { @@ -2118,6 +2119,15 @@ void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel de for (; superIter != superEndIter; ++superIter) tp->head.push_back(*superIter); } + else if (const VariadicTypePack* subVariadic = log.getMutable(subTailPack); + subVariadic && FFlag::LuauCallUnifyPackTails) + { + while (superIter != superEndIter) + { + tryUnify_(subVariadic->ty, *superIter); + ++superIter; + } + } } else { @@ -2125,7 +2135,7 @@ void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel de { if (!isOptional(*superIter)) { - errors.push_back(TypeError{location, CountMismatch{size(superTy), size(subTy), CountMismatch::Return}}); + errors.push_back(TypeError{location, CountMismatch{size(superTy), std::nullopt, size(subTy), CountMismatch::Return}}); return; } ++superIter; diff --git a/CodeGen/include/Luau/CodeAllocator.h b/CodeGen/include/Luau/CodeAllocator.h index c80b5c389..01e131216 100644 --- a/CodeGen/include/Luau/CodeAllocator.h +++ b/CodeGen/include/Luau/CodeAllocator.h @@ -24,7 +24,7 @@ struct CodeAllocator void* context = nullptr; // Called when new block is created to create and setup the unwinding information for all the code in the block - // Some platforms require this data to be placed inside the block itself, so we also return 'unwindDataSizeInBlock' + // If data is placed inside the block itself (some platforms require this), we also return 'unwindDataSizeInBlock' void* (*createBlockUnwindInfo)(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) = nullptr; // Called to destroy unwinding information returned by 'createBlockUnwindInfo' diff --git a/CodeGen/include/Luau/CodeBlockUnwind.h b/CodeGen/include/Luau/CodeBlockUnwind.h new file mode 100644 index 000000000..ddae33a60 --- /dev/null +++ b/CodeGen/include/Luau/CodeBlockUnwind.h @@ -0,0 +1,17 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +// context must be an UnwindBuilder +void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock); +void destroyBlockUnwindInfo(void* context, void* unwindData); + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h new file mode 100644 index 000000000..c6f611b0f --- /dev/null +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -0,0 +1,35 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterX64.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +class UnwindBuilder +{ +public: + virtual ~UnwindBuilder() {} + + virtual void start() = 0; + + virtual void spill(int espOffset, RegisterX64 reg) = 0; + virtual void save(RegisterX64 reg) = 0; + virtual void allocStack(int size) = 0; + virtual void setupFrameReg(RegisterX64 reg, int espOffset) = 0; + + virtual void finish() = 0; + + virtual size_t getSize() const = 0; + + // This will place the unwinding data at the target address and might update values of some fields + virtual void finalize(char* target, void* funcAddress, size_t funcSize) const = 0; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h new file mode 100644 index 000000000..09c91d438 --- /dev/null +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -0,0 +1,40 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterX64.h" +#include "UnwindBuilder.h" + +namespace Luau +{ +namespace CodeGen +{ + +class UnwindBuilderDwarf2 : public UnwindBuilder +{ +public: + void start() override; + + void spill(int espOffset, RegisterX64 reg) override; + void save(RegisterX64 reg) override; + void allocStack(int size) override; + void setupFrameReg(RegisterX64 reg, int espOffset) override; + + void finish() override; + + size_t getSize() const override; + + void finalize(char* target, void* funcAddress, size_t funcSize) const override; + +private: + static const unsigned kRawDataLimit = 128; + char rawData[kRawDataLimit]; + char* pos = rawData; + + uint32_t stackOffset = 0; + + // We will remember the FDE location to write some of the fields like entry length, function start and size later + char* fdeEntryStart = nullptr; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h new file mode 100644 index 000000000..801eb6e47 --- /dev/null +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -0,0 +1,51 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterX64.h" +#include "UnwindBuilder.h" + +#include + +namespace Luau +{ +namespace CodeGen +{ + +// This struct matches the layout of UNWIND_CODE from ehdata.h +struct UnwindCodeWin +{ + uint8_t offset; + uint8_t opcode : 4; + uint8_t opinfo : 4; +}; + +class UnwindBuilderWin : public UnwindBuilder +{ +public: + void start() override; + + void spill(int espOffset, RegisterX64 reg) override; + void save(RegisterX64 reg) override; + void allocStack(int size) override; + void setupFrameReg(RegisterX64 reg, int espOffset) override; + + void finish() override; + + size_t getSize() const override; + + void finalize(char* target, void* funcAddress, size_t funcSize) const override; + +private: + // Windows unwind codes are written in reverse, so we have to collect them all first + std::vector unwindCodes; + + uint8_t prologSize = 0; + RegisterX64 frameReg = rax; // rax means that frame register is not used + uint8_t frameRegOffset = 0; + uint32_t stackOffset = 0; + + size_t infoSize = 0; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index f74320640..aacf40a34 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -6,8 +6,13 @@ #include #if defined(_WIN32) + +#ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX #define NOMINMAX +#endif #include const size_t kPageSize = 4096; @@ -135,17 +140,29 @@ bool CodeAllocator::allocate( if (codeSize) memcpy(blockPos + codeOffset, code, codeSize); - size_t pageSize = alignToPageSize(unwindInfoSize + totalSize); + size_t pageAlignedSize = alignToPageSize(unwindInfoSize + totalSize); - makePagesExecutable(blockPos, pageSize); + makePagesExecutable(blockPos, pageAlignedSize); flushInstructionCache(blockPos + codeOffset, codeSize); result = blockPos + unwindInfoSize; resultSize = totalSize; resultCodeStart = blockPos + codeOffset; - blockPos += pageSize; - LUAU_ASSERT((uintptr_t(blockPos) & (kPageSize - 1)) == 0); // Allocation ends on page boundary + // Ensure that future allocations from the block start from a page boundary. + // This is important since we use W^X, and writing to the previous page would require briefly removing + // executable bit from it, which may result in access violations if that code is being executed concurrently. + if (pageAlignedSize <= size_t(blockEnd - blockPos)) + { + blockPos += pageAlignedSize; + LUAU_ASSERT((uintptr_t(blockPos) & (kPageSize - 1)) == 0); + LUAU_ASSERT(blockPos <= blockEnd); + } + else + { + // Future allocations will need to allocate fresh blocks + blockPos = blockEnd; + } return true; } diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp new file mode 100644 index 000000000..6191cee40 --- /dev/null +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -0,0 +1,123 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/CodeBlockUnwind.h" + +#include "Luau/UnwindBuilder.h" + +#include + +#if defined(_WIN32) && defined(_M_X64) + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include + +#elif !defined(_WIN32) + +// Defined in unwind.h which may not be easily discoverable on various platforms +extern "C" void __register_frame(const void*); +extern "C" void __deregister_frame(const void*); + +#endif + +#if defined(__APPLE__) +// On Mac, each FDE inside eh_frame section has to be handled separately +static void visitFdeEntries(char* pos, void (*cb)(const void*)) +{ + for (;;) + { + unsigned partLength; + memcpy(&partLength, pos, sizeof(partLength)); + + if (partLength == 0) // Zero-length section signals completion + break; + + unsigned partId; + memcpy(&partId, pos + 4, sizeof(partId)); + + if (partId != 0) // Skip CIE part + cb(pos); // CIE is found using an offset in FDE + + pos += partLength + 4; + } +} +#endif + +namespace Luau +{ +namespace CodeGen +{ + +void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) +{ +#if defined(_WIN32) && defined(_M_X64) + UnwindBuilder* unwind = (UnwindBuilder*)context; + + // All unwinding related data is placed together at the start of the block + size_t unwindSize = sizeof(RUNTIME_FUNCTION) + unwind->getSize(); + unwindSize = (unwindSize + 15) & ~15; // Align to 16 bytes + LUAU_ASSERT(blockSize >= unwindSize); + + RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)block; + runtimeFunc->BeginAddress = DWORD(unwindSize); // Code will start after the unwind info + runtimeFunc->EndAddress = DWORD(blockSize); // Whole block is a part of a 'single function' + runtimeFunc->UnwindInfoAddress = DWORD(sizeof(RUNTIME_FUNCTION)); // Unwind info is placed at the start of the block + + char* unwindData = (char*)block + runtimeFunc->UnwindInfoAddress; + unwind->finalize(unwindData, block + unwindSize, blockSize - unwindSize); + + if (!RtlAddFunctionTable(runtimeFunc, 1, uintptr_t(block))) + { + LUAU_ASSERT(!"failed to allocate function table"); + return nullptr; + } + + unwindDataSizeInBlock = unwindSize; + return block; +#elif !defined(_WIN32) + UnwindBuilder* unwind = (UnwindBuilder*)context; + + // All unwinding related data is placed together at the start of the block + size_t unwindSize = unwind->getSize(); + unwindSize = (unwindSize + 15) & ~15; // Align to 16 bytes + LUAU_ASSERT(blockSize >= unwindSize); + + char* unwindData = (char*)block; + unwind->finalize(unwindData, block, blockSize); + +#if defined(__APPLE__) + visitFdeEntries(unwindData, __register_frame); +#else + __register_frame(unwindData); +#endif + + unwindDataSizeInBlock = unwindSize; + return block; +#endif + + return nullptr; +} + +void destroyBlockUnwindInfo(void* context, void* unwindData) +{ +#if defined(_WIN32) && defined(_M_X64) + RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)unwindData; + + if (!RtlDeleteFunctionTable(runtimeFunc)) + LUAU_ASSERT(!"failed to deallocate function table"); +#elif !defined(_WIN32) + +#if defined(__APPLE__) + visitFdeEntries((char*)unwindData, __deregister_frame); +#else + __deregister_frame(unwindData); +#endif + +#endif +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp new file mode 100644 index 000000000..38e3e712f --- /dev/null +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -0,0 +1,253 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/UnwindBuilderDwarf2.h" + +#include + +// General information about Dwarf2 format can be found at: +// https://dwarfstd.org/doc/dwarf-2.0.0.pdf [DWARF Debugging Information Format] +// Main part for async exception unwinding is in section '6.4 Call Frame Information' + +// Information about System V ABI (AMD64) can be found at: +// https://refspecs.linuxbase.org/elf/x86_64-abi-0.99.pdf [System V Application Binary Interface (AMD64 Architecture Processor Supplement)] +// Interaction between Dwarf2 and System V ABI can be found in sections '3.6.2 DWARF Register Number Mapping' and '4.2.4 EH_FRAME sections' + +static char* writeu8(char* target, uint8_t value) +{ + memcpy(target, &value, sizeof(value)); + return target + sizeof(value); +} + +static char* writeu32(char* target, uint32_t value) +{ + memcpy(target, &value, sizeof(value)); + return target + sizeof(value); +} + +static char* writeu64(char* target, uint64_t value) +{ + memcpy(target, &value, sizeof(value)); + return target + sizeof(value); +} + +static char* writeuleb128(char* target, uint64_t value) +{ + do + { + char byte = value & 0x7f; + value >>= 7; + + if (value) + byte |= 0x80; + + *target++ = byte; + } while (value); + + return target; +} + +// Call frame instruction opcodes +#define DW_CFA_advance_loc 0x40 +#define DW_CFA_offset 0x80 +#define DW_CFA_restore 0xc0 +#define DW_CFA_nop 0x00 +#define DW_CFA_set_loc 0x01 +#define DW_CFA_advance_loc1 0x02 +#define DW_CFA_advance_loc2 0x03 +#define DW_CFA_advance_loc4 0x04 +#define DW_CFA_offset_extended 0x05 +#define DW_CFA_restore_extended 0x06 +#define DW_CFA_undefined 0x07 +#define DW_CFA_same_value 0x08 +#define DW_CFA_register 0x09 +#define DW_CFA_remember_state 0x0a +#define DW_CFA_restore_state 0x0b +#define DW_CFA_def_cfa 0x0c +#define DW_CFA_def_cfa_register 0x0d +#define DW_CFA_def_cfa_offset 0x0e +#define DW_CFA_def_cfa_expression 0x0f +#define DW_CFA_expression 0x10 +#define DW_CFA_offset_extended_sf 0x11 +#define DW_CFA_def_cfa_sf 0x12 +#define DW_CFA_def_cfa_offset_sf 0x13 +#define DW_CFA_val_offset 0x14 +#define DW_CFA_val_offset_sf 0x15 +#define DW_CFA_val_expression 0x16 +#define DW_CFA_lo_user 0x1c +#define DW_CFA_hi_user 0x3f + +// Register numbers for x64 +#define DW_REG_RAX 0 +#define DW_REG_RDX 1 +#define DW_REG_RCX 2 +#define DW_REG_RBX 3 +#define DW_REG_RSI 4 +#define DW_REG_RDI 5 +#define DW_REG_RBP 6 +#define DW_REG_RSP 7 +#define DW_REG_R8 8 +#define DW_REG_R9 9 +#define DW_REG_R10 10 +#define DW_REG_R11 11 +#define DW_REG_R12 12 +#define DW_REG_R13 13 +#define DW_REG_R14 14 +#define DW_REG_R15 15 +#define DW_REG_RA 16 + +const int regIndexToDwRegX64[16] = {DW_REG_RAX, DW_REG_RCX, DW_REG_RDX, DW_REG_RBX, DW_REG_RSP, DW_REG_RBP, DW_REG_RSI, DW_REG_RDI, DW_REG_R8, + DW_REG_R9, DW_REG_R10, DW_REG_R11, DW_REG_R12, DW_REG_R13, DW_REG_R14, DW_REG_R15}; + +const int kCodeAlignFactor = 1; +const int kDataAlignFactor = 8; +const int kDwarfAlign = 8; +const int kFdeInitialLocationOffset = 8; +const int kFdeAddressRangeOffset = 16; + +// Define canonical frame address expression as [reg + offset] +static char* defineCfaExpression(char* pos, int dwReg, uint32_t stackOffset) +{ + pos = writeu8(pos, DW_CFA_def_cfa); + pos = writeuleb128(pos, dwReg); + pos = writeuleb128(pos, stackOffset); + return pos; +} + +// Update offset value in canonical frame address expression +static char* defineCfaExpressionOffset(char* pos, uint32_t stackOffset) +{ + pos = writeu8(pos, DW_CFA_def_cfa_offset); + pos = writeuleb128(pos, stackOffset); + return pos; +} + +static char* defineSavedRegisterLocation(char* pos, int dwReg, uint32_t stackOffset) +{ + LUAU_ASSERT(stackOffset % kDataAlignFactor == 0 && "stack offsets have to be measured in kDataAlignFactor units"); + + if (dwReg <= 15) + { + pos = writeu8(pos, DW_CFA_offset + dwReg); + } + else + { + pos = writeu8(pos, DW_CFA_offset_extended); + pos = writeuleb128(pos, dwReg); + } + + pos = writeuleb128(pos, stackOffset / kDataAlignFactor); + return pos; +} + +static char* advanceLocation(char* pos, uint8_t offset) +{ + pos = writeu8(pos, DW_CFA_advance_loc1); + pos = writeu8(pos, offset); + return pos; +} + +static char* alignPosition(char* start, char* pos) +{ + size_t size = pos - start; + size_t pad = ((size + kDwarfAlign - 1) & ~(kDwarfAlign - 1)) - size; + + for (size_t i = 0; i < pad; i++) + pos = writeu8(pos, DW_CFA_nop); + + return pos; +} + +namespace Luau +{ +namespace CodeGen +{ + +void UnwindBuilderDwarf2::start() +{ + char* cieLength = pos; + pos = writeu32(pos, 0); // Length (to be filled later) + + pos = writeu32(pos, 0); // CIE id. 0 -- .eh_frame + pos = writeu8(pos, 1); // Version + + pos = writeu8(pos, 0); // CIE augmentation String "" + + pos = writeuleb128(pos, kCodeAlignFactor); // Code align factor + pos = writeuleb128(pos, -kDataAlignFactor & 0x7f); // Data align factor of (as signed LEB128) + pos = writeu8(pos, DW_REG_RA); // Return address register + + // Optional CIE augmentation section (not present) + + // Call frame instructions (common for all FDEs, of which we have 1) + stackOffset = 8; // Return address was pushed by calling the function + + pos = defineCfaExpression(pos, DW_REG_RSP, stackOffset); // Define CFA to be the rsp + 8 + pos = defineSavedRegisterLocation(pos, DW_REG_RA, 8); // Define return address register (RA) to be located at CFA - 8 + + pos = alignPosition(cieLength, pos); + writeu32(cieLength, unsigned(pos - cieLength - 4)); // Length field itself is excluded from length + + fdeEntryStart = pos; // Will be written at the end + pos = writeu32(pos, 0); // Length (to be filled later) + pos = writeu32(pos, unsigned(pos - rawData)); // CIE pointer + pos = writeu64(pos, 0); // Initial location (to be filled later) + pos = writeu64(pos, 0); // Address range (to be filled later) + + // Optional CIE augmentation section (not present) + + // Function call frame instructions to follow +} + +void UnwindBuilderDwarf2::spill(int espOffset, RegisterX64 reg) +{ + pos = advanceLocation(pos, 5); // REX.W mov [rsp + imm8], reg +} + +void UnwindBuilderDwarf2::save(RegisterX64 reg) +{ + stackOffset += 8; + pos = advanceLocation(pos, 2); // REX.W push reg + pos = defineCfaExpressionOffset(pos, stackOffset); + pos = defineSavedRegisterLocation(pos, regIndexToDwRegX64[reg.index], stackOffset); +} + +void UnwindBuilderDwarf2::allocStack(int size) +{ + stackOffset += size; + pos = advanceLocation(pos, 4); // REX.W sub rsp, imm8 + pos = defineCfaExpressionOffset(pos, stackOffset); +} + +void UnwindBuilderDwarf2::setupFrameReg(RegisterX64 reg, int espOffset) +{ + // Not required for unwinding +} + +void UnwindBuilderDwarf2::finish() +{ + LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); + + pos = alignPosition(fdeEntryStart, pos); + writeu32(fdeEntryStart, unsigned(pos - fdeEntryStart - 4)); // Length field itself is excluded from length + + // Terminate section + pos = writeu32(pos, 0); + + LUAU_ASSERT(getSize() <= kRawDataLimit); +} + +size_t UnwindBuilderDwarf2::getSize() const +{ + return size_t(pos - rawData); +} + +void UnwindBuilderDwarf2::finalize(char* target, void* funcAddress, size_t funcSize) const +{ + memcpy(target, rawData, getSize()); + + unsigned fdeEntryStartPos = unsigned(fdeEntryStart - rawData); + writeu64(target + fdeEntryStartPos + kFdeInitialLocationOffset, uintptr_t(funcAddress)); + writeu64(target + fdeEntryStartPos + kFdeAddressRangeOffset, funcSize); +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp new file mode 100644 index 000000000..5405fcf21 --- /dev/null +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -0,0 +1,120 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/UnwindBuilderWin.h" + +#include + +// Information about the Windows x64 unwinding data setup can be found at: +// https://docs.microsoft.com/en-us/cpp/build/exception-handling-x64 [x64 exception handling] + +#define UWOP_PUSH_NONVOL 0 +#define UWOP_ALLOC_LARGE 1 +#define UWOP_ALLOC_SMALL 2 +#define UWOP_SET_FPREG 3 +#define UWOP_SAVE_NONVOL 4 +#define UWOP_SAVE_NONVOL_FAR 5 +#define UWOP_SAVE_XMM128 8 +#define UWOP_SAVE_XMM128_FAR 9 +#define UWOP_PUSH_MACHFRAME 10 + +namespace Luau +{ +namespace CodeGen +{ + +// This struct matches the layout of UNWIND_INFO from ehdata.h +struct UnwindInfoWin +{ + uint8_t version : 3; + uint8_t flags : 5; + uint8_t prologsize; + uint8_t unwindcodecount; + uint8_t framereg : 4; + uint8_t frameregoff : 4; +}; + +void UnwindBuilderWin::start() +{ + stackOffset = 8; // Return address was pushed by calling the function + + unwindCodes.reserve(16); +} + +void UnwindBuilderWin::spill(int espOffset, RegisterX64 reg) +{ + prologSize += 5; // REX.W mov [rsp + imm8], reg +} + +void UnwindBuilderWin::save(RegisterX64 reg) +{ + prologSize += 2; // REX.W push reg + stackOffset += 8; + unwindCodes.push_back({prologSize, UWOP_PUSH_NONVOL, reg.index}); +} + +void UnwindBuilderWin::allocStack(int size) +{ + LUAU_ASSERT(size >= 8 && size <= 128 && size % 8 == 0); + + prologSize += 4; // REX.W sub rsp, imm8 + stackOffset += size; + unwindCodes.push_back({prologSize, UWOP_ALLOC_SMALL, uint8_t((size - 8) / 8)}); +} + +void UnwindBuilderWin::setupFrameReg(RegisterX64 reg, int espOffset) +{ + LUAU_ASSERT(espOffset < 256 && espOffset % 16 == 0); + + frameReg = reg; + frameRegOffset = uint8_t(espOffset / 16); + + prologSize += 5; // REX.W lea rbp, [rsp + imm8] + unwindCodes.push_back({prologSize, UWOP_SET_FPREG, frameRegOffset}); +} + +void UnwindBuilderWin::finish() +{ + // Windows unwind code count is stored in uint8_t, so we can't have more + LUAU_ASSERT(unwindCodes.size() < 256); + + LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); + + size_t codeArraySize = unwindCodes.size(); + codeArraySize = (codeArraySize + 1) & ~1; // Size has to be even, but unwind code count doesn't have to + + infoSize = sizeof(UnwindInfoWin) + sizeof(UnwindCodeWin) * codeArraySize; +} + +size_t UnwindBuilderWin::getSize() const +{ + return infoSize; +} + +void UnwindBuilderWin::finalize(char* target, void* funcAddress, size_t funcSize) const +{ + UnwindInfoWin info; + info.version = 1; + info.flags = 0; // No EH + info.prologsize = prologSize; + info.unwindcodecount = uint8_t(unwindCodes.size()); + info.framereg = frameReg.index; + info.frameregoff = frameRegOffset; + + memcpy(target, &info, sizeof(info)); + target += sizeof(UnwindInfoWin); + + if (!unwindCodes.empty()) + { + // Copy unwind codes in reverse order + // Some unwind codes take up two array slots, but we don't use those atm + char* pos = target + sizeof(UnwindCodeWin) * (unwindCodes.size() - 1); + + for (size_t i = 0; i < unwindCodes.size(); i++) + { + memcpy(pos, &unwindCodes[i], sizeof(UnwindCodeWin)); + pos -= sizeof(UnwindCodeWin); + } + } +} + +} // namespace CodeGen +} // namespace Luau diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 8b6ccddff..f8652434b 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -411,7 +411,7 @@ enum LuauOpcode enum LuauBytecodeTag { // Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled - LBC_VERSION_MIN = 2, + LBC_VERSION_MIN = 3, LBC_VERSION_MAX = 3, LBC_VERSION_TARGET = 3, // Types of constant table entries diff --git a/Sources.cmake b/Sources.cmake index ff0b5a6ed..580c8b3cf 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -57,13 +57,20 @@ target_sources(Luau.Compiler PRIVATE target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/AssemblyBuilderX64.h CodeGen/include/Luau/CodeAllocator.h + CodeGen/include/Luau/CodeBlockUnwind.h CodeGen/include/Luau/Condition.h CodeGen/include/Luau/Label.h CodeGen/include/Luau/OperandX64.h CodeGen/include/Luau/RegisterX64.h + CodeGen/include/Luau/UnwindBuilder.h + CodeGen/include/Luau/UnwindBuilderDwarf2.h + CodeGen/include/Luau/UnwindBuilderWin.h CodeGen/src/AssemblyBuilderX64.cpp CodeGen/src/CodeAllocator.cpp + CodeGen/src/CodeBlockUnwind.cpp + CodeGen/src/UnwindBuilderDwarf2.cpp + CodeGen/src/UnwindBuilderWin.cpp ) # Luau.Analysis Sources @@ -258,9 +265,13 @@ endif() if(TARGET Luau.UnitTest) # Luau.UnitTest Sources target_sources(Luau.UnitTest PRIVATE + tests/AstQueryDsl.h + tests/ConstraintGraphBuilderFixture.h tests/Fixture.h tests/IostreamOptional.h tests/ScopedFlags.h + tests/AstQueryDsl.cpp + tests/ConstraintGraphBuilderFixture.cpp tests/Fixture.cpp tests/AssemblyBuilderX64.test.cpp tests/AstJsonEncoder.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index 5ce40ae2a..cdd56e96d 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -210,6 +210,7 @@ LUA_API void lua_getfenv(lua_State* L, int idx); */ LUA_API void lua_settable(lua_State* L, int idx); LUA_API void lua_setfield(lua_State* L, int idx, const char* k); +LUA_API void lua_rawsetfield(lua_State* L, int idx, const char* k); LUA_API void lua_rawset(lua_State* L, int idx); LUA_API void lua_rawseti(lua_State* L, int idx, int n); LUA_API int lua_setmetatable(lua_State* L, int objindex); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 6a9c46dae..cbcaa3cc0 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -34,8 +34,6 @@ * therefore call luaC_checkGC before luaC_threadbarrier to guarantee the object is pushed to a gray thread. */ -LUAU_FASTFLAG(LuauSimplerUpval) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -847,6 +845,19 @@ void lua_setfield(lua_State* L, int idx, const char* k) return; } +void lua_rawsetfield(lua_State* L, int idx, const char* k) +{ + api_checknelems(L, 1); + StkId t = index2addr(L, idx); + api_check(L, ttistable(t)); + if (hvalue(t)->readonly) + luaG_runerror(L, "Attempt to modify a readonly table"); + setobj2t(L, luaH_setstr(L, hvalue(t), luaS_new(L, k)), L->top - 1); + luaC_barriert(L, hvalue(t), L->top - 1); + L->top--; + return; +} + void lua_rawset(lua_State* L, int idx) { api_checknelems(L, 2); @@ -1285,8 +1296,6 @@ const char* lua_setupvalue(lua_State* L, int funcindex, int n) L->top--; setobj(L, val, L->top); luaC_barrier(L, clvalue(fi), L->top); - if (!FFlag::LuauSimplerUpval) - luaC_upvalbarrier(L, cast_to(UpVal*, NULL), val); } return name; } diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 51f63d32c..ecd2fcbbf 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -243,14 +243,14 @@ void luaD_call(lua_State* L, StkId func, int nResults) { // is a Lua function? L->ci->flags |= LUA_CALLINFO_RETURN; // luau_execute will stop after returning from the stack frame - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); + bool oldactive = L->isactive; + L->isactive = true; luaC_threadbarrier(L); luau_execute(L); // call it if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); + L->isactive = false; } L->nCcalls--; @@ -427,7 +427,7 @@ static int resume_error(lua_State* L, const char* msg) static void resume_finish(lua_State* L, int status) { L->nCcalls = L->baseCcalls; - resetbit(L->stackstate, THREAD_ACTIVEBIT); + L->isactive = false; if (status != 0) { // error? @@ -452,7 +452,7 @@ int lua_resume(lua_State* L, lua_State* from, int nargs) return resume_error(L, "C stack overflow"); L->baseCcalls = ++L->nCcalls; - l_setbit(L->stackstate, THREAD_ACTIVEBIT); + L->isactive = true; luaC_threadbarrier(L); @@ -481,7 +481,7 @@ int lua_resumeerror(lua_State* L, lua_State* from) return resume_error(L, "C stack overflow"); L->baseCcalls = ++L->nCcalls; - l_setbit(L->stackstate, THREAD_ACTIVEBIT); + L->isactive = true; luaC_threadbarrier(L); @@ -546,7 +546,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e { unsigned short oldnCcalls = L->nCcalls; ptrdiff_t old_ci = saveci(L, L->ci); - int oldactive = luaC_threadactive(L); + bool oldactive = L->isactive; int status = luaD_rawrunprotected(L, func, u); if (status != 0) { @@ -560,7 +560,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e // since the call failed with an error, we might have to reset the 'active' thread state if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); + L->isactive = false; // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. L->nCcalls = oldnCcalls; diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 8c78083b2..3c1869b5b 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,9 +6,6 @@ #include "lmem.h" #include "lgc.h" -LUAU_FASTFLAG(LuauSimplerUpval) -LUAU_FASTFLAG(LuauNoSleepBit) - Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_newgco(L, Proto, sizeof(Proto), L->activememcat); @@ -74,21 +71,16 @@ UpVal* luaF_findupval(lua_State* L, StkId level) UpVal* p; while (*pp != NULL && (p = *pp)->v >= level) { - LUAU_ASSERT(!FFlag::LuauSimplerUpval || !isdead(g, obj2gco(p))); + LUAU_ASSERT(!isdead(g, obj2gco(p))); LUAU_ASSERT(upisopen(p)); if (p->v == level) - { // found a corresponding upvalue? - if (!FFlag::LuauSimplerUpval && isdead(g, obj2gco(p))) // is it dead? - changewhite(obj2gco(p)); // resurrect it return p; - } pp = &p->u.open.threadnext; } - LUAU_ASSERT(luaC_threadactive(L)); - LUAU_ASSERT(!luaC_threadsleeping(L)); - LUAU_ASSERT(!FFlag::LuauNoSleepBit || !isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black + LUAU_ASSERT(L->isactive); + LUAU_ASSERT(!isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black UpVal* uv = luaM_newgco(L, UpVal, sizeof(UpVal), L->activememcat); // not found: create a new one luaC_init(L, uv, LUA_TUPVAL); @@ -96,22 +88,8 @@ UpVal* luaF_findupval(lua_State* L, StkId level) uv->v = level; // current value lives in the stack // chain the upvalue in the threads open upvalue list at the proper position - if (FFlag::LuauSimplerUpval) - { - uv->u.open.threadnext = *pp; - *pp = uv; - } - else - { - UpVal* next = *pp; - uv->u.open.threadnext = next; - - uv->u.open.threadprev = pp; - if (next) - next->u.open.threadprev = &uv->u.open.threadnext; - - *pp = uv; - } + uv->u.open.threadnext = *pp; + *pp = uv; // double link the upvalue in the global open upvalue list uv->u.open.prev = &g->uvhead; @@ -123,26 +101,8 @@ UpVal* luaF_findupval(lua_State* L, StkId level) return uv; } -void luaF_unlinkupval(UpVal* uv) -{ - LUAU_ASSERT(!FFlag::LuauSimplerUpval); - - // unlink upvalue from the global open upvalue list - LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); - uv->u.open.next->u.open.prev = uv->u.open.prev; - uv->u.open.prev->u.open.next = uv->u.open.next; - - // unlink upvalue from the thread open upvalue list - *uv->u.open.threadprev = uv->u.open.threadnext; - - if (UpVal* next = uv->u.open.threadnext) - next->u.open.threadprev = uv->u.open.threadprev; -} - void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) { - if (!FFlag::LuauSimplerUpval && uv->v != &uv->u.value) // is it open? - luaF_unlinkupval(uv); // remove from open list luaM_freegco(L, uv, sizeof(UpVal), uv->memcat, page); // free upvalue } @@ -154,41 +114,17 @@ void luaF_close(lua_State* L, StkId level) { GCObject* o = obj2gco(uv); LUAU_ASSERT(!isblack(o) && upisopen(uv)); + LUAU_ASSERT(!isdead(g, o)); - if (FFlag::LuauSimplerUpval) - { - LUAU_ASSERT(!isdead(g, o)); - - // unlink value *before* closing it since value storage overlaps - L->openupval = uv->u.open.threadnext; + // unlink value *before* closing it since value storage overlaps + L->openupval = uv->u.open.threadnext; - luaF_closeupval(L, uv, /* dead= */ false); - } - else - { - // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue - luaF_unlinkupval(uv); - - if (isdead(g, o)) - { - // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again - uv->v = &uv->u.value; - } - else - { - setobj(L, &uv->u.value, uv->v); - uv->v = &uv->u.value; - // GC state of a new closed upvalue has to be initialized - luaC_upvalclosed(L, uv); - } - } + luaF_closeupval(L, uv, /* dead= */ false); } } void luaF_closeupval(lua_State* L, UpVal* uv, bool dead) { - LUAU_ASSERT(FFlag::LuauSimplerUpval); - // unlink value from all lists *before* closing it since value storage overlaps LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); uv->u.open.next->u.open.prev = uv->u.open.prev; diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h index 899d040b5..679836e7e 100644 --- a/VM/src/lfunc.h +++ b/VM/src/lfunc.h @@ -15,7 +15,6 @@ LUAI_FUNC void luaF_close(lua_State* L, StkId level); LUAI_FUNC void luaF_closeupval(lua_State* L, UpVal* uv, bool dead); LUAI_FUNC void luaF_freeproto(lua_State* L, Proto* f, struct lua_Page* page); LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c, struct lua_Page* page); -LUAI_FUNC void luaF_unlinkupval(UpVal* uv); LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv, struct lua_Page* page); LUAI_FUNC const LocVar* luaF_getlocal(const Proto* func, int local_number, int pc); LUAI_FUNC const LocVar* luaF_findlocal(const Proto* func, int local_reg, int pc); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index fb610e130..c2a672e65 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauBetterThreadMark, false) + /* * Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * @@ -108,21 +110,14 @@ * API calls that can write to thread stacks outside of execution (which implies active) uses a thread barrier that checks if the thread is * black, and if it is it marks it as gray and puts it on a gray list to be rescanned during atomic phase. * - * NOTE: The above is only true when LuauNoSleepBit is enabled. - * * Upvalues are special objects that can be closed, in which case they contain the value (acting as a reference cell) and can be dealt * with using the regular algorithm, or open, in which case they refer to a stack slot in some other thread. These are difficult to deal * with because the stack writes are not monitored. Because of this open upvalues are treated in a somewhat special way: they are never marked * as black (doing so would violate the GC invariant), and they are kept in a special global list (global_State::uvhead) which is traversed * during atomic phase. This is needed because an open upvalue might point to a stack location in a dead thread that never marked the stack * slot - upvalues like this are identified since they don't have `markedopen` bit set during thread traversal and closed in `clearupvals`. - * - * NOTE: The above is only true when LuauSimplerUpval is enabled. */ -LUAU_FASTFLAGVARIABLE(LuauSimplerUpval, false) -LUAU_FASTFLAGVARIABLE(LuauNoSleepBit, false) -LUAU_FASTFLAGVARIABLE(LuauEagerShrink, false) LUAU_FASTFLAGVARIABLE(LuauFasterSweep, false) #define GC_SWEEPPAGESTEPCOST 16 @@ -408,14 +403,11 @@ static void traversestack(global_State* g, lua_State* l) stringmark(l->namecall); for (StkId o = l->stack; o < l->top; o++) markvalue(g, o); - if (FFlag::LuauSimplerUpval) + for (UpVal* uv = l->openupval; uv; uv = uv->u.open.threadnext) { - for (UpVal* uv = l->openupval; uv; uv = uv->u.open.threadnext) - { - LUAU_ASSERT(upisopen(uv)); - uv->markedopen = 1; - markobject(g, uv); - } + LUAU_ASSERT(upisopen(uv)); + uv->markedopen = 1; + markobject(g, uv); } } @@ -426,8 +418,29 @@ static void clearstack(lua_State* l) setnilvalue(o); } -// TODO: pull function definition here when FFlag::LuauEagerShrink is removed -static void shrinkstack(lua_State* L); +static void shrinkstack(lua_State* L) +{ + // compute used stack - note that we can't use th->top if we're in the middle of vararg call + StkId lim = L->top; + for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) + { + LUAU_ASSERT(ci->top <= L->stack_last); + if (lim < ci->top) + lim = ci->top; + } + + // shrink stack and callinfo arrays if we aren't using most of the space + int ci_used = cast_int(L->ci - L->base_ci); // number of `ci' in use + int s_used = cast_int(lim - L->stack); // part of stack in use + if (L->size_ci > LUAI_MAXCALLS) // handling overflow? + return; // do not touch the stacks + if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) + luaD_reallocCI(L, L->size_ci / 2); // still big enough... + condhardstacktests(luaD_reallocCI(L, ci_used + 1)); + if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) + luaD_reallocstack(L, L->stacksize / 2); // still big enough... + condhardstacktests(luaD_reallocstack(L, s_used)); +} /* ** traverse one gray object, turning it to black. @@ -460,37 +473,56 @@ static size_t propagatemark(global_State* g) lua_State* th = gco2th(o); g->gray = th->gclist; - LUAU_ASSERT(!luaC_threadsleeping(th)); + bool active = th->isactive || th == th->global->mainthread; - // threads that are executing and the main thread remain gray - bool active = luaC_threadactive(th) || th == th->global->mainthread; - - // TODO: Refactor this logic after LuauNoSleepBit is removed - if (!active && g->gcstate == GCSpropagate) + if (FFlag::LuauBetterThreadMark) { traversestack(g, th); - clearstack(th); - if (!FFlag::LuauNoSleepBit) - l_setbit(th->stackstate, THREAD_SLEEPINGBIT); + // active threads will need to be rescanned later to mark new stack writes so we mark them gray again + if (active) + { + th->gclist = g->grayagain; + g->grayagain = o; + + black2gray(o); + } + + // the stack needs to be cleared after the last modification of the thread state before sweep begins + // if the thread is inactive, we might not see the thread in this cycle so we must clear it now + if (!active || g->gcstate == GCSatomic) + clearstack(th); + + // we could shrink stack at any time but we opt to do it during initial mark to do that just once per cycle + if (g->gcstate == GCSpropagate) + shrinkstack(th); } else { - th->gclist = g->grayagain; - g->grayagain = o; + // TODO: Refactor this logic! + if (!active && g->gcstate == GCSpropagate) + { + traversestack(g, th); + clearstack(th); + } + else + { + th->gclist = g->grayagain; + g->grayagain = o; - black2gray(o); + black2gray(o); - traversestack(g, th); + traversestack(g, th); - // final traversal? - if (g->gcstate == GCSatomic) - clearstack(th); - } + // final traversal? + if (g->gcstate == GCSatomic) + clearstack(th); + } - // we could shrink stack at any time but we opt to skip it during atomic since it's redundant to do that more than once per cycle - if (FFlag::LuauEagerShrink && g->gcstate != GCSatomic) - shrinkstack(th); + // we could shrink stack at any time but we opt to skip it during atomic since it's redundant to do that more than once per cycle + if (g->gcstate != GCSatomic) + shrinkstack(th); + } return sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; } @@ -593,30 +625,6 @@ static size_t cleartable(lua_State* L, GCObject* l) return work; } -static void shrinkstack(lua_State* L) -{ - // compute used stack - note that we can't use th->top if we're in the middle of vararg call - StkId lim = L->top; - for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) - { - LUAU_ASSERT(ci->top <= L->stack_last); - if (lim < ci->top) - lim = ci->top; - } - - // shrink stack and callinfo arrays if we aren't using most of the space - int ci_used = cast_int(L->ci - L->base_ci); // number of `ci' in use - int s_used = cast_int(lim - L->stack); // part of stack in use - if (L->size_ci > LUAI_MAXCALLS) // handling overflow? - return; // do not touch the stacks - if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) - luaD_reallocCI(L, L->size_ci / 2); // still big enough... - condhardstacktests(luaD_reallocCI(L, ci_used + 1)); - if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) - luaD_reallocstack(L, L->stacksize / 2); // still big enough... - condhardstacktests(luaD_reallocstack(L, s_used)); -} - static void freeobj(lua_State* L, GCObject* o, lua_Page* page) { switch (o->gch.tt) @@ -669,21 +677,6 @@ static void shrinkbuffersfull(lua_State* L) static bool deletegco(void* context, lua_Page* page, GCObject* gco) { - // we are in the process of deleting everything - // threads with open upvalues will attempt to close them all on removal - // but those upvalues might point to stack values that were already deleted - if (!FFlag::LuauSimplerUpval && gco->gch.tt == LUA_TTHREAD) - { - lua_State* th = gco2th(gco); - - while (UpVal* uv = th->openupval) - { - luaF_unlinkupval(uv); - // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again - uv->v = &uv->u.value; - } - } - lua_State* L = (lua_State*)context; freeobj(L, gco, page); return true; @@ -701,7 +694,6 @@ void luaC_freeall(lua_State* L) LUAU_ASSERT(g->strt.hash[i] == NULL); LUAU_ASSERT(L->global->strt.nuse == 0); - LUAU_ASSERT(g->strbufgc == NULL); } static void markmt(global_State* g) @@ -829,15 +821,12 @@ static size_t atomic(lua_State* L) g->gcmetrics.currcycle.atomictimeclear += recordGcDeltaTime(currts); #endif - if (FFlag::LuauSimplerUpval) - { - // close orphaned live upvalues of dead threads and clear dead upvalues - work += clearupvals(L); + // close orphaned live upvalues of dead threads and clear dead upvalues + work += clearupvals(L); #ifdef LUAI_GCMETRICS - g->gcmetrics.currcycle.atomictimeupval += recordGcDeltaTime(currts); + g->gcmetrics.currcycle.atomictimeupval += recordGcDeltaTime(currts); #endif - } // flip current white g->currentwhite = cast_byte(otherwhite(g)); @@ -857,20 +846,6 @@ static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) int alive = (gco->gch.marked ^ WHITEBITS) & deadmask; - if (gco->gch.tt == LUA_TTHREAD) - { - lua_State* th = gco2th(gco); - - if (alive) - { - if (!FFlag::LuauNoSleepBit) - resetbit(th->stackstate, THREAD_SLEEPINGBIT); - - if (!FFlag::LuauEagerShrink) - shrinkstack(th); - } - } - if (alive) { LUAU_ASSERT(!isdead(g, gco)); @@ -896,8 +871,6 @@ static int sweepgcopage(lua_State* L, lua_Page* page) if (FFlag::LuauFasterSweep) { - LUAU_ASSERT(FFlag::LuauNoSleepBit && FFlag::LuauEagerShrink); - global_State* g = L->global; int deadmask = otherwhite(g); @@ -1183,7 +1156,7 @@ void luaC_fullgc(lua_State* L) startGcCycleMetrics(g); #endif - if (FFlag::LuauSimplerUpval ? keepinvariant(g) : g->gcstate <= GCSatomic) + if (keepinvariant(g)) { // reset sweep marks to sweep all elements (returning them to white) g->sweepgcopage = g->allgcopages; @@ -1201,14 +1174,11 @@ void luaC_fullgc(lua_State* L) gcstep(L, SIZE_MAX); } - if (FFlag::LuauSimplerUpval) + // clear markedopen bits for all open upvalues; these might be stuck from half-finished mark prior to full gc + for (UpVal* uv = g->uvhead.u.open.next; uv != &g->uvhead; uv = uv->u.open.next) { - // clear markedopen bits for all open upvalues; these might be stuck from half-finished mark prior to full gc - for (UpVal* uv = g->uvhead.u.open.next; uv != &g->uvhead; uv = uv->u.open.next) - { - LUAU_ASSERT(upisopen(uv)); - uv->markedopen = 0; - } + LUAU_ASSERT(upisopen(uv)); + uv->markedopen = 0; } #ifdef LUAI_GCMETRICS @@ -1245,16 +1215,6 @@ void luaC_fullgc(lua_State* L) #endif } -void luaC_barrierupval(lua_State* L, GCObject* v) -{ - LUAU_ASSERT(!FFlag::LuauSimplerUpval); - global_State* g = L->global; - LUAU_ASSERT(iswhite(v) && !isdead(g, v)); - - if (keepinvariant(g)) - reallymarkobject(g, v); -} - void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v) { global_State* g = L->global; @@ -1346,29 +1306,6 @@ int64_t luaC_allocationrate(lua_State* L) return int64_t((g->gcstats.atomicstarttotalsizebytes - g->gcstats.endtotalsizebytes) / duration); } -void luaC_wakethread(lua_State* L) -{ - LUAU_ASSERT(!FFlag::LuauNoSleepBit); - if (!luaC_threadsleeping(L)) - return; - - global_State* g = L->global; - - resetbit(L->stackstate, THREAD_SLEEPINGBIT); - - if (keepinvariant(g)) - { - GCObject* o = obj2gco(L); - - LUAU_ASSERT(isblack(o)); - - L->gclist = g->grayagain; - g->grayagain = o; - - black2gray(o); - } -} - const char* luaC_statename(int state) { switch (state) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 69379c89e..51216bd8e 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -6,8 +6,6 @@ #include "lobject.h" #include "lstate.h" -LUAU_FASTFLAG(LuauNoSleepBit) - /* ** Default settings for GC tunables (settable via lua_gc) */ @@ -75,14 +73,6 @@ LUAU_FASTFLAG(LuauNoSleepBit) #define luaC_white(g) cast_to(uint8_t, ((g)->currentwhite) & WHITEBITS) -// Thread stack states -// TODO: Remove with FFlag::LuauNoSleepBit and replace with lua_State::threadactive -#define THREAD_ACTIVEBIT 0 // thread is currently active -#define THREAD_SLEEPINGBIT 1 // thread is not executing and stack should not be modified - -#define luaC_threadactive(L) (testbit((L)->stackstate, THREAD_ACTIVEBIT)) -#define luaC_threadsleeping(L) (testbit((L)->stackstate, THREAD_SLEEPINGBIT)) - #define luaC_checkGC(L) \ { \ condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK)); \ @@ -121,25 +111,10 @@ LUAU_FASTFLAG(LuauNoSleepBit) luaC_barrierf(L, obj2gco(p), obj2gco(o)); \ } -// TODO: Remove with FFlag::LuauSimplerUpval -#define luaC_upvalbarrier(L, uv, tv) \ - { \ - if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || (uv)->v != &(uv)->u.value)) \ - luaC_barrierupval(L, gcvalue(tv)); \ - } - #define luaC_threadbarrier(L) \ { \ - if (FFlag::LuauNoSleepBit) \ - { \ - if (isblack(obj2gco(L))) \ - luaC_barrierback(L, obj2gco(L), &L->gclist); \ - } \ - else \ - { \ - if (luaC_threadsleeping(L)) \ - luaC_wakethread(L); \ - } \ + if (isblack(obj2gco(L))) \ + luaC_barrierback(L, obj2gco(L), &L->gclist); \ } #define luaC_init(L, o, tt_) \ @@ -154,12 +129,10 @@ LUAI_FUNC size_t luaC_step(lua_State* L, bool assist); LUAI_FUNC void luaC_fullgc(lua_State* L); LUAI_FUNC void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt); LUAI_FUNC void luaC_upvalclosed(lua_State* L, UpVal* uv); -LUAI_FUNC void luaC_barrierupval(lua_State* L, GCObject* v); LUAI_FUNC void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v); LUAI_FUNC void luaC_barriertable(lua_State* L, Table* t, GCObject* v); LUAI_FUNC void luaC_barrierback(lua_State* L, GCObject* o, GCObject** gclist); LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); LUAI_FUNC int64_t luaC_allocationrate(lua_State* L); -LUAI_FUNC void luaC_wakethread(lua_State* L); LUAI_FUNC const char* luaC_statename(int state); diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 778e22ba8..48aaf94b7 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -332,8 +332,6 @@ typedef struct UpVal // thread linked list (when open) struct UpVal* threadnext; - // note: this is the location of a pointer to this upvalue in the previous element that can be either an UpVal or a lua_State - struct UpVal** threadprev; // TODO: remove with FFlag::LuauSimplerUpval } open; } u; } UpVal; diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index e1cb2ab7a..fdd7fc2b8 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -10,8 +10,6 @@ #include "ldo.h" #include "ldebug.h" -LUAU_FASTFLAG(LuauSimplerUpval) - /* ** Main thread combines a thread state and the global state */ @@ -79,7 +77,7 @@ static void preinit_state(lua_State* L, global_State* g) L->namecall = NULL; L->cachedslot = 0; L->singlestep = false; - L->stackstate = 0; + L->isactive = false; L->activememcat = 0; L->userdata = NULL; } @@ -89,7 +87,6 @@ static void close_state(lua_State* L) global_State* g = L->global; luaF_close(L, L->stack); // close all upvalues for this thread luaC_freeall(L); // collect all objects - LUAU_ASSERT(g->strbufgc == NULL); LUAU_ASSERT(g->strt.nuse == 0); luaM_freearray(L, L->global->strt.hash, L->global->strt.size, TString*, 0); freestack(L, L); @@ -121,11 +118,6 @@ lua_State* luaE_newthread(lua_State* L) void luaE_freethread(lua_State* L, lua_State* L1, lua_Page* page) { - if (!FFlag::LuauSimplerUpval) - { - luaF_close(L1, L1->stack); // close all upvalues for this thread - LUAU_ASSERT(L1->openupval == NULL); - } global_State* g = L->global; if (g->cb.userthread) g->cb.userthread(NULL, L1); @@ -199,7 +191,6 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->gray = NULL; g->grayagain = NULL; g->weak = NULL; - g->strbufgc = NULL; g->totalbytes = sizeof(LG); g->gcgoal = LUAI_GCGOAL; g->gcstepmul = LUAI_GCSTEPMUL; diff --git a/VM/src/lstate.h b/VM/src/lstate.h index df47ce7e0..06544463a 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -167,8 +167,6 @@ typedef struct global_State GCObject* grayagain; // list of objects to be traversed atomically GCObject* weak; // list of weak tables (to be cleared) - TString* strbufgc; // list of all string buffer objects; TODO: remove with LuauNoStrbufLink - size_t GCthreshold; // when totalbytes > GCthreshold, run GC step size_t totalbytes; // number of bytes currently allocated @@ -222,8 +220,8 @@ struct lua_State uint8_t status; uint8_t activememcat; // memory category that is used for new GC object allocations - uint8_t stackstate; + bool isactive; // thread is currently executing, stack may be mutated without barriers bool singlestep; // call debugstep hook after each instruction diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index f43d03b19..e57f6c29e 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauNoStrbufLink, false) - unsigned int luaS_hash(const char* str, size_t len) { // Note that this hashing algorithm is replicated in BytecodeBuilder.cpp, BytecodeBuilder::getStringHash @@ -96,51 +94,18 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) return ts; } -static void unlinkstrbuf(lua_State* L, TString* ts) -{ - LUAU_ASSERT(!FFlag::LuauNoStrbufLink); - global_State* g = L->global; - - TString** p = &g->strbufgc; - - while (TString* curr = *p) - { - if (curr == ts) - { - *p = curr->next; - return; - } - else - { - p = &curr->next; - } - } - - LUAU_ASSERT(!"failed to find string buffer"); -} - TString* luaS_bufstart(lua_State* L, size_t size) { if (size > MAXSSIZE) luaM_toobig(L); - global_State* g = L->global; - TString* ts = luaM_newgco(L, TString, sizestring(size), L->activememcat); luaC_init(L, ts, LUA_TSTRING); ts->atom = ATOM_UNDEF; ts->hash = 0; // computed in luaS_buffinish ts->len = unsigned(size); - if (FFlag::LuauNoStrbufLink) - { - ts->next = NULL; - } - else - { - ts->next = g->strbufgc; - g->strbufgc = ts; - } + ts->next = NULL; return ts; } @@ -164,10 +129,7 @@ TString* luaS_buffinish(lua_State* L, TString* ts) } } - if (FFlag::LuauNoStrbufLink) - LUAU_ASSERT(ts->next == NULL); - else - unlinkstrbuf(L, ts); + LUAU_ASSERT(ts->next == NULL); ts->hash = h; ts->data[ts->len] = '\0'; // ending 0 @@ -222,21 +184,10 @@ static bool unlinkstr(lua_State* L, TString* ts) void luaS_free(lua_State* L, TString* ts, lua_Page* page) { - if (FFlag::LuauNoStrbufLink) - { - if (unlinkstr(L, ts)) - L->global->strt.nuse--; - else - LUAU_ASSERT(ts->next == NULL); // orphaned string buffer - } + if (unlinkstr(L, ts)) + L->global->strt.nuse--; else - { - // Unchain from the string table - if (!unlinkstr(L, ts)) - unlinkstrbuf(L, ts); // An unlikely scenario when we have a string buffer on our hands - else - L->global->strt.nuse--; - } + LUAU_ASSERT(ts->next == NULL); // orphaned string buffer luaM_freegco(L, ts, sizestring(ts->len), ts->memcat, page); } diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index e34889160..c3c744b2e 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,9 +16,6 @@ #include -LUAU_FASTFLAG(LuauSimplerUpval) -LUAU_FASTFLAG(LuauNoSleepBit) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -69,11 +66,6 @@ LUAU_FASTFLAG(LuauNoSleepBit) #define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) #define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) -// NOTE: If debugging the Luau code, disable this macro to prevent timeouts from -// occurring when tracing code in Visual Studio / XCode -#if 0 -#define VM_INTERRUPT() -#else #define VM_INTERRUPT() \ { \ void (*interrupt)(lua_State*, int) = L->global->cb.interrupt; \ @@ -87,7 +79,6 @@ LUAU_FASTFLAG(LuauNoSleepBit) } \ } \ } -#endif #define VM_DISPATCH_OP(op) &&CASE_##op @@ -150,32 +141,6 @@ LUAU_NOINLINE static void luau_prepareFORN(lua_State* L, StkId plimit, StkId pst luaG_forerror(L, pstep, "step"); } -LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) -{ - // note: it's safe to push arguments past top for complicated reasons (see top of the file) - StkId ra = &L->base[a]; - LUAU_ASSERT(ra + 3 <= L->top); - - setobjs2s(L, ra + 3 + 2, ra + 2); - setobjs2s(L, ra + 3 + 1, ra + 1); - setobjs2s(L, ra + 3, ra); - - L->top = ra + 3 + 3; // func. + 2 args (state and index) - LUAU_ASSERT(L->top <= L->stack_last); - - luaD_call(L, ra + 3, c); - L->top = L->ci->top; - - // recompute ra since stack might have been reallocated - ra = &L->base[a]; - LUAU_ASSERT(ra < L->top); - - // copy first variable back into the iteration index - setobjs2s(L, ra + 2, ra + 3); - - return ttisnil(ra + 2); -} - // calls a C function f with no yielding support; optionally save one resulting value to the res register // the function and arguments have to already be pushed to L->top LUAU_NOINLINE static void luau_callTM(lua_State* L, int nparams, int res) @@ -316,9 +281,8 @@ static void luau_execute(lua_State* L) const Instruction* pc; LUAU_ASSERT(isLua(L->ci)); - LUAU_ASSERT(luaC_threadactive(L)); - LUAU_ASSERT(!luaC_threadsleeping(L)); - LUAU_ASSERT(!FFlag::LuauNoSleepBit || !isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black + LUAU_ASSERT(L->isactive); + LUAU_ASSERT(!isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black pc = L->ci->savedpc; cl = clvalue(L->ci->func); @@ -498,8 +462,6 @@ static void luau_execute(lua_State* L) setobj(L, uv->v, ra); luaC_barrier(L, uv, ra); - if (!FFlag::LuauSimplerUpval) - luaC_upvalbarrier(L, uv, uv->v); VM_NEXT(); } @@ -2403,52 +2365,8 @@ static void luau_execute(lua_State* L) VM_CASE(LOP_DEP_FORGLOOP_INEXT) { - VM_INTERRUPT(); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - - // fast-path: ipairs/inext - if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) - { - Table* h = hvalue(ra + 1); - int index = int(reinterpret_cast(pvalue(ra + 2))); - - // if 1-based index of the last iteration is in bounds, this means 0-based index of the current iteration is in bounds - if (unsigned(index) < unsigned(h->sizearray)) - { - // note that nil elements inside the array terminate the traversal - if (!ttisnil(&h->array[index])) - { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); - setnvalue(ra + 3, double(index + 1)); - setobj2s(L, ra + 4, &h->array[index]); - - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } - else - { - // fallthrough to exit - VM_NEXT(); - } - } - else - { - // fallthrough to exit - VM_NEXT(); - } - } - else - { - // slow-path; can call Lua/C generators - bool stop; - VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), 2)); - - pc += stop ? 0 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); } VM_CASE(LOP_FORGPREP_NEXT) @@ -2475,68 +2393,8 @@ static void luau_execute(lua_State* L) VM_CASE(LOP_DEP_FORGLOOP_NEXT) { - VM_INTERRUPT(); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - - // fast-path: pairs/next - if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) - { - Table* h = hvalue(ra + 1); - int index = int(reinterpret_cast(pvalue(ra + 2))); - - int sizearray = h->sizearray; - int sizenode = 1 << h->lsizenode; - - // first we advance index through the array portion - while (unsigned(index) < unsigned(sizearray)) - { - if (!ttisnil(&h->array[index])) - { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); - setnvalue(ra + 3, double(index + 1)); - setobj2s(L, ra + 4, &h->array[index]); - - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } - - index++; - } - - // then we advance index through the hash portion - while (unsigned(index - sizearray) < unsigned(sizenode)) - { - LuaNode* n = &h->node[index - sizearray]; - - if (!ttisnil(gval(n))) - { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); - getnodekey(L, ra + 3, n); - setobj2s(L, ra + 4, gval(n)); - - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } - - index++; - } - - // fallthrough to exit - VM_NEXT(); - } - else - { - // slow-path; can call Lua/C generators - bool stop; - VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), 2)); - - pc += stop ? 0 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); } VM_CASE(LOP_GETVARARGS) @@ -2750,92 +2608,14 @@ static void luau_execute(lua_State* L) VM_CASE(LOP_DEP_JUMPIFEQK) { - Instruction insn = *pc++; - uint32_t aux = *pc; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - TValue* rb = VM_KV(aux); - - // Note that all jumps below jump by 1 in the "false" case to skip over aux - if (ttype(ra) == ttype(rb)) - { - switch (ttype(ra)) - { - case LUA_TNIL: - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TBOOLEAN: - pc += bvalue(ra) == bvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TNUMBER: - pc += nvalue(ra) == nvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TSTRING: - pc += gcvalue(ra) == gcvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - default:; - } - - LUAU_ASSERT(!"Constant is expected to be of primitive type"); - } - else - { - pc += 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); } VM_CASE(LOP_DEP_JUMPIFNOTEQK) { - Instruction insn = *pc++; - uint32_t aux = *pc; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - TValue* rb = VM_KV(aux); - - // Note that all jumps below jump by 1 in the "true" case to skip over aux - if (ttype(ra) == ttype(rb)) - { - switch (ttype(ra)) - { - case LUA_TNIL: - pc += 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TBOOLEAN: - pc += bvalue(ra) != bvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TNUMBER: - pc += nvalue(ra) != nvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TSTRING: - pc += gcvalue(ra) != gcvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - default:; - } - - LUAU_ASSERT(!"Constant is expected to be of primitive type"); - } - else - { - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); } VM_CASE(LOP_FASTCALL1) diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 2b650fa47..4b21c443c 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -1,9 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Fixture.h" #include "Luau/AstQuery.h" +#include "AstQueryDsl.h" #include "doctest.h" +#include "Fixture.h" using namespace Luau; diff --git a/tests/AstQueryDsl.cpp b/tests/AstQueryDsl.cpp new file mode 100644 index 000000000..0cf28f3bd --- /dev/null +++ b/tests/AstQueryDsl.cpp @@ -0,0 +1,45 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "AstQueryDsl.h" + +namespace Luau +{ + +FindNthOccurenceOf::FindNthOccurenceOf(Nth nth) + : requestedNth(nth) +{ +} + +bool FindNthOccurenceOf::checkIt(AstNode* n) +{ + if (theNode) + return false; + + if (n->classIndex == requestedNth.classIndex) + { + // Human factor: the requestedNth starts from 1 because of the term `nth`. + if (currentOccurrence + 1 != requestedNth.nth) + ++currentOccurrence; + else + theNode = n; + } + + return !theNode; // once found, returns false and stops traversal +} + +bool FindNthOccurenceOf::visit(AstNode* n) +{ + return checkIt(n); +} + +bool FindNthOccurenceOf::visit(AstType* t) +{ + return checkIt(t); +} + +bool FindNthOccurenceOf::visit(AstTypePack* t) +{ + return checkIt(t); +} + +} diff --git a/tests/AstQueryDsl.h b/tests/AstQueryDsl.h new file mode 100644 index 000000000..6bf3bd303 --- /dev/null +++ b/tests/AstQueryDsl.h @@ -0,0 +1,83 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Common.h" + +#include +#include + +namespace Luau +{ + +struct Nth +{ + int classIndex; + int nth; +}; + +template +Nth nth(int nth = 1) +{ + static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); + LUAU_ASSERT(nth > 0); // Did you mean to use `nth(1)`? + + return Nth{T::ClassIndex(), nth}; +} + +struct FindNthOccurenceOf : public AstVisitor +{ + Nth requestedNth; + int currentOccurrence = 0; + AstNode* theNode = nullptr; + + FindNthOccurenceOf(Nth nth); + + bool checkIt(AstNode* n); + + bool visit(AstNode* n) override; + bool visit(AstType* n) override; + bool visit(AstTypePack* n) override; +}; + +/** DSL querying of the AST. + * + * Given an AST, one can query for a particular node directly without having to manually unwrap the tree, for example: + * + * ``` + * if a and b then + * print(a + b) + * end + * + * function f(x, y) + * return x + y + * end + * ``` + * + * There are numerous ways to access the second AstExprBinary. + * 1. Luau::query(block, {nth(), nth()}) + * 2. Luau::query(Luau::query(block)) + * 3. Luau::query(block, {nth(2)}) + */ +template +T* query(AstNode* node, const std::vector& nths = {nth(N)}) +{ + static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); + + // If a nested query call fails to find the node in question, subsequent calls can propagate rather than trying to do more. + // This supports `query(query(...))` + + for (Nth nth : nths) + { + if (!node) + return nullptr; + + FindNthOccurenceOf finder{nth}; + node->visit(&finder); + node = finder.theNode; + } + + return node ? node->as() : nullptr; +} + +} diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 005bc9598..758fb44cb 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -1,9 +1,16 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/AssemblyBuilderX64.h" #include "Luau/CodeAllocator.h" +#include "Luau/CodeBlockUnwind.h" +#include "Luau/UnwindBuilder.h" +#include "Luau/UnwindBuilderDwarf2.h" +#include "Luau/UnwindBuilderWin.h" #include "doctest.h" +#include +#include + #include using namespace Luau::CodeGen; @@ -41,8 +48,8 @@ TEST_CASE("CodeAllocation") TEST_CASE("CodeAllocationFailure") { - size_t blockSize = 4096; - size_t maxTotalSize = 8192; + size_t blockSize = 3000; + size_t maxTotalSize = 7000; CodeAllocator allocator(blockSize, maxTotalSize); uint8_t* nativeData; @@ -50,11 +57,13 @@ TEST_CASE("CodeAllocationFailure") uint8_t* nativeEntry; std::vector code; - code.resize(6000); + code.resize(4000); + // allocation has to fit in a block REQUIRE(!allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); - code.resize(3000); + // each allocation exhausts a block, so third allocation fails + code.resize(2000); REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); REQUIRE(!allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); @@ -118,19 +127,84 @@ TEST_CASE("CodeAllocationWithUnwindCallbacks") CHECK(info.destroyCalled); } -#if defined(__x86_64__) || defined(_M_X64) -TEST_CASE("GeneratedCodeExecution") +TEST_CASE("WindowsUnwindCodesX64") { + UnwindBuilderWin unwind; + + unwind.start(); + unwind.spill(16, rdx); + unwind.spill(8, rcx); + unwind.save(rdi); + unwind.save(rsi); + unwind.save(rbx); + unwind.save(rbp); + unwind.save(r12); + unwind.save(r13); + unwind.save(r14); + unwind.save(r15); + unwind.allocStack(72); + unwind.setupFrameReg(rbp, 48); + unwind.finish(); + + std::vector data; + data.resize(unwind.getSize()); + unwind.finalize(data.data(), nullptr, 0); + + std::vector expected{0x01, 0x23, 0x0a, 0x35, 0x23, 0x33, 0x1e, 0x82, 0x1a, 0xf0, 0x18, 0xe0, 0x16, 0xd0, 0x14, 0xc0, 0x12, 0x50, 0x10, + 0x30, 0x0e, 0x60, 0x0c, 0x70}; + + REQUIRE(data.size() == expected.size()); + CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); +} + +TEST_CASE("Dwarf2UnwindCodesX64") +{ + UnwindBuilderDwarf2 unwind; + + unwind.start(); + unwind.save(rdi); + unwind.save(rsi); + unwind.save(rbx); + unwind.save(rbp); + unwind.save(r12); + unwind.save(r13); + unwind.save(r14); + unwind.save(r15); + unwind.allocStack(72); + unwind.setupFrameReg(rbp, 48); + unwind.finish(); + + std::vector data; + data.resize(unwind.getSize()); + unwind.finalize(data.data(), nullptr, 0); + + std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x05, 0x10, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x85, 0x02, 0x02, 0x02, 0x0e, 0x18, 0x84, 0x03, 0x02, 0x02, 0x0e, 0x20, 0x83, + 0x04, 0x02, 0x02, 0x0e, 0x28, 0x86, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, 0x0e, 0x40, + 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + REQUIRE(data.size() == expected.size()); + CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); +} + +#if defined(__x86_64__) || defined(_M_X64) + #if defined(_WIN32) - // Windows x64 ABI - constexpr RegisterX64 rArg1 = rcx; - constexpr RegisterX64 rArg2 = rdx; +// Windows x64 ABI +constexpr RegisterX64 rArg1 = rcx; +constexpr RegisterX64 rArg2 = rdx; #else - // System V AMD64 ABI - constexpr RegisterX64 rArg1 = rdi; - constexpr RegisterX64 rArg2 = rsi; +// System V AMD64 ABI +constexpr RegisterX64 rArg1 = rdi; +constexpr RegisterX64 rArg2 = rsi; #endif +constexpr RegisterX64 rNonVol1 = r12; +constexpr RegisterX64 rNonVol2 = rbx; + +TEST_CASE("GeneratedCodeExecution") +{ AssemblyBuilderX64 build(/* logText= */ false); build.mov(rax, rArg1); @@ -155,6 +229,90 @@ TEST_CASE("GeneratedCodeExecution") int64_t result = f(10, 20); CHECK(result == 210); } + +void throwing(int64_t arg) +{ + CHECK(arg == 25); + + throw std::runtime_error("testing"); +} + +TEST_CASE("GeneratedCodeExecutionWithThrow") +{ + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->start(); + + // Prologue + build.push(rNonVol1); + unwind->save(rNonVol1); + build.push(rNonVol2); + unwind->save(rNonVol2); + build.push(rbp); + unwind->save(rbp); + + int stackSize = 32; + int localsSize = 16; + + build.sub(rsp, stackSize + localsSize); + unwind->allocStack(stackSize + localsSize); + + build.lea(rbp, qword[rsp + stackSize]); + unwind->setupFrameReg(rbp, stackSize); + + unwind->finish(); + + // Body + build.mov(rNonVol1, rArg1); + build.mov(rNonVol2, rArg2); + + build.add(rNonVol1, 15); + build.mov(rArg1, rNonVol1); + build.call(rNonVol2); + + // Epilogue + build.lea(rsp, qword[rbp + localsSize]); + build.pop(rbp); + build.pop(rNonVol2); + build.pop(rNonVol1); + build.ret(); + + build.finalize(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f = (FunctionType*)nativeEntry; + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} + #endif TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index c6bdb4dbe..25129bffb 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -746,6 +746,8 @@ TEST_CASE("ApiTables") lua_newtable(L); lua_pushnumber(L, 123.0); lua_setfield(L, -2, "key"); + lua_pushnumber(L, 456.0); + lua_rawsetfield(L, -2, "key2"); lua_pushstring(L, "test"); lua_rawseti(L, -2, 5); @@ -761,8 +763,8 @@ TEST_CASE("ApiTables") lua_pop(L, 1); // lua_rawgetfield - CHECK(lua_rawgetfield(L, -1, "key") == LUA_TNUMBER); - CHECK(lua_tonumber(L, -1) == 123.0); + CHECK(lua_rawgetfield(L, -1, "key2") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 456.0); lua_pop(L, 1); // lua_rawget diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp index bbe294290..5c34e3d64 100644 --- a/tests/ConstraintGraphBuilder.test.cpp +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -1,8 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Fixture.h" #include "Luau/ConstraintGraphBuilder.h" +#include "Luau/NotNull.h" +#include "Luau/ToString.h" +#include "ConstraintGraphBuilderFixture.h" +#include "Fixture.h" #include "doctest.h" using namespace Luau; diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp new file mode 100644 index 000000000..1958d1dca --- /dev/null +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -0,0 +1,17 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ConstraintGraphBuilderFixture.h" + +namespace Luau +{ + +ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() + : Fixture() + , mainModule(new Module) + , cgb("MainModule", mainModule, &arena, NotNull(&moduleResolver), singletonTypes, NotNull(&ice), frontend.getGlobalScope(), &logger) + , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} +{ + BlockedTypeVar::nextIndex = 0; + BlockedTypePack::nextIndex = 0; +} + +} diff --git a/tests/ConstraintGraphBuilderFixture.h b/tests/ConstraintGraphBuilderFixture.h new file mode 100644 index 000000000..262e39016 --- /dev/null +++ b/tests/ConstraintGraphBuilderFixture.h @@ -0,0 +1,27 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/DcrLogger.h" +#include "Luau/TypeArena.h" +#include "Luau/Module.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +namespace Luau +{ + +struct ConstraintGraphBuilderFixture : Fixture +{ + TypeArena arena; + ModulePtr mainModule; + ConstraintGraphBuilder cgb; + DcrLogger logger; + + ScopedFastFlag forceTheFlag; + + ConstraintGraphBuilderFixture(); +}; + +} diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index fba578230..9976bd2c6 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -1,12 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Fixture.h" - -#include "doctest.h" - #include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintSolver.h" +#include "ConstraintGraphBuilderFixture.h" +#include "Fixture.h" +#include "doctest.h" + using namespace Luau; static TypeId requireBinding(NotNull scope, const char* name) diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 6c4594f4b..3f77978ce 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -444,16 +444,6 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); } -ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() - : Fixture() - , mainModule(new Module) - , cgb(mainModuleName, mainModule, &arena, NotNull(&moduleResolver), singletonTypes, NotNull(&ice), frontend.getGlobalScope(), &logger) - , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} -{ - BlockedTypeVar::nextIndex = 0; - BlockedTypePack::nextIndex = 0; -} - ModuleName fromString(std::string_view name) { return ModuleName(name); @@ -516,41 +506,4 @@ void dump(const std::vector& constraints) printf("%s\n", toString(c, opts).c_str()); } -FindNthOccurenceOf::FindNthOccurenceOf(Nth nth) - : requestedNth(nth) -{ -} - -bool FindNthOccurenceOf::checkIt(AstNode* n) -{ - if (theNode) - return false; - - if (n->classIndex == requestedNth.classIndex) - { - // Human factor: the requestedNth starts from 1 because of the term `nth`. - if (currentOccurrence + 1 != requestedNth.nth) - ++currentOccurrence; - else - theNode = n; - } - - return !theNode; // once found, returns false and stops traversal -} - -bool FindNthOccurenceOf::visit(AstNode* n) -{ - return checkIt(n); -} - -bool FindNthOccurenceOf::visit(AstType* t) -{ - return checkIt(t); -} - -bool FindNthOccurenceOf::visit(AstTypePack* t) -{ - return checkIt(t); -} - } // namespace Luau diff --git a/tests/Fixture.h b/tests/Fixture.h index 03101bbf3..2fb48468b 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -12,7 +12,6 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" -#include "Luau/DcrLogger.h" #include "IostreamOptional.h" #include "ScopedFlags.h" @@ -162,18 +161,6 @@ struct BuiltinsFixture : Fixture BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); }; -struct ConstraintGraphBuilderFixture : Fixture -{ - TypeArena arena; - ModulePtr mainModule; - ConstraintGraphBuilder cgb; - DcrLogger logger; - - ScopedFastFlag forceTheFlag; - - ConstraintGraphBuilderFixture(); -}; - ModuleName fromString(std::string_view name); template @@ -199,76 +186,6 @@ std::optional lookupName(ScopePtr scope, const std::string& name); // Wa std::optional linearSearchForBinding(Scope* scope, const char* name); -struct Nth -{ - int classIndex; - int nth; -}; - -template -Nth nth(int nth = 1) -{ - static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); - LUAU_ASSERT(nth > 0); // Did you mean to use `nth(1)`? - - return Nth{T::ClassIndex(), nth}; -} - -struct FindNthOccurenceOf : public AstVisitor -{ - Nth requestedNth; - int currentOccurrence = 0; - AstNode* theNode = nullptr; - - FindNthOccurenceOf(Nth nth); - - bool checkIt(AstNode* n); - - bool visit(AstNode* n) override; - bool visit(AstType* n) override; - bool visit(AstTypePack* n) override; -}; - -/** DSL querying of the AST. - * - * Given an AST, one can query for a particular node directly without having to manually unwrap the tree, for example: - * - * ``` - * if a and b then - * print(a + b) - * end - * - * function f(x, y) - * return x + y - * end - * ``` - * - * There are numerous ways to access the second AstExprBinary. - * 1. Luau::query(block, {nth(), nth()}) - * 2. Luau::query(Luau::query(block)) - * 3. Luau::query(block, {nth(2)}) - */ -template -T* query(AstNode* node, const std::vector& nths = {nth(N)}) -{ - static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); - - // If a nested query call fails to find the node in question, subsequent calls can propagate rather than trying to do more. - // This supports `query(query(...))` - - for (Nth nth : nths) - { - if (!node) - return nullptr; - - FindNthOccurenceOf finder{nth}; - node->visit(&finder); - node = finder.theNode; - } - - return node ? node->as() : nullptr; -} - } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 07a04363f..6da3f569b 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -459,7 +459,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "thread_is_a_type") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.threadType, *requireType("co")); + CHECK("thread" == toString(requireType("co"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_resume_anything_goes") @@ -627,6 +627,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_decimal_argument_is_rounded_down // Could be flaky if the fix has regressed. TEST_CASE_FIXTURE(BuiltinsFixture, "bad_select_should_not_crash") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + CheckResult result = check(R"( do end local _ = function(l0,...) @@ -638,8 +640,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bad_select_should_not_crash") )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Argument count mismatch. Function expects at least 1 argument, but none are specified", toString(result.errors[0])); - CHECK_EQ("Argument count mismatch. Function expects 1 argument, but none are specified", toString(result.errors[1])); + CHECK_EQ("Argument count mismatch. Function '_' expects at least 1 argument, but none are specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function 'select' expects 1 argument, but none are specified", toString(result.errors[1])); } TEST_CASE_FIXTURE(BuiltinsFixture, "select_way_out_of_range") @@ -824,12 +826,12 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_lib_self_noself") TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_definition") { CheckResult result = check(R"_( -local a, b, c = ("hey"):gmatch("(.)(.)(.)")() + local a, b, c = ("hey"):gmatch("(.)(.)(.)")() -for c in ("hey"):gmatch("(.)") do - print(c:upper()) -end -)_"); + for c in ("hey"):gmatch("(.)") do + print(c:upper()) + end + )_"); LUAU_REQUIRE_NO_ERRORS(result); } @@ -1008,6 +1010,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + ScopedFastFlag sff{"LuauSetMetaTableArgsCheck", true}; CheckResult result = check(R"( local a = {b=setmetatable} @@ -1016,8 +1020,8 @@ a:b() a:b({}) )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function expects 2 arguments, but none are specified"); - CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function expects 2 arguments, but only 1 is specified"); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'a.b' expects 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'a.b' expects 2 arguments, but only 1 is specified"); } TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 35e67ec55..bde28dccb 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1732,4 +1732,56 @@ TEST_CASE_FIXTURE(Fixture, "dont_mutate_the_underlying_head_of_typepack_when_cal LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_errors") +{ + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + + CheckResult result = check(R"( +local function foo1(a: number) end +foo1() + +local function foo2(a: number, b: string?) end +foo2() + +local function foo3(a: number, b: string?, c: any) end -- any is optional +foo3() + +string.find() + +local t = {} +function t.foo(x: number, y: string?, ...: any) end +function t:bar(x: number, y: string?) end +t.foo() + +t:bar() + +local u = { a = t } +u.a.foo() + )"); + + LUAU_REQUIRE_ERROR_COUNT(7, result); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo1' expects 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'foo2' expects 1 to 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[2]), "Argument count mismatch. Function 'foo3' expects 1 to 3 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[3]), "Argument count mismatch. Function 'string.find' expects 2 to 4 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[4]), "Argument count mismatch. Function 't.foo' expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[5]), "Argument count mismatch. Function 't.bar' expects 2 to 3 arguments, but only 1 is specified"); + CHECK_EQ(toString(result.errors[6]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); +} + +// This might be surprising, but since 'any' became optional, unannotated functions in non-strict 'expect' 0 arguments +TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_error_nonstrict") +{ + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + + CheckResult result = check(R"( +--!nonstrict +local function foo(a, b) end +foo(string.find("hello", "e")) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo' expects 0 to 2 arguments, but 3 are specified"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 9ac259cf8..3c8677706 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -782,6 +782,8 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_few") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + CheckResult result = check(R"( function test(a: number) return 1 @@ -794,11 +796,13 @@ wrapper(test) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function 'wrapper' expects 2 arguments, but only 1 is specified)"); } TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + CheckResult result = check(R"( function test2(a: number, b: string) return 1 @@ -811,7 +815,7 @@ wrapper(test2, 1, "", 3) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 3 arguments, but 4 are specified)"); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function 'wrapper' expects 3 arguments, but 4 are specified)"); } TEST_CASE_FIXTURE(Fixture, "generic_function") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 3482b75cf..40ea0ca11 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -69,6 +69,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") CHECK("string" == toString(requireType("c"))); CHECK(expected == decorateWithTypes(code)); + + LUAU_REQUIRE_NO_ERRORS(result); } // We had a bug where if you have two type packs that looks like: diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 0a130d494..7d98b5db7 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -202,6 +202,15 @@ TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_has_a_boolean") +{ + CheckResult result = check(R"( + local t={a=1,b=false} + )"); + + CHECK("{ a: number, b: boolean }" == toString(requireType("t"), {true})); +} + TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 97c3da4f6..d183f650e 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2615,9 +2615,11 @@ do end TEST_CASE_FIXTURE(BuiltinsFixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + CheckResult result = check("local x = setmetatable({})"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but only 1 is specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function 'setmetatable' expects 2 arguments, but only 1 is specified", toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning") @@ -2695,6 +2697,8 @@ local baz = foo[bar] TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + CheckResult result = check(R"( local a = setmetatable({ x = 2 }, { __call = function(self) @@ -2706,7 +2710,7 @@ local c = a(2) -- too many arguments )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function 'a' expects 1 argument, but 2 are specified", toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "access_index_metamethod_that_returns_variadic") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index e1dc5023e..e8bfb67f9 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1105,4 +1105,21 @@ end CHECK_EQ(*getMainModule()->astResolvedTypes.find(annotation), *ty); } +TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_higher_order_function") +{ + CheckResult result = check(R"( + function higher(cb: (number) -> ()) end + + higher(function(n) -- no error here. n : number + local e: string = n -- error here. n /: string + end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + Location location = result.errors[0].location; + CHECK(location.begin.line == 4); + CHECK(location.end.line == 4); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index eaa8b0539..7d33809fa 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -964,4 +964,40 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "detect_cyclic_typepacks2") LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments") +{ + ScopedFastFlag luauCallUnifyPackTails{"LuauCallUnifyPackTails", true}; + + CheckResult result = check(R"( + function foo(...: string): number + return 1 + end + + function bar(...: number): number + return foo(...) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'string'"); +} + +TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments_free") +{ + ScopedFastFlag luauCallUnifyPackTails{"LuauCallUnifyPackTails", true}; + + CheckResult result = check(R"( + function foo(...: T...): T... + return ... + end + + function bar(...: number): boolean + return foo(...) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'boolean'"); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index ef995aa62..9c2df8059 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -383,8 +383,6 @@ TableTests.casting_sealed_tables_with_props_into_table_with_indexer TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.checked_prop_too_early -TableTests.common_table_element_union_in_call -TableTests.common_table_element_union_in_call_tail TableTests.confusing_indexing TableTests.defining_a_method_for_a_builtin_sealed_table_must_fail TableTests.defining_a_method_for_a_local_sealed_table_must_fail @@ -540,7 +538,6 @@ TypeInfer.type_infer_recursion_limit_no_ice TypeInferAnyError.assign_prop_to_table_by_calling_any_yields_any TypeInferAnyError.can_get_length_of_any TypeInferAnyError.for_in_loop_iterator_is_any2 -TypeInferAnyError.for_in_loop_iterator_returns_any TypeInferAnyError.length_of_error_type_does_not_produce_an_error TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferAnyError.union_of_types_regression_test @@ -561,7 +558,6 @@ TypeInferClasses.table_indexers_are_invariant TypeInferClasses.table_properties_are_invariant TypeInferClasses.warn_when_prop_almost_matches TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class -TypeInferFunctions.another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments TypeInferFunctions.another_recursive_local_function TypeInferFunctions.call_o_with_another_argument_after_foo_was_quantified TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types @@ -586,13 +582,14 @@ TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer TypeInferFunctions.higher_order_function_2 TypeInferFunctions.higher_order_function_4 TypeInferFunctions.ignored_return_values +TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict +TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.inconsistent_higher_order_function TypeInferFunctions.inconsistent_return_types TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_return_type_from_selected_overload TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.it_is_ok_not_to_supply_enough_retvals -TypeInferFunctions.it_is_ok_to_oversaturate_a_higher_order_function_argument TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count TypeInferFunctions.no_lossy_function_type From 48fb5a3483a3ab5ca2ec8803ce49d3f771d97253 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 23 Sep 2022 11:32:10 -0700 Subject: [PATCH 05/66] Sync to upstream/release/546 --- Analysis/include/Luau/Constraint.h | 43 +- .../include/Luau/ConstraintGraphBuilder.h | 12 +- Analysis/include/Luau/ConstraintSolver.h | 12 + Analysis/include/Luau/Normalize.h | 4 +- Analysis/include/Luau/RequireTracer.h | 1 + Analysis/src/ConstraintGraphBuilder.cpp | 271 +- Analysis/src/ConstraintSolver.cpp | 172 +- Analysis/src/Normalize.cpp | 167 +- Analysis/src/ToString.cpp | 9 + Analysis/src/TypeChecker2.cpp | 25 +- Ast/include/Luau/DenseHash.h | 188 +- Ast/include/Luau/Lexer.h | 2 + Ast/src/Parser.cpp | 137 +- CMakeLists.txt | 1 + CodeGen/include/Luau/UnwindBuilderDwarf2.h | 6 +- CodeGen/src/AssemblyBuilderX64.cpp | 46 +- CodeGen/src/ByteUtils.h | 78 + CodeGen/src/Fallbacks.cpp | 2511 +++++++++++++++++ CodeGen/src/Fallbacks.h | 93 + CodeGen/src/FallbacksProlog.h | 56 + CodeGen/src/UnwindBuilderDwarf2.cpp | 52 +- Compiler/src/ConstantFolding.cpp | 1 + Makefile | 2 +- Sources.cmake | 5 + VM/include/lua.h | 3 + VM/include/luaconf.h | 5 + VM/src/lapi.cpp | 64 +- VM/src/ldebug.cpp | 5 + VM/src/lfunc.cpp | 14 + VM/src/lobject.h | 4 + VM/src/lstate.cpp | 5 + VM/src/lstate.h | 17 + VM/src/lvm.h | 3 + VM/src/lvmexecute.cpp | 153 +- VM/src/lvmutils.cpp | 78 + tests/Compiler.test.cpp | 69 +- tests/Conformance.test.cpp | 43 + tests/Linter.test.cpp | 3 - tests/Normalize.test.cpp | 4 +- tests/Parser.test.cpp | 2 - tests/TypeInfer.aliases.test.cpp | 14 +- tests/TypeInfer.provisional.test.cpp | 6 +- tests/TypeInfer.singletons.test.cpp | 25 + tests/TypeInfer.test.cpp | 39 +- tools/faillist.txt | 122 +- tools/lvmexecute_split.py | 100 + 46 files changed, 3951 insertions(+), 721 deletions(-) create mode 100644 CodeGen/src/ByteUtils.h create mode 100644 CodeGen/src/Fallbacks.cpp create mode 100644 CodeGen/src/Fallbacks.h create mode 100644 CodeGen/src/FallbacksProlog.h create mode 100644 tools/lvmexecute_split.py diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index d5cfcf3f2..3ffb3fb8c 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -90,18 +90,49 @@ struct TypeAliasExpansionConstraint TypeId target; }; -using ConstraintPtr = std::unique_ptr; - struct FunctionCallConstraint { - std::vector> innerConstraints; + std::vector> innerConstraints; TypeId fn; TypePackId result; class AstExprCall* astFragment; }; -using ConstraintV = Variant; +// result ~ prim ExpectedType SomeSingletonType MultitonType +// +// If ExpectedType is potentially a singleton (an actual singleton or a union +// that contains a singleton), then result ~ SomeSingletonType +// +// else result ~ MultitonType +struct PrimitiveTypeConstraint +{ + TypeId resultType; + TypeId expectedType; + TypeId singletonType; + TypeId multitonType; +}; + +// result ~ hasProp type "prop_name" +// +// If the subject is a table, bind the result to the named prop. If the table +// has an indexer, bind it to the index result type. If the subject is a union, +// bind the result to the union of its constituents' properties. +// +// It would be nice to get rid of this constraint and someday replace it with +// +// T <: {p: X} +// +// Where {} describes an inexact shape type. +struct HasPropConstraint +{ + TypeId resultType; + TypeId subjectType; + std::string prop; +}; + +using ConstraintV = + Variant; struct Constraint { @@ -117,6 +148,8 @@ struct Constraint std::vector> dependencies; }; +using ConstraintPtr = std::unique_ptr; + inline Constraint& asMutable(const Constraint& c) { return const_cast(c); diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 1567e0ada..e7d8ad459 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -125,23 +125,25 @@ struct ConstraintGraphBuilder void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); - TypePackId checkPack(const ScopePtr& scope, AstArray exprs); - TypePackId checkPack(const ScopePtr& scope, AstExpr* expr); + TypePackId checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes = {}); + TypePackId checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes = {}); /** * Checks an expression that is expected to evaluate to one type. * @param scope the scope the expression is contained within. * @param expr the expression to check. + * @param expectedType the type of the expression that is expected from its + * surrounding context. Used to implement bidirectional type checking. * @return the type of the expression. */ - TypeId check(const ScopePtr& scope, AstExpr* expr); + TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); - TypeId checkExprTable(const ScopePtr& scope, AstExprTable* expr); + TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); TypeId check(const ScopePtr& scope, AstExprIndexName* indexName); TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); TypeId check(const ScopePtr& scope, AstExprUnary* unary); TypeId check(const ScopePtr& scope, AstExprBinary* binary); - TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse); + TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); struct FunctionSignature diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index fe6a025b2..abea51b87 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -100,6 +100,8 @@ struct ConstraintSolver bool tryDispatch(const NameConstraint& c, NotNull constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); + bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); + bool tryDispatch(const HasPropConstraint& c, NotNull constraint); // for a, ... in some_table do bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); @@ -116,6 +118,16 @@ struct ConstraintSolver bool block(TypeId target, NotNull constraint); bool block(TypePackId target, NotNull constraint); + // Traverse the type. If any blocked or pending typevars are found, block + // the constraint on them. + // + // Returns false if a type blocks the constraint. + // + // FIXME: This use of a boolean for the return result is an appalling + // interface. + bool recursiveBlock(TypeId target, NotNull constraint); + bool recursiveBlock(TypePackId target, NotNull constraint); + void unblock(NotNull progressed); void unblock(TypeId progressed); void unblock(TypePackId progressed); diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 48dbe2bea..8e8b889b9 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -17,8 +17,8 @@ struct SingletonTypes; using ModulePtr = std::shared_ptr; -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); +bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); std::pair normalize( TypeId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice); diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index f69d133e2..718a6cc1b 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -6,6 +6,7 @@ #include "Luau/Location.h" #include +#include namespace Luau { diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 6a65ab925..aa1e9547d 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -36,6 +36,20 @@ static std::optional matchRequire(const AstExprCall& call) return call.args.data[0]; } +static bool matchSetmetatable(const AstExprCall& call) +{ + const char* smt = "setmetatable"; + + if (call.args.size != 2) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != smt) + return false; + + return true; +} + ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger) : moduleName(moduleName) @@ -214,15 +228,16 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) for (AstLocal* local : local->vars) { - TypeId ty = freshType(scope); + TypeId ty = nullptr; Location location = local->location; if (local->annotation) { location = local->annotation->location; - TypeId annotation = resolveType(scope, local->annotation, /* topLevel */ true); - addConstraint(scope, location, SubtypeConstraint{ty, annotation}); + ty = resolveType(scope, local->annotation, /* topLevel */ true); } + else + ty = freshType(scope); varTypes.push_back(ty); scope->bindings[local] = Binding{ty, location}; @@ -231,6 +246,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) for (size_t i = 0; i < local->values.size; ++i) { AstExpr* value = local->values.data[i]; + const bool hasAnnotation = i < local->vars.size && nullptr != local->vars.data[i]->annotation; + if (value->is()) { // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. @@ -239,7 +256,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) } else if (i == local->values.size - 1) { - TypePackId exprPack = checkPack(scope, value); + std::vector expectedTypes; + if (hasAnnotation) + expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes)); + + TypePackId exprPack = checkPack(scope, value, expectedTypes); if (i < local->vars.size) { @@ -250,7 +271,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) } else { - TypeId exprType = check(scope, value); + std::optional expectedType; + if (hasAnnotation) + expectedType = varTypes.at(i); + + TypeId exprType = check(scope, value, expectedType); if (i < varTypes.size()) addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); } @@ -458,7 +483,15 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) { - TypePackId exprTypes = checkPack(scope, ret->list); + // At this point, the only way scope->returnType should have anything + // interesting in it is if the function has an explicit return annotation. + // If this is the case, then we can expect that the return expression + // conforms to that. + std::vector expectedTypes; + for (TypeId ty : scope->returnType) + expectedTypes.push_back(ty); + + TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes); addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); } @@ -695,7 +728,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction scope->bindings[global->name] = Binding{fnType, global->location}; } -TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs) +TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes) { std::vector head; std::optional tail; @@ -704,9 +737,17 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray expectedType; + if (i < expectedTypes.size()) + expectedType = expectedTypes[i]; head.push_back(check(scope, expr)); + } else - tail = checkPack(scope, expr); + { + std::vector expectedTailTypes{begin(expectedTypes) + i, end(expectedTypes)}; + tail = checkPack(scope, expr, expectedTailTypes); + } } if (head.empty() && tail) @@ -715,7 +756,7 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArrayaddTypePack(TypePack{std::move(head), tail}); } -TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr) +TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes) { RecursionCounter counter{&recursionCount}; @@ -730,7 +771,6 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp if (AstExprCall* call = expr->as()) { TypeId fnType = check(scope, call->func); - const size_t constraintIndex = scope->constraints.size(); const size_t scopeIndex = scopes.size(); @@ -743,49 +783,63 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp // TODO self - const size_t constraintEndIndex = scope->constraints.size(); - const size_t scopeEndIndex = scopes.size(); - - astOriginalCallTypes[call->func] = fnType; - - TypeId instantiatedType = arena->addType(BlockedTypeVar{}); - TypePackId rets = arena->addTypePack(BlockedTypePack{}); - FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); - TypeId inferredFnType = arena->addType(ftv); - - scope->unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); - NotNull ic(scope->unqueuedConstraints.back().get()); - - scope->unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); - NotNull sc(scope->unqueuedConstraints.back().get()); - - // We force constraints produced by checking function arguments to wait - // until after we have resolved the constraint on the function itself. - // This ensures, for instance, that we start inferring the contents of - // lambdas under the assumption that their arguments and return types - // will be compatible with the enclosing function call. - for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) - scope->constraints[ci]->dependencies.push_back(sc); + if (matchSetmetatable(*call)) + { + LUAU_ASSERT(args.size() == 2); + TypeId target = args[0]; + TypeId mt = args[1]; - for (size_t si = scopeIndex; si < scopeEndIndex; ++si) + MetatableTypeVar mtv{target, mt}; + TypeId resultTy = arena->addType(mtv); + result = arena->addTypePack({resultTy}); + } + else { - for (auto& c : scopes[si].second->constraints) + const size_t constraintEndIndex = scope->constraints.size(); + const size_t scopeEndIndex = scopes.size(); + + astOriginalCallTypes[call->func] = fnType; + + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); + // TODO: How do expectedTypes play into this? Do they? + TypePackId rets = arena->addTypePack(BlockedTypePack{}); + FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); + TypeId inferredFnType = arena->addType(ftv); + + scope->unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); + NotNull ic(scope->unqueuedConstraints.back().get()); + + scope->unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); + NotNull sc(scope->unqueuedConstraints.back().get()); + + // We force constraints produced by checking function arguments to wait + // until after we have resolved the constraint on the function itself. + // This ensures, for instance, that we start inferring the contents of + // lambdas under the assumption that their arguments and return types + // will be compatible with the enclosing function call. + for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) + scope->constraints[ci]->dependencies.push_back(sc); + + for (size_t si = scopeIndex; si < scopeEndIndex; ++si) { - c->dependencies.push_back(sc); + for (auto& c : scopes[si].second->constraints) + { + c->dependencies.push_back(sc); + } } - } - addConstraint(scope, call->func->location, - FunctionCallConstraint{ - {ic, sc}, - fnType, - rets, - call, - }); + addConstraint(scope, call->func->location, + FunctionCallConstraint{ + {ic, sc}, + fnType, + rets, + call, + }); - result = rets; + result = rets; + } } else if (AstExprVarargs* varargs = expr->as()) { @@ -796,7 +850,10 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp } else { - TypeId t = check(scope, expr); + std::optional expectedType; + if (!expectedTypes.empty()) + expectedType = expectedTypes[0]; + TypeId t = check(scope, expr, expectedType); result = arena->addTypePack({t}); } @@ -805,7 +862,7 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp return result; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr) +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType) { RecursionCounter counter{&recursionCount}; @@ -819,12 +876,47 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr) if (auto group = expr->as()) result = check(scope, group->expr); - else if (expr->is()) - result = singletonTypes->stringType; + else if (auto stringExpr = expr->as()) + { + if (expectedType) + { + const TypeId expectedTy = follow(*expectedType); + if (get(expectedTy) || get(expectedTy)) + { + result = arena->addType(BlockedTypeVar{}); + TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringExpr->value.data, stringExpr->value.size)})); + addConstraint(scope, expr->location, PrimitiveTypeConstraint{result, expectedTy, singletonType, singletonTypes->stringType}); + } + else if (maybeSingleton(expectedTy)) + result = arena->addType(SingletonTypeVar{StringSingleton{std::string{stringExpr->value.data, stringExpr->value.size}}}); + else + result = singletonTypes->stringType; + } + else + result = singletonTypes->stringType; + } else if (expr->is()) result = singletonTypes->numberType; - else if (expr->is()) - result = singletonTypes->booleanType; + else if (auto boolExpr = expr->as()) + { + if (expectedType) + { + const TypeId expectedTy = follow(*expectedType); + const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; + + if (get(expectedTy) || get(expectedTy)) + { + result = arena->addType(BlockedTypeVar{}); + addConstraint(scope, expr->location, PrimitiveTypeConstraint{result, expectedTy, singletonType, singletonTypes->booleanType}); + } + else if (maybeSingleton(expectedTy)) + result = singletonType; + else + result = singletonTypes->booleanType; + } + else + result = singletonTypes->booleanType; + } else if (expr->is()) result = singletonTypes->nilType; else if (auto a = expr->as()) @@ -864,13 +956,13 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr) else if (auto indexExpr = expr->as()) result = check(scope, indexExpr); else if (auto table = expr->as()) - result = checkExprTable(scope, table); + result = check(scope, table, expectedType); else if (auto unary = expr->as()) result = check(scope, unary); else if (auto binary = expr->as()) result = check(scope, binary); else if (auto ifElse = expr->as()) - result = check(scope, ifElse); + result = check(scope, ifElse, expectedType); else if (auto typeAssert = expr->as()) result = check(scope, typeAssert); else if (auto err = expr->as()) @@ -924,20 +1016,9 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { TypeId operandType = check(scope, unary->expr); - switch (unary->op) - { - case AstExprUnary::Minus: - { - TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, unary->location, UnaryConstraint{AstExprUnary::Minus, operandType, resultType}); - return resultType; - } - default: - LUAU_ASSERT(0); - } - - LUAU_UNREACHABLE(); - return singletonTypes->errorRecoveryType(); + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); + return resultType; } TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary) @@ -946,22 +1027,34 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binar TypeId rightType = check(scope, binary->right); switch (binary->op) { + case AstExprBinary::And: case AstExprBinary::Or: { addConstraint(scope, binary->location, SubtypeConstraint{leftType, rightType}); return leftType; } case AstExprBinary::Add: + case AstExprBinary::Sub: + case AstExprBinary::Mul: + case AstExprBinary::Div: + case AstExprBinary::Mod: + case AstExprBinary::Pow: + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + case AstExprBinary::CompareLt: + case AstExprBinary::CompareLe: + case AstExprBinary::CompareGt: + case AstExprBinary::CompareGe: { TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, binary->location, BinaryConstraint{AstExprBinary::Add, leftType, rightType, resultType}); + addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); return resultType; } - case AstExprBinary::Sub: + case AstExprBinary::Concat: { - TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, binary->location, BinaryConstraint{AstExprBinary::Sub, leftType, rightType, resultType}); - return resultType; + addConstraint(scope, binary->left->location, SubtypeConstraint{leftType, singletonTypes->stringType}); + addConstraint(scope, binary->right->location, SubtypeConstraint{rightType, singletonTypes->stringType}); + return singletonTypes->stringType; } default: LUAU_ASSERT(0); @@ -971,16 +1064,16 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binar return nullptr; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse) +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) { check(scope, ifElse->condition); - TypeId thenType = check(scope, ifElse->trueExpr); - TypeId elseType = check(scope, ifElse->falseExpr); + TypeId thenType = check(scope, ifElse->trueExpr, expectedType); + TypeId elseType = check(scope, ifElse->falseExpr, expectedType); if (ifElse->hasElse) { - TypeId resultType = arena->addType(BlockedTypeVar{}); + TypeId resultType = expectedType ? *expectedType : freshType(scope); addConstraint(scope, ifElse->trueExpr->location, SubtypeConstraint{thenType, resultType}); addConstraint(scope, ifElse->falseExpr->location, SubtypeConstraint{elseType, resultType}); return resultType; @@ -995,7 +1088,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion return resolveType(scope, typeAssert->annotation); } -TypeId ConstraintGraphBuilder::checkExprTable(const ScopePtr& scope, AstExprTable* expr) +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) { TypeId ty = arena->addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(ty); @@ -1015,7 +1108,18 @@ TypeId ConstraintGraphBuilder::checkExprTable(const ScopePtr& scope, AstExprTabl for (const AstExprTable::Item& item : expr->items) { - TypeId itemTy = check(scope, item.value); + std::optional expectedValueType; + + if (item.key && expectedType) + { + if (auto stringKey = item.key->as()) + { + expectedValueType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, item.value->location, HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data}); + } + } + + TypeId itemTy = check(scope, item.value, expectedValueType); if (get(follow(itemTy))) return ty; @@ -1130,7 +1234,12 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS if (fn->returnAnnotation) { TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation); - addConstraint(signatureScope, getLocation(*fn->returnAnnotation), PackSubtypeConstraint{returnType, annotatedRetType}); + + // We bind the annotated type directly here so that, when we need to + // generate constraints for return types, we have a guarantee that we + // know the annotated return type already, if one was provided. + LUAU_ASSERT(get(returnType)); + asMutable(returnType)->ty.emplace(annotatedRetType); } std::vector argTypes; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 6fb57b15e..b2bf773f4 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -396,8 +396,12 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*taec, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); + else if (auto fcc = get(*constraint)) + success = tryDispatch(*fcc, constraint); + else if (auto hpc = get(*constraint)) + success = tryDispatch(*hpc, constraint); else - LUAU_ASSERT(0); + LUAU_ASSERT(false); if (success) { @@ -409,6 +413,11 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { + if (!recursiveBlock(c.subType, constraint)) + return false; + if (!recursiveBlock(c.superType, constraint)) + return false; + if (isBlocked(c.subType)) return block(c.subType, constraint); else if (isBlocked(c.superType)) @@ -421,6 +430,9 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { + if (!recursiveBlock(c.subPack, constraint) || !recursiveBlock(c.superPack, constraint)) + return false; + if (isBlocked(c.subPack)) return block(c.subPack, constraint); else if (isBlocked(c.superPack)) @@ -480,13 +492,30 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull(c.resultType)); - if (isNumber(operandType) || get(operandType) || get(operandType)) + switch (c.op) { - asMutable(c.resultType)->ty.emplace(c.operandType); - return true; + case AstExprUnary::Not: + { + asMutable(c.resultType)->ty.emplace(singletonTypes->booleanType); + return true; + } + case AstExprUnary::Len: + { + asMutable(c.resultType)->ty.emplace(singletonTypes->numberType); + return true; + } + case AstExprUnary::Minus: + { + if (isNumber(operandType) || get(operandType) || get(operandType)) + { + asMutable(c.resultType)->ty.emplace(c.operandType); + return true; + } + break; + } } - LUAU_ASSERT(0); // TODO metatable handling + LUAU_ASSERT(false); // TODO metatable handling return false; } @@ -906,6 +935,91 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull constraint) +{ + TypeId expectedType = follow(c.expectedType); + if (isBlocked(expectedType) || get(expectedType)) + return block(expectedType, constraint); + + TypeId bindTo = maybeSingleton(expectedType) ? c.singletonType : c.multitonType; + asMutable(c.resultType)->ty.emplace(bindTo); + + return true; +} + +bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull constraint) +{ + TypeId subjectType = follow(c.subjectType); + + if (isBlocked(subjectType) || get(subjectType)) + return block(subjectType, constraint); + + TypeId resultType = nullptr; + + auto collectParts = [&](auto&& unionOrIntersection) -> std::pair> { + bool blocked = false; + + std::vector parts; + for (TypeId expectedPart : unionOrIntersection) + { + expectedPart = follow(expectedPart); + if (isBlocked(expectedPart) || get(expectedPart)) + { + blocked = true; + block(expectedPart, constraint); + } + else if (const TableTypeVar* ttv = get(follow(expectedPart))) + { + if (auto prop = ttv->props.find(c.prop); prop != ttv->props.end()) + parts.push_back(prop->second.type); + else if (ttv->indexer && maybeString(ttv->indexer->indexType)) + parts.push_back(ttv->indexer->indexResultType); + } + } + + return {blocked, parts}; + }; + + if (auto ttv = get(subjectType)) + { + if (auto prop = ttv->props.find(c.prop); prop != ttv->props.end()) + resultType = prop->second.type; + else if (ttv->indexer && maybeString(ttv->indexer->indexType)) + resultType = ttv->indexer->indexResultType; + } + else if (auto utv = get(subjectType)) + { + auto [blocked, parts] = collectParts(utv); + + if (blocked) + return false; + else if (parts.size() == 1) + resultType = parts[0]; + else if (parts.size() > 1) + resultType = arena->addType(UnionTypeVar{std::move(parts)}); + else + LUAU_ASSERT(false); // parts.size() == 0 + } + else if (auto itv = get(subjectType)) + { + auto [blocked, parts] = collectParts(itv); + + if (blocked) + return false; + else if (parts.size() == 1) + resultType = parts[0]; + else if (parts.size() > 1) + resultType = arena->addType(IntersectionTypeVar{std::move(parts)}); + else + LUAU_ASSERT(false); // parts.size() == 0 + } + + if (resultType) + asMutable(c.resultType)->ty.emplace(resultType); + + return true; +} + bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) { auto block_ = [&](auto&& t) { @@ -914,7 +1028,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl // TODO: I believe it is the case that, if we are asked to force // this constraint, then we can do nothing but fail. I'd like to // find a code sample that gets here. - LUAU_ASSERT(0); + LUAU_ASSERT(false); } else block(t, constraint); @@ -979,7 +1093,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (get(metaTy)) return block_(metaTy); - LUAU_ASSERT(0); + LUAU_ASSERT(false); } else errorify(c.variables); @@ -996,7 +1110,7 @@ bool ConstraintSolver::tryDispatchIterableFunction( if (get(firstIndexTy)) { if (force) - LUAU_ASSERT(0); + LUAU_ASSERT(false); else block(firstIndexTy, constraint); return false; @@ -1061,6 +1175,48 @@ bool ConstraintSolver::block(TypePackId target, NotNull constr return false; } +struct Blocker : TypeVarOnceVisitor +{ + NotNull solver; + NotNull constraint; + + bool blocked = false; + + explicit Blocker(NotNull solver, NotNull constraint) + : solver(solver) + , constraint(constraint) + { + } + + bool visit(TypeId ty, const BlockedTypeVar&) + { + blocked = true; + solver->block(ty, constraint); + return false; + } + + bool visit(TypeId ty, const PendingExpansionTypeVar&) + { + blocked = true; + solver->block(ty, constraint); + return false; + } +}; + +bool ConstraintSolver::recursiveBlock(TypeId target, NotNull constraint) +{ + Blocker blocker{NotNull{this}, constraint}; + blocker.traverse(target); + return !blocker.blocked; +} + +bool ConstraintSolver::recursiveBlock(TypePackId pack, NotNull constraint) +{ + Blocker blocker{NotNull{this}, constraint}; + blocker.traverse(pack); + return !blocker.blocked; +} + void ConstraintSolver::unblock_(BlockedConstraintId progressed) { auto it = blocked.find(progressed); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index c3f0bb9d6..42f615172 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -5,6 +5,7 @@ #include #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/Unifier.h" #include "Luau/VisitTypeVar.h" @@ -13,8 +14,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); -LUAU_FASTFLAGVARIABLE(LuauFixNormalizationOfCyclicUnions, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau { @@ -54,24 +55,24 @@ struct Replacer } // anonymous namespace -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) { UnifierSharedState sharedState{&ice}; TypeArena arena; Unifier u{&arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; - u.anyIsTop = true; + u.anyIsTop = anyIsTop; u.tryUnify(subTy, superTy); const bool ok = u.errors.empty() && u.log.empty(); return ok; } -bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) { UnifierSharedState sharedState{&ice}; TypeArena arena; Unifier u{&arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; - u.anyIsTop = true; + u.anyIsTop = anyIsTop; u.tryUnify(subPack, superPack); const bool ok = u.errors.empty() && u.log.empty(); @@ -319,18 +320,11 @@ struct Normalize final : TypeVarVisitor UnionTypeVar* utv = &const_cast(utvRef); - // TODO: Clip tempOptions and optionsRef when clipping FFlag::LuauFixNormalizationOfCyclicUnions - std::vector tempOptions; - if (!FFlag::LuauFixNormalizationOfCyclicUnions) - tempOptions = std::move(utv->options); - - std::vector& optionsRef = FFlag::LuauFixNormalizationOfCyclicUnions ? utv->options : tempOptions; - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId option : optionsRef) + for (TypeId option : utv->options) traverse(option); - std::vector newOptions = normalizeUnion(optionsRef); + std::vector newOptions = normalizeUnion(utv->options); const bool normal = areNormal(newOptions, seen, ice); @@ -355,106 +349,54 @@ struct Normalize final : TypeVarVisitor IntersectionTypeVar* itv = &const_cast(itvRef); - if (FFlag::LuauFixNormalizationOfCyclicUnions) - { - std::vector oldParts = itv->parts; - IntersectionTypeVar newIntersection; - - for (TypeId part : oldParts) - traverse(part); - - std::vector tables; - for (TypeId part : oldParts) - { - part = follow(part); - if (get(part)) - tables.push_back(part); - else - { - Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD - combineIntoIntersection(replacer, &newIntersection, part); - } - } - - // Don't allocate a new table if there's just one in the intersection. - if (tables.size() == 1) - newIntersection.parts.push_back(tables[0]); - else if (!tables.empty()) - { - const TableTypeVar* first = get(tables[0]); - LUAU_ASSERT(first); + std::vector oldParts = itv->parts; + IntersectionTypeVar newIntersection; - TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); - TableTypeVar* ttv = getMutable(newTable); - for (TypeId part : tables) - { - // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need - // to be rewritten to point at 'newTable' in the clone. - Replacer replacer{&arena, part, newTable}; - combineIntoTable(replacer, ttv, part); - } - - newIntersection.parts.push_back(newTable); - } - - itv->parts = std::move(newIntersection.parts); - - asMutable(ty)->normal = areNormal(itv->parts, seen, ice); + for (TypeId part : oldParts) + traverse(part); - if (itv->parts.size() == 1) + std::vector tables; + for (TypeId part : oldParts) + { + part = follow(part); + if (get(part)) + tables.push_back(part); + else { - TypeId part = itv->parts[0]; - *asMutable(ty) = BoundTypeVar{part}; + Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD + combineIntoIntersection(replacer, &newIntersection, part); } } - else - { - std::vector oldParts = std::move(itv->parts); - for (TypeId part : oldParts) - traverse(part); + // Don't allocate a new table if there's just one in the intersection. + if (tables.size() == 1) + newIntersection.parts.push_back(tables[0]); + else if (!tables.empty()) + { + const TableTypeVar* first = get(tables[0]); + LUAU_ASSERT(first); - std::vector tables; - for (TypeId part : oldParts) + TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); + TableTypeVar* ttv = getMutable(newTable); + for (TypeId part : tables) { - part = follow(part); - if (get(part)) - tables.push_back(part); - else - { - Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD - combineIntoIntersection(replacer, itv, part); - } + // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need + // to be rewritten to point at 'newTable' in the clone. + Replacer replacer{&arena, part, newTable}; + combineIntoTable(replacer, ttv, part); } - // Don't allocate a new table if there's just one in the intersection. - if (tables.size() == 1) - itv->parts.push_back(tables[0]); - else if (!tables.empty()) - { - const TableTypeVar* first = get(tables[0]); - LUAU_ASSERT(first); - - TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); - TableTypeVar* ttv = getMutable(newTable); - for (TypeId part : tables) - { - // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need - // to be rewritten to point at 'newTable' in the clone. - Replacer replacer{&arena, part, newTable}; - combineIntoTable(replacer, ttv, part); - } + newIntersection.parts.push_back(newTable); + } - itv->parts.push_back(newTable); - } + itv->parts = std::move(newIntersection.parts); - asMutable(ty)->normal = areNormal(itv->parts, seen, ice); + asMutable(ty)->normal = areNormal(itv->parts, seen, ice); - if (itv->parts.size() == 1) - { - TypeId part = itv->parts[0]; - *asMutable(ty) = BoundTypeVar{part}; - } + if (itv->parts.size() == 1) + { + TypeId part = itv->parts[0]; + *asMutable(ty) = BoundTypeVar{part}; } return false; @@ -629,21 +571,18 @@ struct Normalize final : TypeVarVisitor table->props.insert({propName, prop}); } - if (FFlag::LuauFixNormalizationOfCyclicUnions) + if (tyTable->indexer) { - if (tyTable->indexer) + if (table->indexer) { - if (table->indexer) - { - table->indexer->indexType = combine(replacer, replacer.smartClone(tyTable->indexer->indexType), table->indexer->indexType); - table->indexer->indexResultType = - combine(replacer, replacer.smartClone(tyTable->indexer->indexResultType), table->indexer->indexResultType); - } - else - { - table->indexer = - TableIndexer{replacer.smartClone(tyTable->indexer->indexType), replacer.smartClone(tyTable->indexer->indexResultType)}; - } + table->indexer->indexType = combine(replacer, replacer.smartClone(tyTable->indexer->indexType), table->indexer->indexType); + table->indexer->indexResultType = + combine(replacer, replacer.smartClone(tyTable->indexer->indexResultType), table->indexer->indexResultType); + } + else + { + table->indexer = + TableIndexer{replacer.smartClone(tyTable->indexer->indexType), replacer.smartClone(tyTable->indexer->indexResultType)}; } } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 711d461fb..0b389547e 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1502,6 +1502,15 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return "call " + tos(c.fn, opts) + " with { result = " + tos(c.result, opts) + " }"; } + else if constexpr (std::is_same_v) + { + return tos(c.resultType, opts) + " ~ prim " + tos(c.expectedType, opts) + ", " + tos(c.singletonType, opts) + ", " + + tos(c.multitonType, opts); + } + else if constexpr (std::is_same_v) + { + return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; + } else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 76b27acdc..ea06882a5 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -307,10 +307,9 @@ struct TypeChecker2 if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - if (!isSubtype(*it, varType, stack.back(), singletonTypes, ice)) - { - reportError(TypeMismatch{varType, *it}, value->location); - } + ErrorVec errors = tryUnify(stack.back(), value->location, *it, varType); + if (!errors.empty()) + reportErrors(std::move(errors)); } ++it; @@ -325,7 +324,7 @@ struct TypeChecker2 if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - if (!isSubtype(varType, valueType, stack.back(), singletonTypes, ice)) + if (!isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { reportError(TypeMismatch{varType, valueType}, value->location); } @@ -540,7 +539,7 @@ struct TypeChecker2 visit(rhs); TypeId rhsType = lookupType(rhs); - if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice)) + if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } @@ -691,7 +690,7 @@ struct TypeChecker2 TypeId actualType = lookupType(number); TypeId numberType = singletonTypes->numberType; - if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice)) + if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -702,7 +701,7 @@ struct TypeChecker2 TypeId actualType = lookupType(string); TypeId stringType = singletonTypes->stringType; - if (!isSubtype(stringType, actualType, stack.back(), singletonTypes, ice)) + if (!isSubtype(stringType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -762,7 +761,7 @@ struct TypeChecker2 FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(expectedType, instantiatedFunctionType, stack.back(), singletonTypes, ice)) + if (!isSubtype(instantiatedFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { CloneState cloneState; expectedType = clone(expectedType, module->internalTypes, cloneState); @@ -781,7 +780,7 @@ struct TypeChecker2 getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); if (ty) { - if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice)) + if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { reportError(TypeMismatch{resultType, *ty}, indexName->location); } @@ -814,7 +813,7 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice)) + if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); } @@ -859,10 +858,10 @@ struct TypeChecker2 TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice)) + if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) return; - if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice)) + if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index f85431116..b222feb01 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -5,7 +5,6 @@ #include #include -#include #include #include @@ -35,30 +34,125 @@ class DenseHashTable class iterator; DenseHashTable(const Key& empty_key, size_t buckets = 0) - : count(0) + : data(nullptr) + , capacity(0) + , count(0) , empty_key(empty_key) { + // validate that equality operator is at least somewhat functional + LUAU_ASSERT(eq(empty_key, empty_key)); // buckets has to be power-of-two or zero LUAU_ASSERT((buckets & (buckets - 1)) == 0); - // don't move this to initializer list! this works around an MSVC codegen issue on AMD CPUs: - // https://developercommunity.visualstudio.com/t/stdvector-constructor-from-size-t-is-25-times-slow/1546547 if (buckets) - resize_data(buckets); + { + data = static_cast(::operator new(sizeof(Item) * buckets)); + capacity = buckets; + + ItemInterface::fill(data, buckets, empty_key); + } + } + + ~DenseHashTable() + { + if (data) + destroy(); + } + + DenseHashTable(const DenseHashTable& other) + : data(nullptr) + , capacity(0) + , count(other.count) + , empty_key(other.empty_key) + { + if (other.capacity) + { + data = static_cast(::operator new(sizeof(Item) * other.capacity)); + + for (size_t i = 0; i < other.capacity; ++i) + { + new (&data[i]) Item(other.data[i]); + capacity = i + 1; // if Item copy throws, capacity will note the number of initialized objects for destroy() to clean up + } + } + } + + DenseHashTable(DenseHashTable&& other) + : data(other.data) + , capacity(other.capacity) + , count(other.count) + , empty_key(other.empty_key) + { + other.data = nullptr; + other.capacity = 0; + other.count = 0; + } + + DenseHashTable& operator=(DenseHashTable&& other) + { + if (this != &other) + { + if (data) + destroy(); + + data = other.data; + capacity = other.capacity; + count = other.count; + empty_key = other.empty_key; + + other.data = nullptr; + other.capacity = 0; + other.count = 0; + } + + return *this; + } + + DenseHashTable& operator=(const DenseHashTable& other) + { + if (this != &other) + { + DenseHashTable copy(other); + *this = std::move(copy); + } + + return *this; } void clear() { - data.clear(); + if (count == 0) + return; + + if (capacity > 32) + { + destroy(); + } + else + { + ItemInterface::destroy(data, capacity); + ItemInterface::fill(data, capacity, empty_key); + } + count = 0; } + void destroy() + { + ItemInterface::destroy(data, capacity); + + ::operator delete(data); + data = nullptr; + + capacity = 0; + } + Item* insert_unsafe(const Key& key) { // It is invalid to insert empty_key into the table since it acts as a "entry does not exist" marker LUAU_ASSERT(!eq(key, empty_key)); - size_t hashmod = data.size() - 1; + size_t hashmod = capacity - 1; size_t bucket = hasher(key) & hashmod; for (size_t probe = 0; probe <= hashmod; ++probe) @@ -90,12 +184,12 @@ class DenseHashTable const Item* find(const Key& key) const { - if (data.empty()) + if (count == 0) return 0; if (eq(key, empty_key)) return 0; - size_t hashmod = data.size() - 1; + size_t hashmod = capacity - 1; size_t bucket = hasher(key) & hashmod; for (size_t probe = 0; probe <= hashmod; ++probe) @@ -121,18 +215,11 @@ class DenseHashTable void rehash() { - size_t newsize = data.empty() ? 16 : data.size() * 2; - - if (data.empty() && data.capacity() >= newsize) - { - LUAU_ASSERT(count == 0); - resize_data(newsize); - return; - } + size_t newsize = capacity == 0 ? 16 : capacity * 2; DenseHashTable newtable(empty_key, newsize); - for (size_t i = 0; i < data.size(); ++i) + for (size_t i = 0; i < capacity; ++i) { const Key& key = ItemInterface::getKey(data[i]); @@ -144,12 +231,14 @@ class DenseHashTable } LUAU_ASSERT(count == newtable.count); - data.swap(newtable.data); + + std::swap(data, newtable.data); + std::swap(capacity, newtable.capacity); } void rehash_if_full() { - if (count >= data.size() * 3 / 4) + if (count >= capacity * 3 / 4) { rehash(); } @@ -159,7 +248,7 @@ class DenseHashTable { size_t start = 0; - while (start < data.size() && eq(ItemInterface::getKey(data[start]), empty_key)) + while (start < capacity && eq(ItemInterface::getKey(data[start]), empty_key)) start++; return const_iterator(this, start); @@ -167,14 +256,14 @@ class DenseHashTable const_iterator end() const { - return const_iterator(this, data.size()); + return const_iterator(this, capacity); } iterator begin() { size_t start = 0; - while (start < data.size() && eq(ItemInterface::getKey(data[start]), empty_key)) + while (start < capacity && eq(ItemInterface::getKey(data[start]), empty_key)) start++; return iterator(this, start); @@ -182,7 +271,7 @@ class DenseHashTable iterator end() { - return iterator(this, data.size()); + return iterator(this, capacity); } size_t size() const @@ -227,7 +316,7 @@ class DenseHashTable const_iterator& operator++() { - size_t size = set->data.size(); + size_t size = set->capacity; do { @@ -286,7 +375,7 @@ class DenseHashTable iterator& operator++() { - size_t size = set->data.size(); + size_t size = set->capacity; do { @@ -309,23 +398,8 @@ class DenseHashTable }; private: - template - void resize_data(size_t count, typename std::enable_if_t>* dummy = nullptr) - { - data.resize(count, ItemInterface::create(empty_key)); - } - - template - void resize_data(size_t count, typename std::enable_if_t>* dummy = nullptr) - { - size_t size = data.size(); - data.resize(count); - - for (size_t i = size; i < count; i++) - data[i].first = empty_key; - } - - std::vector data; + Item* data; + size_t capacity; size_t count; Key empty_key; Hash hasher; @@ -345,9 +419,16 @@ struct ItemInterfaceSet item = key; } - static Key create(const Key& key) + static void fill(Key* data, size_t count, const Key& key) + { + for (size_t i = 0; i < count; ++i) + new (&data[i]) Key(key); + } + + static void destroy(Key* data, size_t count) { - return key; + for (size_t i = 0; i < count; ++i) + data[i].~Key(); } }; @@ -364,9 +445,22 @@ struct ItemInterfaceMap item.first = key; } - static std::pair create(const Key& key) + static void fill(std::pair* data, size_t count, const Key& key) + { + for (size_t i = 0; i < count; ++i) + { + new (&data[i].first) Key(key); + new (&data[i].second) Value(); + } + } + + static void destroy(std::pair* data, size_t count) { - return std::pair(key, Value()); + for (size_t i = 0; i < count; ++i) + { + data[i].first.~Key(); + data[i].second.~Value(); + } } }; diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index 7e7fe76ba..929402b33 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -6,6 +6,8 @@ #include "Luau/DenseHash.h" #include "Luau/Common.h" +#include + namespace Luau { diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 0914054f9..cf3eaaaea 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -20,7 +20,6 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) bool lua_telemetry_parsed_named_non_function_type = false; LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false) -LUAU_FASTFLAGVARIABLE(LuauLintParseIntegerIssues, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) @@ -2070,95 +2069,8 @@ AstExpr* Parser::parseAssertionExpr() return expr; } -static const char* parseInteger_DEPRECATED(double& result, const char* data, int base) -{ - LUAU_ASSERT(!FFlag::LuauLintParseIntegerIssues); - - char* end = nullptr; - unsigned long long value = strtoull(data, &end, base); - - if (value == ULLONG_MAX && errno == ERANGE) - { - // 'errno' might have been set before we called 'strtoull', but we don't want the overhead of resetting a TLS variable on each call - // so we only reset it when we get a result that might be an out-of-range error and parse again to make sure - errno = 0; - value = strtoull(data, &end, base); - - if (errno == ERANGE) - { - if (DFFlag::LuaReportParseIntegerIssues) - { - if (base == 2) - lua_telemetry_parsed_out_of_range_bin_integer = true; - else - lua_telemetry_parsed_out_of_range_hex_integer = true; - } - } - } - - result = double(value); - return *end == 0 ? nullptr : "Malformed number"; -} - -static const char* parseNumber_DEPRECATED2(double& result, const char* data) -{ - LUAU_ASSERT(!FFlag::LuauLintParseIntegerIssues); - - // binary literal - if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) - return parseInteger_DEPRECATED(result, data + 2, 2); - - // hexadecimal literal - if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2]) - { - if (DFFlag::LuaReportParseIntegerIssues && data[2] == '0' && (data[3] == 'x' || data[3] == 'X')) - lua_telemetry_parsed_double_prefix_hex_integer = true; - - return parseInteger_DEPRECATED(result, data + 2, 16); - } - - char* end = nullptr; - double value = strtod(data, &end); - - result = value; - return *end == 0 ? nullptr : "Malformed number"; -} - -static bool parseNumber_DEPRECATED(double& result, const char* data) -{ - LUAU_ASSERT(!FFlag::LuauLintParseIntegerIssues); - - // binary literal - if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) - { - char* end = nullptr; - unsigned long long value = strtoull(data + 2, &end, 2); - - result = double(value); - return *end == 0; - } - // hexadecimal literal - else if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2]) - { - char* end = nullptr; - unsigned long long value = strtoull(data + 2, &end, 16); - - result = double(value); - return *end == 0; - } - else - { - char* end = nullptr; - double value = strtod(data, &end); - - result = value; - return *end == 0; - } -} - static ConstantNumberParseResult parseInteger(double& result, const char* data, int base) { - LUAU_ASSERT(FFlag::LuauLintParseIntegerIssues); LUAU_ASSERT(base == 2 || base == 16); char* end = nullptr; @@ -2195,8 +2107,6 @@ static ConstantNumberParseResult parseInteger(double& result, const char* data, static ConstantNumberParseResult parseDouble(double& result, const char* data) { - LUAU_ASSERT(FFlag::LuauLintParseIntegerIssues); - // binary literal if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) return parseInteger(result, data + 2, 2); @@ -2771,49 +2681,14 @@ AstExpr* Parser::parseNumber() scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end()); } - if (FFlag::LuauLintParseIntegerIssues) - { - double value = 0; - ConstantNumberParseResult result = parseDouble(value, scratchData.c_str()); - nextLexeme(); - - if (result == ConstantNumberParseResult::Malformed) - return reportExprError(start, {}, "Malformed number"); - - return allocator.alloc(start, value, result); - } - else if (DFFlag::LuaReportParseIntegerIssues) - { - double value = 0; - if (const char* error = parseNumber_DEPRECATED2(value, scratchData.c_str())) - { - nextLexeme(); - - return reportExprError(start, {}, "%s", error); - } - else - { - nextLexeme(); - - return allocator.alloc(start, value); - } - } - else - { - double value = 0; - if (parseNumber_DEPRECATED(value, scratchData.c_str())) - { - nextLexeme(); + double value = 0; + ConstantNumberParseResult result = parseDouble(value, scratchData.c_str()); + nextLexeme(); - return allocator.alloc(start, value); - } - else - { - nextLexeme(); + if (result == ConstantNumberParseResult::Malformed) + return reportExprError(start, {}, "Malformed number"); - return reportExprError(start, {}, "Malformed number"); - } - } + return allocator.alloc(start, value, result); } AstLocal* Parser::pushLocal(const Binding& binding) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3dafe5fec..9ad16e8d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,6 +70,7 @@ target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) target_compile_features(Luau.CodeGen PRIVATE cxx_std_17) target_include_directories(Luau.CodeGen PUBLIC CodeGen/include) +target_link_libraries(Luau.CodeGen PRIVATE Luau.VM Luau.VM.Internals) # Code generation needs VM internals target_link_libraries(Luau.CodeGen PUBLIC Luau.Common) target_compile_features(Luau.VM PRIVATE cxx_std_11) diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index 09c91d438..25dbc55ba 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -27,13 +27,13 @@ class UnwindBuilderDwarf2 : public UnwindBuilder private: static const unsigned kRawDataLimit = 128; - char rawData[kRawDataLimit]; - char* pos = rawData; + uint8_t rawData[kRawDataLimit]; + uint8_t* pos = rawData; uint32_t stackOffset = 0; // We will remember the FDE location to write some of the fields like entry length, function start and size later - char* fdeEntryStart = nullptr; + uint8_t* fdeEntryStart = nullptr; }; } // namespace CodeGen diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 0fd103287..32325b0dc 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/AssemblyBuilderX64.h" +#include "ByteUtils.h" + #include #include #include @@ -46,44 +48,6 @@ const unsigned AVX_F2 = 0b11; const unsigned kMaxAlign = 16; -// Utility functions to correctly write data on big endian machines -#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ -#include - -static void writeu32(uint8_t* target, uint32_t value) -{ - value = htole32(value); - memcpy(target, &value, sizeof(value)); -} - -static void writeu64(uint8_t* target, uint64_t value) -{ - value = htole64(value); - memcpy(target, &value, sizeof(value)); -} - -static void writef32(uint8_t* target, float value) -{ - static_assert(sizeof(float) == sizeof(uint32_t), "type size must match to reinterpret data"); - uint32_t data; - memcpy(&data, &value, sizeof(value)); - writeu32(target, data); -} - -static void writef64(uint8_t* target, double value) -{ - static_assert(sizeof(double) == sizeof(uint64_t), "type size must match to reinterpret data"); - uint64_t data; - memcpy(&data, &value, sizeof(value)); - writeu64(target, data); -} -#else -#define writeu32(target, value) memcpy(target, &value, sizeof(value)) -#define writeu64(target, value) memcpy(target, &value, sizeof(value)) -#define writef32(target, value) memcpy(target, &value, sizeof(value)) -#define writef64(target, value) memcpy(target, &value, sizeof(value)) -#endif - AssemblyBuilderX64::AssemblyBuilderX64(bool logText) : logText(logText) { @@ -1014,16 +978,14 @@ void AssemblyBuilderX64::placeImm32(int32_t imm) { uint8_t* pos = codePos; LUAU_ASSERT(pos + sizeof(imm) < codeEnd); - writeu32(pos, imm); - codePos = pos + sizeof(imm); + codePos = writeu32(pos, imm); } void AssemblyBuilderX64::placeImm64(int64_t imm) { uint8_t* pos = codePos; LUAU_ASSERT(pos + sizeof(imm) < codeEnd); - writeu64(pos, imm); - codePos = pos + sizeof(imm); + codePos = writeu64(pos, imm); } void AssemblyBuilderX64::placeLabel(Label& label) diff --git a/CodeGen/src/ByteUtils.h b/CodeGen/src/ByteUtils.h new file mode 100644 index 000000000..1e1e341d7 --- /dev/null +++ b/CodeGen/src/ByteUtils.h @@ -0,0 +1,78 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#include +#endif + +#include + +inline uint8_t* writeu8(uint8_t* target, uint8_t value) +{ + *target = value; + return target + sizeof(value); +} + +inline uint8_t* writeu32(uint8_t* target, uint32_t value) +{ +#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + value = htole32(value); +#endif + + memcpy(target, &value, sizeof(value)); + return target + sizeof(value); +} + +inline uint8_t* writeu64(uint8_t* target, uint64_t value) +{ +#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + value = htole64(value); +#endif + + memcpy(target, &value, sizeof(value)); + return target + sizeof(value); +} + +inline uint8_t* writeuleb128(uint8_t* target, uint64_t value) +{ + do + { + uint8_t byte = value & 0x7f; + value >>= 7; + + if (value) + byte |= 0x80; + + *target++ = byte; + } while (value); + + return target; +} + +inline uint8_t* writef32(uint8_t* target, float value) +{ +#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + static_assert(sizeof(float) == sizeof(uint32_t), "type size must match to reinterpret data"); + uint32_t data; + memcpy(&data, &value, sizeof(value)); + writeu32(target, data); +#else + memcpy(target, &value, sizeof(value)); +#endif + + return target + sizeof(value); +} + +inline uint8_t* writef64(uint8_t* target, double value) +{ +#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + static_assert(sizeof(double) == sizeof(uint64_t), "type size must match to reinterpret data"); + uint64_t data; + memcpy(&data, &value, sizeof(value)); + writeu64(target, data); +#else + memcpy(target, &value, sizeof(value)); +#endif + + return target + sizeof(value); +} diff --git a/CodeGen/src/Fallbacks.cpp b/CodeGen/src/Fallbacks.cpp new file mode 100644 index 000000000..3893d349d --- /dev/null +++ b/CodeGen/src/Fallbacks.cpp @@ -0,0 +1,2511 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand +#include "Fallbacks.h" +#include "FallbacksProlog.h" + +const Instruction* execute_LOP_NOP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + LUAU_ASSERT(insn == 0); + return pc; +} + +const Instruction* execute_LOP_LOADNIL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + setnilvalue(ra); + return pc; +} + +const Instruction* execute_LOP_LOADB(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + setbvalue(ra, LUAU_INSN_B(insn)); + + pc += LUAU_INSN_C(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_LOADN(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + setnvalue(ra, LUAU_INSN_D(insn)); + return pc; +} + +const Instruction* execute_LOP_LOADK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_D(insn)); + + setobj2s(L, ra, kv); + return pc; +} + +const Instruction* execute_LOP_MOVE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + + setobj2s(L, ra, rb); + return pc; +} + +const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: value is in expected slot + Table* h = cl->env; + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv)) && !ttisnil(gval(n))) + { + setobj2s(L, ra, gval(n)); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __index metamethod + TValue g; + sethvalue(L, &g, h); + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, &g, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } +} + +const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: value is in expected slot + Table* h = cl->env; + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) + { + setobj2t(L, gval(n), ra); + luaC_barriert(L, h, ra); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __newindex metamethod + TValue g; + sethvalue(L, &g, h); + L->cachedslot = slot; + VM_PROTECT(luaV_settable(L, &g, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } +} + +const Instruction* execute_LOP_GETUPVAL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* ur = VM_UV(LUAU_INSN_B(insn)); + TValue* v = ttisupval(ur) ? upvalue(ur)->v : ur; + + setobj2s(L, ra, v); + return pc; +} + +const Instruction* execute_LOP_SETUPVAL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* ur = VM_UV(LUAU_INSN_B(insn)); + UpVal* uv = upvalue(ur); + + setobj(L, uv->v, ra); + luaC_barrier(L, uv, ra); + return pc; +} + +const Instruction* execute_LOP_CLOSEUPVALS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + if (L->openupval && L->openupval->v >= ra) + luaF_close(L, ra); + return pc; +} + +const Instruction* execute_LOP_GETIMPORT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_D(insn)); + + // fast-path: import resolution was successful and closure environment is "safe" for import + if (!ttisnil(kv) && cl->env->safeenv) + { + setobj2s(L, ra, kv); + pc++; // skip over AUX + return pc; + } + else + { + uint32_t aux = *pc++; + + VM_PROTECT(luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false)); + ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack + + setobj2s(L, ra, L->top - 1); + L->top--; + return pc; + } +} + +const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: built-in table + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: value is in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) + { + setobj2s(L, ra, gval(n)); + return pc; + } + else if (!h->metatable) + { + // fast-path: value is not in expected slot, but the table lookup doesn't involve metatable + const TValue* res = luaH_getstr(h, tsvalue(kv)); + + if (res != luaO_nilobject) + { + int cachedslot = gval2slot(h, res); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, cachedslot); + } + + setobj2s(L, ra, res); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __index metamethod + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + } + else + { + // fast-path: user data with C __index TM + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_INDEX)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + else if (ttisvector(rb)) + { + // fast-path: quick case-insensitive comparison with "X"/"Y"/"Z" + const char* name = getstr(tsvalue(kv)); + int ic = (name[0] | ' ') - 'x'; + +#if LUA_VECTOR_SIZE == 4 + // 'w' is before 'x' in ascii, so ic is -1 when indexing with 'w' + if (ic == -1) + ic = 3; +#endif + + if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0') + { + const float* v = rb->value.v; // silences ubsan when indexing v[] + setnvalue(ra, v[ic]); + return pc; + } + + fn = fasttm(L, L->global->mt[LUA_TVECTOR], TM_INDEX); + + if (fn && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + + // fall through to slow path + } + + // fall through to slow path + } + + // slow-path, may invoke Lua calls via __index metamethod + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + return pc; +} + +const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: built-in table + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: value is in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) + { + setobj2t(L, gval(n), ra); + luaC_barriert(L, h, ra); + return pc; + } + else if (fastnotm(h->metatable, TM_NEWINDEX) && !h->readonly) + { + VM_PROTECT_PC(); // set may fail + + TValue* res = luaH_setstr(L, h, tsvalue(kv)); + int cachedslot = gval2slot(h, res); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, cachedslot); + setobj2t(L, res, ra); + luaC_barriert(L, h, ra); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __newindex metamethod + L->cachedslot = slot; + VM_PROTECT(luaV_settable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + } + else + { + // fast-path: user data with C __newindex TM + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_NEWINDEX)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 4 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + setobj2s(L, top + 3, ra); + L->top = top + 4; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_callTM(L, 3, -1)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __newindex metamethod + VM_PROTECT(luaV_settable(L, rb, kv, ra)); + return pc; + } + } +} + +const Instruction* execute_LOP_GETTABLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path: array lookup + if (ttistable(rb) && ttisnumber(rc)) + { + Table* h = hvalue(rb); + + double indexd = nvalue(rc); + int index = int(indexd); + + // index has to be an exact integer and in-bounds for the array portion + if (LUAU_LIKELY(unsigned(index - 1) < unsigned(h->sizearray) && !h->metatable && double(index) == indexd)) + { + setobj2s(L, ra, &h->array[unsigned(index - 1)]); + return pc; + } + + // fall through to slow path + } + + // slow-path: handles out of bounds array lookups, non-integer numeric keys, non-array table lookup, __index MT calls + VM_PROTECT(luaV_gettable(L, rb, rc, ra)); + return pc; +} + +const Instruction* execute_LOP_SETTABLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path: array assign + if (ttistable(rb) && ttisnumber(rc)) + { + Table* h = hvalue(rb); + + double indexd = nvalue(rc); + int index = int(indexd); + + // index has to be an exact integer and in-bounds for the array portion + if (LUAU_LIKELY(unsigned(index - 1) < unsigned(h->sizearray) && !h->metatable && !h->readonly && double(index) == indexd)) + { + setobj2t(L, &h->array[unsigned(index - 1)], ra); + luaC_barriert(L, h, ra); + return pc; + } + + // fall through to slow path + } + + // slow-path: handles out of bounds array assignments, non-integer numeric keys, non-array table access, __newindex MT calls + VM_PROTECT(luaV_settable(L, rb, rc, ra)); + return pc; +} + +const Instruction* execute_LOP_GETTABLEN(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + int c = LUAU_INSN_C(insn); + + // fast-path: array lookup + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + if (LUAU_LIKELY(unsigned(c) < unsigned(h->sizearray) && !h->metatable)) + { + setobj2s(L, ra, &h->array[c]); + return pc; + } + + // fall through to slow path + } + + // slow-path: handles out of bounds array lookups + TValue n; + setnvalue(&n, c + 1); + VM_PROTECT(luaV_gettable(L, rb, &n, ra)); + return pc; +} + +const Instruction* execute_LOP_SETTABLEN(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + int c = LUAU_INSN_C(insn); + + // fast-path: array assign + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + if (LUAU_LIKELY(unsigned(c) < unsigned(h->sizearray) && !h->metatable && !h->readonly)) + { + setobj2t(L, &h->array[c], ra); + luaC_barriert(L, h, ra); + return pc; + } + + // fall through to slow path + } + + // slow-path: handles out of bounds array lookups + TValue n; + setnvalue(&n, c + 1); + VM_PROTECT(luaV_settable(L, rb, &n, ra)); + return pc; +} + +const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)]; + LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep)); + + // note: we save closure to stack early in case the code below wants to capture it by value + Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv); + setclvalue(L, ra, ncl); + + for (int ui = 0; ui < pv->nups; ++ui) + { + Instruction uinsn = *pc++; + LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); + + switch (LUAU_INSN_A(uinsn)) + { + case LCT_VAL: + setobj(L, &ncl->l.uprefs[ui], VM_REG(LUAU_INSN_B(uinsn))); + break; + + case LCT_REF: + setupvalue(L, &ncl->l.uprefs[ui], luaF_findupval(L, VM_REG(LUAU_INSN_B(uinsn)))); + break; + + case LCT_UPVAL: + setobj(L, &ncl->l.uprefs[ui], VM_UV(LUAU_INSN_B(uinsn))); + break; + + default: + LUAU_ASSERT(!"Unknown upvalue capture type"); + } + } + + VM_PROTECT(luaC_checkGC(L)); + return pc; +} + +const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + if (ttistable(rb)) + { + Table* h = hvalue(rb); + // note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works + // for predictive lookups + LuaNode* n = &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; + + const TValue* mt = 0; + const LuaNode* mtn = 0; + + // fast-path: key is in the table in expected slot + if (ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(n)); + } + // fast-path: key is absent from the base, table has an __index table, and it has the result in the expected slot + else if (gnext(n) == 0 && (mt = fasttm(L, hvalue(rb)->metatable, TM_INDEX)) && ttistable(mt) && + (mtn = &hvalue(mt)->node[LUAU_INSN_C(insn) & hvalue(mt)->nodemask8]) && ttisstring(gkey(mtn)) && tsvalue(gkey(mtn)) == tsvalue(kv) && + !ttisnil(gval(mtn))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(mtn)); + } + else + { + // slow-path: handles full table lookup + setobj2s(L, ra + 1, rb); + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); + } + } + else + { + Table* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; + const TValue* tmi = 0; + + // fast-path: metatable with __namecall + if (const TValue* fn = fasttm(L, mt, TM_NAMECALL)) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, fn); + + L->namecall = tsvalue(kv); + } + else if ((tmi = fasttm(L, mt, TM_INDEX)) && ttistable(tmi)) + { + Table* h = hvalue(tmi); + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: metatable with __index that has method in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(n)); + } + else + { + // slow-path: handles slot mismatch + setobj2s(L, ra + 1, rb); + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); + } + } + else + { + // slow-path: handles non-table __index + setobj2s(L, ra + 1, rb); + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); + } + } + + // intentional fallthrough to CALL + LUAU_ASSERT(LUAU_INSN_OP(*pc) == LOP_CALL); + return pc; +} + +const Instruction* execute_LOP_CALL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + VM_INTERRUPT(); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + int nparams = LUAU_INSN_B(insn) - 1; + int nresults = LUAU_INSN_C(insn) - 1; + + StkId argtop = L->top; + argtop = (nparams == LUA_MULTRET) ? argtop : ra + 1 + nparams; + + // slow-path: not a function call + if (LUAU_UNLIKELY(!ttisfunction(ra))) + { + VM_PROTECT(luaV_tryfuncTM(L, ra)); + argtop++; // __call adds an extra self + } + + Closure* ccl = clvalue(ra); + L->ci->savedpc = pc; + + CallInfo* ci = incr_ci(L); + ci->func = ra; + ci->base = ra + 1; + ci->top = argtop + ccl->stacksize; // note: technically UB since we haven't reallocated the stack yet + ci->savedpc = NULL; + ci->flags = 0; + ci->nresults = nresults; + + L->base = ci->base; + L->top = argtop; + + // note: this reallocs stack, but we don't need to VM_PROTECT this + // this is because we're going to modify base/savedpc manually anyhow + // crucially, we can't use ra/argtop after this line + luaD_checkstack(L, ccl->stacksize); + + LUAU_ASSERT(ci->top <= L->stack_last); + + if (!ccl->isC) + { + Proto* p = ccl->l.p; + + // fill unused parameters with nil + StkId argi = L->top; + StkId argend = L->base + p->numparams; + while (argi < argend) + setnilvalue(argi++); // complete missing arguments + L->top = p->is_vararg ? argi : ci->top; + + // reentry + pc = p->code; + cl = ccl; + base = L->base; + k = p->k; + return pc; + } + else + { + lua_CFunction func = ccl->c.f; + int n = func(L); + + // yield + if (n < 0) + return NULL; + + // ci is our callinfo, cip is our parent + CallInfo* ci = L->ci; + CallInfo* cip = ci - 1; + + // copy return values into parent stack (but only up to nresults!), fill the rest with nil + // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally + StkId res = ci->func; + StkId vali = L->top - n; + StkId valend = L->top; + + int i; + for (i = nresults; i != 0 && vali < valend; i--) + setobjs2s(L, res++, vali++); + while (i-- > 0) + setnilvalue(res++); + + // pop the stack frame + L->ci = cip; + L->base = cip->base; + L->top = (nresults == LUA_MULTRET) ? res : cip->top; + + base = L->base; // stack may have been reallocated, so we need to refresh base ptr + return pc; + } +} + +const Instruction* execute_LOP_RETURN(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + VM_INTERRUPT(); + Instruction insn = *pc++; + StkId ra = &base[LUAU_INSN_A(insn)]; // note: this can point to L->top if b == LUA_MULTRET making VM_REG unsafe to use + int b = LUAU_INSN_B(insn) - 1; + + // ci is our callinfo, cip is our parent + CallInfo* ci = L->ci; + CallInfo* cip = ci - 1; + + StkId res = ci->func; // note: we assume CALL always puts func+args and expects results to start at func + + StkId vali = ra; + StkId valend = (b == LUA_MULTRET) ? L->top : ra + b; // copy as much as possible for MULTRET calls, and only as much as needed otherwise + + int nresults = ci->nresults; + + // copy return values into parent stack (but only up to nresults!), fill the rest with nil + // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally + int i; + for (i = nresults; i != 0 && vali < valend; i--) + setobjs2s(L, res++, vali++); + while (i-- > 0) + setnilvalue(res++); + + // pop the stack frame + L->ci = cip; + L->base = cip->base; + L->top = (nresults == LUA_MULTRET) ? res : cip->top; + + // we're done! + if (LUAU_UNLIKELY(ci->flags & LUA_CALLINFO_RETURN)) + { + L->top = res; + return NULL; + } + + LUAU_ASSERT(isLua(L->ci)); + + Closure* nextcl = clvalue(cip->func); + Proto* nextproto = nextcl->l.p; + + // reentry + pc = cip->savedpc; + cl = nextcl; + base = L->base; + k = nextproto->k; + return pc; +} + +const Instruction* execute_LOP_JUMP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_JUMPIF(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + pc += l_isfalse(ra) ? 0 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_JUMPIFNOT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + pc += l_isfalse(ra) ? LUAU_INSN_D(insn) : 0; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_JUMPIFEQ(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // Note that all jumps below jump by 1 in the "false" case to skip over aux + if (ttype(ra) == ttype(rb)) + { + switch (ttype(ra)) + { + case LUA_TNIL: + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TBOOLEAN: + pc += bvalue(ra) == bvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TLIGHTUSERDATA: + pc += pvalue(ra) == pvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TNUMBER: + pc += nvalue(ra) == nvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TVECTOR: + pc += luai_veceq(vvalue(ra), vvalue(rb)) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TSTRING: + case LUA_TFUNCTION: + case LUA_TTHREAD: + pc += gcvalue(ra) == gcvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TTABLE: + // fast-path: same metatable, no EQ metamethod + if (hvalue(ra)->metatable == hvalue(rb)->metatable) + { + const TValue* fn = fasttm(L, hvalue(ra)->metatable, TM_EQ); + + if (!fn) + { + pc += hvalue(ra) == hvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + } + // slow path after switch() + break; + + case LUA_TUSERDATA: + // fast-path: same metatable, no EQ metamethod or C metamethod + if (uvalue(ra)->metatable == uvalue(rb)->metatable) + { + const TValue* fn = fasttm(L, uvalue(ra)->metatable, TM_EQ); + + if (!fn) + { + pc += uvalue(ra) == uvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else if (ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, ra); + setobj2s(L, top + 2, rb); + int res = int(top - base); + L->top = top + 3; + + VM_PROTECT(luaV_callTM(L, 2, res)); + pc += !l_isfalse(&base[res]) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + } + // slow path after switch() + break; + + default:; + } + + // slow-path: tables with metatables and userdata values + // note that we don't have a fast path for userdata values without metatables, since that's very rare + int res; + VM_PROTECT(res = luaV_equalval(L, ra, rb)); + + pc += (res == 1) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + pc += 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } +} + +const Instruction* execute_LOP_JUMPIFNOTEQ(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // Note that all jumps below jump by 1 in the "true" case to skip over aux + if (ttype(ra) == ttype(rb)) + { + switch (ttype(ra)) + { + case LUA_TNIL: + pc += 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TBOOLEAN: + pc += bvalue(ra) != bvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TLIGHTUSERDATA: + pc += pvalue(ra) != pvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TNUMBER: + pc += nvalue(ra) != nvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TVECTOR: + pc += !luai_veceq(vvalue(ra), vvalue(rb)) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TSTRING: + case LUA_TFUNCTION: + case LUA_TTHREAD: + pc += gcvalue(ra) != gcvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + + case LUA_TTABLE: + // fast-path: same metatable, no EQ metamethod + if (hvalue(ra)->metatable == hvalue(rb)->metatable) + { + const TValue* fn = fasttm(L, hvalue(ra)->metatable, TM_EQ); + + if (!fn) + { + pc += hvalue(ra) != hvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + } + // slow path after switch() + break; + + case LUA_TUSERDATA: + // fast-path: same metatable, no EQ metamethod or C metamethod + if (uvalue(ra)->metatable == uvalue(rb)->metatable) + { + const TValue* fn = fasttm(L, uvalue(ra)->metatable, TM_EQ); + + if (!fn) + { + pc += uvalue(ra) != uvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else if (ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, ra); + setobj2s(L, top + 2, rb); + int res = int(top - base); + L->top = top + 3; + + VM_PROTECT(luaV_callTM(L, 2, res)); + pc += l_isfalse(&base[res]) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + } + // slow path after switch() + break; + + default:; + } + + // slow-path: tables with metatables and userdata values + // note that we don't have a fast path for userdata values without metatables, since that's very rare + int res; + VM_PROTECT(res = luaV_equalval(L, ra, rb)); + + pc += (res == 0) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } +} + +const Instruction* execute_LOP_JUMPIFLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // fast-path: number + // Note that all jumps below jump by 1 in the "false" case to skip over aux + if (ttisnumber(ra) && ttisnumber(rb)) + { + pc += nvalue(ra) <= nvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + // fast-path: string + else if (ttisstring(ra) && ttisstring(rb)) + { + pc += luaV_strcmp(tsvalue(ra), tsvalue(rb)) <= 0 ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + int res; + VM_PROTECT(res = luaV_lessequal(L, ra, rb)); + + pc += (res == 1) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } +} + +const Instruction* execute_LOP_JUMPIFNOTLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // fast-path: number + // Note that all jumps below jump by 1 in the "true" case to skip over aux + if (ttisnumber(ra) && ttisnumber(rb)) + { + pc += !(nvalue(ra) <= nvalue(rb)) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + // fast-path: string + else if (ttisstring(ra) && ttisstring(rb)) + { + pc += !(luaV_strcmp(tsvalue(ra), tsvalue(rb)) <= 0) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + int res; + VM_PROTECT(res = luaV_lessequal(L, ra, rb)); + + pc += (res == 0) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } +} + +const Instruction* execute_LOP_JUMPIFLT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // fast-path: number + // Note that all jumps below jump by 1 in the "false" case to skip over aux + if (ttisnumber(ra) && ttisnumber(rb)) + { + pc += nvalue(ra) < nvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + // fast-path: string + else if (ttisstring(ra) && ttisstring(rb)) + { + pc += luaV_strcmp(tsvalue(ra), tsvalue(rb)) < 0 ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + int res; + VM_PROTECT(res = luaV_lessthan(L, ra, rb)); + + pc += (res == 1) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } +} + +const Instruction* execute_LOP_JUMPIFNOTLT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // fast-path: number + // Note that all jumps below jump by 1 in the "true" case to skip over aux + if (ttisnumber(ra) && ttisnumber(rb)) + { + pc += !(nvalue(ra) < nvalue(rb)) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + // fast-path: string + else if (ttisstring(ra) && ttisstring(rb)) + { + pc += !(luaV_strcmp(tsvalue(ra), tsvalue(rb)) < 0) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + int res; + VM_PROTECT(res = luaV_lessthan(L, ra, rb)); + + pc += (res == 0) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } +} + +const Instruction* execute_LOP_ADD(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + setnvalue(ra, nvalue(rb) + nvalue(rc)); + return pc; + } + else if (ttisvector(rb) && ttisvector(rc)) + { + const float* vb = rb->value.v; + const float* vc = rc->value.v; + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); + return pc; + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_ADD)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, rc); + L->top = top + 3; + + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_ADD)); + return pc; + } + } +} + +const Instruction* execute_LOP_SUB(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + setnvalue(ra, nvalue(rb) - nvalue(rc)); + return pc; + } + else if (ttisvector(rb) && ttisvector(rc)) + { + const float* vb = rb->value.v; + const float* vc = rc->value.v; + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); + return pc; + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_SUB)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, rc); + L->top = top + 3; + + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_SUB)); + return pc; + } + } +} + +const Instruction* execute_LOP_MUL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + setnvalue(ra, nvalue(rb) * nvalue(rc)); + return pc; + } + else if (ttisvector(rb) && ttisnumber(rc)) + { + const float* vb = rb->value.v; + float vc = cast_to(float, nvalue(rc)); + setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc, vb[3] * vc); + return pc; + } + else if (ttisvector(rb) && ttisvector(rc)) + { + const float* vb = rb->value.v; + const float* vc = rc->value.v; + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); + return pc; + } + else if (ttisnumber(rb) && ttisvector(rc)) + { + float vb = cast_to(float, nvalue(rb)); + const float* vc = rc->value.v; + setvvalue(ra, vb * vc[0], vb * vc[1], vb * vc[2], vb * vc[3]); + return pc; + } + else + { + // fast-path for userdata with C functions + StkId rbc = ttisnumber(rb) ? rc : rb; + const TValue* fn = 0; + if (ttisuserdata(rbc) && (fn = luaT_gettmbyobj(L, rbc, TM_MUL)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, rc); + L->top = top + 3; + + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MUL)); + return pc; + } + } +} + +const Instruction* execute_LOP_DIV(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + setnvalue(ra, nvalue(rb) / nvalue(rc)); + return pc; + } + else if (ttisvector(rb) && ttisnumber(rc)) + { + const float* vb = rb->value.v; + float vc = cast_to(float, nvalue(rc)); + setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc, vb[3] / vc); + return pc; + } + else if (ttisvector(rb) && ttisvector(rc)) + { + const float* vb = rb->value.v; + const float* vc = rc->value.v; + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); + return pc; + } + else if (ttisnumber(rb) && ttisvector(rc)) + { + float vb = cast_to(float, nvalue(rb)); + const float* vc = rc->value.v; + setvvalue(ra, vb / vc[0], vb / vc[1], vb / vc[2], vb / vc[3]); + return pc; + } + else + { + // fast-path for userdata with C functions + StkId rbc = ttisnumber(rb) ? rc : rb; + const TValue* fn = 0; + if (ttisuserdata(rbc) && (fn = luaT_gettmbyobj(L, rbc, TM_DIV)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, rc); + L->top = top + 3; + + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_DIV)); + return pc; + } + } +} + +const Instruction* execute_LOP_MOD(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + double nb = nvalue(rb); + double nc = nvalue(rc); + setnvalue(ra, luai_nummod(nb, nc)); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MOD)); + return pc; + } +} + +const Instruction* execute_LOP_POW(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + setnvalue(ra, pow(nvalue(rb), nvalue(rc))); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_POW)); + return pc; + } +} + +const Instruction* execute_LOP_ADDK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + setnvalue(ra, nvalue(rb) + nvalue(kv)); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_ADD)); + return pc; + } +} + +const Instruction* execute_LOP_SUBK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + setnvalue(ra, nvalue(rb) - nvalue(kv)); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_SUB)); + return pc; + } +} + +const Instruction* execute_LOP_MULK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + setnvalue(ra, nvalue(rb) * nvalue(kv)); + return pc; + } + else if (ttisvector(rb)) + { + const float* vb = rb->value.v; + float vc = cast_to(float, nvalue(kv)); + setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc, vb[3] * vc); + return pc; + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_MUL)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MUL)); + return pc; + } + } +} + +const Instruction* execute_LOP_DIVK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + setnvalue(ra, nvalue(rb) / nvalue(kv)); + return pc; + } + else if (ttisvector(rb)) + { + const float* vb = rb->value.v; + float vc = cast_to(float, nvalue(kv)); + setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc, vb[3] / vc); + return pc; + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_DIV)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_DIV)); + return pc; + } + } +} + +const Instruction* execute_LOP_MODK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + double nb = nvalue(rb); + double nk = nvalue(kv); + setnvalue(ra, luai_nummod(nb, nk)); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MOD)); + return pc; + } +} + +const Instruction* execute_LOP_POWK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + double nb = nvalue(rb); + double nk = nvalue(kv); + + // pow is very slow so we specialize this for ^2, ^0.5 and ^3 + double r = (nk == 2.0) ? nb * nb : (nk == 0.5) ? sqrt(nb) : (nk == 3.0) ? nb * nb * nb : pow(nb, nk); + + setnvalue(ra, r); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_POW)); + return pc; + } +} + +const Instruction* execute_LOP_AND(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + setobj2s(L, ra, l_isfalse(rb) ? rb : rc); + return pc; +} + +const Instruction* execute_LOP_OR(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + setobj2s(L, ra, l_isfalse(rb) ? rc : rb); + return pc; +} + +const Instruction* execute_LOP_ANDK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + setobj2s(L, ra, l_isfalse(rb) ? rb : kv); + return pc; +} + +const Instruction* execute_LOP_ORK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + setobj2s(L, ra, l_isfalse(rb) ? kv : rb); + return pc; +} + +const Instruction* execute_LOP_CONCAT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + int b = LUAU_INSN_B(insn); + int c = LUAU_INSN_C(insn); + + // This call may realloc the stack! So we need to query args further down + VM_PROTECT(luaV_concat(L, c - b + 1, c)); + + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + setobjs2s(L, ra, base + b); + VM_PROTECT(luaC_checkGC(L)); + return pc; +} + +const Instruction* execute_LOP_NOT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + + int res = l_isfalse(rb); + setbvalue(ra, res); + return pc; +} + +const Instruction* execute_LOP_MINUS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + + // fast-path + if (ttisnumber(rb)) + { + setnvalue(ra, -nvalue(rb)); + return pc; + } + else if (ttisvector(rb)) + { + const float* vb = rb->value.v; + setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); + return pc; + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_UNM)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 2 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + L->top = top + 2; + + VM_PROTECT(luaV_callTM(L, 1, LUAU_INSN_A(insn))); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rb, TM_UNM)); + return pc; + } + } +} + +const Instruction* execute_LOP_LENGTH(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + + // fast-path #1: tables + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + if (fastnotm(h->metatable, TM_LEN)) + { + setnvalue(ra, cast_num(luaH_getn(h))); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_dolen(L, ra, rb)); + return pc; + } + } + // fast-path #2: strings (not very important but easy to do) + else if (ttisstring(rb)) + { + TString* ts = tsvalue(rb); + setnvalue(ra, cast_num(ts->len)); + return pc; + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_dolen(L, ra, rb)); + return pc; + } +} + +const Instruction* execute_LOP_NEWTABLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + int b = LUAU_INSN_B(insn); + uint32_t aux = *pc++; + + sethvalue(L, ra, luaH_new(L, aux, b == 0 ? 0 : (1 << (b - 1)))); + VM_PROTECT(luaC_checkGC(L)); + return pc; +} + +const Instruction* execute_LOP_DUPTABLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_D(insn)); + + sethvalue(L, ra, luaH_clone(L, hvalue(kv))); + VM_PROTECT(luaC_checkGC(L)); + return pc; +} + +const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = &base[LUAU_INSN_B(insn)]; // note: this can point to L->top if c == LUA_MULTRET making VM_REG unsafe to use + int c = LUAU_INSN_C(insn) - 1; + uint32_t index = *pc++; + + if (c == LUA_MULTRET) + { + c = int(L->top - rb); + L->top = L->ci->top; + } + + Table* h = hvalue(ra); + + if (!ttistable(ra)) + return NULL; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode + + int last = index + c - 1; + if (last > h->sizearray) + luaH_resizearray(L, h, last); + + TValue* array = h->array; + + for (int i = 0; i < c; ++i) + setobj2t(L, &array[index + i - 1], rb + i); + + luaC_barrierfast(L, h); + return pc; +} + +const Instruction* execute_LOP_FORNPREP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + if (!ttisnumber(ra + 0) || !ttisnumber(ra + 1) || !ttisnumber(ra + 2)) + { + // slow-path: can convert arguments to numbers and trigger Lua errors + // Note: this doesn't reallocate stack so we don't need to recompute ra/base + VM_PROTECT_PC(); + + luaV_prepareFORN(L, ra + 0, ra + 1, ra + 2); + } + + double limit = nvalue(ra + 0); + double step = nvalue(ra + 1); + double idx = nvalue(ra + 2); + + // Note: make sure the loop condition is exactly the same between this and LOP_FORNLOOP so that we handle NaN/etc. consistently + pc += (step > 0 ? idx <= limit : limit <= idx) ? 0 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_FORNLOOP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + VM_INTERRUPT(); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + LUAU_ASSERT(ttisnumber(ra + 0) && ttisnumber(ra + 1) && ttisnumber(ra + 2)); + + double limit = nvalue(ra + 0); + double step = nvalue(ra + 1); + double idx = nvalue(ra + 2) + step; + + setnvalue(ra + 2, idx); + + // Note: make sure the loop condition is exactly the same between this and LOP_FORNPREP so that we handle NaN/etc. consistently + if (step > 0 ? idx <= limit : limit <= idx) + { + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + // fallthrough to exit + return pc; + } +} + +const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + if (ttisfunction(ra)) + { + // will be called during FORGLOOP + } + else + { + Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); + + if (const TValue* fn = fasttm(L, mt, TM_ITER)) + { + setobj2s(L, ra + 1, ra); + setobj2s(L, ra, fn); + + L->top = ra + 2; // func + self arg + LUAU_ASSERT(L->top <= L->stack_last); + + VM_PROTECT(luaD_call(L, ra, 3)); + L->top = L->ci->top; + + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + + // protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP + if (ttisnil(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "call")); + } + } + else if (fasttm(L, mt, TM_CALL)) + { + // table or userdata with __call, will be called during FORGLOOP + // TODO: we might be able to stop supporting this depending on whether it's used in practice + } + else if (ttistable(ra)) + { + // set up registers for builtin iteration + setobj2s(L, ra + 1, ra); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setnilvalue(ra); + } + else + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } + } + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_FORGLOOP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + VM_INTERRUPT(); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc; + + // fast-path: builtin table iteration + // note: ra=nil guarantees ra+1=table and ra+2=userdata because of the setup by FORGPREP* opcodes + // TODO: remove the table check per guarantee above + if (ttisnil(ra) && ttistable(ra + 1)) + { + Table* h = hvalue(ra + 1); + int index = int(reinterpret_cast(pvalue(ra + 2))); + + int sizearray = h->sizearray; + + // clear extra variables since we might have more than two + // note: while aux encodes ipairs bit, when set we always use 2 variables, so it's safe to check this via a signed comparison + if (LUAU_UNLIKELY(int(aux) > 2)) + for (int i = 2; i < int(aux); ++i) + setnilvalue(ra + 3 + i); + + // terminate ipairs-style traversal early when encountering nil + if (int(aux) < 0 && (unsigned(index) >= unsigned(sizearray) || ttisnil(&h->array[index]))) + { + pc++; + return pc; + } + + // first we advance index through the array portion + while (unsigned(index) < unsigned(sizearray)) + { + TValue* e = &h->array[index]; + + if (!ttisnil(e)) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setnvalue(ra + 3, double(index + 1)); + setobj2s(L, ra + 4, e); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + + index++; + } + + int sizenode = 1 << h->lsizenode; + + // then we advance index through the hash portion + while (unsigned(index - sizearray) < unsigned(sizenode)) + { + LuaNode* n = &h->node[index - sizearray]; + + if (!ttisnil(gval(n))) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + getnodekey(L, ra + 3, n); + setobj2s(L, ra + 4, gval(n)); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + + index++; + } + + // fallthrough to exit + pc++; + return pc; + } + else + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + setobjs2s(L, ra + 3 + 2, ra + 2); + setobjs2s(L, ra + 3 + 1, ra + 1); + setobjs2s(L, ra + 3, ra); + + L->top = ra + 3 + 3; // func + 2 args (state and index) + LUAU_ASSERT(L->top <= L->stack_last); + + VM_PROTECT(luaD_call(L, ra + 3, uint8_t(aux))); + L->top = L->ci->top; + + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + + // copy first variable back into the iteration index + setobjs2s(L, ra + 2, ra + 3); + + // note that we need to increment pc by 1 to exit the loop since we need to skip over aux + pc += ttisnil(ra + 3) ? 1 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } +} + +const Instruction* execute_LOP_FORGPREP_INEXT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + // fast-path: ipairs/inext + if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) + { + setnilvalue(ra); + // ra+1 is already the table + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + } + else if (!ttisfunction(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_DEP_FORGLOOP_INEXT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); +} + +const Instruction* execute_LOP_FORGPREP_NEXT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + // fast-path: pairs/next + if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) + { + setnilvalue(ra); + // ra+1 is already the table + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + } + else if (!ttisfunction(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_DEP_FORGLOOP_NEXT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); +} + +const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + int b = LUAU_INSN_B(insn) - 1; + int n = cast_int(base - L->ci->func) - cl->l.p->numparams - 1; + + if (b == LUA_MULTRET) + { + VM_PROTECT(luaD_checkstack(L, n)); + StkId ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack + + for (int j = 0; j < n; j++) + setobjs2s(L, ra + j, base - n + j); + + L->top = ra + n; + return pc; + } + else + { + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + for (int j = 0; j < b && j < n; j++) + setobjs2s(L, ra + j, base - n + j); + for (int j = n; j < b; j++) + setnilvalue(ra + j); + return pc; + } +} + +const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_D(insn)); + + Closure* kcl = clvalue(kv); + + // clone closure if the environment is not shared + // note: we save closure to stack early in case the code below wants to capture it by value + Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); + setclvalue(L, ra, ncl); + + // this loop does three things: + // - if the closure was created anew, it just fills it with upvalues + // - if the closure from the constant table is used, it fills it with upvalues so that it can be shared in the future + // - if the closure is reused, it checks if the reuse is safe via rawequal, and falls back to duplicating the closure + // normally this would use two separate loops, for reuse check and upvalue setup, but MSVC codegen goes crazy if you do that + for (int ui = 0; ui < kcl->nupvalues; ++ui) + { + Instruction uinsn = pc[ui]; + LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); + LUAU_ASSERT(LUAU_INSN_A(uinsn) == LCT_VAL || LUAU_INSN_A(uinsn) == LCT_UPVAL); + + TValue* uv = (LUAU_INSN_A(uinsn) == LCT_VAL) ? VM_REG(LUAU_INSN_B(uinsn)) : VM_UV(LUAU_INSN_B(uinsn)); + + // check if the existing closure is safe to reuse + if (ncl == kcl && luaO_rawequalObj(&ncl->l.uprefs[ui], uv)) + continue; + + // lazily clone the closure and update the upvalues + if (ncl == kcl && kcl->preload == 0) + { + ncl = luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); + setclvalue(L, ra, ncl); + + ui = -1; // restart the loop to fill all upvalues + continue; + } + + // this updates a newly created closure, or an existing closure created during preload, in which case we need a barrier + setobj(L, &ncl->l.uprefs[ui], uv); + luaC_barrier(L, ncl, uv); + } + + // this is a noop if ncl is newly created or shared successfully, but it has to run after the closure is preloaded for the first time + ncl->preload = 0; + + if (kcl != ncl) + VM_PROTECT(luaC_checkGC(L)); + + pc += kcl->nupvalues; + return pc; +} + +const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + int numparams = LUAU_INSN_A(insn); + + // all fixed parameters are copied after the top so we need more stack space + VM_PROTECT(luaD_checkstack(L, cl->stacksize + numparams)); + + // the caller must have filled extra fixed arguments with nil + LUAU_ASSERT(cast_int(L->top - base) >= numparams); + + // move fixed parameters to final position + StkId fixed = base; // first fixed argument + base = L->top; // final position of first argument + + for (int i = 0; i < numparams; ++i) + { + setobjs2s(L, base + i, fixed + i); + setnilvalue(fixed + i); + } + + // rewire our stack frame to point to the new base + L->ci->base = base; + L->ci->top = base + cl->stacksize; + + L->base = base; + L->top = L->ci->top; + return pc; +} + +const Instruction* execute_LOP_JUMPBACK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + VM_INTERRUPT(); + Instruction insn = *pc++; + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_LOADKX(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + + setobj2s(L, ra, kv); + return pc; +} + +const Instruction* execute_LOP_JUMPX(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + VM_INTERRUPT(); + Instruction insn = *pc++; + + pc += LUAU_INSN_E(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_FASTCALL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + int bfid = LUAU_INSN_A(insn); + int skip = LUAU_INSN_C(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code + skip) < unsigned(cl->l.p->sizecode)); + + Instruction call = pc[skip]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + StkId ra = VM_REG(LUAU_INSN_A(call)); + + int nparams = LUAU_INSN_B(call) - 1; + int nresults = LUAU_INSN_C(call) - 1; + + nparams = (nparams == LUA_MULTRET) ? int(L->top - ra - 1) : nparams; + + luau_FastFunction f = luauF_table[bfid]; + + if (cl->env->safeenv && f) + { + VM_PROTECT_PC(); + + int n = f(L, ra, ra + 1, nresults, ra + 2, nparams); + + if (n >= 0) + { + L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top; + + pc += skip + 1; // skip instructions that compute function as well as CALL + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + // continue execution through the fallback code + return pc; + } + } + else + { + // continue execution through the fallback code + return pc; + } +} + +const Instruction* execute_LOP_COVERAGE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + int hits = LUAU_INSN_E(insn); + + // update hits with saturated add and patch the instruction in place + hits = (hits < (1 << 23) - 1) ? hits + 1 : hits; + VM_PATCH_E(pc - 1, hits); + + return pc; +} + +const Instruction* execute_LOP_CAPTURE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + LUAU_ASSERT(!"CAPTURE is a pseudo-opcode and must be executed as part of NEWCLOSURE"); + LUAU_UNREACHABLE(); +} + +const Instruction* execute_LOP_DEP_JUMPIFEQK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); +} + +const Instruction* execute_LOP_DEP_JUMPIFNOTEQK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); +} + +const Instruction* execute_LOP_FASTCALL1(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + int bfid = LUAU_INSN_A(insn); + TValue* arg = VM_REG(LUAU_INSN_B(insn)); + int skip = LUAU_INSN_C(insn); + + LUAU_ASSERT(unsigned(pc - cl->l.p->code + skip) < unsigned(cl->l.p->sizecode)); + + Instruction call = pc[skip]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + StkId ra = VM_REG(LUAU_INSN_A(call)); + + int nparams = 1; + int nresults = LUAU_INSN_C(call) - 1; + + luau_FastFunction f = luauF_table[bfid]; + + if (cl->env->safeenv && f) + { + VM_PROTECT_PC(); + + int n = f(L, ra, arg, nresults, NULL, nparams); + + if (n >= 0) + { + L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top; + + pc += skip + 1; // skip instructions that compute function as well as CALL + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + // continue execution through the fallback code + return pc; + } + } + else + { + // continue execution through the fallback code + return pc; + } +} + +const Instruction* execute_LOP_FASTCALL2(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + int bfid = LUAU_INSN_A(insn); + int skip = LUAU_INSN_C(insn) - 1; + uint32_t aux = *pc++; + TValue* arg1 = VM_REG(LUAU_INSN_B(insn)); + TValue* arg2 = VM_REG(aux); + + LUAU_ASSERT(unsigned(pc - cl->l.p->code + skip) < unsigned(cl->l.p->sizecode)); + + Instruction call = pc[skip]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + StkId ra = VM_REG(LUAU_INSN_A(call)); + + int nparams = 2; + int nresults = LUAU_INSN_C(call) - 1; + + luau_FastFunction f = luauF_table[bfid]; + + if (cl->env->safeenv && f) + { + VM_PROTECT_PC(); + + int n = f(L, ra, arg1, nresults, arg2, nparams); + + if (n >= 0) + { + L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top; + + pc += skip + 1; // skip instructions that compute function as well as CALL + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + // continue execution through the fallback code + return pc; + } + } + else + { + // continue execution through the fallback code + return pc; + } +} + +const Instruction* execute_LOP_FASTCALL2K(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + int bfid = LUAU_INSN_A(insn); + int skip = LUAU_INSN_C(insn) - 1; + uint32_t aux = *pc++; + TValue* arg1 = VM_REG(LUAU_INSN_B(insn)); + TValue* arg2 = VM_KV(aux); + + LUAU_ASSERT(unsigned(pc - cl->l.p->code + skip) < unsigned(cl->l.p->sizecode)); + + Instruction call = pc[skip]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + StkId ra = VM_REG(LUAU_INSN_A(call)); + + int nparams = 2; + int nresults = LUAU_INSN_C(call) - 1; + + luau_FastFunction f = luauF_table[bfid]; + + if (cl->env->safeenv && f) + { + VM_PROTECT_PC(); + + int n = f(L, ra, arg1, nresults, arg2, nparams); + + if (n >= 0) + { + L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top; + + pc += skip + 1; // skip instructions that compute function as well as CALL + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; + } + else + { + // continue execution through the fallback code + return pc; + } + } + else + { + // continue execution through the fallback code + return pc; + } +} + +const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); +} + +const Instruction* execute_LOP_JUMPXEQKNIL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + static_assert(LUA_TNIL == 0, "we expect type-1 to be negative iff type is nil"); + // condition is equivalent to: int(ttisnil(ra)) != (aux >> 31) + pc += int((ttype(ra) - 1) ^ aux) < 0 ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_JUMPXEQKB(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + pc += int(ttisboolean(ra) && bvalue(ra) == int(aux & 1)) != (aux >> 31) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_JUMPXEQKN(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(aux & 0xffffff); + LUAU_ASSERT(ttisnumber(kv)); + +#if defined(__aarch64__) + // On several ARM chips (Apple M1/M2, Neoverse N1), comparing the result of a floating-point comparison is expensive, and a branch + // is much cheaper; on some 32-bit ARM chips (Cortex A53) the performance is about the same so we prefer less branchy variant there + if (aux >> 31) + pc += !(ttisnumber(ra) && nvalue(ra) == nvalue(kv)) ? LUAU_INSN_D(insn) : 1; + else + pc += (ttisnumber(ra) && nvalue(ra) == nvalue(kv)) ? LUAU_INSN_D(insn) : 1; +#else + pc += int(ttisnumber(ra) && nvalue(ra) == nvalue(kv)) != (aux >> 31) ? LUAU_INSN_D(insn) : 1; +#endif + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* execute_LOP_JUMPXEQKS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k) +{ + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(aux & 0xffffff); + LUAU_ASSERT(ttisstring(kv)); + + pc += int(ttisstring(ra) && gcvalue(ra) == gcvalue(kv)) != (aux >> 31) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} diff --git a/CodeGen/src/Fallbacks.h b/CodeGen/src/Fallbacks.h new file mode 100644 index 000000000..3bec8c5b4 --- /dev/null +++ b/CodeGen/src/Fallbacks.h @@ -0,0 +1,93 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand +#pragma once + +#include + +struct lua_State; +struct Closure; +typedef uint32_t Instruction; +typedef struct lua_TValue TValue; +typedef TValue* StkId; + +const Instruction* execute_LOP_NOP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_LOADNIL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_LOADB(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_LOADN(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_LOADK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_MOVE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_GETUPVAL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_SETUPVAL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_CLOSEUPVALS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_GETIMPORT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_GETTABLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_SETTABLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_GETTABLEN(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_SETTABLEN(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_CALL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_RETURN(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPIF(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPIFNOT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPIFEQ(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPIFNOTEQ(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPIFLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPIFNOTLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPIFLT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPIFNOTLT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_ADD(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_SUB(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_MUL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_DIV(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_MOD(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_POW(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_ADDK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_SUBK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_MULK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_DIVK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_MODK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_POWK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_AND(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_OR(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_ANDK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_ORK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_CONCAT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_NOT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_MINUS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_LENGTH(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_NEWTABLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_DUPTABLE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_FORNPREP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_FORNLOOP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_FORGLOOP(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_FORGPREP_INEXT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_DEP_FORGLOOP_INEXT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_FORGPREP_NEXT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_DEP_FORGLOOP_NEXT(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPBACK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_LOADKX(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPX(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_FASTCALL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_COVERAGE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_CAPTURE(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_DEP_JUMPIFEQK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_DEP_JUMPIFNOTEQK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_FASTCALL1(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_FASTCALL2(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_FASTCALL2K(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPXEQKNIL(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPXEQKB(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPXEQKN(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); +const Instruction* execute_LOP_JUMPXEQKS(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k); diff --git a/CodeGen/src/FallbacksProlog.h b/CodeGen/src/FallbacksProlog.h new file mode 100644 index 000000000..bbb06b84b --- /dev/null +++ b/CodeGen/src/FallbacksProlog.h @@ -0,0 +1,56 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "lvm.h" + +#include "lbuiltins.h" +#include "lbytecode.h" +#include "ldebug.h" +#include "ldo.h" +#include "lfunc.h" +#include "lgc.h" +#include "lmem.h" +#include "lnumutils.h" +#include "lstate.h" +#include "lstring.h" +#include "ltable.h" + +#include + +// All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT +// This makes sure that we save the pc (in case the Lua call needs to generate a backtrace) before the call, +// and restores the stack pointer after in case stack gets reallocated +// Should only be used on the slow paths. +#define VM_PROTECT(x) \ + { \ + L->ci->savedpc = pc; \ + { \ + x; \ + }; \ + base = L->base; \ + } + +// Some external functions can cause an error, but never reallocate the stack; for these, VM_PROTECT_PC() is +// a cheaper version of VM_PROTECT that can be called before the external call. +#define VM_PROTECT_PC() L->ci->savedpc = pc + +#define VM_REG(i) (LUAU_ASSERT(unsigned(i) < unsigned(L->top - base)), &base[i]) +#define VM_KV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->l.p->sizek)), &k[i]) +#define VM_UV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->nupvalues)), &cl->l.uprefs[i]) + +#define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) +#define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) + +#define VM_INTERRUPT() \ + { \ + void (*interrupt)(lua_State*, int) = L->global->cb.interrupt; \ + if (LUAU_UNLIKELY(!!interrupt)) \ + { /* the interrupt hook is called right before we advance pc */ \ + VM_PROTECT(L->ci->savedpc++; interrupt(L, -1)); \ + if (L->status != 0) \ + { \ + L->ci->savedpc--; \ + return NULL; \ + } \ + } \ + } diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index 38e3e712f..f3886d9ce 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/UnwindBuilderDwarf2.h" +#include "ByteUtils.h" + #include // General information about Dwarf2 format can be found at: @@ -11,40 +13,6 @@ // https://refspecs.linuxbase.org/elf/x86_64-abi-0.99.pdf [System V Application Binary Interface (AMD64 Architecture Processor Supplement)] // Interaction between Dwarf2 and System V ABI can be found in sections '3.6.2 DWARF Register Number Mapping' and '4.2.4 EH_FRAME sections' -static char* writeu8(char* target, uint8_t value) -{ - memcpy(target, &value, sizeof(value)); - return target + sizeof(value); -} - -static char* writeu32(char* target, uint32_t value) -{ - memcpy(target, &value, sizeof(value)); - return target + sizeof(value); -} - -static char* writeu64(char* target, uint64_t value) -{ - memcpy(target, &value, sizeof(value)); - return target + sizeof(value); -} - -static char* writeuleb128(char* target, uint64_t value) -{ - do - { - char byte = value & 0x7f; - value >>= 7; - - if (value) - byte |= 0x80; - - *target++ = byte; - } while (value); - - return target; -} - // Call frame instruction opcodes #define DW_CFA_advance_loc 0x40 #define DW_CFA_offset 0x80 @@ -104,7 +72,7 @@ const int kFdeInitialLocationOffset = 8; const int kFdeAddressRangeOffset = 16; // Define canonical frame address expression as [reg + offset] -static char* defineCfaExpression(char* pos, int dwReg, uint32_t stackOffset) +static uint8_t* defineCfaExpression(uint8_t* pos, int dwReg, uint32_t stackOffset) { pos = writeu8(pos, DW_CFA_def_cfa); pos = writeuleb128(pos, dwReg); @@ -113,14 +81,14 @@ static char* defineCfaExpression(char* pos, int dwReg, uint32_t stackOffset) } // Update offset value in canonical frame address expression -static char* defineCfaExpressionOffset(char* pos, uint32_t stackOffset) +static uint8_t* defineCfaExpressionOffset(uint8_t* pos, uint32_t stackOffset) { pos = writeu8(pos, DW_CFA_def_cfa_offset); pos = writeuleb128(pos, stackOffset); return pos; } -static char* defineSavedRegisterLocation(char* pos, int dwReg, uint32_t stackOffset) +static uint8_t* defineSavedRegisterLocation(uint8_t* pos, int dwReg, uint32_t stackOffset) { LUAU_ASSERT(stackOffset % kDataAlignFactor == 0 && "stack offsets have to be measured in kDataAlignFactor units"); @@ -138,14 +106,14 @@ static char* defineSavedRegisterLocation(char* pos, int dwReg, uint32_t stackOff return pos; } -static char* advanceLocation(char* pos, uint8_t offset) +static uint8_t* advanceLocation(uint8_t* pos, uint8_t offset) { pos = writeu8(pos, DW_CFA_advance_loc1); pos = writeu8(pos, offset); return pos; } -static char* alignPosition(char* start, char* pos) +static uint8_t* alignPosition(uint8_t* start, uint8_t* pos) { size_t size = pos - start; size_t pad = ((size + kDwarfAlign - 1) & ~(kDwarfAlign - 1)) - size; @@ -163,7 +131,7 @@ namespace CodeGen void UnwindBuilderDwarf2::start() { - char* cieLength = pos; + uint8_t* cieLength = pos; pos = writeu32(pos, 0); // Length (to be filled later) pos = writeu32(pos, 0); // CIE id. 0 -- .eh_frame @@ -245,8 +213,8 @@ void UnwindBuilderDwarf2::finalize(char* target, void* funcAddress, size_t funcS memcpy(target, rawData, getSize()); unsigned fdeEntryStartPos = unsigned(fdeEntryStart - rawData); - writeu64(target + fdeEntryStartPos + kFdeInitialLocationOffset, uintptr_t(funcAddress)); - writeu64(target + fdeEntryStartPos + kFdeAddressRangeOffset, funcSize); + writeu64((uint8_t*)target + fdeEntryStartPos + kFdeInitialLocationOffset, uintptr_t(funcAddress)); + writeu64((uint8_t*)target + fdeEntryStartPos + kFdeAddressRangeOffset, funcSize); } } // namespace CodeGen diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index e35c883ae..510f2b7b9 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -3,6 +3,7 @@ #include "BuiltinFolding.h" +#include #include namespace Luau diff --git a/Makefile b/Makefile index bd72cf881..a773af4a3 100644 --- a/Makefile +++ b/Makefile @@ -105,7 +105,7 @@ endif $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -$(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include +$(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include $(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include $(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY diff --git a/Sources.cmake b/Sources.cmake index 580c8b3cf..381004831 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -69,8 +69,13 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/AssemblyBuilderX64.cpp CodeGen/src/CodeAllocator.cpp CodeGen/src/CodeBlockUnwind.cpp + CodeGen/src/Fallbacks.cpp CodeGen/src/UnwindBuilderDwarf2.cpp CodeGen/src/UnwindBuilderWin.cpp + + CodeGen/src/ByteUtils.h + CodeGen/src/Fallbacks.h + CodeGen/src/FallbacksProlog.h ) # Luau.Analysis Sources diff --git a/VM/include/lua.h b/VM/include/lua.h index cdd56e96d..0b34bd0a5 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -304,6 +304,7 @@ LUA_API size_t lua_totalbytes(lua_State* L, int category); LUA_API l_noret lua_error(lua_State* L); LUA_API int lua_next(lua_State* L, int idx); +LUA_API int lua_rawiter(lua_State* L, int idx, int iter); LUA_API void lua_concat(lua_State* L, int n); @@ -316,6 +317,8 @@ LUA_API void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, LUA_API void lua_clonefunction(lua_State* L, int idx); +LUA_API void lua_cleartable(lua_State* L, int idx); + /* ** reference system, can be used to pin objects */ diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index 7b0f4c30a..1d8d78134 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -109,6 +109,11 @@ #define LUA_MAXCAPTURES 32 #endif +// enables callbacks to redirect code execution from Luau VM to a custom implementation +#ifndef LUA_CUSTOM_EXECUTION +#define LUA_CUSTOM_EXECUTION 0 +#endif + // }================================================================== /* diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index cbcaa3cc0..e5ce4d5a3 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -51,6 +51,12 @@ const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n" L->top++; \ } +#define api_update_top(L, p) \ + { \ + api_check(L, p >= L->base && p < L->ci->top); \ + L->top = p; \ + } + #define updateatom(L, ts) \ { \ if (ts->atom == ATOM_UNDEF) \ @@ -851,7 +857,7 @@ void lua_rawsetfield(lua_State* L, int idx, const char* k) StkId t = index2addr(L, idx); api_check(L, ttistable(t)); if (hvalue(t)->readonly) - luaG_runerror(L, "Attempt to modify a readonly table"); + luaG_readonlyerror(L); setobj2t(L, luaH_setstr(L, hvalue(t), luaS_new(L, k)), L->top - 1); luaC_barriert(L, hvalue(t), L->top - 1); L->top--; @@ -1204,6 +1210,52 @@ int lua_next(lua_State* L, int idx) return more; } +int lua_rawiter(lua_State* L, int idx, int iter) +{ + luaC_threadbarrier(L); + StkId t = index2addr(L, idx); + api_check(L, ttistable(t)); + api_check(L, iter >= 0); + + Table* h = hvalue(t); + int sizearray = h->sizearray; + + // first we advance iter through the array portion + for (; unsigned(iter) < unsigned(sizearray); ++iter) + { + TValue* e = &h->array[iter]; + + if (!ttisnil(e)) + { + StkId top = L->top; + setnvalue(top + 0, double(iter + 1)); + setobj2s(L, top + 1, e); + api_update_top(L, top + 2); + return iter + 1; + } + } + + int sizenode = 1 << h->lsizenode; + + // then we advance iter through the hash portion + for (; unsigned(iter - sizearray) < unsigned(sizenode); ++iter) + { + LuaNode* n = &h->node[iter - sizearray]; + + if (!ttisnil(gval(n))) + { + StkId top = L->top; + getnodekey(L, top + 0, n); + setobj2s(L, top + 1, gval(n)); + api_update_top(L, top + 2); + return iter + 1; + } + } + + // traversal finished + return -1; +} + void lua_concat(lua_State* L, int n) { api_checknelems(L, n); @@ -1376,6 +1428,16 @@ void lua_clonefunction(lua_State* L, int idx) api_incr_top(L); } +void lua_cleartable(lua_State* L, int idx) +{ + StkId t = index2addr(L, idx); + api_check(L, ttistable(t)); + Table* tt = hvalue(t); + if (tt->readonly) + luaG_readonlyerror(L); + luaH_clear(tt); +} + lua_Callbacks* lua_callbacks(lua_State* L) { return &L->global->cb; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index fee9aaa93..7c181d4ec 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -340,6 +340,11 @@ void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable) p->code[i] |= op; LUAU_ASSERT(LUAU_INSN_OP(p->code[i]) == op); +#if LUA_CUSTOM_EXECUTION + if (L->global->ecb.setbreakpoint) + L->global->ecb.setbreakpoint(L, p, i); +#endif + // note: this is important! // we only patch the *first* instruction in each proto that's attributed to a given line // this can be changed, but if requires making patching a bit more nuanced so that we don't patch AUX words diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 3c1869b5b..aeb3d710b 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -31,6 +31,11 @@ Proto* luaF_newproto(lua_State* L) f->source = NULL; f->debugname = NULL; f->debuginsn = NULL; + +#if LUA_CUSTOM_EXECUTION + f->execdata = NULL; +#endif + return f; } @@ -149,6 +154,15 @@ void luaF_freeproto(lua_State* L, Proto* f, lua_Page* page) luaM_freearray(L, f->upvalues, f->sizeupvalues, TString*, f->memcat); if (f->debuginsn) luaM_freearray(L, f->debuginsn, f->sizecode, uint8_t, f->memcat); + +#if LUA_CUSTOM_EXECUTION + if (f->execdata) + { + LUAU_ASSERT(L->global->ecb.destroy); + L->global->ecb.destroy(L, f); + } +#endif + luaM_freegco(L, f, sizeof(Proto), f->memcat, page); } diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 48aaf94b7..97cbfbb11 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -281,6 +281,10 @@ typedef struct Proto TString* debugname; uint8_t* debuginsn; // a copy of code[] array with just opcodes +#if LUA_CUSTOM_EXECUTION + void* execdata; +#endif + GCObject* gclist; diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index fdd7fc2b8..cfe2cbfb1 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -212,6 +212,11 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->memcatbytes[0] = sizeof(LG); g->cb = lua_Callbacks(); + +#if LUA_CUSTOM_EXECUTION + g->ecb = lua_ExecutionCallbacks(); +#endif + g->gcstats = GCStats(); #ifdef LUAI_GCMETRICS diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 06544463a..5b7d08836 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -146,6 +146,19 @@ struct GCMetrics }; #endif +#if LUA_CUSTOM_EXECUTION + +// Callbacks that can be used to to redirect code execution from Luau bytecode VM to a custom implementation (AoT/JiT/sandboxing/...) +typedef struct lua_ExecutionCallbacks +{ + void* context; + void (*destroy)(lua_State* L, Proto* proto); // called when function is destroyed + int (*enter)(lua_State* L, Proto* proto); // called when function is about to start/resume (when execdata is present), return 0 to exit VM + void (*setbreakpoint)(lua_State* L, Proto* proto, int line); // called when a breakpoint is set in a function +} lua_ExecutionCallbacks; + +#endif + /* ** `global state', shared by all threads of this state */ @@ -202,6 +215,10 @@ typedef struct global_State lua_Callbacks cb; +#if LUA_CUSTOM_EXECUTION + lua_ExecutionCallbacks ecb; +#endif + GCStats gcstats; #ifdef LUAI_GCMETRICS diff --git a/VM/src/lvm.h b/VM/src/lvm.h index 25a271661..c4b1c18b5 100644 --- a/VM/src/lvm.h +++ b/VM/src/lvm.h @@ -24,6 +24,9 @@ LUAI_FUNC void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId v LUAI_FUNC void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val); LUAI_FUNC void luaV_concat(lua_State* L, int total, int last); LUAI_FUNC void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil); +LUAI_FUNC void luaV_prepareFORN(lua_State* L, StkId plimit, StkId pstep, StkId pinit); +LUAI_FUNC void luaV_callTM(lua_State* L, int nparams, int res); +LUAI_FUNC void luaV_tryfuncTM(lua_State* L, StkId func); LUAI_FUNC void luau_execute(lua_State* L); LUAI_FUNC int luau_precall(lua_State* L, struct lua_TValue* func, int nresults); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index c3c744b2e..6ceed5120 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -131,84 +131,6 @@ goto dispatchContinue #endif -LUAU_NOINLINE static void luau_prepareFORN(lua_State* L, StkId plimit, StkId pstep, StkId pinit) -{ - if (!ttisnumber(pinit) && !luaV_tonumber(pinit, pinit)) - luaG_forerror(L, pinit, "initial value"); - if (!ttisnumber(plimit) && !luaV_tonumber(plimit, plimit)) - luaG_forerror(L, plimit, "limit"); - if (!ttisnumber(pstep) && !luaV_tonumber(pstep, pstep)) - luaG_forerror(L, pstep, "step"); -} - -// calls a C function f with no yielding support; optionally save one resulting value to the res register -// the function and arguments have to already be pushed to L->top -LUAU_NOINLINE static void luau_callTM(lua_State* L, int nparams, int res) -{ - ++L->nCcalls; - - if (L->nCcalls >= LUAI_MAXCCALLS) - luaD_checkCstack(L); - - luaD_checkstack(L, LUA_MINSTACK); - - StkId top = L->top; - StkId fun = top - nparams - 1; - - CallInfo* ci = incr_ci(L); - ci->func = fun; - ci->base = fun + 1; - ci->top = top + LUA_MINSTACK; - ci->savedpc = NULL; - ci->flags = 0; - ci->nresults = (res >= 0); - LUAU_ASSERT(ci->top <= L->stack_last); - - LUAU_ASSERT(ttisfunction(ci->func)); - LUAU_ASSERT(clvalue(ci->func)->isC); - - L->base = fun + 1; - LUAU_ASSERT(L->top == L->base + nparams); - - lua_CFunction func = clvalue(fun)->c.f; - int n = func(L); - LUAU_ASSERT(n >= 0); // yields should have been blocked by nCcalls - - // ci is our callinfo, cip is our parent - // note that we read L->ci again since it may have been reallocated by the call - CallInfo* cip = L->ci - 1; - - // copy return value into parent stack - if (res >= 0) - { - if (n > 0) - { - setobj2s(L, &cip->base[res], L->top - n); - } - else - { - setnilvalue(&cip->base[res]); - } - } - - L->ci = cip; - L->base = cip->base; - L->top = cip->top; - - --L->nCcalls; -} - -LUAU_NOINLINE static void luau_tryfuncTM(lua_State* L, StkId func) -{ - const TValue* tm = luaT_gettmbyobj(L, func, TM_CALL); - if (!ttisfunction(tm)) - luaG_typeerror(L, func, "call"); - for (StkId p = L->top; p > func; p--) // open space for metamethod - setobjs2s(L, p, p - 1); - L->top++; // stack space pre-allocated by the caller - setobj2s(L, func, tm); // tag method is the new function to be called -} - LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) { ptrdiff_t base = savestack(L, L->base); @@ -284,6 +206,20 @@ static void luau_execute(lua_State* L) LUAU_ASSERT(L->isactive); LUAU_ASSERT(!isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black +#if LUA_CUSTOM_EXECUTION + Proto* p = clvalue(L->ci->func)->l.p; + + if (p->execdata) + { + if (L->global->ecb.enter(L, p) == 0) + return; + } + +reentry: +#endif + + LUAU_ASSERT(isLua(L->ci)); + pc = L->ci->savedpc; cl = clvalue(L->ci->func); base = L->base; @@ -564,7 +500,7 @@ static void luau_execute(lua_State* L) L->top = top + 3; L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ VM_PATCH_C(pc - 2, L->cachedslot); VM_NEXT(); @@ -601,7 +537,7 @@ static void luau_execute(lua_State* L) L->top = top + 3; L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ VM_PATCH_C(pc - 2, L->cachedslot); VM_NEXT(); @@ -680,7 +616,7 @@ static void luau_execute(lua_State* L) L->top = top + 4; L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luau_callTM(L, 3, -1)); + VM_PROTECT(luaV_callTM(L, 3, -1)); // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ VM_PATCH_C(pc - 2, L->cachedslot); VM_NEXT(); @@ -973,7 +909,7 @@ static void luau_execute(lua_State* L) // slow-path: not a function call if (LUAU_UNLIKELY(!ttisfunction(ra))) { - VM_PROTECT(luau_tryfuncTM(L, ra)); + VM_PROTECT(luaV_tryfuncTM(L, ra)); argtop++; // __call adds an extra self } @@ -1009,6 +945,18 @@ static void luau_execute(lua_State* L) setnilvalue(argi++); // complete missing arguments L->top = p->is_vararg ? argi : ci->top; +#if LUA_CUSTOM_EXECUTION + if (p->execdata) + { + LUAU_ASSERT(L->global->ecb.enter); + + if (L->global->ecb.enter(L, p) == 1) + goto reentry; + else + goto exit; + } +#endif + // reentry pc = p->code; cl = ccl; @@ -1092,11 +1040,26 @@ static void luau_execute(lua_State* L) LUAU_ASSERT(isLua(L->ci)); + Closure* nextcl = clvalue(cip->func); + Proto* nextproto = nextcl->l.p; + +#if LUA_CUSTOM_EXECUTION + if (nextproto->execdata) + { + LUAU_ASSERT(L->global->ecb.enter); + + if (L->global->ecb.enter(L, nextproto) == 1) + goto reentry; + else + goto exit; + } +#endif + // reentry pc = cip->savedpc; - cl = clvalue(cip->func); + cl = nextcl; base = L->base; - k = cl->l.p->k; + k = nextproto->k; VM_NEXT(); } @@ -1212,7 +1175,7 @@ static void luau_execute(lua_State* L) int res = int(top - base); L->top = top + 3; - VM_PROTECT(luau_callTM(L, 2, res)); + VM_PROTECT(luaV_callTM(L, 2, res)); pc += !l_isfalse(&base[res]) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); VM_NEXT(); @@ -1324,7 +1287,7 @@ static void luau_execute(lua_State* L) int res = int(top - base); L->top = top + 3; - VM_PROTECT(luau_callTM(L, 2, res)); + VM_PROTECT(luaV_callTM(L, 2, res)); pc += l_isfalse(&base[res]) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); VM_NEXT(); @@ -1519,7 +1482,7 @@ static void luau_execute(lua_State* L) setobj2s(L, top + 2, rc); L->top = top + 3; - VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); VM_NEXT(); } else @@ -1565,7 +1528,7 @@ static void luau_execute(lua_State* L) setobj2s(L, top + 2, rc); L->top = top + 3; - VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); VM_NEXT(); } else @@ -1626,7 +1589,7 @@ static void luau_execute(lua_State* L) setobj2s(L, top + 2, rc); L->top = top + 3; - VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); VM_NEXT(); } else @@ -1687,7 +1650,7 @@ static void luau_execute(lua_State* L) setobj2s(L, top + 2, rc); L->top = top + 3; - VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); VM_NEXT(); } else @@ -1819,7 +1782,7 @@ static void luau_execute(lua_State* L) setobj2s(L, top + 2, kv); L->top = top + 3; - VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); VM_NEXT(); } else @@ -1865,7 +1828,7 @@ static void luau_execute(lua_State* L) setobj2s(L, top + 2, kv); L->top = top + 3; - VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); VM_NEXT(); } else @@ -2029,7 +1992,7 @@ static void luau_execute(lua_State* L) setobj2s(L, top + 1, rb); L->top = top + 2; - VM_PROTECT(luau_callTM(L, 1, LUAU_INSN_A(insn))); + VM_PROTECT(luaV_callTM(L, 1, LUAU_INSN_A(insn))); VM_NEXT(); } else @@ -2145,7 +2108,7 @@ static void luau_execute(lua_State* L) // Note: this doesn't reallocate stack so we don't need to recompute ra/base VM_PROTECT_PC(); - luau_prepareFORN(L, ra + 0, ra + 1, ra + 2); + luaV_prepareFORN(L, ra + 0, ra + 1, ra + 2); } double limit = nvalue(ra + 0); @@ -2861,7 +2824,7 @@ int luau_precall(lua_State* L, StkId func, int nresults) { if (!ttisfunction(func)) { - luau_tryfuncTM(L, func); + luaV_tryfuncTM(L, func); // L->top is incremented by tryfuncTM } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 33d47020f..35124e632 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -508,3 +508,81 @@ void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) if (!ttisnumber(res)) luaG_runerror(L, "'__len' must return a number"); // note, we can't access rb since stack may have been reallocated } + +LUAU_NOINLINE void luaV_prepareFORN(lua_State* L, StkId plimit, StkId pstep, StkId pinit) +{ + if (!ttisnumber(pinit) && !luaV_tonumber(pinit, pinit)) + luaG_forerror(L, pinit, "initial value"); + if (!ttisnumber(plimit) && !luaV_tonumber(plimit, plimit)) + luaG_forerror(L, plimit, "limit"); + if (!ttisnumber(pstep) && !luaV_tonumber(pstep, pstep)) + luaG_forerror(L, pstep, "step"); +} + +// calls a C function f with no yielding support; optionally save one resulting value to the res register +// the function and arguments have to already be pushed to L->top +LUAU_NOINLINE void luaV_callTM(lua_State* L, int nparams, int res) +{ + ++L->nCcalls; + + if (L->nCcalls >= LUAI_MAXCCALLS) + luaD_checkCstack(L); + + luaD_checkstack(L, LUA_MINSTACK); + + StkId top = L->top; + StkId fun = top - nparams - 1; + + CallInfo* ci = incr_ci(L); + ci->func = fun; + ci->base = fun + 1; + ci->top = top + LUA_MINSTACK; + ci->savedpc = NULL; + ci->flags = 0; + ci->nresults = (res >= 0); + LUAU_ASSERT(ci->top <= L->stack_last); + + LUAU_ASSERT(ttisfunction(ci->func)); + LUAU_ASSERT(clvalue(ci->func)->isC); + + L->base = fun + 1; + LUAU_ASSERT(L->top == L->base + nparams); + + lua_CFunction func = clvalue(fun)->c.f; + int n = func(L); + LUAU_ASSERT(n >= 0); // yields should have been blocked by nCcalls + + // ci is our callinfo, cip is our parent + // note that we read L->ci again since it may have been reallocated by the call + CallInfo* cip = L->ci - 1; + + // copy return value into parent stack + if (res >= 0) + { + if (n > 0) + { + setobj2s(L, &cip->base[res], L->top - n); + } + else + { + setnilvalue(&cip->base[res]); + } + } + + L->ci = cip; + L->base = cip->base; + L->top = cip->top; + + --L->nCcalls; +} + +LUAU_NOINLINE void luaV_tryfuncTM(lua_State* L, StkId func) +{ + const TValue* tm = luaT_gettmbyobj(L, func, TM_CALL); + if (!ttisfunction(tm)) + luaG_typeerror(L, func, "call"); + for (StkId p = L->top; p > func; p--) // open space for metamethod + setobjs2s(L, p, p - 1); + L->top++; // stack space pre-allocated by the caller + setobj2s(L, func, tm); // tag method is the new function to be called +} diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index e6222b029..9409c8222 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -62,6 +62,29 @@ LOADN R0 5 LOADK R1 K0 RETURN R0 2 )"); + + CHECK_EQ("\n" + bcb.dumpEverything(), R"( +Function 0 (??): +LOADN R0 5 +LOADK R1 K0 +RETURN R0 2 + +)"); +} + +TEST_CASE("CompileError") +{ + std::string source = "local " + rep("a,", 300) + "a = ..."; + + // fails to parse + std::string bc1 = Luau::compile(source + " !#*$!#$^&!*#&$^*"); + + // parses, but fails to compile (too many locals) + std::string bc2 = Luau::compile(source); + + // 0 acts as a special marker for error bytecode + CHECK_EQ(bc1[0], 0); + CHECK_EQ(bc2[0], 0); } TEST_CASE("LocalsDirectReference") @@ -1230,6 +1253,27 @@ RETURN R0 0 )"); } +TEST_CASE("UnaryBasic") +{ + CHECK_EQ("\n" + compileFunction0("local a = ... return not a"), R"( +GETVARARGS R0 1 +NOT R1 R0 +RETURN R1 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = ... return -a"), R"( +GETVARARGS R0 1 +MINUS R1 R0 +RETURN R1 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = ... return #a"), R"( +GETVARARGS R0 1 +LENGTH R1 R0 +RETURN R1 1 +)"); +} + TEST_CASE("InterpStringWithNoExpressions") { ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; @@ -4975,6 +5019,27 @@ MOVE R1 R0 CALL R1 0 1 RETURN R1 1 )"); + + // we can't inline any functions in modules with getfenv/setfenv + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return 42 +end + +local x = foo() +getfenv() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +GETIMPORT R2 2 +CALL R2 0 0 +RETURN R1 1 +)"); + } TEST_CASE("InlineNestedLoops") @@ -6101,6 +6166,7 @@ return bit32.extract(-1, 31), bit32.replace(100, 1, 0), math.log(100, 10), + typeof(nil), (type("fin")) )", 0, 2), @@ -6156,7 +6222,8 @@ LOADN R47 1 LOADN R48 101 LOADN R49 2 LOADK R50 K3 -RETURN R0 51 +LOADK R51 K4 +RETURN R0 52 )"); } diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 25129bffb..77b30487b 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -778,6 +778,49 @@ TEST_CASE("ApiTables") CHECK(strcmp(lua_tostring(L, -1), "test") == 0); lua_pop(L, 1); + // lua_cleartable + lua_cleartable(L, -1); + lua_pushnil(L); + CHECK(lua_next(L, -2) == 0); + + lua_pop(L, 1); +} + +TEST_CASE("ApiIter") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_newtable(L); + lua_pushnumber(L, 123.0); + lua_setfield(L, -2, "key"); + lua_pushnumber(L, 456.0); + lua_rawsetfield(L, -2, "key2"); + lua_pushstring(L, "test"); + lua_rawseti(L, -2, 1); + + // Lua-compatible iteration interface: lua_next + double sum1 = 0; + lua_pushnil(L); + while (lua_next(L, -2)) + { + sum1 += lua_tonumber(L, -2); // key + sum1 += lua_tonumber(L, -1); // value + lua_pop(L, 1); // pop value, key is used by lua_next + } + CHECK(sum1 == 580); + + // Luau iteration interface: lua_rawiter (faster and preferable to lua_next) + double sum2 = 0; + for (int index = 0; index = lua_rawiter(L, -1, index), index >= 0; ) + { + sum2 += lua_tonumber(L, -2); // key + sum2 += lua_tonumber(L, -1); // value + lua_pop(L, 2); // pop both key and value + } + CHECK(sum2 == 580); + + // pop table lua_pop(L, 1); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 8c7d762ed..921e6691c 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1705,8 +1705,6 @@ TEST_CASE_FIXTURE(Fixture, "TestStringInterpolation") TEST_CASE_FIXTURE(Fixture, "IntegerParsing") { - ScopedFastFlag luauLintParseIntegerIssues{"LuauLintParseIntegerIssues", true}; - LintResult result = lint(R"( local _ = 0b10000000000000000000000000000000000000000000000000000000000000000 local _ = 0x10000000000000000 @@ -1720,7 +1718,6 @@ local _ = 0x10000000000000000 // TODO: remove with FFlagLuauErrorDoubleHexPrefix TEST_CASE_FIXTURE(Fixture, "IntegerParsingDoublePrefix") { - ScopedFastFlag luauLintParseIntegerIssues{"LuauLintParseIntegerIssues", true}; ScopedFastFlag luauErrorDoubleHexPrefix{"LuauErrorDoubleHexPrefix", false}; // Lint will be available until we start rejecting code LintResult result = lint(R"( diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 42e9a9336..31df707d7 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -2,6 +2,7 @@ #include "Fixture.h" +#include "Luau/Common.h" #include "doctest.h" #include "Luau/Normalize.h" @@ -747,7 +748,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_union") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", true}, - {"LuauFixNormalizationOfCyclicUnions", true}, }; CheckResult result = check(R"( @@ -765,7 +765,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_intersection") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", true}, - {"LuauFixNormalizationOfCyclicUnions", true}, }; CheckResult result = check(R"( @@ -784,7 +783,6 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_indexers") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", true}, - {"LuauFixNormalizationOfCyclicUnions", true}, }; CheckResult result = check(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 44a6b4acf..b4064cfb5 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -682,7 +682,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_numbers_binary") TEST_CASE_FIXTURE(Fixture, "parse_numbers_error") { - ScopedFastFlag luauLintParseIntegerIssues{"LuauLintParseIntegerIssues", true}; ScopedFastFlag luauErrorDoubleHexPrefix{"LuauErrorDoubleHexPrefix", true}; CHECK_EQ(getParseError("return 0b123"), "Malformed number"); @@ -695,7 +694,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_numbers_error") TEST_CASE_FIXTURE(Fixture, "parse_numbers_error_soft") { - ScopedFastFlag luauLintParseIntegerIssues{"LuauLintParseIntegerIssues", true}; ScopedFastFlag luauErrorDoubleHexPrefix{"LuauErrorDoubleHexPrefix", false}; CHECK_EQ(getParseError("return 0x0x0x0x0x0x0x0"), "Malformed number"); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index dd91467d5..834391a75 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -198,8 +198,12 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases") LUAU_REQUIRE_ERROR_COUNT(1, result); + const char* expectedError = "Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; + CHECK(result.errors[0].location == Location{{4, 31}, {4, 44}}); - CHECK(toString(result.errors[0]) == "Type '{ v: string }' could not be converted into 'T'"); + CHECK(toString(result.errors[0]) == expectedError); } TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") @@ -215,8 +219,14 @@ TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") LUAU_REQUIRE_ERROR_COUNT(1, result); + const char* expectedError = "Type '{ t: { v: string } }' could not be converted into 'U'\n" + "caused by:\n" + " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; + CHECK(result.errors[0].location == Location{{4, 31}, {4, 52}}); - CHECK(toString(result.errors[0]) == "Type '{ t: { v: string } }' could not be converted into 'U'"); + CHECK(toString(result.errors[0]) == expectedError); } TEST_CASE_FIXTURE(Fixture, "mutually_recursive_generic_aliases") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 40ea0ca11..2c4c35040 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -558,11 +558,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_f CHECK_EQ("Not all codepaths in this function return 'self, a...'.", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "dcr_cant_partially_dispatch_a_constraint") +TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") { ScopedFastFlag sff[] = { {"DebugLuauDeferredConstraintResolution", true}, - {"LuauSpecialTypesAsterisked", true}, }; CheckResult result = check(R"( @@ -577,7 +576,6 @@ TEST_CASE_FIXTURE(Fixture, "dcr_cant_partially_dispatch_a_constraint") LUAU_REQUIRE_NO_ERRORS(result); - // We should be able to resolve this to number, but we're not there yet. // Solving this requires recognizing that we can partially solve the // following constraint: // @@ -586,7 +584,7 @@ TEST_CASE_FIXTURE(Fixture, "dcr_cant_partially_dispatch_a_constraint") // The correct thing for us to do is to consider the constraint dispatched, // but we need to also record a new constraint number <: *blocked* to finish // the job later. - CHECK("(a, *error-type*) -> ()" == toString(requireType("prime_iter"))); + CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); } TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 7d98b5db7..5ee956d7b 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -314,6 +314,31 @@ caused by: toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "parametric_tagged_union_alias") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + }; + + CheckResult result = check(R"( + type Ok = {success: true, result: T} + type Err = {success: false, error: T} + type Result = Ok | Err + + local a : Result = {success = false, result = "hotdogs"} + local b : Result = {success = true, result = "hotdogs"} + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + const std::string expectedError = "Type '{ result: string, success: false }' could not be converted into 'Err | Ok'\n" + "caused by:\n" + " None of the union options are compatible. For example: Table type '{ result: string, success: false }'" + " not compatible with type 'Err' because the former is missing field 'error'"; + + CHECK(toString(result.errors[0]) == expectedError); +} + TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index e8bfb67f9..a96c36760 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1082,7 +1082,7 @@ TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_ )"); } -TEST_CASE_FIXTURE(Fixture, "types stored in astResolvedTypes") +TEST_CASE_FIXTURE(Fixture, "types_stored_in_astResolvedTypes") { CheckResult result = check(R"( type alias = typeof("hello") @@ -1122,4 +1122,41 @@ TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_higher_order_function") CHECK(location.end.line == 4); } +TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + }; + + CheckResult result = check(R"( + local function hasDivisors(value: number) + end + + function prime_iter(state, index) + hasDivisors(index) + index += 1 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Solving this requires recognizing that we can't dispatch a constraint + // like this without doing further work: + // + // (*blocked*) -> () <: (number) -> (b...) + // + // We solve this by searching both types for BlockedTypeVars and block the + // constraint on any we find. It also gets the job done, but I'm worried + // about the efficiency of doing so many deep type traversals and it may + // make us more prone to getting stuck on constraint cycles. + // + // If this doesn't pan out, a possible solution is to go further down the + // path of supporting partial constraint dispatch. The way it would work is + // that we'd dispatch the above constraint by binding b... to (), but we + // would append a new constraint number <: *blocked* to the constraint set + // to be solved later. This should be faster and theoretically less prone + // to cyclic constraint dependencies. + CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index 9c2df8059..825fb2f68 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -10,7 +10,6 @@ AnnotationTests.luau_ice_triggers_an_ice_handler AnnotationTests.luau_print_is_magic_if_the_flag_is_set AnnotationTests.occurs_check_on_cyclic_intersection_typevar AnnotationTests.occurs_check_on_cyclic_union_typevar -AnnotationTests.too_many_type_params AnnotationTests.two_type_params AnnotationTests.use_type_required_from_another_file AstQuery.last_argument_function_call_type @@ -28,7 +27,6 @@ AutocompleteTest.autocomplete_interpolated_string AutocompleteTest.autocomplete_on_string_singletons AutocompleteTest.autocomplete_oop_implicit_self AutocompleteTest.autocomplete_repeat_middle_keyword -AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.autocomplete_string_singleton_escape AutocompleteTest.autocomplete_string_singletons AutocompleteTest.autocomplete_while_middle_keywords @@ -85,7 +83,6 @@ AutocompleteTest.type_correct_expected_argument_type_pack_suggestion AutocompleteTest.type_correct_expected_argument_type_suggestion AutocompleteTest.type_correct_expected_argument_type_suggestion_optional AutocompleteTest.type_correct_expected_argument_type_suggestion_self -AutocompleteTest.type_correct_expected_return_type_pack_suggestion AutocompleteTest.type_correct_expected_return_type_suggestion AutocompleteTest.type_correct_full_type_suggestion AutocompleteTest.type_correct_function_no_parenthesis @@ -113,7 +110,6 @@ BuiltinTests.dont_add_definitions_to_persistent_types BuiltinTests.find_capture_types BuiltinTests.find_capture_types2 BuiltinTests.find_capture_types3 -BuiltinTests.getfenv BuiltinTests.global_singleton_types_are_sealed BuiltinTests.gmatch_capture_types BuiltinTests.gmatch_capture_types2 @@ -130,13 +126,9 @@ BuiltinTests.next_iterator_should_infer_types_and_type_check BuiltinTests.os_time_takes_optional_date_table BuiltinTests.pairs_iterator_should_infer_types_and_type_check BuiltinTests.see_thru_select -BuiltinTests.see_thru_select_count -BuiltinTests.select_on_variadic BuiltinTests.select_slightly_out_of_range BuiltinTests.select_way_out_of_range BuiltinTests.select_with_decimal_argument_is_rounded_down -BuiltinTests.select_with_variadic_typepack_tail -BuiltinTests.select_with_variadic_typepack_tail_and_string_head BuiltinTests.set_metatable_needs_arguments BuiltinTests.setmetatable_should_not_mutate_persisted_types BuiltinTests.sort @@ -147,20 +139,16 @@ BuiltinTests.string_format_arg_types_inference BuiltinTests.string_format_as_method BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions -BuiltinTests.string_format_tostring_specifier -BuiltinTests.string_format_tostring_specifier_type_constraint BuiltinTests.string_format_use_correct_argument BuiltinTests.string_format_use_correct_argument2 BuiltinTests.string_format_use_correct_argument3 BuiltinTests.string_lib_self_noself BuiltinTests.table_concat_returns_string -BuiltinTests.table_dot_remove_optionally_returns_generic BuiltinTests.table_freeze_is_generic BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload BuiltinTests.table_pack BuiltinTests.table_pack_reduce -BuiltinTests.table_pack_variadic BuiltinTests.tonumber_returns_optional_number_type BuiltinTests.tonumber_returns_optional_number_type2 DefinitionTests.class_definition_overload_metamethods @@ -168,7 +156,6 @@ DefinitionTests.declaring_generic_functions DefinitionTests.definition_file_classes FrontendTest.ast_node_at_position FrontendTest.automatically_check_dependent_scripts -FrontendTest.dont_reparse_clean_file_when_linting FrontendTest.environments FrontendTest.imported_table_modification_2 FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded @@ -187,11 +174,8 @@ GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_always_instantiate_generic_intersection_types GenericsTests.do_not_infer_generic_functions -GenericsTests.dont_unify_bound_types GenericsTests.duplicate_generic_type_packs GenericsTests.duplicate_generic_types -GenericsTests.error_detailed_function_mismatch_generic_pack -GenericsTests.error_detailed_function_mismatch_generic_types GenericsTests.factories_of_generics GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many @@ -205,30 +189,22 @@ GenericsTests.generic_type_pack_unification2 GenericsTests.generic_type_pack_unification3 GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument_overloaded -GenericsTests.infer_generic_lib_function_function_argument GenericsTests.infer_generic_methods +GenericsTests.inferred_local_vars_can_be_polytypes GenericsTests.instantiate_cyclic_generic_function GenericsTests.instantiate_generic_function_in_assignments GenericsTests.instantiate_generic_function_in_assignments2 GenericsTests.instantiated_function_argument_names GenericsTests.instantiation_sharing_types GenericsTests.local_vars_can_be_instantiated_polytypes -GenericsTests.mutable_state_polymorphism GenericsTests.no_stack_overflow_from_quantifying GenericsTests.properties_can_be_instantiated_polytypes -GenericsTests.rank_N_types_via_typeof GenericsTests.reject_clashing_generic_and_pack_names GenericsTests.self_recursive_instantiated_param -IntersectionTypes.argument_is_intersection -IntersectionTypes.error_detailed_intersection_all -IntersectionTypes.error_detailed_intersection_part -IntersectionTypes.fx_intersection_as_argument -IntersectionTypes.fx_union_as_argument_fails IntersectionTypes.index_on_an_intersection_type_with_mixed_types IntersectionTypes.index_on_an_intersection_type_with_property_guaranteed_to_exist IntersectionTypes.index_on_an_intersection_type_works_at_arbitrary_depth IntersectionTypes.no_stack_overflow_from_flattenintersection -IntersectionTypes.overload_is_not_a_function IntersectionTypes.select_correct_union_fn IntersectionTypes.should_still_pick_an_overload_whose_arguments_are_unions IntersectionTypes.table_intersection_write_sealed @@ -236,17 +212,13 @@ IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect isSubtype.intersection_of_tables isSubtype.table_with_table_prop -Linter.TableOperations ModuleTests.clone_self_property ModuleTests.deepClone_cyclic_table ModuleTests.do_not_clone_reexports -NonstrictModeTests.delay_function_does_not_require_its_argument_to_return_anything NonstrictModeTests.for_in_iterator_variables_are_any NonstrictModeTests.function_parameters_are_any NonstrictModeTests.inconsistent_module_return_types_are_ok -NonstrictModeTests.inconsistent_return_types_are_ok NonstrictModeTests.infer_nullary_function -NonstrictModeTests.infer_the_maximum_number_of_values_the_function_could_return NonstrictModeTests.inline_table_props_are_also_any NonstrictModeTests.local_tables_are_not_any NonstrictModeTests.locals_are_any_by_default @@ -294,20 +266,16 @@ ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_ret ProvisionalTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound ProvisionalTests.it_should_be_agnostic_of_actual_size ProvisionalTests.lower_bounds_calculation_is_too_permissive_with_overloaded_higher_order_functions -ProvisionalTests.lvalue_equals_another_lvalue_with_no_overlap ProvisionalTests.normalization_fails_on_certain_kinds_of_cyclic_tables -ProvisionalTests.operator_eq_completely_incompatible ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weirditer_should_not_loop_forever ProvisionalTests.while_body_are_also_refined RefinementTest.and_constraint -RefinementTest.and_or_peephole_refinement RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number RefinementTest.assert_non_binary_expressions_actually_resolve_constraints -RefinementTest.assign_table_with_refined_property_with_a_similar_type_is_illegal RefinementTest.call_a_more_specific_function_using_typeguard RefinementTest.correctly_lookup_a_shadowed_local_that_which_was_previously_refined RefinementTest.correctly_lookup_property_whose_base_was_previously_refined @@ -319,24 +287,19 @@ RefinementTest.discriminate_tag RefinementTest.either_number_or_string RefinementTest.eliminate_subclasses_of_instance RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil -RefinementTest.free_type_is_equal_to_an_lvalue -RefinementTest.impossible_type_narrow_is_not_an_error RefinementTest.index_on_a_refined_property RefinementTest.invert_is_truthy_constraint RefinementTest.invert_is_truthy_constraint_ifelse_expression RefinementTest.is_truthy_constraint RefinementTest.is_truthy_constraint_ifelse_expression -RefinementTest.lvalue_is_equal_to_a_term -RefinementTest.lvalue_is_equal_to_another_lvalue RefinementTest.lvalue_is_not_nil RefinementTest.merge_should_be_fully_agnostic_of_hashmap_ordering +RefinementTest.narrow_boolean_to_true_or_false RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.narrow_this_large_union RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.not_a_and_not_b RefinementTest.not_a_and_not_b2 -RefinementTest.not_a_or_not_b -RefinementTest.not_a_or_not_b2 RefinementTest.not_and_constraint RefinementTest.not_t_or_some_prop_of_t RefinementTest.or_predicate_with_truthy_predicates @@ -344,7 +307,6 @@ RefinementTest.parenthesized_expressions_are_followed_through RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table RefinementTest.refine_the_correct_types_opposite_of_when_a_is_not_number_or_string RefinementTest.refine_unknowns -RefinementTest.string_not_equal_to_string_or_nil RefinementTest.term_is_equal_to_an_lvalue RefinementTest.truthy_constraint_on_properties RefinementTest.type_assertion_expr_carry_its_constraints @@ -363,11 +325,9 @@ RefinementTest.typeguard_narrows_for_functions RefinementTest.typeguard_narrows_for_table RefinementTest.typeguard_not_to_be_string RefinementTest.typeguard_only_look_up_types_from_global_scope -RefinementTest.unknown_lvalue_is_not_synonymous_with_other_on_not_equal RefinementTest.what_nonsensical_condition RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RefinementTest.x_is_not_instance_or_else_not_part -RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.access_index_metamethod_that_returns_variadic @@ -383,7 +343,6 @@ TableTests.casting_sealed_tables_with_props_into_table_with_indexer TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.checked_prop_too_early -TableTests.confusing_indexing TableTests.defining_a_method_for_a_builtin_sealed_table_must_fail TableTests.defining_a_method_for_a_local_sealed_table_must_fail TableTests.defining_a_self_method_for_a_builtin_sealed_table_must_fail @@ -393,11 +352,7 @@ TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index TableTests.dont_leak_free_table_props TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys -TableTests.error_detailed_indexer_key -TableTests.error_detailed_indexer_value TableTests.error_detailed_metatable_prop -TableTests.error_detailed_prop -TableTests.error_detailed_prop_nested TableTests.expected_indexer_from_table_union TableTests.expected_indexer_value_type_extra TableTests.expected_indexer_value_type_extra_2 @@ -422,11 +377,7 @@ TableTests.infer_indexer_from_value_property_in_literal TableTests.inferred_return_type_of_free_table TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 -TableTests.instantiate_tables_at_scope_level TableTests.leaking_bad_metatable_errors -TableTests.length_operator_intersection -TableTests.length_operator_non_table_union -TableTests.length_operator_union TableTests.length_operator_union_errors TableTests.less_exponential_blowup_please TableTests.meta_add @@ -444,16 +395,15 @@ TableTests.open_table_unification_2 TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table_2 TableTests.pass_incompatible_union_to_a_generic_table_without_crashing +TableTests.passing_compatible_unions_to_a_generic_table_without_crashing TableTests.persistent_sealed_table_is_immutable TableTests.prop_access_on_key_whose_types_mismatches TableTests.property_lookup_through_tabletypevar_metatable TableTests.quantify_even_that_table_was_never_exported_at_all -TableTests.quantify_metatables_of_metatables_of_table TableTests.quantifying_a_bound_var_works TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table TableTests.result_is_always_any_if_lhs_is_any TableTests.result_is_bool_for_equality_operators_if_lhs_is_any -TableTests.right_table_missing_key TableTests.right_table_missing_key2 TableTests.scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type TableTests.scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type @@ -463,14 +413,11 @@ TableTests.shared_selfs_through_metatables TableTests.table_indexing_error_location TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict TableTests.table_insert_should_cope_with_optional_properties_in_strict -TableTests.table_length TableTests.table_param_row_polymorphism_2 TableTests.table_param_row_polymorphism_3 TableTests.table_simple_call TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors -TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors2 -TableTests.table_unifies_into_map TableTests.tables_get_names_from_their_locals TableTests.tc_member_function TableTests.tc_member_function_2 @@ -480,10 +427,8 @@ TableTests.unification_of_unions_in_a_self_referential_type TableTests.unifying_tables_shouldnt_uaf2 TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon -TableTests.width_subtyping ToDot.bound_table ToDot.function -ToDot.metatable ToDot.table ToString.exhaustive_toString_of_cyclic_table ToString.function_type_with_argument_names_generic @@ -500,7 +445,6 @@ TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails TryUnifyTests.variadics_should_use_reversed_properly -TypeAliases.cli_38393_recursive_intersection_oom TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any TypeAliases.generic_param_remap TypeAliases.mismatched_generic_pack_type_param @@ -517,29 +461,19 @@ TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeInfer.checking_should_not_ice -TypeInfer.cyclic_follow -TypeInfer.do_not_bind_a_free_table_to_a_union_containing_that_table TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.globals TypeInfer.globals2 TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.no_stack_overflow_from_isoptional TypeInfer.tc_after_error_recovery_no_replacement_name_in_error -TypeInfer.tc_if_else_expressions1 -TypeInfer.tc_if_else_expressions2 -TypeInfer.tc_if_else_expressions_expected_type_1 -TypeInfer.tc_if_else_expressions_expected_type_2 TypeInfer.tc_if_else_expressions_expected_type_3 -TypeInfer.tc_if_else_expressions_type_union TypeInfer.tc_interpolated_string_basic TypeInfer.tc_interpolated_string_constant_type TypeInfer.tc_interpolated_string_with_invalid_expression TypeInfer.type_infer_recursion_limit_no_ice TypeInferAnyError.assign_prop_to_table_by_calling_any_yields_any -TypeInferAnyError.can_get_length_of_any TypeInferAnyError.for_in_loop_iterator_is_any2 -TypeInferAnyError.length_of_error_type_does_not_produce_an_error -TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferAnyError.union_of_types_regression_test TypeInferClasses.call_base_method TypeInferClasses.call_instance_method @@ -549,39 +483,23 @@ TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.classes_can_have_overloaded_operators TypeInferClasses.classes_without_overloaded_operators_cannot_be_added TypeInferClasses.detailed_class_unification_error -TypeInferClasses.function_arguments_are_covariant -TypeInferClasses.higher_order_function_return_type_is_not_contravariant -TypeInferClasses.higher_order_function_return_values_are_covariant +TypeInferClasses.higher_order_function_arguments_are_contravariant TypeInferClasses.optional_class_field_access_error TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties -TypeInferClasses.table_indexers_are_invariant -TypeInferClasses.table_properties_are_invariant TypeInferClasses.warn_when_prop_almost_matches TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class -TypeInferFunctions.another_recursive_local_function TypeInferFunctions.call_o_with_another_argument_after_foo_was_quantified -TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument -TypeInferFunctions.complicated_return_types_require_an_explicit_annotation TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict -TypeInferFunctions.error_detailed_function_mismatch_arg -TypeInferFunctions.error_detailed_function_mismatch_arg_count -TypeInferFunctions.error_detailed_function_mismatch_ret -TypeInferFunctions.error_detailed_function_mismatch_ret_count -TypeInferFunctions.error_detailed_function_mismatch_ret_mult TypeInferFunctions.free_is_not_bound_to_unknown TypeInferFunctions.func_expr_doesnt_leak_free TypeInferFunctions.function_cast_error_uses_correct_language -TypeInferFunctions.function_decl_non_self_sealed_overwrite TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite TypeInferFunctions.function_does_not_return_enough_values TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer -TypeInferFunctions.higher_order_function_2 -TypeInferFunctions.higher_order_function_4 -TypeInferFunctions.ignored_return_values TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.inconsistent_higher_order_function @@ -589,15 +507,11 @@ TypeInferFunctions.inconsistent_return_types TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_return_type_from_selected_overload TypeInferFunctions.infer_that_function_does_not_return_a_table -TypeInferFunctions.it_is_ok_not_to_supply_enough_retvals TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count TypeInferFunctions.no_lossy_function_type -TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.quantify_constrained_types TypeInferFunctions.record_matching_overload -TypeInferFunctions.recursive_function -TypeInferFunctions.recursive_local_function TypeInferFunctions.report_exiting_without_return_nonstrict TypeInferFunctions.report_exiting_without_return_strict TypeInferFunctions.return_type_by_overload @@ -608,11 +522,7 @@ TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_return_values TypeInferFunctions.vararg_function_is_quantified -TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values -TypeInferLoops.for_in_loop_with_custom_iterator -TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_just_one_iterator_is_ok -TypeInferLoops.loop_iter_iter_metamethod TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.loop_typecheck_crash_on_empty_optional @@ -631,7 +541,6 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory -TypeInferOOP.methods_are_topologically_sorted TypeInferOperators.and_adds_boolean TypeInferOperators.and_adds_boolean_no_superfluous_union TypeInferOperators.and_binexps_dont_unify @@ -641,13 +550,10 @@ TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators TypeInferOperators.cli_38355_recursive_union -TypeInferOperators.compare_numbers -TypeInferOperators.compare_strings TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.compound_assign_mismatch_op TypeInferOperators.compound_assign_mismatch_result TypeInferOperators.concat_op_on_free_lhs_and_string_rhs -TypeInferOperators.concat_op_on_string_lhs_and_free_rhs TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops TypeInferOperators.dont_strip_nil_from_rhs_or_operator TypeInferOperators.equality_operations_succeed_if_any_union_branch_succeeds @@ -655,18 +561,12 @@ TypeInferOperators.error_on_invalid_operand_types_to_relational_operators TypeInferOperators.error_on_invalid_operand_types_to_relational_operators2 TypeInferOperators.expected_types_through_binary_and TypeInferOperators.expected_types_through_binary_or -TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown -TypeInferOperators.operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap -TypeInferOperators.operator_eq_verifies_types_do_intersect TypeInferOperators.or_joins_types TypeInferOperators.or_joins_types_with_no_extras -TypeInferOperators.primitive_arith_no_metatable -TypeInferOperators.primitive_arith_no_metatable_with_follows TypeInferOperators.primitive_arith_possible_metatable TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not TypeInferOperators.refine_and_or -TypeInferOperators.some_primitive_binary_ops TypeInferOperators.strict_binary_op_where_lhs_unknown TypeInferOperators.strip_nil_from_lhs_or_operator TypeInferOperators.strip_nil_from_lhs_or_operator2 @@ -676,13 +576,11 @@ TypeInferOperators.typecheck_unary_len_error TypeInferOperators.typecheck_unary_minus TypeInferOperators.typecheck_unary_minus_error TypeInferOperators.unary_not_is_boolean -TypeInferOperators.unknown_type_in_comparison TypeInferOperators.UnknownGlobalCompoundAssign TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.singleton_types TypeInferPrimitives.string_function_other TypeInferPrimitives.string_index -TypeInferPrimitives.string_length TypeInferPrimitives.string_method TypeInferUnknownNever.assign_to_global_which_is_never TypeInferUnknownNever.assign_to_local_which_is_never @@ -692,9 +590,7 @@ TypeInferUnknownNever.call_never TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_never TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_sorta_never -TypeInferUnknownNever.length_of_never TypeInferUnknownNever.math_operators_and_never -TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.unary_minus_of_never TypePackTests.higher_order_function @@ -718,18 +614,17 @@ TypePackTests.type_alias_type_packs TypePackTests.type_alias_type_packs_errors TypePackTests.type_alias_type_packs_import TypePackTests.type_alias_type_packs_nested -TypePackTests.type_pack_hidden_free_tail_infinite_growth TypePackTests.type_pack_type_parameters +TypePackTests.unify_variadic_tails_in_arguments +TypePackTests.unify_variadic_tails_in_arguments_free TypePackTests.varargs_inference_through_multiple_scopes TypePackTests.variadic_packs -TypeSingletons.bool_singleton_subtype -TypeSingletons.bool_singletons -TypeSingletons.bool_singletons_mismatch TypeSingletons.enums_using_singletons TypeSingletons.enums_using_singletons_mismatch TypeSingletons.enums_using_singletons_subtyping TypeSingletons.error_detailed_tagged_union_mismatch_bool TypeSingletons.error_detailed_tagged_union_mismatch_string +TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch TypeSingletons.if_then_else_expression_singleton_options TypeSingletons.indexing_on_string_singletons @@ -752,7 +647,6 @@ TypeSingletons.widening_happens_almost_everywhere TypeSingletons.widening_happens_almost_everywhere_except_for_tables UnionTypes.error_detailed_optional UnionTypes.error_detailed_union_all -UnionTypes.error_detailed_union_part UnionTypes.error_takes_optional_arguments UnionTypes.index_on_a_union_type_with_missing_property UnionTypes.index_on_a_union_type_with_mixed_types @@ -771,6 +665,4 @@ UnionTypes.optional_union_follow UnionTypes.optional_union_functions UnionTypes.optional_union_members UnionTypes.optional_union_methods -UnionTypes.return_types_can_be_disjoint UnionTypes.table_union_write_indirect -UnionTypes.union_equality_comparisons diff --git a/tools/lvmexecute_split.py b/tools/lvmexecute_split.py new file mode 100644 index 000000000..10e3ccbb8 --- /dev/null +++ b/tools/lvmexecute_split.py @@ -0,0 +1,100 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# This code can be used to split lvmexecute.cpp VM switch into separate functions for use as native code generation fallbacks +import sys +import re + +input = sys.stdin.readlines() + +inst = "" + +header = """// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand +#pragma once + +#include + +struct lua_State; +struct Closure; +typedef uint32_t Instruction; +typedef struct lua_TValue TValue; +typedef TValue* StkId; + +""" + +source = """// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand +#include "Fallbacks.h" +#include "FallbacksProlog.h" + +""" + +function = "" + +state = 0 + +# parse with the state machine +for line in input: + # find the start of an instruction + if state == 0: + match = re.match("\s+VM_CASE\((LOP_[A-Z_0-9]+)\)", line) + + if match: + inst = match[1] + signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k)" + header += signature + ";\n" + function = signature + "\n" + state = 1 + + # find the end of an instruction + elif state == 1: + # remove jumps back into the native code + if line == "#if LUA_CUSTOM_EXECUTION\n": + state = 2 + continue + + if line[0] == ' ': + finalline = line[12:-1] + "\n" + else: + finalline = line + + finalline = finalline.replace("VM_NEXT();", "return pc;"); + finalline = finalline.replace("goto exit;", "return NULL;"); + finalline = finalline.replace("return;", "return NULL;"); + + function += finalline + match = re.match(" }", line) + + if match: + # break is not supported + if inst == "LOP_BREAK": + function = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k)\n" + function += "{\n LUAU_ASSERT(!\"Unsupported deprecated opcode\");\n LUAU_UNREACHABLE();\n}\n" + # handle fallthrough + elif inst == "LOP_NAMECALL": + function = function[:-len(finalline)] + function += " return pc;\n}\n" + + source += function + "\n" + state = 0 + + # skip LUA_CUSTOM_EXECUTION code blocks + elif state == 2: + if line == "#endif\n": + state = 3 + continue + + # skip extra line + elif state == 3: + state = 1 + +# make sure we found the ending +assert(state == 0) + +with open("Fallbacks.h", "w") as fp: + fp.writelines(header) + +with open("Fallbacks.cpp", "w") as fp: + fp.writelines(source) From 4176e1c0447d1fdca38d7428b2aec843159d3495 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 23 Sep 2022 11:43:13 -0700 Subject: [PATCH 06/66] Fix internals library --- CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9ad16e8d2..43289f418 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,8 @@ if(LUAU_BUILD_WEB) add_executable(Luau.Web) endif() +# Proxy target to make it possible to depend on private VM headers +add_library(Luau.VM.Internals INTERFACE) include(Sources.cmake) @@ -79,6 +81,8 @@ target_link_libraries(Luau.VM PUBLIC Luau.Common) target_include_directories(isocline PUBLIC extern/isocline/include) +target_include_directories(Luau.VM.Internals INTERFACE VM/src) + set(LUAU_OPTIONS) if(MSVC) From d0989b9e15e03bcb9093846a195b449af7d0f110 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 29 Sep 2022 15:11:54 -0700 Subject: [PATCH 07/66] Sync to upstream/release/547 --- Analysis/include/Luau/BuiltinDefinitions.h | 4 +- Analysis/include/Luau/Config.h | 7 +- Analysis/include/Luau/Constraint.h | 3 +- .../include/Luau/ConstraintGraphBuilder.h | 1 + Analysis/include/Luau/ConstraintSolver.h | 1 + Analysis/include/Luau/Error.h | 5 +- Analysis/include/Luau/Frontend.h | 10 +- Analysis/include/Luau/Instantiation.h | 10 +- Analysis/include/Luau/Scope.h | 15 +- Analysis/include/Luau/TxnLog.h | 10 + Analysis/include/Luau/TypeArena.h | 1 + Analysis/include/Luau/TypeInfer.h | 10 +- Analysis/include/Luau/TypeUtils.h | 4 + Analysis/include/Luau/TypeVar.h | 18 +- Analysis/include/Luau/Unifiable.h | 1 + Analysis/include/Luau/Unifier.h | 3 +- Analysis/src/Autocomplete.cpp | 55 ++-- Analysis/src/BuiltinDefinitions.cpp | 95 +++++- Analysis/src/Clone.cpp | 7 +- Analysis/src/Config.cpp | 11 +- Analysis/src/ConstraintGraphBuilder.cpp | 64 +++- Analysis/src/ConstraintSolver.cpp | 68 ++++- Analysis/src/Error.cpp | 14 +- Analysis/src/Frontend.cpp | 15 +- Analysis/src/Instantiation.cpp | 8 +- Analysis/src/Quantify.cpp | 52 +--- Analysis/src/Scope.cpp | 28 +- Analysis/src/ToString.cpp | 5 + Analysis/src/TxnLog.cpp | 39 +++ Analysis/src/TypeArena.cpp | 9 + Analysis/src/TypeChecker2.cpp | 259 +++++++++------- Analysis/src/TypeInfer.cpp | 279 +++++++++++------- Analysis/src/TypeUtils.cpp | 42 +++ Analysis/src/TypeVar.cpp | 36 ++- Analysis/src/Unifiable.cpp | 7 + Analysis/src/Unifier.cpp | 108 +++++-- CLI/Analyze.cpp | 2 +- CMakeLists.txt | 4 + CodeGen/src/Fallbacks.cpp | 20 +- Common/include/Luau/Bytecode.h | 4 + Compiler/src/Builtins.cpp | 10 + VM/src/lapi.cpp | 10 +- VM/src/laux.cpp | 6 +- VM/src/lbuiltins.cpp | 81 ++++- VM/src/ldebug.cpp | 2 +- VM/src/ldo.cpp | 12 +- VM/src/lgc.cpp | 101 ++----- VM/src/lobject.cpp | 2 +- VM/src/lobject.h | 14 +- VM/src/ltm.cpp | 1 + VM/src/ltm.h | 1 + VM/src/lvmexecute.cpp | 24 +- VM/src/lvmload.cpp | 2 +- VM/src/lvmutils.cpp | 10 +- fuzz/linter.cpp | 2 +- fuzz/proto.cpp | 2 +- fuzz/typeck.cpp | 2 +- tests/Compiler.test.cpp | 13 +- tests/Conformance.test.cpp | 2 +- tests/Fixture.cpp | 9 +- tests/Frontend.test.cpp | 4 +- tests/ToString.test.cpp | 16 + tests/TypeInfer.aliases.test.cpp | 39 ++- tests/TypeInfer.annotations.test.cpp | 9 +- tests/TypeInfer.builtins.test.cpp | 54 +++- tests/TypeInfer.functions.test.cpp | 48 ++- tests/TypeInfer.loops.test.cpp | 75 ++++- tests/TypeInfer.provisional.test.cpp | 4 +- tests/TypeInfer.refinements.test.cpp | 22 -- tests/TypeInfer.test.cpp | 34 +++ tests/TypeInfer.tryUnify.test.cpp | 36 +++ tests/conformance/events.lua | 3 + tools/faillist.txt | 105 +++---- tools/lldb_formatters.lldb | 7 +- tools/lldb_formatters.py | 152 ++++++++-- tools/perfgraph.py | 1 - tools/perfstat.py | 65 ++++ tools/test_dcr.py | 2 +- 78 files changed, 1639 insertions(+), 677 deletions(-) create mode 100644 tools/perfstat.py diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 0292dff78..616367bb4 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -8,9 +8,11 @@ namespace Luau { -void registerBuiltinTypes(TypeChecker& typeChecker); void registerBuiltinTypes(Frontend& frontend); +void registerBuiltinGlobals(TypeChecker& typeChecker); +void registerBuiltinGlobals(Frontend& frontend); + TypeId makeUnion(TypeArena& arena, std::vector&& types); TypeId makeIntersection(TypeArena& arena, std::vector&& types); diff --git a/Analysis/include/Luau/Config.h b/Analysis/include/Luau/Config.h index 56cdfe781..8ba4ffa56 100644 --- a/Analysis/include/Luau/Config.h +++ b/Analysis/include/Luau/Config.h @@ -17,12 +17,9 @@ constexpr const char* kConfigName = ".luaurc"; struct Config { - Config() - { - enabledLint.setDefaults(); - } + Config(); - Mode mode = Mode::NoCheck; + Mode mode; ParseOptions parseOptions; diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 3ffb3fb8c..0e19f13f5 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -94,8 +94,9 @@ struct FunctionCallConstraint { std::vector> innerConstraints; TypeId fn; + TypePackId argsPack; TypePackId result; - class AstExprCall* astFragment; + class AstExprCall* callSite; }; // result ~ prim ExpectedType SomeSingletonType MultitonType diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index e7d8ad459..973c0a8ea 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -124,6 +124,7 @@ struct ConstraintGraphBuilder void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); + void visit(const ScopePtr& scope, AstStatError* error); TypePackId checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes = {}); TypePackId checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes = {}); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index abea51b87..06f53e4ab 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -104,6 +104,7 @@ struct ConstraintSolver bool tryDispatch(const HasPropConstraint& c, NotNull constraint); // for a, ... in some_table do + // also handles __iter metamethod bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); // for a, ... in next_function, t, ... do diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index eab6a21d3..f3735864f 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -81,7 +81,7 @@ struct OnlyTablesCanHaveMethods struct DuplicateTypeDefinition { Name name; - Location previousLocation; + std::optional previousLocation; bool operator==(const DuplicateTypeDefinition& rhs) const; }; @@ -91,7 +91,8 @@ struct CountMismatch enum Context { Arg, - Result, + FunctionResult, + ExprListResult, Return, }; size_t expected; diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 04c598de1..5df6f4b59 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -157,7 +157,8 @@ struct Frontend ScopePtr getGlobalScope(); private: - ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector requireCycles); + ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector requireCycles, + bool forAutocomplete = false); std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); @@ -171,10 +172,9 @@ struct Frontend std::unordered_map environments; std::unordered_map> builtinDefinitions; - ScopePtr globalScope; + SingletonTypes singletonTypes_; public: - SingletonTypes singletonTypes_; const NotNull singletonTypes; FileResolver* fileResolver; @@ -186,13 +186,15 @@ struct Frontend FrontendOptions options; InternalErrorReporter iceHandler; TypeArena globalTypes; - TypeArena arenaForAutocomplete; std::unordered_map sourceNodes; std::unordered_map sourceModules; std::unordered_map requireTrace; Stats stats = {}; + +private: + ScopePtr globalScope; }; } // namespace Luau diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h index e05ceebe4..cd88d33a4 100644 --- a/Analysis/include/Luau/Instantiation.h +++ b/Analysis/include/Luau/Instantiation.h @@ -14,16 +14,18 @@ struct TxnLog; // A substitution which replaces generic types in a given set by free types. struct ReplaceGenerics : Substitution { - ReplaceGenerics( - const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector& generics, const std::vector& genericPacks) + ReplaceGenerics(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope, const std::vector& generics, + const std::vector& genericPacks) : Substitution(log, arena) , level(level) + , scope(scope) , generics(generics) , genericPacks(genericPacks) { } TypeLevel level; + Scope* scope; std::vector generics; std::vector genericPacks; bool ignoreChildren(TypeId ty) override; @@ -36,13 +38,15 @@ struct ReplaceGenerics : Substitution // A substitution which replaces generic functions by monomorphic functions struct Instantiation : Substitution { - Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level) + Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope) : Substitution(log, arena) , level(level) + , scope(scope) { } TypeLevel level; + Scope* scope; bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index b7569d8eb..b2da7bc0f 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -49,9 +49,11 @@ struct Scope std::unordered_map exportedTypeBindings; std::unordered_map privateTypeBindings; std::unordered_map typeAliasLocations; - std::unordered_map> importedTypeBindings; + DenseHashSet builtinTypeNames{""}; + void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun); + std::optional lookup(Symbol sym); std::optional lookupType(const Name& name); @@ -61,7 +63,7 @@ struct Scope std::optional lookupPack(const Name& name); // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) - std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); + std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; RefinementMap refinements; @@ -73,4 +75,13 @@ struct Scope std::unordered_map typeAliasTypePackParameters; }; +// Returns true iff the left scope encloses the right scope. A Scope* equal to +// nullptr is considered to be the outermost-possible scope. +bool subsumesStrict(Scope* left, Scope* right); + +// Returns true if the left scope encloses the right scope, or if they are the +// same scope. As in subsumesStrict(), nullptr is considered to be the +// outermost-possible scope. +bool subsumes(Scope* left, Scope* right); + } // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 016cc927b..3c3122c27 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -186,6 +186,16 @@ struct TxnLog // The pointer returned lives until `commit` or `clear` is called. PendingTypePack* changeLevel(TypePackId tp, TypeLevel newLevel); + // Queues the replacement of a type's scope with the provided scope. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* changeScope(TypeId ty, NotNull scope); + + // Queues the replacement of a type pack's scope with the provided scope. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* changeScope(TypePackId tp, NotNull scope); + // Queues a replacement of a table type with another table type with a new // indexer. // diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 1e029aeb8..c67f643bc 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -30,6 +30,7 @@ struct TypeArena TypeId freshType(TypeLevel level); TypeId freshType(Scope* scope); + TypeId freshType(Scope* scope, TypeLevel level); TypePackId freshTypePack(Scope* scope); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index bbb9bd6d1..e5675ebb3 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -80,10 +80,12 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatForIn& forin); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); - void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0, bool forwardDeclare = false); + void check(const ScopePtr& scope, const AstStatTypeAlias& typealias); void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); + void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0); + void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); @@ -392,8 +394,12 @@ struct TypeChecker std::vector> deferredQuantification; }; +using PrintLineProc = void(*)(const std::string&); + +extern PrintLineProc luauPrintLine; + // Unit test hook -void setPrintLine(void (*pl)(const std::string& s)); +void setPrintLine(PrintLineProc pl); void resetPrintLine(); } // namespace Luau diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index efc1d8814..e5a205bab 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -25,4 +25,8 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro // Returns the minimum and maximum number of types the argument list can accept. std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false); +// "Render" a type pack out to an array of a given length. Expands variadics and +// various other things to get there. +std::vector flatten(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length); + } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 2847d0b16..1d587ffe5 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -27,6 +27,7 @@ namespace Luau struct TypeArena; struct Scope; +using ScopePtr = std::shared_ptr; /** * There are three kinds of type variables: @@ -264,7 +265,15 @@ struct WithPredicate using MagicFunction = std::function>( struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; -using DcrMagicFunction = std::function, TypePackId, const class AstExprCall*)>; +struct MagicFunctionCallContext +{ + NotNull solver; + const class AstExprCall* callSite; + TypePackId arguments; + TypePackId result; +}; + +using DcrMagicFunction = std::function; struct FunctionTypeVar { @@ -277,10 +286,14 @@ struct FunctionTypeVar // Local monomorphic function FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar( + TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local polymorphic function FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, + TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); TypeLevel level; Scope* scope = nullptr; @@ -345,8 +358,9 @@ struct TableTypeVar using Props = std::map; TableTypeVar() = default; - explicit TableTypeVar(TableState state, TypeLevel level); + explicit TableTypeVar(TableState state, TypeLevel level, Scope* scope = nullptr); TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, TableState state); + TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, Scope* scope, TableState state); Props props; std::optional indexer; diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index e5eb41983..0ea175cc4 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -85,6 +85,7 @@ struct Free { explicit Free(TypeLevel level); explicit Free(Scope* scope); + explicit Free(Scope* scope, TypeLevel level); int index; TypeLevel level; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 4d46869dd..26a922f5c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -60,6 +60,7 @@ struct Unifier Location location; Variance variance = Covariant; bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. + bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; UnifierSharedState& sharedState; @@ -140,6 +141,6 @@ struct Unifier std::optional firstPackErrorPos; }; -void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, TypePackId tp); +void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 2fc145d32..1f594fed5 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -1208,13 +1208,11 @@ static bool autocompleteIfElseExpression( } } -static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, const TypeChecker& typeChecker, +static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull singletonTypes, TypeArena* typeArena, const std::vector& ancestry, Position position, AutocompleteEntryMap& result) { LUAU_ASSERT(!ancestry.empty()); - NotNull singletonTypes = typeChecker.singletonTypes; - AstNode* node = ancestry.rbegin()[0]; if (node->is()) @@ -1254,16 +1252,16 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu scope = scope->parent; } - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, typeChecker.nilType); + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->nilType); TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->trueType); TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->falseType); TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForTrue}; - result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForFalse}; - result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; + result["true"] = {AutocompleteEntryKind::Keyword, singletonTypes->booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, singletonTypes->booleanType, false, false, correctForFalse}; + result["nil"] = {AutocompleteEntryKind::Keyword, singletonTypes->nilType, false, false, correctForNil}; result["not"] = {AutocompleteEntryKind::Keyword}; result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; @@ -1274,11 +1272,11 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu return AutocompleteContext::Expression; } -static AutocompleteResult autocompleteExpression(const SourceModule& sourceModule, const Module& module, const TypeChecker& typeChecker, +static AutocompleteResult autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull singletonTypes, TypeArena* typeArena, const std::vector& ancestry, Position position) { AutocompleteEntryMap result; - AutocompleteContext context = autocompleteExpression(sourceModule, module, typeChecker, typeArena, ancestry, position, result); + AutocompleteContext context = autocompleteExpression(sourceModule, module, singletonTypes, typeArena, ancestry, position, result); return {result, ancestry, context}; } @@ -1385,13 +1383,13 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } -static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, const TypeChecker& typeChecker, - TypeArena* typeArena, Position position, StringCompletionCallback callback) +static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull singletonTypes, + Scope* globalScope, Position position, StringCompletionCallback callback) { if (isWithinComment(sourceModule, position)) return {}; - NotNull singletonTypes = typeChecker.singletonTypes; + TypeArena typeArena; std::vector ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position); LUAU_ASSERT(!ancestry.empty()); @@ -1419,11 +1417,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; if (!FFlag::LuauSelfCallAutocompleteFix3 && isString(ty)) - return {autocompleteProps( - *module, typeArena, singletonTypes, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), + return {autocompleteProps(*module, &typeArena, singletonTypes, globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), ancestry, AutocompleteContext::Property}; else - return {autocompleteProps(*module, typeArena, singletonTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; + return {autocompleteProps(*module, &typeArena, singletonTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; } else if (auto typeReference = node->as()) { @@ -1441,7 +1438,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Unknown}; else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) - return autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); else return {}; } @@ -1455,7 +1452,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); return {}; } @@ -1485,7 +1482,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; if (lastExpr->location.containsClosed(position)) - return autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); if (position > lastExpr->location.end) return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; @@ -1509,7 +1506,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; if (!statWhile->hasDo || position < statWhile->doLocation.begin) - return autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); if (statWhile->hasDo && position > statWhile->doLocation.end) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1526,7 +1523,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatIf* statIf = parent->as(); statIf && node->is()) { if (statIf->condition->is()) - return autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; } @@ -1534,7 +1531,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) - return autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is())) @@ -1546,7 +1543,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, typeArena, singletonTypes, *it, PropIndexType::Key, ancestry); + auto result = autocompleteProps(*module, &typeArena, singletonTypes, *it, PropIndexType::Key, ancestry); // Remove keys that are already completed for (const auto& item : exprTable->items) @@ -1560,7 +1557,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // If we know for sure that a key is being written, do not offer general expression suggestions if (!key) - autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position, result); + autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position, result); return {result, ancestry, AutocompleteContext::Property}; } @@ -1588,7 +1585,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) { if (auto it = module->astTypes.find(idxExpr->expr)) - autocompleteProps(*module, typeArena, singletonTypes, follow(*it), PropIndexType::Point, ancestry, result); + autocompleteProps(*module, &typeArena, singletonTypes, follow(*it), PropIndexType::Point, ancestry, result); } else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) { @@ -1604,12 +1601,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } if (node->is()) - { return {}; - } if (node->asExpr()) - return autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); else if (node->asStat()) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1628,15 +1623,15 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName if (!sourceModule) return {}; - TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete; ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName); if (!module) return {}; - AutocompleteResult autocompleteResult = autocomplete(*sourceModule, module, typeChecker, &frontend.arenaForAutocomplete, position, callback); + NotNull singletonTypes = frontend.singletonTypes; + Scope* globalScope = frontend.typeCheckerForAutocomplete.globalScope.get(); - frontend.arenaForAutocomplete.clear(); + AutocompleteResult autocompleteResult = autocomplete(*sourceModule, module, singletonTypes, globalScope, position, callback); return autocompleteResult; } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 8f4863d00..dbe27bfd4 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -1,18 +1,22 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +#include "Luau/Ast.h" #include "Luau/Frontend.h" #include "Luau/Symbol.h" #include "Luau/Common.h" #include "Luau/ToString.h" #include "Luau/ConstraintSolver.h" #include "Luau/TypeInfer.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" #include LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false) +LUAU_FASTFLAG(LuauReportShadowedTypeAlias) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -34,7 +38,9 @@ static std::optional> magicFunctionPack( static std::optional> magicFunctionRequire( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionRequire(NotNull solver, TypePackId result, const AstExprCall* expr); + +static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); +static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -226,7 +232,22 @@ void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::strin } } -void registerBuiltinTypes(TypeChecker& typeChecker) +void registerBuiltinTypes(Frontend& frontend) +{ + frontend.getGlobalScope()->addBuiltinTypeBinding("any", TypeFun{{}, frontend.singletonTypes->anyType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("nil", TypeFun{{}, frontend.singletonTypes->nilType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("number", TypeFun{{}, frontend.singletonTypes->numberType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("string", TypeFun{{}, frontend.singletonTypes->stringType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("boolean", TypeFun{{}, frontend.singletonTypes->booleanType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("thread", TypeFun{{}, frontend.singletonTypes->threadType}); + if (FFlag::LuauUnknownAndNeverType) + { + frontend.getGlobalScope()->addBuiltinTypeBinding("unknown", TypeFun{{}, frontend.singletonTypes->unknownType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("never", TypeFun{{}, frontend.singletonTypes->neverType}); + } +} + +void registerBuiltinGlobals(TypeChecker& typeChecker) { LUAU_ASSERT(!typeChecker.globalTypes.typeVars.isFrozen()); LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); @@ -303,6 +324,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) attachMagicFunction(getGlobalBinding(typeChecker, "assert"), magicFunctionAssert); attachMagicFunction(getGlobalBinding(typeChecker, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(typeChecker, "select"), magicFunctionSelect); + attachDcrMagicFunction(getGlobalBinding(typeChecker, "select"), dcrMagicFunctionSelect); if (TableTypeVar* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) { @@ -317,12 +339,13 @@ void registerBuiltinTypes(TypeChecker& typeChecker) attachDcrMagicFunction(getGlobalBinding(typeChecker, "require"), dcrMagicFunctionRequire); } -void registerBuiltinTypes(Frontend& frontend) +void registerBuiltinGlobals(Frontend& frontend) { LUAU_ASSERT(!frontend.globalTypes.typeVars.isFrozen()); LUAU_ASSERT(!frontend.globalTypes.typePacks.isFrozen()); - TypeId nilType = frontend.typeChecker.nilType; + if (FFlag::LuauReportShadowedTypeAlias) + registerBuiltinTypes(frontend); TypeArena& arena = frontend.globalTypes; NotNull singletonTypes = frontend.singletonTypes; @@ -352,7 +375,7 @@ void registerBuiltinTypes(Frontend& frontend) TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) addGlobalBinding(frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); @@ -394,6 +417,7 @@ void registerBuiltinTypes(Frontend& frontend) attachMagicFunction(getGlobalBinding(frontend, "assert"), magicFunctionAssert); attachMagicFunction(getGlobalBinding(frontend, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(frontend, "select"), magicFunctionSelect); + attachDcrMagicFunction(getGlobalBinding(frontend, "select"), dcrMagicFunctionSelect); if (TableTypeVar* ttv = getMutable(getGlobalBinding(frontend, "table"))) { @@ -408,7 +432,6 @@ void registerBuiltinTypes(Frontend& frontend) attachDcrMagicFunction(getGlobalBinding(frontend, "require"), dcrMagicFunctionRequire); } - static std::optional> magicFunctionSelect( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { @@ -450,6 +473,50 @@ static std::optional> magicFunctionSelect( return std::nullopt; } +static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) +{ + if (context.callSite->args.size <= 0) + { + context.solver->reportError(TypeError{context.callSite->location, GenericError{"select should take 1 or more arguments"}}); + return false; + } + + AstExpr* arg1 = context.callSite->args.data[0]; + + if (AstExprConstantNumber* num = arg1->as()) + { + const auto& [v, tail] = flatten(context.arguments); + + int offset = int(num->value); + if (offset > 0) + { + if (size_t(offset) < v.size()) + { + std::vector res(v.begin() + offset, v.end()); + TypePackId resTypePack = context.solver->arena->addTypePack({std::move(res), tail}); + asMutable(context.result)->ty.emplace(resTypePack); + } + else if (tail) + asMutable(context.result)->ty.emplace(*tail); + + return true; + } + + return false; + } + + if (AstExprConstantString* str = arg1->as()) + { + if (str->value.size == 1 && str->value.data[0] == '#') { + TypePackId numberTypePack = context.solver->arena->addTypePack({context.solver->singletonTypes->numberType}); + asMutable(context.result)->ty.emplace(numberTypePack); + return true; + } + } + + return false; +} + static std::optional> magicFunctionSetMetaTable( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { @@ -675,22 +742,22 @@ static bool checkRequirePathDcr(NotNull solver, AstExpr* expr) return good; } -static bool dcrMagicFunctionRequire(NotNull solver, TypePackId result, const AstExprCall* expr) +static bool dcrMagicFunctionRequire(MagicFunctionCallContext context) { - if (expr->args.size != 1) + if (context.callSite->args.size != 1) { - solver->reportError(GenericError{"require takes 1 argument"}, expr->location); + context.solver->reportError(GenericError{"require takes 1 argument"}, context.callSite->location); return false; } - if (!checkRequirePathDcr(solver, expr->args.data[0])) + if (!checkRequirePathDcr(context.solver, context.callSite->args.data[0])) return false; - if (auto moduleInfo = solver->moduleResolver->resolveModuleInfo(solver->currentModuleName, *expr)) + if (auto moduleInfo = context.solver->moduleResolver->resolveModuleInfo(context.solver->currentModuleName, *context.callSite)) { - TypeId moduleType = solver->resolveModule(*moduleInfo, expr->location); - TypePackId moduleResult = solver->arena->addTypePack({moduleType}); - asMutable(result)->ty.emplace(moduleResult); + TypeId moduleType = context.solver->resolveModule(*moduleInfo, context.callSite->location); + TypePackId moduleResult = context.solver->arena->addTypePack({moduleType}); + asMutable(context.result)->ty.emplace(moduleResult); return true; } diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 7048d201b..fd3a089b4 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -220,6 +220,9 @@ void TypeCloner::operator()(const SingletonTypeVar& t) void TypeCloner::operator()(const FunctionTypeVar& t) { + // FISHY: We always erase the scope when we clone things. clone() was + // originally written so that we could copy a module's type surface into an + // export arena. This probably dates to that. TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); FunctionTypeVar* ftv = getMutable(result); LUAU_ASSERT(ftv != nullptr); @@ -436,7 +439,7 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl if (const FunctionTypeVar* ftv = get(ty)) { - FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; + FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; clone.genericPacks = ftv->genericPacks; clone.magicFunction = ftv->magicFunction; @@ -448,7 +451,7 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl else if (const TableTypeVar* ttv = get(ty)) { LUAU_ASSERT(!ttv->boundTo); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; diff --git a/Analysis/src/Config.cpp b/Analysis/src/Config.cpp index 35a2259d1..00ca7b16f 100644 --- a/Analysis/src/Config.cpp +++ b/Analysis/src/Config.cpp @@ -4,15 +4,18 @@ #include "Luau/Lexer.h" #include "Luau/StringUtils.h" -namespace +LUAU_FASTFLAGVARIABLE(LuauEnableNonstrictByDefaultForLuauConfig, false) + +namespace Luau { using Error = std::optional; -} - -namespace Luau +Config::Config() + : mode(FFlag::LuauEnableNonstrictByDefaultForLuauConfig ? Mode::Nonstrict : Mode::NoCheck) { + enabledLint.setDefaults(); +} static Error parseBoolean(bool& result, const std::string& value) { diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index aa1e9547d..169f46452 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -11,6 +11,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAG(DebugLuauMagicTypes); #include "Luau/Scope.h" @@ -218,6 +219,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) visit(scope, s); else if (auto s = stat->as()) visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); else LUAU_ASSERT(0); } @@ -454,8 +457,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct TypeId containingTableType = check(scope, indexName->expr); functionType = arena->addType(BlockedTypeVar{}); - TypeId prospectiveTableType = - arena->addType(TableTypeVar{}); // TODO look into stack utilization. This is probably ok because it scales with AST depth. + + // TODO look into stack utilization. This is probably ok because it scales with AST depth. + TypeId prospectiveTableType = arena->addType(TableTypeVar{TableState::Unsealed, TypeLevel{}, scope.get()}); + NotNull prospectiveTable{getMutable(prospectiveTableType)}; Property& prop = prospectiveTable->props[indexName->index.value]; @@ -619,7 +624,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d TypeId classTy = arena->addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, moduleName)); ClassTypeVar* ctv = getMutable(classTy); - TypeId metaTy = arena->addType(TableTypeVar{TableState::Sealed, scope->level}); + TypeId metaTy = arena->addType(TableTypeVar{TableState::Sealed, scope->level, scope.get()}); TableTypeVar* metatable = getMutable(metaTy); ctv->metatable = metaTy; @@ -715,7 +720,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction TypePackId paramPack = resolveTypePack(funScope, global->params); TypePackId retPack = resolveTypePack(funScope, global->retTypes); - TypeId fnType = arena->addType(FunctionTypeVar{funScope->level, std::move(genericTys), std::move(genericTps), paramPack, retPack}); + TypeId fnType = arena->addType(FunctionTypeVar{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); FunctionTypeVar* ftv = getMutable(fnType); ftv->argNames.reserve(global->paramNames.size); @@ -728,6 +733,14 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction scope->bindings[global->name] = Binding{fnType, global->location}; } +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) +{ + for (AstStat* stat : error->statements) + visit(scope, stat); + for (AstExpr* expr : error->expressions) + check(scope, expr); +} + TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes) { std::vector head; @@ -745,7 +758,9 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray expectedTailTypes{begin(expectedTypes) + i, end(expectedTypes)}; + std::vector expectedTailTypes; + if (i < expectedTypes.size()) + expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); tail = checkPack(scope, expr, expectedTailTypes); } } @@ -803,7 +818,8 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp TypeId instantiatedType = arena->addType(BlockedTypeVar{}); // TODO: How do expectedTypes play into this? Do they? TypePackId rets = arena->addTypePack(BlockedTypePack{}); - FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); + TypePackId argPack = arena->addTypePack(TypePack{args, {}}); + FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); TypeId inferredFnType = arena->addType(ftv); scope->unqueuedConstraints.push_back( @@ -834,6 +850,7 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp FunctionCallConstraint{ {ic, sc}, fnType, + argPack, rets, call, }); @@ -968,6 +985,9 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: else if (auto err = expr->as()) { // Open question: Should we traverse into this? + for (AstExpr* subExpr : err->expressions) + check(scope, subExpr); + result = singletonTypes->errorRecoveryType(); } else @@ -988,7 +1008,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* in TableTypeVar::Props props{{indexName->index.value, Property{result}}}; const std::optional indexer; - TableTypeVar ttv{std::move(props), indexer, TypeLevel{}, TableState::Free}; + TableTypeVar ttv{std::move(props), indexer, TypeLevel{}, scope.get(), TableState::Free}; TypeId expectedTableType = arena->addType(std::move(ttv)); @@ -1005,7 +1025,8 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* in TypeId result = freshType(scope); TableIndexer indexer{indexType, result}; - TypeId tableType = arena->addType(TableTypeVar{TableTypeVar::Props{}, TableIndexer{indexType, result}, TypeLevel{}, TableState::Free}); + TypeId tableType = + arena->addType(TableTypeVar{TableTypeVar::Props{}, TableIndexer{indexType, result}, TypeLevel{}, scope.get(), TableState::Free}); addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); @@ -1094,6 +1115,9 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, TableTypeVar* ttv = getMutable(ty); LUAU_ASSERT(ttv); + ttv->state = TableState::Unsealed; + ttv->scope = scope.get(); + auto createIndexer = [this, scope, ttv](const Location& location, TypeId currentIndexType, TypeId currentResultType) { if (!ttv->indexer) { @@ -1195,7 +1219,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS } else { - bodyScope = childScope(fn->body, parent); + bodyScope = childScope(fn, parent); returnType = freshTypePack(bodyScope); bodyScope->returnType = returnType; @@ -1260,7 +1284,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // TODO: Vararg annotation. // TODO: Preserve argument names in the function's type. - FunctionTypeVar actualFunction{arena->addTypePack(argTypes, varargPack), returnType}; + FunctionTypeVar actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; actualFunction.hasNoGenerics = !hasGenerics; actualFunction.generics = std::move(genericTypes); actualFunction.genericPacks = std::move(genericTypePacks); @@ -1297,6 +1321,22 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b if (auto ref = ty->as()) { + if (FFlag::DebugLuauMagicTypes) + { + if (ref->name == "_luau_ice") + ice->ice("_luau_ice encountered", ty->location); + else if (ref->name == "_luau_print") + { + if (ref->parameters.size != 1 || !ref->parameters.data[0].type) + { + reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); + return singletonTypes->errorRecoveryType(); + } + else + return resolveType(scope, ref->parameters.data[0].type, topLevel); + } + } + std::optional alias = scope->lookupType(ref->name.value); if (alias.has_value() || ref->prefix.has_value()) @@ -1369,7 +1409,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b }; } - result = arena->addType(TableTypeVar{props, indexer, scope->level, TableState::Sealed}); + result = arena->addType(TableTypeVar{props, indexer, scope->level, scope.get(), TableState::Sealed}); } else if (auto fn = ty->as()) { @@ -1414,7 +1454,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b // TODO: FunctionTypeVar needs a pointer to the scope so that we know // how to quantify/instantiate it. - FunctionTypeVar ftv{argTypes, returnTypes}; + FunctionTypeVar ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; // This replicates the behavior of the appropriate FunctionTypeVar // constructors. diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index b2bf773f4..35b8387fd 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -8,9 +8,11 @@ #include "Luau/ModuleResolver.h" #include "Luau/Quantify.h" #include "Luau/ToString.h" +#include "Luau/TypeVar.h" #include "Luau/Unifier.h" #include "Luau/DcrLogger.h" #include "Luau/VisitTypeVar.h" +#include "Luau/TypeUtils.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); @@ -439,6 +441,7 @@ bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNullscope); + return true; } @@ -465,7 +468,7 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNullscope); std::optional instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS @@ -909,7 +912,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulldcrMagicFunction != nullptr) { - usedMagic = ftv->dcrMagicFunction(NotNull(this), result, c.astFragment); + usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull(this), c.callSite, c.argsPack, result}); } if (usedMagic) @@ -1087,6 +1090,63 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl else errorify(c.variables); } + else if (std::optional iterFn = findMetatableEntry(singletonTypes, errors, iteratorTy, "__iter", Location{})) + { + if (isBlocked(*iterFn)) + { + return block(*iterFn, constraint); + } + + Instantiation instantiation(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); + + if (std::optional instantiatedIterFn = instantiation.substitute(*iterFn)) + { + if (auto iterFtv = get(*instantiatedIterFn)) + { + TypePackId expectedIterArgs = arena->addTypePack({iteratorTy}); + unify(iterFtv->argTypes, expectedIterArgs, constraint->scope); + + std::vector iterRets = flatten(*arena, singletonTypes, iterFtv->retTypes, 2); + + if (iterRets.size() < 1) + { + // We've done what we can; this will get reported as an + // error by the type checker. + return true; + } + + TypeId nextFn = iterRets[0]; + TypeId table = iterRets.size() == 2 ? iterRets[1] : arena->freshType(constraint->scope); + + if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) + { + const TypeId firstIndex = arena->freshType(constraint->scope); + + // nextTy : (iteratorTy, indexTy?) -> (indexTy, valueTailTy...) + const TypePackId nextArgPack = arena->addTypePack({table, arena->addType(UnionTypeVar{{firstIndex, singletonTypes->nilType}})}); + const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); + const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); + + const TypeId expectedNextTy = arena->addType(FunctionTypeVar{nextArgPack, nextRetPack}); + unify(*instantiatedNextFn, expectedNextTy, constraint->scope); + + pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); + } + else + { + reportError(UnificationTooComplex{}, constraint->location); + } + } + else + { + // TODO: Support __call and function overloads (what does an overload even mean for this?) + } + } + else + { + reportError(UnificationTooComplex{}, constraint->location); + } + } else if (auto iteratorMetatable = get(iteratorTy)) { TypeId metaTy = follow(iteratorMetatable->metatable); @@ -1124,7 +1184,7 @@ bool ConstraintSolver::tryDispatchIterableFunction( const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); - const TypeId expectedNextTy = arena->addType(FunctionTypeVar{nextArgPack, nextRetPack}); + const TypeId expectedNextTy = arena->addType(FunctionTypeVar{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); unify(nextTy, expectedNextTy, constraint->scope); pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); @@ -1297,6 +1357,7 @@ void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull sc { UnifierSharedState sharedState{&iceReporter}; Unifier u{arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; + u.useScopes = true; u.tryUnify(subType, superType); @@ -1319,6 +1380,7 @@ void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNullbegin.line + 1); + return s; } std::string operator()(const Luau::CountMismatch& e) const @@ -183,11 +186,14 @@ struct ErrorConverter case CountMismatch::Return: return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + actualVerb + " returned here"; - case CountMismatch::Result: + case CountMismatch::FunctionResult: // It is alright if right hand side produces more values than the // left hand side accepts. In this context consider only the opposite case. - return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + - " are required here"; + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + + actualVerb + " required here"; + case CountMismatch::ExprListResult: + return "Expression list has " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + + actualVerb + " required here"; case CountMismatch::Arg: if (!e.function.empty()) return "Argument count mismatch. Function '" + e.function + "' " + diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 01e82baad..1890e0811 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -400,6 +400,7 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, singletonTypes, &iceHandler) , configResolver(configResolver) , options(options) + , globalScope(typeChecker.globalScope) { } @@ -505,7 +506,10 @@ CheckResult Frontend::check(const ModuleName& name, std::optional requireCycles) +ModulePtr Frontend::check( + const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector requireCycles, bool forAutocomplete) { ModulePtr result = std::make_shared(); @@ -852,7 +857,11 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco } } - ConstraintGraphBuilder cgb{sourceModule.name, result, &result->internalTypes, NotNull(&moduleResolver), singletonTypes, NotNull(&iceHandler), getGlobalScope(), logger.get()}; + const NotNull mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; + const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; + + ConstraintGraphBuilder cgb{ + sourceModule.name, result, &result->internalTypes, mr, singletonTypes, NotNull(&iceHandler), globalScope, logger.get()}; cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 2d1d62f31..3d0cd0d11 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -44,7 +44,7 @@ TypeId Instantiation::clean(TypeId ty) const FunctionTypeVar* ftv = log->getMutable(ty); LUAU_ASSERT(ftv); - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; + FunctionTypeVar clone = FunctionTypeVar{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.magicFunction = ftv->magicFunction; clone.dcrMagicFunction = ftv->dcrMagicFunction; clone.tags = ftv->tags; @@ -53,7 +53,7 @@ TypeId Instantiation::clean(TypeId ty) // Annoyingly, we have to do this even if there are no generics, // to replace any generic tables. - ReplaceGenerics replaceGenerics{log, arena, level, ftv->generics, ftv->genericPacks}; + ReplaceGenerics replaceGenerics{log, arena, level, scope, ftv->generics, ftv->genericPacks}; // TODO: What to do if this returns nullopt? // We don't have access to the error-reporting machinery @@ -114,12 +114,12 @@ TypeId ReplaceGenerics::clean(TypeId ty) LUAU_ASSERT(isDirty(ty)); if (const TableTypeVar* ttv = log->getMutable(ty)) { - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, scope, TableState::Free}; clone.definitionModuleName = ttv->definitionModuleName; return addType(std::move(clone)); } else - return addType(FreeTypeVar{level}); + return addType(FreeTypeVar{scope, level}); } TypePackId ReplaceGenerics::clean(TypePackId tp) diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 7e6ff2f99..e4c069bd6 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -15,19 +15,6 @@ LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) namespace Luau { -/// @return true if outer encloses inner -static bool subsumes(Scope* outer, Scope* inner) -{ - while (inner) - { - if (inner == outer) - return true; - inner = inner->parent.get(); - } - - return false; -} - struct Quantifier final : TypeVarOnceVisitor { TypeLevel level; @@ -43,12 +30,6 @@ struct Quantifier final : TypeVarOnceVisitor LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); } - explicit Quantifier(Scope* scope) - : scope(scope) - { - LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); - } - /// @return true if outer encloses inner bool subsumes(Scope* outer, Scope* inner) { @@ -66,13 +47,10 @@ struct Quantifier final : TypeVarOnceVisitor { seenMutableType = true; - if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ftv.scope) : !level.subsumes(ftv.level)) + if (!level.subsumes(ftv.level)) return false; - if (FFlag::DebugLuauDeferredConstraintResolution) - *asMutable(ty) = GenericTypeVar{scope}; - else - *asMutable(ty) = GenericTypeVar{level}; + *asMutable(ty) = GenericTypeVar{level}; generics.push_back(ty); @@ -85,7 +63,7 @@ struct Quantifier final : TypeVarOnceVisitor seenMutableType = true; - if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ctv->scope) : !level.subsumes(ctv->level)) + if (!level.subsumes(ctv->level)) return false; std::vector opts = std::move(ctv->parts); @@ -113,7 +91,7 @@ struct Quantifier final : TypeVarOnceVisitor if (ttv.state == TableState::Free) seenMutableType = true; - if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ttv.scope) : !level.subsumes(ttv.level)) + if (!level.subsumes(ttv.level)) { if (ttv.state == TableState::Unsealed) seenMutableType = true; @@ -137,7 +115,7 @@ struct Quantifier final : TypeVarOnceVisitor { seenMutableType = true; - if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ftp.scope) : !level.subsumes(ftp.level)) + if (!level.subsumes(ftp.level)) return false; *asMutable(tp) = GenericTypePack{level}; @@ -197,20 +175,6 @@ void quantify(TypeId ty, TypeLevel level) } } -void quantify(TypeId ty, Scope* scope) -{ - Quantifier q{scope}; - q.traverse(ty); - - FunctionTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); - - if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) - ftv->hasNoGenerics = true; -} - struct PureQuantifier : Substitution { Scope* scope; @@ -253,7 +217,7 @@ struct PureQuantifier : Substitution { if (auto ftv = get(ty)) { - TypeId result = arena->addType(GenericTypeVar{}); + TypeId result = arena->addType(GenericTypeVar{scope}); insertedGenerics.push_back(result); return result; } @@ -264,7 +228,8 @@ struct PureQuantifier : Substitution LUAU_ASSERT(resultTable); *resultTable = *ttv; - resultTable->scope = nullptr; + resultTable->level = TypeLevel{}; + resultTable->scope = scope; resultTable->state = TableState::Generic; return result; @@ -306,6 +271,7 @@ TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope) FunctionTypeVar* ftv = getMutable(*result); LUAU_ASSERT(ftv); + ftv->scope = scope; ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty(); diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index c129b9733..9a7d36090 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -21,6 +21,12 @@ Scope::Scope(const ScopePtr& parent, int subLevel) level.subLevel = subLevel; } +void Scope::addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun) +{ + exportedTypeBindings[name] = tyFun; + builtinTypeNames.insert(name); +} + std::optional Scope::lookupType(const Name& name) { const Scope* scope = this; @@ -82,9 +88,9 @@ std::optional Scope::lookupPack(const Name& name) } } -std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) +std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) const { - Scope* scope = this; + const Scope* scope = this; while (scope) { @@ -122,4 +128,22 @@ std::optional Scope::lookup(Symbol sym) } } +bool subsumesStrict(Scope* left, Scope* right) +{ + while (right) + { + if (right->parent.get() == left) + return true; + + right = right->parent.get(); + } + + return false; +} + +bool subsumes(Scope* left, Scope* right) +{ + return left == right || subsumesStrict(left, right); +} + } // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 0b389547e..135602511 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) +LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) /* * Prefix generic typenames with gen- @@ -631,6 +632,10 @@ struct TypeVarStringifier state.emit("{"); stringify(ttv.indexer->indexResultType); state.emit("}"); + + if (FFlag::LuauUnseeArrayTtv) + state.unsee(&ttv); + return; } diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 74d77307f..06bde1950 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -289,6 +289,45 @@ PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) return newTp; } +PendingType* TxnLog::changeScope(TypeId ty, NotNull newScope) +{ + LUAU_ASSERT(get(ty) || get(ty) || get(ty) || get(ty)); + + PendingType* newTy = queue(ty); + if (FreeTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->scope = newScope; + } + else if (TableTypeVar* ttv = Luau::getMutable(newTy)) + { + LUAU_ASSERT(ttv->state == TableState::Free || ttv->state == TableState::Generic); + ttv->scope = newScope; + } + else if (FunctionTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->scope = newScope; + } + else if (ConstrainedTypeVar* ctv = Luau::getMutable(newTy)) + { + ctv->scope = newScope; + } + + return newTy; +} + +PendingTypePack* TxnLog::changeScope(TypePackId tp, NotNull newScope) +{ + LUAU_ASSERT(get(tp)); + + PendingTypePack* newTp = queue(tp); + if (FreeTypePack* ftp = Luau::getMutable(newTp)) + { + ftp->scope = newScope; + } + + return newTp; +} + PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexer) { LUAU_ASSERT(get(ty)); diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index abf31aee2..666ab8674 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -40,6 +40,15 @@ TypeId TypeArena::freshType(Scope* scope) return allocated; } +TypeId TypeArena::freshType(Scope* scope, TypeLevel level) +{ + TypeId allocated = typeVars.allocate(FreeTypeVar{scope, level}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + TypePackId TypeArena::freshTypePack(Scope* scope) { TypePackId allocated = typePacks.allocate(FreeTypePack{scope}); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index ea06882a5..f98a2123e 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -17,10 +17,16 @@ #include LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAG(DebugLuauMagicTypes); namespace Luau { +// TypeInfer.h +// TODO move these +using PrintLineProc = void(*)(const std::string&); +extern PrintLineProc luauPrintLine; + /* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. * TypeChecker2 uses this to maintain knowledge about which scope encloses every * given AstNode. @@ -114,6 +120,19 @@ struct TypeChecker2 TypeId lookupAnnotation(AstType* annotation) { + if (FFlag::DebugLuauMagicTypes) + { + if (auto ref = annotation->as(); ref && ref->name == "_luau_print" && ref->parameters.size > 0) + { + if (auto ann = ref->parameters.data[0].type) + { + TypeId argTy = lookupAnnotation(ref->parameters.data[0].type); + luauPrintLine(format("_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str())); + return follow(argTy); + } + } + } + TypeId* ty = module->astResolvedTypes.find(annotation); LUAU_ASSERT(ty); return follow(*ty); @@ -284,50 +303,49 @@ struct TypeChecker2 void visit(AstStatLocal* local) { - for (size_t i = 0; i < local->values.size; ++i) + size_t count = std::max(local->values.size, local->vars.size); + for (size_t i = 0; i < count; ++i) { - AstExpr* value = local->values.data[i]; + AstExpr* value = i < local->values.size ? local->values.data[i] : nullptr; - visit(value); + if (value) + visit(value); - if (i == local->values.size - 1) + if (i != local->values.size - 1) { - if (i < local->values.size) - { - TypePackId valueTypes = lookupPack(value); - auto it = begin(valueTypes); - for (size_t j = i; j < local->vars.size; ++j) - { - if (it == end(valueTypes)) - { - break; - } + AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; - AstLocal* var = local->vars.data[i]; - if (var->annotation) - { - TypeId varType = lookupAnnotation(var->annotation); - ErrorVec errors = tryUnify(stack.back(), value->location, *it, varType); - if (!errors.empty()) - reportErrors(std::move(errors)); - } - - ++it; - } + if (var && var->annotation) + { + TypeId varType = lookupAnnotation(var->annotation); + TypeId valueType = value ? lookupType(value) : nullptr; + if (valueType && !isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + reportError(TypeMismatch{varType, valueType}, value->location); } } else { - TypeId valueType = lookupType(value); - AstLocal* var = local->vars.data[i]; + LUAU_ASSERT(value); - if (var->annotation) + TypePackId valueTypes = lookupPack(value); + auto it = begin(valueTypes); + for (size_t j = i; j < local->vars.size; ++j) { - TypeId varType = lookupAnnotation(var->annotation); - if (!isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (it == end(valueTypes)) { - reportError(TypeMismatch{varType, valueType}, value->location); + break; } + + AstLocal* var = local->vars.data[i]; + if (var->annotation) + { + TypeId varType = lookupAnnotation(var->annotation); + ErrorVec errors = tryUnify(stack.back(), value->location, *it, varType); + if (!errors.empty()) + reportErrors(std::move(errors)); + } + + ++it; } } } @@ -345,50 +363,6 @@ struct TypeChecker2 visit(forStatement->body); } - // "Render" a type pack out to an array of a given length. Expands - // variadics and various other things to get there. - std::vector flatten(TypeArena& arena, TypePackId pack, size_t length) - { - std::vector result; - - auto it = begin(pack); - auto endIt = end(pack); - - while (it != endIt) - { - result.push_back(*it); - - if (result.size() >= length) - return result; - - ++it; - } - - if (!it.tail()) - return result; - - TypePackId tail = *it.tail(); - if (get(tail)) - LUAU_ASSERT(0); - else if (auto vtp = get(tail)) - { - while (result.size() < length) - result.push_back(vtp->ty); - } - else if (get(tail) || get(tail)) - { - while (result.size() < length) - result.push_back(arena.addType(FreeTypeVar{nullptr})); - } - else if (auto etp = get(tail)) - { - while (result.size() < length) - result.push_back(singletonTypes->errorRecoveryType()); - } - - return result; - } - void visit(AstStatForIn* forInStatement) { for (AstLocal* local : forInStatement->vars) @@ -426,7 +400,7 @@ struct TypeChecker2 TypePackId iteratorPack = arena.addTypePack(valueTypes, iteratorTail); // ... and then expand it out to 3 values (if possible) - const std::vector iteratorTypes = flatten(arena, iteratorPack, 3); + const std::vector iteratorTypes = flatten(arena, singletonTypes, iteratorPack, 3); if (iteratorTypes.empty()) { reportError(GenericError{"for..in loops require at least one value to iterate over. Got zero"}, getLocation(forInStatement->values)); @@ -434,34 +408,31 @@ struct TypeChecker2 } TypeId iteratorTy = follow(iteratorTypes[0]); - /* - * If the first iterator argument is a function - * * There must be 1 to 3 iterator arguments. Name them (nextTy, - * arrayTy, startIndexTy) - * * The return type of nextTy() must correspond to the variables' - * types and counts. HOWEVER the first iterator will never be nil. - * * The first return value of nextTy must be compatible with - * startIndexTy. - * * The first argument to nextTy() must be compatible with arrayTy if - * present. nil if not. - * * The second argument to nextTy() must be compatible with - * startIndexTy if it is present. Else, it must be compatible with - * nil. - * * nextTy() must be callable with only 2 arguments. - */ - if (const FunctionTypeVar* nextFn = get(iteratorTy)) + auto checkFunction = [this, &arena, &scope, &forInStatement, &variableTypes](const FunctionTypeVar* iterFtv, std::vector iterTys, bool isMm) { - if (iteratorTypes.size() < 1 || iteratorTypes.size() > 3) - reportError(GenericError{"for..in loops must be passed (next, [table[, state]])"}, getLocation(forInStatement->values)); + if (iterTys.size() < 1 || iterTys.size() > 3) + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values)); + return; + } + // It is okay if there aren't enough iterators, but the iteratee must provide enough. - std::vector expectedVariableTypes = flatten(arena, nextFn->retTypes, variableTypes.size()); + std::vector expectedVariableTypes = flatten(arena, singletonTypes, iterFtv->retTypes, variableTypes.size()); if (expectedVariableTypes.size() < variableTypes.size()) - reportError(GenericError{"next() does not return enough values"}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values)); + else + reportError(GenericError{"next() does not return enough values"}, forInStatement->values.data[0]->location); + } for (size_t i = 0; i < std::min(expectedVariableTypes.size(), variableTypes.size()); ++i) reportErrors(tryUnify(scope, forInStatement->vars.data[i]->location, variableTypes[i], expectedVariableTypes[i])); - + // nextFn is going to be invoked with (arrayTy, startIndexTy) // It will be passed two arguments on every iteration save the @@ -474,17 +445,15 @@ struct TypeChecker2 // If iteratorTypes is too short to be a valid call to nextFn, we have to report a count mismatch error. // If 2 is too short to be a valid call to nextFn, we have to report a count mismatch error. // If 2 is too long to be a valid call to nextFn, we have to report a count mismatch error. - auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), nextFn->argTypes, /*includeHiddenVariadics*/ true); + auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), iterFtv->argTypes, /*includeHiddenVariadics*/ true); if (minCount > 2) reportError(CountMismatch{2, std::nullopt, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); if (maxCount && *maxCount < 2) reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - const std::vector flattenedArgTypes = flatten(arena, nextFn->argTypes, 2); - const auto [argTypes, argsTail] = Luau::flatten(nextFn->argTypes); - - size_t firstIterationArgCount = iteratorTypes.empty() ? 0 : iteratorTypes.size() - 1; + const std::vector flattenedArgTypes = flatten(arena, singletonTypes, iterFtv->argTypes, 2); + size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; size_t actualArgCount = expectedVariableTypes.size(); if (firstIterationArgCount < minCount) @@ -492,17 +461,37 @@ struct TypeChecker2 else if (actualArgCount < minCount) reportError(CountMismatch{2, std::nullopt, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - if (iteratorTypes.size() >= 2 && flattenedArgTypes.size() > 0) + if (iterTys.size() >= 2 && flattenedArgTypes.size() > 0) { size_t valueIndex = forInStatement->values.size > 1 ? 1 : 0; - reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iteratorTypes[1], flattenedArgTypes[0])); + reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[1], flattenedArgTypes[0])); } - if (iteratorTypes.size() == 3 && flattenedArgTypes.size() > 1) + if (iterTys.size() == 3 && flattenedArgTypes.size() > 1) { size_t valueIndex = forInStatement->values.size > 2 ? 2 : 0; - reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iteratorTypes[2], flattenedArgTypes[1])); + reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[2], flattenedArgTypes[1])); } + }; + + /* + * If the first iterator argument is a function + * * There must be 1 to 3 iterator arguments. Name them (nextTy, + * arrayTy, startIndexTy) + * * The return type of nextTy() must correspond to the variables' + * types and counts. HOWEVER the first iterator will never be nil. + * * The first return value of nextTy must be compatible with + * startIndexTy. + * * The first argument to nextTy() must be compatible with arrayTy if + * present. nil if not. + * * The second argument to nextTy() must be compatible with + * startIndexTy if it is present. Else, it must be compatible with + * nil. + * * nextTy() must be callable with only 2 arguments. + */ + if (const FunctionTypeVar* nextFn = get(iteratorTy)) + { + checkFunction(nextFn, iteratorTypes, false); } else if (const TableTypeVar* ttv = get(iteratorTy)) { @@ -519,6 +508,62 @@ struct TypeChecker2 { // nothing } + else if (std::optional iterMmTy = findMetatableEntry(singletonTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) + { + Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}, scope}; + + if (std::optional instantiatedIterMmTy = instantiation.substitute(*iterMmTy)) + { + if (const FunctionTypeVar* iterMmFtv = get(*instantiatedIterMmTy)) + { + TypePackId argPack = arena.addTypePack({iteratorTy}); + reportErrors(tryUnify(scope, forInStatement->values.data[0]->location, argPack, iterMmFtv->argTypes)); + + std::vector mmIteratorTypes = flatten(arena, singletonTypes, iterMmFtv->retTypes, 3); + + if (mmIteratorTypes.size() == 0) + { + reportError(GenericError{"__iter must return at least one value"}, forInStatement->values.data[0]->location); + return; + } + + TypeId nextFn = follow(mmIteratorTypes[0]); + + if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) + { + std::vector instantiatedIteratorTypes = mmIteratorTypes; + instantiatedIteratorTypes[0] = *instantiatedNextFn; + + if (const FunctionTypeVar* nextFtv = get(*instantiatedNextFn)) + { + checkFunction(nextFtv, instantiatedIteratorTypes, true); + } + else + { + reportError(CannotCallNonFunction{*instantiatedNextFn}, forInStatement->values.data[0]->location); + } + } + else + { + reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); + } + } + else + { + // TODO: This will not tell the user that this is because the + // metamethod isn't callable. This is not ideal, and we should + // improve this error message. + + // TODO: This will also not handle intersections of functions or + // callable tables (which are supported by the runtime). + reportError(CannotCallNonFunction{*iterMmTy}, forInStatement->values.data[0]->location); + } + } + else + { + reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); + } + } else { reportError(CannotCallNonFunction{iteratorTy}, forInStatement->values.data[0]->location); @@ -730,7 +775,7 @@ struct TypeChecker2 visit(arg); TypeArena arena; - Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}}; + Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}, stack.back()}; TypePackId expectedRetType = lookupPack(call); TypeId functionType = lookupType(call->func); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b6f20ac81..b96046be7 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -47,6 +47,8 @@ LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) LUAU_FASTFLAGVARIABLE(LuauUnionOfTypesFollow, false) +LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) +LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) namespace Luau { @@ -66,9 +68,7 @@ static void defaultLuauPrintLine(const std::string& s) printf("%s\n", s.c_str()); } -using PrintLineProc = decltype(&defaultLuauPrintLine); - -static PrintLineProc luauPrintLine = &defaultLuauPrintLine; +PrintLineProc luauPrintLine = &defaultLuauPrintLine; void setPrintLine(PrintLineProc pl) { @@ -270,16 +270,16 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull singl { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); - globalScope->exportedTypeBindings["any"] = TypeFun{{}, anyType}; - globalScope->exportedTypeBindings["nil"] = TypeFun{{}, nilType}; - globalScope->exportedTypeBindings["number"] = TypeFun{{}, numberType}; - globalScope->exportedTypeBindings["string"] = TypeFun{{}, stringType}; - globalScope->exportedTypeBindings["boolean"] = TypeFun{{}, booleanType}; - globalScope->exportedTypeBindings["thread"] = TypeFun{{}, threadType}; + globalScope->addBuiltinTypeBinding("any", TypeFun{{}, anyType}); + globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, nilType}); + globalScope->addBuiltinTypeBinding("number", TypeFun{{}, numberType}); + globalScope->addBuiltinTypeBinding("string", TypeFun{{}, stringType}); + globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, booleanType}); + globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, threadType}); if (FFlag::LuauUnknownAndNeverType) { - globalScope->exportedTypeBindings["unknown"] = TypeFun{{}, unknownType}; - globalScope->exportedTypeBindings["never"] = TypeFun{{}, neverType}; + globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, unknownType}); + globalScope->addBuiltinTypeBinding("never", TypeFun{{}, neverType}); } } @@ -534,7 +534,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A { if (const auto& typealias = stat->as()) { - check(scope, *typealias, subLevel, true); + prototype(scope, *typealias, subLevel); ++subLevel; } } @@ -698,6 +698,10 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; Name name = typealias->name.value; + + if (FFlag::LuauReportShadowedTypeAlias && duplicateTypeAliases.contains({typealias->exported, name})) + continue; + TypeId type = bindings[name].type; if (get(follow(type))) { @@ -1109,8 +1113,23 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) TypePackId valuePack = checkExprList(scope, local.location, local.values, /* substituteFreeForNil= */ true, instantiateGenerics, expectedTypes).type; + // If the expression list only contains one expression and it's a function call or is otherwise within parentheses, use FunctionResult. + // Otherwise, we'll want to use ExprListResult to make the error messaging more general. + CountMismatch::Context ctx = FFlag::LuauBetterMessagingOnCountMismatch ? CountMismatch::ExprListResult : CountMismatch::FunctionResult; + if (FFlag::LuauBetterMessagingOnCountMismatch) + { + if (local.values.size == 1) + { + AstExpr* e = local.values.data[0]; + while (auto group = e->as()) + e = group->expr; + if (e->is()) + ctx = CountMismatch::FunctionResult; + } + } + Unifier state = mkUnifier(scope, local.location); - state.ctx = CountMismatch::Result; + state.ctx = ctx; state.tryUnify(valuePack, variablePack); reportErrors(state.errors); @@ -1472,10 +1491,8 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; } -void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare) +void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias) { - // This function should be called at most twice for each type alias. - // Once with forwardDeclare, and once without. Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. @@ -1490,14 +1507,134 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias auto& bindingsMap = typealias.exported ? scope->exportedTypeBindings : scope->privateTypeBindings; - if (forwardDeclare) + // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be + // interesting. + if (duplicateTypeAliases.find({typealias.exported, name})) + return; + + // By now this alias must have been `prototype()`d first. + if (!binding) + ice("Not predeclared"); + + ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); + + for (auto param : binding->typeParams) { - if (binding) + auto generic = get(param.ty); + LUAU_ASSERT(generic); + aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, param.ty}; + } + + for (auto param : binding->typePackParams) + { + auto generic = get(param.tp); + LUAU_ASSERT(generic); + aliasScope->privateTypePackBindings[generic->name] = param.tp; + } + + TypeId ty = resolveType(aliasScope, *typealias.type); + if (auto ttv = getMutable(follow(ty))) + { + // If the table is already named and we want to rename the type function, we have to bind new alias to a copy + // Additionally, we can't modify types that come from other modules + if (ttv->name || follow(ty)->owningArena != ¤tModule->internalTypes) + { + bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), + binding->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), binding->typePackParams.begin(), + binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + + // Copy can be skipped if this is an identical alias + if (!ttv->name || ttv->name != name || !sameTys || !sameTps) + { + // This is a shallow clone, original recursive links to self are not updated + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; + clone.definitionModuleName = ttv->definitionModuleName; + clone.name = name; + + for (auto param : binding->typeParams) + clone.instantiatedTypeParams.push_back(param.ty); + + for (auto param : binding->typePackParams) + clone.instantiatedTypePackParams.push_back(param.tp); + + bool isNormal = ty->normal; + ty = addType(std::move(clone)); + + if (FFlag::LuauLowerBoundsCalculation) + asMutable(ty)->normal = isNormal; + } + } + else { - Location location = scope->typeAliasLocations[name]; - reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); + ttv->name = name; + + ttv->instantiatedTypeParams.clear(); + for (auto param : binding->typeParams) + ttv->instantiatedTypeParams.push_back(param.ty); + + ttv->instantiatedTypePackParams.clear(); + for (auto param : binding->typePackParams) + ttv->instantiatedTypePackParams.push_back(param.tp); + } + } + else if (auto mtv = getMutable(follow(ty))) + { + // We can't modify types that come from other modules + if (follow(ty)->owningArena == ¤tModule->internalTypes) + mtv->syntheticName = name; + } + + TypeId& bindingType = bindingsMap[name].type; + + if (unify(ty, bindingType, aliasScope, typealias.location)) + bindingType = ty; + + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(bindingType, currentModule, singletonTypes, *iceHandler); + bindingType = t; + if (!ok) + reportError(typealias.location, NormalizationTooComplex{}); + } +} + +void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel) +{ + Name name = typealias.name.value; + + // If the alias is missing a name, we can't do anything with it. Ignore it. + if (name == kParseNameError) + return; + std::optional binding; + if (auto it = scope->exportedTypeBindings.find(name); it != scope->exportedTypeBindings.end()) + binding = it->second; + else if (auto it = scope->privateTypeBindings.find(name); it != scope->privateTypeBindings.end()) + binding = it->second; + + auto& bindingsMap = typealias.exported ? scope->exportedTypeBindings : scope->privateTypeBindings; + + if (binding) + { + Location location = scope->typeAliasLocations[name]; + reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); + + if (!FFlag::LuauReportShadowedTypeAlias) bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; + + duplicateTypeAliases.insert({typealias.exported, name}); + } + else if (FFlag::LuauReportShadowedTypeAlias) + { + if (globalScope->builtinTypeNames.contains(name)) + { + reportError(typealias.location, DuplicateTypeDefinition{name}); duplicateTypeAliases.insert({typealias.exported, name}); } else @@ -1520,100 +1657,20 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { - // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be - // interesting. - if (duplicateTypeAliases.find({typealias.exported, name})) - return; - - if (!binding) - ice("Not predeclared"); - ScopePtr aliasScope = childScope(scope, typealias.location); aliasScope->level = scope->level.incr(); + aliasScope->level.subLevel = subLevel; - for (auto param : binding->typeParams) - { - auto generic = get(param.ty); - LUAU_ASSERT(generic); - aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, param.ty}; - } - - for (auto param : binding->typePackParams) - { - auto generic = get(param.tp); - LUAU_ASSERT(generic); - aliasScope->privateTypePackBindings[generic->name] = param.tp; - } - - TypeId ty = resolveType(aliasScope, *typealias.type); - if (auto ttv = getMutable(follow(ty))) - { - // If the table is already named and we want to rename the type function, we have to bind new alias to a copy - // Additionally, we can't modify types that come from other modules - if (ttv->name || follow(ty)->owningArena != ¤tModule->internalTypes) - { - bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), - binding->typeParams.end(), [](auto&& itp, auto&& tp) { - return itp == tp.ty; - }); - bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), - binding->typePackParams.begin(), binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) { - return itpp == tpp.tp; - }); - - // Copy can be skipped if this is an identical alias - if (!ttv->name || ttv->name != name || !sameTys || !sameTps) - { - // This is a shallow clone, original recursive links to self are not updated - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - clone.definitionModuleName = ttv->definitionModuleName; - clone.name = name; - - for (auto param : binding->typeParams) - clone.instantiatedTypeParams.push_back(param.ty); - - for (auto param : binding->typePackParams) - clone.instantiatedTypePackParams.push_back(param.tp); - - bool isNormal = ty->normal; - ty = addType(std::move(clone)); - - if (FFlag::LuauLowerBoundsCalculation) - asMutable(ty)->normal = isNormal; - } - } - else - { - ttv->name = name; - - ttv->instantiatedTypeParams.clear(); - for (auto param : binding->typeParams) - ttv->instantiatedTypeParams.push_back(param.ty); + auto [generics, genericPacks] = + createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); - ttv->instantiatedTypePackParams.clear(); - for (auto param : binding->typePackParams) - ttv->instantiatedTypePackParams.push_back(param.tp); - } - } - else if (auto mtv = getMutable(follow(ty))) - { - // We can't modify types that come from other modules - if (follow(ty)->owningArena == ¤tModule->internalTypes) - mtv->syntheticName = name; - } - - TypeId& bindingType = bindingsMap[name].type; - - if (unify(ty, bindingType, aliasScope, typealias.location)) - bindingType = ty; + TypeId ty = freshType(aliasScope); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; - if (FFlag::LuauLowerBoundsCalculation) - { - auto [t, ok] = normalize(bindingType, currentModule, singletonTypes, *iceHandler); - bindingType = t; - if (!ok) - reportError(typealias.location, NormalizationTooComplex{}); - } + scope->typeAliasLocations[name] = typealias.location; } } @@ -4152,7 +4209,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc TypePackId adjustedArgPack = addTypePack(TypePack{std::move(adjustedArgTypes), it.tail()}); TxnLog log; - promoteTypeLevels(log, ¤tModule->internalTypes, level, retPack); + promoteTypeLevels(log, ¤tModule->internalTypes, level, /*scope*/ nullptr, /*useScope*/ false, retPack); log.commit(); *asMutable(fn) = FunctionTypeVar{level, adjustedArgPack, retPack}; @@ -4712,7 +4769,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat if (ftv && ftv->hasNoGenerics) return ty; - Instantiation instantiation{log, ¤tModule->internalTypes, scope->level}; + Instantiation instantiation{log, ¤tModule->internalTypes, scope->level, /*scope*/ nullptr}; if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit) instantiation.childLimit = *instantiationChildLimit; diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 6ea04ea9b..ca00c2699 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -224,4 +224,46 @@ std::pair> getParameterExtents(const TxnLog* log, return {minCount, minCount + optionalCount}; } +std::vector flatten(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length) +{ + std::vector result; + + auto it = begin(pack); + auto endIt = end(pack); + + while (it != endIt) + { + result.push_back(*it); + + if (result.size() >= length) + return result; + + ++it; + } + + if (!it.tail()) + return result; + + TypePackId tail = *it.tail(); + if (get(tail)) + LUAU_ASSERT(0); + else if (auto vtp = get(tail)) + { + while (result.size() < length) + result.push_back(vtp->ty); + } + else if (get(tail) || get(tail)) + { + while (result.size() < length) + result.push_back(arena.addType(FreeTypeVar{nullptr})); + } + else if (auto etp = get(tail)) + { + while (result.size() < length) + result.push_back(singletonTypes->errorRecoveryType()); + } + + return result; +} + } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 3a820ea6e..bf6bf34a7 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -474,6 +474,17 @@ FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackI { } +FunctionTypeVar::FunctionTypeVar( + TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) + : level(level) + , scope(scope) + , argTypes(argTypes) + , retTypes(retTypes) + , definition(std::move(defn)) + , hasSelf(hasSelf) +{ +} + FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : generics(generics) @@ -497,9 +508,23 @@ FunctionTypeVar::FunctionTypeVar(TypeLevel level, std::vector generics, { } -TableTypeVar::TableTypeVar(TableState state, TypeLevel level) +FunctionTypeVar::FunctionTypeVar(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, + TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) + : level(level) + , scope(scope) + , generics(generics) + , genericPacks(genericPacks) + , argTypes(argTypes) + , retTypes(retTypes) + , definition(std::move(defn)) + , hasSelf(hasSelf) +{ +} + +TableTypeVar::TableTypeVar(TableState state, TypeLevel level, Scope* scope) : state(state) , level(level) + , scope(scope) { } @@ -511,6 +536,15 @@ TableTypeVar::TableTypeVar(const Props& props, const std::optional { } +TableTypeVar::TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, Scope* scope, TableState state) + : props(props) + , indexer(indexer) + , state(state) + , level(level) + , scope(scope) +{ +} + // Test TypeVars for equivalence // More complex than we'd like because TypeVars can self-reference. diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index fa76e8204..a3d4540cb 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -18,6 +18,13 @@ Free::Free(Scope* scope) { } +Free::Free(Scope* scope, TypeLevel level) + : index(++nextIndex) + , level(level) + , scope(scope) +{ +} + int Free::nextIndex = 0; Generic::Generic() diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 505e9e437..c13a6f8b5 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -23,6 +23,7 @@ LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauCallUnifyPackTails) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau { @@ -33,10 +34,15 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor const TypeArena* typeArena = nullptr; TypeLevel minLevel; - PromoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel) + Scope* outerScope = nullptr; + bool useScopes; + + PromoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes) : log(log) , typeArena(typeArena) , minLevel(minLevel) + , outerScope(outerScope) + , useScopes(useScopes) { } @@ -44,9 +50,18 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor void promote(TID ty, T* t) { LUAU_ASSERT(t); - if (minLevel.subsumesStrict(t->level)) + + if (useScopes) { - log.changeLevel(ty, minLevel); + if (subsumesStrict(outerScope, t->scope)) + log.changeScope(ty, NotNull{outerScope}); + } + else + { + if (minLevel.subsumesStrict(t->level)) + { + log.changeLevel(ty, minLevel); + } } } @@ -123,23 +138,23 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor } }; -static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) +static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes, TypeId ty) { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) return; - PromoteTypeLevels ptl{log, typeArena, minLevel}; + PromoteTypeLevels ptl{log, typeArena, minLevel, outerScope, useScopes}; ptl.traverse(ty); } -void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) +void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (tp->owningArena != typeArena) return; - PromoteTypeLevels ptl{log, typeArena, minLevel}; + PromoteTypeLevels ptl{log, typeArena, minLevel, outerScope, useScopes}; ptl.traverse(tp); } @@ -318,6 +333,16 @@ static std::optional> getTableMat return std::nullopt; } +// TODO: Inline and clip with FFlag::DebugLuauDeferredConstraintResolution +template +static bool subsumes(bool useScopes, TY_A* left, TY_B* right) +{ + if (useScopes) + return subsumes(left->scope, right->scope); + else + return left->level.subsumes(right->level); +} + Unifier::Unifier(TypeArena* types, NotNull singletonTypes, Mode mode, NotNull scope, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) @@ -375,7 +400,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); - if (superFree && subFree && superFree->level.subsumes(subFree->level)) + if (superFree && subFree && subsumes(useScopes, superFree, subFree)) { if (!occursCheck(subTy, superTy)) log.replace(subTy, BoundTypeVar(superTy)); @@ -386,7 +411,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { if (!occursCheck(superTy, subTy)) { - if (superFree->level.subsumes(subFree->level)) + if (subsumes(useScopes, superFree, subFree)) { log.changeLevel(subTy, superFree->level); } @@ -400,7 +425,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { // Unification can't change the level of a generic. auto subGeneric = log.getMutable(subTy); - if (subGeneric && !subGeneric->level.subsumes(superFree->level)) + if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) { // TODO: a more informative error message? CLI-39912 reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -409,7 +434,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursCheck(superTy, subTy)) { - promoteTypeLevels(log, types, superFree->level, subTy); + promoteTypeLevels(log, types, superFree->level, superFree->scope, useScopes, subTy); Widen widen{types, singletonTypes}; log.replace(superTy, BoundTypeVar(widen(subTy))); @@ -429,7 +454,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto superGeneric = log.getMutable(superTy); - if (superGeneric && !superGeneric->level.subsumes(subFree->level)) + if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}}); @@ -438,7 +463,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursCheck(subTy, superTy)) { - promoteTypeLevels(log, types, subFree->level, superTy); + promoteTypeLevels(log, types, subFree->level, subFree->scope, useScopes, superTy); log.replace(subTy, BoundTypeVar(superTy)); } @@ -855,6 +880,7 @@ struct WeirdIter size_t index; bool growing; TypeLevel level; + Scope* scope = nullptr; WeirdIter(TypePackId packId, TxnLog& log) : packId(packId) @@ -915,6 +941,7 @@ struct WeirdIter LUAU_ASSERT(log.getMutable(newTail)); level = log.getMutable(packId)->level; + scope = log.getMutable(packId)->scope; log.replace(packId, BoundTypePack(newTail)); packId = newTail; pack = log.getMutable(newTail); @@ -1055,8 +1082,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal auto superIter = WeirdIter(superTp, log); auto subIter = WeirdIter(subTp, log); - auto mkFreshType = [this](TypeLevel level) { - return types->freshType(level); + auto mkFreshType = [this](Scope* scope, TypeLevel level) { + return types->freshType(scope, level); }; const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); @@ -1072,12 +1099,12 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (superIter.good() && subIter.growing) { - subIter.pushType(mkFreshType(subIter.level)); + subIter.pushType(mkFreshType(subIter.scope, subIter.level)); } if (subIter.good() && superIter.growing) { - superIter.pushType(mkFreshType(superIter.level)); + superIter.pushType(mkFreshType(superIter.scope, superIter.level)); } if (superIter.good() && subIter.good()) @@ -1158,7 +1185,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // these to produce the expected error message. size_t expectedSize = size(superTp); size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result) + if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult) std::swap(expectedSize, actualSize); reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}}); @@ -1271,7 +1298,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal else if (!innerState.errors.empty()) reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); - innerState.ctx = CountMismatch::Result; + innerState.ctx = CountMismatch::FunctionResult; innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); if (!reported) @@ -1295,7 +1322,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal ctx = CountMismatch::Arg; tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - ctx = CountMismatch::Result; + ctx = CountMismatch::FunctionResult; tryUnify_(subFunction->retTypes, superFunction->retTypes); } @@ -1693,8 +1720,45 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { case TableState::Free: { - tryUnify_(subTy, superMetatable->table); - log.bindTable(subTy, superTy); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + Unifier innerState = makeChildUnifier(); + bool missingProperty = false; + + for (const auto& [propName, prop] : subTable->props) + { + if (std::optional mtPropTy = findTablePropertyRespectingMeta(superTy, propName)) + { + innerState.tryUnify(prop.type, *mtPropTy); + } + else + { + reportError(mismatchError); + missingProperty = true; + break; + } + } + + if (const TableTypeVar* superTable = log.get(log.follow(superMetatable->table))) + { + // TODO: Unify indexers. + } + + if (auto e = hasUnificationTooComplex(innerState.errors)) + reportError(*e); + else if (!innerState.errors.empty()) + reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); + else if (!missingProperty) + { + log.concat(std::move(innerState.log)); + log.bindTable(subTy, superTy); + } + } + else + { + tryUnify_(subTy, superMetatable->table); + log.bindTable(subTy, superTy); + } break; } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index cd50ef007..7e4c5691c 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -270,7 +270,7 @@ int main(int argc, char** argv) CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); - Luau::registerBuiltinTypes(frontend.typeChecker); + Luau::registerBuiltinGlobals(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); #ifdef CALLGRIND diff --git a/CMakeLists.txt b/CMakeLists.txt index 9ad16e8d2..43289f418 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,8 @@ if(LUAU_BUILD_WEB) add_executable(Luau.Web) endif() +# Proxy target to make it possible to depend on private VM headers +add_library(Luau.VM.Internals INTERFACE) include(Sources.cmake) @@ -79,6 +81,8 @@ target_link_libraries(Luau.VM PUBLIC Luau.Common) target_include_directories(isocline PUBLIC extern/isocline/include) +target_include_directories(Luau.VM.Internals INTERFACE VM/src) + set(LUAU_OPTIONS) if(MSVC) diff --git a/CodeGen/src/Fallbacks.cpp b/CodeGen/src/Fallbacks.cpp index 3893d349d..625a9bd6a 100644 --- a/CodeGen/src/Fallbacks.cpp +++ b/CodeGen/src/Fallbacks.cpp @@ -720,7 +720,7 @@ const Instruction* execute_LOP_CALL(lua_State* L, const Instruction* pc, Closure int i; for (i = nresults; i != 0 && vali < valend; i--) - setobjs2s(L, res++, vali++); + setobj2s(L, res++, vali++); while (i-- > 0) setnilvalue(res++); @@ -756,7 +756,7 @@ const Instruction* execute_LOP_RETURN(lua_State* L, const Instruction* pc, Closu // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally int i; for (i = nresults; i != 0 && vali < valend; i--) - setobjs2s(L, res++, vali++); + setobj2s(L, res++, vali++); while (i-- > 0) setnilvalue(res++); @@ -1667,7 +1667,7 @@ const Instruction* execute_LOP_CONCAT(lua_State* L, const Instruction* pc, Closu StkId ra = VM_REG(LUAU_INSN_A(insn)); - setobjs2s(L, ra, base + b); + setobj2s(L, ra, base + b); VM_PROTECT(luaC_checkGC(L)); return pc; } @@ -2003,9 +2003,9 @@ const Instruction* execute_LOP_FORGLOOP(lua_State* L, const Instruction* pc, Clo else { // note: it's safe to push arguments past top for complicated reasons (see top of the file) - setobjs2s(L, ra + 3 + 2, ra + 2); - setobjs2s(L, ra + 3 + 1, ra + 1); - setobjs2s(L, ra + 3, ra); + setobj2s(L, ra + 3 + 2, ra + 2); + setobj2s(L, ra + 3 + 1, ra + 1); + setobj2s(L, ra + 3, ra); L->top = ra + 3 + 3; // func + 2 args (state and index) LUAU_ASSERT(L->top <= L->stack_last); @@ -2017,7 +2017,7 @@ const Instruction* execute_LOP_FORGLOOP(lua_State* L, const Instruction* pc, Clo ra = VM_REG(LUAU_INSN_A(insn)); // copy first variable back into the iteration index - setobjs2s(L, ra + 2, ra + 3); + setobj2s(L, ra + 2, ra + 3); // note that we need to increment pc by 1 to exit the loop since we need to skip over aux pc += ttisnil(ra + 3) ? 1 : LUAU_INSN_D(insn); @@ -2094,7 +2094,7 @@ const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, C StkId ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack for (int j = 0; j < n; j++) - setobjs2s(L, ra + j, base - n + j); + setobj2s(L, ra + j, base - n + j); L->top = ra + n; return pc; @@ -2104,7 +2104,7 @@ const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, C StkId ra = VM_REG(LUAU_INSN_A(insn)); for (int j = 0; j < b && j < n; j++) - setobjs2s(L, ra + j, base - n + j); + setobj2s(L, ra + j, base - n + j); for (int j = n; j < b; j++) setnilvalue(ra + j); return pc; @@ -2183,7 +2183,7 @@ const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, for (int i = 0; i < numparams; ++i) { - setobjs2s(L, base + i, fixed + i); + setobj2s(L, base + i, fixed + i); setnilvalue(fixed + i); } diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index f8652434b..32e5ba9bb 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -517,6 +517,10 @@ enum LuauBuiltinFunction // bit32.extract(_, k, k) LBF_BIT32_EXTRACTK, + + // get/setmetatable + LBF_GETMETATABLE, + LBF_SETMETATABLE, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 269337301..8d4640daf 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,6 +4,8 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinMT, false) + namespace Luau { namespace Compile @@ -64,6 +66,14 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op if (builtin.isGlobal("select")) return LBF_SELECT_VARARG; + if (FFlag::LuauCompileBuiltinMT) + { + if (builtin.isGlobal("getmetatable")) + return LBF_GETMETATABLE; + if (builtin.isGlobal("setmetatable")) + return LBF_SETMETATABLE; + } + if (builtin.object == "math") { if (builtin.method == "abs") diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index e5ce4d5a3..28307eb90 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -235,7 +235,7 @@ void lua_remove(lua_State* L, int idx) StkId p = index2addr(L, idx); api_checkvalidindex(L, p); while (++p < L->top) - setobjs2s(L, p - 1, p); + setobj2s(L, p - 1, p); L->top--; return; } @@ -246,8 +246,8 @@ void lua_insert(lua_State* L, int idx) StkId p = index2addr(L, idx); api_checkvalidindex(L, p); for (StkId q = L->top; q > p; q--) - setobjs2s(L, q, q - 1); - setobjs2s(L, p, L->top); + setobj2s(L, q, q - 1); + setobj2s(L, p, L->top); return; } @@ -614,7 +614,7 @@ void lua_pushlstring(lua_State* L, const char* s, size_t len) { luaC_checkGC(L); luaC_threadbarrier(L); - setsvalue2s(L, L->top, luaS_newlstr(L, s, len)); + setsvalue(L, L->top, luaS_newlstr(L, s, len)); api_incr_top(L); return; } @@ -1269,7 +1269,7 @@ void lua_concat(lua_State* L, int n) else if (n == 0) { // push empty string luaC_threadbarrier(L); - setsvalue2s(L, L->top, luaS_newlstr(L, "", 0)); + setsvalue(L, L->top, luaS_newlstr(L, "", 0)); api_incr_top(L); } // else n == 1; nothing to do diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index c42e5ccc2..c2d07dddd 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -400,7 +400,7 @@ char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc) lua_insert(L, boxloc); } - setsvalue2s(L, L->top + boxloc, newStorage); + setsvalue(L, L->top + boxloc, newStorage); B->p = newStorage->data + (B->p - base); B->end = newStorage->data + nextsize; B->storage = newStorage; @@ -451,11 +451,11 @@ void luaL_pushresult(luaL_Buffer* B) // if we finished just at the end of the string buffer, we can convert it to a mutable stirng without a copy if (B->p == B->end) { - setsvalue2s(L, L->top - 1, luaS_buffinish(L, storage)); + setsvalue(L, L->top - 1, luaS_buffinish(L, storage)); } else { - setsvalue2s(L, L->top - 1, luaS_newlstr(L, storage->data, B->p - storage->data)); + setsvalue(L, L->top - 1, luaS_newlstr(L, storage->data, B->p - storage->data)); } } else diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index c16e5aa70..87b6ae0ff 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -789,7 +789,7 @@ static int luauF_type(lua_State* L, StkId res, TValue* arg0, int nresults, StkId int tt = ttype(arg0); TString* ttname = L->global->ttname[tt]; - setsvalue2s(L, res, ttname); + setsvalue(L, res, ttname); return 1; } @@ -861,7 +861,7 @@ static int luauF_char(lua_State* L, StkId res, TValue* arg0, int nresults, StkId buffer[nparams] = 0; - setsvalue2s(L, res, luaS_newlstr(L, buffer, nparams)); + setsvalue(L, res, luaS_newlstr(L, buffer, nparams)); return 1; } @@ -887,7 +887,7 @@ static int luauF_typeof(lua_State* L, StkId res, TValue* arg0, int nresults, Stk { const TString* ttname = luaT_objtypenamestr(L, arg0); - setsvalue2s(L, res, ttname); + setsvalue(L, res, ttname); return 1; } @@ -904,7 +904,7 @@ static int luauF_sub(lua_State* L, StkId res, TValue* arg0, int nresults, StkId if (i >= 1 && j >= i && unsigned(j - 1) < unsigned(ts->len)) { - setsvalue2s(L, res, luaS_newlstr(L, getstr(ts) + (i - 1), j - i + 1)); + setsvalue(L, res, luaS_newlstr(L, getstr(ts) + (i - 1), j - i + 1)); return 1; } } @@ -993,12 +993,13 @@ static int luauF_rawset(lua_State* L, StkId res, TValue* arg0, int nresults, Stk else if (ttisvector(key) && luai_vecisnan(vvalue(key))) return -1; - if (hvalue(arg0)->readonly) + Table* t = hvalue(arg0); + if (t->readonly) return -1; setobj2s(L, res, arg0); - setobj2t(L, luaH_set(L, hvalue(arg0), args), args + 1); - luaC_barriert(L, hvalue(arg0), args + 1); + setobj2t(L, luaH_set(L, t, args), args + 1); + luaC_barriert(L, t, args + 1); return 1; } @@ -1009,12 +1010,13 @@ static int luauF_tinsert(lua_State* L, StkId res, TValue* arg0, int nresults, St { if (nparams == 2 && nresults <= 0 && ttistable(arg0)) { - if (hvalue(arg0)->readonly) + Table* t = hvalue(arg0); + if (t->readonly) return -1; - int pos = luaH_getn(hvalue(arg0)) + 1; - setobj2t(L, luaH_setnum(L, hvalue(arg0), pos), args); - luaC_barriert(L, hvalue(arg0), args); + int pos = luaH_getn(t) + 1; + setobj2t(L, luaH_setnum(L, t, pos), args); + luaC_barriert(L, t, args); return 0; } @@ -1193,6 +1195,60 @@ static int luauF_extractk(lua_State* L, StkId res, TValue* arg0, int nresults, S return -1; } +static int luauF_getmetatable(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + Table* mt = NULL; + if (ttistable(arg0)) + mt = hvalue(arg0)->metatable; + else if (ttisuserdata(arg0)) + mt = uvalue(arg0)->metatable; + else + mt = L->global->mt[ttype(arg0)]; + + const TValue* mtv = mt ? luaH_getstr(mt, L->global->tmname[TM_METATABLE]) : luaO_nilobject; + if (!ttisnil(mtv)) + { + setobj2s(L, res, mtv); + return 1; + } + + if (mt) + { + sethvalue(L, res, mt); + return 1; + } + else + { + setnilvalue(res); + return 1; + } + } + + return -1; +} + +static int luauF_setmetatable(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + // note: setmetatable(_, nil) is rare so we use fallback for it to optimize the fast path + if (nparams >= 2 && nresults <= 1 && ttistable(arg0) && ttistable(args)) + { + Table* t = hvalue(arg0); + if (t->readonly || t->metatable != NULL) + return -1; // note: overwriting non-null metatable is very rare but it requires __metatable check + + Table* mt = hvalue(args); + t->metatable = mt; + luaC_objbarrier(L, t, mt); + + sethvalue(L, res, t); + return 1; + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1268,4 +1324,7 @@ luau_FastFunction luauF_table[256] = { luauF_rawlen, luauF_extractk, + + luauF_getmetatable, + luauF_setmetatable, }; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 7c181d4ec..e695cd2b3 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -83,7 +83,7 @@ const char* lua_setlocal(lua_State* L, int level, int n) Proto* fp = getluaproto(ci); const LocVar* var = fp ? luaF_getlocal(fp, n, currentpc(L, ci)) : NULL; if (var) - setobjs2s(L, ci->base + var->reg, L->top - 1); + setobj2s(L, ci->base + var->reg, L->top - 1); L->top--; // pop value const char* name = var ? getstr(var->varname) : NULL; return name; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index ecd2fcbbf..ff8105b8c 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -263,18 +263,18 @@ static void seterrorobj(lua_State* L, int errcode, StkId oldtop) { case LUA_ERRMEM: { - setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_MEMERRMSG)); // can not fail because string is pinned in luaopen + setsvalue(L, oldtop, luaS_newliteral(L, LUA_MEMERRMSG)); // can not fail because string is pinned in luaopen break; } case LUA_ERRERR: { - setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_ERRERRMSG)); // can not fail because string is pinned in luaopen + setsvalue(L, oldtop, luaS_newliteral(L, LUA_ERRERRMSG)); // can not fail because string is pinned in luaopen break; } case LUA_ERRSYNTAX: case LUA_ERRRUN: { - setobjs2s(L, oldtop, L->top - 1); // error message on current top + setobj2s(L, oldtop, L->top - 1); // error message on current top break; } } @@ -419,7 +419,7 @@ static void resume_handle(lua_State* L, void* ud) static int resume_error(lua_State* L, const char* msg) { L->top = L->ci->base; - setsvalue2s(L, L->top, luaS_new(L, msg)); + setsvalue(L, L->top, luaS_new(L, msg)); incr_top(L); return LUA_ERRRUN; } @@ -525,8 +525,8 @@ static void callerrfunc(lua_State* L, void* ud) { StkId errfunc = cast_to(StkId, ud); - setobjs2s(L, L->top, L->top - 1); - setobjs2s(L, L->top - 1, errfunc); + setobj2s(L, L->top, L->top - 1); + setobj2s(L, L->top - 1, errfunc); incr_top(L); luaD_call(L, L->top - 2, 1); } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index c2a672e65..4b9fbb69b 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -118,8 +118,6 @@ LUAU_FASTFLAGVARIABLE(LuauBetterThreadMark, false) * slot - upvalues like this are identified since they don't have `markedopen` bit set during thread traversal and closed in `clearupvals`. */ -LUAU_FASTFLAGVARIABLE(LuauFasterSweep, false) - #define GC_SWEEPPAGESTEPCOST 16 #define GC_INTERRUPT(state) \ @@ -836,28 +834,6 @@ static size_t atomic(lua_State* L) return work; } -static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) -{ - LUAU_ASSERT(!FFlag::LuauFasterSweep); - global_State* g = L->global; - - int deadmask = otherwhite(g); - LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); // make sure we never sweep fixed objects - - int alive = (gco->gch.marked ^ WHITEBITS) & deadmask; - - if (alive) - { - LUAU_ASSERT(!isdead(g, gco)); - makewhite(g, gco); // make it white (for next cycle) - return false; - } - - LUAU_ASSERT(isdead(g, gco)); - freeobj(L, gco, page); - return true; -} - // a version of generic luaM_visitpage specialized for the main sweep stage static int sweepgcopage(lua_State* L, lua_Page* page) { @@ -869,58 +845,36 @@ static int sweepgcopage(lua_State* L, lua_Page* page) LUAU_ASSERT(busyBlocks > 0); - if (FFlag::LuauFasterSweep) - { - global_State* g = L->global; - - int deadmask = otherwhite(g); - LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); // make sure we never sweep fixed objects + global_State* g = L->global; - int newwhite = luaC_white(g); + int deadmask = otherwhite(g); + LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); // make sure we never sweep fixed objects - for (char* pos = start; pos != end; pos += blockSize) - { - GCObject* gco = (GCObject*)pos; + int newwhite = luaC_white(g); - // skip memory blocks that are already freed - if (gco->gch.tt == LUA_TNIL) - continue; + for (char* pos = start; pos != end; pos += blockSize) + { + GCObject* gco = (GCObject*)pos; - // is the object alive? - if ((gco->gch.marked ^ WHITEBITS) & deadmask) - { - LUAU_ASSERT(!isdead(g, gco)); - // make it white (for next cycle) - gco->gch.marked = cast_byte((gco->gch.marked & maskmarks) | newwhite); - } - else - { - LUAU_ASSERT(isdead(g, gco)); - freeobj(L, gco, page); + // skip memory blocks that are already freed + if (gco->gch.tt == LUA_TNIL) + continue; - // if the last block was removed, page would be removed as well - if (--busyBlocks == 0) - return int(pos - start) / blockSize + 1; - } + // is the object alive? + if ((gco->gch.marked ^ WHITEBITS) & deadmask) + { + LUAU_ASSERT(!isdead(g, gco)); + // make it white (for next cycle) + gco->gch.marked = cast_byte((gco->gch.marked & maskmarks) | newwhite); } - } - else - { - for (char* pos = start; pos != end; pos += blockSize) + else { - GCObject* gco = (GCObject*)pos; - - // skip memory blocks that are already freed - if (gco->gch.tt == LUA_TNIL) - continue; + LUAU_ASSERT(isdead(g, gco)); + freeobj(L, gco, page); - // when true is returned it means that the element was deleted - if (sweepgco(L, page, gco)) - { - // if the last block was removed, page would be removed as well - if (--busyBlocks == 0) - return int(pos - start) / blockSize + 1; - } + // if the last block was removed, page would be removed as well + if (--busyBlocks == 0) + return int(pos - start) / blockSize + 1; } } @@ -1009,15 +963,8 @@ static size_t gcstep(lua_State* L, size_t limit) if (g->sweepgcopage == NULL) { // don't forget to visit main thread, it's the only object not allocated in GCO pages - if (FFlag::LuauFasterSweep) - { - LUAU_ASSERT(!isdead(g, obj2gco(g->mainthread))); - makewhite(g, obj2gco(g->mainthread)); // make it white (for next cycle) - } - else - { - sweepgco(L, NULL, obj2gco(g->mainthread)); - } + LUAU_ASSERT(!isdead(g, obj2gco(g->mainthread))); + makewhite(g, obj2gco(g->mainthread)); // make it white (for next cycle) shrinkbuffers(L); diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index b6a40bb6b..f5f1cd0e8 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -102,7 +102,7 @@ const char* luaO_pushvfstring(lua_State* L, const char* fmt, va_list argp) char result[LUA_BUFFERSIZE]; vsnprintf(result, sizeof(result), fmt, argp); - setsvalue2s(L, L->top, luaS_new(L, result)); + setsvalue(L, L->top, luaS_new(L, result)); incr_top(L); return svalue(L->top - 1); } diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 97cbfbb11..5f5e7b1c8 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -200,20 +200,14 @@ typedef struct lua_TValue ** different types of sets, according to destination */ -// from stack to (same) stack -#define setobjs2s setobj -// to stack (not from same stack) +// to stack #define setobj2s setobj -#define setsvalue2s setsvalue -#define sethvalue2s sethvalue -#define setptvalue2s setptvalue -// from table to same table +// from table to same table (no barrier) #define setobjt2t setobj -// to table +// to table (needs barrier) #define setobj2t setobj -// to new object +// to new object (no barrier) #define setobj2n setobj -#define setsvalue2n setsvalue #define setttype(obj, tt) (ttype(obj) = (tt)) diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index cb7ba097a..d753e8a43 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -57,6 +57,7 @@ const char* const luaT_eventname[] = { "__le", "__concat", "__type", + "__metatable", }; // clang-format on diff --git a/VM/src/ltm.h b/VM/src/ltm.h index f20ce1b22..4b1c28181 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -36,6 +36,7 @@ typedef enum TM_LE, TM_CONCAT, TM_TYPE, + TM_METATABLE, TM_N // number of elements in the enum } TMS; diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 6ceed5120..490358c4b 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -985,7 +985,7 @@ static void luau_execute(lua_State* L) int i; for (i = nresults; i != 0 && vali < valend; i--) - setobjs2s(L, res++, vali++); + setobj2s(L, res++, vali++); while (i-- > 0) setnilvalue(res++); @@ -1022,7 +1022,7 @@ static void luau_execute(lua_State* L) // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally int i; for (i = nresults; i != 0 && vali < valend; i--) - setobjs2s(L, res++, vali++); + setobj2s(L, res++, vali++); while (i-- > 0) setnilvalue(res++); @@ -1945,7 +1945,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); - setobjs2s(L, ra, base + b); + setobj2s(L, ra, base + b); VM_PROTECT(luaC_checkGC(L)); VM_NEXT(); } @@ -2281,9 +2281,9 @@ static void luau_execute(lua_State* L) else { // note: it's safe to push arguments past top for complicated reasons (see top of the file) - setobjs2s(L, ra + 3 + 2, ra + 2); - setobjs2s(L, ra + 3 + 1, ra + 1); - setobjs2s(L, ra + 3, ra); + setobj2s(L, ra + 3 + 2, ra + 2); + setobj2s(L, ra + 3 + 1, ra + 1); + setobj2s(L, ra + 3, ra); L->top = ra + 3 + 3; // func + 2 args (state and index) LUAU_ASSERT(L->top <= L->stack_last); @@ -2295,7 +2295,7 @@ static void luau_execute(lua_State* L) ra = VM_REG(LUAU_INSN_A(insn)); // copy first variable back into the iteration index - setobjs2s(L, ra + 2, ra + 3); + setobj2s(L, ra + 2, ra + 3); // note that we need to increment pc by 1 to exit the loop since we need to skip over aux pc += ttisnil(ra + 3) ? 1 : LUAU_INSN_D(insn); @@ -2372,7 +2372,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack for (int j = 0; j < n; j++) - setobjs2s(L, ra + j, base - n + j); + setobj2s(L, ra + j, base - n + j); L->top = ra + n; VM_NEXT(); @@ -2382,7 +2382,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); for (int j = 0; j < b && j < n; j++) - setobjs2s(L, ra + j, base - n + j); + setobj2s(L, ra + j, base - n + j); for (int j = n; j < b; j++) setnilvalue(ra + j); VM_NEXT(); @@ -2461,7 +2461,7 @@ static void luau_execute(lua_State* L) for (int i = 0; i < numparams; ++i) { - setobjs2s(L, base + i, fixed + i); + setobj2s(L, base + i, fixed + i); setnilvalue(fixed + i); } @@ -2878,7 +2878,7 @@ int luau_precall(lua_State* L, StkId func, int nresults) int i; for (i = nresults; i != 0 && vali < valend; i--) - setobjs2s(L, res++, vali++); + setobj2s(L, res++, vali++); while (i-- > 0) setnilvalue(res++); @@ -2906,7 +2906,7 @@ void luau_poscall(lua_State* L, StkId first) int i; for (i = ci->nresults; i != 0 && vali < valend; i--) - setobjs2s(L, res++, vali++); + setobj2s(L, res++, vali++); while (i-- > 0) setnilvalue(res++); diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 0ae85ab61..bd40bad2f 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -240,7 +240,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size case LBC_CONSTANT_STRING: { TString* v = readString(strings, data, size, offset); - setsvalue2n(L, &p->k[j], v); + setsvalue(L, &p->k[j], v); break; } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 35124e632..5c5551580 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -38,7 +38,7 @@ int luaV_tostring(lua_State* L, StkId obj) double n = nvalue(obj); char* e = luai_num2str(s, n); LUAU_ASSERT(e < s + sizeof(s)); - setsvalue2s(L, obj, luaS_newlstr(L, s, e - s)); + setsvalue(L, obj, luaS_newlstr(L, s, e - s)); return 1; } } @@ -70,7 +70,7 @@ static StkId callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p luaD_call(L, L->top - 3, 1); res = restorestack(L, result); L->top--; - setobjs2s(L, res, L->top); + setobj2s(L, res, L->top); return res; } @@ -350,11 +350,11 @@ void luaV_concat(lua_State* L, int total, int last) if (tl < LUA_BUFFERSIZE) { - setsvalue2s(L, top - n, luaS_newlstr(L, buffer, tl)); + setsvalue(L, top - n, luaS_newlstr(L, buffer, tl)); } else { - setsvalue2s(L, top - n, luaS_buffinish(L, ts)); + setsvalue(L, top - n, luaS_buffinish(L, ts)); } } total -= n - 1; // got `n' strings to create 1 new @@ -582,7 +582,7 @@ LUAU_NOINLINE void luaV_tryfuncTM(lua_State* L, StkId func) if (!ttisfunction(tm)) luaG_typeerror(L, func, "call"); for (StkId p = L->top; p > func; p--) // open space for metamethod - setobjs2s(L, p, p - 1); + setobj2s(L, p, p - 1); L->top++; // stack space pre-allocated by the caller setobj2s(L, func, tm); // tag method is the new function to be called } diff --git a/fuzz/linter.cpp b/fuzz/linter.cpp index 0bdd49f58..66ca5bb14 100644 --- a/fuzz/linter.cpp +++ b/fuzz/linter.cpp @@ -21,7 +21,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) static Luau::NullModuleResolver moduleResolver; static Luau::InternalErrorReporter iceHandler; static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); - static int once = (Luau::registerBuiltinTypes(sharedEnv), 1); + static int once = (Luau::registerBuiltinGlobals(sharedEnv), 1); (void)once; static int once2 = (Luau::freeze(sharedEnv.globalTypes), 1); (void)once2; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index f64b61571..9e0a68296 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -96,7 +96,7 @@ int registerTypes(Luau::TypeChecker& env) using namespace Luau; using std::nullopt; - Luau::registerBuiltinTypes(env); + Luau::registerBuiltinGlobals(env); TypeArena& arena = env.globalTypes; diff --git a/fuzz/typeck.cpp b/fuzz/typeck.cpp index 3905cc191..a6c9ae284 100644 --- a/fuzz/typeck.cpp +++ b/fuzz/typeck.cpp @@ -26,7 +26,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) static Luau::NullModuleResolver moduleResolver; static Luau::InternalErrorReporter iceHandler; static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); - static int once = (Luau::registerBuiltinTypes(sharedEnv), 1); + static int once = (Luau::registerBuiltinGlobals(sharedEnv), 1); (void)once; static int once2 = (Luau::freeze(sharedEnv.globalTypes), 1); (void)once2; diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 9409c8222..d0f777852 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -798,6 +798,8 @@ RETURN R0 1 TEST_CASE("TableSizePredictionSetMetatable") { + ScopedFastFlag sff("LuauCompileBuiltinMT", true); + CHECK_EQ("\n" + compileFunction0(R"( local t = setmetatable({}, nil) t.field1 = 1 @@ -805,14 +807,15 @@ t.field2 = 2 return t )"), R"( -GETIMPORT R0 1 NEWTABLE R1 2 0 -LOADNIL R2 +FASTCALL2K 61 R1 K0 L0 +LOADK R2 K0 +GETIMPORT R0 2 CALL R0 2 1 -LOADN R1 1 -SETTABLEKS R1 R0 K2 -LOADN R1 2 +L0: LOADN R1 1 SETTABLEKS R1 R0 K3 +LOADN R1 2 +SETTABLEKS R1 R0 K4 RETURN R0 1 )"); } diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 77b30487b..a6e4e6e1f 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -499,7 +499,7 @@ TEST_CASE("Types") Luau::SingletonTypes singletonTypes; Luau::TypeChecker env(&moduleResolver, Luau::NotNull{&singletonTypes}, &iceHandler); - Luau::registerBuiltinTypes(env); + Luau::registerBuiltinGlobals(env); Luau::freeze(env.globalTypes); lua_newtable(L); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 3f77978ce..dcc0222ab 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -8,7 +8,6 @@ #include "Luau/TypeVar.h" #include "Luau/TypeAttach.h" #include "Luau/Transpiler.h" - #include "Luau/BuiltinDefinitions.h" #include "doctest.h" @@ -20,6 +19,8 @@ static const char* mainModuleName = "MainModule"; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(LuauReportShadowedTypeAlias) namespace Luau { @@ -97,6 +98,8 @@ Fixture::Fixture(bool freeze, bool prepareAutocomplete) configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.parseOptions.captureComments = true; + registerBuiltinTypes(frontend); + Luau::freeze(frontend.typeChecker.globalTypes); Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); @@ -435,9 +438,9 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::unfreeze(frontend.typeChecker.globalTypes); Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); - registerBuiltinTypes(frontend); + registerBuiltinGlobals(frontend); if (prepareAutocomplete) - registerBuiltinTypes(frontend.typeCheckerForAutocomplete); + registerBuiltinGlobals(frontend.typeCheckerForAutocomplete); registerTestTypes(); Luau::freeze(frontend.typeChecker.globalTypes); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index a8a9e044b..5f74931c0 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1070,12 +1070,12 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_type_alias") fileResolver.source["Module/A"] = R"( type KeyOfTestEvents = "test-file-start" | "test-file-success" | "test-file-failure" | "test-case-result" - type unknown = any + type MyAny = any export type TestFileEvent = ( eventName: T, args: any --[[ ROBLOX TODO: Unhandled node for type: TSIndexedAccessType ]] --[[ TestEvents[T] ]] - ) -> unknown + ) -> MyAny return {} )"; diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index ab5d85989..1339ec28a 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -790,4 +790,20 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") CHECK_EQ("foo:method(arg: string): ()", toStringNamedFunction("foo:method", *ftv, opts)); } +TEST_CASE_FIXTURE(Fixture, "tostring_unsee_ttv_if_array") +{ + ScopedFastFlag sff("LuauUnseeArrayTtv", true); + + CheckResult result = check(R"( + local x: {string} + -- This code is constructed very specifically to use the same (by pointer + -- identity) type in the function twice. + local y: (typeof(x), typeof(x)) -> () + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(toString(requireType("y")) == "({string}, {string}) -> ()"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 834391a75..5ecc2a8ca 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -443,7 +443,8 @@ TEST_CASE_FIXTURE(Fixture, "reported_location_is_correct_when_type_alias_are_dup auto dtd = get(result.errors[0]); REQUIRE(dtd); CHECK_EQ(dtd->name, "B"); - CHECK_EQ(dtd->previousLocation.begin.line + 1, 3); + REQUIRE(dtd->previousLocation); + CHECK_EQ(dtd->previousLocation->begin.line + 1, 3); } TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") @@ -868,4 +869,40 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "report_shadowed_aliases") +{ + ScopedFastFlag sff{"LuauReportShadowedTypeAlias", true}; + + // We allow a previous type alias to depend on a future type alias. That exact feature enables a confusing example, like the following snippet, + // which has the type alias FakeString point to the type alias `string` that which points to `number`. + CheckResult result = check(R"( + type MyString = string + type string = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Redefinition of type 'string'"); + + std::optional t1 = lookupType("MyString"); + REQUIRE(t1); + CHECK(isPrim(*t1, PrimitiveTypeVar::String)); + + std::optional t2 = lookupType("string"); + REQUIRE(t2); + CHECK(isPrim(*t2, PrimitiveTypeVar::String)); +} + +TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_shadow_user_defined_alias") +{ + CheckResult result = check(R"( + type T = number + + do + type T = string + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 5d18b335d..5f2c22cfd 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -745,7 +745,12 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") TEST_CASE_FIXTURE(BuiltinsFixture, "luau_print_is_magic_if_the_flag_is_set") { - // Luau::resetPrintLine(); + static std::vector output; + output.clear(); + Luau::setPrintLine([](const std::string& s) { + output.push_back(s); + }); + ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; CheckResult result = check(R"( @@ -753,6 +758,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "luau_print_is_magic_if_the_flag_is_set") )"); LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE(1 == output.size()); } TEST_CASE_FIXTURE(Fixture, "luau_print_is_not_special_without_the_flag") diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 6da3f569b..037f79d8a 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -11,6 +11,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); LUAU_FASTFLAG(LuauStringFormatArgumentErrorFix) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) TEST_SUITE_BEGIN("BuiltinTests"); @@ -596,6 +597,15 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall") CHECK_EQ("boolean", toString(requireType("c"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "trivial_select") +{ + CheckResult result = check(R"( + local a:number = select(1, 42) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "see_thru_select") { CheckResult result = check(R"( @@ -679,10 +689,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); - CHECK_EQ("any", toString(requireType("baz"))); - CHECK_EQ("any", toString(requireType("quux"))); + if (FFlag::DebugLuauDeferredConstraintResolution && FFlag::LuauSpecialTypesAsterisked) + { + CHECK_EQ("string", toString(requireType("foo"))); + CHECK_EQ("*error-type*", toString(requireType("bar"))); + CHECK_EQ("*error-type*", toString(requireType("baz"))); + CHECK_EQ("*error-type*", toString(requireType("quux"))); + } + else + { + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); + CHECK_EQ("any", toString(requireType("baz"))); + CHECK_EQ("any", toString(requireType("quux"))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_string_head") @@ -698,10 +718,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_strin LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); - CHECK_EQ("any", toString(requireType("baz"))); - CHECK_EQ("any", toString(requireType("quux"))); + if (FFlag::DebugLuauDeferredConstraintResolution && FFlag::LuauSpecialTypesAsterisked) + { + CHECK_EQ("string", toString(requireType("foo"))); + CHECK_EQ("string", toString(requireType("bar"))); + CHECK_EQ("*error-type*", toString(requireType("baz"))); + CHECK_EQ("*error-type*", toString(requireType("quux"))); + } + else + { + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); + CHECK_EQ("any", toString(requireType("baz"))); + CHECK_EQ("any", toString(requireType("quux"))); + } } TEST_CASE_FIXTURE(Fixture, "string_format_as_method") @@ -1099,7 +1129,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); CHECK_EQ(acm->expected, 1); CHECK_EQ(acm->actual, 4); @@ -1116,7 +1146,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); CHECK_EQ(acm->expected, 3); CHECK_EQ(acm->actual, 4); @@ -1135,7 +1165,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_igno CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); CHECK_EQ(acm->expected, 2); CHECK_EQ(acm->actual, 3); @@ -1288,7 +1318,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); CHECK_EQ(acm->expected, 2); CHECK_EQ(acm->actual, 4); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index bde28dccb..a4420b9a8 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -837,6 +837,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "calling_function_with_anytypepack_doesnt_lea TEST_CASE_FIXTURE(Fixture, "too_many_return_values") { + ScopedFastFlag sff{"LuauBetterMessagingOnCountMismatch", true}; + CheckResult result = check(R"( --!strict @@ -851,7 +853,49 @@ TEST_CASE_FIXTURE(Fixture, "too_many_return_values") CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 2); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_return_values_in_parentheses") +{ + ScopedFastFlag sff{"LuauBetterMessagingOnCountMismatch", true}; + + CheckResult result = check(R"( + --!strict + + function f() + return 55 + end + + local a, b = (f()) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 2); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_return_values_no_function") +{ + ScopedFastFlag sff{"LuauBetterMessagingOnCountMismatch", true}; + + CheckResult result = check(R"( + --!strict + + local a, b = 55 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::ExprListResult); CHECK_EQ(acm->expected, 1); CHECK_EQ(acm->actual, 2); } @@ -1271,7 +1315,7 @@ local b: B = a LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' caused by: - Function only returns 1 value. 2 are required here)"); + Function only returns 1 value, but 2 are required here)"); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 85249ecdc..588a9a763 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -526,6 +526,16 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_generic_next") +{ + CheckResult result = check(R"( + for k: number, v: number in next, {1, 2, 3} do + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") { CheckResult result = check(R"( @@ -584,11 +594,48 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_nonstrict") LUAU_REQUIRE_ERROR_COUNT(0, result); } -TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_iter_metamethod") +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_nil") { + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + CheckResult result = check(R"( - local t = {} - setmetatable(t, { __iter = function(o) return next, o.children end }) + local t = setmetatable({}, { __iter = function(o) return next, nil end, }) + for k: number, v: string in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Type 'nil' could not be converted into '{- [a]: b -}'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_not_enough_returns") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local t = setmetatable({}, { __iter = function(o) end }) + for k: number, v: string in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{ + Location{{2, 36}, {2, 37}}, + GenericError{"__iter must return at least one value"}, + }); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_ok") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local t = setmetatable({ + children = {"foo"} + }, { __iter = function(o) return next, o.children end }) for k: number, v: string in t do end )"); @@ -596,4 +643,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_iter_metamethod") LUAU_REQUIRE_ERROR_COUNT(0, result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_ok_with_inference") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local t = setmetatable({ + children = {"foo"} + }, { __iter = function(o) return next, o.children end }) + + local a, b + for k, v in t do + a = k + b = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("a")) == "number"); + CHECK(toString(requireType("b")) == "string"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 2c4c35040..45740a0b1 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -439,13 +439,15 @@ TEST_CASE_FIXTURE(Fixture, "normalization_fails_on_certain_kinds_of_cyclic_table // Belongs in TypeInfer.builtins.test.cpp. TEST_CASE_FIXTURE(BuiltinsFixture, "pcall_returns_at_least_two_value_but_function_returns_nothing") { + ScopedFastFlag sff{"LuauBetterMessagingOnCountMismatch", true}; + CheckResult result = check(R"( local function f(): () end local ok, res = pcall(f) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Function only returns 1 value. 2 are required here", toString(result.errors[0])); + CHECK_EQ("Function only returns 1 value, but 2 are required here", toString(result.errors[0])); // LUAU_REQUIRE_NO_ERRORS(result); // CHECK_EQ("boolean", toString(requireType("ok"))); // CHECK_EQ("any", toString(requireType("res"))); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index d2bff9c3c..b6dedcbd3 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -256,28 +256,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") REQUIRE_EQ("number", toString(requireType("b"))); } -TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") -{ - CheckResult result = check(R"( - type ActuallyString = string - - do -- Necessary. Otherwise toposort has ActuallyString come after string type alias. - type string = number - local foo: string = 1 - - if type(foo) == "string" then - local bar: ActuallyString = foo - local baz: boolean = foo - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("never", toString(requireTypeAtPosition({8, 44}))); - CHECK_EQ("never", toString(requireTypeAtPosition({9, 38}))); -} - TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index a96c36760..8ed61b496 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1159,4 +1159,38 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + function validate(stats, hits, misses) + local checked = {} + + for _,l in ipairs(hits) do + if not (stats[l] and stats[l] > 0) then + return false, string.format("expected line %d to be hit", l) + end + checked[l] = true + end + + for _,l in ipairs(misses) do + if not (stats[l] and stats[l] == 0) then + return false, string.format("expected line %d to be missed", l) + end + checked[l] = true + end + + for k,v in pairs(stats) do + if type(k) == "number" and not checked[k] then + return false, string.format("expected line %d to be absent", k) + end + end + + return true + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 7fa0fac0f..3911c520d 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -271,4 +271,40 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") CHECK_EQ(a->owningArena, &arena); } +TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table") +{ + ScopedFastFlag sff("DebugLuauDeferredConstraintResolution", true); + + TableTypeVar::Props freeProps{ + {"foo", {typeChecker.numberType}}, + }; + + TypeId free = arena.addType(TableTypeVar{freeProps, std::nullopt, TypeLevel{}, TableState::Free}); + + TableTypeVar::Props indexProps{ + {"foo", {typeChecker.stringType}}, + }; + + TypeId index = arena.addType(TableTypeVar{indexProps, std::nullopt, TypeLevel{}, TableState::Sealed}); + + TableTypeVar::Props mtProps{ + {"__index", {index}}, + }; + + TypeId mt = arena.addType(TableTypeVar{mtProps, std::nullopt, TypeLevel{}, TableState::Sealed}); + + TypeId target = arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); + TypeId metatable = arena.addType(MetatableTypeVar{target, mt}); + + state.tryUnify(metatable, free); + state.log.commit(); + + REQUIRE_EQ(state.errors.size(), 1); + + std::string expected = "Type '{ @metatable {| __index: {| foo: string |} |}, { } }' could not be converted into '{- foo: number -}'\n" + "caused by:\n" + " Type 'number' could not be converted into 'string'"; + CHECK_EQ(toString(state.errors[0]), expected); +} + TEST_SUITE_END(); diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 6dcdbf0e9..0c6055dac 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -24,6 +24,9 @@ assert(getmetatable(nil) == nil) a={}; setmetatable(a, {__metatable = "xuxu", __tostring=function(x) return x.name end}) assert(getmetatable(a) == "xuxu") +ud=newproxy(true); getmetatable(ud).__metatable = "xuxu" +assert(getmetatable(ud) == "xuxu") + local res,err = pcall(tostring, a) assert(not res and err == "'__tostring' must return a string") -- cannot change a protected metatable diff --git a/tools/faillist.txt b/tools/faillist.txt index 825fb2f68..00e01011b 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,101 +1,47 @@ AnnotationTests.builtin_types_are_not_exported +AnnotationTests.corecursive_types_error_on_tight_loop AnnotationTests.duplicate_type_param_name AnnotationTests.for_loop_counter_annotation_is_checked AnnotationTests.generic_aliases_are_cloned_properly AnnotationTests.instantiation_clone_has_to_follow -AnnotationTests.luau_ice_triggers_an_ice -AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag -AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag_handler -AnnotationTests.luau_ice_triggers_an_ice_handler -AnnotationTests.luau_print_is_magic_if_the_flag_is_set AnnotationTests.occurs_check_on_cyclic_intersection_typevar AnnotationTests.occurs_check_on_cyclic_union_typevar +AnnotationTests.too_many_type_params AnnotationTests.two_type_params AnnotationTests.use_type_required_from_another_file AstQuery.last_argument_function_call_type AstQuery::getDocumentationSymbolAtPosition.overloaded_fn -AutocompleteTest.argument_types -AutocompleteTest.arguments_to_global_lambda -AutocompleteTest.autocomplete_boolean_singleton -AutocompleteTest.autocomplete_end_with_fn_exprs -AutocompleteTest.autocomplete_end_with_lambda AutocompleteTest.autocomplete_first_function_arg_expected_type -AutocompleteTest.autocomplete_for_in_middle_keywords -AutocompleteTest.autocomplete_for_middle_keywords -AutocompleteTest.autocomplete_if_middle_keywords AutocompleteTest.autocomplete_interpolated_string -AutocompleteTest.autocomplete_on_string_singletons AutocompleteTest.autocomplete_oop_implicit_self -AutocompleteTest.autocomplete_repeat_middle_keyword +AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.autocomplete_string_singleton_escape AutocompleteTest.autocomplete_string_singletons -AutocompleteTest.autocomplete_while_middle_keywords AutocompleteTest.autocompleteProp_index_function_metamethod_is_variadic -AutocompleteTest.bias_toward_inner_scope AutocompleteTest.cyclic_table AutocompleteTest.do_compatible_self_calls -AutocompleteTest.do_not_overwrite_context_sensitive_kws -AutocompleteTest.do_not_suggest_internal_module_type AutocompleteTest.do_wrong_compatible_self_calls -AutocompleteTest.dont_offer_any_suggestions_from_within_a_broken_comment -AutocompleteTest.dont_offer_any_suggestions_from_within_a_broken_comment_at_the_very_end_of_the_file -AutocompleteTest.dont_offer_any_suggestions_from_within_a_comment -AutocompleteTest.dont_suggest_local_before_its_definition -AutocompleteTest.function_expr_params -AutocompleteTest.function_in_assignment_has_parentheses -AutocompleteTest.function_in_assignment_has_parentheses_2 -AutocompleteTest.function_parameters -AutocompleteTest.function_result_passed_to_function_has_parentheses -AutocompleteTest.generic_types -AutocompleteTest.get_suggestions_for_the_very_start_of_the_script -AutocompleteTest.global_function_params -AutocompleteTest.global_functions_are_not_scoped_lexically -AutocompleteTest.globals_are_order_independent -AutocompleteTest.if_then_else_elseif_completions AutocompleteTest.keyword_methods -AutocompleteTest.library_non_self_calls_are_fine -AutocompleteTest.library_self_calls_are_invalid -AutocompleteTest.local_function -AutocompleteTest.local_function_params -AutocompleteTest.local_functions_fall_out_of_scope -AutocompleteTest.method_call_inside_function_body -AutocompleteTest.nested_member_completions -AutocompleteTest.nested_recursive_function -AutocompleteTest.no_function_name_suggestions AutocompleteTest.no_incompatible_self_calls AutocompleteTest.no_incompatible_self_calls_2 -AutocompleteTest.no_incompatible_self_calls_on_class AutocompleteTest.no_wrong_compatible_self_calls_with_generics -AutocompleteTest.recursive_function -AutocompleteTest.recursive_function_global -AutocompleteTest.recursive_function_local -AutocompleteTest.return_types -AutocompleteTest.sometimes_the_metatable_is_an_error -AutocompleteTest.source_module_preservation_and_invalidation -AutocompleteTest.statement_between_two_statements -AutocompleteTest.string_prim_non_self_calls_are_avoided -AutocompleteTest.string_prim_self_calls_are_fine -AutocompleteTest.suggest_external_module_type -AutocompleteTest.table_intersection -AutocompleteTest.table_union +AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_argument_type_suggestion AutocompleteTest.type_correct_expected_argument_type_pack_suggestion -AutocompleteTest.type_correct_expected_argument_type_suggestion AutocompleteTest.type_correct_expected_argument_type_suggestion_optional AutocompleteTest.type_correct_expected_argument_type_suggestion_self +AutocompleteTest.type_correct_expected_return_type_pack_suggestion AutocompleteTest.type_correct_expected_return_type_suggestion AutocompleteTest.type_correct_full_type_suggestion AutocompleteTest.type_correct_function_no_parenthesis AutocompleteTest.type_correct_function_return_types AutocompleteTest.type_correct_function_type_suggestion AutocompleteTest.type_correct_keywords -AutocompleteTest.type_correct_local_type_suggestion -AutocompleteTest.type_correct_sealed_table AutocompleteTest.type_correct_suggestion_for_overloads AutocompleteTest.type_correct_suggestion_in_argument +AutocompleteTest.type_correct_suggestion_in_table AutocompleteTest.unsealed_table AutocompleteTest.unsealed_table_2 -AutocompleteTest.user_defined_local_functions_in_own_definition BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 @@ -149,21 +95,20 @@ BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload BuiltinTests.table_pack BuiltinTests.table_pack_reduce +BuiltinTests.table_pack_variadic BuiltinTests.tonumber_returns_optional_number_type BuiltinTests.tonumber_returns_optional_number_type2 DefinitionTests.class_definition_overload_metamethods DefinitionTests.declaring_generic_functions DefinitionTests.definition_file_classes -FrontendTest.ast_node_at_position FrontendTest.automatically_check_dependent_scripts FrontendTest.environments FrontendTest.imported_table_modification_2 FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded FrontendTest.nocheck_cycle_used_by_checked -FrontendTest.produce_errors_for_unchanged_file_with_a_syntax_error FrontendTest.recheck_if_dependent_script_is_dirty FrontendTest.reexport_cyclic_type -FrontendTest.report_syntax_error_in_required_file +FrontendTest.reexport_type_alias FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics2 @@ -212,13 +157,16 @@ IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect isSubtype.intersection_of_tables isSubtype.table_with_table_prop +ModuleTests.any_persistance_does_not_leak ModuleTests.clone_self_property ModuleTests.deepClone_cyclic_table ModuleTests.do_not_clone_reexports NonstrictModeTests.for_in_iterator_variables_are_any NonstrictModeTests.function_parameters_are_any NonstrictModeTests.inconsistent_module_return_types_are_ok +NonstrictModeTests.inconsistent_return_types_are_ok NonstrictModeTests.infer_nullary_function +NonstrictModeTests.infer_the_maximum_number_of_values_the_function_could_return NonstrictModeTests.inline_table_props_are_also_any NonstrictModeTests.local_tables_are_not_any NonstrictModeTests.locals_are_any_by_default @@ -324,7 +272,6 @@ RefinementTest.typeguard_in_if_condition_position RefinementTest.typeguard_narrows_for_functions RefinementTest.typeguard_narrows_for_table RefinementTest.typeguard_not_to_be_string -RefinementTest.typeguard_only_look_up_types_from_global_scope RefinementTest.what_nonsensical_condition RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RefinementTest.x_is_not_instance_or_else_not_part @@ -349,6 +296,7 @@ TableTests.defining_a_self_method_for_a_builtin_sealed_table_must_fail TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index +TableTests.dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back TableTests.dont_leak_free_table_props TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys @@ -363,13 +311,11 @@ TableTests.found_like_key_in_table_function_call TableTests.found_like_key_in_table_property_access TableTests.found_multiple_like_keys TableTests.function_calls_produces_sealed_table_given_unsealed_table -TableTests.generalize_table_argument TableTests.getmetatable_returns_pointer_to_metatable TableTests.give_up_after_one_metatable_index_look_up TableTests.hide_table_error_properties TableTests.indexer_fn TableTests.indexer_on_sealed_table_must_unify_with_free_table -TableTests.indexer_table TableTests.indexing_from_a_table_should_prefer_properties_when_possible TableTests.inequality_operators_imply_exactly_matching_types TableTests.infer_array_2 @@ -395,11 +341,11 @@ TableTests.open_table_unification_2 TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table_2 TableTests.pass_incompatible_union_to_a_generic_table_without_crashing -TableTests.passing_compatible_unions_to_a_generic_table_without_crashing TableTests.persistent_sealed_table_is_immutable TableTests.prop_access_on_key_whose_types_mismatches TableTests.property_lookup_through_tabletypevar_metatable TableTests.quantify_even_that_table_was_never_exported_at_all +TableTests.quantify_metatables_of_metatables_of_table TableTests.quantifying_a_bound_var_works TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table TableTests.result_is_always_any_if_lhs_is_any @@ -435,8 +381,11 @@ ToString.function_type_with_argument_names_generic ToString.no_parentheses_around_cyclic_function_type_in_union ToString.toStringDetailed2 ToString.toStringErrorPack +ToString.toStringNamedFunction_generic_pack +ToString.toStringNamedFunction_hide_self_param ToString.toStringNamedFunction_hide_type_params ToString.toStringNamedFunction_id +ToString.toStringNamedFunction_include_self_param ToString.toStringNamedFunction_map ToString.toStringNamedFunction_variadics TranspilerTests.types_should_not_be_considered_cyclic_if_they_are_not_recursive @@ -453,6 +402,7 @@ TypeAliases.mutually_recursive_types_restriction_not_ok_1 TypeAliases.mutually_recursive_types_restriction_not_ok_2 TypeAliases.mutually_recursive_types_swapsies_not_ok TypeAliases.recursive_types_restriction_not_ok +TypeAliases.report_shadowed_aliases TypeAliases.stringify_optional_parameterized_alias TypeAliases.stringify_type_alias_of_recursive_template_table_type TypeAliases.stringify_type_alias_of_recursive_template_table_type2 @@ -460,11 +410,14 @@ TypeAliases.type_alias_fwd_declaration_is_precise TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_of_an_imported_recursive_generic_type +TypeAliases.type_alias_of_an_imported_recursive_type TypeInfer.checking_should_not_ice +TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.globals TypeInfer.globals2 TypeInfer.infer_assignment_value_types_mutable_lval +TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict TypeInfer.no_stack_overflow_from_isoptional TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.tc_if_else_expressions_expected_type_3 @@ -489,6 +442,7 @@ TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_propert TypeInferClasses.warn_when_prop_almost_matches TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class TypeInferFunctions.call_o_with_another_argument_after_foo_was_quantified +TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site @@ -500,16 +454,20 @@ TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite TypeInferFunctions.function_does_not_return_enough_values TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer +TypeInferFunctions.ignored_return_values TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.inconsistent_higher_order_function TypeInferFunctions.inconsistent_return_types TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_return_type_from_selected_overload +TypeInferFunctions.infer_return_value_type TypeInferFunctions.infer_that_function_does_not_return_a_table +TypeInferFunctions.it_is_ok_not_to_supply_enough_retvals TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count TypeInferFunctions.no_lossy_function_type +TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.quantify_constrained_types TypeInferFunctions.record_matching_overload TypeInferFunctions.report_exiting_without_return_nonstrict @@ -521,7 +479,13 @@ TypeInferFunctions.too_few_arguments_variadic_generic TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_return_values +TypeInferFunctions.too_many_return_values_in_parentheses +TypeInferFunctions.too_many_return_values_no_function TypeInferFunctions.vararg_function_is_quantified +TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values +TypeInferLoops.for_in_loop_with_custom_iterator +TypeInferLoops.for_in_loop_with_next +TypeInferLoops.for_in_with_generic_next TypeInferLoops.for_in_with_just_one_iterator_is_ok TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil @@ -529,6 +493,9 @@ TypeInferLoops.loop_typecheck_crash_on_empty_optional TypeInferLoops.unreachable_code_after_infinite_loop TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferModules.custom_require_global +TypeInferModules.do_not_modify_imported_types +TypeInferModules.do_not_modify_imported_types_2 +TypeInferModules.do_not_modify_imported_types_3 TypeInferModules.general_require_type_mismatch TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated @@ -539,8 +506,8 @@ TypeInferOOP.CheckMethodsOfSealed TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon -TypeInferOOP.inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory +TypeInferOOP.methods_are_topologically_sorted TypeInferOperators.and_adds_boolean TypeInferOperators.and_adds_boolean_no_superfluous_union TypeInferOperators.and_binexps_dont_unify @@ -564,6 +531,7 @@ TypeInferOperators.expected_types_through_binary_or TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown TypeInferOperators.or_joins_types TypeInferOperators.or_joins_types_with_no_extras +TypeInferOperators.primitive_arith_no_metatable TypeInferOperators.primitive_arith_possible_metatable TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not TypeInferOperators.refine_and_or @@ -591,6 +559,7 @@ TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_never TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_sorta_never TypeInferUnknownNever.math_operators_and_never +TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.unary_minus_of_never TypePackTests.higher_order_function diff --git a/tools/lldb_formatters.lldb b/tools/lldb_formatters.lldb index f6fa6cf5a..4a5acd742 100644 --- a/tools/lldb_formatters.lldb +++ b/tools/lldb_formatters.lldb @@ -1,2 +1,7 @@ +type synthetic add -x "^Luau::detail::DenseHashTable<.*>$" -l lldb_formatters.DenseHashTableSyntheticChildrenProvider +type summary add "Luau::Symbol" -F lldb_formatters.luau_symbol_summary + type synthetic add -x "^Luau::Variant<.+>$" -l lldb_formatters.LuauVariantSyntheticChildrenProvider -type summary add -x "^Luau::Variant<.+>$" -l lldb_formatters.luau_variant_summary +type summary add -x "^Luau::Variant<.+>$" -F lldb_formatters.luau_variant_summary + +type synthetic add -x "^Luau::AstArray<.+>$" -l lldb_formatters.AstArraySyntheticChildrenProvider diff --git a/tools/lldb_formatters.py b/tools/lldb_formatters.py index ff610d096..19fc0f54c 100644 --- a/tools/lldb_formatters.py +++ b/tools/lldb_formatters.py @@ -4,30 +4,31 @@ # We're forced to resort to parsing names as strings. def templateParams(s): depth = 0 - start = s.find('<') + 1 + start = s.find("<") + 1 result = [] for i, c in enumerate(s[start:], start): - if c == '<': + if c == "<": depth += 1 - elif c == '>': + elif c == ">": if depth == 0: - result.append(s[start: i].strip()) + result.append(s[start:i].strip()) break depth -= 1 - elif c == ',' and depth == 0: - result.append(s[start: i].strip()) + elif c == "," and depth == 0: + result.append(s[start:i].strip()) start = i + 1 return result + def getType(target, typeName): stars = 0 typeName = typeName.strip() - while typeName.endswith('*'): + while typeName.endswith("*"): stars += 1 typeName = typeName[:-1] - if typeName.startswith('const '): + if typeName.startswith("const "): typeName = typeName[6:] ty = target.FindFirstType(typeName.strip()) @@ -36,13 +37,10 @@ def getType(target, typeName): return ty + def luau_variant_summary(valobj, internal_dict, options): - type_id = valobj.GetChildMemberWithName("typeId").GetValueAsUnsigned() - storage = valobj.GetChildMemberWithName("storage") - params = templateParams(valobj.GetType().GetCanonicalType().GetName()) - stored_type = params[type_id] - value = storage.Cast(stored_type.GetPointerType()).Dereference() - return stored_type.GetDisplayTypeName() + " [" + value.GetValue() + "]" + return valobj.GetChildMemberWithName("type").GetSummary()[1:-1] + class LuauVariantSyntheticChildrenProvider: node_names = ["type", "value"] @@ -74,26 +72,42 @@ def get_child_at_index(self, index): if node == "type": if self.current_type: - return self.valobj.CreateValueFromExpression(node, f"(const char*)\"{self.current_type.GetDisplayTypeName()}\"") + return self.valobj.CreateValueFromExpression( + node, f'(const char*)"{self.current_type.GetDisplayTypeName()}"' + ) else: - return self.valobj.CreateValueFromExpression(node, "(const char*)\"\"") + return self.valobj.CreateValueFromExpression( + node, '(const char*)""' + ) elif node == "value": if self.stored_value is not None: if self.current_type is not None: - return self.valobj.CreateValueFromData(node, self.stored_value.GetData(), self.current_type) + return self.valobj.CreateValueFromData( + node, self.stored_value.GetData(), self.current_type + ) else: - return self.valobj.CreateValueExpression(node, "(const char*)\"\"") + return self.valobj.CreateValueExpression( + node, '(const char*)""' + ) else: - return self.valobj.CreateValueFromExpression(node, "(const char*)\"\"") + return self.valobj.CreateValueFromExpression( + node, '(const char*)""' + ) else: return None def update(self): - self.type_index = self.valobj.GetChildMemberWithName("typeId").GetValueAsSigned() - self.type_params = templateParams(self.valobj.GetType().GetCanonicalType().GetName()) + self.type_index = self.valobj.GetChildMemberWithName( + "typeId" + ).GetValueAsSigned() + self.type_params = templateParams( + self.valobj.GetType().GetCanonicalType().GetName() + ) if len(self.type_params) > self.type_index: - self.current_type = getType(self.valobj.GetTarget(), self.type_params[self.type_index]) + self.current_type = getType( + self.valobj.GetTarget(), self.type_params[self.type_index] + ) if self.current_type: storage = self.valobj.GetChildMemberWithName("storage") @@ -105,3 +119,97 @@ def update(self): self.stored_value = None return False + + +class DenseHashTableSyntheticChildrenProvider: + def __init__(self, valobj, internal_dict): + """this call should initialize the Python object using valobj as the variable to provide synthetic children for""" + self.valobj = valobj + self.update() + + def num_children(self): + """this call should return the number of children that you want your object to have""" + return self.capacity + + def get_child_index(self, name): + """this call should return the index of the synthetic child whose name is given as argument""" + try: + if name.startswith("[") and name.endswith("]"): + return int(name[1:-1]) + else: + return -1 + except Exception as e: + print("get_child_index exception", e) + return -1 + + def get_child_at_index(self, index): + """this call should return a new LLDB SBValue object representing the child at the index given as argument""" + try: + dataMember = self.valobj.GetChildMemberWithName("data") + + data = dataMember.GetPointeeData(index) + + return self.valobj.CreateValueFromData( + f"[{index}]", + data, + dataMember.Dereference().GetType(), + ) + + except Exception as e: + print("get_child_at_index error", e) + + def update(self): + """this call should be used to update the internal state of this Python object whenever the state of the variables in LLDB changes.[1] + Also, this method is invoked before any other method in the interface.""" + self.capacity = self.valobj.GetChildMemberWithName( + "capacity" + ).GetValueAsUnsigned() + + def has_children(self): + """this call should return True if this object might have children, and False if this object can be guaranteed not to have children.[2]""" + return True + + +def luau_symbol_summary(valobj, internal_dict, options): + local = valobj.GetChildMemberWithName("local") + global_ = valobj.GetChildMemberWithName("global").GetChildMemberWithName("value") + + if local.GetValueAsUnsigned() != 0: + return f'local {local.GetChildMemberWithName("name").GetChildMemberWithName("value").GetSummary()}' + elif global_.GetValueAsUnsigned() != 0: + return f"global {global_.GetSummary()}" + else: + return "???" + + +class AstArraySyntheticChildrenProvider: + def __init__(self, valobj, internal_dict): + self.valobj = valobj + + def num_children(self): + return self.size + + def get_child_index(self, name): + try: + if name.startswith("[") and name.endswith("]"): + return int(name[1:-1]) + else: + return -1 + except Exception as e: + print("get_child_index error:", e) + + def get_child_at_index(self, index): + try: + dataMember = self.valobj.GetChildMemberWithName("data") + data = dataMember.GetPointeeData(index) + return self.valobj.CreateValueFromData( + f"[{index}]", data, dataMember.Dereference().GetType() + ) + except Exception as e: + print("get_child_index error:", e) + + def update(self): + self.size = self.valobj.GetChildMemberWithName("size").GetValueAsUnsigned() + + def has_children(self): + return True diff --git a/tools/perfgraph.py b/tools/perfgraph.py index 7d2639df7..eb6b68ce1 100644 --- a/tools/perfgraph.py +++ b/tools/perfgraph.py @@ -4,7 +4,6 @@ # Given a profile dump, this tool generates a flame graph based on the stacks listed in the profile # The result of analysis is a .svg file which can be viewed in a browser -import sys import svg import argparse import json diff --git a/tools/perfstat.py b/tools/perfstat.py new file mode 100644 index 000000000..e5cfd1173 --- /dev/null +++ b/tools/perfstat.py @@ -0,0 +1,65 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a profile dump, this tool displays top functions based on the stacks listed in the profile + +import argparse + +class Node: + def __init__(self): + self.function = "" + self.source = "" + self.line = 0 + self.hier_ticks = 0 + self.self_ticks = 0 + + def title(self): + if self.line > 0: + return "{} ({}:{})".format(self.function, self.source, self.line) + else: + return self.function + +argumentParser = argparse.ArgumentParser(description='Display summary statistics from Luau sampling profiler dumps') +argumentParser.add_argument('source_file', type=open) +argumentParser.add_argument('--limit', dest='limit', type=int, default=10, help='Display top N functions') + +arguments = argumentParser.parse_args() + +dump = arguments.source_file.readlines() + +stats = {} +total = 0 +total_gc = 0 + +for l in dump: + ticks, stack = l.strip().split(" ", 1) + hier = {} + + for f in reversed(stack.split(";")): + source, function, line = f.split(",") + node = stats.setdefault(f, Node()) + + node.function = function + node.source = source + node.line = int(line) if len(line) > 0 else 0 + + if not node in hier: + node.hier_ticks += int(ticks) + hier[node] = True + + total += int(ticks) + node.self_ticks += int(ticks) + + if node.source == "GC": + total_gc += int(ticks) + +if total > 0: + print(f"Runtime: {total:,} usec ({100.0 * total_gc / total:.2f}% GC)") + print() + print("Top functions (self time):") + for n in sorted(stats.values(), key=lambda node: node.self_ticks, reverse=True)[:arguments.limit]: + print(f"{n.self_ticks:12,} usec ({100.0 * n.self_ticks / total:.2f}%): {n.title()}") + print() + print("Top functions (total time):") + for n in sorted(stats.values(), key=lambda node: node.hier_ticks, reverse=True)[:arguments.limit]: + print(f"{n.hier_ticks:12,} usec ({100.0 * n.hier_ticks / total:.2f}%): {n.title()}") diff --git a/tools/test_dcr.py b/tools/test_dcr.py index db932253b..1e3a50176 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -39,7 +39,7 @@ def startElement(self, name, attrs): elif name == "OverallResultsAsserts": if self.currentTest: - passed = 0 == safeParseInt(attrs["failures"]) + passed = attrs["test_case_success"] == "true" dottedName = ".".join(self.currentTest) From 91e144ac1b83e0b3fddf42db5549a6b916034fb5 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 6 Oct 2022 16:55:58 -0700 Subject: [PATCH 08/66] Sync to upstream/release/548 --- Analysis/include/Luau/ConstraintSolver.h | 12 +- Analysis/include/Luau/Error.h | 2 +- Analysis/include/Luau/Frontend.h | 7 +- Analysis/include/Luau/Normalize.h | 229 ++- Analysis/include/Luau/TypeInfer.h | 3 + Analysis/include/Luau/Unifiable.h | 4 +- Analysis/include/Luau/Unifier.h | 10 +- Analysis/src/Autocomplete.cpp | 132 +- Analysis/src/ConstraintSolver.cpp | 20 +- Analysis/src/Error.cpp | 4 +- Analysis/src/Frontend.cpp | 8 +- Analysis/src/Module.cpp | 14 +- Analysis/src/Normalize.cpp | 1696 +++++++++++++++++++- Analysis/src/TypeChecker2.cpp | 6 +- Analysis/src/TypeInfer.cpp | 103 +- Analysis/src/TypeVar.cpp | 5 + Analysis/src/Unifiable.cpp | 26 +- Analysis/src/Unifier.cpp | 288 +++- Ast/src/Parser.cpp | 23 + CLI/Profiler.cpp | 6 +- CodeGen/include/Luau/AssemblyBuilderX64.h | 1 + CodeGen/include/Luau/CodeAllocator.h | 8 +- CodeGen/include/Luau/CodeBlockUnwind.h | 2 +- CodeGen/include/Luau/UnwindBuilder.h | 5 +- CodeGen/include/Luau/UnwindBuilderDwarf2.h | 5 + CodeGen/include/Luau/UnwindBuilderWin.h | 5 + CodeGen/src/AssemblyBuilderX64.cpp | 28 +- CodeGen/src/CodeAllocator.cpp | 18 +- CodeGen/src/CodeBlockUnwind.cpp | 6 +- CodeGen/src/UnwindBuilderDwarf2.cpp | 10 + CodeGen/src/UnwindBuilderWin.cpp | 10 + Common/include/Luau/ExperimentalFlags.h | 1 + Makefile | 17 +- VM/include/lua.h | 4 +- VM/src/ldebug.cpp | 59 +- VM/src/lgc.cpp | 59 +- VM/src/lobject.cpp | 74 +- VM/src/lobject.h | 2 +- VM/src/lvmexecute.cpp | 9 +- VM/src/lvmload.cpp | 8 +- VM/src/lvmutils.cpp | 7 +- tests/AssemblyBuilderX64.test.cpp | 12 +- tests/Autocomplete.test.cpp | 18 - tests/CodeAllocator.test.cpp | 119 +- tests/ConstraintSolver.test.cpp | 16 +- tests/Fixture.cpp | 4 +- tests/Linter.test.cpp | 164 +- tests/Normalize.test.cpp | 10 +- tests/Parser.test.cpp | 55 + tests/TypeInfer.classes.test.cpp | 34 + tests/TypeInfer.functions.test.cpp | 30 +- tests/TypeInfer.generics.test.cpp | 65 +- tests/TypeInfer.intersectionTypes.test.cpp | 455 ++++++ tests/TypeInfer.modules.test.cpp | 23 +- tests/TypeInfer.provisional.test.cpp | 110 +- tests/TypeInfer.tables.test.cpp | 65 +- tests/TypeInfer.test.cpp | 60 +- tests/TypeInfer.tryUnify.test.cpp | 4 +- tests/TypeInfer.typePacks.cpp | 19 + tests/TypeInfer.unionTypes.test.cpp | 177 ++ tests/conformance/basic.lua | 3 + tests/main.cpp | 17 + tools/lldb_formatters.lldb | 3 + tools/natvis/Ast.natvis | 21 + tools/test_dcr.py | 50 +- 65 files changed, 3995 insertions(+), 475 deletions(-) diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 06f53e4ab..9d5aadfbc 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -7,6 +7,7 @@ #include "Luau/Constraint.h" #include "Luau/TypeVar.h" #include "Luau/ToString.h" +#include "Luau/Normalize.h" #include @@ -44,6 +45,7 @@ struct ConstraintSolver TypeArena* arena; NotNull singletonTypes; InternalErrorReporter iceReporter; + NotNull normalizer; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; @@ -74,9 +76,12 @@ struct ConstraintSolver DcrLogger* logger; - explicit ConstraintSolver(TypeArena* arena, NotNull singletonTypes, NotNull rootScope, ModuleName moduleName, + explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); + // Randomize the order in which to dispatch constraints + void randomize(unsigned seed); + /** * Attempts to dispatch all pending constraints and reach a type solution * that satisfies all of the constraints. @@ -85,8 +90,9 @@ struct ConstraintSolver bool done(); - /** Attempt to dispatch a constraint. Returns true if it was successful. - * If tryDispatch() returns false, the constraint remains in the unsolved set and will be retried later. + /** Attempt to dispatch a constraint. Returns true if it was successful. If + * tryDispatch() returns false, the constraint remains in the unsolved set + * and will be retried later. */ bool tryDispatch(NotNull c, bool force); diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index f3735864f..677548830 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -16,7 +16,7 @@ struct TypeMismatch TypeMismatch() = default; TypeMismatch(TypeId wantedType, TypeId givenType); TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason); - TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error); TypeId wantedType = nullptr; TypeId givenType = nullptr; diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 5df6f4b59..b2662c688 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -83,8 +83,13 @@ struct FrontendOptions // is complete. bool retainFullTypeGraphs = false; - // Run typechecking only in mode required for autocomplete (strict mode in order to get more precise type information) + // Run typechecking only in mode required for autocomplete (strict mode in + // order to get more precise type information) bool forAutocomplete = false; + + // If not empty, randomly shuffle the constraint set before attempting to + // solve. Use this value to seed the random number generator. + std::optional randomizeConstraintResolutionSeed; }; struct CheckResult diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 8e8b889b9..41e50d1b6 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -1,9 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Module.h" #include "Luau/NotNull.h" #include "Luau/TypeVar.h" +#include "Luau/UnifierSharedState.h" #include @@ -29,4 +29,231 @@ std::pair normalize( std::pair normalize(TypePackId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice); std::pair normalize(TypePackId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice); +class TypeIds +{ +private: + std::unordered_set types; + std::vector order; + std::size_t hash = 0; + +public: + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + + TypeIds(const TypeIds&) = delete; + TypeIds(TypeIds&&) = default; + TypeIds() = default; + ~TypeIds() = default; + TypeIds& operator=(TypeIds&&) = default; + + void insert(TypeId ty); + /// Erase every element that does not also occur in tys + void retain(const TypeIds& tys); + void clear(); + + iterator begin(); + iterator end(); + const_iterator begin() const; + const_iterator end() const; + iterator erase(const_iterator it); + + size_t size() const; + bool empty() const; + size_t count(TypeId ty) const; + + template + void insert(Iterator begin, Iterator end) + { + for (Iterator it = begin; it != end; ++it) + insert(*it); + } + + bool operator ==(const TypeIds& there) const; + size_t getHash() const; +}; + +} // namespace Luau + +template<> struct std::hash +{ + std::size_t operator()(const Luau::TypeIds& tys) const + { + return tys.getHash(); + } +}; + +template<> struct std::hash +{ + std::size_t operator()(const Luau::TypeIds* tys) const + { + return tys->getHash(); + } +}; + +template<> struct std::equal_to +{ + bool operator()(const Luau::TypeIds& here, const Luau::TypeIds& there) const + { + return here == there; + } +}; + +template<> struct std::equal_to +{ + bool operator()(const Luau::TypeIds* here, const Luau::TypeIds* there) const + { + return *here == *there; + } +}; + +namespace Luau +{ + +// A normalized string type is either `string` (represented by `nullopt`) +// or a union of string singletons. +using NormalizedStringType = std::optional>; + +// A normalized function type is either `never` (represented by `nullopt`) +// or an intersection of function types. +// NOTE: type normalization can fail on function types with generics +// (e.g. because we do not support unions and intersections of generic type packs), +// so this type may contain `error`. +using NormalizedFunctionType = std::optional; + +// A normalized generic/free type is a union, where each option is of the form (X & T) where +// * X is either a free type or a generic +// * T is a normalized type. +struct NormalizedType; +using NormalizedTyvars = std::unordered_map>; + +// A normalized type is either any, unknown, or one of the form P | T | F | G where +// * P is a union of primitive types (including singletons, classes and the error type) +// * T is a union of table types +// * F is a union of an intersection of function types +// * G is a union of generic/free normalized types, intersected with a normalized type +struct NormalizedType +{ + // The top part of the type. + // This type is either never, unknown, or any. + // If this type is not never, all the other fields are null. + TypeId tops; + + // The boolean part of the type. + // This type is either never, boolean type, or a boolean singleton. + TypeId booleans; + + // The class part of the type. + // Each element of this set is a class, and none of the classes are subclasses of each other. + TypeIds classes; + + // The error part of the type. + // This type is either never or the error type. + TypeId errors; + + // The nil part of the type. + // This type is either never or nil. + TypeId nils; + + // The number part of the type. + // This type is either never or number. + TypeId numbers; + + // The string part of the type. + // This may be the `string` type, or a union of singletons. + NormalizedStringType strings = std::map{}; + + // The thread part of the type. + // This type is either never or thread. + TypeId threads; + + // The (meta)table part of the type. + // Each element of this set is a (meta)table type. + TypeIds tables; + + // The function part of the type. + NormalizedFunctionType functions; + + // The generic/free part of the type. + NormalizedTyvars tyvars; + + NormalizedType(NotNull singletonTypes); + + NormalizedType(const NormalizedType&) = delete; + NormalizedType(NormalizedType&&) = default; + NormalizedType() = delete; + ~NormalizedType() = default; + NormalizedType& operator=(NormalizedType&&) = default; + NormalizedType& operator=(NormalizedType&) = delete; +}; + +class Normalizer +{ + std::unordered_map> cachedNormals; + std::unordered_map cachedIntersections; + std::unordered_map cachedUnions; + std::unordered_map> cachedTypeIds; + bool withinResourceLimits(); + +public: + TypeArena* arena; + NotNull singletonTypes; + NotNull sharedState; + + Normalizer(TypeArena* arena, NotNull singletonTypes, NotNull sharedState); + Normalizer(const Normalizer&) = delete; + Normalizer(Normalizer&&) = delete; + Normalizer() = delete; + ~Normalizer() = default; + Normalizer& operator=(Normalizer&&) = delete; + Normalizer& operator=(Normalizer&) = delete; + + // If this returns null, the typechecker should emit a "too complex" error + const NormalizedType* normalize(TypeId ty); + void clearNormal(NormalizedType& norm); + + // ------- Cached TypeIds + TypeId unionType(TypeId here, TypeId there); + TypeId intersectionType(TypeId here, TypeId there); + const TypeIds* cacheTypeIds(TypeIds tys); + void clearCaches(); + + // ------- Normalizing unions + void unionTysWithTy(TypeIds& here, TypeId there); + TypeId unionOfTops(TypeId here, TypeId there); + TypeId unionOfBools(TypeId here, TypeId there); + void unionClassesWithClass(TypeIds& heres, TypeId there); + void unionClasses(TypeIds& heres, const TypeIds& theres); + void unionStrings(NormalizedStringType& here, const NormalizedStringType& there); + std::optional unionOfTypePacks(TypePackId here, TypePackId there); + std::optional unionOfFunctions(TypeId here, TypeId there); + std::optional unionSaturatedFunctions(TypeId here, TypeId there); + void unionFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); + void unionFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); + void unionTablesWithTable(TypeIds& heres, TypeId there); + void unionTables(TypeIds& heres, const TypeIds& theres); + bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); + bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); + + // ------- Normalizing intersections + void intersectTysWithTy(TypeIds& here, TypeId there); + TypeId intersectionOfTops(TypeId here, TypeId there); + TypeId intersectionOfBools(TypeId here, TypeId there); + void intersectClasses(TypeIds& heres, const TypeIds& theres); + void intersectClassesWithClass(TypeIds& heres, TypeId there); + void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); + std::optional intersectionOfTypePacks(TypePackId here, TypePackId there); + std::optional intersectionOfTables(TypeId here, TypeId there); + void intersectTablesWithTable(TypeIds& heres, TypeId there); + void intersectTables(TypeIds& heres, const TypeIds& theres); + std::optional intersectionOfFunctions(TypeId here, TypeId there); + void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); + void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); + bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there); + bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); + bool intersectNormalWithTy(NormalizedType& here, TypeId there); + + // -------- Convert back from a normalized type to a type + TypeId typeFromNormal(const NormalizedType& norm); +}; + } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index e5675ebb3..3184b0d30 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -234,6 +234,8 @@ struct TypeChecker TypeId anyify(const ScopePtr& scope, TypeId ty, Location location); TypePackId anyify(const ScopePtr& scope, TypePackId ty, Location location); + TypePackId anyifyModuleReturnTypePackGenerics(TypePackId ty); + void reportError(const TypeError& error); void reportError(const Location& location, TypeErrorData error); void reportErrors(const ErrorVec& errors); @@ -359,6 +361,7 @@ struct TypeChecker InternalErrorReporter* iceHandler; UnifierSharedState unifierState; + Normalizer normalizer; std::vector requireCycles; diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 0ea175cc4..c43daa21a 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -96,7 +96,7 @@ struct Free bool forwardedTypeAlias = false; private: - static int nextIndex; + static int DEPRECATED_nextIndex; }; template @@ -127,7 +127,7 @@ struct Generic bool explicitName = false; private: - static int nextIndex; + static int DEPRECATED_nextIndex; }; struct Error diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 26a922f5c..f6219dfbe 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -9,6 +9,7 @@ #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" #include "Luau/UnifierSharedState.h" +#include "Normalize.h" #include @@ -52,6 +53,7 @@ struct Unifier { TypeArena* const types; NotNull singletonTypes; + NotNull normalizer; Mode mode; NotNull scope; // const Scope maybe @@ -60,13 +62,14 @@ struct Unifier Location location; Variance variance = Covariant; bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. + bool normalize; // Normalize unions and intersections if necessary bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; UnifierSharedState& sharedState; - Unifier(TypeArena* types, NotNull singletonTypes, Mode mode, NotNull scope, const Location& location, Variance variance, - UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, + TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); @@ -84,6 +87,7 @@ struct Unifier void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv); void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); + void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error = std::nullopt); void tryUnifyPrimitives(TypeId subTy, TypeId superTy); void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); @@ -92,6 +96,8 @@ struct Unifier void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); + TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args); + TypeId widen(TypeId ty); TypePackId widen(TypePackId tp); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1f594fed5..224e94401 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauSelfCallAutocompleteFix3) - static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -139,7 +137,8 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); - Unifier unifier(typeArena, singletonTypes, Mode::Strict, scope, Location(), Variance::Covariant, unifierState); + Normalizer normalizer{typeArena, singletonTypes, NotNull{&unifierState}}; + Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); return unifier.canUnify(subTy, superTy).empty(); } @@ -151,18 +150,6 @@ static TypeCorrectKind checkTypeCorrectKind( NotNull moduleScope{module.getModuleScope().get()}; - auto canUnify = [&typeArena, singletonTypes, moduleScope](TypeId subTy, TypeId superTy) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3); - - InternalErrorReporter iceReporter; - UnifierSharedState unifierState(&iceReporter); - Unifier unifier(typeArena, singletonTypes, Mode::Strict, moduleScope, Location(), Variance::Covariant, unifierState); - - unifier.tryUnify(subTy, superTy); - bool ok = unifier.errors.empty(); - return ok; - }; - auto typeAtPosition = findExpectedTypeAt(module, node, position); if (!typeAtPosition) @@ -170,30 +157,11 @@ static TypeCorrectKind checkTypeCorrectKind( TypeId expectedType = follow(*typeAtPosition); - auto checkFunctionType = [typeArena, singletonTypes, moduleScope, &canUnify, &expectedType](const FunctionTypeVar* ftv) { - if (FFlag::LuauSelfCallAutocompleteFix3) - { - if (std::optional firstRetTy = first(ftv->retTypes)) - return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, singletonTypes); - - return false; - } - else - { - auto [retHead, retTail] = flatten(ftv->retTypes); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return true; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) - { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return true; - } + auto checkFunctionType = [typeArena, singletonTypes, moduleScope, &expectedType](const FunctionTypeVar* ftv) { + if (std::optional firstRetTy = first(ftv->retTypes)) + return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, singletonTypes); - return false; - } + return false; }; // We also want to suggest functions that return compatible result @@ -212,11 +180,8 @@ static TypeCorrectKind checkTypeCorrectKind( } } - if (FFlag::LuauSelfCallAutocompleteFix3) - return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena, singletonTypes) ? TypeCorrectKind::Correct - : TypeCorrectKind::None; - else - return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena, singletonTypes) ? TypeCorrectKind::Correct + : TypeCorrectKind::None; } enum class PropIndexType @@ -230,51 +195,14 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) { - if (FFlag::LuauSelfCallAutocompleteFix3) - rootTy = follow(rootTy); - + rootTy = follow(rootTy); ty = follow(ty); if (seen.count(ty)) return; seen.insert(ty); - auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3); - - if (indexType == PropIndexType::Key) - return false; - - bool colonIndex = indexType == PropIndexType::Colon; - - if (const FunctionTypeVar* ftv = get(type)) - { - return useStrictFunctionIndexers ? colonIndex != ftv->hasSelf : false; - } - else if (const IntersectionTypeVar* itv = get(type)) - { - bool allHaveSelf = true; - for (auto subType : itv->parts) - { - if (const FunctionTypeVar* ftv = get(Luau::follow(subType))) - { - allHaveSelf &= ftv->hasSelf; - } - else - { - return colonIndex; - } - } - return useStrictFunctionIndexers ? colonIndex != allHaveSelf : false; - } - else - { - return colonIndex; - } - }; auto isWrongIndexer = [typeArena, singletonTypes, &module, rootTy, indexType](Luau::TypeId type) { - LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix3); - if (indexType == PropIndexType::Key) return false; @@ -337,7 +265,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul AutocompleteEntryKind::Property, type, prop.deprecated, - FFlag::LuauSelfCallAutocompleteFix3 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), + isWrongIndexer(type), typeCorrect, containingClass, &prop, @@ -380,31 +308,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul { autocompleteProps(module, typeArena, singletonTypes, rootTy, mt->table, indexType, nodes, result, seen); - if (FFlag::LuauSelfCallAutocompleteFix3) - { - if (auto mtable = get(mt->metatable)) - fillMetatableProps(mtable); - } - else - { - auto mtable = get(mt->metatable); - if (!mtable) - return; - - auto indexIt = mtable->props.find("__index"); - if (indexIt != mtable->props.end()) - { - TypeId followed = follow(indexIt->second.type); - if (get(followed) || get(followed)) - autocompleteProps(module, typeArena, singletonTypes, rootTy, followed, indexType, nodes, result, seen); - else if (auto indexFunction = get(followed)) - { - std::optional indexFunctionResult = first(indexFunction->retTypes); - if (indexFunctionResult) - autocompleteProps(module, typeArena, singletonTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); - } - } - } + if (auto mtable = get(mt->metatable)) + fillMetatableProps(mtable); } else if (auto i = get(ty)) { @@ -446,9 +351,6 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul AutocompleteEntryMap inner; std::unordered_set innerSeen; - if (!FFlag::LuauSelfCallAutocompleteFix3) - innerSeen = seen; - if (isNil(*iter)) { ++iter; @@ -472,7 +374,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul ++iter; } } - else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix3) + else if (auto pt = get(ty)) { if (pt->metatable) { @@ -480,7 +382,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul fillMetatableProps(mtable); } } - else if (FFlag::LuauSelfCallAutocompleteFix3 && get(get(ty))) + else if (get(get(ty))) { autocompleteProps(module, typeArena, singletonTypes, rootTy, singletonTypes->stringType, indexType, nodes, result, seen); } @@ -1416,11 +1318,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - if (!FFlag::LuauSelfCallAutocompleteFix3 && isString(ty)) - return {autocompleteProps(*module, &typeArena, singletonTypes, globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), - ancestry, AutocompleteContext::Property}; - else - return {autocompleteProps(*module, &typeArena, singletonTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; + return {autocompleteProps(*module, &typeArena, singletonTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; } else if (auto typeReference = node->as()) { diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 35b8387fd..5b3ec03cc 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -14,6 +14,8 @@ #include "Luau/VisitTypeVar.h" #include "Luau/TypeUtils.h" +#include + LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); LUAU_FASTFLAG(LuauFixNameMaps) @@ -251,10 +253,11 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } } -ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull singletonTypes, NotNull rootScope, ModuleName moduleName, +ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) - : arena(arena) - , singletonTypes(singletonTypes) + : arena(normalizer->arena) + , singletonTypes(normalizer->singletonTypes) + , normalizer(normalizer) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) @@ -278,6 +281,12 @@ ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull sin LUAU_ASSERT(logger); } +void ConstraintSolver::randomize(unsigned seed) +{ + std::mt19937 g(seed); + std::shuffle(begin(unsolvedConstraints), end(unsolvedConstraints), g); +} + void ConstraintSolver::run() { if (done()) @@ -1355,8 +1364,7 @@ bool ConstraintSolver::isBlocked(NotNull constraint) void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) { - UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; + Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; u.useScopes = true; u.tryUnify(subType, superType); @@ -1379,7 +1387,7 @@ void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull sc void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; + Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; u.useScopes = true; u.tryUnify(subPack, superPack); diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index aa496ee40..d13e26c0b 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -511,11 +511,11 @@ TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reas { } -TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error) +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error) : wantedType(wantedType) , givenType(givenType) , reason(reason) - , error(std::make_shared(std::move(error))) + , error(error ? std::make_shared(std::move(*error)) : nullptr) { } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 1890e0811..5705ac17f 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -860,12 +860,18 @@ ModulePtr Frontend::check( const NotNull mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; + Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}}; + ConstraintGraphBuilder cgb{ sourceModule.name, result, &result->internalTypes, mr, singletonTypes, NotNull(&iceHandler), globalScope, logger.get()}; cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{&result->internalTypes, singletonTypes, NotNull(cgb.rootScope), sourceModule.name, NotNull(&moduleResolver), requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), sourceModule.name, NotNull(&moduleResolver), requireCycles, logger.get()}; + + if (options.randomizeConstraintResolutionSeed) + cs.randomize(*options.randomizeConstraintResolutionSeed); + cs.run(); for (TypeError& e : cs.errors) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index b9deac769..45eb87d65 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,6 +14,7 @@ #include +LUAU_FASTFLAG(LuauAnyifyModuleReturnGenerics) LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauForceExportSurfacesToBeNormal, false); @@ -285,13 +286,16 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern } } - for (TypeId ty : returnType) + if (!FFlag::LuauAnyifyModuleReturnGenerics) { - if (get(follow(ty))) + for (TypeId ty : returnType) { - auto t = asMutable(ty); - t->ty = AnyTypeVar{}; - t->normal = true; + if (get(follow(ty))) + { + auto t = asMutable(ty); + t->ty = AnyTypeVar{}; + t->normal = true; + } } } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 42f615172..c008bcfc0 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Normalize.h" +#include "Luau/ToString.h" #include @@ -10,15 +11,1703 @@ #include "Luau/VisitTypeVar.h" LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) +LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); +LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); +LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); +LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau { +void TypeIds::insert(TypeId ty) +{ + ty = follow(ty); + auto [_, fresh] = types.insert(ty); + if (fresh) + { + order.push_back(ty); + hash ^= std::hash{}(ty); + } +} + +void TypeIds::clear() +{ + order.clear(); + types.clear(); + hash = 0; +} + +TypeIds::iterator TypeIds::begin() +{ + return order.begin(); +} + +TypeIds::iterator TypeIds::end() +{ + return order.end(); +} + +TypeIds::const_iterator TypeIds::begin() const +{ + return order.begin(); +} + +TypeIds::const_iterator TypeIds::end() const +{ + return order.end(); +} + +TypeIds::iterator TypeIds::erase(TypeIds::const_iterator it) +{ + TypeId ty = *it; + types.erase(ty); + hash ^= std::hash{}(ty); + return order.erase(it); +} + +size_t TypeIds::size() const +{ + return types.size(); +} + +bool TypeIds::empty() const +{ + return types.empty(); +} + +size_t TypeIds::count(TypeId ty) const +{ + ty = follow(ty); + return types.count(ty); +} + +void TypeIds::retain(const TypeIds& there) +{ + for (auto it = begin(); it != end();) + { + if (there.count(*it)) + it++; + else + it = erase(it); + } +} + +size_t TypeIds::getHash() const +{ + return hash; +} + +bool TypeIds::operator==(const TypeIds& there) const +{ + return hash == there.hash && types == there.types; +} + +NormalizedType::NormalizedType(NotNull singletonTypes) + : tops(singletonTypes->neverType) + , booleans(singletonTypes->neverType) + , errors(singletonTypes->neverType) + , nils(singletonTypes->neverType) + , numbers(singletonTypes->neverType) + , threads(singletonTypes->neverType) +{ +} + +static bool isInhabited(const NormalizedType& norm) +{ + return !get(norm.tops) + || !get(norm.booleans) + || !norm.classes.empty() + || !get(norm.errors) + || !get(norm.nils) + || !get(norm.numbers) + || !norm.strings || !norm.strings->empty() + || !get(norm.threads) + || norm.functions + || !norm.tables.empty() + || !norm.tyvars.empty(); +} + +static int tyvarIndex(TypeId ty) +{ + if (const GenericTypeVar* gtv = get(ty)) + return gtv->index; + else if (const FreeTypeVar* ftv = get(ty)) + return ftv->index; + else + return 0; +} + +#ifdef LUAU_ASSERTENABLED + +static bool isNormalizedTop(TypeId ty) +{ + return get(ty) || get(ty) || get(ty); +} + +static bool isNormalizedBoolean(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::Boolean; + else if (const SingletonTypeVar* stv = get(ty)) + return get(stv); + else + return false; +} + +static bool isNormalizedError(TypeId ty) +{ + if (get(ty) || get(ty)) + return true; + else + return false; +} + +static bool isNormalizedNil(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::NilType; + else + return false; +} + +static bool isNormalizedNumber(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::Number; + else + return false; +} + +static bool isNormalizedString(const NormalizedStringType& ty) +{ + if (!ty) + return true; + + for (auto& [str, ty] : *ty) + { + if (const SingletonTypeVar* stv = get(ty)) + { + if (const StringSingleton* sstv = get(stv)) + { + if (sstv->value != str) + return false; + } + else + return false; + } + else + return false; + } + + return true; +} + +static bool isNormalizedThread(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::Thread; + else + return false; +} + +static bool areNormalizedFunctions(const NormalizedFunctionType& tys) +{ + if (tys) + for (TypeId ty : *tys) + if (!get(ty) && !get(ty)) + return false; + return true; +} + +static bool areNormalizedTables(const TypeIds& tys) +{ + for (TypeId ty : tys) + if (!get(ty) && !get(ty)) + return false; + return true; +} + +static bool areNormalizedClasses(const TypeIds& tys) +{ + for (TypeId ty : tys) + if (!get(ty)) + return false; + return true; +} + +static bool isPlainTyvar(TypeId ty) +{ + return (get(ty) || get(ty)); +} + +static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) +{ + for (auto& [tyvar, intersect] : tyvars) + { + if (!isPlainTyvar(tyvar)) + return false; + if (!isInhabited(*intersect)) + return false; + for (auto& [other, _] : intersect->tyvars) + if (tyvarIndex(other) <= tyvarIndex(tyvar)) + return false; + } + return true; +} + +#endif // LUAU_ASSERTENABLED + +static void assertInvariant(const NormalizedType& norm) +{ + #ifdef LUAU_ASSERTENABLED + if (!FFlag::DebugLuauCheckNormalizeInvariant) + return; + + LUAU_ASSERT(isNormalizedTop(norm.tops)); + LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); + LUAU_ASSERT(areNormalizedClasses(norm.classes)); + LUAU_ASSERT(isNormalizedError(norm.errors)); + LUAU_ASSERT(isNormalizedNil(norm.nils)); + LUAU_ASSERT(isNormalizedNumber(norm.numbers)); + LUAU_ASSERT(isNormalizedString(norm.strings)); + LUAU_ASSERT(isNormalizedThread(norm.threads)); + LUAU_ASSERT(areNormalizedFunctions(norm.functions)); + LUAU_ASSERT(areNormalizedTables(norm.tables)); + LUAU_ASSERT(isNormalizedTyvar(norm.tyvars)); + for (auto& [_, child] : norm.tyvars) + assertInvariant(*child); + #endif +} + +Normalizer::Normalizer(TypeArena* arena, NotNull singletonTypes, NotNull sharedState) + : arena(arena) + , singletonTypes(singletonTypes) + , sharedState(sharedState) +{ +} + +const NormalizedType* Normalizer::normalize(TypeId ty) +{ + if (!arena) + sharedState->iceHandler->ice("Normalizing types outside a module"); + + auto found = cachedNormals.find(ty); + if (found != cachedNormals.end()) + return found->second.get(); + + NormalizedType norm{singletonTypes}; + if (!unionNormalWithTy(norm, ty)) + return nullptr; + std::unique_ptr uniq = std::make_unique(std::move(norm)); + const NormalizedType* result = uniq.get(); + cachedNormals[ty] = std::move(uniq); + return result; +} + +void Normalizer::clearNormal(NormalizedType& norm) +{ + norm.tops = singletonTypes->neverType; + norm.booleans = singletonTypes->neverType; + norm.classes.clear(); + norm.errors = singletonTypes->neverType; + norm.nils = singletonTypes->neverType; + norm.numbers = singletonTypes->neverType; + if (norm.strings) + norm.strings->clear(); + else + norm.strings.emplace(); + norm.threads = singletonTypes->neverType; + norm.tables.clear(); + norm.functions = std::nullopt; + norm.tyvars.clear(); +} + +// ------- Cached TypeIds +const TypeIds* Normalizer::cacheTypeIds(TypeIds tys) +{ + auto found = cachedTypeIds.find(&tys); + if (found != cachedTypeIds.end()) + return found->first; + + std::unique_ptr uniq = std::make_unique(std::move(tys)); + const TypeIds* result = uniq.get(); + cachedTypeIds[result] = std::move(uniq); + return result; +} + +TypeId Normalizer::unionType(TypeId here, TypeId there) +{ + here = follow(here); + there = follow(there); + + if (here == there) + return here; + if (get(here) || get(there)) + return there; + if (get(there) || get(here)) + return here; + + TypeIds tmps; + + if (const UnionTypeVar* utv = get(here)) + { + TypeIds heres; + heres.insert(begin(utv), end(utv)); + tmps.insert(heres.begin(), heres.end()); + cachedUnions[cacheTypeIds(std::move(heres))] = here; + } + else + tmps.insert(here); + + if (const UnionTypeVar* utv = get(there)) + { + TypeIds theres; + theres.insert(begin(utv), end(utv)); + tmps.insert(theres.begin(), theres.end()); + cachedUnions[cacheTypeIds(std::move(theres))] = there; + } + else + tmps.insert(there); + + auto cacheHit = cachedUnions.find(&tmps); + if (cacheHit != cachedUnions.end()) + return cacheHit->second; + + std::vector parts; + parts.insert(parts.end(), tmps.begin(), tmps.end()); + TypeId result = arena->addType(UnionTypeVar{std::move(parts)}); + cachedUnions[cacheTypeIds(std::move(tmps))] = result; + + return result; +} + +TypeId Normalizer::intersectionType(TypeId here, TypeId there) +{ + here = follow(here); + there = follow(there); + + if (here == there) + return here; + if (get(here) || get(there)) + return here; + if (get(there) || get(here)) + return there; + + TypeIds tmps; + + if (const IntersectionTypeVar* utv = get(here)) + { + TypeIds heres; + heres.insert(begin(utv), end(utv)); + tmps.insert(heres.begin(), heres.end()); + cachedIntersections[cacheTypeIds(std::move(heres))] = here; + } + else + tmps.insert(here); + + if (const IntersectionTypeVar* utv = get(there)) + { + TypeIds theres; + theres.insert(begin(utv), end(utv)); + tmps.insert(theres.begin(), theres.end()); + cachedIntersections[cacheTypeIds(std::move(theres))] = there; + } + else + tmps.insert(there); + + if (tmps.size() == 1) + return *tmps.begin(); + + auto cacheHit = cachedIntersections.find(&tmps); + if (cacheHit != cachedIntersections.end()) + return cacheHit->second; + + std::vector parts; + parts.insert(parts.end(), tmps.begin(), tmps.end()); + TypeId result = arena->addType(IntersectionTypeVar{std::move(parts)}); + cachedIntersections[cacheTypeIds(std::move(tmps))] = result; + + return result; +} + +void Normalizer::clearCaches() +{ + cachedNormals.clear(); + cachedIntersections.clear(); + cachedUnions.clear(); + cachedTypeIds.clear(); +} + +// ------- Normalizing unions +TypeId Normalizer::unionOfTops(TypeId here, TypeId there) +{ + if (get(here) || get(there)) + return there; + else + return here; +} + +TypeId Normalizer::unionOfBools(TypeId here, TypeId there) +{ + if (get(here)) + return there; + if (get(there)) + return here; + if (const BooleanSingleton* hbool = get(get(here))) + if (const BooleanSingleton* tbool = get(get(there))) + if (hbool->value == tbool->value) + return here; + return singletonTypes->booleanType; +} + +void Normalizer::unionClassesWithClass(TypeIds& heres, TypeId there) +{ + if (heres.count(there)) + return; + + const ClassTypeVar* tctv = get(there); + + for (auto it = heres.begin(); it != heres.end();) + { + TypeId here = *it; + const ClassTypeVar* hctv = get(here); + if (isSubclass(tctv, hctv)) + return; + else if (isSubclass(hctv, tctv)) + it = heres.erase(it); + else + it++; + } + + heres.insert(there); +} + +void Normalizer::unionClasses(TypeIds& heres, const TypeIds& theres) +{ + for (TypeId there : theres) + unionClassesWithClass(heres, there); +} + +void Normalizer::unionStrings(NormalizedStringType& here, const NormalizedStringType& there) +{ + if (!there) + here.reset(); + else if (here) + here->insert(there->begin(), there->end()); +} + +std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePackId there) +{ + if (here == there) + return here; + + std::vector head; + std::optional tail; + + bool hereSubThere = true; + bool thereSubHere = true; + + TypePackIterator ith = begin(here); + TypePackIterator itt = begin(there); + + while (ith != end(here) && itt != end(there)) + { + TypeId hty = *ith; + TypeId tty = *itt; + TypeId ty = unionType(hty, tty); + if (ty != hty) + thereSubHere = false; + if (ty != tty) + hereSubThere = false; + head.push_back(ty); + ith++; + itt++; + } + + auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere) + { + if (ith != end(here)) + { + TypeId tty = singletonTypes->nilType; + if (std::optional ttail = itt.tail()) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + tty = tvtp->ty; + else + // Luau doesn't have unions of type pack variables + return false; + } + else + // Type packs of different arities are incomparable + return false; + + while (ith != end(here)) + { + TypeId hty = *ith; + TypeId ty = unionType(hty, tty); + if (ty != hty) + thereSubHere = false; + if (ty != tty) + hereSubThere = false; + head.push_back(ty); + ith++; + } + } + return true; + }; + + if (!dealWithDifferentArities(ith, itt, here, there, hereSubThere, thereSubHere)) + return std::nullopt; + + if (!dealWithDifferentArities(itt, ith, there, here, thereSubHere, hereSubThere)) + return std::nullopt; + + if (std::optional htail = ith.tail()) + { + if (std::optional ttail = itt.tail()) + { + if (*htail == *ttail) + tail = htail; + else if (const VariadicTypePack* hvtp = get(*htail)) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + { + TypeId ty = unionType(hvtp->ty, tvtp->ty); + if (ty != hvtp->ty) + thereSubHere = false; + if (ty != tvtp->ty) + hereSubThere = false; + bool hidden = hvtp->hidden & tvtp->hidden; + tail = arena->addTypePack(VariadicTypePack{ty,hidden}); + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (get(*htail)) + { + hereSubThere = false; + tail = htail; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (std::optional ttail = itt.tail()) + { + if (get(*ttail)) + { + thereSubHere = false; + tail = htail; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + + if (hereSubThere) + return there; + else if (thereSubHere) + return here; + if (!head.empty()) + return arena->addTypePack(TypePack{head,tail}); + else if (tail) + return *tail; + else + // TODO: Add an emptyPack to singleton types + return arena->addTypePack({}); +} + +std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) +{ + if (get(here)) + return here; + + if (get(there)) + return there; + + const FunctionTypeVar* hftv = get(here); + LUAU_ASSERT(hftv); + const FunctionTypeVar* tftv = get(there); + LUAU_ASSERT(tftv); + + if (hftv->generics != tftv->generics) + return std::nullopt; + if (hftv->genericPacks != tftv->genericPacks) + return std::nullopt; + + std::optional argTypes = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypes) + return std::nullopt; + + std::optional retTypes = unionOfTypePacks(hftv->retTypes, tftv->retTypes); + if (!retTypes) + return std::nullopt; + + if (*argTypes == hftv->argTypes && *retTypes == hftv->retTypes) + return here; + if (*argTypes == tftv->argTypes && *retTypes == tftv->retTypes) + return there; + + FunctionTypeVar result{*argTypes, *retTypes}; + result.generics = hftv->generics; + result.genericPacks = hftv->genericPacks; + return arena->addType(std::move(result)); +} + +void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) +{ + if (!theres) + return; + + TypeIds tmps; + + if (!heres) + { + tmps.insert(theres->begin(), theres->end()); + heres = std::move(tmps); + return; + } + + for (TypeId here : *heres) + for (TypeId there : *theres) + { + if (std::optional fun = unionOfFunctions(here, there)) + tmps.insert(*fun); + else + tmps.insert(singletonTypes->errorRecoveryType(there)); + } + + heres = std::move(tmps); +} + +void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) +{ + if (!heres) + { + TypeIds tmps; + tmps.insert(there); + heres = std::move(tmps); + return; + } + + TypeIds tmps; + for (TypeId here : *heres) + { + if (std::optional fun = unionOfFunctions(here, there)) + tmps.insert(*fun); + else + tmps.insert(singletonTypes->errorRecoveryType(there)); + } + heres = std::move(tmps); +} + +void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) +{ + // TODO: remove unions of tables where possible + heres.insert(there); +} + +void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) +{ + for (TypeId there : theres) + unionTablesWithTable(heres, there); +} + +// So why `ignoreSmallerTyvars`? +// +// First up, what it does... Every tyvar has an index, and this parameter says to ignore +// any tyvars in `there` if their index is less than or equal to the parameter. +// The parameter is always greater than any tyvars mentioned in here, so the result is +// a lower bound on any tyvars in `here.tyvars`. +// +// This is used to maintain in invariant, which is that in any tyvar `X&T`, any any tyvar +// `Y&U` in `T`, the index of `X` is less than the index of `Y`. This is an implementation +// of *ordered decision diagrams* (https://en.wikipedia.org/wiki/Binary_decision_diagram#Variable_ordering) +// which are a compression technique used to save memory usage when representing boolean formulae. +// +// The idea is that if you have an out-of-order decision diagram +// like `Z&(X|Y)`, to re-order it in this case to `(X&Z)|(Y&Z)`. +// The hope is that by imposing a global order, there's a higher chance of sharing opportunities, +// and hence reduced memory. +// +// And yes, this is essentially a SAT solver hidden inside a typechecker. +// That's what you get for having a type system with generics, intersection and union types. +bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) +{ + TypeId tops = unionOfTops(here.tops, there.tops); + if (!get(tops)) + { + clearNormal(here); + here.tops = tops; + return true; + } + + for (auto it = there.tyvars.begin(); it != there.tyvars.end(); it++) + { + TypeId tyvar = it->first; + const NormalizedType& inter = *it->second; + int index = tyvarIndex(tyvar); + if (index <= ignoreSmallerTyvars) + continue; + auto [emplaced, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{singletonTypes})); + if (fresh) + if (!unionNormals(*emplaced->second, here, index)) + return false; + if (!unionNormals(*emplaced->second, inter, index)) + return false; + } + + here.booleans = unionOfBools(here.booleans, there.booleans); + unionClasses(here.classes, there.classes); + here.errors = (get(there.errors) ? here.errors : there.errors); + here.nils = (get(there.nils) ? here.nils : there.nils); + here.numbers = (get(there.numbers) ? here.numbers : there.numbers); + unionStrings(here.strings, there.strings); + here.threads = (get(there.threads) ? here.threads : there.threads); + unionFunctions(here.functions, there.functions); + unionTables(here.tables, there.tables); + return true; +} + +bool Normalizer::withinResourceLimits() +{ + // If cache is too large, clear it + if (FInt::LuauNormalizeCacheLimit > 0) + { + size_t cacheUsage = cachedNormals.size() + cachedIntersections.size() + cachedUnions.size() + cachedTypeIds.size(); + if (cacheUsage > size_t(FInt::LuauNormalizeCacheLimit)) + { + clearCaches(); + return false; + } + } + + // Check the recursion count + if (sharedState->counters.recursionLimit > 0) + if (sharedState->counters.recursionLimit < sharedState->counters.recursionCount) + return false; + + return true; +} + +// See above for an explaination of `ignoreSmallerTyvars`. +bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars) +{ + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return false; + + there = follow(there); + if (get(there) || get(there)) + { + TypeId tops = unionOfTops(here.tops, there); + clearNormal(here); + here.tops = tops; + return true; + } + else if (get(there) || !get(here.tops)) + return true; + else if (const UnionTypeVar* utv = get(there)) + { + for (UnionTypeVarIterator it = begin(utv); it != end(utv); ++it) + if (!unionNormalWithTy(here, *it)) + return false; + return true; + } + else if (const IntersectionTypeVar* itv = get(there)) + { + NormalizedType norm{singletonTypes}; + norm.tops = singletonTypes->anyType; + for (IntersectionTypeVarIterator it = begin(itv); it != end(itv); ++it) + if (!intersectNormalWithTy(norm, *it)) + return false; + return unionNormals(here, norm); + } + else if (get(there) || get(there)) + { + if (tyvarIndex(there) <= ignoreSmallerTyvars) + return true; + NormalizedType inter{singletonTypes}; + inter.tops = singletonTypes->unknownType; + here.tyvars.insert_or_assign(there, std::make_unique(std::move(inter))); + } + else if (get(there)) + unionFunctionsWithFunction(here.functions, there); + else if (get(there) || get(there)) + unionTablesWithTable(here.tables, there); + else if (get(there)) + unionClassesWithClass(here.classes, there); + else if (get(there)) + here.errors = there; + else if (const PrimitiveTypeVar* ptv = get(there)) + { + if (ptv->type == PrimitiveTypeVar::Boolean) + here.booleans = there; + else if (ptv->type == PrimitiveTypeVar::NilType) + here.nils = there; + else if (ptv->type == PrimitiveTypeVar::Number) + here.numbers = there; + else if (ptv->type == PrimitiveTypeVar::String) + here.strings = std::nullopt; + else if (ptv->type == PrimitiveTypeVar::Thread) + here.threads = there; + else + LUAU_ASSERT(!"Unreachable"); + } + else if (const SingletonTypeVar* stv = get(there)) + { + if (get(stv)) + here.booleans = unionOfBools(here.booleans, there); + else if (const StringSingleton* sstv = get(stv)) + { + if (here.strings) + here.strings->insert({sstv->value, there}); + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + LUAU_ASSERT(!"Unreachable"); + + for (auto& [tyvar, intersect] : here.tyvars) + if (!unionNormalWithTy(*intersect, there, tyvarIndex(tyvar))) + return false; + + assertInvariant(here); + return true; +} + +// ------- Normalizing intersections +TypeId Normalizer::intersectionOfTops(TypeId here, TypeId there) +{ + if (get(here) || get(there)) + return here; + else + return there; +} + +TypeId Normalizer::intersectionOfBools(TypeId here, TypeId there) +{ + if (get(here)) + return here; + if (get(there)) + return there; + if (const BooleanSingleton* hbool = get(get(here))) + if (const BooleanSingleton* tbool = get(get(there))) + return (hbool->value == tbool->value ? here : singletonTypes->neverType); + else + return here; + else + return there; +} + +void Normalizer::intersectClasses(TypeIds& heres, const TypeIds& theres) +{ + TypeIds tmp; + for (auto it = heres.begin(); it != heres.end();) + { + const ClassTypeVar* hctv = get(*it); + LUAU_ASSERT(hctv); + bool keep = false; + for (TypeId there : theres) + { + const ClassTypeVar* tctv = get(there); + LUAU_ASSERT(tctv); + if (isSubclass(hctv, tctv)) + { + keep = true; + break; + } + else if (isSubclass(tctv, hctv)) + { + keep = false; + tmp.insert(there); + break; + } + } + if (keep) + it++; + else + it = heres.erase(it); + } + heres.insert(tmp.begin(), tmp.end()); +} + +void Normalizer::intersectClassesWithClass(TypeIds& heres, TypeId there) +{ + bool foundSuper = false; + const ClassTypeVar* tctv = get(there); + LUAU_ASSERT(tctv); + for (auto it = heres.begin(); it != heres.end();) + { + const ClassTypeVar* hctv = get(*it); + LUAU_ASSERT(hctv); + if (isSubclass(hctv, tctv)) + it++; + else if (isSubclass(tctv, hctv)) + { + foundSuper = true; + break; + } + else + it = heres.erase(it); + } + if (foundSuper) + { + heres.clear(); + heres.insert(there); + } +} + +void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) +{ + if (!there) + return; + if (!here) + here.emplace(); + + for (auto it = here->begin(); it != here->end();) + { + if (there->count(it->first)) + it++; + else + it = here->erase(it); + } +} + +std::optional Normalizer::intersectionOfTypePacks(TypePackId here, TypePackId there) +{ + if (here == there) + return here; + + std::vector head; + std::optional tail; + + bool hereSubThere = true; + bool thereSubHere = true; + + TypePackIterator ith = begin(here); + TypePackIterator itt = begin(there); + + while (ith != end(here) && itt != end(there)) + { + TypeId hty = *ith; + TypeId tty = *itt; + TypeId ty = intersectionType(hty, tty); + if (ty != hty) + hereSubThere = false; + if (ty != tty) + thereSubHere = false; + head.push_back(ty); + ith++; + itt++; + } + + auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere) + { + if (ith != end(here)) + { + TypeId tty = singletonTypes->nilType; + if (std::optional ttail = itt.tail()) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + tty = tvtp->ty; + else + // Luau doesn't have intersections of type pack variables + return false; + } + else + // Type packs of different arities are incomparable + return false; + + while (ith != end(here)) + { + TypeId hty = *ith; + TypeId ty = intersectionType(hty, tty); + if (ty != hty) + hereSubThere = false; + if (ty != tty) + thereSubHere = false; + head.push_back(ty); + ith++; + } + } + return true; + }; + + if (!dealWithDifferentArities(ith, itt, here, there, hereSubThere, thereSubHere)) + return std::nullopt; + + if (!dealWithDifferentArities(itt, ith, there, here, thereSubHere, hereSubThere)) + return std::nullopt; + + if (std::optional htail = ith.tail()) + { + if (std::optional ttail = itt.tail()) + { + if (*htail == *ttail) + tail = htail; + else if (const VariadicTypePack* hvtp = get(*htail)) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + { + TypeId ty = intersectionType(hvtp->ty, tvtp->ty); + if (ty != hvtp->ty) + thereSubHere = false; + if (ty != tvtp->ty) + hereSubThere = false; + bool hidden = hvtp->hidden & tvtp->hidden; + tail = arena->addTypePack(VariadicTypePack{ty,hidden}); + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (get(*htail)) + hereSubThere = false; + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (std::optional ttail = itt.tail()) + { + if (get(*ttail)) + thereSubHere = false; + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + + if (hereSubThere) + return here; + else if (thereSubHere) + return there; + if (!head.empty()) + return arena->addTypePack(TypePack{head,tail}); + else if (tail) + return *tail; + else + // TODO: Add an emptyPack to singleton types + return arena->addTypePack({}); +} + +std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there) +{ + if (here == there) + return here; + + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (sharedState->counters.recursionLimit > 0 && sharedState->counters.recursionLimit < sharedState->counters.recursionCount) + return std::nullopt; + + TypeId htable = here; + TypeId hmtable = nullptr; + if (const MetatableTypeVar* hmtv = get(here)) + { + htable = hmtv->table; + hmtable = hmtv->metatable; + } + TypeId ttable = there; + TypeId tmtable = nullptr; + if (const MetatableTypeVar* tmtv = get(there)) + { + ttable = tmtv->table; + tmtable = tmtv->metatable; + } + + const TableTypeVar* httv = get(htable); + LUAU_ASSERT(httv); + const TableTypeVar* tttv = get(ttable); + LUAU_ASSERT(tttv); + + if (httv->state == TableState::Free || tttv->state == TableState::Free) + return std::nullopt; + if (httv->state == TableState::Generic || tttv->state == TableState::Generic) + return std::nullopt; + + TableState state = httv->state; + if (tttv->state == TableState::Unsealed) + state = tttv->state; + + TypeLevel level = max(httv->level, tttv->level); + TableTypeVar result{state, level}; + + bool hereSubThere = true; + bool thereSubHere = true; + + for (const auto& [name, hprop] : httv->props) + { + Property prop = hprop; + auto tfound = tttv->props.find(name); + if (tfound == tttv->props.end()) + thereSubHere = false; + else + { + const auto& [_name, tprop] = *tfound; + // TODO: variance issues here, which can't be fixed until we have read/write property types + prop.type = intersectionType(hprop.type, tprop.type); + hereSubThere &= (prop.type == hprop.type); + thereSubHere &= (prop.type == tprop.type); + } + // TODO: string indexers + result.props[name] = prop; + } + + for (const auto& [name, tprop] : tttv->props) + { + if (httv->props.count(name) == 0) + { + result.props[name] = tprop; + hereSubThere = false; + } + } + + if (httv->indexer && tttv->indexer) + { + // TODO: What should intersection of indexes be? + TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType); + TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType); + result.indexer = {index, indexResult}; + hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult); + thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult); + } + else if (httv->indexer) + { + result.indexer = httv->indexer; + thereSubHere = false; + } + else if (tttv->indexer) + { + result.indexer = tttv->indexer; + hereSubThere = false; + } + + TypeId table; + if (hereSubThere) + table = htable; + else if (thereSubHere) + table = ttable; + else + table = arena->addType(std::move(result)); + + if (tmtable && hmtable) + { + // NOTE: this assumes metatables are ivariant + if (std::optional mtable = intersectionOfTables(hmtable, tmtable)) + { + if (table == htable && *mtable == hmtable) + return here; + else if (table == ttable && *mtable == tmtable) + return there; + else + return arena->addType(MetatableTypeVar{table, *mtable}); + } + else + return std::nullopt; + + } + else if (hmtable) + { + if (table == htable) + return here; + else + return arena->addType(MetatableTypeVar{table, hmtable}); + } + else if (tmtable) + { + if (table == ttable) + return there; + else + return arena->addType(MetatableTypeVar{table, tmtable}); + } + else + return table; +} + +void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there) +{ + TypeIds tmp; + for (TypeId here : heres) + if (std::optional inter = intersectionOfTables(here, there)) + tmp.insert(*inter); + heres.retain(tmp); + heres.insert(tmp.begin(), tmp.end()); +} + +void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres) +{ + TypeIds tmp; + for (TypeId here : heres) + for (TypeId there : theres) + if (std::optional inter = intersectionOfTables(here, there)) + tmp.insert(*inter); + heres.retain(tmp); + heres.insert(tmp.begin(), tmp.end()); +} + +std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId there) +{ + const FunctionTypeVar* hftv = get(here); + LUAU_ASSERT(hftv); + const FunctionTypeVar* tftv = get(there); + LUAU_ASSERT(tftv); + + if (hftv->generics != tftv->generics) + return std::nullopt; + if (hftv->genericPacks != tftv->genericPacks) + return std::nullopt; + if (hftv->retTypes != tftv->retTypes) + return std::nullopt; + + std::optional argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypes) + return std::nullopt; + + if (*argTypes == hftv->argTypes) + return here; + if (*argTypes == tftv->argTypes) + return there; + + FunctionTypeVar result{*argTypes, hftv->retTypes}; + result.generics = hftv->generics; + result.genericPacks = hftv->genericPacks; + return arena->addType(std::move(result)); +} + +std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId there) +{ + // Deep breath... + // + // When we come to check overloaded functions for subtyping, + // we have to compare (F1 & ... & FM) <: (G1 & ... G GN) + // where each Fi or Gj is a function type. Now that intersection on the right is no + // problem, since that's true if and only if (F1 & ... & FM) <: Gj for every j. + // But the intersection on the left is annoying, since we might have + // (F1 & ... & FM) <: G but no Fi <: G. For example + // + // ((number? -> number?) & (string? -> string?)) <: (nil -> nil) + // + // So in this case, what we do is define Apply for the result of applying + // a function of type F to an argument of type T, and then F <: (T -> U) + // if and only if Apply <: U. For example: + // + // if f : ((number? -> number?) & (string? -> string?)) + // then f(nil) must be nil, so + // Apply<((number? -> number?) & (string? -> string?)), nil> is nil + // + // So subtyping on overloaded functions "just" boils down to defining Apply. + // + // Now for non-overloaded functions, this is easy! + // Apply<(R -> S), T> is S if T <: R, and an error type otherwise. + // + // But for overloaded functions it's not so simple. We'd like Apply + // to just be Apply & ... & Apply but oh dear + // + // if f : ((number -> number) & (string -> string)) + // and x : (number | string) + // then f(x) : (number | string) + // + // so we want + // + // Apply<((number -> number) & (string -> string)), (number | string)> is (number | string) + // + // but + // + // Apply<(number -> number), (number | string)> is an error + // Apply<(string -> string), (number | string)> is an error + // + // that is Apply should consider all possible combinations of overloads of F, + // not just individual overloads. + // + // For this reason, when we're normalizing function types (in order to check subtyping + // or perform overload resolution) we should first *union-saturate* them. An overloaded + // function is union-saturated whenever: + // + // if (R -> S) is an overload of F + // and (T -> U) is an overload of F + // then ((R | T) -> (S | U)) is a subtype of an overload of F + // + // Any overloaded function can be normalized to a union-saturated one by adding enough extra overloads. + // For example, union-saturating + // + // ((number -> number) & (string -> string)) + // + // is + // + // ((number -> number) & (string -> string) & ((number | string) -> (number | string))) + // + // For union-saturated overloaded functions, the "obvious" algorithm works: + // + // Apply is Apply & ... & Apply + // + // so we can define Apply, so we can perform overloaded function resolution + // and check subtyping on overloaded function types, yay! + // + // This is yet another potential source of exponential blow-up, sigh, since + // the union-saturation of a function with N overloads may have 2^N overloads + // (one for every subset). In practice, that hopefully won't happen that often, + // in particular we only union-saturate overloads with different return types, + // and there are hopefully not very many cases of that. + // + // All of this is mechanically verified in Agda, at https://github.com/luau-lang/agda-typeck + // + // It is essentially the algorithm defined in https://pnwamk.github.io/sst-tutorial/ + // except that we're precomputing the union-saturation rather than converting + // to disjunctive normal form on the fly. + // + // This is all built on semantic subtyping: + // + // Covariance and Contravariance, Giuseppe Castagna, + // Logical Methods in Computer Science 16(1), 2022 + // https://arxiv.org/abs/1809.01427 + // + // A gentle introduction to semantic subtyping, Giuseppe Castagna and Alain Frisch, + // Proc. Principles and practice of declarative programming 2005, pp 198–208 + // https://doi.org/10.1145/1069774.1069793 + + const FunctionTypeVar* hftv = get(here); + if (!hftv) + return std::nullopt; + const FunctionTypeVar* tftv = get(there); + if (!tftv) + return std::nullopt; + + if (hftv->generics != tftv->generics) + return std::nullopt; + if (hftv->genericPacks != tftv->genericPacks) + return std::nullopt; + + std::optional argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypes) + return std::nullopt; + std::optional retTypes = unionOfTypePacks(hftv->retTypes, tftv->retTypes); + if (!retTypes) + return std::nullopt; + + FunctionTypeVar result{*argTypes, *retTypes}; + result.generics = hftv->generics; + result.genericPacks = hftv->genericPacks; + return arena->addType(std::move(result)); +} + +void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) +{ + if (!heres) + return; + + for (auto it = heres->begin(); it != heres->end();) + { + TypeId here = *it; + if (get(here)) + it++; + else if (std::optional tmp = intersectionOfFunctions(here, there)) + { + heres->erase(it); + heres->insert(*tmp); + return; + } + else + it++; + } + + TypeIds tmps; + for (TypeId here : *heres) + { + if (std::optional tmp = unionSaturatedFunctions(here, there)) + tmps.insert(*tmp); + } + heres->insert(there); + heres->insert(tmps.begin(), tmps.end()); +} + +void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) +{ + if (!heres) + return; + else if (!theres) + { + heres = std::nullopt; + return; + } + else + { + for (TypeId there : *theres) + intersectFunctionsWithFunction(heres, there); + } +} + +bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there) +{ + for (auto it = here.begin(); it != here.end();) + { + NormalizedType& inter = *it->second; + if (!intersectNormalWithTy(inter, there)) + return false; + if (isInhabited(inter)) + ++it; + else + it = here.erase(it); + } + return true; +} + +// See above for an explaination of `ignoreSmallerTyvars`. +bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) +{ + if (!get(there.tops)) + { + here.tops = intersectionOfTops(here.tops, there.tops); + return true; + } + else if (!get(here.tops)) + { + clearNormal(here); + return unionNormals(here, there, ignoreSmallerTyvars); + } + + here.booleans = intersectionOfBools(here.booleans, there.booleans); + intersectClasses(here.classes, there.classes); + here.errors = (get(there.errors) ? there.errors : here.errors); + here.nils = (get(there.nils) ? there.nils : here.nils); + here.numbers = (get(there.numbers) ? there.numbers : here.numbers); + intersectStrings(here.strings, there.strings); + here.threads = (get(there.threads) ? there.threads : here.threads); + intersectFunctions(here.functions, there.functions); + intersectTables(here.tables, there.tables); + + for (auto& [tyvar, inter] : there.tyvars) + { + int index = tyvarIndex(tyvar); + if (ignoreSmallerTyvars < index) + { + auto [found, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{singletonTypes})); + if (fresh) + { + if (!unionNormals(*found->second, here, index)) + return false; + } + } + } + for (auto it = here.tyvars.begin(); it != here.tyvars.end();) + { + TypeId tyvar = it->first; + NormalizedType& inter = *it->second; + int index = tyvarIndex(tyvar); + LUAU_ASSERT(ignoreSmallerTyvars < index); + auto found = there.tyvars.find(tyvar); + if (found == there.tyvars.end()) + { + if (!intersectNormals(inter, there, index)) + return false; + } + else + { + if (!intersectNormals(inter, *found->second, index)) + return false; + } + if (isInhabited(inter)) + it++; + else + it = here.tyvars.erase(it); + } + return true; +} + +bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) +{ + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return false; + + there = follow(there); + if (get(there) || get(there)) + { + here.tops = intersectionOfTops(here.tops, there); + return true; + } + else if (!get(here.tops)) + { + clearNormal(here); + return unionNormalWithTy(here, there); + } + else if (const UnionTypeVar* utv = get(there)) + { + NormalizedType norm{singletonTypes}; + for (UnionTypeVarIterator it = begin(utv); it != end(utv); ++it) + if (!unionNormalWithTy(norm, *it)) + return false; + return intersectNormals(here, norm); + } + else if (const IntersectionTypeVar* itv = get(there)) + { + for (IntersectionTypeVarIterator it = begin(itv); it != end(itv); ++it) + if (!intersectNormalWithTy(here, *it)) + return false; + return true; + } + else if (get(there) || get(there)) + { + NormalizedType thereNorm{singletonTypes}; + NormalizedType topNorm{singletonTypes}; + topNorm.tops = singletonTypes->unknownType; + thereNorm.tyvars.insert_or_assign(there, std::make_unique(std::move(topNorm))); + return intersectNormals(here, thereNorm); + } + + NormalizedTyvars tyvars = std::move(here.tyvars); + + if (const FunctionTypeVar* utv = get(there)) + { + NormalizedFunctionType functions = std::move(here.functions); + clearNormal(here); + intersectFunctionsWithFunction(functions, there); + here.functions = std::move(functions); + } + else if (get(there) || get(there)) + { + TypeIds tables = std::move(here.tables); + clearNormal(here); + intersectTablesWithTable(tables, there); + here.tables = std::move(tables); + } + else if (get(there)) + { + TypeIds classes = std::move(here.classes); + clearNormal(here); + intersectClassesWithClass(classes, there); + here.classes = std::move(classes); + } + else if (get(there)) + { + TypeId errors = here.errors; + clearNormal(here); + here.errors = errors; + } + else if (const PrimitiveTypeVar* ptv = get(there)) + { + TypeId booleans = here.booleans; + TypeId nils = here.nils; + TypeId numbers = here.numbers; + NormalizedStringType strings = std::move(here.strings); + TypeId threads = here.threads; + + clearNormal(here); + + if (ptv->type == PrimitiveTypeVar::Boolean) + here.booleans = booleans; + else if (ptv->type == PrimitiveTypeVar::NilType) + here.nils = nils; + else if (ptv->type == PrimitiveTypeVar::Number) + here.numbers = numbers; + else if (ptv->type == PrimitiveTypeVar::String) + here.strings = std::move(strings); + else if (ptv->type == PrimitiveTypeVar::Thread) + here.threads = threads; + else + LUAU_ASSERT(!"Unreachable"); + } + else if (const SingletonTypeVar* stv = get(there)) + { + TypeId booleans = here.booleans; + NormalizedStringType strings = std::move(here.strings); + + clearNormal(here); + + if (get(stv)) + here.booleans = intersectionOfBools(booleans, there); + else if (const StringSingleton* sstv = get(stv)) + { + if (!strings || strings->count(sstv->value)) + here.strings->insert({sstv->value, there}); + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + LUAU_ASSERT(!"Unreachable"); + + if (!intersectTyvarsWithTy(tyvars, there)) + return false; + here.tyvars = std::move(tyvars); + + return true; +} + +// -------- Convert back from a normalized type to a type +TypeId Normalizer::typeFromNormal(const NormalizedType& norm) +{ + assertInvariant(norm); + if (!get(norm.tops)) + return norm.tops; + + std::vector result; + + if (!get(norm.booleans)) + result.push_back(norm.booleans); + result.insert(result.end(), norm.classes.begin(), norm.classes.end()); + if (!get(norm.errors)) + result.push_back(norm.errors); + if (norm.functions) + { + if (norm.functions->size() == 1) + result.push_back(*norm.functions->begin()); + else + { + std::vector parts; + parts.insert(parts.end(), norm.functions->begin(), norm.functions->end()); + result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); + } + } + if (!get(norm.nils)) + result.push_back(norm.nils); + if (!get(norm.numbers)) + result.push_back(norm.numbers); + if (norm.strings) + for (auto& [_, ty] : *norm.strings) + result.push_back(ty); + else + result.push_back(singletonTypes->stringType); + result.insert(result.end(), norm.tables.begin(), norm.tables.end()); + for (auto& [tyvar, intersect] : norm.tyvars) + { + if (get(intersect->tops)) + { + TypeId ty = typeFromNormal(*intersect); + result.push_back(arena->addType(IntersectionTypeVar{{tyvar, ty}})); + } + else + result.push_back(tyvar); + } + + if (result.size() == 0) + return singletonTypes->neverType; + else if (result.size() == 1) + return result[0]; + else + return arena->addType(UnionTypeVar{std::move(result)}); +} namespace { @@ -59,7 +1748,8 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, N { UnifierSharedState sharedState{&ice}; TypeArena arena; - Unifier u{&arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; u.anyIsTop = anyIsTop; u.tryUnify(subPack, superPack); @@ -686,3 +2377,4 @@ std::pair normalize(TypePackId tp, const ModulePtr& module, No } } // namespace Luau + diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index f98a2123e..b2f3cfd3f 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -280,7 +280,8 @@ struct TypeChecker2 TypePackId actualRetType = reconstructPack(ret->list, arena); UnifierSharedState sharedState{&ice}; - Unifier u{&arena, singletonTypes, Mode::Strict, stack.back(), ret->location, Covariant, sharedState}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; u.anyIsTop = true; u.tryUnify(actualRetType, expectedRetType); @@ -1206,7 +1207,8 @@ struct TypeChecker2 ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy) { UnifierSharedState sharedState{&ice}; - Unifier u{&module->internalTypes, singletonTypes, Mode::Strict, scope, location, Covariant, sharedState}; + Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; u.anyIsTop = true; u.tryUnify(subTy, superTy); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b96046be7..cb21aa7fe 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,19 +32,21 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(LuauFunctionArgMismatchDetails, false) -LUAU_FASTFLAGVARIABLE(LuauInplaceDemoteSkipAllBound, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix3, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) +LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauCallUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) +LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) LUAU_FASTFLAGVARIABLE(LuauUnionOfTypesFollow, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) @@ -255,6 +257,7 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull singl , singletonTypes(singletonTypes) , iceHandler(iceHandler) , unifierState(iceHandler) + , normalizer(nullptr, singletonTypes, NotNull{&unifierState}) , nilType(singletonTypes->nilType) , numberType(singletonTypes->numberType) , stringType(singletonTypes->stringType) @@ -301,12 +304,13 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); - currentModule.reset(new Module()); + currentModule.reset(new Module); currentModule->type = module.type; currentModule->allocator = module.allocator; currentModule->names = module.names; iceHandler->moduleName = module.name; + normalizer.arena = ¤tModule->internalTypes; if (FFlag::LuauAutocompleteDynamicLimits) { @@ -351,15 +355,23 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); else - { moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); - } + + if (FFlag::LuauAnyifyModuleReturnGenerics) + moduleScope->returnType = anyifyModuleReturnTypePackGenerics(moduleScope->returnType); for (auto& [_, typeFun] : moduleScope->exportedTypeBindings) typeFun.type = anyify(moduleScope, typeFun.type, Location{}); prepareErrorsForDisplay(currentModule->errors); + if (FFlag::LuauTypeNormalization2) + { + // Clear the normalizer caches, since they contain types from the internal type surface + normalizer.clearCaches(); + normalizer.arena = nullptr; + } + currentModule->clonePublicInterface(singletonTypes, *iceHandler); // Clear unifier cache since it's keyed off internal types that get deallocated @@ -474,7 +486,7 @@ struct InplaceDemoter : TypeVarOnceVisitor TypeArena* arena; InplaceDemoter(TypeLevel level, TypeArena* arena) - : TypeVarOnceVisitor(/* skipBoundTypes= */ FFlag::LuauInplaceDemoteSkipAllBound) + : TypeVarOnceVisitor(/* skipBoundTypes= */ true) , newLevel(level) , arena(arena) { @@ -494,12 +506,6 @@ struct InplaceDemoter : TypeVarOnceVisitor return false; } - bool visit(TypeId ty, const BoundTypeVar& btyRef) override - { - LUAU_ASSERT(!FFlag::LuauInplaceDemoteSkipAllBound); - return true; - } - bool visit(TypeId ty) override { if (ty->owningArena != arena) @@ -1029,8 +1035,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) if (right) { - if (!maybeGeneric(left) && isGeneric(right)) - right = instantiate(scope, right, loc); + if (!FFlag::LuauInstantiateInSubtyping) + { + if (!maybeGeneric(left) && isGeneric(right)) + right = instantiate(scope, right, loc); + } // Setting a table entry to nil doesn't mean nil is the type of the indexer, it is just deleting the entry const TableTypeVar* destTableTypeReceivingNil = nullptr; @@ -1104,7 +1113,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) variableTypes.push_back(ty); expectedTypes.push_back(ty); - instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); + // with FFlag::LuauInstantiateInSubtyping enabled, we shouldn't need to produce instantiateGenerics at all. + if (!FFlag::LuauInstantiateInSubtyping) + instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); } if (local.values.size > 0) @@ -1729,9 +1740,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar { ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - - if (FFlag::LuauSelfCallAutocompleteFix3) - ftv->hasSelf = true; + ftv->hasSelf = true; } } @@ -1905,8 +1914,18 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (get(varargPack)) { - std::vector types = flatten(varargPack).first; - return {!types.empty() ? types[0] : nilType}; + if (FFlag::LuauFixVarargExprHeadType) + { + if (std::optional ty = first(varargPack)) + return {*ty}; + + return {nilType}; + } + else + { + std::vector types = flatten(varargPack).first; + return {!types.empty() ? types[0] : nilType}; + } } else if (get(varargPack)) { @@ -3967,7 +3986,10 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam } else { - unifyWithInstantiationIfNeeded(*argIter, *paramIter, scope, state); + if (FFlag::LuauInstantiateInSubtyping) + state.tryUnify(*argIter, *paramIter, /*isFunctionCall*/ false); + else + unifyWithInstantiationIfNeeded(*argIter, *paramIter, scope, state); ++argIter; ++paramIter; } @@ -4523,8 +4545,11 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons TypeId actualType = substituteFreeForNil && expr->is() ? freshType(scope) : type; - if (instantiateGenerics.size() > i && instantiateGenerics[i]) - actualType = instantiate(scope, actualType, expr->location); + if (!FFlag::LuauInstantiateInSubtyping) + { + if (instantiateGenerics.size() > i && instantiateGenerics[i]) + actualType = instantiate(scope, actualType, expr->location); + } if (expectedType) { @@ -4686,6 +4711,8 @@ bool TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, Unifier& state) { + LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); + if (!maybeGeneric(subTy)) // Quick check to see if we definitely can't instantiate state.tryUnify(subTy, superTy, /*isFunctionCall*/ false); @@ -4828,6 +4855,33 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo } } +TypePackId TypeChecker::anyifyModuleReturnTypePackGenerics(TypePackId tp) +{ + tp = follow(tp); + + if (const VariadicTypePack* vtp = get(tp)) + return get(vtp->ty) ? anyTypePack : tp; + + if (!get(follow(tp))) + return tp; + + std::vector resultTypes; + std::optional resultTail; + + TypePackIterator it = begin(tp); + + for (TypePackIterator e = end(tp); it != e; ++it) + { + TypeId ty = follow(*it); + resultTypes.push_back(get(ty) ? anyType : ty); + } + + if (std::optional tail = it.tail()) + resultTail = anyifyModuleReturnTypePackGenerics(*tail); + + return addTypePack(resultTypes, resultTail); +} + void TypeChecker::reportError(const TypeError& error) { if (currentModule->mode == Mode::NoCheck) @@ -4955,8 +5009,7 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location) { - return Unifier{ - ¤tModule->internalTypes, singletonTypes, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant, unifierState}; + return Unifier{NotNull{&normalizer}, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant}; } TypeId TypeChecker::freshType(const ScopePtr& scope) diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index bf6bf34a7..b143268e3 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -27,6 +27,7 @@ LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAGVARIABLE(LuauStringFormatArgumentErrorFix, false) LUAU_FASTFLAGVARIABLE(LuauNoMoreGlobalSingletonTypes, false) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) namespace Luau { @@ -339,6 +340,8 @@ bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub) // then instantiate U if `isGeneric(U)` is true, and `maybeGeneric(T)` is false. bool isGeneric(TypeId ty) { + LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); + ty = follow(ty); if (auto ftv = get(ty)) return ftv->generics.size() > 0 || ftv->genericPacks.size() > 0; @@ -350,6 +353,8 @@ bool isGeneric(TypeId ty) bool maybeGeneric(TypeId ty) { + LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); + if (FFlag::LuauMaybeGenericIntersectionTypes) { ty = follow(ty); diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index a3d4540cb..e0cc14149 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -1,60 +1,64 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Unifiable.h" +LUAU_FASTFLAG(LuauTypeNormalization2); + namespace Luau { namespace Unifiable { +static int nextIndex = 0; + Free::Free(TypeLevel level) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) { } Free::Free(Scope* scope) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , scope(scope) { } Free::Free(Scope* scope, TypeLevel level) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) , scope(scope) { } -int Free::nextIndex = 0; +int Free::DEPRECATED_nextIndex = 0; Generic::Generic() - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , name("g" + std::to_string(index)) { } Generic::Generic(TypeLevel level) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) , name("g" + std::to_string(index)) { } Generic::Generic(const Name& name) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , name(name) , explicitName(true) { } Generic::Generic(Scope* scope) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , scope(scope) { } Generic::Generic(TypeLevel level, const Name& name) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) , name(name) , explicitName(true) @@ -62,14 +66,14 @@ Generic::Generic(TypeLevel level, const Name& name) } Generic::Generic(Scope* scope, const Name& name) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , scope(scope) , name(name) , explicitName(true) { } -int Generic::nextIndex = 0; +int Generic::DEPRECATED_nextIndex = 0; Error::Error() : index(++nextIndex) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index c13a6f8b5..5a01c9348 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -2,6 +2,7 @@ #include "Luau/Unifier.h" #include "Luau/Common.h" +#include "Luau/Instantiation.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypePack.h" @@ -20,7 +21,9 @@ LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) +LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauCallUnifyPackTails) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -343,17 +346,19 @@ static bool subsumes(bool useScopes, TY_A* left, TY_B* right) return left->level.subsumes(right->level); } -Unifier::Unifier(TypeArena* types, NotNull singletonTypes, Mode mode, NotNull scope, const Location& location, - Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) - : types(types) - , singletonTypes(singletonTypes) +Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, + Variance variance, TxnLog* parentLog) + : types(normalizer->arena) + , singletonTypes(normalizer->singletonTypes) + , normalizer(normalizer) , mode(mode) , scope(scope) , log(parentLog) , location(location) , variance(variance) - , sharedState(sharedState) + , sharedState(*normalizer->sharedState) { + normalize = FFlag::LuauSubtypeNormalizer; LUAU_ASSERT(sharedState.iceHandler); } @@ -524,7 +529,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyUnionWithType(subTy, subUnion, superTy); } - else if (const UnionTypeVar* uv = log.getMutable(superTy)) + else if (const UnionTypeVar* uv = (FFlag::LuauSubtypeNormalizer? nullptr: log.getMutable(superTy))) { tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); } @@ -532,6 +537,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyTypeWithIntersection(subTy, superTy, uv); } + else if (const UnionTypeVar* uv = log.getMutable(superTy)) + { + tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); + } else if (const IntersectionTypeVar* uv = log.getMutable(subTy)) { tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); @@ -585,7 +594,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, TypeId superTy) { - // A | B <: T if A <: T and B <: T + // A | B <: T if and only if A <: T and B <: T bool failed = false; std::optional unificationTooComplex; std::optional firstFailedOption; @@ -715,6 +724,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { TypeId type = uv->options[(i + startIndex) % uv->options.size()]; Unifier innerState = makeChildUnifier(); + innerState.normalize = false; innerState.tryUnify_(subTy, type, isFunctionCall); if (innerState.errors.empty()) @@ -741,6 +751,20 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { reportError(*unificationTooComplex); } + else if (!found && normalize) + { + // It is possible that T <: A | B even though T normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + reportError(TypeError{location, UnificationTooComplex{}}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + else + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + } else if (!found) { if ((failedOptionCount == 1 || foundHeuristic) && failedOption) @@ -755,7 +779,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I std::optional unificationTooComplex; std::optional firstFailedOption; - // T <: A & B if T <: A and T <: B + // T <: A & B if and only if T <: A and T <: B for (TypeId type : uv->parts) { Unifier innerState = makeChildUnifier(); @@ -806,6 +830,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV { TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; Unifier innerState = makeChildUnifier(); + innerState.normalize = false; innerState.tryUnify_(type, superTy, isFunctionCall); if (innerState.errors.empty()) @@ -822,12 +847,207 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV if (unificationTooComplex) reportError(*unificationTooComplex); + else if (!found && normalize) + { + // It is possible that A & B <: T even though A normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); + else + reportError(TypeError{location, UnificationTooComplex{}}); + } else if (!found) { reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } +void Unifier::tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error) +{ + LUAU_ASSERT(FFlag::LuauSubtypeNormalizer); + + if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) + return; + else if (get(subNorm.tops)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + if (get(subNorm.errors)) + if (!get(superNorm.errors)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + if (get(subNorm.booleans)) + { + if (!get(superNorm.booleans)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + else if (const SingletonTypeVar* stv = get(subNorm.booleans)) + { + if (!get(superNorm.booleans) && stv != get(superNorm.booleans)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + + if (get(subNorm.nils)) + if (!get(superNorm.nils)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + if (get(subNorm.numbers)) + if (!get(superNorm.numbers)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + if (subNorm.strings && superNorm.strings) + { + for (auto [name, ty] : *subNorm.strings) + if (!superNorm.strings->count(name)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + else if (!subNorm.strings && superNorm.strings) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + if (get(subNorm.threads)) + if (!get(superNorm.errors)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + for (TypeId subClass : subNorm.classes) + { + bool found = false; + const ClassTypeVar* subCtv = get(subClass); + for (TypeId superClass : superNorm.classes) + { + const ClassTypeVar* superCtv = get(superClass); + if (isSubclass(subCtv, superCtv)) + { + found = true; + break; + } + } + if (!found) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + + for (TypeId subTable : subNorm.tables) + { + bool found = false; + for (TypeId superTable : superNorm.tables) + { + Unifier innerState = makeChildUnifier(); + if (get(superTable)) + innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); + else if (get(subTable)) + innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); + else + innerState.tryUnifyTables(subTable, superTable); + if (innerState.errors.empty()) + { + found = true; + log.concat(std::move(innerState.log)); + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); + } + if (!found) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + + if (subNorm.functions) + { + if (!superNorm.functions) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + if (superNorm.functions->empty()) + return; + for (TypeId superFun : *superNorm.functions) + { + Unifier innerState = makeChildUnifier(); + const FunctionTypeVar* superFtv = get(superFun); + if (!superFtv) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); + innerState.tryUnify_(tgt, superFtv->retTypes); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); + else + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + } + + for (auto& [tyvar, subIntersect] : subNorm.tyvars) + { + auto found = superNorm.tyvars.find(tyvar); + if (found == superNorm.tyvars.end()) + tryUnifyNormalizedTypes(subTy, superTy, *subIntersect, superNorm, reason, error); + else + tryUnifyNormalizedTypes(subTy, superTy, *subIntersect, *found->second, reason, error); + if (!errors.empty()) + return; + } +} + +TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args) +{ + if (!overloads || overloads->empty()) + { + reportError(TypeError{location, CannotCallNonFunction{function}}); + return singletonTypes->errorRecoveryTypePack(); + } + + std::optional result; + const FunctionTypeVar* firstFun = nullptr; + for (TypeId overload : *overloads) + { + if (const FunctionTypeVar* ftv = get(overload)) + { + // TODO: instantiate generics? + if (ftv->generics.empty() && ftv->genericPacks.empty()) + { + if (!firstFun) + firstFun = ftv; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(args, ftv->argTypes); + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); + if (result) + { + // Annoyingly, since we don't support intersection of generic type packs, + // the intersection may fail. We rather arbitrarily use the first matching overload + // in that case. + if (std::optional intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) + result = intersect; + } + else + result = ftv->retTypes; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + { + reportError(*e); + return singletonTypes->errorRecoveryTypePack(args); + } + } + } + } + + if (result) + return *result; + else if (firstFun) + { + // TODO: better error reporting? + // The logic for error reporting overload resolution + // is currently over in TypeInfer.cpp, should we move it? + reportError(TypeError{location, GenericError{"No matching overload."}}); + return singletonTypes->errorRecoveryTypePack(firstFun->retTypes); + } + else + { + reportError(TypeError{location, CannotCallNonFunction{function}}); + return singletonTypes->errorRecoveryTypePack(); + } +} + bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) { bool* superTyInfo = sharedState.skipCacheForType.find(superTy); @@ -1253,14 +1473,38 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal ice("passed non-function types to unifyFunction"); size_t numGenerics = superFunction->generics.size(); - if (numGenerics != subFunction->generics.size()) + size_t numGenericPacks = superFunction->genericPacks.size(); + + bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0); + + if (FFlag::LuauInstantiateInSubtyping && variance == Covariant && shouldInstantiate) + { + Instantiation instantiation{&log, types, scope->level, scope}; + + std::optional instantiated = instantiation.substitute(subTy); + if (instantiated.has_value()) + { + subFunction = log.getMutable(*instantiated); + + if (!subFunction) + ice("instantiation made a function type into a non-function type in unifyFunction"); + + numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); + numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); + + } + else + { + reportError(TypeError{location, UnificationTooComplex{}}); + } + } + else if (numGenerics != subFunction->generics.size()) { numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); } - size_t numGenericPacks = superFunction->genericPacks.size(); if (numGenericPacks != subFunction->genericPacks.size()) { numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); @@ -1376,6 +1620,27 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) std::vector missingProperties; std::vector extraProperties; + if (FFlag::LuauInstantiateInSubtyping) + { + if (variance == Covariant && subTable->state == TableState::Generic && superTable->state != TableState::Generic) + { + Instantiation instantiation{&log, types, subTable->level, scope}; + + std::optional instantiated = instantiation.substitute(subTy); + if (instantiated.has_value()) + { + subTable = log.getMutable(*instantiated); + + if (!subTable) + ice("instantiation made a table type into a non-table type in tryUnifyTables"); + } + else + { + reportError(TypeError{location, UnificationTooComplex{}}); + } + } + } + // Optimization: First test that the property sets are compatible without doing any recursive unification if (!subTable->indexer && subTable->state != TableState::Free) { @@ -2344,8 +2609,9 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - Unifier u = Unifier{types, singletonTypes, mode, scope, location, variance, sharedState, &log}; + Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; u.anyIsTop = anyIsTop; + u.normalize = normalize; return u; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index cf3eaaaea..c20c08471 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -25,6 +25,8 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauTypeAnnotationLocationChange, false) +LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) + bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false; bool lua_telemetry_parsed_double_prefix_hex_integer = false; @@ -1062,6 +1064,12 @@ void Parser::parseExprList(TempVector& result) { nextLexeme(); + if (FFlag::LuauCommaParenWarnings && lexer.current().type == ')') + { + report(lexer.current().location, "Expected expression after ',' but got ')' instead"); + break; + } + result.push_back(parseExpr()); } } @@ -1148,7 +1156,14 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector, AstArray> Parser::parseG } if (lexer.current().type == ',') + { nextLexeme(); + + if (FFlag::LuauCommaParenWarnings && lexer.current().type == '>') + { + report(lexer.current().location, "Expected type after ',' but got '>' instead"); + break; + } + } else break; } diff --git a/CLI/Profiler.cpp b/CLI/Profiler.cpp index 30a171f0f..d3ad4e996 100644 --- a/CLI/Profiler.cpp +++ b/CLI/Profiler.cpp @@ -82,11 +82,13 @@ static void profilerLoop() if (now - last >= 1.0 / double(gProfiler.frequency)) { - gProfiler.ticks += uint64_t((now - last) * 1e6); + int64_t ticks = int64_t((now - last) * 1e6); + + gProfiler.ticks += ticks; gProfiler.samples++; gProfiler.callbacks->interrupt = profilerTrigger; - last = now; + last += ticks * 1e-6; } else { diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index cb799d319..15db7a156 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -152,6 +152,7 @@ class AssemblyBuilderX64 void placeModRegMem(OperandX64 rhs, uint8_t regop); void placeRex(RegisterX64 op); void placeRex(OperandX64 op); + void placeRexNoW(OperandX64 op); void placeRex(RegisterX64 lhs, OperandX64 rhs); void placeVex(OperandX64 dst, OperandX64 src1, OperandX64 src2, bool setW, uint8_t mode, uint8_t prefix); void placeImm8Or32(int32_t imm); diff --git a/CodeGen/include/Luau/CodeAllocator.h b/CodeGen/include/Luau/CodeAllocator.h index 01e131216..2ea64630a 100644 --- a/CodeGen/include/Luau/CodeAllocator.h +++ b/CodeGen/include/Luau/CodeAllocator.h @@ -24,13 +24,15 @@ struct CodeAllocator void* context = nullptr; // Called when new block is created to create and setup the unwinding information for all the code in the block - // If data is placed inside the block itself (some platforms require this), we also return 'unwindDataSizeInBlock' - void* (*createBlockUnwindInfo)(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) = nullptr; + // 'startOffset' reserves space for data at the beginning of the page + void* (*createBlockUnwindInfo)(void* context, uint8_t* block, size_t blockSize, size_t& startOffset) = nullptr; // Called to destroy unwinding information returned by 'createBlockUnwindInfo' void (*destroyBlockUnwindInfo)(void* context, void* unwindData) = nullptr; - static const size_t kMaxUnwindDataSize = 128; + // Unwind information can be placed inside the block with some implementation-specific reservations at the beginning + // But to simplify block space checks, we limit the max size of all that data + static const size_t kMaxReservedDataSize = 256; bool allocateNewBlock(size_t& unwindInfoSize); diff --git a/CodeGen/include/Luau/CodeBlockUnwind.h b/CodeGen/include/Luau/CodeBlockUnwind.h index ddae33a60..0f7af3ac5 100644 --- a/CodeGen/include/Luau/CodeBlockUnwind.h +++ b/CodeGen/include/Luau/CodeBlockUnwind.h @@ -10,7 +10,7 @@ namespace CodeGen { // context must be an UnwindBuilder -void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock); +void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& startOffset); void destroyBlockUnwindInfo(void* context, void* unwindData); } // namespace CodeGen diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h index c6f611b0f..b7237318a 100644 --- a/CodeGen/include/Luau/UnwindBuilder.h +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -14,7 +14,10 @@ namespace CodeGen class UnwindBuilder { public: - virtual ~UnwindBuilder() {} + virtual ~UnwindBuilder() = default; + + virtual void setBeginOffset(size_t beginOffset) = 0; + virtual size_t getBeginOffset() const = 0; virtual void start() = 0; diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index 25dbc55ba..dab6e9573 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -12,6 +12,9 @@ namespace CodeGen class UnwindBuilderDwarf2 : public UnwindBuilder { public: + void setBeginOffset(size_t beginOffset) override; + size_t getBeginOffset() const override; + void start() override; void spill(int espOffset, RegisterX64 reg) override; @@ -26,6 +29,8 @@ class UnwindBuilderDwarf2 : public UnwindBuilder void finalize(char* target, void* funcAddress, size_t funcSize) const override; private: + size_t beginOffset = 0; + static const unsigned kRawDataLimit = 128; uint8_t rawData[kRawDataLimit]; uint8_t* pos = rawData; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index 801eb6e47..005137712 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -22,6 +22,9 @@ struct UnwindCodeWin class UnwindBuilderWin : public UnwindBuilder { public: + void setBeginOffset(size_t beginOffset) override; + size_t getBeginOffset() const override; + void start() override; void spill(int espOffset, RegisterX64 reg) override; @@ -36,6 +39,8 @@ class UnwindBuilderWin : public UnwindBuilder void finalize(char* target, void* funcAddress, size_t funcSize) const override; private: + size_t beginOffset = 0; + // Windows unwind codes are written in reverse, so we have to collect them all first std::vector unwindCodes; diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 32325b0dc..cd3079ac3 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -354,10 +354,15 @@ void AssemblyBuilderX64::jmp(Label& label) void AssemblyBuilderX64::jmp(OperandX64 op) { + LUAU_ASSERT((op.cat == CategoryX64::reg ? op.base.size : op.memSize) == SizeX64::qword); + if (logText) log("jmp", op); - placeRex(op); + // Indirect absolute calls always work in 64 bit width mode, so REX.W is optional + // While we could keep an optional prefix, in Windows x64 ABI it signals a tail call return statement to the unwinder + placeRexNoW(op); + place(0xff); placeModRegMem(op, 4); commit(); @@ -376,10 +381,14 @@ void AssemblyBuilderX64::call(Label& label) void AssemblyBuilderX64::call(OperandX64 op) { + LUAU_ASSERT((op.cat == CategoryX64::reg ? op.base.size : op.memSize) == SizeX64::qword); + if (logText) log("call", op); - placeRex(op); + // Indirect absolute calls always work in 64 bit width mode, so REX.W is optional + placeRexNoW(op); + place(0xff); placeModRegMem(op, 2); commit(); @@ -838,6 +847,21 @@ void AssemblyBuilderX64::placeRex(OperandX64 op) place(code | 0x40); } +void AssemblyBuilderX64::placeRexNoW(OperandX64 op) +{ + uint8_t code = 0; + + if (op.cat == CategoryX64::reg) + code = REX_B(op.base); + else if (op.cat == CategoryX64::mem) + code = REX_X(op.index) | REX_B(op.base); + else + LUAU_ASSERT(!"No encoding for left operand of this category"); + + if (code != 0) + place(code | 0x40); +} + void AssemblyBuilderX64::placeRex(RegisterX64 lhs, OperandX64 rhs) { uint8_t code = REX_W(lhs.size == SizeX64::qword); diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index aacf40a34..b3787d1f2 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -91,7 +91,7 @@ CodeAllocator::CodeAllocator(size_t blockSize, size_t maxTotalSize) : blockSize(blockSize) , maxTotalSize(maxTotalSize) { - LUAU_ASSERT(blockSize > kMaxUnwindDataSize); + LUAU_ASSERT(blockSize > kMaxReservedDataSize); LUAU_ASSERT(maxTotalSize >= blockSize); } @@ -116,15 +116,15 @@ bool CodeAllocator::allocate( size_t totalSize = alignedDataSize + codeSize; // Function has to fit into a single block with unwinding information - if (totalSize > blockSize - kMaxUnwindDataSize) + if (totalSize > blockSize - kMaxReservedDataSize) return false; - size_t unwindInfoSize = 0; + size_t startOffset = 0; // We might need a new block if (totalSize > size_t(blockEnd - blockPos)) { - if (!allocateNewBlock(unwindInfoSize)) + if (!allocateNewBlock(startOffset)) return false; LUAU_ASSERT(totalSize <= size_t(blockEnd - blockPos)); @@ -132,20 +132,20 @@ bool CodeAllocator::allocate( LUAU_ASSERT((uintptr_t(blockPos) & (kPageSize - 1)) == 0); // Allocation starts on page boundary - size_t dataOffset = unwindInfoSize + alignedDataSize - dataSize; - size_t codeOffset = unwindInfoSize + alignedDataSize; + size_t dataOffset = startOffset + alignedDataSize - dataSize; + size_t codeOffset = startOffset + alignedDataSize; if (dataSize) memcpy(blockPos + dataOffset, data, dataSize); if (codeSize) memcpy(blockPos + codeOffset, code, codeSize); - size_t pageAlignedSize = alignToPageSize(unwindInfoSize + totalSize); + size_t pageAlignedSize = alignToPageSize(startOffset + totalSize); makePagesExecutable(blockPos, pageAlignedSize); flushInstructionCache(blockPos + codeOffset, codeSize); - result = blockPos + unwindInfoSize; + result = blockPos + startOffset; resultSize = totalSize; resultCodeStart = blockPos + codeOffset; @@ -190,7 +190,7 @@ bool CodeAllocator::allocateNewBlock(size_t& unwindInfoSize) // 'Round up' to preserve 16 byte alignment of the following data and code unwindInfoSize = (unwindInfoSize + 15) & ~15; - LUAU_ASSERT(unwindInfoSize <= kMaxUnwindDataSize); + LUAU_ASSERT(unwindInfoSize <= kMaxReservedDataSize); if (!unwindInfo) return false; diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index 6191cee40..c045ba6b3 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -51,7 +51,7 @@ namespace Luau namespace CodeGen { -void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) +void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) { #if defined(_WIN32) && defined(_M_X64) UnwindBuilder* unwind = (UnwindBuilder*)context; @@ -75,7 +75,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz return nullptr; } - unwindDataSizeInBlock = unwindSize; + beginOffset = unwindSize + unwind->getBeginOffset(); return block; #elif !defined(_WIN32) UnwindBuilder* unwind = (UnwindBuilder*)context; @@ -94,7 +94,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz __register_frame(unwindData); #endif - unwindDataSizeInBlock = unwindSize; + beginOffset = unwindSize + unwind->getBeginOffset(); return block; #endif diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index f3886d9ce..8d06864ee 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -129,6 +129,16 @@ namespace Luau namespace CodeGen { +void UnwindBuilderDwarf2::setBeginOffset(size_t beginOffset) +{ + this->beginOffset = beginOffset; +} + +size_t UnwindBuilderDwarf2::getBeginOffset() const +{ + return beginOffset; +} + void UnwindBuilderDwarf2::start() { uint8_t* cieLength = pos; diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 5405fcf21..1b3279e82 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -32,6 +32,16 @@ struct UnwindInfoWin uint8_t frameregoff : 4; }; +void UnwindBuilderWin::setBeginOffset(size_t beginOffset) +{ + this->beginOffset = beginOffset; +} + +size_t UnwindBuilderWin::getBeginOffset() const +{ + return beginOffset; +} + void UnwindBuilderWin::start() { stackOffset = 8; // Return address was pushed by calling the function diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index ce47cd9ad..7cff70a47 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,6 +13,7 @@ inline bool isFlagExperimental(const char* flag) static const char* kList[] = { "LuauLowerBoundsCalculation", "LuauInterpolatedStringBaseSupport", + "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code // makes sure we always have at least one entry nullptr, }; diff --git a/Makefile b/Makefile index a773af4a3..338bb9ba9 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,7 @@ MAKEFLAGS+=-r -j8 COMMA=, config=debug +protobuf=system BUILD=build/$(config) @@ -95,12 +96,22 @@ ifeq ($(config),fuzz) CXX=clang++ # our fuzzing infra relies on llvm fuzzer CXXFLAGS+=-fsanitize=address,fuzzer -Ibuild/libprotobuf-mutator -O2 LDFLAGS+=-fsanitize=address,fuzzer + LPROTOBUF=-lprotobuf + DPROTOBUF=-D CMAKE_BUILD_TYPE=Release -D LIB_PROTO_MUTATOR_TESTING=OFF + EPROTOC=protoc endif ifeq ($(config),profile) CXXFLAGS+=-O2 -DNDEBUG -gdwarf-4 -DCALLGRIND=1 endif +ifeq ($(protobuf),download) + CXXFLAGS+=-Ibuild/libprotobuf-mutator/external.protobuf/include + LPROTOBUF=build/libprotobuf-mutator/external.protobuf/lib/libprotobuf.a + DPROTOBUF+=-D LIB_PROTO_MUTATOR_DOWNLOAD_PROTOBUF=ON + EPROTOC=../build/libprotobuf-mutator/external.protobuf/bin/protoc +endif + # target-specific flags $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include @@ -115,7 +126,7 @@ $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/ $(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread -fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a -lprotobuf +fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a $(LPROTOBUF) # pseudo targets .PHONY: all test clean coverage format luau-size aliases @@ -199,7 +210,7 @@ $(BUILD)/%.c.o: %.c # protobuf fuzzer setup fuzz/luau.pb.cpp: fuzz/luau.proto build/libprotobuf-mutator - cd fuzz && protoc luau.proto --cpp_out=. + cd fuzz && $(EPROTOC) luau.proto --cpp_out=. mv fuzz/luau.pb.cc fuzz/luau.pb.cpp $(BUILD)/fuzz/proto.cpp.o: fuzz/luau.pb.cpp @@ -207,7 +218,7 @@ $(BUILD)/fuzz/protoprint.cpp.o: fuzz/luau.pb.cpp build/libprotobuf-mutator: git clone https://github.com/google/libprotobuf-mutator build/libprotobuf-mutator - CXX= cmake -S build/libprotobuf-mutator -B build/libprotobuf-mutator -D CMAKE_BUILD_TYPE=Release -D LIB_PROTO_MUTATOR_TESTING=OFF + CXX= cmake -S build/libprotobuf-mutator -B build/libprotobuf-mutator $(DPROTOBUF) make -C build/libprotobuf-mutator -j8 # picks up include dependencies for all object files diff --git a/VM/include/lua.h b/VM/include/lua.h index 0b34bd0a5..6d0e98d70 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -401,13 +401,15 @@ struct lua_Debug const char* name; // (n) const char* what; // (s) `Lua', `C', `main', `tail' const char* source; // (s) + const char* short_src; // (s) int linedefined; // (s) int currentline; // (l) unsigned char nupvals; // (u) number of upvalues unsigned char nparams; // (a) number of parameters char isvararg; // (a) - char short_src[LUA_IDSIZE]; // (s) void* userdata; // only valid in luau_callhook + + char ssbuf[LUA_IDSIZE]; }; // }====================================================================== diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index e695cd2b3..82af5d380 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauFasterGetInfo, false) + static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -89,9 +91,9 @@ const char* lua_setlocal(lua_State* L, int level, int n) return name; } -static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, CallInfo* ci) +static Closure* auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, CallInfo* ci) { - int status = 1; + Closure* cl = NULL; for (; *what; what++) { switch (*what) @@ -103,14 +105,23 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, ar->source = "=[C]"; ar->what = "C"; ar->linedefined = -1; + if (FFlag::LuauFasterGetInfo) + ar->short_src = "[C]"; } else { - ar->source = getstr(f->l.p->source); + TString* source = f->l.p->source; + ar->source = getstr(source); ar->what = "Lua"; ar->linedefined = f->l.p->linedefined; + if (FFlag::LuauFasterGetInfo) + ar->short_src = luaO_chunkid(ar->ssbuf, sizeof(ar->ssbuf), getstr(source), source->len); + } + if (!FFlag::LuauFasterGetInfo) + { + luaO_chunkid(ar->ssbuf, LUA_IDSIZE, ar->source, 0); + ar->short_src = ar->ssbuf; } - luaO_chunkid(ar->short_src, ar->source, LUA_IDSIZE); break; } case 'l': @@ -150,10 +161,15 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, ar->name = ci ? getfuncname(ci_func(ci)) : getfuncname(f); break; } + case 'f': + { + cl = f; + break; + } default:; } } - return status; + return cl; } int lua_stackdepth(lua_State* L) @@ -163,7 +179,6 @@ int lua_stackdepth(lua_State* L) int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar) { - int status = 0; Closure* f = NULL; CallInfo* ci = NULL; if (level < 0) @@ -180,15 +195,28 @@ int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar) } if (f) { - status = auxgetinfo(L, what, ar, f, ci); - if (strchr(what, 'f')) + if (FFlag::LuauFasterGetInfo) { - luaC_threadbarrier(L); - setclvalue(L, L->top, f); - incr_top(L); + // auxgetinfo fills ar and optionally requests to put closure on stack + if (Closure* fcl = auxgetinfo(L, what, ar, f, ci)) + { + luaC_threadbarrier(L); + setclvalue(L, L->top, fcl); + incr_top(L); + } + } + else + { + auxgetinfo(L, what, ar, f, ci); + if (strchr(what, 'f')) + { + luaC_threadbarrier(L); + setclvalue(L, L->top, f); + incr_top(L); + } } } - return status; + return f ? 1 : 0; } static const char* getfuncname(Closure* cl) @@ -284,10 +312,11 @@ static void pusherror(lua_State* L, const char* msg) CallInfo* ci = L->ci; if (isLua(ci)) { - char buff[LUA_IDSIZE]; // add file:line information - luaO_chunkid(buff, getstr(getluaproto(ci)->source), LUA_IDSIZE); + TString* source = getluaproto(ci)->source; + char chunkbuf[LUA_IDSIZE]; // add file:line information + const char* chunkid = luaO_chunkid(chunkbuf, sizeof(chunkbuf), getstr(source), source->len); int line = currentline(L, ci); - luaO_pushfstring(L, "%s:%d: %s", buff, line, msg); + luaO_pushfstring(L, "%s:%d: %s", chunkid, line, msg); } else { diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 4b9fbb69b..f50b33d18 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBetterThreadMark, false) - /* * Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * @@ -473,54 +471,25 @@ static size_t propagatemark(global_State* g) bool active = th->isactive || th == th->global->mainthread; - if (FFlag::LuauBetterThreadMark) - { - traversestack(g, th); + traversestack(g, th); - // active threads will need to be rescanned later to mark new stack writes so we mark them gray again - if (active) - { - th->gclist = g->grayagain; - g->grayagain = o; - - black2gray(o); - } - - // the stack needs to be cleared after the last modification of the thread state before sweep begins - // if the thread is inactive, we might not see the thread in this cycle so we must clear it now - if (!active || g->gcstate == GCSatomic) - clearstack(th); - - // we could shrink stack at any time but we opt to do it during initial mark to do that just once per cycle - if (g->gcstate == GCSpropagate) - shrinkstack(th); - } - else + // active threads will need to be rescanned later to mark new stack writes so we mark them gray again + if (active) { - // TODO: Refactor this logic! - if (!active && g->gcstate == GCSpropagate) - { - traversestack(g, th); - clearstack(th); - } - else - { - th->gclist = g->grayagain; - g->grayagain = o; - - black2gray(o); + th->gclist = g->grayagain; + g->grayagain = o; - traversestack(g, th); + black2gray(o); + } - // final traversal? - if (g->gcstate == GCSatomic) - clearstack(th); - } + // the stack needs to be cleared after the last modification of the thread state before sweep begins + // if the thread is inactive, we might not see the thread in this cycle so we must clear it now + if (!active || g->gcstate == GCSatomic) + clearstack(th); - // we could shrink stack at any time but we opt to skip it during atomic since it's redundant to do that more than once per cycle - if (g->gcstate != GCSatomic) - shrinkstack(th); - } + // we could shrink stack at any time but we opt to do it during initial mark to do that just once per cycle + if (g->gcstate == GCSpropagate) + shrinkstack(th); return sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; } diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index f5f1cd0e8..8b3e4783d 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -15,6 +15,8 @@ +LUAU_FASTFLAG(LuauFasterGetInfo) + const TValue luaO_nilobject_ = {{NULL}, {0}, LUA_TNIL}; int luaO_log2(unsigned int x) @@ -117,44 +119,68 @@ const char* luaO_pushfstring(lua_State* L, const char* fmt, ...) return msg; } -void luaO_chunkid(char* out, const char* source, size_t bufflen) +const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t srclen) { if (*source == '=') { - source++; // skip the `=' - size_t srclen = strlen(source); - size_t dstlen = srclen < bufflen ? srclen : bufflen - 1; - memcpy(out, source, dstlen); - out[dstlen] = '\0'; + if (FFlag::LuauFasterGetInfo) + { + if (srclen <= buflen) + return source + 1; + // truncate the part after = + memcpy(buf, source + 1, buflen - 1); + buf[buflen - 1] = '\0'; + } + else + { + source++; // skip the `=' + size_t len = strlen(source); + size_t dstlen = len < buflen ? len : buflen - 1; + memcpy(buf, source, dstlen); + buf[dstlen] = '\0'; + } } else if (*source == '@') { - size_t l; - source++; // skip the `@' - bufflen -= sizeof("..."); - l = strlen(source); - strcpy(out, ""); - if (l > bufflen) + if (FFlag::LuauFasterGetInfo) + { + if (srclen <= buflen) + return source + 1; + // truncate the part after @ + memcpy(buf, "...", 3); + memcpy(buf + 3, source + srclen - (buflen - 4), buflen - 4); + buf[buflen - 1] = '\0'; + } + else { - source += (l - bufflen); // get last part of file name - strcat(out, "..."); + size_t l; + source++; // skip the `@' + buflen -= sizeof("..."); + l = strlen(source); + strcpy(buf, ""); + if (l > buflen) + { + source += (l - buflen); // get last part of file name + strcat(buf, "..."); + } + strcat(buf, source); } - strcat(out, source); } else - { // out = [string "string"] + { // buf = [string "string"] size_t len = strcspn(source, "\n\r"); // stop at first newline - bufflen -= sizeof("[string \"...\"]"); - if (len > bufflen) - len = bufflen; - strcpy(out, "[string \""); + buflen -= sizeof("[string \"...\"]"); + if (len > buflen) + len = buflen; + strcpy(buf, "[string \""); if (source[len] != '\0') { // must truncate? - strncat(out, source, len); - strcat(out, "..."); + strncat(buf, source, len); + strcat(buf, "..."); } else - strcat(out, source); - strcat(out, "\"]"); + strcat(buf, source); + strcat(buf, "\"]"); } + return buf; } diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 5f5e7b1c8..41bf3386e 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -460,4 +460,4 @@ LUAI_FUNC int luaO_rawequalKey(const TKey* t1, const TValue* t2); LUAI_FUNC int luaO_str2d(const char* s, double* result); LUAI_FUNC const char* luaO_pushvfstring(lua_State* L, const char* fmt, va_list argp); LUAI_FUNC const char* luaO_pushfstring(lua_State* L, const char* fmt, ...); -LUAI_FUNC void luaO_chunkid(char* out, const char* source, size_t len); +LUAI_FUNC const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t srclen); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 490358c4b..7ee3ee9bd 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -781,6 +781,7 @@ static void luau_execute(lua_State* L) default: LUAU_ASSERT(!"Unknown upvalue capture type"); + LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks } } @@ -1184,7 +1185,9 @@ static void luau_execute(lua_State* L) // slow path after switch() break; - default:; + default: + LUAU_ASSERT(!"Unknown value type"); + LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks } // slow-path: tables with metatables and userdata values @@ -1296,7 +1299,9 @@ static void luau_execute(lua_State* L) // slow path after switch() break; - default:; + default: + LUAU_ASSERT(!"Unknown value type"); + LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks } // slow-path: tables with metatables and userdata values diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index bd40bad2f..3edec6889 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -148,16 +148,16 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size // 0 means the rest of the bytecode is the error message if (version == 0) { - char chunkid[LUA_IDSIZE]; - luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + char chunkbuf[LUA_IDSIZE]; + const char* chunkid = luaO_chunkid(chunkbuf, sizeof(chunkbuf), chunkname, strlen(chunkname)); lua_pushfstring(L, "%s%.*s", chunkid, int(size - offset), data + offset); return 1; } if (version < LBC_VERSION_MIN || version > LBC_VERSION_MAX) { - char chunkid[LUA_IDSIZE]; - luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + char chunkbuf[LUA_IDSIZE]; + const char* chunkid = luaO_chunkid(chunkbuf, sizeof(chunkbuf), chunkname, strlen(chunkname)); lua_pushfstring(L, "%s: bytecode version mismatch (expected [%d..%d], got %d)", chunkid, LBC_VERSION_MIN, LBC_VERSION_MAX, version); return 1; } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 5c5551580..05d397540 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -220,8 +220,13 @@ int luaV_strcmp(const TString* ls, const TString* rs) return 0; const char* l = getstr(ls); - size_t ll = ls->len; const char* r = getstr(rs); + + // always safe to read one character because even empty strings are nul terminated + if (*l != *r) + return uint8_t(*l) - uint8_t(*r); + + size_t ll = ls->len; size_t lr = rs->len; size_t lmin = ll < lr ? ll : lr; diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 08f241ed7..a05108878 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -240,12 +240,12 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps") { - SINGLE_COMPARE(jmp(rax), 0x48, 0xff, 0xe0); - SINGLE_COMPARE(jmp(r14), 0x49, 0xff, 0xe6); - SINGLE_COMPARE(jmp(qword[r14 + rdx * 4]), 0x49, 0xff, 0x24, 0x96); - SINGLE_COMPARE(call(rax), 0x48, 0xff, 0xd0); - SINGLE_COMPARE(call(r14), 0x49, 0xff, 0xd6); - SINGLE_COMPARE(call(qword[r14 + rdx * 4]), 0x49, 0xff, 0x14, 0x96); + SINGLE_COMPARE(jmp(rax), 0xff, 0xe0); + SINGLE_COMPARE(jmp(r14), 0x41, 0xff, 0xe6); + SINGLE_COMPARE(jmp(qword[r14 + rdx * 4]), 0x41, 0xff, 0x24, 0x96); + SINGLE_COMPARE(call(rax), 0xff, 0xd0); + SINGLE_COMPARE(call(r14), 0x41, 0xff, 0xd6); + SINGLE_COMPARE(call(qword[r14 + rdx * 4]), 0x41, 0xff, 0x14, 0x96); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfImul") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index a64d372fc..9a5c3411c 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2955,8 +2955,6 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_on_class") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - loadDefinition(R"( declare class Foo function one(self): number @@ -2995,8 +2993,6 @@ t.@1 TEST_CASE_FIXTURE(ACFixture, "do_compatible_self_calls") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local t = {} function t:m() end @@ -3011,8 +3007,6 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local t = {} function t.m() end @@ -3027,8 +3021,6 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local f: (() -> number) & ((number) -> number) = function(x: number?) return 2 end local t = {} @@ -3059,8 +3051,6 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_wrong_compatible_self_calls_with_generics") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local t = {} function t.m(a: T) end @@ -3076,8 +3066,6 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local s = "hello" s:@1 @@ -3095,8 +3083,6 @@ s:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local s = "hello" s.@1 @@ -3112,8 +3098,6 @@ s.@1 TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_non_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( string.@1 )"); @@ -3143,8 +3127,6 @@ table.@1 TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_self_calls_are_invalid") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( string:@1 )"); diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 758fb44cb..4e553f054 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -96,12 +96,12 @@ TEST_CASE("CodeAllocationWithUnwindCallbacks") data.resize(8); allocator.context = &info; - allocator.createBlockUnwindInfo = [](void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) -> void* { + allocator.createBlockUnwindInfo = [](void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) -> void* { Info& info = *(Info*)context; CHECK(info.unwind.size() == 8); memcpy(block, info.unwind.data(), info.unwind.size()); - unwindDataSizeInBlock = 8; + beginOffset = 8; info.block = block; @@ -194,10 +194,12 @@ TEST_CASE("Dwarf2UnwindCodesX64") // Windows x64 ABI constexpr RegisterX64 rArg1 = rcx; constexpr RegisterX64 rArg2 = rdx; +constexpr RegisterX64 rArg3 = r8; #else // System V AMD64 ABI constexpr RegisterX64 rArg1 = rdi; constexpr RegisterX64 rArg2 = rsi; +constexpr RegisterX64 rArg3 = rdx; #endif constexpr RegisterX64 rNonVol1 = r12; @@ -313,6 +315,119 @@ TEST_CASE("GeneratedCodeExecutionWithThrow") } } +TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGate") +{ + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->start(); + + // Prologue (some of these registers don't have to be saved, but we want to have a big prologue) + build.push(r10); + unwind->save(r10); + build.push(r11); + unwind->save(r11); + build.push(r12); + unwind->save(r12); + build.push(r13); + unwind->save(r13); + build.push(r14); + unwind->save(r14); + build.push(r15); + unwind->save(r15); + build.push(rbp); + unwind->save(rbp); + + int stackSize = 64; + int localsSize = 16; + + build.sub(rsp, stackSize + localsSize); + unwind->allocStack(stackSize + localsSize); + + build.lea(rbp, qword[rsp + stackSize]); + unwind->setupFrameReg(rbp, stackSize); + + unwind->finish(); + + size_t prologueSize = build.setLabel().location; + + // Body + build.mov(rax, rArg1); + build.mov(rArg1, 25); + build.jmp(rax); + + Label returnOffset = build.setLabel(); + + // Epilogue + build.lea(rsp, qword[rbp + localsSize]); + build.pop(rbp); + build.pop(r15); + build.pop(r14); + build.pop(r13); + build.pop(r12); + build.pop(r11); + build.pop(r10); + build.ret(); + + build.finalize(); + + size_t blockSize = 4096; // Force allocate to create a new block each time + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData1; + size_t sizeNativeData1; + uint8_t* nativeEntry1; + REQUIRE( + allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData1, sizeNativeData1, nativeEntry1)); + REQUIRE(nativeEntry1); + + // Now we set the offset at the begining so that functions in new blocks will not overlay the locations + // specified by the unwind information of the entry function + unwind->setBeginOffset(prologueSize); + + using FunctionType = int64_t(void*, void (*)(int64_t), void*); + FunctionType* f = (FunctionType*)nativeEntry1; + + uint8_t* nativeExit = nativeEntry1 + returnOffset.location; + + AssemblyBuilderX64 build2(/* logText= */ false); + + build2.mov(r12, rArg3); + build2.call(rArg2); + build2.jmp(r12); + + build2.finalize(); + + uint8_t* nativeData2; + size_t sizeNativeData2; + uint8_t* nativeEntry2; + REQUIRE(allocator.allocate( + build2.data.data(), build2.data.size(), build2.code.data(), build2.code.size(), nativeData2, sizeNativeData2, nativeEntry2)); + REQUIRE(nativeEntry2); + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(nativeEntry2, throwing, nativeExit); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } + + REQUIRE(nativeEntry2); +} + #endif TEST_SUITE_END(); diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index 9976bd2c6..3420fd878 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -28,8 +28,11 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") cgb.visit(block); NotNull rootScope{cgb.rootScope}; + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; NullModuleResolver resolver; - ConstraintSolver cs{&arena, singletonTypes, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; + ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; cs.run(); @@ -49,9 +52,11 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") cgb.visit(block); NotNull rootScope{cgb.rootScope}; + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; NullModuleResolver resolver; - ConstraintSolver cs{&arena, singletonTypes, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; - + ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; cs.run(); TypeId idType = requireBinding(rootScope, "id"); @@ -79,7 +84,10 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") ToStringOptions opts; NullModuleResolver resolver; - ConstraintSolver cs{&arena, singletonTypes, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; cs.run(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index dcc0222ab..4d3c8854f 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -22,6 +22,8 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauReportShadowedTypeAlias) +extern std::optional randomSeed; // tests/main.cpp + namespace Luau { @@ -90,7 +92,7 @@ std::optional TestFileResolver::getEnvironmentForModule(const Modul Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) - , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true}) + , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* randomConstraintResolutionSeed */ randomSeed}) , typeChecker(frontend.typeChecker) , singletonTypes(frontend.singletonTypes) { diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 921e6691c..fc3ede738 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -21,14 +21,14 @@ end return math.max(fib(5), 1) )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnknownGlobal") { LintResult result = lint("--!nocheck\nreturn foo"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Unknown global 'foo'"); } @@ -39,7 +39,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") LintResult result = lintTyped("Wait(5)"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'Wait' is deprecated, use 'wait' instead"); } @@ -53,7 +53,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobalNoReplacement") LintResult result = lintTyped("Version()"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'Version' is deprecated"); } @@ -64,7 +64,7 @@ local _ = 5 return _ )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); } @@ -75,7 +75,7 @@ _ = 5 print(_) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); } @@ -86,7 +86,7 @@ local _ = 5 _ = 6 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(BuiltinsFixture, "BuiltinGlobalWrite") @@ -100,7 +100,7 @@ end assert(5) )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Built-in global 'math' is overwritten here; consider using a local or changing the name"); CHECK_EQ(result.warnings[1].text, "Built-in global 'assert' is overwritten here; consider using a local or changing the name"); } @@ -111,7 +111,7 @@ TEST_CASE_FIXTURE(Fixture, "MultilineBlock") if true then print(1) print(2) print(3) end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "A new statement is on the same line; add semi-colon on previous statement to silence"); } @@ -121,7 +121,7 @@ TEST_CASE_FIXTURE(Fixture, "MultilineBlockSemicolonsWhitelisted") print(1); print(2); print(3) )"); - CHECK(result.warnings.empty()); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "MultilineBlockMissedSemicolon") @@ -130,7 +130,7 @@ TEST_CASE_FIXTURE(Fixture, "MultilineBlockMissedSemicolon") print(1); print(2) print(3) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "A new statement is on the same line; add semi-colon on previous statement to silence"); } @@ -142,7 +142,7 @@ local _x do end )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "ConfusingIndentation") @@ -152,7 +152,7 @@ print(math.max(1, 2)) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Statement spans multiple lines; use indentation to silence"); } @@ -167,7 +167,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'foo' is only used in the enclosing function 'bar'; consider changing it to local"); } @@ -188,7 +188,7 @@ end return bar() + baz() )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'foo' is never read before being written. Consider changing it to local"); } @@ -213,7 +213,7 @@ end return bar() + baz() + read() )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalWithConditional") @@ -233,7 +233,7 @@ end return bar() + baz() )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "GlobalAsLocal3WithConditionalRead") @@ -257,7 +257,7 @@ end return bar() + baz() + read() )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalInnerRead") @@ -275,7 +275,7 @@ function baz() bar = 0 end return foo() + baz() )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMulti") @@ -304,7 +304,7 @@ fnA() -- prints "true", "nil" fnB() -- prints "false", "nil" )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'moreInternalLogic' is only used in the enclosing function defined at line 2; consider changing it to local"); } @@ -319,7 +319,7 @@ local arg = 5 print(arg) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'arg' shadows previous declaration at line 2"); } @@ -337,7 +337,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'global' shadows a global variable used at line 3"); } @@ -352,7 +352,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'a' shadows previous declaration at line 2"); } @@ -372,7 +372,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'arg' is never used; prefix with '_' to silence"); CHECK_EQ(result.warnings[1].text, "Variable 'blarg' is never used; prefix with '_' to silence"); } @@ -387,7 +387,7 @@ local Roact = require(game.Packages.Roact) local _Roact = require(game.Packages.Roact) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Import 'Roact' is never used; prefix with '_' to silence"); } @@ -412,7 +412,7 @@ end return foo() )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Function 'bar' is never used; prefix with '_' to silence"); CHECK_EQ(result.warnings[1].text, "Function 'qux' is never used; prefix with '_' to silence"); } @@ -427,7 +427,7 @@ end print("hi!") )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 5); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always returns)"); } @@ -443,7 +443,7 @@ end print("hi!") )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always breaks)"); } @@ -459,7 +459,7 @@ end print("hi!") )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always continues)"); } @@ -495,7 +495,7 @@ end return { foo1, foo2, foo3 } )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 7); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always returns)"); } @@ -515,7 +515,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnreachableCodeAssertFalseReturnSilent") @@ -532,7 +532,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnreachableCodeErrorReturnNonSilentBranchy") @@ -550,7 +550,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 7); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always errors)"); } @@ -571,7 +571,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 8); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always errors)"); } @@ -589,7 +589,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnreachableCodeLoopRepeat") @@ -605,8 +605,8 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), - 0); // this is technically a bug, since the repeat body always returns; fixing this bug is a bit more involved than I'd like + // this is technically a bug, since the repeat body always returns; fixing this bug is a bit more involved than I'd like + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnknownType") @@ -633,7 +633,7 @@ local _o02 = type(game) == "vector" local _o03 = typeof(game) == "Part" )"); - REQUIRE_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 2); CHECK_EQ(result.warnings[0].text, "Unknown type 'Part' (expected primitive type)"); CHECK_EQ(result.warnings[1].location.begin.line, 3); @@ -654,7 +654,7 @@ for i=#t,1,-1 do end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[0].text, "For loop should iterate backwards; did you forget to specify -1 as step?"); } @@ -669,7 +669,7 @@ for i=8,1,-1 do end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 1); CHECK_EQ(result.warnings[0].text, "For loop should iterate backwards; did you forget to specify -1 as step?"); } @@ -684,7 +684,7 @@ for i=1.3,7.5,1 do end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 1); CHECK_EQ(result.warnings[0].text, "For loop ends at 7.3 instead of 7.5; did you forget to specify step?"); } @@ -702,7 +702,7 @@ for i=#t,0 do end )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 1); CHECK_EQ(result.warnings[0].text, "For loop starts at 0, but arrays start at 1"); CHECK_EQ(result.warnings[1].location.begin.line, 7); @@ -730,7 +730,7 @@ local _a,_b,_c = pcall(), nil end )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 5); CHECK_EQ(result.warnings[0].text, "Assigning 2 values to 3 variables initializes extra variables with nil; add 'nil' to value list to silence"); CHECK_EQ(result.warnings[1].location.begin.line, 11); @@ -795,7 +795,7 @@ end return f1,f2,f3,f4,f5,f6,f7 )"); - CHECK_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 4); CHECK_EQ(result.warnings[0].text, "Function 'f1' can implicitly return no values even though there's an explicit return at line 4; add explicit return to silence"); @@ -851,7 +851,7 @@ end return f1,f2,f3,f4 )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 25); CHECK_EQ(result.warnings[0].text, "Function 'f3' can implicitly return no values even though there's an explicit return at line 21; add explicit return to silence"); @@ -874,7 +874,7 @@ type InputData = { } )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "BreakFromInfiniteLoopMakesStatementReachable") @@ -893,7 +893,7 @@ until true return 1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "IgnoreLintAll") @@ -903,7 +903,7 @@ TEST_CASE_FIXTURE(Fixture, "IgnoreLintAll") return foo )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "IgnoreLintSpecific") @@ -914,7 +914,7 @@ local x = 1 return foo )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'x' is never used; prefix with '_' to silence"); } @@ -933,7 +933,7 @@ local _ = ("%"):format() string.format("hello %+10d %.02f %%", 4, 5) )"); - CHECK_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid format string: unfinished format specifier"); CHECK_EQ(result.warnings[1].text, "Invalid format string: invalid format specifier: must be a string format specifier or %"); CHECK_EQ(result.warnings[2].text, "Invalid format string: invalid format specifier: must be a string format specifier or %"); @@ -973,7 +973,7 @@ string.packsize("c99999999999999999999") string.packsize("=!1bbbI3c42") )"); - CHECK_EQ(result.warnings.size(), 11); + REQUIRE(11 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); CHECK_EQ(result.warnings[1].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); CHECK_EQ(result.warnings[2].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); @@ -1017,7 +1017,7 @@ local _ = s:match("%q") string.match(s, "[A-Z]+(%d)%1") )"); - CHECK_EQ(result.warnings.size(), 14); + REQUIRE(14 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); CHECK_EQ(result.warnings[2].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); @@ -1049,7 +1049,7 @@ string.match(s, "((a)%1)") string.match(s, "((a)%3)") )~"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid capture reference, must refer to a closed capture"); CHECK_EQ(result.warnings[0].location.begin.line, 7); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid capture reference, must refer to a valid capture"); @@ -1087,7 +1087,7 @@ string.match(s, "[]|'[]") string.match(s, "[^]|'[]") )~"); - CHECK_EQ(result.warnings.size(), 7); + REQUIRE(7 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: expected ] at the end of the string to close a set"); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: expected ] at the end of the string to close a set"); CHECK_EQ(result.warnings[2].text, "Invalid match pattern: character range can't include character sets"); @@ -1118,7 +1118,7 @@ string.find("foo"); ("foo"):find() )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); CHECK_EQ(result.warnings[0].location.begin.line, 4); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); @@ -1141,7 +1141,7 @@ string.gsub(s, '[A-Z]+(%d)', "%0%1") string.gsub(s, 'foo', "%0") )"); - CHECK_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match replacement: unfinished replacement"); CHECK_EQ(result.warnings[1].text, "Invalid match replacement: unexpected replacement character; must be a digit or %"); CHECK_EQ(result.warnings[2].text, "Invalid match replacement: invalid capture index, must refer to pattern capture"); @@ -1162,7 +1162,7 @@ os.date("it's %c now") os.date("!*t") )"); - CHECK_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid date format: unfinished replacement"); CHECK_EQ(result.warnings[1].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); CHECK_EQ(result.warnings[2].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); @@ -1181,7 +1181,7 @@ s:match("[]") nons:match("[]") )~"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: expected ] at the end of the string to close a set"); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: expected ] at the end of the string to close a set"); @@ -1231,7 +1231,7 @@ _ = { } )"); - CHECK_EQ(result.warnings.size(), 6); + REQUIRE(6 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Table field 'first' is a duplicate; previously defined at line 3"); CHECK_EQ(result.warnings[1].text, "Table field 'first' is a duplicate; previously defined at line 9"); CHECK_EQ(result.warnings[2].text, "Table index 1 is a duplicate; previously defined as a list entry"); @@ -1248,7 +1248,7 @@ TEST_CASE_FIXTURE(Fixture, "ImportOnlyUsedInTypeAnnotation") local x: Foo.Y = 1 )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'x' is never used; prefix with '_' to silence"); } @@ -1259,7 +1259,7 @@ TEST_CASE_FIXTURE(Fixture, "DisableUnknownGlobalWithTypeChecking") unknownGlobal() )"); - REQUIRE_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "no_spurious_warning_after_a_function_type_alias") @@ -1271,7 +1271,7 @@ TEST_CASE_FIXTURE(Fixture, "no_spurious_warning_after_a_function_type_alias") return exports )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") @@ -1294,7 +1294,7 @@ TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") LintResult result = frontend.lint("A"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "DeadLocalsUsed") @@ -1320,7 +1320,7 @@ do end )"); - CHECK_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'x' defined at line 4 is never initialized or assigned; initialize with 'nil' to silence"); CHECK_EQ(result.warnings[1].text, "Assigning 2 values to 3 variables initializes extra variables with nil; add 'nil' to value list to silence"); CHECK_EQ(result.warnings[2].text, "Variable 'c' defined at line 12 is never initialized or assigned; initialize with 'nil' to silence"); @@ -1333,7 +1333,7 @@ local foo function foo() end )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "DuplicateGlobalFunction") @@ -1408,7 +1408,7 @@ TEST_CASE_FIXTURE(Fixture, "DontTriggerTheWarningIfTheFunctionsAreInDifferentSco return c )"); - CHECK(result.warnings.empty()); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") @@ -1444,7 +1444,7 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") local h: Hooty.Pt )"); - CHECK_EQ(result.warnings.size(), 12); + REQUIRE(12 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") @@ -1478,7 +1478,7 @@ return function (i: Instance) end )"); - REQUIRE_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Member 'Instance.Wait' is deprecated"); CHECK_EQ(result.warnings[1].text, "Member 'toHSV' is deprecated, use 'Color3:ToHSV' instead"); CHECK_EQ(result.warnings[2].text, "Member 'Instance.DataCost' is deprecated"); @@ -1511,7 +1511,7 @@ table.create(42, {}) table.create(42, {} :: {}) )"); - REQUIRE_EQ(result.warnings.size(), 10); + REQUIRE(10 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); @@ -1556,7 +1556,7 @@ _ = true and true or false -- no warning since this is is a common pattern used _ = if true then 1 elseif true then 2 else 3 )"); - REQUIRE_EQ(result.warnings.size(), 8); + REQUIRE(8 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); CHECK_EQ(result.warnings[0].location.begin.line + 1, 4); CHECK_EQ(result.warnings[1].text, "Condition has already been checked on column 5"); @@ -1580,7 +1580,7 @@ elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls", false)}) t end )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 4"); CHECK_EQ(result.warnings[0].location.begin.line + 1, 5); } @@ -1601,7 +1601,7 @@ end return foo, moo, a1, a2 )"); - REQUIRE_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Function parameter 'a1' already defined on column 14"); CHECK_EQ(result.warnings[1].text, "Variable 'a1' is never used; prefix with '_' to silence"); CHECK_EQ(result.warnings[2].text, "Variable 'a1' already defined on column 7"); @@ -1618,7 +1618,7 @@ _ = math.random() < 0.5 and 0 or 42 _ = (math.random() < 0.5 and false) or 42 -- currently ignored )"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "The and-or expression always evaluates to the second alternative because the first alternative is false; " "consider using if-then-else expression instead"); CHECK_EQ(result.warnings[1].text, "The and-or expression always evaluates to the second alternative because the first alternative is nil; " @@ -1640,7 +1640,7 @@ do end --!nolint )"); - REQUIRE_EQ(result.warnings.size(), 6); + REQUIRE(6 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Unknown comment directive 'struct'; did you mean 'strict'?"); CHECK_EQ(result.warnings[1].text, "Unknown comment directive 'nolintGlobal'"); CHECK_EQ(result.warnings[2].text, "nolint directive refers to unknown lint rule 'Global'"); @@ -1656,7 +1656,7 @@ TEST_CASE_FIXTURE(Fixture, "WrongCommentMuteSelf") --!struct )"); - REQUIRE_EQ(result.warnings.size(), 0); // --!nolint disables WrongComment lint :) + REQUIRE(0 == result.warnings.size()); // --!nolint disables WrongComment lint :) } TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsIfStatAndExpr") @@ -1668,7 +1668,7 @@ elseif if 0 then 5 else 4 then end )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); } @@ -1681,13 +1681,13 @@ TEST_CASE_FIXTURE(Fixture, "WrongCommentOptimize") --!optimize 2 )"); - REQUIRE_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "optimize directive requires an optimization level"); CHECK_EQ(result.warnings[1].text, "optimize directive uses unknown optimization level 'me', 0..2 expected"); CHECK_EQ(result.warnings[2].text, "optimize directive uses unknown optimization level '100500', 0..2 expected"); result = lint("--!optimize "); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "optimize directive requires an optimization level"); } @@ -1700,7 +1700,7 @@ TEST_CASE_FIXTURE(Fixture, "TestStringInterpolation") local _ = `unknown {foo}` )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "IntegerParsing") @@ -1710,7 +1710,7 @@ local _ = 0b10000000000000000000000000000000000000000000000000000000000000000 local _ = 0x10000000000000000 )"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Binary number literal exceeded available precision and has been truncated to 2^64"); CHECK_EQ(result.warnings[1].text, "Hexadecimal number literal exceeded available precision and has been truncated to 2^64"); } @@ -1725,7 +1725,7 @@ local _ = 0x0x123 local _ = 0x0xffffffffffffffffffffffffffffffffff )"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Hexadecimal number literal has a double prefix, which will fail to parse in the future; remove the extra 0x to fix"); CHECK_EQ(result.warnings[1].text, @@ -1756,7 +1756,7 @@ local _ = (a <= b) == 0 local _ = a <= (b == 0) )"); - REQUIRE_EQ(result.warnings.size(), 5); + REQUIRE(5 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "not X == Y is equivalent to (not X) == Y; consider using X ~= Y, or add parentheses to silence"); CHECK_EQ(result.warnings[1].text, "not X ~= Y is equivalent to (not X) ~= Y; consider using X == Y, or add parentheses to silence"); CHECK_EQ(result.warnings[2].text, "not X <= Y is equivalent to (not X) <= Y; add parentheses to silence"); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 31df707d7..156cbbc5d 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -356,6 +356,11 @@ TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop") TEST_CASE_FIXTURE(NormalizeFixture, "intersection") { + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + check(R"( local a: number & string local b: number @@ -374,8 +379,9 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersection") CHECK(!isSubtype(c, a)); CHECK(isSubtype(a, c)); - CHECK(!isSubtype(d, a)); - CHECK(!isSubtype(a, d)); + // These types are both equivalent to never + CHECK(isSubtype(d, a)); + CHECK(isSubtype(a, d)); } TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection") diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index b4064cfb5..662c29008 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2722,4 +2722,59 @@ TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation" result.errors[0].getMessage()); } +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_function_argument_list") +{ + ScopedFastFlag sff{"LuauCommaParenWarnings", true}; + + ParseResult result = tryParse(R"( + foo(a, b, c,) + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({1, 20}, {1, 21}) == result.errors[0].getLocation()); + CHECK("Expected expression after ',' but got ')' instead" == result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_function_parameter_list") +{ + ScopedFastFlag sff{"LuauCommaParenWarnings", true}; + + ParseResult result = tryParse(R"( + export type VisitFn = ( + any, + Array>, -- extra comma here + ) -> any + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({4, 8}, {4, 9}) == result.errors[0].getLocation()); + CHECK("Expected type after ',' but got ')' instead" == result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_generic_parameter_list") +{ + ScopedFastFlag sff{"LuauCommaParenWarnings", true}; + + ParseResult result = tryParse(R"( + export type VisitFn = (a: A, b: B) -> () + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({1, 36}, {1, 37}) == result.errors[0].getLocation()); + CHECK("Expected type after ',' but got '>' instead" == result.errors[0].getMessage()); + + REQUIRE(1 == result.root->body.size); + + AstStatTypeAlias* t = result.root->body.data[0]->as(); + REQUIRE(t != nullptr); + + AstTypeFunction* f = t->type->as(); + REQUIRE(f != nullptr); + + CHECK(2 == f->generics.size); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 98883dfa7..87ec58c9f 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -480,4 +480,38 @@ local a: ChildClass = i CHECK_EQ("Type 'ChildClass' from 'Test' could not be converted into 'ChildClass' from 'MainModule'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(ClassFixture, "intersections_of_unions_of_classes") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (BaseClass | Vector2) & (ChildClass | AnotherChild) + local y : (ChildClass | AnotherChild) + x = y + y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "unions_of_intersections_of_classes") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (BaseClass & ChildClass) | (BaseClass & AnotherChild) | (BaseClass & Vector2) + local y : (ChildClass | AnotherChild) + x = y + y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index a4420b9a8..fa99ff584 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -2,6 +2,7 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Error.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -14,6 +15,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -1087,10 +1089,20 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + if (FFlag::LuauInstantiateInSubtyping) + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' caused by: Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } // Infer from variadic packs into elements result = check(R"( @@ -1189,10 +1201,20 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + if (FFlag::LuauInstantiateInSubtyping) + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' caused by: Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } // Infer from variadic packs into elements result = check(R"( diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 3c8677706..1b02abc1b 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -10,6 +10,7 @@ #include "doctest.h" LUAU_FASTFLAG(LuauCheckGenericHOFTypes) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSpecialTypesAsterisked) using namespace Luau; @@ -960,7 +961,11 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); - CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); + if (FFlag::LuauInstantiateInSubtyping) + CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); + else + CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); + } TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") @@ -980,7 +985,10 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); - CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); + if (FFlag::LuauInstantiateInSubtyping) + CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); + else + CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); } TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") @@ -1110,6 +1118,15 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i { LUAU_REQUIRE_NO_ERRORS(result); } + else if (FFlag::LuauInstantiateInSubtyping) + { + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ( + R"(Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a' +caused by: + Argument #1 type is not compatible. Generic subtype escaping scope)", + toString(result.errors[0])); + } else { LUAU_REQUIRE_ERRORS(result); @@ -1219,4 +1236,48 @@ TEST_CASE_FIXTURE(Fixture, "do_not_always_instantiate_generic_intersection_types LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "hof_subtype_instantiation_regression") +{ + CheckResult result = check(R"( +--!strict + +local function defaultSort(a: T, b: T) + return true +end +type A = any +return function(array: {T}): {T} + table.sort(array, defaultSort) + return array +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "higher_rank_polymorphism_should_not_accept_instantiated_arguments") +{ + ScopedFastFlag sffs[] = { + {"LuauInstantiateInSubtyping", true}, + {"LuauCheckGenericHOFTypes", true}, // necessary because of interactions with the test + }; + + CheckResult result = check(R"( +--!strict + +local function instantiate(f: (a) -> a): (number) -> number + return f +end + +instantiate(function(x: string) return "foo" end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto tm1 = get(result.errors[0]); + REQUIRE(tm1); + + CHECK_EQ("(a) -> a", toString(tm1->wantedType)); + CHECK_EQ("(string) -> string", toString(tm1->givenType)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 818d0124c..e49df1017 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -446,4 +446,459 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "intersect_bool_and_false") +{ + CheckResult result = check(R"( + local x : (boolean & false) + local y : false = x -- OK + local z : true = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") +{ + CheckResult result = check(R"( + local x : false & (boolean & false) + local y : false = x -- OK + local z : true = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + // TODO: odd stringification of `false & (boolean & false)`.) + CHECK_EQ(toString(result.errors[0]), "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number?) -> number?) & ((string?) -> string?) + local y : (nil) -> nil = x -- OK + local z : (number) -> number = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> number?) & ((string?) -> string?)' could not be converted into '(number) -> number'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number) -> number) & ((string) -> string) + local y : ((number | string) -> (number | string)) = x -- OK + local z : ((number | boolean) -> (number | boolean)) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number) & ((string) -> string)' could not be converted into '(boolean | number) -> boolean | number'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : { p : number?, q : string? } & { p : number?, q : number?, r : number? } + local y : { p : number?, q : nil, r : number? } = x -- OK + local z : { p : nil } = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into '{| p: nil |}'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") +{ + CheckResult result = check(R"( + local x : { p : number?, q : any } & { p : unknown, q : string? } + local y : { p : number?, q : string? } = x -- OK + local z : { p : string?, q : number? } = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, q: number? |}'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : { p : number?, q : never } & { p : never, q : string? } + local y : { p : never, q : never } = x -- OK + local z : never = x -- OK + )"); + + // TODO: this should not produce type errors, since never <: { p : never } + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: never, q: string? |} & {| p: number?, q: never |}' could not be converted into 'never'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number?) -> ({ p : number } & { q : number })) & ((string?) -> ({ p : number } & { r : number })) + local y : (nil) -> { p : number, q : number, r : number} = x -- OK + local z : (number?) -> { p : number, q : number, r : number} = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into '(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number?) -> (a | number)) & ((string?) -> (a | string)) + local y : (nil) -> a = x -- OK + local z : (number?) -> a = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> a | number) & ((string?) -> a | string)' could not be converted into '(number?) -> a'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generics") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((a?) -> (a | b)) & ((c?) -> (b | c)) + local y : (nil) -> ((a & c) | b) = x -- OK + local z : (a?) -> ((a & c) | b) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((a?) -> a | b) & ((c?) -> b | c)' could not be converted into '(a?) -> (a & c) | b'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic_packs") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...)) + local y : ((nil, a...) -> (nil, b...)) = x -- OK + local z : ((nil, b...) -> (nil, a...)) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' could not be converted into '(nil, b...) -> (nil, a...)'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number) & ((nil) -> unknown) + local y : (number?) -> unknown = x -- OK + local z : (number?) -> number? = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> unknown) & ((number) -> number)' could not be converted into '(number?) -> number?'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number?) & ((unknown) -> string?) + local y : (number) -> nil = x -- OK + local z : (number?) -> nil = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number?) & ((unknown) -> string?)' could not be converted into '(number?) -> nil'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_result") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number) & ((nil) -> never) + local y : (number?) -> number = x -- OK + local z : (number?) -> never = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> never) & ((number) -> number)' could not be converted into '(number?) -> never'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_arguments") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number?) & ((never) -> string?) + local y : (never) -> nil = x -- OK + local z : (number?) -> nil = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((never) -> string?) & ((number) -> number?)' could not be converted into '(number?) -> nil'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_overlapping_results_and_variadics") +{ + CheckResult result = check(R"( + local x : ((string?) -> (string | number)) & ((number?) -> ...number) + local y : ((nil) -> (number, number?)) = x -- OK + local z : ((string | number) -> (number, number?)) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> (...number)) & ((string?) -> number | string)' could not be converted into '(number | string) -> (number, number?)'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_1") +{ + CheckResult result = check(R"( + function f() + local x : (() -> a...) & (() -> b...) + local y : (() -> b...) & (() -> a...) = x -- OK + local z : () -> () = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(() -> (a...)) & (() -> (b...))' could not be converted into '() -> ()'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_2") +{ + CheckResult result = check(R"( + function f() + local x : ((a...) -> ()) & ((b...) -> ()) + local y : ((b...) -> ()) & ((a...) -> ()) = x -- OK + local z : () -> () = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((a...) -> ()) & ((b...) -> ())' could not be converted into '() -> ()'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_3") +{ + CheckResult result = check(R"( + function f() + local x : (() -> a...) & (() -> (number?,a...)) + local y : (() -> (number?,a...)) & (() -> a...) = x -- OK + local z : () -> (number) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(() -> (a...)) & (() -> (number?, a...))' could not be converted into '() -> number'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_4") +{ + CheckResult result = check(R"( + function f() + local x : ((a...) -> ()) & ((number,a...) -> number) + local y : ((number,a...) -> number) & ((a...) -> ()) = x -- OK + local z : (number?) -> () = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((a...) -> ()) & ((number, a...) -> number)' could not be converted into '(number?) -> ()'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local a : string? = nil + local b : number? = nil + + local x = setmetatable({}, { p = 5, q = a }); + local y = setmetatable({}, { q = b, r = "hi" }); + local z = setmetatable({}, { p = 5, q = nil, r = "hi" }); + + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) + + local xy : X&Y = z; + local yx : Y&X = z; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_subtypes") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }); + local y = setmetatable({ b = "hi" }, { p = 5, q = "hi" }); + local z = setmetatable({ a = 5, b = "hi" }, { p = 5, q = "hi" }); + + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) + + local xy : X&Y = z; + local yx : Y&X = z; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables_with_properties") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }); + local y = setmetatable({ b = "hi" }, { q = "hi" }); + local z = setmetatable({ a = 5, b = "hi" }, { p = 5, q = "hi" }); + + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) + + local xy : X&Y = z; + z = xy; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_with table") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }); + local z = setmetatable({ a = 5, b = "hi" }, { p = 5 }); + + type X = typeof(x) + type Y = { b : string } + type Z = typeof(z) + + -- TODO: once we have shape types, we should be able to initialize these with z + local xy : X&Y; + local yx : Y&X; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "CLI-44817") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + type X = {x: number} + type Y = {y: number} + type Z = {z: number} + + type XY = {x: number, y: number} + type XYZ = {x:number, y: number, z: number} + + local xy: XY = {x = 0, y = 0} + local xyz: XYZ = {x = 0, y = 0, z = 0} + + local xNy: X&Y = xy + local xNyNz: X&Y&Z = xyz + + local t1: XY = xNy -- Type 'X & Y' could not be converted into 'XY' + local t2: XY = xNyNz -- Type 'X & Y & Z' could not be converted into 'XY' + local t3: XYZ = xNyNz -- Type 'X & Y & Z' could not be converted into 'XYZ' + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index ede84f4a5..36943cac8 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -10,6 +10,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauInstantiateInSubtyping) + using namespace Luau; LUAU_FASTFLAG(LuauSpecialTypesAsterisked) @@ -248,7 +250,24 @@ end return m )"); - LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauInstantiateInSubtyping) + { + // though this didn't error before the flag, it seems as though it should error since fields of a table are invariant. + // the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be unsound. + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(R"(Type 'n' could not be converted into 't1 where t1 = {- Clone: (t1) -> (a...) -}' +caused by: + Property 'Clone' is not compatible. Type '(a) -> ()' could not be converted into 't1 where t1 = ({- Clone: t1 -}) -> (a...)'; different number of generic type parameters)", + toString(result.errors[0])); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } + } TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") @@ -367,8 +386,6 @@ type Table = typeof(tbl) TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_5") { - ScopedFastFlag luauInplaceDemoteSkipAllBound{"LuauInplaceDemoteSkipAllBound", true}; - fileResolver.source["game/A"] = R"( export type Type = {x: number, y: number} local arrayops = {} diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 45740a0b1..2aac6653f 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -271,30 +271,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated } } -// Should be in TypeInfer.tables.test.cpp -// It's unsound to instantiate tables containing generic methods, -// since mutating properties means table properties should be invariant. -// We currently allow this but we shouldn't! -TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_call_is_unsound") -{ - CheckResult result = check(R"( - --!strict - local t = {} - function t.m(x) return x end - local a : string = t.m("hi") - local b : number = t.m(5) - function f(x : { m : (number)->number }) - x.m = function(x) return 1+x end - end - f(t) -- This shouldn't typecheck - local c : string = t.m("hi") - )"); - - // TODO: this should error! - // This should be fixed by replacing generic tables by generics with type bounds. - LUAU_REQUIRE_NO_ERRORS(result); -} - // FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { @@ -608,7 +584,8 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") InternalErrorReporter iceHandler; UnifierSharedState sharedState{&iceHandler}; - Unifier u{&arena, singletonTypes, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant, sharedState}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant}; u.tryUnify(option1, option2); @@ -635,4 +612,87 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") LUAU_REQUIRE_NO_ERRORS(result); } +// Ideally, we would not try to export a function type with generic types from incorrect scope +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface") +{ + ScopedFastFlag LuauAnyifyModuleReturnGenerics{"LuauAnyifyModuleReturnGenerics", true}; + + fileResolver.source["game/A"] = R"( +local wrapStrictTable + +local metatable = { + __index = function(self, key) + local value = self.__tbl[key] + if type(value) == "table" then + -- unification of the free 'wrapStrictTable' with this function type causes generics of this function to leak out of scope + return wrapStrictTable(value, self.__name .. "." .. key) + end + return value + end, +} + +return wrapStrictTable + )"; + + frontend.check("game/A"); + + fileResolver.source["game/B"] = R"( +local wrapStrictTable = require(game.A) + +local Constants = {} + +return wrapStrictTable(Constants, "Constants") + )"; + + frontend.check("game/B"); + + ModulePtr m = frontend.moduleResolver.modules["game/B"]; + REQUIRE(m); + + std::optional result = first(m->getModuleScope()->returnType); + REQUIRE(result); + CHECK(get(*result)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") +{ + ScopedFastFlag LuauAnyifyModuleReturnGenerics{"LuauAnyifyModuleReturnGenerics", true}; + + fileResolver.source["game/A"] = R"( +local wrapStrictTable + +local metatable = { + __index = function(self, key, ...: T) + local value = self.__tbl[key] + if type(value) == "table" then + -- unification of the free 'wrapStrictTable' with this function type causes generics of this function to leak out of scope + return wrapStrictTable(value, self.__name .. "." .. key) + end + return ... + end, +} + +return wrapStrictTable + )"; + + frontend.check("game/A"); + + fileResolver.source["game/B"] = R"( +local wrapStrictTable = require(game.A) + +local Constants = {} + +return wrapStrictTable(Constants, "Constants") + )"; + + frontend.check("game/B"); + + ModulePtr m = frontend.moduleResolver.modules["game/B"]; + REQUIRE(m); + + std::optional result = first(m->getModuleScope()->returnType); + REQUIRE(result); + CHECK(get(*result)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index d183f650e..a6d870fad 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -11,7 +11,8 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauLowerBoundsCalculation) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) TEST_SUITE_BEGIN("TableTests"); @@ -2038,11 +2039,22 @@ caused by: caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + if (FFlag::LuauInstantiateInSubtyping) + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + } + else + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + } } TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") @@ -3173,4 +3185,53 @@ caused by: CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_call_is_unsound") +{ + ScopedFastFlag sff[]{ + {"LuauInstantiateInSubtyping", true}, + }; + + CheckResult result = check(R"( + --!strict + local t = {} + function t.m(x) return x end + local a : string = t.m("hi") + local b : number = t.m(5) + function f(x : { m : (number)->number }) + x.m = function(x) return 1+x end + end + f(t) -- This shouldn't typecheck + local c : string = t.m("hi") + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' +caused by: + Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)"); + // this error message is not great since the underlying issue is that the context is invariant, + // and `(number) -> number` cannot be a subtype of `(a) -> a`. +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_table_instantiation_potential_regression") +{ + CheckResult result = check(R"( +--!strict + +function f(x) + x.p = 5 + return x +end +local g : ({ p : number, q : string }) -> ({ p : number, r : boolean }) = f + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + MissingProperties* error = get(result.errors[0]); + REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); + + CHECK_EQ("r", error->properties[0]); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 8ed61b496..26171c518 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -17,7 +17,9 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); +LUAU_FASTFLAG(LuauCheckGenericHOFTypes); using namespace Luau; @@ -999,7 +1001,26 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauInstantiateInSubtyping && !FFlag::LuauCheckGenericHOFTypes) + { + // though this didn't error before the flag, it seems as though it should error since fields of a table are invariant. + // the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be unsound. + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(R"(Type 't1 where t1 = {+ getStoreFieldName: (t1, {| fieldName: string |} & {| from: number? |}) -> (a, b...) +}' could not be converted into 'Policies' +caused by: + Property 'getStoreFieldName' is not compatible. Type 't1 where t1 = ({+ getStoreFieldName: t1 +}, {| fieldName: string |} & {| from: number? |}) -> (a, b...)' could not be converted into '(Policies, FieldSpecifier) -> string' +caused by: + Argument #2 type is not compatible. Type 'FieldSpecifier' could not be converted into 'FieldSpecifier & {| from: number? |}' +caused by: + Not all intersection parts are compatible. Table type 'FieldSpecifier' not compatible with type '{| from: number? |}' because the former has extra field 'fieldName')", + toString(result.errors[0])); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") @@ -1020,6 +1041,43 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") +{ + ScopedFastInt sfi("LuauTypeInferRecursionLimit", 10); + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + {"LuauAutocompleteDynamicLimits", true}, + }; + + CheckResult result = check(R"( + function f() + local x : a&b&c&d&e&f&g&h&(i?) + local y : (a&b&c&d&e&f&g&h&i)? = x + end + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Internal error: Code is too complex to typecheck! Consider adding type annotations around this area", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "type_infer_cache_limit_normalizer") +{ + ScopedFastInt sfi("LuauNormalizeCacheLimit", 10); + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number) -> number) & ((string) -> string) & ((nil) -> nil) & (({}) -> {}) + local y : (number | string | nil | {}) -> (number | string | nil | {}) = x + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Internal error: Code is too complex to typecheck! Consider adding type annotations around this area", toString(result.errors[0])); +} + TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 3911c520d..dedb7d28a 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -17,8 +17,8 @@ struct TryUnifyFixture : Fixture ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; - - Unifier state{&arena, singletonTypes, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant, unifierState}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&unifierState}}; + Unifier state{NotNull{&normalizer}, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant}; }; TEST_SUITE_BEGIN("TryUnifyTests"); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 7d33809fa..eb61c396b 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1000,4 +1000,23 @@ TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments_free") CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'boolean'"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "type_packs_with_tails_in_vararg_adjustment") +{ + ScopedFastFlag luauFixVarargExprHeadType{"LuauFixVarargExprHeadType", true}; + + CheckResult result = check(R"( + local function wrapReject(fn: (self: any, ...TArg) -> ...TResult): (self: any, ...TArg) -> ...TResult + return function(self, ...) + local arguments = { ... } + local ok, result = pcall(function() + return fn(self, table.unpack(arguments)) + end) + return result + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 8eb485e90..64c9b5630 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -541,5 +541,182 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); } +TEST_CASE_FIXTURE(Fixture, "union_true_and_false") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : boolean + local y1 : (true | false) = x -- OK + local y2 : (true | false | (string & number)) = x -- OK + local y3 : (true | (string & number) | false) = x -- OK + local y4 : (true | (boolean & true) | false) = x -- OK + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (number) -> number? + local y : ((number?) -> number?) | ((number) -> number) = x -- OK + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_generic_functions") +{ + CheckResult result = check(R"( + local x : (a) -> a? + local y : ((a?) -> a?) | ((b) -> b) = x -- Not OK + )"); + + // TODO: should this example typecheck? + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_generic_typepack_functions") +{ + CheckResult result = check(R"( + local x : (number, a...) -> (number?, a...) + local y : ((number?, a...) -> (number?, a...)) | ((number, b...) -> (number, b...)) = x -- Not OK + )"); + + // TODO: should this example typecheck? + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generics") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : (a) -> a? + local y : ((a?) -> nil) | ((a) -> a) = x -- OK + local z : ((b?) -> nil) | ((b) -> b) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(a) -> a?' could not be converted into '((b) -> b) | ((b?) -> nil)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generic_typepacks") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : (number, a...) -> (number?, a...) + local y : ((number | string, a...) -> (number, a...)) | ((number?, a...) -> (nil, a...)) = x -- OK + local z : ((number) -> number) | ((number?, a...) -> (number?, a...)) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(number, a...) -> (number?, a...)' could not be converted into '((number) -> number) | ((number?, a...) -> (number?, a...))'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (number) -> number? + local y : ((number?) -> number) | ((number | string) -> nil) = x -- OK + local z : ((number, string?) -> number) | ((number) -> nil) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(number) -> number?' could not be converted into '((number) -> nil) | ((number, string?) -> number)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : () -> (number | string) + local y : (() -> number) | (() -> string) = x -- OK + local z : (() -> number) | (() -> (string, string)) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '() -> number | string' could not be converted into '(() -> (string, string)) | (() -> number)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (...nil) -> (...number?) + local y : ((...string?) -> (...number)) | ((...number?) -> nil) = x -- OK + local z : ((...string?) -> (...number)) | ((...string?) -> nil) = x -- OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(...nil) -> (...number?)' could not be converted into '((...string?) -> (...number)) | ((...string?) -> nil)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_variadics") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (number) -> () + local y : ((number?) -> ()) | ((...number) -> ()) = x -- OK + local z : ((number?) -> ()) | ((...number?) -> ()) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(number) -> ()' could not be converted into '((...number?) -> ()) | ((number?) -> ())'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : () -> (number?, ...number) + local y : (() -> (...number)) | (() -> nil) = x -- OK + local z : (() -> (...number)) | (() -> number) = x -- OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '() -> (number?, ...number)' could not be converted into '(() -> (...number)) | (() -> number)'; none of the union options are compatible"); +} TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 9c19da59a..b09b087b5 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -508,6 +508,9 @@ assert((function() function cmp(a,b) return ab,a>=b end return concat( assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('abc', 'abd')) end)() == "true,true,false,false") assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('ab\\0c', 'ab\\0d')) end)() == "true,true,false,false") assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('ab\\0c', 'ab\\0')) end)() == "false,false,true,true") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('\\0a', '\\0b')) end)() == "true,true,false,false") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('a', 'a\\0')) end)() == "true,true,false,false") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('a', '\200')) end)() == "true,true,false,false") -- array access assert((function() local a = {4,5,6} return a[3] end)() == 6) diff --git a/tests/main.cpp b/tests/main.cpp index 3e480c9f0..3f564c077 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -21,6 +21,7 @@ #include #endif +#include #include // Indicates if verbose output is enabled; can be overridden via --verbose @@ -30,6 +31,10 @@ bool verbose = false; // Default optimization level for conformance test; can be overridden via -On int optimizationLevel = 1; +// Something to seed a pseudorandom number generator with. Defaults to +// something from std::random_device. +std::optional randomSeed; + static bool skipFastFlag(const char* flagName) { if (strncmp(flagName, "Test", 4) == 0) @@ -261,6 +266,16 @@ int main(int argc, char** argv) optimizationLevel = level; } + int rseed = -1; + if (doctest::parseIntOption(argc, argv, "--random-seed=", doctest::option_int, rseed)) + randomSeed = unsigned(rseed); + + if (doctest::parseOption(argc, argv, "--randomize") && !randomSeed) + { + randomSeed = std::random_device()(); + printf("Using RNG seed %u\n", *randomSeed); + } + if (std::vector flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags)) setFastFlags(flags); @@ -295,6 +310,8 @@ int main(int argc, char** argv) printf(" --verbose Enables verbose output (e.g. lua 'print' statements)\n"); printf(" --fflags= Sets specified fast flags\n"); printf(" --list-fflags List all fast flags\n"); + printf(" --randomize Use a random RNG seed\n"); + printf(" --random-seed=n Use a particular RNG seed\n"); } return result; } diff --git a/tools/lldb_formatters.lldb b/tools/lldb_formatters.lldb index 4a5acd742..f10faa94e 100644 --- a/tools/lldb_formatters.lldb +++ b/tools/lldb_formatters.lldb @@ -5,3 +5,6 @@ type synthetic add -x "^Luau::Variant<.+>$" -l lldb_formatters.LuauVariantSynthe type summary add -x "^Luau::Variant<.+>$" -F lldb_formatters.luau_variant_summary type synthetic add -x "^Luau::AstArray<.+>$" -l lldb_formatters.AstArraySyntheticChildrenProvider + +type summary add --summary-string "${var.line}:${var.column}" Luau::Position +type summary add --summary-string "${var.begin}-${var.end}" Luau::Location diff --git a/tools/natvis/Ast.natvis b/tools/natvis/Ast.natvis index 322eb8f67..18e7b762a 100644 --- a/tools/natvis/Ast.natvis +++ b/tools/natvis/Ast.natvis @@ -22,4 +22,25 @@ + + {value,na} + + + + {name.value,na} + + + + local {local->name.value,na} + global {global.value,na} + + + + {line}:{column} + + + + {begin}-{end} + + diff --git a/tools/test_dcr.py b/tools/test_dcr.py index 1e3a50176..5f1c87058 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -22,6 +22,10 @@ def safeParseInt(i, default=0): return default +def makeDottedName(path): + return ".".join(path) + + class Handler(x.ContentHandler): def __init__(self, failList): self.currentTest = [] @@ -41,7 +45,7 @@ def startElement(self, name, attrs): if self.currentTest: passed = attrs["test_case_success"] == "true" - dottedName = ".".join(self.currentTest) + dottedName = makeDottedName(self.currentTest) # Sometimes we get multiple XML trees for the same test. All of # them must report a pass in order for us to consider the test @@ -60,6 +64,10 @@ def endElement(self, name): self.currentTest.pop() +def print_stderr(*args, **kw): + print(*args, **kw, file=sys.stderr) + + def main(): parser = argparse.ArgumentParser( description="Run Luau.UnitTest with deferred constraint resolution enabled" @@ -80,6 +88,16 @@ def main(): help="Write a new faillist.txt after running tests.", ) + parser.add_argument("--randomize", action="store_true", help="Pick a random seed") + + parser.add_argument( + "--random-seed", + action="store", + dest="random_seed", + type=int, + help="Accept a specific RNG seed", + ) + args = parser.parse_args() failList = loadFailList() @@ -90,7 +108,12 @@ def main(): "--fflags=true,DebugLuauDeferredConstraintResolution=true", ] - print('>', ' '.join(commandLine), file=sys.stderr) + if args.random_seed: + commandLine.append("--random-seed=" + str(args.random_seed)) + elif args.randomize: + commandLine.append("--randomize") + + print_stderr(">", " ".join(commandLine)) p = sp.Popen( commandLine, @@ -104,15 +127,21 @@ def main(): sys.stdout.buffer.write(line) return else: - x.parse(p.stdout, handler) + try: + x.parse(p.stdout, handler) + except x.SAXParseException as e: + print_stderr( + f"XML parsing failed during test {makeDottedName(handler.currentTest)}. That probably means that the test crashed" + ) + sys.exit(1) p.wait() for testName, passed in handler.results.items(): if passed and testName in failList: - print("UNEXPECTED: {} should have failed".format(testName)) + print_stderr(f"UNEXPECTED: {testName} should have failed") elif not passed and testName not in failList: - print("UNEXPECTED: {} should have passed".format(testName)) + print_stderr(f"UNEXPECTED: {testName} should have passed") if args.write: newFailList = sorted( @@ -126,14 +155,11 @@ def main(): with open(FAIL_LIST_PATH, "w", newline="\n") as f: for name in newFailList: print(name, file=f) - print("Updated faillist.txt", file=sys.stderr) + print_stderr("Updated faillist.txt") if handler.numSkippedTests > 0: - print( - "{} test(s) were skipped! That probably means that a test segfaulted!".format( - handler.numSkippedTests - ), - file=sys.stderr, + print_stderr( + f"{handler.numSkippedTests} test(s) were skipped! That probably means that a test segfaulted!" ) sys.exit(1) @@ -143,7 +169,7 @@ def main(): ) if ok: - print("Everything in order!", file=sys.stderr) + print_stderr("Everything in order!") sys.exit(0 if ok else 1) From d82e73607c698cc74ccbc17bfd74ec60670903f6 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 14 Oct 2022 01:59:53 +0300 Subject: [PATCH 09/66] Sync to upstream/release/549 --- Analysis/include/Luau/ConstraintSolver.h | 8 +- Analysis/include/Luau/DcrLogger.h | 4 +- Analysis/include/Luau/Error.h | 1 - Analysis/include/Luau/JsonEmitter.h | 2 +- Analysis/include/Luau/Normalize.h | 22 +- Analysis/include/Luau/TypeInfer.h | 2 +- Analysis/include/Luau/Unifier.h | 11 +- Analysis/src/BuiltinDefinitions.cpp | 5 +- Analysis/src/ConstraintGraphBuilder.cpp | 5 +- Analysis/src/ConstraintSolver.cpp | 71 +- Analysis/src/DcrLogger.cpp | 11 +- Analysis/src/Error.cpp | 43 +- Analysis/src/Linter.cpp | 20 +- Analysis/src/Module.cpp | 39 - Analysis/src/Normalize.cpp | 133 ++- Analysis/src/Substitution.cpp | 4 - Analysis/src/ToString.cpp | 4 +- Analysis/src/TypeChecker2.cpp | 21 +- Analysis/src/TypeInfer.cpp | 308 +------ Analysis/src/TypeUtils.cpp | 4 +- Analysis/src/TypeVar.cpp | 18 +- Analysis/src/Unifier.cpp | 23 +- Ast/include/Luau/StringUtils.h | 8 +- CLI/Repl.cpp | 71 +- CMakeLists.txt | 11 +- CodeGen/include/Luau/AssemblyBuilderX64.h | 19 +- CodeGen/include/Luau/CodeAllocator.h | 3 +- CodeGen/include/Luau/CodeGen.h | 24 + CodeGen/include/Luau/Condition.h | 3 + CodeGen/src/AssemblyBuilderX64.cpp | 37 +- CodeGen/src/CodeGen.cpp | 449 ++++++++++ CodeGen/src/CodeGenX64.cpp | 154 ++++ CodeGen/src/CodeGenX64.h | 18 + CodeGen/src/CustomExecUtils.h | 145 ++++ CodeGen/src/EmitBuiltinsX64.cpp | 109 +++ CodeGen/src/EmitBuiltinsX64.h | 28 + CodeGen/src/EmitCommonX64.cpp | 345 ++++++++ CodeGen/src/EmitCommonX64.h | 175 ++++ CodeGen/src/EmitInstructionX64.cpp | 925 +++++++++++++++++++++ CodeGen/src/EmitInstructionX64.h | 66 ++ CodeGen/src/Fallbacks.cpp | 251 ++++-- CodeGen/src/Fallbacks.h | 162 ++-- CodeGen/src/NativeState.cpp | 122 +++ CodeGen/src/NativeState.h | 94 +++ Common/include/Luau/Common.h | 6 + Common/include/Luau/ExperimentalFlags.h | 1 - Compiler/include/Luau/BytecodeBuilder.h | 7 + Compiler/src/BytecodeBuilder.cpp | 37 + Compiler/src/Compiler.cpp | 29 +- Makefile | 9 +- Sources.cmake | 14 + VM/include/lua.h | 28 +- VM/include/luaconf.h | 2 +- VM/src/lbuiltins.cpp | 98 +-- VM/src/lfunc.cpp | 2 +- VM/src/lobject.cpp | 2 - VM/src/lobject.h | 1 + VM/src/lstate.cpp | 6 + VM/src/lstate.h | 9 +- VM/src/lvmexecute.cpp | 34 +- VM/src/lvmload.cpp | 1 + bench/tests/sha256.lua | 19 +- tests/AssemblyBuilderX64.test.cpp | 15 +- tests/AstQueryDsl.cpp | 2 +- tests/AstQueryDsl.h | 2 +- tests/Compiler.test.cpp | 38 +- tests/Conformance.test.cpp | 54 +- tests/ConstraintGraphBuilderFixture.cpp | 2 +- tests/ConstraintGraphBuilderFixture.h | 2 +- tests/Fixture.cpp | 3 +- tests/Frontend.test.cpp | 2 - tests/Linter.test.cpp | 2 - tests/Module.test.cpp | 8 +- tests/Normalize.test.cpp | 702 +--------------- tests/RuntimeLimits.test.cpp | 7 +- tests/ToDot.test.cpp | 29 +- tests/ToString.test.cpp | 17 + tests/TypeInfer.annotations.test.cpp | 48 +- tests/TypeInfer.builtins.test.cpp | 22 +- tests/TypeInfer.classes.test.cpp | 4 +- tests/TypeInfer.functions.test.cpp | 182 +--- tests/TypeInfer.generics.test.cpp | 29 +- tests/TypeInfer.intersectionTypes.test.cpp | 117 +-- tests/TypeInfer.loops.test.cpp | 6 +- tests/TypeInfer.modules.test.cpp | 6 +- tests/TypeInfer.provisional.test.cpp | 284 ++++--- tests/TypeInfer.refinements.test.cpp | 6 +- tests/TypeInfer.singletons.test.cpp | 22 - tests/TypeInfer.tables.test.cpp | 12 +- tests/TypeInfer.test.cpp | 58 +- tests/TypeInfer.tryUnify.test.cpp | 4 +- tests/TypeInfer.typePacks.cpp | 10 +- tests/TypeInfer.unionTypes.test.cpp | 54 +- tests/conformance/basic.lua | 7 + tests/conformance/bitwise.lua | 8 + tests/conformance/debugger.lua | 15 + tests/conformance/events.lua | 39 +- tests/conformance/interrupt.lua | 9 + tests/conformance/math.lua | 39 +- tests/conformance/safeenv.lua | 23 + tests/main.cpp | 14 +- tools/faillist.txt | 74 +- tools/lvmexecute_split.py | 23 +- tools/test_dcr.py | 39 +- 104 files changed, 4099 insertions(+), 2223 deletions(-) create mode 100644 CodeGen/include/Luau/CodeGen.h create mode 100644 CodeGen/src/CodeGen.cpp create mode 100644 CodeGen/src/CodeGenX64.cpp create mode 100644 CodeGen/src/CodeGenX64.h create mode 100644 CodeGen/src/CustomExecUtils.h create mode 100644 CodeGen/src/EmitBuiltinsX64.cpp create mode 100644 CodeGen/src/EmitBuiltinsX64.h create mode 100644 CodeGen/src/EmitCommonX64.cpp create mode 100644 CodeGen/src/EmitCommonX64.h create mode 100644 CodeGen/src/EmitInstructionX64.cpp create mode 100644 CodeGen/src/EmitInstructionX64.h create mode 100644 CodeGen/src/NativeState.cpp create mode 100644 CodeGen/src/NativeState.h create mode 100644 tests/conformance/safeenv.lua diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 9d5aadfbc..0bf6d1bc7 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -76,8 +76,8 @@ struct ConstraintSolver DcrLogger* logger; - explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, - NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); + explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, + std::vector requireCycles, DcrLogger* logger); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); @@ -88,7 +88,9 @@ struct ConstraintSolver **/ void run(); - bool done(); + bool isDone(); + + void finalizeModule(); /** Attempt to dispatch a constraint. Returns true if it was successful. If * tryDispatch() returns false, the constraint remains in the unsolved set diff --git a/Analysis/include/Luau/DcrLogger.h b/Analysis/include/Luau/DcrLogger.h index bd8672e32..30d2e15ec 100644 --- a/Analysis/include/Luau/DcrLogger.h +++ b/Analysis/include/Luau/DcrLogger.h @@ -112,11 +112,13 @@ struct DcrLogger void popBlock(NotNull block); void captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints); - StepSnapshot prepareStepSnapshot(const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints); + StepSnapshot prepareStepSnapshot( + const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints); void commitStepSnapshot(StepSnapshot snapshot); void captureFinalSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints); void captureTypeCheckError(const TypeError& error); + private: ConstraintGenerationLog generationLog; std::unordered_map, std::vector> constraintBlocks; diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 677548830..7338627cf 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -33,7 +33,6 @@ struct UnknownSymbol { Binding, Type, - Generic }; Name name; Context context; diff --git a/Analysis/include/Luau/JsonEmitter.h b/Analysis/include/Luau/JsonEmitter.h index d8dc96e43..1a416586a 100644 --- a/Analysis/include/Luau/JsonEmitter.h +++ b/Analysis/include/Luau/JsonEmitter.h @@ -240,7 +240,7 @@ void write(JsonEmitter& emitter, const std::unordered_map& map) for (const auto& [k, v] : map) o.writePair(k, v); - + o.finish(); } diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 41e50d1b6..72ea95588 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -17,8 +17,10 @@ struct SingletonTypes; using ModulePtr = std::shared_ptr; -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); +bool isSubtype( + TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); +bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, + bool anyIsTop = true); std::pair normalize( TypeId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice); @@ -68,13 +70,14 @@ class TypeIds insert(*it); } - bool operator ==(const TypeIds& there) const; + bool operator==(const TypeIds& there) const; size_t getHash() const; }; } // namespace Luau -template<> struct std::hash +template<> +struct std::hash { std::size_t operator()(const Luau::TypeIds& tys) const { @@ -82,7 +85,8 @@ template<> struct std::hash } }; -template<> struct std::hash +template<> +struct std::hash { std::size_t operator()(const Luau::TypeIds* tys) const { @@ -90,7 +94,8 @@ template<> struct std::hash } }; -template<> struct std::equal_to +template<> +struct std::equal_to { bool operator()(const Luau::TypeIds& here, const Luau::TypeIds& there) const { @@ -98,7 +103,8 @@ template<> struct std::equal_to } }; -template<> struct std::equal_to +template<> +struct std::equal_to { bool operator()(const Luau::TypeIds* here, const Luau::TypeIds* there) const { @@ -160,7 +166,7 @@ struct NormalizedType // The string part of the type. // This may be the `string` type, or a union of singletons. - NormalizedStringType strings = std::map{}; + NormalizedStringType strings = std::map{}; // The thread part of the type. // This type is either never or thread. diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 3184b0d30..1c4d1cb41 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -397,7 +397,7 @@ struct TypeChecker std::vector> deferredQuantification; }; -using PrintLineProc = void(*)(const std::string&); +using PrintLineProc = void (*)(const std::string&); extern PrintLineProc luauPrintLine; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index f6219dfbe..10f3f48cb 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -61,15 +61,15 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; - bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. - bool normalize; // Normalize unions and intersections if necessary + bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. + bool normalize; // Normalize unions and intersections if necessary bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; UnifierSharedState& sharedState; - Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, - TxnLog* parentLog = nullptr); + Unifier( + NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); @@ -87,7 +87,8 @@ struct Unifier void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv); void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); - void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error = std::nullopt); + void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, + std::optional error = std::nullopt); void tryUnifyPrimitives(TypeId subTy, TypeId superTy); void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index dbe27bfd4..c5250a6db 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -497,7 +497,7 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) asMutable(context.result)->ty.emplace(resTypePack); } else if (tail) - asMutable(context.result)->ty.emplace(*tail); + asMutable(context.result)->ty.emplace(*tail); return true; } @@ -507,7 +507,8 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) if (AstExprConstantString* str = arg1->as()) { - if (str->value.size == 1 && str->value.data[0] == '#') { + if (str->value.size == 1 && str->value.data[0] == '#') + { TypePackId numberTypePack = context.solver->arena->addTypePack({context.solver->singletonTypes->numberType}); asMutable(context.result)->ty.emplace(numberTypePack); return true; diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 169f46452..8436fb309 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -43,7 +43,7 @@ static bool matchSetmetatable(const AstExprCall& call) if (call.args.size != 2) return false; - + const AstExprGlobal* funcAsGlobal = call.func->as(); if (!funcAsGlobal || funcAsGlobal->name != smt) return false; @@ -52,7 +52,8 @@ static bool matchSetmetatable(const AstExprCall& call) } ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, - NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger) + NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, + DcrLogger* logger) : moduleName(moduleName) , module(module) , singletonTypes(singletonTypes) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 5b3ec03cc..e29eeaaaa 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -14,8 +14,6 @@ #include "Luau/VisitTypeVar.h" #include "Luau/TypeUtils.h" -#include - LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); LUAU_FASTFLAG(LuauFixNameMaps) @@ -283,13 +281,27 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull 0; --i) + { + // Fisher-Yates shuffle + size_t j = rng % (i + 1); + + std::swap(unsolvedConstraints[i], unsolvedConstraints[j]); + + // LCG RNG, constants from Numerical Recipes + // This may occasionally result in skewed shuffles due to distribution properties, but this is a debugging tool so it should be good enough + rng = rng * 1664525 + 1013904223; + } } void ConstraintSolver::run() { - if (done()) + if (isDone()) return; if (FFlag::DebugLuauLogSolver) @@ -364,6 +376,8 @@ void ConstraintSolver::run() progress |= runSolverPass(true); } while (progress); + finalizeModule(); + if (FFlag::DebugLuauLogSolver) { dumpBindings(rootScope, opts); @@ -375,11 +389,24 @@ void ConstraintSolver::run() } } -bool ConstraintSolver::done() +bool ConstraintSolver::isDone() { return unsolvedConstraints.empty(); } +void ConstraintSolver::finalizeModule() +{ + Anyification a{arena, rootScope, singletonTypes, &iceReporter, singletonTypes->anyType, singletonTypes->anyTypePack}; + std::optional returnType = a.substitute(rootScope->returnType); + if (!returnType) + { + reportError(CodeTooComplex{}, Location{}); + rootScope->returnType = singletonTypes->errorTypePack; + } + else + rootScope->returnType = *returnType; +} + bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) { if (!force && isBlocked(constraint)) @@ -506,25 +533,25 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(singletonTypes->booleanType); - return true; - } - case AstExprUnary::Len: + case AstExprUnary::Not: + { + asMutable(c.resultType)->ty.emplace(singletonTypes->booleanType); + return true; + } + case AstExprUnary::Len: + { + asMutable(c.resultType)->ty.emplace(singletonTypes->numberType); + return true; + } + case AstExprUnary::Minus: + { + if (isNumber(operandType) || get(operandType) || get(operandType)) { - asMutable(c.resultType)->ty.emplace(singletonTypes->numberType); + asMutable(c.resultType)->ty.emplace(c.operandType); return true; } - case AstExprUnary::Minus: - { - if (isNumber(operandType) || get(operandType) || get(operandType)) - { - asMutable(c.resultType)->ty.emplace(c.operandType); - return true; - } - break; - } + break; + } } LUAU_ASSERT(false); // TODO metatable handling diff --git a/Analysis/src/DcrLogger.cpp b/Analysis/src/DcrLogger.cpp index a2eb96e5c..ef33aa606 100644 --- a/Analysis/src/DcrLogger.cpp +++ b/Analysis/src/DcrLogger.cpp @@ -57,7 +57,7 @@ void write(JsonEmitter& emitter, const ConstraintGenerationLog& log) emitter.writeRaw(":"); ObjectEmitter locationEmitter = emitter.writeObject(); - + for (const auto& [id, location] : log.constraintLocations) { locationEmitter.writePair(id, location); @@ -232,7 +232,7 @@ void DcrLogger::captureSource(std::string source) void DcrLogger::captureGenerationError(const TypeError& error) { std::string stringifiedError = toString(error); - generationLog.errors.push_back(ErrorSnapshot { + generationLog.errors.push_back(ErrorSnapshot{ /* message */ stringifiedError, /* location */ error.location, }); @@ -298,7 +298,8 @@ void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vec } } -StepSnapshot DcrLogger::prepareStepSnapshot(const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints) +StepSnapshot DcrLogger::prepareStepSnapshot( + const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints) { ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts); std::string currentId = toPointerId(current); @@ -344,7 +345,7 @@ void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vecto void DcrLogger::captureTypeCheckError(const TypeError& error) { std::string stringifiedError = toString(error); - checkLog.errors.push_back(ErrorSnapshot { + checkLog.errors.push_back(ErrorSnapshot{ /* message */ stringifiedError, /* location */ error.location, }); @@ -359,7 +360,7 @@ std::vector DcrLogger::snapshotBlocks(NotNull } std::vector snapshot; - + for (const ConstraintBlockTarget& target : it->second) { if (const TypeId* ty = get_if(&target)) diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index d13e26c0b..4e9b68820 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -8,7 +8,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false) -LUAU_FASTFLAGVARIABLE(LuauUseInternalCompilerErrorException, false) static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) @@ -122,8 +121,6 @@ struct ErrorConverter return "Unknown global '" + e.name + "'"; case UnknownSymbol::Type: return "Unknown type '" + e.name + "'"; - case UnknownSymbol::Generic: - return "Unknown generic '" + e.name + "'"; } LUAU_ASSERT(!"Unexpected context for UnknownSymbol"); @@ -902,46 +899,22 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) void InternalErrorReporter::ice(const std::string& message, const Location& location) { - if (FFlag::LuauUseInternalCompilerErrorException) - { - InternalCompilerError error(message, moduleName, location); + InternalCompilerError error(message, moduleName, location); - if (onInternalError) - onInternalError(error.what()); + if (onInternalError) + onInternalError(error.what()); - throw error; - } - else - { - std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message); - - if (onInternalError) - onInternalError(error.what()); - - throw error; - } + throw error; } void InternalErrorReporter::ice(const std::string& message) { - if (FFlag::LuauUseInternalCompilerErrorException) - { - InternalCompilerError error(message, moduleName); + InternalCompilerError error(message, moduleName); - if (onInternalError) - onInternalError(error.what()); + if (onInternalError) + onInternalError(error.what()); - throw error; - } - else - { - std::runtime_error error("Internal error in " + moduleName + ": " + message); - - if (onInternalError) - onInternalError(error.what()); - - throw error; - } + throw error; } const char* InternalCompilerError::what() const throw() diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 7f67a7db1..d5578446a 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,7 +14,6 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) -LUAU_FASTFLAGVARIABLE(LuauLintFixDeprecationMessage, false) namespace Luau { @@ -307,22 +306,11 @@ class LintGlobalLocal : AstVisitor emitWarning(*context, LintWarning::Code_UnknownGlobal, gv->location, "Unknown global '%s'", gv->name.value); else if (g->deprecated) { - if (FFlag::LuauLintFixDeprecationMessage) - { - if (const char* replacement = *g->deprecated; replacement && strlen(replacement)) - emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", - gv->name.value, replacement); - else - emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); - } + if (const char* replacement = *g->deprecated; replacement && strlen(replacement)) + emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", + gv->name.value, replacement); else - { - if (*g->deprecated) - emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", - gv->name.value, *g->deprecated); - else - emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); - } + emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); } } diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 45eb87d65..31a089a43 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,7 +15,6 @@ #include LUAU_FASTFLAG(LuauAnyifyModuleReturnGenerics) -LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauForceExportSurfacesToBeNormal, false); LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); @@ -244,19 +243,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern ForceNormal forceNormal{&interfaceTypes}; - if (FFlag::LuauLowerBoundsCalculation) - { - normalize(returnType, NotNull{this}, singletonTypes, ice); - if (FFlag::LuauForceExportSurfacesToBeNormal) - forceNormal.traverse(returnType); - if (varargPack) - { - normalize(*varargPack, NotNull{this}, singletonTypes, ice); - if (FFlag::LuauForceExportSurfacesToBeNormal) - forceNormal.traverse(*varargPack); - } - } - if (exportedTypeBindings) { for (auto& [name, tf] : *exportedTypeBindings) @@ -265,24 +251,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern tf = clonePublicInterface.cloneTypeFun(tf); else tf = clone(tf, interfaceTypes, cloneState); - if (FFlag::LuauLowerBoundsCalculation) - { - normalize(tf.type, NotNull{this}, singletonTypes, ice); - - // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables - // won't be marked normal. If the types aren't normal by now, they never will be. - forceNormal.traverse(tf.type); - for (GenericTypeDefinition param : tf.typeParams) - { - forceNormal.traverse(param.ty); - - if (param.defaultValue) - { - normalize(*param.defaultValue, NotNull{this}, singletonTypes, ice); - forceNormal.traverse(*param.defaultValue); - } - } - } } } @@ -305,13 +273,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern ty = clonePublicInterface.cloneType(ty); else ty = clone(ty, interfaceTypes, cloneState); - if (FFlag::LuauLowerBoundsCalculation) - { - normalize(ty, NotNull{this}, singletonTypes, ice); - - if (FFlag::LuauForceExportSurfacesToBeNormal) - forceNormal.traverse(ty); - } } freeze(internalTypes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index c008bcfc0..81114b76b 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -119,17 +119,9 @@ NormalizedType::NormalizedType(NotNull singletonTypes) static bool isInhabited(const NormalizedType& norm) { - return !get(norm.tops) - || !get(norm.booleans) - || !norm.classes.empty() - || !get(norm.errors) - || !get(norm.nils) - || !get(norm.numbers) - || !norm.strings || !norm.strings->empty() - || !get(norm.threads) - || norm.functions - || !norm.tables.empty() - || !norm.tyvars.empty(); + return !get(norm.tops) || !get(norm.booleans) || !norm.classes.empty() || !get(norm.errors) || + !get(norm.nils) || !get(norm.numbers) || !norm.strings || !norm.strings->empty() || + !get(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty(); } static int tyvarIndex(TypeId ty) @@ -139,7 +131,7 @@ static int tyvarIndex(TypeId ty) else if (const FreeTypeVar* ftv = get(ty)) return ftv->index; else - return 0; + return 0; } #ifdef LUAU_ASSERTENABLED @@ -193,7 +185,7 @@ static bool isNormalizedString(const NormalizedStringType& ty) { if (!ty) return true; - + for (auto& [str, ty] : *ty) { if (const SingletonTypeVar* stv = get(ty)) @@ -272,24 +264,24 @@ static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) static void assertInvariant(const NormalizedType& norm) { - #ifdef LUAU_ASSERTENABLED - if (!FFlag::DebugLuauCheckNormalizeInvariant) - return; +#ifdef LUAU_ASSERTENABLED + if (!FFlag::DebugLuauCheckNormalizeInvariant) + return; - LUAU_ASSERT(isNormalizedTop(norm.tops)); - LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); - LUAU_ASSERT(areNormalizedClasses(norm.classes)); - LUAU_ASSERT(isNormalizedError(norm.errors)); - LUAU_ASSERT(isNormalizedNil(norm.nils)); - LUAU_ASSERT(isNormalizedNumber(norm.numbers)); - LUAU_ASSERT(isNormalizedString(norm.strings)); - LUAU_ASSERT(isNormalizedThread(norm.threads)); - LUAU_ASSERT(areNormalizedFunctions(norm.functions)); - LUAU_ASSERT(areNormalizedTables(norm.tables)); - LUAU_ASSERT(isNormalizedTyvar(norm.tyvars)); - for (auto& [_, child] : norm.tyvars) - assertInvariant(*child); - #endif + LUAU_ASSERT(isNormalizedTop(norm.tops)); + LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); + LUAU_ASSERT(areNormalizedClasses(norm.classes)); + LUAU_ASSERT(isNormalizedError(norm.errors)); + LUAU_ASSERT(isNormalizedNil(norm.nils)); + LUAU_ASSERT(isNormalizedNumber(norm.numbers)); + LUAU_ASSERT(isNormalizedString(norm.strings)); + LUAU_ASSERT(isNormalizedThread(norm.threads)); + LUAU_ASSERT(areNormalizedFunctions(norm.functions)); + LUAU_ASSERT(areNormalizedTables(norm.tables)); + LUAU_ASSERT(isNormalizedTyvar(norm.tyvars)); + for (auto& [_, child] : norm.tyvars) + assertInvariant(*child); +#endif } Normalizer::Normalizer(TypeArena* arena, NotNull singletonTypes, NotNull sharedState) @@ -359,7 +351,7 @@ TypeId Normalizer::unionType(TypeId here, TypeId there) return there; if (get(there) || get(here)) return here; - + TypeIds tmps; if (const UnionTypeVar* utv = get(here)) @@ -405,7 +397,7 @@ TypeId Normalizer::intersectionType(TypeId here, TypeId there) return here; if (get(there) || get(here)) return there; - + TypeIds tmps; if (const IntersectionTypeVar* utv = get(here)) @@ -516,13 +508,13 @@ std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePack std::vector head; std::optional tail; - + bool hereSubThere = true; bool thereSubHere = true; TypePackIterator ith = begin(here); TypePackIterator itt = begin(there); - + while (ith != end(here) && itt != end(there)) { TypeId hty = *ith; @@ -537,8 +529,8 @@ std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePack itt++; } - auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere) - { + auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, + bool& thereSubHere) { if (ith != end(here)) { TypeId tty = singletonTypes->nilType; @@ -591,13 +583,13 @@ std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePack if (ty != tvtp->ty) hereSubThere = false; bool hidden = hvtp->hidden & tvtp->hidden; - tail = arena->addTypePack(VariadicTypePack{ty,hidden}); + tail = arena->addTypePack(VariadicTypePack{ty, hidden}); } - else + else // Luau doesn't have unions of type pack variables return std::nullopt; } - else + else // Luau doesn't have unions of type pack variables return std::nullopt; } @@ -627,7 +619,7 @@ std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePack else if (thereSubHere) return here; if (!head.empty()) - return arena->addTypePack(TypePack{head,tail}); + return arena->addTypePack(TypePack{head, tail}); else if (tail) return *tail; else @@ -639,10 +631,10 @@ std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) { if (get(here)) return here; - + if (get(there)) return there; - + const FunctionTypeVar* hftv = get(here); LUAU_ASSERT(hftv); const FunctionTypeVar* tftv = get(there); @@ -665,7 +657,7 @@ std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) return here; if (*argTypes == tftv->argTypes && *retTypes == tftv->retTypes) return there; - + FunctionTypeVar result{*argTypes, *retTypes}; result.generics = hftv->generics; result.genericPacks = hftv->genericPacks; @@ -802,9 +794,9 @@ bool Normalizer::withinResourceLimits() // Check the recursion count if (sharedState->counters.recursionLimit > 0) - if (sharedState->counters.recursionLimit < sharedState->counters.recursionCount) - return false; - + if (sharedState->counters.recursionLimit < sharedState->counters.recursionCount) + return false; + return true; } @@ -1000,13 +992,13 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T std::vector head; std::optional tail; - + bool hereSubThere = true; bool thereSubHere = true; TypePackIterator ith = begin(here); TypePackIterator itt = begin(there); - + while (ith != end(here) && itt != end(there)) { TypeId hty = *ith; @@ -1021,8 +1013,8 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T itt++; } - auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere) - { + auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, + bool& thereSubHere) { if (ith != end(here)) { TypeId tty = singletonTypes->nilType; @@ -1075,13 +1067,13 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T if (ty != tvtp->ty) hereSubThere = false; bool hidden = hvtp->hidden & tvtp->hidden; - tail = arena->addTypePack(VariadicTypePack{ty,hidden}); + tail = arena->addTypePack(VariadicTypePack{ty, hidden}); } - else + else // Luau doesn't have unions of type pack variables return std::nullopt; } - else + else // Luau doesn't have unions of type pack variables return std::nullopt; } @@ -1105,7 +1097,7 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T else if (thereSubHere) return there; if (!head.empty()) - return arena->addTypePack(TypePack{head,tail}); + return arena->addTypePack(TypePack{head, tail}); else if (tail) return *tail; else @@ -1146,7 +1138,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there return std::nullopt; if (httv->state == TableState::Generic || tttv->state == TableState::Generic) return std::nullopt; - + TableState state = httv->state; if (tttv->state == TableState::Unsealed) state = tttv->state; @@ -1226,21 +1218,20 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there } else return std::nullopt; - } else if (hmtable) { if (table == htable) return here; else - return arena->addType(MetatableTypeVar{table, hmtable}); + return arena->addType(MetatableTypeVar{table, hmtable}); } else if (tmtable) { if (table == ttable) return there; else - return arena->addType(MetatableTypeVar{table, tmtable}); + return arena->addType(MetatableTypeVar{table, tmtable}); } else return table; @@ -1280,7 +1271,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th return std::nullopt; if (hftv->retTypes != tftv->retTypes) return std::nullopt; - + std::optional argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); if (!argTypes) return std::nullopt; @@ -1289,7 +1280,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th return here; if (*argTypes == tftv->argTypes) return there; - + FunctionTypeVar result{*argTypes, hftv->retTypes}; result.generics = hftv->generics; result.genericPacks = hftv->genericPacks; @@ -1299,7 +1290,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId there) { // Deep breath... - // + // // When we come to check overloaded functions for subtyping, // we have to compare (F1 & ... & FM) <: (G1 & ... G GN) // where each Fi or Gj is a function type. Now that intersection on the right is no @@ -1319,12 +1310,12 @@ std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId th // // So subtyping on overloaded functions "just" boils down to defining Apply. // - // Now for non-overloaded functions, this is easy! + // Now for non-overloaded functions, this is easy! // Apply<(R -> S), T> is S if T <: R, and an error type otherwise. // // But for overloaded functions it's not so simple. We'd like Apply // to just be Apply & ... & Apply but oh dear - // + // // if f : ((number -> number) & (string -> string)) // and x : (number | string) // then f(x) : (number | string) @@ -1334,7 +1325,7 @@ std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId th // Apply<((number -> number) & (string -> string)), (number | string)> is (number | string) // // but - // + // // Apply<(number -> number), (number | string)> is an error // Apply<(string -> string), (number | string)> is an error // @@ -1382,7 +1373,7 @@ std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId th // Covariance and Contravariance, Giuseppe Castagna, // Logical Methods in Computer Science 16(1), 2022 // https://arxiv.org/abs/1809.01427 - // + // // A gentle introduction to semantic subtyping, Giuseppe Castagna and Alain Frisch, // Proc. Principles and practice of declarative programming 2005, pp 198–208 // https://doi.org/10.1145/1069774.1069793 @@ -1398,7 +1389,7 @@ std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId th return std::nullopt; if (hftv->genericPacks != tftv->genericPacks) return std::nullopt; - + std::optional argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); if (!argTypes) return std::nullopt; @@ -1416,7 +1407,7 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T { if (!heres) return; - + for (auto it = heres->begin(); it != heres->end();) { TypeId here = *it; @@ -1450,7 +1441,7 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali { heres = std::nullopt; return; - } + } else { for (TypeId there : *theres) @@ -1530,7 +1521,7 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th if (isInhabited(inter)) it++; else - it = here.tyvars.erase(it); + it = here.tyvars.erase(it); } return true; } @@ -1757,7 +1748,8 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) +bool isSubtype( + TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) { UnifierSharedState sharedState{&ice}; TypeArena arena; @@ -2377,4 +2369,3 @@ std::pair normalize(TypePackId tp, const ModulePtr& module, No } } // namespace Luau - diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index fa12f306b..2137d73ee 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -9,7 +9,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false) -LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauClonePublicInterfaceLess) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false) @@ -553,9 +552,6 @@ TypePackId Substitution::replace(TypePackId tp) void Substitution::replaceChildren(TypeId ty) { - if (BoundTypeVar* btv = log->getMutable(ty); FFlag::LuauLowerBoundsCalculation && btv) - btv->boundTo = replace(btv->boundTo); - LUAU_ASSERT(ty == log->follow(ty)); if (ignoreChildren(ty)) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 135602511..f5ab9494c 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,11 +10,11 @@ #include #include -LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) +LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) /* * Prefix generic typenames with gen- @@ -523,7 +523,7 @@ struct TypeVarStringifier bool plural = true; - if (FFlag::LuauLowerBoundsCalculation) + if (FFlag::LuauFunctionReturnStringificationFixup) { auto retBegin = begin(ftv.retTypes); auto retEnd = end(ftv.retTypes); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index b2f3cfd3f..4753a7c21 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -24,7 +24,7 @@ namespace Luau // TypeInfer.h // TODO move these -using PrintLineProc = void(*)(const std::string&); +using PrintLineProc = void (*)(const std::string&); extern PrintLineProc luauPrintLine; /* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. @@ -127,7 +127,8 @@ struct TypeChecker2 if (auto ann = ref->parameters.data[0].type) { TypeId argTy = lookupAnnotation(ref->parameters.data[0].type); - luauPrintLine(format("_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str())); + luauPrintLine(format( + "_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str())); return follow(argTy); } } @@ -409,8 +410,8 @@ struct TypeChecker2 } TypeId iteratorTy = follow(iteratorTypes[0]); - auto checkFunction = [this, &arena, &scope, &forInStatement, &variableTypes](const FunctionTypeVar* iterFtv, std::vector iterTys, bool isMm) - { + auto checkFunction = [this, &arena, &scope, &forInStatement, &variableTypes]( + const FunctionTypeVar* iterFtv, std::vector iterTys, bool isMm) { if (iterTys.size() < 1 || iterTys.size() > 3) { if (isMm) @@ -420,20 +421,21 @@ struct TypeChecker2 return; } - + // It is okay if there aren't enough iterators, but the iteratee must provide enough. std::vector expectedVariableTypes = flatten(arena, singletonTypes, iterFtv->retTypes, variableTypes.size()); if (expectedVariableTypes.size() < variableTypes.size()) { if (isMm) - reportError(GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values)); + reportError( + GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values)); else reportError(GenericError{"next() does not return enough values"}, forInStatement->values.data[0]->location); } for (size_t i = 0; i < std::min(expectedVariableTypes.size(), variableTypes.size()); ++i) reportErrors(tryUnify(scope, forInStatement->vars.data[i]->location, variableTypes[i], expectedVariableTypes[i])); - + // nextFn is going to be invoked with (arrayTy, startIndexTy) // It will be passed two arguments on every iteration save the @@ -509,7 +511,8 @@ struct TypeChecker2 { // nothing } - else if (std::optional iterMmTy = findMetatableEntry(singletonTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) + else if (std::optional iterMmTy = + findMetatableEntry(singletonTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) { Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}, scope}; @@ -554,7 +557,7 @@ struct TypeChecker2 // TODO: This will not tell the user that this is because the // metamethod isn't callable. This is not ideal, and we should // improve this error message. - + // TODO: This will also not handle intersections of functions or // callable tables (which are supported by the runtime). reportError(CannotCallNonFunction{*iterMmTy}, forInStatement->values.data[0]->location); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index cb21aa7fe..b806edb7c 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,15 +33,11 @@ LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauTypeNormalization2) -LUAU_FASTFLAGVARIABLE(LuauFunctionArgMismatchDetails, false) -LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) -LUAU_FASTFLAGVARIABLE(LuauCallUnifyPackTails, false) -LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) @@ -136,34 +132,6 @@ bool hasBreak(AstStat* node) } } -static bool hasReturn(const AstStat* node) -{ - struct Searcher : AstVisitor - { - bool result = false; - - bool visit(AstStat*) override - { - return !result; // if we've already found a return statement, don't bother to traverse inward anymore - } - - bool visit(AstStatReturn*) override - { - result = true; - return false; - } - - bool visit(AstExprFunction*) override - { - return false; // We don't care if the function uses a lambda that itself returns - } - }; - - Searcher searcher; - const_cast(node)->visit(&searcher); - return searcher.result; -} - // returns the last statement before the block exits, or nullptr if the block never exits const AstStat* getFallthrough(const AstStat* node) { @@ -550,16 +518,6 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A std::unordered_map> functionDecls; - auto isLocalLambda = [](AstStat* stat) -> AstStatLocal* { - AstStatLocal* local = stat->as(); - - if (FFlag::LuauLowerBoundsCalculation && local && local->vars.size == 1 && local->values.size == 1 && - local->values.data[0]->is()) - return local; - else - return nullptr; - }; - auto checkBody = [&](AstStat* stat) { if (auto fun = stat->as()) { @@ -607,7 +565,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (containsFunctionCallOrReturn(**protoIter) || (FFlag::LuauLowerBoundsCalculation && isLocalLambda(*protoIter))) + if (containsFunctionCallOrReturn(**protoIter)) { while (checkIter != protoIter) { @@ -906,12 +864,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; - if (useConstrainedIntersections()) - { - unifyLowerBound(retPack, scope->returnType, demoter.demotedLevel(scope->level), scope, return_.location); - return; - } - // HACK: Nonstrict mode gets a bit too smart and strict for us when we // start typechecking everything across module boundaries. if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) @@ -1574,11 +1526,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias for (auto param : binding->typePackParams) clone.instantiatedTypePackParams.push_back(param.tp); - bool isNormal = ty->normal; ty = addType(std::move(clone)); - - if (FFlag::LuauLowerBoundsCalculation) - asMutable(ty)->normal = isNormal; } } else @@ -1605,14 +1553,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (unify(ty, bindingType, aliasScope, typealias.location)) bindingType = ty; - - if (FFlag::LuauLowerBoundsCalculation) - { - auto [t, ok] = normalize(bindingType, currentModule, singletonTypes, *iceHandler); - bindingType = t; - if (!ok) - reportError(typealias.location, NormalizationTooComplex{}); - } } void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel) @@ -1959,9 +1899,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } else if (const FreeTypePack* ftp = get(retPack)) { - TypeLevel level = FFlag::LuauLowerBoundsCalculation ? ftp->level : scope->level; - TypeId head = freshType(level); - TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(level)}}); + TypeId head = freshType(scope->level); + TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope->level)}}); unify(pack, retPack, scope, expr.location); return {head, std::move(result.predicates)}; } @@ -2111,27 +2050,14 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( return std::nullopt; } - if (FFlag::LuauLowerBoundsCalculation) - { - // FIXME Inefficient. We craft a UnionTypeVar and immediately throw it away. - auto [t, ok] = normalize(addType(UnionTypeVar{std::move(goodOptions)}), currentModule, singletonTypes, *iceHandler); - - if (!ok) - reportError(location, NormalizationTooComplex{}); - - return t; - } - else - { - std::vector result = reduceUnion(goodOptions); - if (FFlag::LuauUnknownAndNeverType && result.empty()) - return neverType; + std::vector result = reduceUnion(goodOptions); + if (FFlag::LuauUnknownAndNeverType && result.empty()) + return neverType; - if (result.size() == 1) - return result[0]; + if (result.size() == 1) + return result[0]; - return addType(UnionTypeVar{std::move(result)}); - } + return addType(UnionTypeVar{std::move(result)}); } else if (const IntersectionTypeVar* itv = get(type)) { @@ -3426,13 +3352,6 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& } } } - - if (!FFlag::LuauCheckGenericHOFTypes) - { - // We do not infer type binders, so if a generic function is required we do not propagate - if (expectedFunctionType && !(expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty())) - expectedFunctionType = nullptr; - } } auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); @@ -3442,8 +3361,7 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& retPack = resolveTypePack(funScope, *expr.returnAnnotation); else if (isNonstrictMode()) retPack = anyTypePack; - else if (expectedFunctionType && - (!FFlag::LuauCheckGenericHOFTypes || (expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty()))) + else if (expectedFunctionType && expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty()) { auto [head, tail] = flatten(expectedFunctionType->retTypes); @@ -3488,10 +3406,6 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& funScope->varargPack = anyTypePack; } } - else if (FFlag::LuauLowerBoundsCalculation && !isNonstrictMode()) - { - funScope->varargPack = addTypePack(TypePackVar{VariadicTypePack{anyType, /*hidden*/ true}}); - } std::vector argTypes; @@ -3575,48 +3489,28 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& std::vector genericTys; // if we have a generic expected function type and no generics, we should use the expected ones. - if (FFlag::LuauCheckGenericHOFTypes) + if (expectedFunctionType && generics.empty()) { - if (expectedFunctionType && generics.empty()) - { - genericTys = expectedFunctionType->generics; - } - else - { - genericTys.reserve(generics.size()); - for (const GenericTypeDefinition& generic : generics) - genericTys.push_back(generic.ty); - } + genericTys = expectedFunctionType->generics; } else { genericTys.reserve(generics.size()); - std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { - return el.ty; - }); + for (const GenericTypeDefinition& generic : generics) + genericTys.push_back(generic.ty); } std::vector genericTps; // if we have a generic expected function type and no generic typepacks, we should use the expected ones. - if (FFlag::LuauCheckGenericHOFTypes) + if (expectedFunctionType && genericPacks.empty()) { - if (expectedFunctionType && genericPacks.empty()) - { - genericTps = expectedFunctionType->genericPacks; - } - else - { - genericTps.reserve(genericPacks.size()); - for (const GenericTypePackDefinition& generic : genericPacks) - genericTps.push_back(generic.tp); - } + genericTps = expectedFunctionType->genericPacks; } else { genericTps.reserve(genericPacks.size()); - std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { - return el.tp; - }); + for (const GenericTypePackDefinition& generic : genericPacks) + genericTps.push_back(generic.tp); } TypeId funTy = @@ -3674,24 +3568,9 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE { check(scope, *function.body); - if (useConstrainedIntersections()) - { - TypePackId retPack = follow(funTy->retTypes); - // It is possible for a function to have no annotation and no return statement, and yet still have an ascribed return type - // if it is expected to conform to some other interface. (eg the function may be a lambda passed as a callback) - if (!hasReturn(function.body) && !function.returnAnnotation.has_value() && get(retPack)) - { - auto level = getLevel(retPack); - if (level && scope->level.subsumes(*level)) - *asMutable(retPack) = TypePack{{}, std::nullopt}; - } - } - else - { - // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (get_if(&funTy->retTypes->ty)) - *asMutable(funTy->retTypes) = TypePack{{}, std::nullopt}; - } + // We explicitly don't follow here to check if we have a 'true' free type instead of bound one + if (get_if(&funTy->retTypes->ty)) + *asMutable(funTy->retTypes) = TypePack{{}, std::nullopt}; bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; @@ -3763,21 +3642,13 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - if (FFlag::LuauFunctionArgMismatchDetails) - { - std::string namePath; - if (std::optional lValue = tryGetLValue(funName)) - namePath = toString(*lValue); + std::string namePath; + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); - auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); - state.reportError(TypeError{location, - CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}}); - } - else - { - size_t minParams = getParameterExtents(&state.log, paramPack).first; - state.reportError(TypeError{location, CountMismatch{minParams, std::nullopt, std::distance(begin(argPack), end(argPack))}}); - } + auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); + state.reportError(TypeError{location, + CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}}); }; while (true) @@ -3801,7 +3672,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam else state.log.replace(*argTail, TypePackVar(TypePack{{}})); } - else if (FFlag::LuauCallUnifyPackTails && paramTail) + else if (paramTail) { state.tryUnify(*argTail, *paramTail); } @@ -3881,20 +3752,12 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam std::optional tail = flatten(paramPack, state.log).second; bool isVariadic = tail && Luau::isVariadic(*tail); - if (FFlag::LuauFunctionArgMismatchDetails) - { - std::string namePath; - if (std::optional lValue = tryGetLValue(funName)) - namePath = toString(*lValue); + std::string namePath; + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); - state.reportError(TypeError{ - state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); - } - else - { - state.reportError( - TypeError{state.location, CountMismatch{minParams, std::nullopt, paramIndex, CountMismatch::Context::Arg, isVariadic}}); - } + state.reportError(TypeError{ + state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); return; } ++paramIter; @@ -3924,21 +3787,6 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam } else if (auto vtp = state.log.getMutable(tail)) { - if (FFlag::LuauLowerBoundsCalculation && vtp->hidden) - { - // We know that this function can technically be oversaturated, but we have its definition and we - // know that it's useless. - - TypeId e = errorRecoveryType(scope); - while (argIter != endIter) - { - unify(e, *argIter, scope, state.location); - ++argIter; - } - - reportCountMismatchError(); - return; - } // Function is variadic and requires that all subsequent parameters // be compatible with a type. size_t argIndex = paramIndex; @@ -4040,21 +3888,14 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope } TypePackId retPack; - if (FFlag::LuauLowerBoundsCalculation) + if (auto free = get(actualFunctionType)) { - retPack = freshTypePack(scope->level); + retPack = freshTypePack(free->level); + TypePackId freshArgPack = freshTypePack(free->level); + asMutable(actualFunctionType)->ty.emplace(free->level, freshArgPack, retPack); } else - { - if (auto free = get(actualFunctionType)) - { - retPack = freshTypePack(free->level); - TypePackId freshArgPack = freshTypePack(free->level); - asMutable(actualFunctionType)->ty.emplace(free->level, freshArgPack, retPack); - } - else - retPack = freshTypePack(scope->level); - } + retPack = freshTypePack(scope->level); // checkExpr will log the pre-instantiated type of the function. // That's not nearly as interesting as the instantiated type, which will include details about how @@ -4214,39 +4055,13 @@ std::optional> TypeChecker::checkCallOverload(const Sc // fn is one of the overloads of actualFunctionType, which // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. - if (useConstrainedIntersections()) - { - // This ternary is phrased deliberately. We need ties between sibling scopes to bias toward ftv->level. - const TypeLevel level = scope->level.subsumes(ftv->level) ? scope->level : ftv->level; - - std::vector adjustedArgTypes; - auto it = begin(argPack); - auto endIt = end(argPack); - Widen widen{¤tModule->internalTypes, singletonTypes}; - for (; it != endIt; ++it) - { - adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widen(*it)}})); - } - - TypePackId adjustedArgPack = addTypePack(TypePack{std::move(adjustedArgTypes), it.tail()}); - - TxnLog log; - promoteTypeLevels(log, ¤tModule->internalTypes, level, /*scope*/ nullptr, /*useScope*/ false, retPack); - log.commit(); + TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - *asMutable(fn) = FunctionTypeVar{level, adjustedArgPack, retPack}; - return {{retPack}}; - } - else - { - TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - - UnifierOptions options; - options.isFunctionCall = true; - unify(r, fn, scope, expr.location, options); + UnifierOptions options; + options.isFunctionCall = true; + unify(r, fn, scope, expr.location, options); - return {{retPack}}; - } + return {{retPack}}; } std::vector metaArgLocations; @@ -4760,14 +4575,6 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location Luau::quantify(ty, scope->level); else if (auto ttv = getTableType(ty); ttv && ttv->selfTy) Luau::quantify(ty, scope->level); - - if (FFlag::LuauLowerBoundsCalculation) - { - auto [t, ok] = Luau::normalize(ty, currentModule, singletonTypes, *iceHandler); - if (!ok) - reportError(location, NormalizationTooComplex{}); - return t; - } } else { @@ -4775,14 +4582,6 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (ftv) Luau::quantify(ty, scope->level); - - if (FFlag::LuauLowerBoundsCalculation && ftv) - { - auto [t, ok] = Luau::normalize(ty, currentModule, singletonTypes, *iceHandler); - if (!ok) - reportError(location, NormalizationTooComplex{}); - return t; - } } return ty; @@ -4813,14 +4612,6 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { - if (FFlag::LuauLowerBoundsCalculation) - { - auto [t, ok] = normalize(ty, currentModule, singletonTypes, *iceHandler); - if (!ok) - reportError(location, NormalizationTooComplex{}); - ty = t; - } - Anyification anyification{¤tModule->internalTypes, scope, singletonTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (anyification.normalizationTooComplex) @@ -4836,14 +4627,6 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { - if (FFlag::LuauLowerBoundsCalculation) - { - auto [t, ok] = normalize(ty, currentModule, singletonTypes, *iceHandler); - if (!ok) - reportError(location, NormalizationTooComplex{}); - ty = t; - } - Anyification anyification{¤tModule->internalTypes, scope, singletonTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) @@ -6083,11 +5866,6 @@ bool TypeChecker::isNonstrictMode() const return (currentModule->mode == Mode::Nonstrict) || (currentModule->mode == Mode::NoCheck); } -bool TypeChecker::useConstrainedIntersections() const -{ - return FFlag::LuauLowerBoundsCalculation && !isNonstrictMode(); -} - std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp, size_t expectedLength, const Location& location) { TypePackId expectedTypePack = addTypePack({}); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index ca00c2699..688c87672 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -6,8 +6,6 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" -LUAU_FASTFLAG(LuauFunctionArgMismatchDetails) - namespace Luau { @@ -218,7 +216,7 @@ std::pair> getParameterExtents(const TxnLog* log, ++it; } - if (it.tail() && (!FFlag::LuauFunctionArgMismatchDetails || isVariadicTail(*it.tail(), *log, includeHiddenVariadics))) + if (it.tail() && isVariadicTail(*it.tail(), *log, includeHiddenVariadics)) return {minCount, std::nullopt}; else return {minCount, minCount + optionalCount}; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index b143268e3..bcdaff7d2 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -25,7 +25,6 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) -LUAU_FASTFLAGVARIABLE(LuauStringFormatArgumentErrorFix, false) LUAU_FASTFLAGVARIABLE(LuauNoMoreGlobalSingletonTypes, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) @@ -1166,21 +1165,12 @@ std::optional> magicFunctionFormat( } // if we know the argument count or if we have too many arguments for sure, we can issue an error - if (FFlag::LuauStringFormatArgumentErrorFix) - { - size_t numActualParams = params.size(); - size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - } - else - { - size_t actualParamSize = params.size() - paramOffset; + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) - typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), std::nullopt, actualParamSize}}); - } return WithPredicate{arena.addTypePack({typechecker.stringType})}; } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 5a01c9348..42fcd2fda 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,14 +18,12 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); -LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) -LUAU_FASTFLAG(LuauCallUnifyPackTails) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau @@ -346,8 +344,7 @@ static bool subsumes(bool useScopes, TY_A* left, TY_B* right) return left->level.subsumes(right->level); } -Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, - Variance variance, TxnLog* parentLog) +Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog) : types(normalizer->arena) , singletonTypes(normalizer->singletonTypes) , normalizer(normalizer) @@ -529,7 +526,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyUnionWithType(subTy, subUnion, superTy); } - else if (const UnionTypeVar* uv = (FFlag::LuauSubtypeNormalizer? nullptr: log.getMutable(superTy))) + else if (const UnionTypeVar* uv = (FFlag::LuauSubtypeNormalizer ? nullptr : log.getMutable(superTy))) { tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); } @@ -865,7 +862,8 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV } } -void Unifier::tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error) +void Unifier::tryUnifyNormalizedTypes( + TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error) { LUAU_ASSERT(FFlag::LuauSubtypeNormalizer); @@ -1371,12 +1369,12 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal else { // A union type including nil marks an optional argument - if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && superIter.good() && isOptional(*superIter)) + if (superIter.good() && isOptional(*superIter)) { superIter.advance(); continue; } - else if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && subIter.good() && isOptional(*subIter)) + else if (subIter.good() && isOptional(*subIter)) { subIter.advance(); continue; @@ -1394,7 +1392,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal return; } - if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && !isFunctionCall && subIter.good()) + if (!isFunctionCall && subIter.good()) { // Sometimes it is ok to pass too many arguments return; @@ -1491,7 +1489,6 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); - } else { @@ -2012,7 +2009,8 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); + reportError( + TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); else if (!missingProperty) { log.concat(std::move(innerState.log)); @@ -2448,8 +2446,7 @@ void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel de for (; superIter != superEndIter; ++superIter) tp->head.push_back(*superIter); } - else if (const VariadicTypePack* subVariadic = log.getMutable(subTailPack); - subVariadic && FFlag::LuauCallUnifyPackTails) + else if (const VariadicTypePack* subVariadic = log.getMutable(subTailPack)) { while (superIter != superEndIter) { diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index dab761060..6345fde46 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -1,17 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Common.h" + #include #include #include -#if defined(__GNUC__) -#define LUAU_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg))) -#else -#define LUAU_PRINTF_ATTR(fmt, arg) -#endif - namespace Luau { diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 4d3beec93..aecddf383 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -4,6 +4,7 @@ #include "lua.h" #include "lualib.h" +#include "Luau/CodeGen.h" #include "Luau/Compiler.h" #include "Luau/BytecodeBuilder.h" #include "Luau/Parser.h" @@ -46,11 +47,15 @@ enum class CompileFormat { Text, Binary, + Remarks, + Codegen, Null }; constexpr int MaxTraversalLimit = 50; +static bool codegen = false; + // Ctrl-C handling static void sigintCallback(lua_State* L, int gc) { @@ -159,6 +164,9 @@ static int lua_require(lua_State* L) std::string bytecode = Luau::compile(*source, copts()); if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (codegen) + Luau::CodeGen::compile(ML, -1); + if (coverageActive()) coverageTrack(ML, -1); @@ -242,6 +250,9 @@ static int lua_callgrind(lua_State* L) void setupState(lua_State* L) { + if (codegen) + Luau::CodeGen::create(L); + luaL_openlibs(L); static const luaL_Reg funcs[] = { @@ -276,6 +287,9 @@ std::string runCode(lua_State* L, const std::string& source) return error; } + if (codegen) + Luau::CodeGen::compile(L, -1); + lua_State* T = lua_newthread(L); lua_pushvalue(L, -2); @@ -604,6 +618,9 @@ static bool runFile(const char* name, lua_State* GL, bool repl) if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (codegen) + Luau::CodeGen::compile(L, -1); + if (coverageActive()) coverageTrack(L, -1); @@ -656,6 +673,20 @@ static void reportError(const char* name, const Luau::CompileError& error) report(name, error.getLocation(), "CompileError", error.what()); } +static std::string getCodegenAssembly(const char* name, const std::string& bytecode) +{ + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + setupState(L); + + if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) + return Luau::CodeGen::getAssemblyText(L, -1); + + fprintf(stderr, "Error loading bytecode %s\n", name); + return ""; +} + static bool compileFile(const char* name, CompileFormat format) { std::optional source = readFile(name); @@ -675,6 +706,11 @@ static bool compileFile(const char* name, CompileFormat format) Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpSource(*source); } + else if (format == CompileFormat::Remarks) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } Luau::compileOrThrow(bcb, *source, copts()); @@ -683,9 +719,15 @@ static bool compileFile(const char* name, CompileFormat format) case CompileFormat::Text: printf("%s", bcb.dumpEverything().c_str()); break; + case CompileFormat::Remarks: + printf("%s", bcb.dumpSourceRemarks().c_str()); + break; case CompileFormat::Binary: fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); break; + case CompileFormat::Codegen: + printf("%s", getCodegenAssembly(name, bcb.getBytecode()).c_str()); + break; case CompileFormat::Null: break; } @@ -713,7 +755,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available modes:\n"); printf(" omitted: compile and run input files one by one\n"); - printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); + printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary, text, remarks, codegen or null)\n"); printf("\n"); printf("Available options:\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); @@ -723,6 +765,7 @@ static void displayHelp(const char* argv0) printf(" -g: compile with debug level n (default 1, n should be between 0 and 2).\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); + printf(" --codegen: execute code using native code generation\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -761,6 +804,14 @@ int replMain(int argc, char** argv) { compileFormat = CompileFormat::Text; } + else if (strcmp(argv[1], "--compile=remarks") == 0) + { + compileFormat = CompileFormat::Remarks; + } + else if (strcmp(argv[1], "--compile=codegen") == 0) + { + compileFormat = CompileFormat::Codegen; + } else if (strcmp(argv[1], "--compile=null") == 0) { compileFormat = CompileFormat::Null; @@ -811,6 +862,10 @@ int replMain(int argc, char** argv) { profile = atoi(argv[i] + 10); } + else if (strcmp(argv[i], "--codegen") == 0) + { + codegen = true; + } else if (strcmp(argv[i], "--coverage") == 0) { coverage = true; @@ -839,12 +894,26 @@ int replMain(int argc, char** argv) } #endif +#if !LUA_CUSTOM_EXECUTION + if (codegen) + { + fprintf(stderr, "To run with --codegen, Luau has to be built with LUA_CUSTOM_EXECUTION enabled\n"); + return 1; + } +#endif + const std::vector files = getSourceFiles(argc, argv); if (mode == CliMode::Unknown) { mode = files.empty() ? CliMode::Repl : CliMode::RunSourceFiles; } + if (mode != CliMode::Compile && codegen && !Luau::CodeGen::isSupported()) + { + fprintf(stderr, "Cannot enable --codegen, native code generation is not supported in current configuration\n"); + return 1; + } + switch (mode) { case CliMode::Compile: diff --git a/CMakeLists.txt b/CMakeLists.txt index 43289f418..0016160ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,7 @@ option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) +option(LUAU_NATIVE "Enable support for native code generation" OFF) if(LUAU_STATIC_CRT) cmake_minimum_required(VERSION 3.15) @@ -132,6 +133,10 @@ if(LUAU_EXTERN_C) target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\") endif() +if(LUAU_NATIVE) + target_compile_definitions(Luau.VM PUBLIC LUA_CUSTOM_EXECUTION=1) +endif() + if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: # https://developercommunity.visualstudio.com/t/performance-regression-on-a-complex-interpreter-lo/1631863 @@ -167,7 +172,7 @@ if(LUAU_BUILD_CLI) target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include) - target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM isocline) + target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.CodeGen Luau.VM isocline) if(UNIX) find_library(LIBPTHREAD pthread) @@ -193,11 +198,11 @@ if(LUAU_BUILD_TESTS) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Conformance PRIVATE extern) - target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.VM) + target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen Luau.VM) target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.CLI.Test PRIVATE extern CLI) - target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.VM isocline) + target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.CodeGen Luau.VM isocline) if(UNIX) find_library(LIBPTHREAD pthread) if (LIBPTHREAD) diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 15db7a156..1c7550170 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -15,6 +15,14 @@ namespace Luau namespace CodeGen { +enum class RoundingModeX64 +{ + RoundToNearestEven = 0b00, + RoundToNegativeInfinity = 0b01, + RoundToPositiveInfinity = 0b10, + RoundToZero = 0b11, +}; + class AssemblyBuilderX64 { public: @@ -48,6 +56,8 @@ class AssemblyBuilderX64 void imul(OperandX64 op); void neg(OperandX64 op); void not_(OperandX64 op); + void dec(OperandX64 op); + void inc(OperandX64 op); // Additional forms of imul void imul(OperandX64 lhs, OperandX64 rhs); @@ -82,13 +92,12 @@ class AssemblyBuilderX64 void vxorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); - void vcomisd(OperandX64 src1, OperandX64 src2); void vucomisd(OperandX64 src1, OperandX64 src2); void vcvttsd2si(OperandX64 dst, OperandX64 src); void vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2); - void vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mode); + void vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode); // inexact void vsqrtpd(OperandX64 dst, OperandX64 src); void vsqrtps(OperandX64 dst, OperandX64 src); @@ -120,6 +129,8 @@ class AssemblyBuilderX64 OperandX64 f32x4(float x, float y, float z, float w); OperandX64 bytes(const void* ptr, size_t size, size_t align = 8); + void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); + // Resulting data and code that need to be copied over one after the other // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' std::vector data; @@ -127,6 +138,8 @@ class AssemblyBuilderX64 std::string text; + const bool logText = false; + private: // Instruction archetypes void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev, @@ -178,7 +191,6 @@ class AssemblyBuilderX64 LUAU_NOINLINE void log(Label label); LUAU_NOINLINE void log(const char* opcode, Label label); void log(OperandX64 op); - void logAppend(const char* fmt, ...); const char* getSizeName(SizeX64 size); const char* getRegisterName(RegisterX64 reg); @@ -187,7 +199,6 @@ class AssemblyBuilderX64 std::vector(() -> a, a) -> ()", toString(requireType("f"))); -} - -TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") -{ - ScopedFastFlag flags[] = { - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - export type t0 = {_:{_:any} & {_:any|string}} & {_:{_:{}}} - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "fuzz_failure_bound_type_is_normal_but_not_its_bounded_to") -{ - ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; - - CheckResult result = check(R"( - type t252 = ((t0)|(any))|(any) - type t0 = t252,t24...> - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -// We had an issue where a normal BoundTypeVar might point at a non-normal BoundTypeVar if it in turn pointed to a -// normal TypeVar because we were calling follow() in an improper place. -TEST_CASE_FIXTURE(Fixture, "bound_typevars_should_only_be_marked_normal_if_their_pointee_is_normal") -{ - ScopedFastFlag sff[]{ - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - local T = {} - - function T:M() - local function f(a) - print(self.prop) - self:g(a) - self.prop = a - end - end - - return T - )"); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "skip_force_normal_on_external_types") { createSomeClasses(frontend); @@ -1108,68 +472,4 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "normalize_unions_containing_never") -{ - ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; - - CheckResult result = check(R"( - type Foo = string | never - local foo: Foo - )"); - - CHECK_EQ("string", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "normalize_unions_containing_unknown") -{ - ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; - - CheckResult result = check(R"( - type Foo = string | unknown - local foo: Foo - )"); - - CHECK_EQ("unknown", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "any_wins_the_battle_over_unknown_in_unions") -{ - ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; - - CheckResult result = check(R"( - type Foo = unknown | any - local foo: Foo - - type Bar = any | unknown - local bar: Bar - )"); - - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); -} - -TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_does_not_convert_ever") -{ - ScopedFastFlag sff[]{ - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - --!strict - local function f() - if math.random() > 0.5 then - return true - end - type Ret = typeof(f()) - if math.random() > 0.5 then - return "something" - end - return "something" :: Ret - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("() -> boolean | string", toString(requireType("f"))); -} - TEST_SUITE_END(); diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index 6619147b3..7e50d5b64 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -15,8 +15,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLowerBoundsCalculation); - struct LimitFixture : BuiltinsFixture { #if defined(_NOOPT) || defined(_DEBUG) @@ -267,10 +265,7 @@ TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") CheckResult result = check(src); CodeTooComplex ctc; - if (FFlag::LuauLowerBoundsCalculation) - LUAU_REQUIRE_ERRORS(result); - else - CHECK(hasError(result, &ctc)); + CHECK(hasError(result, &ctc)); } TEST_SUITE_END(); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 95dcd70ad..98eb9863b 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -7,8 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauLowerBoundsCalculation) - using namespace Luau; struct ToDotClassFixture : Fixture @@ -111,29 +109,7 @@ local function f(a, ...: string) return a end ToDotOptions opts; opts.showPointers = false; - if (FFlag::LuauLowerBoundsCalculation) - { - CHECK_EQ(R"(digraph graphname { -n1 [label="FunctionTypeVar 1"]; -n1 -> n2 [label="arg"]; -n2 [label="TypePack 2"]; -n2 -> n3; -n3 [label="GenericTypeVar 3"]; -n2 -> n4 [label="tail"]; -n4 [label="VariadicTypePack 4"]; -n4 -> n5; -n5 [label="string"]; -n1 -> n6 [label="ret"]; -n6 [label="TypePack 6"]; -n6 -> n7; -n7 [label="BoundTypeVar 7"]; -n7 -> n3; -})", - toDot(requireType("f"), opts)); - } - else - { - CHECK_EQ(R"(digraph graphname { + CHECK_EQ(R"(digraph graphname { n1 [label="FunctionTypeVar 1"]; n1 -> n2 [label="arg"]; n2 [label="TypePack 2"]; @@ -149,8 +125,7 @@ n6 -> n7; n7 [label="TypePack 7"]; n7 -> n3; })", - toDot(requireType("f"), opts)); - } + toDot(requireType("f"), opts)); } TEST_CASE_FIXTURE(Fixture, "union") diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 1339ec28a..53e5f71ba 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -12,6 +12,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); LUAU_FASTFLAG(LuauFixNameMaps); +LUAU_FASTFLAG(LuauFunctionReturnStringificationFixup); TEST_SUITE_BEGIN("ToString"); @@ -570,6 +571,22 @@ TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_T CHECK_EQ("{| hello: number, world: number |}", toString(&tpv2)); } +TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_return_type_if_pack_has_an_empty_head_link") +{ + TypeArena arena; + TypePackId realTail = arena.addTypePack({singletonTypes->stringType}); + TypePackId emptyTail = arena.addTypePack({}, realTail); + + TypePackId argList = arena.addTypePack({singletonTypes->stringType}); + + TypeId functionType = arena.addType(FunctionTypeVar{argList, emptyTail}); + + if (FFlag::LuauFunctionReturnStringificationFixup) + CHECK("(string) -> string" == toString(functionType)); + else + CHECK("(string) -> (string)" == toString(functionType)); +} + TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 5f2c22cfd..28767889d 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -657,50 +657,9 @@ struct AssertionCatcher int AssertionCatcher::tripped; } // namespace -TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice") -{ - ScopedFastFlag sffs[] = { - {"DebugLuauMagicTypes", true}, - {"LuauUseInternalCompilerErrorException", false}, - }; - - AssertionCatcher ac; - - CHECK_THROWS_AS(check(R"( - local a: _luau_ice = 55 - )"), - std::runtime_error); - - LUAU_ASSERT(1 == AssertionCatcher::tripped); -} - -TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_handler") -{ - ScopedFastFlag sffs[] = { - {"DebugLuauMagicTypes", true}, - {"LuauUseInternalCompilerErrorException", false}, - }; - - bool caught = false; - - frontend.iceHandler.onInternalError = [&](const char*) { - caught = true; - }; - - CHECK_THROWS_AS(check(R"( - local a: _luau_ice = 55 - )"), - std::runtime_error); - - CHECK_EQ(true, caught); -} - TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag") { - ScopedFastFlag sffs[] = { - {"DebugLuauMagicTypes", true}, - {"LuauUseInternalCompilerErrorException", true}, - }; + ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; AssertionCatcher ac; @@ -714,10 +673,7 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag") TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag_handler") { - ScopedFastFlag sffs[] = { - {"DebugLuauMagicTypes", true}, - {"LuauUseInternalCompilerErrorException", true}, - }; + ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; bool caught = false; diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 037f79d8a..f9c104fd9 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -8,9 +8,7 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); -LUAU_FASTFLAG(LuauStringFormatArgumentErrorFix) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) TEST_SUITE_BEGIN("BuiltinTests"); @@ -637,8 +635,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_decimal_argument_is_rounded_down // Could be flaky if the fix has regressed. TEST_CASE_FIXTURE(BuiltinsFixture, "bad_select_should_not_crash") { - ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; - CheckResult result = check(R"( do end local _ = function(l0,...) @@ -754,14 +750,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument") LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauStringFormatArgumentErrorFix) - { - CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but 3 are specified", toString(result.errors[0])); - } - else - { - CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); - } + CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but 3 are specified", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument2") @@ -778,8 +767,6 @@ TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument2") TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_use_correct_argument3") { - ScopedFastFlag LuauStringFormatArgumentErrorFix{"LuauStringFormatArgumentErrorFix", true}; - CheckResult result = check(R"( local s1 = string.format("%d") local s2 = string.format("%d", 1) @@ -966,10 +953,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types") )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauLowerBoundsCalculation) - CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); - else - CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); + CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") @@ -1040,8 +1024,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { - ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; - ScopedFastFlag sff{"LuauSetMetaTableArgsCheck", true}; CheckResult result = check(R"( local a = {b=setmetatable} diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 87ec58c9f..d00f1d831 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -482,7 +482,7 @@ local a: ChildClass = i TEST_CASE_FIXTURE(ClassFixture, "intersections_of_unions_of_classes") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -499,7 +499,7 @@ TEST_CASE_FIXTURE(ClassFixture, "intersections_of_unions_of_classes") TEST_CASE_FIXTURE(ClassFixture, "unions_of_intersections_of_classes") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index fa99ff584..edc25c7e8 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -14,7 +14,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); @@ -299,22 +298,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") CHECK_EQ("t1 where t1 = () -> t1", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") -{ - ScopedFastFlag sff[] = { - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - function f(g) - return f(f) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = (t1) -> (a...)", toString(requireType("f"))); -} - TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") { CheckResult result = check(R"( @@ -1132,16 +1115,13 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - if (!FFlag::LuauLowerBoundsCalculation) - { - // Return type doesn't inference 'nil' - result = check(R"( - function f(a: (number) -> nil) return a(4) end - f(function(x) print(x) end) - )"); + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); - LUAU_REQUIRE_NO_ERRORS(result); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") @@ -1244,16 +1224,13 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - if (!FFlag::LuauLowerBoundsCalculation) - { - // Return type doesn't inference 'nil' - result = check(R"( - function f(a: (number) -> nil) return a(4) end - f(function(x) print(x) end) - )"); + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); - LUAU_REQUIRE_NO_ERRORS(result); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") @@ -1436,87 +1413,6 @@ end CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); } -TEST_CASE_FIXTURE(Fixture, "inconsistent_return_types") -{ - const ScopedFastFlag flags[] = { - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - function foo(a: boolean, b: number) - if a then - return nil - else - return b - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(boolean, number) -> number?", toString(requireType("foo"))); - - // TODO: Test multiple returns - // Think of various cases where typepacks need to grow. maybe consult other tests - // Basic normalization of ConstrainedTypeVars during quantification -} - -TEST_CASE_FIXTURE(Fixture, "inconsistent_higher_order_function") -{ - const ScopedFastFlag flags[] = { - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - function foo(f) - f(5) - f("six") - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("((number | string) -> (a...)) -> ()", toString(requireType("foo"))); -} - - -/* The bug here is that we are using the same level 2.0 for both the body of resolveDispatcher and the - * lambda useCallback. - * - * I think what we want to do is, at each scope level, never reuse the same sublevel. - * - * We also adjust checkBlock to consider the syntax `local x = function() ... end` to be sortable - * in the same way as `local function x() ... end`. This causes the function `resolveDispatcher` to be - * checked before the lambda. - */ -TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time") -{ - ScopedFastFlag sff[] = { - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - --!strict - - local function resolveDispatcher() - return (nil :: any) :: {useCallback: (any) -> any} - end - - local useCallback = function(deps: any) - return resolveDispatcher().useCallback(deps) - end - )"); - - // LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken. - // You get a TypeMismatch error where both types stringify the same. - - CHECK(result.errors.empty()); - if (!result.errors.empty()) - { - for (const auto& e : result.errors) - printf("%s: %s\n", toString(e.location).c_str(), toString(e).c_str()); - } -} - TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time2") { CheckResult result = check(R"( @@ -1700,56 +1596,6 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type") CHECK(nullptr != get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "quantify_constrained_types") -{ - ScopedFastFlag sff[]{ - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - --!strict - local function foo(f) - f(5) - f("hi") - local function g() - return f - end - local h = g() - h(true) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("((boolean | number | string) -> (a...)) -> ()", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "call_o_with_another_argument_after_foo_was_quantified") -{ - ScopedFastFlag sff[]{ - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - local function f(o) - local t = {} - t[o] = true - - local function foo(o) - o.m1(5) - t[o] = nil - end - - o.m1("hi") - - return t - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - // TODO: check the normalized type of f -} - TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_unknown") { CheckResult result = check(R"( @@ -1800,8 +1646,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_mutate_the_underlying_head_of_typepack_when_cal TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_errors") { - ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; - CheckResult result = check(R"( local function foo1(a: number) end foo1() @@ -1838,8 +1682,6 @@ u.a.foo() // This might be surprising, but since 'any' became optional, unannotated functions in non-strict 'expect' 0 arguments TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_error_nonstrict") { - ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; - CheckResult result = check(R"( --!nonstrict local function foo(a, b) end diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 1b02abc1b..e1729ef5f 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,7 +9,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauCheckGenericHOFTypes) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSpecialTypesAsterisked) @@ -783,8 +782,6 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_few") { - ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; - CheckResult result = check(R"( function test(a: number) return 1 @@ -802,8 +799,6 @@ wrapper(test) TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") { - ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; - CheckResult result = check(R"( function test2(a: number, b: string) return 1 @@ -965,7 +960,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); else CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); - } TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") @@ -1114,27 +1108,7 @@ local b = sumrec(sum) -- ok local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred )"); - if (FFlag::LuauCheckGenericHOFTypes) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else if (FFlag::LuauInstantiateInSubtyping) - { - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ( - R"(Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a' -caused by: - Argument #1 type is not compatible. Generic subtype escaping scope)", - toString(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ( - "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " - "parameters", - toString(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") @@ -1258,7 +1232,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "higher_rank_polymorphism_should_not_accept_i { ScopedFastFlag sffs[] = { {"LuauInstantiateInSubtyping", true}, - {"LuauCheckGenericHOFTypes", true}, // necessary because of interactions with the test }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index e49df1017..ca22c351b 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -8,7 +8,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLowerBoundsCalculation); TEST_SUITE_BEGIN("IntersectionTypes"); @@ -306,10 +305,7 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauLowerBoundsCalculation) - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table '{| x: number, y: number |}'"); - else - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); } TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") @@ -333,16 +329,9 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' caused by: Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); - if (FFlag::LuauLowerBoundsCalculation) - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table '{| x: (number) -> number, y: (string) -> string |}'"); - else - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); - - if (FFlag::LuauLowerBoundsCalculation) - CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table '{| x: (number) -> number, y: (string) -> string |}'"); - else - CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); } TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") @@ -381,15 +370,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_intersection_setmetatable") TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") { - ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } type Z = { z: number } - type XYZ = X & Y & Z - local a: XYZ = 3 )"); @@ -401,15 +386,11 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") { - ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } type Z = { z: number } - type XYZ = X & Y & Z - local a: XYZ local b: number = a )"); @@ -468,12 +449,13 @@ TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") LUAU_REQUIRE_ERROR_COUNT(1, result); // TODO: odd stringification of `false & (boolean & false)`.) - CHECK_EQ(toString(result.errors[0]), "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), + "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -485,12 +467,13 @@ TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> number?) & ((string?) -> string?)' could not be converted into '(number) -> number'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> number?) & ((string?) -> string?)' could not be converted into '(number) -> number'; " + "none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -502,12 +485,13 @@ TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number) & ((string) -> string)' could not be converted into '(boolean | number) -> boolean | number'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number) & ((string) -> string)' could not be converted into '(boolean | number) -> " + "boolean | number'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -519,7 +503,8 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into '{| p: nil |}'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into " + "'{| p: nil |}'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") @@ -531,12 +516,13 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, q: number? |}'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, " + "q: number? |}'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -549,12 +535,13 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") // TODO: this should not produce type errors, since never <: { p : never } LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| p: never, q: string? |} & {| p: number?, q: never |}' could not be converted into 'never'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: never, q: string? |} & {| p: number?, q: never |}' could not be converted into 'never'; none " + "of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -566,12 +553,14 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into '(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), + "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into " + "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -585,12 +574,13 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> a | number) & ((string?) -> a | string)' could not be converted into '(number?) -> a'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> a | number) & ((string?) -> a | string)' could not be converted into '(number?) -> a'; " + "none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generics") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -604,12 +594,13 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generics") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((a?) -> a | b) & ((c?) -> b | c)' could not be converted into '(a?) -> (a & c) | b'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), + "Type '((a?) -> a | b) & ((c?) -> b | c)' could not be converted into '(a?) -> (a & c) | b'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic_packs") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -623,12 +614,13 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic_packs") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' could not be converted into '(nil, b...) -> (nil, a...)'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' could not be converted " + "into '(nil, b...) -> (nil, a...)'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -642,12 +634,13 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> unknown) & ((number) -> number)' could not be converted into '(number?) -> number?'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> unknown) & ((number) -> number)' could not be converted into '(number?) -> number?'; none " + "of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -661,12 +654,13 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number?) & ((unknown) -> string?)' could not be converted into '(number?) -> nil'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number?) & ((unknown) -> string?)' could not be converted into '(number?) -> nil'; none " + "of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_result") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -680,12 +674,13 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_result") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> never) & ((number) -> number)' could not be converted into '(number?) -> never'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> never) & ((number) -> number)' could not be converted into '(number?) -> never'; none of " + "the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_arguments") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -699,7 +694,8 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_arguments") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((never) -> string?) & ((number) -> number?)' could not be converted into '(number?) -> nil'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '((never) -> string?) & ((number) -> number?)' could not be converted into '(number?) -> nil'; none " + "of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_overlapping_results_and_variadics") @@ -711,7 +707,8 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_overlapping_results_and_ )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> (...number)) & ((string?) -> number | string)' could not be converted into '(number | string) -> (number, number?)'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> (...number)) & ((string?) -> number | string)' could not be converted into '(number | " + "string) -> (number, number?)'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_1") @@ -725,7 +722,8 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_1") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(() -> (a...)) & (() -> (b...))' could not be converted into '() -> ()'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), + "Type '(() -> (a...)) & (() -> (b...))' could not be converted into '() -> ()'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_2") @@ -739,7 +737,8 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_2") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((a...) -> ()) & ((b...) -> ())' could not be converted into '() -> ()'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), + "Type '((a...) -> ()) & ((b...) -> ())' could not be converted into '() -> ()'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_3") @@ -753,7 +752,8 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_3") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(() -> (a...)) & (() -> (number?, a...))' could not be converted into '() -> number'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), + "Type '(() -> (a...)) & (() -> (number?, a...))' could not be converted into '() -> number'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_4") @@ -767,12 +767,13 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_4") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((a...) -> ()) & ((number, a...) -> number)' could not be converted into '(number?) -> ()'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '((a...) -> ()) & ((number, a...) -> number)' could not be converted into '(number?) -> ()'; none of " + "the intersection parts are compatible"); } TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -800,7 +801,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_subtypes") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -826,7 +827,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_subtypes") TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables_with_properties") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -849,7 +850,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables_with_properties") TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_with table") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -874,7 +875,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_with table") TEST_CASE_FIXTURE(Fixture, "CLI-44817") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 588a9a763..d6f787bec 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -622,9 +622,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_not_enough_returns") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(result.errors[0] == TypeError{ - Location{{2, 36}, {2, 37}}, - GenericError{"__iter must return at least one value"}, - }); + Location{{2, 36}, {2, 37}}, + GenericError{"__iter must return at least one value"}, + }); } TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_ok") diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 36943cac8..8b7b35141 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -254,20 +254,20 @@ return m if (FFlag::LuauInstantiateInSubtyping) { // though this didn't error before the flag, it seems as though it should error since fields of a table are invariant. - // the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be unsound. + // the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be + // unsound. LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(R"(Type 'n' could not be converted into 't1 where t1 = {- Clone: (t1) -> (a...) -}' caused by: Property 'Clone' is not compatible. Type '(a) -> ()' could not be converted into 't1 where t1 = ({- Clone: t1 -}) -> (a...)'; different number of generic type parameters)", - toString(result.errors[0])); + toString(result.errors[0])); } else { LUAU_REQUIRE_NO_ERRORS(result); } - } TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 2aac6653f..ccc4d775a 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAG(LuauLowerBoundsCalculation) - using namespace Luau; TEST_SUITE_BEGIN("ProvisionalTests"); @@ -301,19 +299,10 @@ TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauLowerBoundsCalculation) - { - CHECK_EQ("() -> ()", toString(requireType("f"))); - CHECK_EQ("() -> ()", toString(requireType("g"))); - CHECK_EQ("nil", toString(requireType("x"))); - } - else - { - // f and g should have the type () -> () - CHECK_EQ("() -> (a...)", toString(requireType("f"))); - CHECK_EQ("() -> (a...)", toString(requireType("g"))); - CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now - } + // f and g should have the type () -> () + CHECK_EQ("() -> (a...)", toString(requireType("f"))); + CHECK_EQ("() -> (a...)", toString(requireType("g"))); + CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now } TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") @@ -330,7 +319,6 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") { ScopedFastFlag sff[] = { - {"LuauLowerBoundsCalculation", false}, // I'm not sure why this is broken without DCR, but it seems to be fixed // when DCR is enabled. {"DebugLuauDeferredConstraintResolution", false}, @@ -347,7 +335,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { ScopedFastFlag sff[] = { - {"LuauLowerBoundsCalculation", false}, // I'm not sure why this is broken without DCR, but it seems to be fixed // when DCR is enabled. {"DebugLuauDeferredConstraintResolution", false}, @@ -362,56 +349,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } -TEST_CASE_FIXTURE(Fixture, "lower_bounds_calculation_is_too_permissive_with_overloaded_higher_order_functions") -{ - ScopedFastFlag sff[] = { - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - function foo(f) - f(5, 'a') - f('b', 6) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // We incorrectly infer that the argument to foo could be called with (number, number) or (string, string) - // even though that is strictly more permissive than the actual source text shows. - CHECK("((number | string, number | string) -> (a...)) -> ()" == toString(requireType("foo"))); -} - -// Once fixed, move this to Normalize.test.cpp -TEST_CASE_FIXTURE(Fixture, "normalization_fails_on_certain_kinds_of_cyclic_tables") -{ -#if defined(_DEBUG) || defined(_NOOPT) - ScopedFastInt sfi("LuauNormalizeIterationLimit", 500); -#endif - - ScopedFastFlag flags[] = { - {"LuauLowerBoundsCalculation", true}, - }; - - // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. - // This exposes a bug where the type of y is mutated. - CheckResult result = check(R"( - function strange(x, y) - x.x = y - y.x = x - - type R = {x: typeof(x)} & {x: typeof(y)} - local r: R - - return r - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK(nullptr != get(result.errors[0])); -} - // Belongs in TypeInfer.builtins.test.cpp. TEST_CASE_FIXTURE(BuiltinsFixture, "pcall_returns_at_least_two_value_but_function_returns_nothing") { @@ -473,36 +410,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_many_things_but_first_of_it CHECK_EQ("boolean", toString(requireType("b"))); } -TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") -{ - ScopedFastFlag sff[]{ - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - local function f(o) - local t = {} - t[o] = true - - local function foo(o) - o:m1() - t[o] = nil - end - - local function bar(o) - o:m2() - t[o] = true - end - - return t - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - // TODO: We're missing generics b... - CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); -} - TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_any") { CheckResult result = check(R"( @@ -695,4 +602,187 @@ return wrapStrictTable(Constants, "Constants") CHECK(get(*result)); } +// We need a simplification step to make this do the right thing. ("normalization-lite") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") +{ + CheckResult result = check(R"( + local function foo(t, x) + if x == "hi" or x == "bye" then + table.insert(t, x) + end + + return t + end + + local t = foo({}, "hi") + table.insert(t, "totally_unrelated_type" :: "totally_unrelated_type") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // We'd really like for this to be {string} + CHECK_EQ("{string | string}", toString(requireType("t"))); +} + +struct NormalizeFixture : Fixture +{ + bool isSubtype(TypeId a, TypeId b) + { + return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); + } +}; + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arities") +{ + check(R"( + type A = (any) -> () + type B = (any, any) -> () + type T = A & B + + local a: A + local b: B + local t: T + )"); + + [[maybe_unused]] TypeId a = requireType("a"); + [[maybe_unused]] TypeId b = requireType("b"); + + // CHECK(!isSubtype(a, b)); // !! + // CHECK(!isSubtype(b, a)); + + CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") +{ + check(R"( + local a: (number) -> () + local b: () -> () + + local c: () -> number + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + // CHECK(!isSubtype(b, a)); + // CHECK(!isSubtype(c, a)); + + CHECK(!isSubtype(a, b)); + // CHECK(!isSubtype(c, b)); + + CHECK(!isSubtype(a, c)); + CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_optional_parameters") +{ + /* + * (T0..TN) <: (T0..TN, A?) + * (T0..TN) <: (T0..TN, any) + * (T0..TN, A?) R <: U -> S if U <: T and R <: S + * A | B <: T if A <: T and B <: T + * T <: A | B if T <: A or T <: B + */ + check(R"( + local a: (number?) -> () + local b: (number) -> () + local c: (number, number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, number?) -> () <: (number) -> (number) + * The packs have inequal lengths, but (number) <: (number, number?) + * and number <: number + */ + // CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * because (number, number?) () () + * because (number, number?) () + local b: (number) -> () + local c: (number, any) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, any) -> () (number) + * The packs have inequal lengths + */ + // CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * The packs have inequal lengths + */ + // CHECK(!isSubtype(a, c)); + + /* + * (number) -> () () + * The packs have inequal lengths + */ + // CHECK(!isSubtype(b, c)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index b6dedcbd3..f707f9522 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -7,7 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauSpecialTypesAsterisked) using namespace Luau; @@ -608,10 +607,7 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauLowerBoundsCalculation) - CHECK_EQ("{| x: number, y: number |}", toString(requireTypeAtPosition({4, 28}))); - else - CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 5ee956d7b..73ccac701 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -421,28 +421,6 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") -{ - ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; - - CheckResult result = check(R"( - local function foo(t, x) - if x == "hi" or x == "bye" then - table.insert(t, x) - end - - return t - end - - local t = foo({}, "hi") - table.insert(t, "totally_unrelated_type" :: "totally_unrelated_type") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("{string}", toString(requireType("t"))); -} - TEST_CASE_FIXTURE(Fixture, "functions_are_not_to_be_widened") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index a6d870fad..53f9a1abb 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -11,7 +11,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauInstantiateInSubtyping) TEST_SUITE_BEGIN("TableTests"); @@ -1196,10 +1195,7 @@ TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_c )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauLowerBoundsCalculation) - CHECK(get(result.errors[0])); - else - CHECK(get(result.errors[0])); + CHECK(get(result.errors[0])); } // This unit test could be flaky if the fix has regressed. @@ -2627,8 +2623,6 @@ do end TEST_CASE_FIXTURE(BuiltinsFixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") { - ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; - CheckResult result = check("local x = setmetatable({})"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Argument count mismatch. Function 'setmetatable' expects 2 arguments, but only 1 is specified", toString(result.errors[0])); @@ -2709,8 +2703,6 @@ local baz = foo[bar] TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") { - ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; - CheckResult result = check(R"( local a = setmetatable({ x = 2 }, { __call = function(self) @@ -2887,7 +2879,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") { ScopedFastFlag sff[] = { - {"LuauLowerBoundsCalculation", true}, + // {"LuauLowerBoundsCalculation", true}, {"DebugLuauSharedSelf", true}, }; diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 26171c518..239b8c28f 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -14,12 +14,10 @@ #include -LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); -LUAU_FASTFLAG(LuauCheckGenericHOFTypes); using namespace Luau; @@ -89,7 +87,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { ScopedFastFlag sff[]{ {"DebugLuauDeferredConstraintResolution", false}, - {"LuauLowerBoundsCalculation", true}, }; CheckResult result = check(R"( @@ -1001,21 +998,23 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") end )"); - if (FFlag::LuauInstantiateInSubtyping && !FFlag::LuauCheckGenericHOFTypes) + if (FFlag::LuauInstantiateInSubtyping) { // though this didn't error before the flag, it seems as though it should error since fields of a table are invariant. - // the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be unsound. + // the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be + // unsound. LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 't1 where t1 = {+ getStoreFieldName: (t1, {| fieldName: string |} & {| from: number? |}) -> (a, b...) +}' could not be converted into 'Policies' + CHECK_EQ( + R"(Type 't1 where t1 = {+ getStoreFieldName: (t1, {| fieldName: string |} & {| from: number? |}) -> (a, b...) +}' could not be converted into 'Policies' caused by: Property 'getStoreFieldName' is not compatible. Type 't1 where t1 = ({+ getStoreFieldName: t1 +}, {| fieldName: string |} & {| from: number? |}) -> (a, b...)' could not be converted into '(Policies, FieldSpecifier) -> string' caused by: Argument #2 type is not compatible. Type 'FieldSpecifier' could not be converted into 'FieldSpecifier & {| from: number? |}' caused by: Not all intersection parts are compatible. Table type 'FieldSpecifier' not compatible with type '{| from: number? |}' because the former has extra field 'fieldName')", - toString(result.errors[0])); + toString(result.errors[0])); } else { @@ -1044,7 +1043,7 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") { ScopedFastInt sfi("LuauTypeInferRecursionLimit", 10); - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, {"LuauAutocompleteDynamicLimits", true}, @@ -1057,14 +1056,14 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") end )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Internal error: Code is too complex to typecheck! Consider adding type annotations around this area", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "type_infer_cache_limit_normalizer") { ScopedFastInt sfi("LuauNormalizeCacheLimit", 10); - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -1101,45 +1100,6 @@ TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") LUAU_REQUIRE_NO_ERRORS(result); } -/** - * The problem we had here was that the type of q in B.h was initially inferring to {} | {prop: free} before we bound - * that second table to the enclosing union. - */ -TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_table") -{ - ScopedFastFlag flag[] = { - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - --!strict - - local A = {} - - function A:f() - local t = {} - - for key, value in pairs(self) do - t[key] = value - end - - return t - end - - local B = A:f() - - function B.g(t) - assert(type(t) == "table") - assert(t.prop ~= nil) - end - - function B.h(q) - q = q or {} - return q or {} - end - )"); -} - TEST_CASE_FIXTURE(Fixture, "types_stored_in_astResolvedTypes") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index dedb7d28a..c178d2a4e 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -302,8 +302,8 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table REQUIRE_EQ(state.errors.size(), 1); std::string expected = "Type '{ @metatable {| __index: {| foo: string |} |}, { } }' could not be converted into '{- foo: number -}'\n" - "caused by:\n" - " Type 'number' could not be converted into 'string'"; + "caused by:\n" + " Type 'number' could not be converted into 'string'"; CHECK_EQ(toString(state.errors[0]), expected); } diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index eb61c396b..aaa7ded44 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -7,9 +7,9 @@ #include "doctest.h" -using namespace Luau; +LUAU_FASTFLAG(LuauFunctionReturnStringificationFixup); -LUAU_FASTFLAG(LuauLowerBoundsCalculation); +using namespace Luau; TEST_SUITE_BEGIN("TypePackTests"); @@ -311,7 +311,7 @@ local c: Packed auto ttvA = get(requireType("a")); REQUIRE(ttvA); CHECK_EQ(toString(requireType("a")), "Packed"); - if (FFlag::LuauLowerBoundsCalculation) + if (FFlag::LuauFunctionReturnStringificationFixup) CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); else CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); @@ -966,8 +966,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "detect_cyclic_typepacks2") TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments") { - ScopedFastFlag luauCallUnifyPackTails{"LuauCallUnifyPackTails", true}; - CheckResult result = check(R"( function foo(...: string): number return 1 @@ -984,8 +982,6 @@ TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments") TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments_free") { - ScopedFastFlag luauCallUnifyPackTails{"LuauCallUnifyPackTails", true}; - CheckResult result = check(R"( function foo(...: T...): T... return ... diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 64c9b5630..dc5516345 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -6,7 +6,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauSpecialTypesAsterisked) using namespace Luau; @@ -360,10 +359,7 @@ a.x = 2 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauLowerBoundsCalculation) - CHECK_EQ("Value of type '{| x: number, y: number |}?' could be nil", toString(result.errors[0])); - else - CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "optional_length_error") @@ -532,18 +528,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") LUAU_REQUIRE_ERROR_COUNT(1, result); // NOTE: union normalization will improve this message - if (FFlag::LuauLowerBoundsCalculation) - CHECK_EQ(toString(result.errors[0]), "Type '(string) -> number' could not be converted into '(number) -> string'\n" - "caused by:\n" - " Argument #1 type is not compatible. Type 'number' could not be converted into 'string'"); - else - CHECK_EQ(toString(result.errors[0]), - R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); + CHECK_EQ(toString(result.errors[0]), + R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); } TEST_CASE_FIXTURE(Fixture, "union_true_and_false") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -561,7 +552,7 @@ TEST_CASE_FIXTURE(Fixture, "union_true_and_false") TEST_CASE_FIXTURE(Fixture, "union_of_functions") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -598,7 +589,7 @@ TEST_CASE_FIXTURE(Fixture, "union_of_generic_typepack_functions") TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generics") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -612,12 +603,13 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generics") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(a) -> a?' could not be converted into '((b) -> b) | ((b?) -> nil)'; none of the union options are compatible"); + CHECK_EQ(toString(result.errors[0]), + "Type '(a) -> a?' could not be converted into '((b) -> b) | ((b?) -> nil)'; none of the union options are compatible"); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generic_typepacks") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -631,12 +623,13 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generic_typepacks") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(number, a...) -> (number?, a...)' could not be converted into '((number) -> number) | ((number?, a...) -> (number?, a...))'; none of the union options are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '(number, a...) -> (number?, a...)' could not be converted into '((number) -> number) | ((number?, " + "a...) -> (number?, a...))'; none of the union options are compatible"); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -648,12 +641,13 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(number) -> number?' could not be converted into '((number) -> nil) | ((number, string?) -> number)'; none of the union options are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '(number) -> number?' could not be converted into '((number) -> nil) | ((number, string?) -> " + "number)'; none of the union options are compatible"); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -665,12 +659,13 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '() -> number | string' could not be converted into '(() -> (string, string)) | (() -> number)'; none of the union options are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '() -> number | string' could not be converted into '(() -> (string, string)) | (() -> number)'; none " + "of the union options are compatible"); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -682,12 +677,13 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(...nil) -> (...number?)' could not be converted into '((...string?) -> (...number)) | ((...string?) -> nil)'; none of the union options are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '(...nil) -> (...number?)' could not be converted into '((...string?) -> (...number)) | ((...string?) " + "-> nil)'; none of the union options are compatible"); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_variadics") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -699,12 +695,13 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_variadics") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(number) -> ()' could not be converted into '((...number?) -> ()) | ((number?) -> ())'; none of the union options are compatible"); + CHECK_EQ(toString(result.errors[0]), + "Type '(number) -> ()' could not be converted into '((...number?) -> ()) | ((number?) -> ())'; none of the union options are compatible"); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, }; @@ -716,7 +713,8 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '() -> (number?, ...number)' could not be converted into '(() -> (...number)) | (() -> number)'; none of the union options are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type '() -> (number?, ...number)' could not be converted into '(() -> (...number)) | (() -> number)'; none " + "of the union options are compatible"); } TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index b09b087b5..b7e85aa74 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -93,6 +93,7 @@ assert((function() local a = 1 a = a * 2 return a end)() == 2) assert((function() local a = 1 a = a / 2 return a end)() == 0.5) assert((function() local a = 5 a = a % 2 return a end)() == 1) assert((function() local a = 3 a = a ^ 2 return a end)() == 9) +assert((function() local a = 9 a = a ^ 0.5 return a end)() == 3) assert((function() local a = '1' a = a .. '2' return a end)() == "12") assert((function() local a = '1' a = a .. '2' .. '3' return a end)() == "123") @@ -475,6 +476,12 @@ assert(rawequal("a", "a") == true) assert(rawequal("a", "b") == false) assert((function() a = {} b = {} mt = { __eq = function(l, r) return #l == #r end } setmetatable(a, mt) setmetatable(b, mt) return concat(a == b, rawequal(a, b)) end)() == "true,false") +-- rawequal fallback +assert(concat(pcall(rawequal, "a", "a")) == "true,true") +assert(concat(pcall(rawequal, "a", "b")) == "true,false") +assert(concat(pcall(rawequal, "a", nil)) == "true,false") +assert(pcall(rawequal, "a") == false) + -- metatable ops local function vec3t(x, y, z) return setmetatable({x=x, y=y, z=z}, { diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua index f0c5698d4..3b117892d 100644 --- a/tests/conformance/bitwise.lua +++ b/tests/conformance/bitwise.lua @@ -71,6 +71,7 @@ for _, b in pairs(c) do assert(bit32.bxor(b) == b) assert(bit32.bxor(b, b) == 0) assert(bit32.bxor(b, 0) == b) + assert(bit32.bxor(b, b, b) == b) assert(bit32.bnot(b) ~= b) assert(bit32.bnot(bit32.bnot(b)) == b) assert(bit32.bnot(b) == 2^32 - 1 - b) @@ -104,6 +105,9 @@ assert(bit32.extract(0xa0001111, 16) == 0) assert(bit32.extract(0xa0001111, 31) == 1) assert(bit32.extract(42, 1, 3) == 5) +local pos pos = 1 +assert(bit32.extract(42, pos, 3) == 5) -- test bit32.extract builtin instead of bit32.extractk + assert(not pcall(bit32.extract, 0, -1)) assert(not pcall(bit32.extract, 0, 32)) assert(not pcall(bit32.extract, 0, 0, 33)) @@ -144,13 +148,17 @@ assert(bit32.lrotate("0x12345678", 4) == 0x23456781) assert(bit32.rrotate("0x12345678", -4) == 0x23456781) assert(bit32.arshift("0x12345678", 1) == 0x12345678 / 2) assert(bit32.arshift("-1", 32) == 0xffffffff) +assert(bit32.arshift("-1", 1) == 0xffffffff) assert(bit32.bnot("1") == 0xfffffffe) assert(bit32.band("1", 3) == 1) assert(bit32.band(1, "3") == 1) +assert(bit32.band(1, 3, "5") == 1) assert(bit32.bor("1", 2) == 3) assert(bit32.bor(1, "2") == 3) +assert(bit32.bor(1, 3, "5") == 7) assert(bit32.bxor("1", 3) == 2) assert(bit32.bxor(1, "3") == 2) +assert(bit32.bxor(1, 3, "5") == 7) assert(bit32.btest(1, "3") == true) assert(bit32.btest("1", 3) == true) assert(bit32.countlz("42") == 26) diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index ec0b412e0..c773013b7 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -54,4 +54,19 @@ breakpoint(49, false) -- validate that disabling breakpoints works bar() +local function breakpointSetFromMetamethod() + local a = setmetatable({}, { + __index = function() + breakpoint(67) + return 2 + end + }) + + local b = a.x + + assert(b == 2) +end + +breakpointSetFromMetamethod() + return 'OK' diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 0c6055dac..447b67bce 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -4,20 +4,6 @@ print('testing metatables') local unpack = table.unpack -X = 20; B = 30 - -local _G = getfenv() -setfenv(1, setmetatable({}, {__index=_G})) - -collectgarbage() - -X = X+10 -assert(X == 30 and _G.X == 20) -B = false -assert(B == false) -B = nil -assert(B == 30) - assert(getmetatable{} == nil) assert(getmetatable(4) == nil) assert(getmetatable(nil) == nil) @@ -299,14 +285,8 @@ x = c(3,4,5) assert(i == 3 and x[1] == 3 and x[3] == 5) -assert(_G.X == 20) -assert(_G == getfenv(0)) - print'+' -local _g = _G -setfenv(1, setmetatable({}, {__index=function (_,k) return _g[k] end})) - -- testing proxies assert(getmetatable(newproxy()) == nil) assert(getmetatable(newproxy(false)) == nil) @@ -480,4 +460,23 @@ do end end +function testfenv() + X = 20; B = 30 + + local _G = getfenv() + setfenv(1, setmetatable({}, {__index=_G})) + + X = X+10 + assert(X == 30 and _G.X == 20) + B = false + assert(B == false) + B = nil + assert(B == 30) + + assert(_G.X == 20) + assert(_G == getfenv(0)) +end + +testfenv() -- DONT MOVE THIS LINE + return 'OK' diff --git a/tests/conformance/interrupt.lua b/tests/conformance/interrupt.lua index 2b1270991..d4b7c80a4 100644 --- a/tests/conformance/interrupt.lua +++ b/tests/conformance/interrupt.lua @@ -8,4 +8,13 @@ end foo() +function bar() + local i = 0 + while i < 10 do + i += i + 1 + end +end + +bar() + return "OK" diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 79ea0fb69..0cd0cdce7 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -152,14 +152,34 @@ assert(eq(a[1000][3], 1000/3, 0.001)) print('+') do -- testing NaN - local NaN = 10e500 - 10e400 + local NaN -- to avoid constant folding + NaN = 10e500 - 10e400 + assert(NaN ~= NaN) + assert(not (NaN == NaN)) + assert(not (NaN < NaN)) assert(not (NaN <= NaN)) assert(not (NaN > NaN)) assert(not (NaN >= NaN)) + + assert(not (0 == NaN)) assert(not (0 < NaN)) + assert(not (0 <= NaN)) + assert(not (0 > NaN)) + assert(not (0 >= NaN)) + + assert(not (NaN == 0)) assert(not (NaN < 0)) + assert(not (NaN <= 0)) + assert(not (NaN > 0)) + assert(not (NaN >= 0)) + + assert(if NaN < 0 then false else true) + assert(if NaN <= 0 then false else true) + assert(if NaN > 0 then false else true) + assert(if NaN >= 0 then false else true) + local a = {} assert(not pcall(function () a[NaN] = 1 end)) assert(a[NaN] == nil) @@ -215,6 +235,16 @@ assert(flag); assert(select(2, pcall(math.random, 1, 2, 3)):match("wrong number of arguments")) +-- min/max +assert(math.min(1) == 1) +assert(math.min(1, 2) == 1) +assert(math.min(1, 2, -1) == -1) +assert(math.min(1, -1, 2) == -1) +assert(math.max(1) == 1) +assert(math.max(1, 2) == 2) +assert(math.max(1, 2, -1) == 2) +assert(math.max(1, -1, 2) == 2) + -- noise assert(math.noise(0.5) == 0) assert(math.noise(0.5, 0.5) == -0.25) @@ -277,8 +307,10 @@ assert(math.log("10", 10) == 1) assert(math.log("9", 3) == 2) assert(math.max("1", 2) == 2) assert(math.max(2, "1") == 2) +assert(math.max(1, 2, "3") == 3) assert(math.min("1", 2) == 1) assert(math.min(2, "1") == 1) +assert(math.min(1, 2, "3") == 1) local v,f = math.modf("1.5") assert(v == 1 and f == 0.5) assert(math.pow("2", 2) == 4) @@ -295,4 +327,9 @@ assert(math.sign("-2") == -1) assert(math.sign("0") == 0) assert(math.round("1.8") == 2) +-- test that fastcalls return correct number of results +assert(select('#', math.floor(1.4)) == 1) +assert(select('#', math.ceil(1.6)) == 1) +assert(select('#', math.sqrt(9)) == 1) + return('OK') diff --git a/tests/conformance/safeenv.lua b/tests/conformance/safeenv.lua new file mode 100644 index 000000000..3a430b5fa --- /dev/null +++ b/tests/conformance/safeenv.lua @@ -0,0 +1,23 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("safeenv reset") + +local function envChangeInMetamethod() + -- declare constant so that at O2 this test doesn't interfere with constant folding which we can't deoptimize + local ten + ten = 10 + + local a = setmetatable({}, { + __index = function() + getfenv().math = { abs = function(n) return n*n end } + return 2 + end + }) + + local b = a.x + + assert(math.abs(ten) == 100) +end + +envChangeInMetamethod() + +return"OK" diff --git a/tests/main.cpp b/tests/main.cpp index 3f564c077..82ce4e16a 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -21,7 +21,6 @@ #include #endif -#include #include // Indicates if verbose output is enabled; can be overridden via --verbose @@ -31,8 +30,10 @@ bool verbose = false; // Default optimization level for conformance test; can be overridden via -On int optimizationLevel = 1; -// Something to seed a pseudorandom number generator with. Defaults to -// something from std::random_device. +// Run conformance tests with native code generation +bool codegen = false; + +// Something to seed a pseudorandom number generator with std::optional randomSeed; static bool skipFastFlag(const char* flagName) @@ -257,6 +258,11 @@ int main(int argc, char** argv) verbose = true; } + if (doctest::parseFlag(argc, argv, "--codegen")) + { + codegen = true; + } + int level = -1; if (doctest::parseIntOption(argc, argv, "-O", doctest::option_int, level)) { @@ -272,7 +278,7 @@ int main(int argc, char** argv) if (doctest::parseOption(argc, argv, "--randomize") && !randomSeed) { - randomSeed = std::random_device()(); + randomSeed = unsigned(time(nullptr)); printf("Using RNG seed %u\n", *randomSeed); } diff --git a/tools/faillist.txt b/tools/faillist.txt index 00e01011b..0eb022096 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -48,7 +48,6 @@ BuiltinTests.assert_removes_falsy_types2 BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy BuiltinTests.bad_select_should_not_crash -BuiltinTests.coroutine_resume_anything_goes BuiltinTests.coroutine_wrap_anything_goes BuiltinTests.debug_info_is_crazy BuiltinTests.debug_traceback_is_crazy @@ -56,7 +55,6 @@ BuiltinTests.dont_add_definitions_to_persistent_types BuiltinTests.find_capture_types BuiltinTests.find_capture_types2 BuiltinTests.find_capture_types3 -BuiltinTests.global_singleton_types_are_sealed BuiltinTests.gmatch_capture_types BuiltinTests.gmatch_capture_types2 BuiltinTests.gmatch_capture_types_balanced_escaped_parens @@ -69,7 +67,6 @@ BuiltinTests.match_capture_types BuiltinTests.match_capture_types2 BuiltinTests.math_max_checks_for_numbers BuiltinTests.next_iterator_should_infer_types_and_type_check -BuiltinTests.os_time_takes_optional_date_table BuiltinTests.pairs_iterator_should_infer_types_and_type_check BuiltinTests.see_thru_select BuiltinTests.select_slightly_out_of_range @@ -77,7 +74,6 @@ BuiltinTests.select_way_out_of_range BuiltinTests.select_with_decimal_argument_is_rounded_down BuiltinTests.set_metatable_needs_arguments BuiltinTests.setmetatable_should_not_mutate_persisted_types -BuiltinTests.sort BuiltinTests.sort_with_bad_predicate BuiltinTests.sort_with_predicate BuiltinTests.string_format_arg_count_mismatch @@ -88,8 +84,6 @@ BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument BuiltinTests.string_format_use_correct_argument2 BuiltinTests.string_format_use_correct_argument3 -BuiltinTests.string_lib_self_noself -BuiltinTests.table_concat_returns_string BuiltinTests.table_freeze_is_generic BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload @@ -101,12 +95,10 @@ BuiltinTests.tonumber_returns_optional_number_type2 DefinitionTests.class_definition_overload_metamethods DefinitionTests.declaring_generic_functions DefinitionTests.definition_file_classes -FrontendTest.automatically_check_dependent_scripts FrontendTest.environments FrontendTest.imported_table_modification_2 FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded FrontendTest.nocheck_cycle_used_by_checked -FrontendTest.recheck_if_dependent_script_is_dirty FrontendTest.reexport_cyclic_type FrontendTest.reexport_type_alias FrontendTest.trace_requires_in_nonstrict_mode @@ -132,18 +124,15 @@ GenericsTests.generic_type_pack_parentheses GenericsTests.generic_type_pack_unification1 GenericsTests.generic_type_pack_unification2 GenericsTests.generic_type_pack_unification3 +GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_methods GenericsTests.inferred_local_vars_can_be_polytypes GenericsTests.instantiate_cyclic_generic_function -GenericsTests.instantiate_generic_function_in_assignments -GenericsTests.instantiate_generic_function_in_assignments2 GenericsTests.instantiated_function_argument_names GenericsTests.instantiation_sharing_types -GenericsTests.local_vars_can_be_instantiated_polytypes GenericsTests.no_stack_overflow_from_quantifying -GenericsTests.properties_can_be_instantiated_polytypes GenericsTests.reject_clashing_generic_and_pack_names GenericsTests.self_recursive_instantiated_param IntersectionTypes.index_on_an_intersection_type_with_mixed_types @@ -155,8 +144,6 @@ IntersectionTypes.should_still_pick_an_overload_whose_arguments_are_unions IntersectionTypes.table_intersection_write_sealed IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect -isSubtype.intersection_of_tables -isSubtype.table_with_table_prop ModuleTests.any_persistance_does_not_leak ModuleTests.clone_self_property ModuleTests.deepClone_cyclic_table @@ -172,51 +159,24 @@ NonstrictModeTests.local_tables_are_not_any NonstrictModeTests.locals_are_any_by_default NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon NonstrictModeTests.parameters_having_type_any_are_optional -NonstrictModeTests.returning_insufficient_return_values -NonstrictModeTests.returning_too_many_values NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_props_are_any -Normalize.any_wins_the_battle_over_unknown_in_unions -Normalize.constrained_intersection_of_intersections -Normalize.cyclic_intersection Normalize.cyclic_table_normalizes_sensibly -Normalize.cyclic_union -Normalize.fuzz_failure_bound_type_is_normal_but_not_its_bounded_to Normalize.intersection_combine_on_bound_self -Normalize.intersection_inside_a_table_inside_another_intersection -Normalize.intersection_inside_a_table_inside_another_intersection_2 -Normalize.intersection_inside_a_table_inside_another_intersection_3 -Normalize.intersection_inside_a_table_inside_another_intersection_4 -Normalize.intersection_of_confluent_overlapping_tables -Normalize.intersection_of_disjoint_tables -Normalize.intersection_of_functions -Normalize.intersection_of_overlapping_tables -Normalize.intersection_of_tables_with_indexers -Normalize.normalization_does_not_convert_ever -Normalize.normalize_module_return_type -Normalize.normalize_unions_containing_never -Normalize.normalize_unions_containing_unknown -Normalize.union_of_distinct_free_types -Normalize.variadic_tail_is_marked_normal -Normalize.visiting_a_type_twice_is_not_considered_normal ParseErrorRecovery.generic_type_list_recovery ParseErrorRecovery.recovery_of_parenthesized_expressions ParserTests.parse_nesting_based_end_detection_failsafe_earlier ParserTests.parse_nesting_based_end_detection_local_function ProvisionalTests.bail_early_if_unification_is_too_complicated -ProvisionalTests.choose_the_right_overload_for_pcall -ProvisionalTests.constrained_is_level_dependent ProvisionalTests.discriminate_from_x_not_equal_to_nil ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean -ProvisionalTests.function_returns_many_things_but_first_of_it_is_forgotten +ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns -ProvisionalTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound -ProvisionalTests.it_should_be_agnostic_of_actual_size -ProvisionalTests.lower_bounds_calculation_is_too_permissive_with_overloaded_higher_order_functions -ProvisionalTests.normalization_fails_on_certain_kinds_of_cyclic_tables ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing ProvisionalTests.setmetatable_constrains_free_type_into_free_table +ProvisionalTests.specialization_binds_with_prototypes_too_early +ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weirditer_should_not_loop_forever ProvisionalTests.while_body_are_also_refined @@ -225,7 +185,6 @@ RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_con RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number RefinementTest.assert_non_binary_expressions_actually_resolve_constraints RefinementTest.call_a_more_specific_function_using_typeguard -RefinementTest.correctly_lookup_a_shadowed_local_that_which_was_previously_refined RefinementTest.correctly_lookup_property_whose_base_was_previously_refined RefinementTest.correctly_lookup_property_whose_base_was_previously_refined2 RefinementTest.discriminate_from_isa_of_x @@ -311,6 +270,7 @@ TableTests.found_like_key_in_table_function_call TableTests.found_like_key_in_table_property_access TableTests.found_multiple_like_keys TableTests.function_calls_produces_sealed_table_given_unsealed_table +TableTests.generic_table_instantiation_potential_regression TableTests.getmetatable_returns_pointer_to_metatable TableTests.give_up_after_one_metatable_index_look_up TableTests.hide_table_error_properties @@ -323,6 +283,7 @@ TableTests.infer_indexer_from_value_property_in_literal TableTests.inferred_return_type_of_free_table TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 +TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors TableTests.length_operator_union_errors TableTests.less_exponential_blowup_please @@ -340,7 +301,6 @@ TableTests.oop_polymorphic TableTests.open_table_unification_2 TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table_2 -TableTests.pass_incompatible_union_to_a_generic_table_without_crashing TableTests.persistent_sealed_table_is_immutable TableTests.prop_access_on_key_whose_types_mismatches TableTests.property_lookup_through_tabletypevar_metatable @@ -378,7 +338,6 @@ ToDot.function ToDot.table ToString.exhaustive_toString_of_cyclic_table ToString.function_type_with_argument_names_generic -ToString.no_parentheses_around_cyclic_function_type_in_union ToString.toStringDetailed2 ToString.toStringErrorPack ToString.toStringNamedFunction_generic_pack @@ -412,12 +371,12 @@ TypeAliases.type_alias_local_rename TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeAliases.type_alias_of_an_imported_recursive_type TypeInfer.checking_should_not_ice +TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.globals TypeInfer.globals2 TypeInfer.infer_assignment_value_types_mutable_lval -TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict TypeInfer.no_stack_overflow_from_isoptional TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.tc_if_else_expressions_expected_type_3 @@ -427,7 +386,7 @@ TypeInfer.tc_interpolated_string_with_invalid_expression TypeInfer.type_infer_recursion_limit_no_ice TypeInferAnyError.assign_prop_to_table_by_calling_any_yields_any TypeInferAnyError.for_in_loop_iterator_is_any2 -TypeInferAnyError.union_of_types_regression_test +TypeInferAnyError.for_in_loop_iterator_is_error2 TypeInferClasses.call_base_method TypeInferClasses.call_instance_method TypeInferClasses.can_assign_to_prop_of_base_class_using_string @@ -441,7 +400,6 @@ TypeInferClasses.optional_class_field_access_error TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferClasses.warn_when_prop_almost_matches TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class -TypeInferFunctions.call_o_with_another_argument_after_foo_was_quantified TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists @@ -454,26 +412,20 @@ TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite TypeInferFunctions.function_does_not_return_enough_values TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer -TypeInferFunctions.ignored_return_values TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors -TypeInferFunctions.inconsistent_higher_order_function -TypeInferFunctions.inconsistent_return_types TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_return_type_from_selected_overload TypeInferFunctions.infer_return_value_type TypeInferFunctions.infer_that_function_does_not_return_a_table -TypeInferFunctions.it_is_ok_not_to_supply_enough_retvals TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count TypeInferFunctions.no_lossy_function_type TypeInferFunctions.occurs_check_failure_in_function_return_type -TypeInferFunctions.quantify_constrained_types TypeInferFunctions.record_matching_overload TypeInferFunctions.report_exiting_without_return_nonstrict TypeInferFunctions.report_exiting_without_return_strict TypeInferFunctions.return_type_by_overload -TypeInferFunctions.strict_mode_ok_with_missing_arguments TypeInferFunctions.too_few_arguments_variadic TypeInferFunctions.too_few_arguments_variadic_generic TypeInferFunctions.too_few_arguments_variadic_generic2 @@ -489,14 +441,13 @@ TypeInferLoops.for_in_with_generic_next TypeInferLoops.for_in_with_just_one_iterator_is_ok TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil -TypeInferLoops.loop_typecheck_crash_on_empty_optional TypeInferLoops.unreachable_code_after_infinite_loop TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free +TypeInferModules.bound_free_table_export_is_ok TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types TypeInferModules.do_not_modify_imported_types_2 TypeInferModules.do_not_modify_imported_types_3 -TypeInferModules.general_require_type_mismatch TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated TypeInferModules.require_a_variadic_function @@ -531,7 +482,6 @@ TypeInferOperators.expected_types_through_binary_or TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown TypeInferOperators.or_joins_types TypeInferOperators.or_joins_types_with_no_extras -TypeInferOperators.primitive_arith_no_metatable TypeInferOperators.primitive_arith_possible_metatable TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not TypeInferOperators.refine_and_or @@ -543,7 +493,6 @@ TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.typecheck_unary_len_error TypeInferOperators.typecheck_unary_minus TypeInferOperators.typecheck_unary_minus_error -TypeInferOperators.unary_not_is_boolean TypeInferOperators.UnknownGlobalCompoundAssign TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.singleton_types @@ -563,8 +512,6 @@ TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.unary_minus_of_never TypePackTests.higher_order_function -TypePackTests.multiple_varargs_inference_are_not_confused -TypePackTests.no_return_size_should_be_zero TypePackTests.pack_tail_unification_check TypePackTests.parenthesized_varargs_returns_any TypePackTests.type_alias_backwards_compatible @@ -606,7 +553,6 @@ TypeSingletons.string_singleton_subtype TypeSingletons.string_singletons TypeSingletons.string_singletons_escape_chars TypeSingletons.string_singletons_mismatch -TypeSingletons.table_insert_with_a_singleton_argument TypeSingletons.table_properties_type_error_escapes TypeSingletons.tagged_unions_using_singletons TypeSingletons.taking_the_length_of_string_singleton @@ -616,14 +562,12 @@ TypeSingletons.widening_happens_almost_everywhere TypeSingletons.widening_happens_almost_everywhere_except_for_tables UnionTypes.error_detailed_optional UnionTypes.error_detailed_union_all -UnionTypes.error_takes_optional_arguments UnionTypes.index_on_a_union_type_with_missing_property UnionTypes.index_on_a_union_type_with_mixed_types UnionTypes.index_on_a_union_type_with_one_optional_property UnionTypes.index_on_a_union_type_with_one_property_of_type_any UnionTypes.index_on_a_union_type_with_property_guaranteed_to_exist UnionTypes.index_on_a_union_type_works_at_arbitrary_depth -UnionTypes.optional_arguments UnionTypes.optional_assignment_errors UnionTypes.optional_call_error UnionTypes.optional_field_access_error diff --git a/tools/lvmexecute_split.py b/tools/lvmexecute_split.py index 10e3ccbb8..48d66cb08 100644 --- a/tools/lvmexecute_split.py +++ b/tools/lvmexecute_split.py @@ -43,16 +43,23 @@ if match: inst = match[1] - signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k)" + signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)" header += signature + ";\n" function = signature + "\n" + function += "{\n" + function += " [[maybe_unused]] Closure* cl = clvalue(L->ci->func);\n" state = 1 - # find the end of an instruction + # first line of the instruction which is "{" elif state == 1: + assert(line == " {\n") + state = 2 + + # find the end of an instruction + elif state == 2: # remove jumps back into the native code if line == "#if LUA_CUSTOM_EXECUTION\n": - state = 2 + state = 3 continue if line[0] == ' ': @@ -70,7 +77,7 @@ if match: # break is not supported if inst == "LOP_BREAK": - function = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, Closure* cl, StkId base, TValue* k)\n" + function = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)\n" function += "{\n LUAU_ASSERT(!\"Unsupported deprecated opcode\");\n LUAU_UNREACHABLE();\n}\n" # handle fallthrough elif inst == "LOP_NAMECALL": @@ -81,14 +88,14 @@ state = 0 # skip LUA_CUSTOM_EXECUTION code blocks - elif state == 2: + elif state == 3: if line == "#endif\n": - state = 3 + state = 4 continue # skip extra line - elif state == 3: - state = 1 + elif state == 4: + state = 2 # make sure we found the ending assert(state == 0) diff --git a/tools/test_dcr.py b/tools/test_dcr.py index 5f1c87058..76bf11ac7 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -5,6 +5,9 @@ import subprocess as sp import sys import xml.sax as x +import colorama as c + +c.init() SCRIPT_PATH = os.path.split(sys.argv[0])[0] FAIL_LIST_PATH = os.path.join(SCRIPT_PATH, "faillist.txt") @@ -35,6 +38,10 @@ def __init__(self, failList): self.numSkippedTests = 0 + self.pass_count = 0 + self.fail_count = 0 + self.test_count = 0 + def startElement(self, name, attrs): if name == "TestSuite": self.currentTest.append(attrs["name"]) @@ -53,6 +60,12 @@ def startElement(self, name, attrs): r = self.results.get(dottedName, True) self.results[dottedName] = r and passed + self.test_count += 1 + if passed: + self.pass_count += 1 + else: + self.fail_count += 1 + elif name == "OverallResultsTestCases": self.numSkippedTests = safeParseInt(attrs.get("skipped", 0)) @@ -137,11 +150,33 @@ def main(): p.wait() + unexpected_fails = 0 + unexpected_passes = 0 + for testName, passed in handler.results.items(): if passed and testName in failList: - print_stderr(f"UNEXPECTED: {testName} should have failed") + unexpected_passes += 1 + print_stderr( + f"UNEXPECTED: {c.Fore.RED}{testName}{c.Fore.RESET} should have failed" + ) elif not passed and testName not in failList: - print_stderr(f"UNEXPECTED: {testName} should have passed") + unexpected_fails += 1 + print_stderr( + f"UNEXPECTED: {c.Fore.GREEN}{testName}{c.Fore.RESET} should have passed" + ) + + if unexpected_fails or unexpected_passes: + print_stderr("") + print_stderr(f"Unexpected fails: {unexpected_fails}") + print_stderr(f"Unexpected passes: {unexpected_passes}") + + pass_percent = int(handler.pass_count / handler.test_count * 100) + + print_stderr("") + print_stderr( + f"{handler.pass_count} of {handler.test_count} tests passed. ({pass_percent}%)" + ) + print_stderr(f"{handler.fail_count} tests failed.") if args.write: newFailList = sorted( From 48fd16d4a968c9e575416dc9d5c3c8cf08911347 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 14 Oct 2022 03:16:47 +0300 Subject: [PATCH 10/66] Fix build error --- CodeGen/src/NativeState.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index c108dd10a..4f38ff9fb 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -94,7 +94,7 @@ void initHelperFunctions(NativeState& data) "fast call tables are not of the same length"); // Replace missing fast call functions with an empty placeholder that forces LOP_CALL fallback - for (int i = 0; i < sizeof(data.context.luauF_table) / sizeof(data.context.luauF_table[0]); i++) + for (int i = 0; i < int(sizeof(data.context.luauF_table) / sizeof(data.context.luauF_table[0])); i++) data.context.luauF_table[i] = luauF_table[i] ? luauF_table[i] : luauF_missing; data.context.luaV_lessthan = luaV_lessthan; From 6a98d15fe7d0bd7ff5ec7bb13010c71ad51b79e6 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 14 Oct 2022 12:55:17 +0300 Subject: [PATCH 11/66] Responding to PR comments --- CodeGen/src/NativeState.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 4f38ff9fb..da6ceb6dc 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -94,7 +94,7 @@ void initHelperFunctions(NativeState& data) "fast call tables are not of the same length"); // Replace missing fast call functions with an empty placeholder that forces LOP_CALL fallback - for (int i = 0; i < int(sizeof(data.context.luauF_table) / sizeof(data.context.luauF_table[0])); i++) + for (size_t i = 0; i < sizeof(data.context.luauF_table) / sizeof(data.context.luauF_table[0]); i++) data.context.luauF_table[i] = luauF_table[i] ? luauF_table[i] : luauF_missing; data.context.luaV_lessthan = luaV_lessthan; From 2eff6cfe50fbecaaeb5533ac4171c808fb6c51db Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 21 Oct 2022 10:33:43 -0700 Subject: [PATCH 12/66] Sync to upstream/release/550 --- Analysis/include/Luau/Clone.h | 2 + Analysis/include/Luau/Constraint.h | 17 +- .../include/Luau/ConstraintGraphBuilder.h | 31 +- Analysis/include/Luau/ConstraintSolver.h | 3 + Analysis/include/Luau/DataFlowGraphBuilder.h | 115 ++++ Analysis/include/Luau/Def.h | 78 +++ Analysis/include/Luau/LValue.h | 2 + Analysis/include/Luau/Metamethods.h | 32 + Analysis/include/Luau/Normalize.h | 9 - Analysis/include/Luau/Scope.h | 5 +- Analysis/include/Luau/Symbol.h | 14 +- Analysis/include/Luau/TypeInfer.h | 6 - Analysis/include/Luau/TypeUtils.h | 19 + Analysis/include/Luau/TypeVar.h | 72 +- Analysis/include/Luau/Unifier.h | 5 - Analysis/include/Luau/Variant.h | 2 +- Analysis/include/Luau/VisitTypeVar.h | 24 +- Analysis/src/Anyification.cpp | 14 - Analysis/src/BuiltinDefinitions.cpp | 46 +- Analysis/src/Clone.cpp | 50 +- Analysis/src/ConstraintGraphBuilder.cpp | 371 +++++++--- Analysis/src/ConstraintSolver.cpp | 285 +++++++- Analysis/src/DataFlowGraphBuilder.cpp | 440 ++++++++++++ Analysis/src/Def.cpp | 12 + Analysis/src/Error.cpp | 13 +- Analysis/src/Frontend.cpp | 53 +- Analysis/src/Module.cpp | 33 - Analysis/src/Normalize.cpp | 638 +----------------- Analysis/src/Quantify.cpp | 23 - Analysis/src/Scope.cpp | 55 +- Analysis/src/Substitution.cpp | 18 +- Analysis/src/ToDot.cpp | 9 - Analysis/src/ToString.cpp | 51 +- Analysis/src/TxnLog.cpp | 13 +- Analysis/src/TypeAttach.cpp | 21 +- Analysis/src/TypeChecker2.cpp | 268 +++++++- Analysis/src/TypeInfer.cpp | 53 +- Analysis/src/TypeUtils.cpp | 86 ++- Analysis/src/TypeVar.cpp | 33 +- Analysis/src/Unifier.cpp | 261 +------ Ast/src/Parser.cpp | 19 + CLI/Analyze.cpp | 5 +- CLI/Repl.cpp | 59 +- CodeGen/include/Luau/AssemblyBuilderX64.h | 26 +- CodeGen/include/Luau/CodeAllocator.h | 2 + CodeGen/include/Luau/CodeGen.h | 16 +- CodeGen/include/Luau/OperandX64.h | 5 + CodeGen/include/Luau/RegisterX64.h | 15 + CodeGen/src/AssemblyBuilderX64.cpp | 206 +++++- CodeGen/src/ByteUtils.h | 12 +- CodeGen/src/CodeAllocator.cpp | 8 +- CodeGen/src/CodeBlockUnwind.cpp | 5 +- CodeGen/src/CodeGen.cpp | 283 ++++++-- CodeGen/src/CodeGenX64.cpp | 4 +- CodeGen/src/CustomExecUtils.h | 15 - CodeGen/src/EmitCommonX64.cpp | 82 ++- CodeGen/src/EmitCommonX64.h | 62 +- CodeGen/src/EmitInstructionX64.cpp | 599 +++++++++++----- CodeGen/src/EmitInstructionX64.h | 84 ++- CodeGen/src/NativeState.cpp | 63 +- CodeGen/src/NativeState.h | 14 +- Common/include/Luau/Common.h | 4 + Compiler/include/Luau/BytecodeBuilder.h | 9 +- Compiler/src/Builtins.cpp | 13 +- Compiler/src/BytecodeBuilder.cpp | 167 ++++- Makefile | 15 +- Sources.cmake | 9 +- VM/src/ldebug.cpp | 36 +- VM/src/lobject.cpp | 52 +- VM/src/lvmexecute.cpp | 28 +- bench/gc/test_GC_Boehm_Trees.lua | 2 +- tests/AssemblyBuilderX64.test.cpp | 221 ++++-- tests/CodeAllocator.test.cpp | 12 +- tests/Compiler.test.cpp | 2 - tests/Conformance.test.cpp | 63 ++ tests/ConstraintGraphBuilder.test.cpp | 126 ---- tests/ConstraintGraphBuilderFixture.cpp | 19 +- tests/ConstraintGraphBuilderFixture.h | 13 +- tests/ConstraintSolver.test.cpp | 46 +- tests/DataFlowGraphBuilder.test.cpp | 104 +++ tests/Fixture.cpp | 3 +- tests/Module.test.cpp | 18 - tests/Normalize.test.cpp | 20 - tests/Symbol.test.cpp | 24 +- tests/ToDot.test.cpp | 25 +- tests/TypeInfer.aliases.test.cpp | 12 +- tests/TypeInfer.definitions.test.cpp | 17 + tests/TypeInfer.functions.test.cpp | 43 ++ tests/TypeInfer.intersectionTypes.test.cpp | 2 +- tests/TypeInfer.operators.test.cpp | 106 ++- tests/TypeInfer.refinements.test.cpp | 60 +- tests/TypeInfer.tables.test.cpp | 76 ++- tests/TypeInfer.test.cpp | 1 - tests/TypeInfer.typePacks.cpp | 4 + tests/TypeVar.test.cpp | 2 - tests/conformance/basic.lua | 7 + tests/conformance/calls.lua | 10 + tests/conformance/datetime.lua | 1 + tests/conformance/errors.lua | 2 + tests/conformance/events.lua | 8 + tests/conformance/iter.lua | 20 + tests/conformance/math.lua | 7 + tests/conformance/move.lua | 2 + tests/conformance/strings.lua | 30 + tests/conformance/tables.lua | 124 +++- tests/conformance/tpack.lua | 2 + tools/faillist.txt | 51 +- tools/test_dcr.py | 28 +- 108 files changed, 4249 insertions(+), 2263 deletions(-) create mode 100644 Analysis/include/Luau/DataFlowGraphBuilder.h create mode 100644 Analysis/include/Luau/Def.h create mode 100644 Analysis/include/Luau/Metamethods.h create mode 100644 Analysis/src/DataFlowGraphBuilder.cpp create mode 100644 Analysis/src/Def.cpp delete mode 100644 tests/ConstraintGraphBuilder.test.cpp create mode 100644 tests/DataFlowGraphBuilder.test.cpp diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 217e1cc35..f003c2425 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include #include "Luau/TypeArena.h" #include "Luau/TypeVar.h" @@ -26,5 +27,6 @@ TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false); +TypeId shallowClone(TypeId ty, NotNull dest); } // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 0e19f13f5..7f092f5b2 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -2,9 +2,10 @@ #pragma once #include "Luau/Ast.h" // Used for some of the enumerations +#include "Luau/Def.h" #include "Luau/NotNull.h" -#include "Luau/Variant.h" #include "Luau/TypeVar.h" +#include "Luau/Variant.h" #include #include @@ -131,9 +132,15 @@ struct HasPropConstraint std::string prop; }; -using ConstraintV = - Variant; +struct RefinementConstraint +{ + DefId def; + TypeId discriminantType; +}; + +using ConstraintV = Variant; struct Constraint { @@ -143,7 +150,7 @@ struct Constraint Constraint& operator=(const Constraint&) = delete; NotNull scope; - Location location; + Location location; // TODO: Extract this out into only the constraints that needs a location. Not all constraints needs locations. ConstraintV c; std::vector> dependencies; diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 973c0a8ea..dc5d45988 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -1,13 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #pragma once -#include -#include -#include - #include "Luau/Ast.h" #include "Luau/Constraint.h" +#include "Luau/DataFlowGraphBuilder.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" @@ -15,6 +11,10 @@ #include "Luau/TypeVar.h" #include "Luau/Variant.h" +#include +#include +#include + namespace Luau { @@ -48,6 +48,7 @@ struct ConstraintGraphBuilder DenseHashMap astResolvedTypePacks{nullptr}; // Defining scopes for AST nodes. DenseHashMap astTypeAliasDefiningScopes{nullptr}; + NotNull dfg; int recursionCount = 0; @@ -63,7 +64,8 @@ struct ConstraintGraphBuilder DcrLogger* logger; ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, - NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger); + NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, + NotNull dfg); /** * Fabricates a new free type belonging to a given scope. @@ -88,15 +90,17 @@ struct ConstraintGraphBuilder * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. * @param cv the constraint variant to add. + * @return the pointer to the inserted constraint */ - void addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); + NotNull addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); /** * Adds a constraint to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param c the constraint to add. + * @return the pointer to the inserted constraint */ - void addConstraint(const ScopePtr& scope, std::unique_ptr c); + NotNull addConstraint(const ScopePtr& scope, std::unique_ptr c); /** * The entry point to the ConstraintGraphBuilder. This will construct a set @@ -139,13 +143,20 @@ struct ConstraintGraphBuilder */ TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); - TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + TypeId check(const ScopePtr& scope, AstExprLocal* local); + TypeId check(const ScopePtr& scope, AstExprGlobal* global); TypeId check(const ScopePtr& scope, AstExprIndexName* indexName); TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); TypeId check(const ScopePtr& scope, AstExprUnary* unary); - TypeId check(const ScopePtr& scope, AstExprBinary* binary); + TypeId check_(const ScopePtr& scope, AstExprUnary* unary); + TypeId check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); + TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + + TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); + + TypeId checkLValue(const ScopePtr& scope, AstExpr* expr); struct FunctionSignature { diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 0bf6d1bc7..5cc63e656 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -110,6 +110,7 @@ struct ConstraintSolver bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); + bool tryDispatch(const RefinementConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod @@ -215,6 +216,8 @@ struct ConstraintSolver TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; + TypeId unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes); + ToStringOptions opts; }; diff --git a/Analysis/include/Luau/DataFlowGraphBuilder.h b/Analysis/include/Luau/DataFlowGraphBuilder.h new file mode 100644 index 000000000..3a72403e3 --- /dev/null +++ b/Analysis/include/Luau/DataFlowGraphBuilder.h @@ -0,0 +1,115 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +// Do not include LValue. It should never be used here. +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" +#include "Luau/Def.h" +#include "Luau/Symbol.h" + +#include + +namespace Luau +{ + +struct DataFlowGraph +{ + DataFlowGraph(DataFlowGraph&&) = default; + DataFlowGraph& operator=(DataFlowGraph&&) = default; + + // TODO: AstExprLocal, AstExprGlobal, and AstLocal* are guaranteed never to return nullopt. + // We leave them to return an optional as we build it out, but the end state is for them to return a non-optional DefId. + std::optional getDef(const AstExpr* expr) const; + std::optional getDef(const AstLocal* local) const; + + /// Retrieve the Def that corresponds to the given Symbol. + /// + /// We do not perform dataflow analysis on globals, so this function always + /// yields nullopt when passed a global Symbol. + std::optional getDef(const Symbol& symbol) const; + +private: + DataFlowGraph() = default; + + DataFlowGraph(const DataFlowGraph&) = delete; + DataFlowGraph& operator=(const DataFlowGraph&) = delete; + + DefArena arena; + DenseHashMap astDefs{nullptr}; + DenseHashMap localDefs{nullptr}; + + friend struct DataFlowGraphBuilder; +}; + +struct DfgScope +{ + DfgScope* parent; + DenseHashMap bindings{Symbol{}}; +}; + +struct ExpressionFlowGraph +{ + std::optional def; +}; + +// Currently unsound. We do not presently track the control flow of the program. +// Additionally, we do not presently track assignments. +struct DataFlowGraphBuilder +{ + static DataFlowGraph build(AstStatBlock* root, NotNull handle); + +private: + DataFlowGraphBuilder() = default; + + DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete; + DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; + + DataFlowGraph graph; + NotNull arena{&graph.arena}; + struct InternalErrorReporter* handle; + std::vector> scopes; + + DfgScope* childScope(DfgScope* scope); + + std::optional use(DfgScope* scope, Symbol symbol, AstExpr* e); + + void visit(DfgScope* scope, AstStatBlock* b); + void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); + + // TODO: visit type aliases + void visit(DfgScope* scope, AstStat* s); + void visit(DfgScope* scope, AstStatIf* i); + void visit(DfgScope* scope, AstStatWhile* w); + void visit(DfgScope* scope, AstStatRepeat* r); + void visit(DfgScope* scope, AstStatBreak* b); + void visit(DfgScope* scope, AstStatContinue* c); + void visit(DfgScope* scope, AstStatReturn* r); + void visit(DfgScope* scope, AstStatExpr* e); + void visit(DfgScope* scope, AstStatLocal* l); + void visit(DfgScope* scope, AstStatFor* f); + void visit(DfgScope* scope, AstStatForIn* f); + void visit(DfgScope* scope, AstStatAssign* a); + void visit(DfgScope* scope, AstStatCompoundAssign* c); + void visit(DfgScope* scope, AstStatFunction* f); + void visit(DfgScope* scope, AstStatLocalFunction* l); + + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExpr* e); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprLocal* l); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprGlobal* g); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprCall* c); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexName* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexExpr* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprFunction* f); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTable* t); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprUnary* u); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprBinary* b); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTypeAssertion* t); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIfElse* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprInterpString* i); + + // TODO: visitLValue + // TODO: visitTypes (because of typeof which has access to values namespace, needs unreachable scope) + // TODO: visitTypePacks (because of typeof which has access to values namespace, needs unreachable scope) +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Def.h b/Analysis/include/Luau/Def.h new file mode 100644 index 000000000..ac1fa132c --- /dev/null +++ b/Analysis/include/Luau/Def.h @@ -0,0 +1,78 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/TypedAllocator.h" +#include "Luau/Variant.h" + +namespace Luau +{ + +using Def = Variant; + +/** + * We statically approximate a value at runtime using a symbolic value, which we call a Def. + * + * DataFlowGraphBuilder will allocate these defs as a stand-in for some Luau values, and bind them to places that + * can hold a Luau value, and then observes how those defs will commute as it statically evaluate the program. + * + * It must also be noted that defs are a cyclic graph, so it is not safe to recursively traverse into it expecting it to terminate. + */ +using DefId = NotNull; + +/** + * A "single-object" value. + * + * Leaky implementation note: sometimes "multiple-object" values, but none of which were interesting enough to warrant creating a phi node instead. + * That can happen because there's no point in creating a phi node that points to either resultant in `if math.random() > 0.5 then 5 else "hello"`. + * This might become of utmost importance if we wanted to do some backward reasoning, e.g. if `5` is taken, then `cond` must be `truthy`. + */ +struct Undefined +{ +}; + +/** + * A phi node is a union of defs. + * + * We need this because we're statically evaluating a program, and sometimes a place may be assigned with + * different defs, and when that happens, we need a special data type that merges in all the defs + * that will flow into that specific place. For example, consider this simple program: + * + * ``` + * x-1 + * if cond() then + * x-2 = 5 + * else + * x-3 = "hello" + * end + * x-4 : {x-2, x-3} + * ``` + * + * At x-4, we know for a fact statically that either `5` or `"hello"` can flow into the variable `x` after the branch, but + * we cannot make any definitive decisions about which one, so we just take in both. + */ +struct Phi +{ + std::vector operands; +}; + +template +T* getMutable(DefId def) +{ + return get_if(def.get()); +} + +template +const T* get(DefId def) +{ + return getMutable(def); +} + +struct DefArena +{ + TypedAllocator allocator; + + DefId freshDef(); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 1a92d52d0..518cbfafe 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -14,6 +14,8 @@ struct TypeVar; using TypeId = const TypeVar*; struct Field; + +// Deprecated. Do not use in new work. using LValue = Variant; struct Field diff --git a/Analysis/include/Luau/Metamethods.h b/Analysis/include/Luau/Metamethods.h new file mode 100644 index 000000000..84b0092fb --- /dev/null +++ b/Analysis/include/Luau/Metamethods.h @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" + +#include + +namespace Luau +{ + +static const std::unordered_map kBinaryOpMetamethods{ + {AstExprBinary::Op::CompareEq, "__eq"}, + {AstExprBinary::Op::CompareNe, "__eq"}, + {AstExprBinary::Op::CompareGe, "__lt"}, + {AstExprBinary::Op::CompareGt, "__le"}, + {AstExprBinary::Op::CompareLe, "__le"}, + {AstExprBinary::Op::CompareLt, "__lt"}, + {AstExprBinary::Op::Add, "__add"}, + {AstExprBinary::Op::Sub, "__sub"}, + {AstExprBinary::Op::Mul, "__mul"}, + {AstExprBinary::Op::Div, "__div"}, + {AstExprBinary::Op::Pow, "__pow"}, + {AstExprBinary::Op::Mod, "__mod"}, + {AstExprBinary::Op::Concat, "__concat"}, +}; + +static const std::unordered_map kUnaryOpMetamethods{ + {AstExprUnary::Op::Minus, "__unm"}, + {AstExprUnary::Op::Len, "__len"}, +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 72ea95588..a23d0fda0 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -22,15 +22,6 @@ bool isSubtype( bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); -std::pair normalize( - TypeId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypeId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypeId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize( - TypePackId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypePackId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypePackId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice); - class TypeIds { private: diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index b2da7bc0f..ccf2964ce 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -54,7 +54,9 @@ struct Scope DenseHashSet builtinTypeNames{""}; void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun); - std::optional lookup(Symbol sym); + std::optional lookup(Symbol sym) const; + std::optional lookup(DefId def) const; + std::optional> lookupEx(Symbol sym); std::optional lookupType(const Name& name); std::optional lookupImportedType(const Name& moduleAlias, const Name& name); @@ -66,6 +68,7 @@ struct Scope std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; RefinementMap refinements; + DenseHashMap dcrRefinements{nullptr}; // For mutually recursive type aliases, it's important that // they use the same types for the same names. diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index 1fe037e56..0432946cc 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -6,10 +6,11 @@ #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + namespace Luau { -// TODO Rename this to Name once the old type alias is gone. struct Symbol { Symbol() @@ -40,9 +41,12 @@ struct Symbol { if (local) return local == rhs.local; - if (global.value) + else if (global.value) return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. - return false; + else if (FFlag::DebugLuauDeferredConstraintResolution) + return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. + else + return false; } bool operator!=(const Symbol& rhs) const @@ -58,8 +62,8 @@ struct Symbol return global < rhs.global; else if (local) return true; - else - return false; + + return false; } AstName astName() const diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 1c4d1cb41..384637bbc 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -192,18 +192,12 @@ struct TypeChecker ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location); - void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location); - std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); std::optional getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); - // Reduces the union to its simplest possible shape. - // (A | B) | B | C yields A | B | C - std::vector reduceUnion(const std::vector& types); - std::optional tryStripUnionFromNil(TypeId ty); TypeId stripFromNilAndReport(TypeId ty, const Location& location); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index e5a205bab..7409dbe74 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -29,4 +29,23 @@ std::pair> getParameterExtents(const TxnLog* log, // various other things to get there. std::vector flatten(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length); +/** + * Reduces a union by decomposing to the any/error type if it appears in the + * type list, and by merging child unions. Also strips out duplicate (by pointer + * identity) types. + * @param types the input type list to reduce. + * @returns the reduced type list. +*/ +std::vector reduceUnion(const std::vector& types); + +/** + * Tries to remove nil from a union type, if there's another option. T | nil + * reduces to T, but nil itself does not reduce. + * @param singletonTypes the singleton types to use + * @param arena the type arena to allocate the new type in, if necessary + * @param ty the type to remove nil from + * @returns a type with nil removed, or nil itself if that were the only option. +*/ +TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty); + } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 1d587ffe5..70c12cb9d 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -2,22 +2,23 @@ #pragma once #include "Luau/Ast.h" +#include "Luau/Common.h" #include "Luau/DenseHash.h" +#include "Luau/Def.h" +#include "Luau/NotNull.h" #include "Luau/Predicate.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" -#include "Luau/Common.h" -#include "Luau/NotNull.h" +#include +#include +#include +#include #include #include -#include -#include #include +#include #include -#include -#include -#include LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) @@ -131,24 +132,6 @@ struct PrimitiveTypeVar } }; -struct ConstrainedTypeVar -{ - explicit ConstrainedTypeVar(TypeLevel level) - : level(level) - { - } - - explicit ConstrainedTypeVar(TypeLevel level, const std::vector& parts) - : parts(parts) - , level(level) - { - } - - std::vector parts; - TypeLevel level; - Scope* scope = nullptr; -}; - // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Types for true and false struct BooleanSingleton @@ -496,11 +479,13 @@ struct AnyTypeVar { }; +// T | U struct UnionTypeVar { std::vector options; }; +// T & U struct IntersectionTypeVar { std::vector parts; @@ -519,12 +504,27 @@ struct NeverTypeVar { }; -using ErrorTypeVar = Unifiable::Error; +// Invariant 1: there should never be a reason why such UseTypeVar exists without it mapping to another type. +// Invariant 2: UseTypeVar should always disappear across modules. +struct UseTypeVar +{ + DefId def; + NotNull scope; +}; -using TypeVariant = - Unifiable::Variant; +// ~T +// TODO: Some simplification step that overwrites the type graph to make sure negation +// types disappear from the user's view, and (?) a debug flag to disable that +struct NegationTypeVar +{ + TypeId ty; +}; +using ErrorTypeVar = Unifiable::Error; + +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -541,7 +541,6 @@ struct TypeVar final TypeVar(const TypeVariant& ty, bool persistent) : ty(ty) , persistent(persistent) - , normal(persistent) // We assume that all persistent types are irreducable. { } @@ -549,7 +548,6 @@ struct TypeVar final void reassign(const TypeVar& rhs) { ty = rhs.ty; - normal = rhs.normal; documentationSymbol = rhs.documentationSymbol; } @@ -560,10 +558,6 @@ struct TypeVar final // Persistent TypeVars do not get cloned. bool persistent = false; - // Normalization sets this for types that are fully normalized. - // This implies that they are transitively immutable. - bool normal = false; - std::optional documentationSymbol; // Pointer to the type arena that allocated this type. @@ -656,6 +650,8 @@ struct SingletonTypes const TypeId unknownType; const TypeId neverType; const TypeId errorType; + const TypeId falsyType; // No type binding! + const TypeId truthyType; // No type binding! const TypePackId anyTypePack; const TypePackId neverTypePack; @@ -703,7 +699,6 @@ T* getMutable(TypeId tv) const std::vector& getTypes(const UnionTypeVar* utv); const std::vector& getTypes(const IntersectionTypeVar* itv); -const std::vector& getTypes(const ConstrainedTypeVar* ctv); template struct TypeIterator; @@ -716,10 +711,6 @@ using IntersectionTypeVarIterator = TypeIterator; IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv); IntersectionTypeVarIterator end(const IntersectionTypeVar* itv); -using ConstrainedTypeVarIterator = TypeIterator; -ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv); -ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv); - /* Traverses the type T yielding each TypeId. * If the iterator encounters a nested type T, it will instead yield each TypeId within. */ @@ -793,7 +784,6 @@ struct TypeIterator // with templates portability in this area, so not worth it. Thanks MSVC. friend UnionTypeVarIterator end(const UnionTypeVar*); friend IntersectionTypeVarIterator end(const IntersectionTypeVar*); - friend ConstrainedTypeVarIterator end(const ConstrainedTypeVar*); private: TypeIterator() = default; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 10f3f48cb..c15cae31d 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -119,12 +119,7 @@ struct Unifier std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); - void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy); - void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy); - public: - void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel); - // Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error" bool occursCheck(TypeId needle, TypeId haystack); bool occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index f637222ef..76812c9bf 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -105,7 +105,7 @@ class Variant tableDtor[typeId](&storage); typeId = tid; - new (&storage) TT(std::forward(args)...); + new (&storage) TT{std::forward(args)...}; return *reinterpret_cast(&storage); } diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 315e5992f..d4f8528ff 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -103,10 +103,6 @@ struct GenericTypeVarVisitor { return visit(ty); } - virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv) - { - return visit(ty); - } virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) { return visit(ty); @@ -159,6 +155,14 @@ struct GenericTypeVarVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const UseTypeVar& utv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const NegationTypeVar& ntv) + { + return visit(ty); + } virtual bool visit(TypePackId tp) { @@ -216,14 +220,6 @@ struct GenericTypeVarVisitor visit(ty, *gtv); else if (auto etv = get(ty)) visit(ty, *etv); - else if (auto ctv = get(ty)) - { - if (visit(ty, *ctv)) - { - for (TypeId part : ctv->parts) - traverse(part); - } - } else if (auto ptv = get(ty)) visit(ty, *ptv); else if (auto ftv = get(ty)) @@ -325,6 +321,10 @@ struct GenericTypeVarVisitor traverse(a); } } + else if (auto utv = get(ty)) + visit(ty, *utv); + else if (auto ntv = get(ty)) + visit(ty, *ntv); else if (!FFlag::LuauCompleteVisitor) return visit_detail::unsee(seen, ty); else diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp index cc9796eec..5dd761c25 100644 --- a/Analysis/src/Anyification.cpp +++ b/Analysis/src/Anyification.cpp @@ -37,8 +37,6 @@ bool Anyification::isDirty(TypeId ty) return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); else if (log->getMutable(ty)) return true; - else if (get(ty)) - return true; else return false; } @@ -65,20 +63,8 @@ TypeId Anyification::clean(TypeId ty) clone.syntheticName = ttv->syntheticName; clone.tags = ttv->tags; TypeId res = addType(std::move(clone)); - asMutable(res)->normal = ty->normal; return res; } - else if (auto ctv = get(ty)) - { - std::vector copy = ctv->parts; - for (TypeId& ty : copy) - ty = replace(ty); - TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)}); - auto [t, ok] = normalize(res, scope, *arena, singletonTypes, *iceHandler); - if (!ok) - normalizationTooComplex = true; - return t; - } else return anyType; } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index c5250a6db..6051e117a 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -10,6 +10,7 @@ #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" +#include "Luau/TypeUtils.h" #include @@ -41,6 +42,7 @@ static std::optional> magicFunctionRequire( static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); +static bool dcrMagicFunctionPack(MagicFunctionCallContext context); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -333,6 +335,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); + attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); } attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); @@ -660,7 +663,7 @@ static std::optional> magicFunctionPack( options.push_back(vtp->ty); } - options = typechecker.reduceUnion(options); + options = reduceUnion(options); // table.pack() -> {| n: number, [number]: nil |} // table.pack(1) -> {| n: number, [number]: number |} @@ -679,6 +682,47 @@ static std::optional> magicFunctionPack( return WithPredicate{arena.addTypePack({packedTable})}; } +static bool dcrMagicFunctionPack(MagicFunctionCallContext context) +{ + + TypeArena* arena = context.solver->arena; + + const auto& [paramTypes, paramTail] = flatten(context.arguments); + + std::vector options; + options.reserve(paramTypes.size()); + for (auto type : paramTypes) + options.push_back(type); + + if (paramTail) + { + if (const VariadicTypePack* vtp = get(*paramTail)) + options.push_back(vtp->ty); + } + + options = reduceUnion(options); + + // table.pack() -> {| n: number, [number]: nil |} + // table.pack(1) -> {| n: number, [number]: number |} + // table.pack(1, "foo") -> {| n: number, [number]: number | string |} + TypeId result = nullptr; + if (options.empty()) + result = context.solver->singletonTypes->nilType; + else if (options.size() == 1) + result = options[0]; + else + result = arena->addType(UnionTypeVar{std::move(options)}); + + TypeId numberType = context.solver->singletonTypes->numberType; + TypeId packedTable = arena->addType( + TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); + + TypePackId tableTypePack = arena->addTypePack({packedTable}); + asMutable(context.result)->ty.emplace(tableTypePack); + + return true; +} + static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) { // require(foo.parent.bar) will technically work, but it depends on legacy goop that diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index fd3a089b4..85408919b 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #include "Luau/Clone.h" + #include "Luau/RecursionCounter.h" #include "Luau/TxnLog.h" #include "Luau/TypePack.h" @@ -51,7 +51,6 @@ struct TypeCloner void operator()(const BlockedTypeVar& t); void operator()(const PendingExpansionTypeVar& t); void operator()(const PrimitiveTypeVar& t); - void operator()(const ConstrainedTypeVar& t); void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); @@ -63,6 +62,8 @@ struct TypeCloner void operator()(const LazyTypeVar& t); void operator()(const UnknownTypeVar& t); void operator()(const NeverTypeVar& t); + void operator()(const UseTypeVar& t); + void operator()(const NegationTypeVar& t); }; struct TypePackCloner @@ -198,21 +199,6 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) defaultClone(t); } -void TypeCloner::operator()(const ConstrainedTypeVar& t) -{ - TypeId res = dest.addType(ConstrainedTypeVar{t.level}); - ConstrainedTypeVar* ctv = getMutable(res); - LUAU_ASSERT(ctv); - - seenTypes[typeId] = res; - - std::vector parts; - for (TypeId part : t.parts) - parts.push_back(clone(part, dest, cloneState)); - - ctv->parts = std::move(parts); -} - void TypeCloner::operator()(const SingletonTypeVar& t) { defaultClone(t); @@ -352,6 +338,21 @@ void TypeCloner::operator()(const NeverTypeVar& t) defaultClone(t); } +void TypeCloner::operator()(const UseTypeVar& t) +{ + TypeId result = dest.addType(BoundTypeVar{follow(typeId)}); + seenTypes[typeId] = result; +} + +void TypeCloner::operator()(const NegationTypeVar& t) +{ + TypeId result = dest.addType(AnyTypeVar{}); + seenTypes[typeId] = result; + + TypeId ty = clone(t.ty, dest, cloneState); + asMutable(result)->ty = NegationTypeVar{ty}; +} + } // anonymous namespace TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) @@ -390,7 +391,6 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (!res->persistent) { asMutable(res)->documentationSymbol = typeId->documentationSymbol; - asMutable(res)->normal = typeId->normal; } } @@ -478,11 +478,6 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl clone.parts = itv->parts; result = dest.addType(std::move(clone)); } - else if (const ConstrainedTypeVar* ctv = get(ty)) - { - ConstrainedTypeVar clone{ctv->level, ctv->parts}; - result = dest.addType(std::move(clone)); - } else if (const PendingExpansionTypeVar* petv = get(ty)) { PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; @@ -497,6 +492,10 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl { result = dest.addType(*ty); } + else if (const NegationTypeVar* ntv = get(ty)) + { + result = dest.addType(NegationTypeVar{ntv->ty}); + } else return result; @@ -504,4 +503,9 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl return result; } +TypeId shallowClone(TypeId ty, NotNull dest) +{ + return shallowClone(ty, *dest, TxnLog::empty()); +} + } // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 8436fb309..de2b0a4e1 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -1,20 +1,21 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #include "Luau/ConstraintGraphBuilder.h" + #include "Luau/Ast.h" +#include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Constraint.h" +#include "Luau/DcrLogger.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" -#include "Luau/DcrLogger.h" +#include "Luau/TypeUtils.h" LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); -#include "Luau/Scope.h" - namespace Luau { @@ -53,12 +54,13 @@ static bool matchSetmetatable(const AstExprCall& call) ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, - DcrLogger* logger) + DcrLogger* logger, NotNull dfg) : moduleName(moduleName) , module(module) , singletonTypes(singletonTypes) , arena(arena) , rootScope(nullptr) + , dfg(dfg) , moduleResolver(moduleResolver) , ice(ice) , globalScope(globalScope) @@ -95,14 +97,14 @@ ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& paren return scope; } -void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) +NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) { - scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}); + return NotNull{scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; } -void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) +NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) { - scope->constraints.emplace_back(std::move(c)); + return NotNull{scope->constraints.emplace_back(std::move(c)).get()}; } void ConstraintGraphBuilder::visit(AstStatBlock* block) @@ -229,22 +231,16 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) { std::vector varTypes; + varTypes.reserve(local->vars.size); for (AstLocal* local : local->vars) { TypeId ty = nullptr; - Location location = local->location; if (local->annotation) - { - location = local->annotation->location; ty = resolveType(scope, local->annotation, /* topLevel */ true); - } - else - ty = freshType(scope); varTypes.push_back(ty); - scope->bindings[local] = Binding{ty, location}; } for (size_t i = 0; i < local->values.size; ++i) @@ -257,6 +253,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. // See the test TypeInfer/infer_locals_with_nil_value. // Better flow awareness should make this obsolete. + + if (!varTypes[i]) + varTypes[i] = freshType(scope); } else if (i == local->values.size - 1) { @@ -268,6 +267,20 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (i < local->vars.size) { + std::vector packTypes = flatten(*arena, singletonTypes, exprPack, varTypes.size() - i); + + // fill out missing values in varTypes with values from exprPack + for (size_t j = i; j < varTypes.size(); ++j) + { + if (!varTypes[j]) + { + if (j - i < packTypes.size()) + varTypes[j] = packTypes[j - i]; + else + varTypes[j] = freshType(scope); + } + } + std::vector tailValues{varTypes.begin() + i, varTypes.end()}; TypePackId tailPack = arena->addTypePack(std::move(tailValues)); addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); @@ -281,10 +294,31 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) TypeId exprType = check(scope, value, expectedType); if (i < varTypes.size()) - addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); + { + if (varTypes[i]) + addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); + else + varTypes[i] = exprType; + } } } + for (size_t i = 0; i < local->vars.size; ++i) + { + AstLocal* l = local->vars.data[i]; + Location location = l->location; + + if (!varTypes[i]) + varTypes[i] = freshType(scope); + + scope->bindings[l] = Binding{varTypes[i], location}; + + // HACK: In the greedy solver, we say the type state of a variable is the type annotation itself, but + // the actual type state is the corresponding initializer expression (if it exists) or nil otherwise. + if (auto def = dfg->getDef(l)) + scope->dcrRefinements[*def] = varTypes[i]; + } + if (local->values.size > 0) { // To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'. @@ -510,7 +544,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { - TypePackId varPackId = checkPack(scope, assign->vars); + TypePackId varPackId = checkLValues(scope, assign->vars); TypePackId valuePack = checkPack(scope, assign->values); addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); @@ -532,7 +566,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { - check(scope, ifStatement->condition); + // TODO: Optimization opportunity, the interior scope of the condition could be + // reused for the then body, so we don't need to refine twice. + ScopePtr condScope = childScope(ifStatement->condition, scope); + check(condScope, ifStatement->condition, std::nullopt); ScopePtr thenScope = childScope(ifStatement->thenbody, scope); visit(thenScope, ifStatement->thenbody); @@ -893,7 +930,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: TypeId result = nullptr; if (auto group = expr->as()) - result = check(scope, group->expr); + result = check(scope, group->expr, expectedType); else if (auto stringExpr = expr->as()) { if (expectedType) @@ -937,32 +974,14 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: } else if (expr->is()) result = singletonTypes->nilType; - else if (auto a = expr->as()) - { - std::optional ty = scope->lookup(a->local); - if (ty) - result = *ty; - else - result = singletonTypes->errorRecoveryType(); // FIXME? Record an error at this point? - } - else if (auto g = expr->as()) - { - std::optional ty = scope->lookup(g->name); - if (ty) - result = *ty; - else - { - /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any - * global that is not already in-scope is definitely an unknown symbol. - */ - reportError(g->location, UnknownSymbol{g->name.value}); - result = singletonTypes->errorRecoveryType(); // FIXME? Record an error at this point? - } - } + else if (auto local = expr->as()) + result = check(scope, local); + else if (auto global = expr->as()) + result = check(scope, global); else if (expr->is()) result = flattenPack(scope, expr->location, checkPack(scope, expr)); else if (expr->is()) - result = flattenPack(scope, expr->location, checkPack(scope, expr)); + result = flattenPack(scope, expr->location, checkPack(scope, expr)); // TODO: needs predicates too else if (auto a = expr->as()) { FunctionSignature sig = checkFunctionSignature(scope, a); @@ -978,7 +997,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: else if (auto unary = expr->as()) result = check(scope, unary); else if (auto binary = expr->as()) - result = check(scope, binary); + result = check(scope, binary, expectedType); else if (auto ifElse = expr->as()) result = check(scope, ifElse, expectedType); else if (auto typeAssert = expr->as()) @@ -1002,6 +1021,37 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: return result; } +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) +{ + std::optional resultTy; + + if (auto def = dfg->getDef(local)) + resultTy = scope->lookup(*def); + + if (!resultTy) + { + if (auto ty = scope->lookup(local->local)) + resultTy = *ty; + } + + if (!resultTy) + return singletonTypes->errorRecoveryType(); // TODO: replace with ice, locals should never exist before its definition. + + return *resultTy; +} + +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) +{ + if (std::optional ty = scope->lookup(global->name)) + return *ty; + + /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any + * global that is not already in-scope is definitely an unknown symbol. + */ + reportError(global->location, UnknownSymbol{global->name.value}); + return singletonTypes->errorRecoveryType(); +} + TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr); @@ -1036,54 +1086,32 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* in TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { - TypeId operandType = check(scope, unary->expr); - + TypeId operandType = check_(scope, unary); TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); return resultType; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary) +TypeId ConstraintGraphBuilder::check_(const ScopePtr& scope, AstExprUnary* unary) { - TypeId leftType = check(scope, binary->left); - TypeId rightType = check(scope, binary->right); - switch (binary->op) - { - case AstExprBinary::And: - case AstExprBinary::Or: - { - addConstraint(scope, binary->location, SubtypeConstraint{leftType, rightType}); - return leftType; - } - case AstExprBinary::Add: - case AstExprBinary::Sub: - case AstExprBinary::Mul: - case AstExprBinary::Div: - case AstExprBinary::Mod: - case AstExprBinary::Pow: - case AstExprBinary::CompareNe: - case AstExprBinary::CompareEq: - case AstExprBinary::CompareLt: - case AstExprBinary::CompareLe: - case AstExprBinary::CompareGt: - case AstExprBinary::CompareGe: - { - TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); - return resultType; - } - case AstExprBinary::Concat: + if (unary->op == AstExprUnary::Not) { - addConstraint(scope, binary->left->location, SubtypeConstraint{leftType, singletonTypes->stringType}); - addConstraint(scope, binary->right->location, SubtypeConstraint{rightType, singletonTypes->stringType}); - return singletonTypes->stringType; - } - default: - LUAU_ASSERT(0); + TypeId ty = check(scope, unary->expr, std::nullopt); + + return ty; } - LUAU_ASSERT(0); - return nullptr; + return check(scope, unary->expr); +} + +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) +{ + TypeId leftType = check(scope, binary->left, expectedType); + TypeId rightType = check(scope, binary->right, expectedType); + + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); + return resultType; } TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) @@ -1106,10 +1134,182 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifEls TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) { - check(scope, typeAssert->expr); + check(scope, typeAssert->expr, std::nullopt); return resolveType(scope, typeAssert->annotation); } +TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) +{ + std::vector types; + types.reserve(exprs.size); + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* const expr = exprs.data[i]; + types.push_back(checkLValue(scope, expr)); + } + + return arena->addTypePack(std::move(types)); +} + +static bool isUnsealedTable(TypeId ty) +{ + ty = follow(ty); + const TableTypeVar* ttv = get(ty); + return ttv && ttv->state == TableState::Unsealed; +}; + +/** + * If the expr is a dotted set of names, and if the root symbol refers to an + * unsealed table, return that table type, plus the indeces that follow as a + * vector. + */ +static std::optional>> extractDottedName(AstExpr* expr) +{ + std::vector names; + + while (expr) + { + if (auto global = expr->as()) + { + std::reverse(begin(names), end(names)); + return std::pair{global->name, std::move(names)}; + } + else if (auto local = expr->as()) + { + std::reverse(begin(names), end(names)); + return std::pair{local->local, std::move(names)}; + } + else if (auto indexName = expr->as()) + { + names.push_back(indexName->index.value); + expr = indexName->expr; + } + else + return std::nullopt; + } + + return std::nullopt; +} + +/** + * Create a shallow copy of `ty` and its properties along `path`. Insert a new + * property (the last segment of `path`) into the tail table with the value `t`. + * + * On success, returns the new outermost table type. If the root table or any + * of its subkeys are not unsealed tables, the function fails and returns + * std::nullopt. + * + * TODO: Prove that we completely give up in the face of indexers and + * metatables. + */ +static std::optional updateTheTableType(NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) +{ + if (path.empty()) + return std::nullopt; + + // First walk the path and ensure that it's unsealed tables all the way + // to the end. + { + TypeId t = ty; + for (size_t i = 0; i < path.size() - 1; ++i) + { + if (!isUnsealedTable(t)) + return std::nullopt; + + const TableTypeVar* tbl = get(t); + auto it = tbl->props.find(path[i]); + if (it == tbl->props.end()) + return std::nullopt; + + t = it->second.type; + } + + // The last path segment should not be a property of the table at all. + // We are not changing property types. We are only admitting this one + // new property to be appended. + if (!isUnsealedTable(t)) + return std::nullopt; + const TableTypeVar* tbl = get(t); + auto it = tbl->props.find(path.back()); + if (it != tbl->props.end()) + return std::nullopt; + } + + const TypeId res = shallowClone(ty, arena); + TypeId t = res; + + for (size_t i = 0; i < path.size() - 1; ++i) + { + const std::string segment = path[i]; + + TableTypeVar* ttv = getMutable(t); + LUAU_ASSERT(ttv); + + auto propIt = ttv->props.find(segment); + if (propIt != ttv->props.end()) + { + LUAU_ASSERT(isUnsealedTable(propIt->second.type)); + t = shallowClone(follow(propIt->second.type), arena); + ttv->props[segment].type = t; + } + else + return std::nullopt; + } + + TableTypeVar* ttv = getMutable(t); + LUAU_ASSERT(ttv); + + const std::string lastSegment = path.back(); + LUAU_ASSERT(0 == ttv->props.count(lastSegment)); + ttv->props[lastSegment] = Property{replaceTy}; + return res; +} + +/** + * This function is mostly about identifying properties that are being inserted into unsealed tables. + * + * If expr has the form name.a.b.c + */ +TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) +{ + if (auto indexExpr = expr->as()) + { + if (auto constantString = indexExpr->index->as()) + { + AstName syntheticIndex{constantString->value.data}; + AstExprIndexName synthetic{ + indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'}; + return checkLValue(scope, &synthetic); + } + } + + auto dottedPath = extractDottedName(expr); + if (!dottedPath) + return check(scope, expr); + const auto [sym, segments] = std::move(*dottedPath); + + if (!sym.local) + return check(scope, expr); + + auto lookupResult = scope->lookupEx(sym); + if (!lookupResult) + return check(scope, expr); + const auto [ty, symbolScope] = std::move(*lookupResult); + + TypeId replaceTy = arena->freshType(scope.get()); + + std::optional updatedType = updateTheTableType(arena, ty, segments, replaceTy); + if (!updatedType) + return check(scope, expr); + + std::optional def = dfg->getDef(sym); + LUAU_ASSERT(def); + symbolScope->bindings[sym].typeId = *updatedType; + symbolScope->dcrRefinements[*def] = *updatedType; + return replaceTy; +} + TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) { TypeId ty = arena->addType(TableTypeVar{}); @@ -1275,6 +1475,9 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS argTypes.push_back(t); signatureScope->bindings[local] = Binding{t, local->location}; + if (auto def = dfg->getDef(local)) + signatureScope->dcrRefinements[*def] = t; + if (local->annotation) { TypeId argAnnotation = resolveType(signatureScope, local->annotation, /* topLevel */ true); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index e29eeaaaa..60f4666aa 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -3,14 +3,16 @@ #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" #include "Luau/ConstraintSolver.h" +#include "Luau/DcrLogger.h" #include "Luau/Instantiation.h" #include "Luau/Location.h" +#include "Luau/Metamethods.h" #include "Luau/ModuleResolver.h" #include "Luau/Quantify.h" #include "Luau/ToString.h" +#include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" -#include "Luau/DcrLogger.h" #include "Luau/VisitTypeVar.h" #include "Luau/TypeUtils.h" @@ -438,6 +440,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); + else if (auto rc = get(*constraint)) + success = tryDispatch(*rc, constraint); else LUAU_ASSERT(false); @@ -564,44 +568,192 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull + * + * This constraint is the one that is meant to unblock A, so it doesn't + * make any sense to stop and wait for someone else to do it. + */ + + if (isBlocked(leftType) && leftType != resultType) + return block(c.leftType, constraint); + + if (isBlocked(rightType) && rightType != resultType) + return block(c.rightType, constraint); + + if (!force) { - /* Compound assignments create constraints of the form - * - * A <: Binary - * - * This constraint is the one that is meant to unblock A, so it doesn't - * make any sense to stop and wait for someone else to do it. - */ - if (leftType != resultType && rightType != resultType) - { - block(c.leftType, constraint); - block(c.rightType, constraint); - return false; - } + // Logical expressions may proceed if the LHS is free. + if (get(leftType) && !isLogical) + return block(leftType, constraint); } - if (isNumber(leftType)) + // Logical expressions may proceed if the LHS is free. + if (isBlocked(leftType) || (get(leftType) && !isLogical)) { - unify(leftType, rightType, constraint->scope); - asMutable(resultType)->ty.emplace(leftType); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + unblock(resultType); return true; } - if (!force) + // For or expressions, the LHS will never have nil as a possible output. + // Consider: + // local foo = nil or 2 + // `foo` will always be 2. + if (c.op == AstExprBinary::Op::Or) + leftType = stripNil(singletonTypes, *arena, leftType); + + // Metatables go first, even if there is primitive behavior. + if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end()) { - if (get(leftType)) - return block(leftType, constraint); + // Metatables are not the same. The metamethod will not be invoked. + if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) && + getMetatable(leftType, singletonTypes) != getMetatable(rightType, singletonTypes)) + { + // TODO: Boolean singleton false? The result is _always_ boolean false. + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + } + + std::optional mm; + + // The LHS metatable takes priority over the RHS metatable, where + // present. + if (std::optional leftMm = findMetatableEntry(singletonTypes, errors, leftType, it->second, constraint->location)) + mm = leftMm; + else if (std::optional rightMm = findMetatableEntry(singletonTypes, errors, rightType, it->second, constraint->location)) + mm = rightMm; + + if (mm) + { + // TODO: Is a table with __call legal here? + // TODO: Overloads + if (const FunctionTypeVar* ftv = get(follow(*mm))) + { + TypePackId inferredArgs; + // For >= and > we invoke __lt and __le respectively with + // swapped argument ordering. + if (c.op == AstExprBinary::Op::CompareGe || c.op == AstExprBinary::Op::CompareGt) + { + inferredArgs = arena->addTypePack({rightType, leftType}); + } + else + { + inferredArgs = arena->addTypePack({leftType, rightType}); + } + + unify(inferredArgs, ftv->argTypes, constraint->scope); + + TypeId mmResult; + + // Comparison operations always evaluate to a boolean, + // regardless of what the metamethod returns. + switch (c.op) + { + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + mmResult = singletonTypes->booleanType; + break; + default: + mmResult = first(ftv->retTypes).value_or(errorRecoveryType()); + } + + asMutable(resultType)->ty.emplace(mmResult); + unblock(resultType); + return true; + } + } + + // If there's no metamethod available, fall back to primitive behavior. } - if (isBlocked(leftType)) + // If any is present, the expression must evaluate to any as well. + bool leftAny = get(leftType) || get(leftType); + bool rightAny = get(rightType) || get(rightType); + bool anyPresent = leftAny || rightAny; + + switch (c.op) { - asMutable(resultType)->ty.emplace(errorRecoveryType()); - // reportError(constraint->location, CannotInferBinaryOperation{c.op, std::nullopt, CannotInferBinaryOperation::Operation}); + // For arithmetic operators, if the LHS is a number, the RHS must be a + // number as well. The result will also be a number. + case AstExprBinary::Op::Add: + case AstExprBinary::Op::Sub: + case AstExprBinary::Op::Mul: + case AstExprBinary::Op::Div: + case AstExprBinary::Op::Pow: + case AstExprBinary::Op::Mod: + if (isNumber(leftType)) + { + unify(leftType, rightType, constraint->scope); + asMutable(resultType)->ty.emplace(anyPresent ? singletonTypes->anyType : leftType); + unblock(resultType); + return true; + } + + break; + // For concatenation, if the LHS is a string, the RHS must be a string as + // well. The result will also be a string. + case AstExprBinary::Op::Concat: + if (isString(leftType)) + { + unify(leftType, rightType, constraint->scope); + asMutable(resultType)->ty.emplace(anyPresent ? singletonTypes->anyType : leftType); + unblock(resultType); + return true; + } + + break; + // Inexact comparisons require that the types be both numbers or both + // strings, and evaluate to a boolean. + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + if ((isNumber(leftType) && isNumber(rightType)) || (isString(leftType) && isString(rightType))) + { + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + } + + break; + // == and ~= always evaluate to a boolean, and impose no other constraints + // on their parameters. + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); return true; + // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is + // truthy. + case AstExprBinary::Op::And: + asMutable(resultType)->ty.emplace(unionOfTypes(rightType, singletonTypes->booleanType, constraint->scope, false)); + unblock(resultType); + return true; + // Or evaluates to the LHS type if the LHS is truthy, and the RHS type if + // LHS is falsey. + case AstExprBinary::Op::Or: + asMutable(resultType)->ty.emplace(unionOfTypes(rightType, leftType, constraint->scope, true)); + unblock(resultType); + return true; + default: + iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location); + break; } - // TODO metatables, classes + // We failed to either evaluate a metamethod or invoke primitive behavior. + unify(leftType, errorRecoveryType(), constraint->scope); + unify(rightType, errorRecoveryType(), constraint->scope); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + unblock(resultType); return true; } @@ -943,6 +1095,31 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull callMm = findMetatableEntry(singletonTypes, errors, fn, "__call", constraint->location)) + { + std::vector args{fn}; + + for (TypeId arg : c.argsPack) + args.push_back(arg); + + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); + TypeId inferredFnType = + arena->addType(FunctionTypeVar(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result)); + + // Alter the inner constraints. + LUAU_ASSERT(c.innerConstraints.size() == 2); + + asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm}; + asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType}; + + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); + + asMutable(c.result)->ty.emplace(constraint->scope); + unblock(c.result); + return true; + } + const FunctionTypeVar* ftv = get(fn); bool usedMagic = false; @@ -1059,6 +1236,29 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull constraint) +{ + // TODO: Figure out exact details on when refinements need to be blocked. + // It's possible that it never needs to be, since we can just use intersection types with the discriminant type? + + if (!constraint->scope->parent) + iceReporter.ice("No parent scope"); + + std::optional previousTy = constraint->scope->parent->lookup(c.def); + if (!previousTy) + iceReporter.ice("No previous type"); + + std::optional useTy = constraint->scope->lookup(c.def); + if (!useTy) + iceReporter.ice("The def is not bound to a type"); + + TypeId resultTy = follow(*useTy); + std::vector parts{*previousTy, c.discriminantType}; + asMutable(resultTy)->ty.emplace(std::move(parts)); + + return true; +} + bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) { auto block_ = [&](auto&& t) { @@ -1502,4 +1702,39 @@ TypePackId ConstraintSolver::errorRecoveryTypePack() const return singletonTypes->errorRecoveryTypePack(); } +TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes) +{ + a = follow(a); + b = follow(b); + + if (unifyFreeTypes && (get(a) || get(b))) + { + Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; + u.useScopes = true; + u.tryUnify(b, a); + + if (u.errors.empty()) + { + u.log.commit(); + return a; + } + else + { + return singletonTypes->errorRecoveryType(singletonTypes->anyType); + } + } + + if (*a == *b) + return a; + + std::vector types = reduceUnion({a, b}); + if (types.empty()) + return singletonTypes->neverType; + + if (types.size() == 1) + return types[0]; + + return arena->addType(UnionTypeVar{types}); +} + } // namespace Luau diff --git a/Analysis/src/DataFlowGraphBuilder.cpp b/Analysis/src/DataFlowGraphBuilder.cpp new file mode 100644 index 000000000..e2c4c2857 --- /dev/null +++ b/Analysis/src/DataFlowGraphBuilder.cpp @@ -0,0 +1,440 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/DataFlowGraphBuilder.h" + +#include "Luau/Error.h" + +LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + +namespace Luau +{ + +std::optional DataFlowGraph::getDef(const AstExpr* expr) const +{ + if (auto def = astDefs.find(expr)) + return NotNull{*def}; + return std::nullopt; +} + +std::optional DataFlowGraph::getDef(const AstLocal* local) const +{ + if (auto def = localDefs.find(local)) + return NotNull{*def}; + return std::nullopt; +} + +std::optional DataFlowGraph::getDef(const Symbol& symbol) const +{ + if (symbol.local) + return getDef(symbol.local); + else + return std::nullopt; +} + +DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + + DataFlowGraphBuilder builder; + builder.handle = handle; + builder.visit(nullptr, block); // nullptr is the root DFG scope. + if (FFlag::DebugLuauFreezeArena) + builder.arena->allocator.freeze(); + return std::move(builder.graph); +} + +DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope) +{ + return scopes.emplace_back(new DfgScope{scope}).get(); +} + +std::optional DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, AstExpr* e) +{ + for (DfgScope* current = scope; current; current = current->parent) + { + if (auto loc = current->bindings.find(symbol)) + { + graph.astDefs[e] = *loc; + return NotNull{*loc}; + } + } + + return std::nullopt; +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) +{ + DfgScope* child = childScope(scope); + return visitBlockWithoutChildScope(child, b); +} + +void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b) +{ + for (AstStat* s : b->body) + visit(scope, s); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) +{ + if (auto b = s->as()) + return visit(scope, b); + else if (auto i = s->as()) + return visit(scope, i); + else if (auto w = s->as()) + return visit(scope, w); + else if (auto r = s->as()) + return visit(scope, r); + else if (auto b = s->as()) + return visit(scope, b); + else if (auto c = s->as()) + return visit(scope, c); + else if (auto r = s->as()) + return visit(scope, r); + else if (auto e = s->as()) + return visit(scope, e); + else if (auto l = s->as()) + return visit(scope, l); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto a = s->as()) + return visit(scope, a); + else if (auto c = s->as()) + return visit(scope, c); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto l = s->as()) + return visit(scope, l); + else if (auto t = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto _ = s->as()) + return; // ok + else + handle->ice("Unknown AstStat in DataFlowGraphBuilder"); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) +{ + DfgScope* condScope = childScope(scope); + visitExpr(condScope, i->condition); + visit(condScope, i->thenbody); + + if (i->elsebody) + visit(scope, i->elsebody); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) +{ + // TODO(controlflow): entry point has a back edge from exit point + DfgScope* whileScope = childScope(scope); + visitExpr(whileScope, w->condition); + visit(whileScope, w->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r) +{ + // TODO(controlflow): entry point has a back edge from exit point + DfgScope* repeatScope = childScope(scope); // TODO: loop scope. + visitBlockWithoutChildScope(repeatScope, r->body); + visitExpr(repeatScope, r->condition); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b) +{ + // TODO: Control flow analysis + return; // ok +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c) +{ + // TODO: Control flow analysis + return; // ok +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r) +{ + // TODO: Control flow analysis + for (AstExpr* e : r->list) + visitExpr(scope, e); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) +{ + visitExpr(scope, e->expr); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) +{ + // TODO: alias tracking + for (AstExpr* e : l->values) + visitExpr(scope, e); + + for (AstLocal* local : l->vars) + { + DefId def = arena->freshDef(); + graph.localDefs[local] = def; + scope->bindings[local] = def; + } +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) +{ + DfgScope* forScope = childScope(scope); // TODO: loop scope. + DefId def = arena->freshDef(); + graph.localDefs[f->var] = def; + scope->bindings[f->var] = def; + + // TODO(controlflow): entry point has a back edge from exit point + visit(forScope, f->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) +{ + DfgScope* forScope = childScope(scope); // TODO: loop scope. + + for (AstLocal* local : f->vars) + { + DefId def = arena->freshDef(); + graph.localDefs[local] = def; + forScope->bindings[local] = def; + } + + // TODO(controlflow): entry point has a back edge from exit point + for (AstExpr* e : f->values) + visitExpr(forScope, e); + + visit(forScope, f->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) +{ + for (AstExpr* r : a->values) + visitExpr(scope, r); + + for (AstExpr* l : a->vars) + { + AstExpr* root = l; + + bool isUpdatable = true; + while (true) + { + if (root->is() || root->is()) + break; + + AstExprIndexName* indexName = root->as(); + if (!indexName) + { + isUpdatable = false; + break; + } + + root = indexName->expr; + } + + if (isUpdatable) + { + // TODO global? + if (auto exprLocal = root->as()) + { + DefId def = arena->freshDef(); + graph.astDefs[exprLocal] = def; + + // Update the def in the scope that introduced the local. Not + // the current scope. + AstLocal* local = exprLocal->local; + DfgScope* s = scope; + while (s && !s->bindings.find(local)) + s = s->parent; + LUAU_ASSERT(s && s->bindings.find(local)); + s->bindings[local] = def; + } + } + + visitExpr(scope, l); // TODO: they point to a new def!! + } +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) +{ + // TODO(typestates): The lhs is being read and written to. This might or might not be annoying. + visitExpr(scope, c->value); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) +{ + visitExpr(scope, f->name); + visitExpr(scope, f->func); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) +{ + DefId def = arena->freshDef(); + graph.localDefs[l->name] = def; + scope->bindings[l->name] = def; + + visitExpr(scope, l->func); +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) +{ + if (auto g = e->as()) + return visitExpr(scope, g->expr); + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto l = e->as()) + return visitExpr(scope, l); + else if (auto g = e->as()) + return visitExpr(scope, g); + else if (auto v = e->as()) + return {}; // ok + else if (auto c = e->as()) + return visitExpr(scope, c); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto f = e->as()) + return visitExpr(scope, f); + else if (auto t = e->as()) + return visitExpr(scope, t); + else if (auto u = e->as()) + return visitExpr(scope, u); + else if (auto b = e->as()) + return visitExpr(scope, b); + else if (auto t = e->as()) + return visitExpr(scope, t); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto _ = e->as()) + return {}; // ok + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder"); +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +{ + return {use(scope, l->local, l)}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +{ + return {use(scope, g->name, g)}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) +{ + visitExpr(scope, c->func); + + for (AstExpr* arg : c->args) + visitExpr(scope, arg); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) +{ + std::optional def = visitExpr(scope, i->expr).def; + if (!def) + return {}; + + // TODO: properties for the above def. + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) +{ + visitExpr(scope, i->expr); + visitExpr(scope, i->expr); + + if (i->index->as()) + { + // TODO: properties for the def + } + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) +{ + if (AstLocal* self = f->self) + { + DefId def = arena->freshDef(); + graph.localDefs[self] = def; + scope->bindings[self] = def; + } + + for (AstLocal* param : f->args) + { + DefId def = arena->freshDef(); + graph.localDefs[param] = def; + scope->bindings[param] = def; + } + + visit(scope, f->body); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) +{ + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) +{ + visitExpr(scope, u->expr); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) +{ + visitExpr(scope, b->left); + visitExpr(scope, b->right); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) +{ + ExpressionFlowGraph result = visitExpr(scope, t->expr); + // TODO: visit type + return result; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) +{ + DfgScope* condScope = childScope(scope); + visitExpr(condScope, i->condition); + visitExpr(condScope, i->trueExpr); + + visitExpr(scope, i->falseExpr); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) +{ + for (AstExpr* e : i->expressions) + visitExpr(scope, e); + return {}; +} + +} // namespace Luau diff --git a/Analysis/src/Def.cpp b/Analysis/src/Def.cpp new file mode 100644 index 000000000..935301c86 --- /dev/null +++ b/Analysis/src/Def.cpp @@ -0,0 +1,12 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Def.h" + +namespace Luau +{ + +DefId DefArena::freshDef() +{ + return NotNull{allocator.allocate(Undefined{})}; +} + +} // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 4e9b68820..e55530036 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false) - static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { @@ -70,7 +68,7 @@ struct ErrorConverter { if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) { - if (FFlag::LuauTypeMismatchModuleNameResolution && fileResolver != nullptr) + if (fileResolver != nullptr) { std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); @@ -96,14 +94,7 @@ struct ErrorConverter if (!tm.reason.empty()) result += tm.reason + " "; - if (FFlag::LuauTypeMismatchModuleNameResolution) - { - result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); - } - else - { - result += Luau::toString(*tm.error); - } + result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); } else if (!tm.reason.empty()) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 5705ac17f..8f2a3ebd6 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,11 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/BuiltinDefinitions.h" #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Config.h" #include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintSolver.h" +#include "Luau/DataFlowGraphBuilder.h" #include "Luau/DcrLogger.h" #include "Luau/FileResolver.h" #include "Luau/Parser.h" @@ -15,7 +17,6 @@ #include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" -#include "Luau/BuiltinDefinitions.h" #include #include @@ -26,7 +27,6 @@ LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(DebugLuauLogSolverToJson); @@ -488,23 +488,19 @@ CheckResult Frontend::check(const ModuleName& name, std::optional 0) - typeCheckerForAutocomplete.instantiationChildLimit = - std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; - - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckerForAutocomplete.unifierIterationLimit = - std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; - } + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckerForAutocomplete.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckerForAutocomplete.unifierIterationLimit = + std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution ? check(sourceModule, mode, environmentScope, requireCycles, /*forAutocomplete*/ true) @@ -518,10 +514,9 @@ CheckResult Frontend::check(const ModuleName& name, std::optional mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}}; ConstraintGraphBuilder cgb{ - sourceModule.name, result, &result->internalTypes, mr, singletonTypes, NotNull(&iceHandler), globalScope, logger.get()}; + sourceModule.name, + result, + &result->internalTypes, + mr, + singletonTypes, + NotNull(&iceHandler), + globalScope, + logger.get(), + NotNull{&dfg}, + }; + cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 31a089a43..0412f0077 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -60,36 +60,6 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) return contains(pos, *iter); } -struct ForceNormal : TypeVarOnceVisitor -{ - const TypeArena* typeArena = nullptr; - - ForceNormal(const TypeArena* typeArena) - : typeArena(typeArena) - { - } - - bool visit(TypeId ty) override - { - if (ty->owningArena != typeArena) - return false; - - asMutable(ty)->normal = true; - return true; - } - - bool visit(TypeId ty, const FreeTypeVar& ftv) override - { - visit(ty); - return true; - } - - bool visit(TypePackId tp, const FreeTypePack& ftp) override - { - return true; - } -}; - struct ClonePublicInterface : Substitution { NotNull singletonTypes; @@ -241,8 +211,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern moduleScope->varargPack = varargPack; } - ForceNormal forceNormal{&interfaceTypes}; - if (exportedTypeBindings) { for (auto& [name, tf] : *exportedTypeBindings) @@ -262,7 +230,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern { auto t = asMutable(ty); t->ty = AnyTypeVar{}; - t->normal = true; } } } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 81114b76b..cea159c36 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -16,11 +16,11 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); namespace Luau { @@ -1269,19 +1269,35 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th return std::nullopt; if (hftv->genericPacks != tftv->genericPacks) return std::nullopt; - if (hftv->retTypes != tftv->retTypes) - return std::nullopt; - std::optional argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); - if (!argTypes) + TypePackId argTypes; + TypePackId retTypes; + + if (hftv->retTypes == tftv->retTypes) + { + std::optional argTypesOpt = unionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypesOpt) + return std::nullopt; + argTypes = *argTypesOpt; + retTypes = hftv->retTypes; + } + else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) + { + std::optional retTypesOpt = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!retTypesOpt) + return std::nullopt; + argTypes = hftv->argTypes; + retTypes = *retTypesOpt; + } + else return std::nullopt; - if (*argTypes == hftv->argTypes) + if (argTypes == hftv->argTypes && retTypes == hftv->retTypes) return here; - if (*argTypes == tftv->argTypes) + if (argTypes == tftv->argTypes && retTypes == tftv->retTypes) return there; - FunctionTypeVar result{*argTypes, hftv->retTypes}; + FunctionTypeVar result{argTypes, retTypes}; result.generics = hftv->generics; result.genericPacks = hftv->genericPacks; return arena->addType(std::move(result)); @@ -1762,610 +1778,4 @@ bool isSubtype( return ok; } -template -static bool areNormal_(const T& t, const std::unordered_set& seen, InternalErrorReporter& ice) -{ - int count = 0; - auto isNormal = [&](TypeId ty) { - ++count; - if (count >= FInt::LuauNormalizeIterationLimit) - ice.ice("Luau::areNormal hit iteration limit"); - - return ty->normal; - }; - - return std::all_of(begin(t), end(t), isNormal); -} - -static bool areNormal(const std::vector& types, const std::unordered_set& seen, InternalErrorReporter& ice) -{ - return areNormal_(types, seen, ice); -} - -static bool areNormal(TypePackId tp, const std::unordered_set& seen, InternalErrorReporter& ice) -{ - tp = follow(tp); - if (get(tp)) - return false; - - auto [head, tail] = flatten(tp); - - if (!areNormal_(head, seen, ice)) - return false; - - if (!tail) - return true; - - if (auto vtp = get(*tail)) - return vtp->ty->normal || follow(vtp->ty)->normal || seen.find(asMutable(vtp->ty)) != seen.end(); - - return true; -} - -#define CHECK_ITERATION_LIMIT(...) \ - do \ - { \ - if (iterationLimit > FInt::LuauNormalizeIterationLimit) \ - { \ - limitExceeded = true; \ - return __VA_ARGS__; \ - } \ - ++iterationLimit; \ - } while (false) - -struct Normalize final : TypeVarVisitor -{ - using TypeVarVisitor::Set; - - Normalize(TypeArena& arena, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) - : arena(arena) - , scope(scope) - , singletonTypes(singletonTypes) - , ice(ice) - { - } - - TypeArena& arena; - NotNull scope; - NotNull singletonTypes; - InternalErrorReporter& ice; - - int iterationLimit = 0; - bool limitExceeded = false; - - bool visit(TypeId ty, const FreeTypeVar&) override - { - LUAU_ASSERT(!ty->normal); - return false; - } - - bool visit(TypeId ty, const BoundTypeVar& btv) override - { - // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. - // So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack. - if (seen.find(asMutable(btv.boundTo)) != seen.end()) - return false; - - // It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases. - LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); - - if (!ty->normal) - asMutable(ty)->normal = btv.boundTo->normal; - return !ty->normal; - } - - bool visit(TypeId ty, const PrimitiveTypeVar&) override - { - LUAU_ASSERT(ty->normal); - return false; - } - - bool visit(TypeId ty, const GenericTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const ErrorTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const UnknownTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const NeverTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override - { - CHECK_ITERATION_LIMIT(false); - LUAU_ASSERT(!ty->normal); - - ConstrainedTypeVar* ctv = const_cast(&ctvRef); - - std::vector parts = std::move(ctv->parts); - - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId part : parts) - traverse(part); - - std::vector newParts = normalizeUnion(parts); - ctv->parts = std::move(newParts); - - return false; - } - - bool visit(TypeId ty, const FunctionTypeVar& ftv) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - traverse(ftv.argTypes); - traverse(ftv.retTypes); - - asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retTypes, seen, ice); - - return false; - } - - bool visit(TypeId ty, const TableTypeVar& ttv) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - bool normal = true; - - auto checkNormal = [&](TypeId t) { - // if t is on the stack, it is possible that this type is normal. - // If t is not normal and it is not on the stack, this type is definitely not normal. - if (!t->normal && seen.find(asMutable(t)) == seen.end()) - normal = false; - }; - - if (ttv.boundTo) - { - traverse(*ttv.boundTo); - asMutable(ty)->normal = (*ttv.boundTo)->normal; - return false; - } - - for (const auto& [_name, prop] : ttv.props) - { - traverse(prop.type); - checkNormal(prop.type); - } - - if (ttv.indexer) - { - traverse(ttv.indexer->indexType); - checkNormal(ttv.indexer->indexType); - traverse(ttv.indexer->indexResultType); - checkNormal(ttv.indexer->indexResultType); - } - - // An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal. - if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal)) - asMutable(ty)->normal = normal; - - return false; - } - - bool visit(TypeId ty, const MetatableTypeVar& mtv) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - traverse(mtv.table); - traverse(mtv.metatable); - - asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; - - return false; - } - - bool visit(TypeId ty, const ClassTypeVar& ctv) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const AnyTypeVar&) override - { - LUAU_ASSERT(ty->normal); - return false; - } - - bool visit(TypeId ty, const UnionTypeVar& utvRef) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - UnionTypeVar* utv = &const_cast(utvRef); - - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId option : utv->options) - traverse(option); - - std::vector newOptions = normalizeUnion(utv->options); - - const bool normal = areNormal(newOptions, seen, ice); - - LUAU_ASSERT(!newOptions.empty()); - - if (newOptions.size() == 1) - *asMutable(ty) = BoundTypeVar{newOptions[0]}; - else - utv->options = std::move(newOptions); - - asMutable(ty)->normal = normal; - - return false; - } - - bool visit(TypeId ty, const IntersectionTypeVar& itvRef) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - IntersectionTypeVar* itv = &const_cast(itvRef); - - std::vector oldParts = itv->parts; - IntersectionTypeVar newIntersection; - - for (TypeId part : oldParts) - traverse(part); - - std::vector tables; - for (TypeId part : oldParts) - { - part = follow(part); - if (get(part)) - tables.push_back(part); - else - { - Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD - combineIntoIntersection(replacer, &newIntersection, part); - } - } - - // Don't allocate a new table if there's just one in the intersection. - if (tables.size() == 1) - newIntersection.parts.push_back(tables[0]); - else if (!tables.empty()) - { - const TableTypeVar* first = get(tables[0]); - LUAU_ASSERT(first); - - TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); - TableTypeVar* ttv = getMutable(newTable); - for (TypeId part : tables) - { - // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need - // to be rewritten to point at 'newTable' in the clone. - Replacer replacer{&arena, part, newTable}; - combineIntoTable(replacer, ttv, part); - } - - newIntersection.parts.push_back(newTable); - } - - itv->parts = std::move(newIntersection.parts); - - asMutable(ty)->normal = areNormal(itv->parts, seen, ice); - - if (itv->parts.size() == 1) - { - TypeId part = itv->parts[0]; - *asMutable(ty) = BoundTypeVar{part}; - } - - return false; - } - - std::vector normalizeUnion(const std::vector& options) - { - if (options.size() == 1) - return options; - - std::vector result; - - for (TypeId part : options) - { - // AnyTypeVar always win the battle no matter what we do, so we're done. - if (FFlag::LuauUnknownAndNeverType && get(follow(part))) - return {part}; - - combineIntoUnion(result, part); - } - - return result; - } - - void combineIntoUnion(std::vector& result, TypeId ty) - { - ty = follow(ty); - if (auto utv = get(ty)) - { - for (TypeId t : utv) - { - // AnyTypeVar always win the battle no matter what we do, so we're done. - if (FFlag::LuauUnknownAndNeverType && get(t)) - { - result = {t}; - return; - } - - combineIntoUnion(result, t); - } - - return; - } - - for (TypeId& part : result) - { - if (isSubtype(ty, part, scope, singletonTypes, ice)) - return; // no need to do anything - else if (isSubtype(part, ty, scope, singletonTypes, ice)) - { - part = ty; // replace the less general type by the more general one - return; - } - } - - result.push_back(ty); - } - - /** - * @param replacer knows how to clone a type such that any recursive references point at the new containing type. - * @param result is an intersection that is safe for us to mutate in-place. - */ - void combineIntoIntersection(Replacer& replacer, IntersectionTypeVar* result, TypeId ty) - { - // Note: this check guards against running out of stack space - // so if you increase the size of a stack frame, you'll need to decrease the limit. - CHECK_ITERATION_LIMIT(); - - ty = follow(ty); - if (auto itv = get(ty)) - { - for (TypeId part : itv->parts) - combineIntoIntersection(replacer, result, part); - return; - } - - // Let's say that the last part of our result intersection is always a table, if any table is part of this intersection - if (get(ty)) - { - if (result->parts.empty()) - result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); - - TypeId theTable = result->parts.back(); - - if (!get(follow(theTable))) - { - result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); - theTable = result->parts.back(); - } - - TypeId newTable = replacer.smartClone(theTable); - result->parts.back() = newTable; - - combineIntoTable(replacer, getMutable(newTable), ty); - } - else if (auto ftv = get(ty)) - { - bool merged = false; - for (TypeId& part : result->parts) - { - if (isSubtype(part, ty, scope, singletonTypes, ice)) - { - merged = true; - break; // no need to do anything - } - else if (isSubtype(ty, part, scope, singletonTypes, ice)) - { - merged = true; - part = ty; // replace the less general type by the more general one - break; - } - } - - if (!merged) - result->parts.push_back(ty); - } - else - result->parts.push_back(ty); - } - - TableState combineTableStates(TableState lhs, TableState rhs) - { - if (lhs == rhs) - return lhs; - - if (lhs == TableState::Free || rhs == TableState::Free) - return TableState::Free; - - if (lhs == TableState::Unsealed || rhs == TableState::Unsealed) - return TableState::Unsealed; - - return lhs; - } - - /** - * @param replacer gives us a way to clone a type such that recursive references are rewritten to the new - * "containing" type. - * @param table always points into a table that is safe for us to mutate. - */ - void combineIntoTable(Replacer& replacer, TableTypeVar* table, TypeId ty) - { - // Note: this check guards against running out of stack space - // so if you increase the size of a stack frame, you'll need to decrease the limit. - CHECK_ITERATION_LIMIT(); - - LUAU_ASSERT(table); - - ty = follow(ty); - - TableTypeVar* tyTable = getMutable(ty); - LUAU_ASSERT(tyTable); - - for (const auto& [propName, prop] : tyTable->props) - { - if (auto it = table->props.find(propName); it != table->props.end()) - { - /** - * If we are going to recursively merge intersections of tables, we need to ensure that we never mutate - * a table that comes from somewhere else in the type graph. - * - * smarClone() does some nice things for us: It will perform a clone that is as shallow as possible - * while still rewriting any cyclic references back to the new 'root' table. - * - * replacer also keeps a mapping of types that have previously been copied, so we have the added - * advantage here of knowing that, whether or not a new copy was actually made, the resulting TypeVar is - * safe for us to mutate in-place. - */ - TypeId clone = replacer.smartClone(it->second.type); - it->second.type = combine(replacer, clone, prop.type); - } - else - table->props.insert({propName, prop}); - } - - if (tyTable->indexer) - { - if (table->indexer) - { - table->indexer->indexType = combine(replacer, replacer.smartClone(tyTable->indexer->indexType), table->indexer->indexType); - table->indexer->indexResultType = - combine(replacer, replacer.smartClone(tyTable->indexer->indexResultType), table->indexer->indexResultType); - } - else - { - table->indexer = - TableIndexer{replacer.smartClone(tyTable->indexer->indexType), replacer.smartClone(tyTable->indexer->indexResultType)}; - } - } - - table->state = combineTableStates(table->state, tyTable->state); - table->level = max(table->level, tyTable->level); - } - - /** - * @param a is always cloned by the caller. It is safe to mutate in-place. - * @param b will never be mutated. - */ - TypeId combine(Replacer& replacer, TypeId a, TypeId b) - { - b = follow(b); - - if (FFlag::LuauNormalizeCombineTableFix && a == b) - return a; - - if (!get(a) && !get(a)) - { - if (!FFlag::LuauNormalizeCombineTableFix && a == b) - return a; - else - return arena.addType(IntersectionTypeVar{{a, b}}); - } - - if (auto itv = getMutable(a)) - { - combineIntoIntersection(replacer, itv, b); - return a; - } - else if (auto ttv = getMutable(a)) - { - if (FFlag::LuauNormalizeCombineTableFix && !get(b)) - return arena.addType(IntersectionTypeVar{{a, b}}); - combineIntoTable(replacer, ttv, b); - return a; - } - - LUAU_ASSERT(!"Impossible"); - LUAU_UNREACHABLE(); - } -}; - -#undef CHECK_ITERATION_LIMIT - -/** - * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) - */ -std::pair normalize( - TypeId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice) -{ - CloneState state; - if (FFlag::DebugLuauCopyBeforeNormalizing) - (void)clone(ty, arena, state); - - Normalize n{arena, scope, singletonTypes, ice}; - n.traverse(ty); - - return {ty, !n.limitExceeded}; -} - -// TODO: Think about using a temporary arena and cloning types out of it so that we -// reclaim memory used by wantonly allocated intermediate types here. -// The main wrinkle here is that we don't want clone() to copy a type if the source and dest -// arena are the same. -std::pair normalize(TypeId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(ty, NotNull{module->getModuleScope().get()}, module->internalTypes, singletonTypes, ice); -} - -std::pair normalize(TypeId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(ty, NotNull{module.get()}, singletonTypes, ice); -} - -/** - * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) - */ -std::pair normalize( - TypePackId tp, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice) -{ - CloneState state; - if (FFlag::DebugLuauCopyBeforeNormalizing) - (void)clone(tp, arena, state); - - Normalize n{arena, scope, singletonTypes, ice}; - n.traverse(tp); - - return {tp, !n.limitExceeded}; -} - -std::pair normalize(TypePackId tp, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(tp, NotNull{module->getModuleScope().get()}, module->internalTypes, singletonTypes, ice); -} - -std::pair normalize(TypePackId tp, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(tp, NotNull{module.get()}, singletonTypes, ice); -} - } // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index e4c069bd6..e9de094b8 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -57,29 +57,6 @@ struct Quantifier final : TypeVarOnceVisitor return false; } - bool visit(TypeId ty, const ConstrainedTypeVar&) override - { - ConstrainedTypeVar* ctv = getMutable(ty); - - seenMutableType = true; - - if (!level.subsumes(ctv->level)) - return false; - - std::vector opts = std::move(ctv->parts); - - // We might transmute, so it's not safe to rely on the builtin traversal logic - for (TypeId opt : opts) - traverse(opt); - - if (opts.size() == 1) - *asMutable(ty) = BoundTypeVar{opts[0]}; - else - *asMutable(ty) = UnionTypeVar{std::move(opts)}; - - return false; - } - bool visit(TypeId ty, const TableTypeVar&) override { LUAU_ASSERT(getMutable(ty)); diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 9a7d36090..84925f790 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -27,6 +27,44 @@ void Scope::addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun) builtinTypeNames.insert(name); } +std::optional Scope::lookup(Symbol sym) const +{ + auto r = const_cast(this)->lookupEx(sym); + if (r) + return r->first; + else + return std::nullopt; +} + +std::optional> Scope::lookupEx(Symbol sym) +{ + Scope* s = this; + + while (true) + { + auto it = s->bindings.find(sym); + if (it != s->bindings.end()) + return std::pair{it->second.typeId, s}; + + if (s->parent) + s = s->parent.get(); + else + return std::nullopt; + } +} + +// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis. +std::optional Scope::lookup(DefId def) const +{ + for (const Scope* current = this; current; current = current->parent.get()) + { + if (auto ty = current->dcrRefinements.find(def)) + return *ty; + } + + return std::nullopt; +} + std::optional Scope::lookupType(const Name& name) { const Scope* scope = this; @@ -111,23 +149,6 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } -std::optional Scope::lookup(Symbol sym) -{ - Scope* s = this; - - while (true) - { - auto it = s->bindings.find(sym); - if (it != s->bindings.end()) - return it->second.typeId; - - if (s->parent) - s = s->parent.get(); - else - return std::nullopt; - } -} - bool subsumesStrict(Scope* left, Scope* right) { while (right) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 2137d73ee..20ed34f6c 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -73,11 +73,6 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypeId part : itv->parts) visitChild(part); } - else if (const ConstrainedTypeVar* ctv = get(ty)) - { - for (TypeId part : ctv->parts) - visitChild(part); - } else if (const PendingExpansionTypeVar* petv = get(ty)) { for (TypeId a : petv->typeArguments) @@ -97,6 +92,10 @@ void Tarjan::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable); } + else if (const NegationTypeVar* ntv = get(ty)) + { + visitChild(ntv->ty); + } } void Tarjan::visitChildren(TypePackId tp, int index) @@ -605,11 +604,6 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& part : itv->parts) part = replace(part); } - else if (ConstrainedTypeVar* ctv = getMutable(ty)) - { - for (TypeId& part : ctv->parts) - part = replace(part); - } else if (PendingExpansionTypeVar* petv = getMutable(ty)) { for (TypeId& a : petv->typeArguments) @@ -629,6 +623,10 @@ void Substitution::replaceChildren(TypeId ty) if (ctv->metatable) ctv->metatable = replace(*ctv->metatable); } + else if (NegationTypeVar* ntv = getMutable(ty)) + { + ntv->ty = replace(ntv->ty); + } } void Substitution::replaceChildren(TypePackId tp) diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 0d989ca03..68fa53931 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -237,15 +237,6 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } - else if (const ConstrainedTypeVar* ctv = get(ty)) - { - formatAppend(result, "ConstrainedTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId part : ctv->parts) - visitChild(part, index); - } else if (get(ty)) { formatAppend(result, "ErrorTypeVar %d", index); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index f5ab9494c..5897ca211 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -399,29 +399,6 @@ struct TypeVarStringifier state.emit(state.getName(ty)); } - void operator()(TypeId, const ConstrainedTypeVar& ctv) - { - state.result.invalid = true; - - state.emit("["); - if (FFlag::DebugLuauVerboseTypeNames) - state.emit(ctv.level); - state.emit("["); - - bool first = true; - for (TypeId ty : ctv.parts) - { - if (first) - first = false; - else - state.emit("|"); - - stringify(ty); - } - - state.emit("]]"); - } - void operator()(TypeId, const BlockedTypeVar& btv) { state.emit("*blocked-"); @@ -870,6 +847,28 @@ struct TypeVarStringifier { state.emit("never"); } + + void operator()(TypeId ty, const UseTypeVar&) + { + stringify(follow(ty)); + } + + void operator()(TypeId, const NegationTypeVar& ntv) + { + state.emit("~"); + + // The precedence of `~` should be less than `|` and `&`. + TypeId followed = follow(ntv.ty); + bool parens = get(followed) || get(followed); + + if (parens) + state.emit("("); + + stringify(ntv.ty); + + if (parens) + state.emit(")"); + } }; struct TypePackStringifier @@ -1432,7 +1431,7 @@ std::string generateName(size_t i) std::string toString(const Constraint& constraint, ToStringOptions& opts) { - auto go = [&opts](auto&& c) { + auto go = [&opts](auto&& c) -> std::string { using T = std::decay_t; // TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps @@ -1516,6 +1515,10 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; } + else if constexpr (std::is_same_v) + { + return "TODO"; + } else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 06bde1950..034aeaeca 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -251,7 +251,7 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); if (FreeTypeVar* ftv = Luau::getMutable(newTy)) @@ -267,11 +267,6 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { ftv->level = newLevel; } - else if (ConstrainedTypeVar* ctv = Luau::getMutable(newTy)) - { - if (FFlag::LuauUnknownAndNeverType) - ctv->level = newLevel; - } return newTy; } @@ -291,7 +286,7 @@ PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) PendingType* TxnLog::changeScope(TypeId ty, NotNull newScope) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); if (FreeTypeVar* ftv = Luau::getMutable(newTy)) @@ -307,10 +302,6 @@ PendingType* TxnLog::changeScope(TypeId ty, NotNull newScope) { ftv->scope = newScope; } - else if (ConstrainedTypeVar* ctv = Luau::getMutable(newTy)) - { - ctv->scope = newScope; - } return newTy; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 84494083b..f2613cae2 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -104,16 +104,6 @@ class TypeRehydrationVisitor return allocator->alloc(Location(), std::nullopt, AstName("*pending-expansion*")); } - AstType* operator()(const ConstrainedTypeVar& ctv) - { - AstArray types; - types.size = ctv.parts.size(); - types.data = static_cast(allocator->allocate(sizeof(AstType*) * ctv.parts.size())); - for (size_t i = 0; i < ctv.parts.size(); ++i) - types.data[i] = Luau::visit(*this, ctv.parts[i]->ty); - return allocator->alloc(Location(), types); - } - AstType* operator()(const SingletonTypeVar& stv) { if (const BooleanSingleton* bs = get(&stv)) @@ -348,6 +338,17 @@ class TypeRehydrationVisitor { return allocator->alloc(Location(), std::nullopt, AstName{"never"}); } + AstType* operator()(const UseTypeVar& utv) + { + std::optional ty = utv.scope->lookup(utv.def); + LUAU_ASSERT(ty); + return Luau::visit(*this, (*ty)->ty); + } + AstType* operator()(const NegationTypeVar& ntv) + { + // FIXME: do the same thing we do with ErrorTypeVar + throw std::runtime_error("Cannot convert NegationTypeVar into AstNode"); + } private: Allocator* allocator; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 4753a7c21..bd220e9c0 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -5,6 +5,7 @@ #include "Luau/AstQuery.h" #include "Luau/Clone.h" #include "Luau/Instantiation.h" +#include "Luau/Metamethods.h" #include "Luau/Normalize.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" @@ -62,6 +63,23 @@ struct StackPusher } }; +static std::optional getIdentifierOfBaseVar(AstExpr* node) +{ + if (AstExprGlobal* expr = node->as()) + return expr->name.value; + + if (AstExprLocal* expr = node->as()) + return expr->local->name.value; + + if (AstExprIndexExpr* expr = node->as()) + return getIdentifierOfBaseVar(expr->expr); + + if (AstExprIndexName* expr = node->as()) + return getIdentifierOfBaseVar(expr->expr); + + return std::nullopt; +} + struct TypeChecker2 { NotNull singletonTypes; @@ -750,7 +768,7 @@ struct TypeChecker2 TypeId actualType = lookupType(string); TypeId stringType = singletonTypes->stringType; - if (!isSubtype(stringType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -783,26 +801,55 @@ struct TypeChecker2 TypePackId expectedRetType = lookupPack(call); TypeId functionType = lookupType(call->func); - LUAU_ASSERT(functionType); + TypeId testFunctionType = functionType; + TypePack args; if (get(functionType) || get(functionType)) return; - - // TODO: Lots of other types are callable: intersections of functions - // and things with the __call metamethod. - if (!get(functionType)) + else if (std::optional callMm = findMetatableEntry(singletonTypes, module->errors, functionType, "__call", call->func->location)) + { + if (get(follow(*callMm))) + { + if (std::optional instantiatedCallMm = instantiation.substitute(*callMm)) + { + args.head.push_back(functionType); + testFunctionType = follow(*instantiatedCallMm); + } + else + { + reportError(UnificationTooComplex{}, call->func->location); + return; + } + } + else + { + // TODO: This doesn't flag the __call metamethod as the problem + // very clearly. + reportError(CannotCallNonFunction{*callMm}, call->func->location); + return; + } + } + else if (get(functionType)) + { + if (std::optional instantiatedFunctionType = instantiation.substitute(functionType)) + { + testFunctionType = *instantiatedFunctionType; + } + else + { + reportError(UnificationTooComplex{}, call->func->location); + return; + } + } + else { reportError(CannotCallNonFunction{functionType}, call->func->location); return; } - TypeId instantiatedFunctionType = follow(instantiation.substitute(functionType).value_or(nullptr)); - - TypePack args; for (AstExpr* arg : call->args) { - TypeId argTy = module->astTypes[arg]; - LUAU_ASSERT(argTy); + TypeId argTy = lookupType(arg); args.head.push_back(argTy); } @@ -810,7 +857,7 @@ struct TypeChecker2 FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(instantiatedFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { CloneState cloneState; expectedType = clone(expectedType, module->internalTypes, cloneState); @@ -893,9 +940,204 @@ struct TypeChecker2 void visit(AstExprBinary* expr) { - // TODO! visit(expr->left); visit(expr->right); + + NotNull scope = stack.back(); + + bool isEquality = expr->op == AstExprBinary::Op::CompareEq || expr->op == AstExprBinary::Op::CompareNe; + bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe; + bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or; + + TypeId leftType = lookupType(expr->left); + TypeId rightType = lookupType(expr->right); + + if (expr->op == AstExprBinary::Op::Or) + { + leftType = stripNil(singletonTypes, module->internalTypes, leftType); + } + + bool isStringOperation = isString(leftType) && isString(rightType); + + if (get(leftType) || get(leftType) || get(rightType) || get(rightType)) + return; + + if ((get(leftType) || get(leftType)) && !isEquality && !isLogical) + { + auto name = getIdentifierOfBaseVar(expr->left); + reportError(CannotInferBinaryOperation{expr->op, name, + isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation}, + expr->location); + return; + } + + if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) + { + std::optional leftMt = getMetatable(leftType, singletonTypes); + std::optional rightMt = getMetatable(rightType, singletonTypes); + + bool matches = leftMt == rightMt; + if (isEquality && !matches) + { + auto testUnion = [&matches, singletonTypes = this->singletonTypes](const UnionTypeVar* utv, std::optional otherMt) { + for (TypeId option : utv) + { + if (getMetatable(follow(option), singletonTypes) == otherMt) + { + matches = true; + break; + } + } + }; + + if (const UnionTypeVar* utv = get(leftType); utv && rightMt) + { + testUnion(utv, rightMt); + } + + if (const UnionTypeVar* utv = get(rightType); utv && leftMt && !matches) + { + testUnion(utv, leftMt); + } + } + + if (!matches && isComparison) + { + reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + + return; + } + + std::optional mm; + if (std::optional leftMm = findMetatableEntry(singletonTypes, module->errors, leftType, it->second, expr->left->location)) + mm = leftMm; + else if (std::optional rightMm = findMetatableEntry(singletonTypes, module->errors, rightType, it->second, expr->right->location)) + mm = rightMm; + + if (mm) + { + if (const FunctionTypeVar* ftv = get(*mm)) + { + TypePackId expectedArgs; + // For >= and > we invoke __lt and __le respectively with + // swapped argument ordering. + if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt) + { + expectedArgs = module->internalTypes.addTypePack({rightType, leftType}); + } + else + { + expectedArgs = module->internalTypes.addTypePack({leftType, rightType}); + } + + reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); + + if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe || + expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) + { + TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType}); + if (!isSubtype(ftv->retTypes, expectedRets, scope, singletonTypes, ice)) + { + reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location); + } + } + else if (!first(ftv->retTypes)) + { + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + } + } + else + { + reportError(CannotCallNonFunction{*mm}, expr->location); + } + + return; + } + // If this is a string comparison, or a concatenation of strings, we + // want to fall through to primitive behavior. + else if (!isEquality && !(isStringOperation && (expr->op == AstExprBinary::Op::Concat || isComparison))) + { + if (leftMt || rightMt) + { + if (isComparison) + { + reportError(GenericError{format( + "Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)}, + expr->location); + } + else + { + reportError(GenericError{format( + "Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod", + toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)}, + expr->location); + } + + return; + } + else if (!leftMt && !rightMt && (get(leftType) || get(rightType))) + { + if (isComparison) + { + reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + } + else + { + reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable", + toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())}, + expr->location); + } + + return; + } + } + } + + switch (expr->op) + { + case AstExprBinary::Op::Add: + case AstExprBinary::Op::Sub: + case AstExprBinary::Op::Mul: + case AstExprBinary::Op::Div: + case AstExprBinary::Op::Pow: + case AstExprBinary::Op::Mod: + reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->numberType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType)); + + break; + case AstExprBinary::Op::Concat: + reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->stringType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType)); + + break; + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + if (isNumber(leftType)) + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType)); + else if (isString(leftType)) + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType)); + else + reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), + toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + + break; + case AstExprBinary::Op::And: + case AstExprBinary::Op::Or: + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + break; + default: + // Unhandled AstExprBinary::Op possibility. + LUAU_ASSERT(false); + } } void visit(AstExprTypeAssertion* expr) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b806edb7c..d5c6b2c46 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -31,7 +31,6 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. @@ -280,11 +279,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo iceHandler->moduleName = module.name; normalizer.arena = ¤tModule->internalTypes; - if (FFlag::LuauAutocompleteDynamicLimits) - { - unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; - unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; - } + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); @@ -773,16 +769,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } -void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location) -{ - Unifier state = mkUnifier(scope, location); - state.unifyLowerBound(subTy, superTy, demotedLevel); - - state.log.commit(); - - reportErrors(state.errors); -} - struct Demoter : Substitution { Demoter(TypeArena* arena) @@ -2091,39 +2077,6 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( return std::nullopt; } -std::vector TypeChecker::reduceUnion(const std::vector& types) -{ - std::vector result; - for (TypeId t : types) - { - t = follow(t); - if (get(t)) - continue; - - if (get(t) || get(t)) - return {t}; - - if (const UnionTypeVar* utv = get(t)) - { - for (TypeId ty : utv) - { - ty = follow(ty); - if (get(ty)) - continue; - if (get(ty) || get(ty)) - return {ty}; - - if (result.end() == std::find(result.begin(), result.end(), ty)) - result.push_back(ty); - } - } - else if (std::find(result.begin(), result.end(), t) == result.end()) - result.push_back(t); - } - - return result; -} - std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) { if (const UnionTypeVar* utv = get(ty)) @@ -4597,7 +4550,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat Instantiation instantiation{log, ¤tModule->internalTypes, scope->level, /*scope*/ nullptr}; - if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit) + if (instantiationChildLimit) instantiation.childLimit = *instantiationChildLimit; std::optional instantiated = instantiation.substitute(ty); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 688c87672..72597c4a1 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -6,6 +6,8 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" +#include + namespace Luau { @@ -146,18 +148,15 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro return std::nullopt; } + goodOptions = reduceUnion(goodOptions); + if (goodOptions.empty()) return singletonTypes->neverType; if (goodOptions.size() == 1) return goodOptions[0]; - // TODO: inefficient. - TypeId result = arena->addType(UnionTypeVar{std::move(goodOptions)}); - auto [ty, ok] = normalize(result, NotNull{scope.get()}, *arena, singletonTypes, handle); - if (!ok && addErrors) - errors.push_back(TypeError{location, NormalizationTooComplex{}}); - return ok ? ty : singletonTypes->anyType; + return arena->addType(UnionTypeVar{std::move(goodOptions)}); } else if (const IntersectionTypeVar* itv = get(type)) { @@ -264,4 +263,79 @@ std::vector flatten(TypeArena& arena, NotNull singletonT return result; } +std::vector reduceUnion(const std::vector& types) +{ + std::vector result; + for (TypeId t : types) + { + t = follow(t); + if (get(t)) + continue; + + if (get(t) || get(t)) + return {t}; + + if (const UnionTypeVar* utv = get(t)) + { + for (TypeId ty : utv) + { + ty = follow(ty); + if (get(ty)) + continue; + if (get(ty) || get(ty)) + return {ty}; + + if (result.end() == std::find(result.begin(), result.end(), ty)) + result.push_back(ty); + } + } + else if (std::find(result.begin(), result.end(), t) == result.end()) + result.push_back(t); + } + + return result; +} + +static std::optional tryStripUnionFromNil(TypeArena& arena, TypeId ty) +{ + if (const UnionTypeVar* utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + + std::vector result; + + for (TypeId option : utv) + { + if (!isNil(option)) + result.push_back(option); + } + + if (result.empty()) + return std::nullopt; + + return result.size() == 1 ? result[0] : arena.addType(UnionTypeVar{std::move(result)}); + } + + return std::nullopt; +} + +TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty) +{ + ty = follow(ty); + + if (get(ty)) + { + std::optional cleaned = tryStripUnionFromNil(arena, ty); + + // If there is no union option without 'nil' + if (!cleaned) + return singletonTypes->nilType; + + return follow(*cleaned); + } + + return follow(ty); +} + } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index bcdaff7d2..19d3d2669 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -57,6 +57,13 @@ TypeId follow(TypeId t, std::function mapper) return btv->boundTo; else if (auto ttv = get(mapper(ty))) return ttv->boundTo; + else if (auto utv = get(mapper(ty))) + { + std::optional ty = utv->scope->lookup(utv->def); + if (!ty) + throw std::runtime_error("UseTypeVar must map to another TypeId"); + return *ty; + } else return std::nullopt; }; @@ -760,6 +767,8 @@ SingletonTypes::SingletonTypes() , unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true})) , neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true})) , errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true})) + , falsyType(arena->addType(TypeVar{UnionTypeVar{{falseType, nilType}}, /*persistent*/ true})) + , truthyType(arena->addType(TypeVar{NegationTypeVar{falsyType}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack)) @@ -896,7 +905,6 @@ void persist(TypeId ty) continue; asMutable(t)->persistent = true; - asMutable(t)->normal = true; // all persistent types are assumed to be normal if (auto btv = get(t)) queue.push_back(btv->boundTo); @@ -933,11 +941,6 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } - else if (auto ctv = get(t)) - { - for (TypeId opt : ctv->parts) - queue.push_back(opt); - } else if (auto mtv = get(t)) { queue.push_back(mtv->table); @@ -990,8 +993,6 @@ const TypeLevel* getLevel(TypeId ty) return &ttv->level; else if (auto ftv = get(ty)) return &ftv->level; - else if (auto ctv = get(ty)) - return &ctv->level; else return nullptr; } @@ -1056,11 +1057,6 @@ const std::vector& getTypes(const IntersectionTypeVar* itv) return itv->parts; } -const std::vector& getTypes(const ConstrainedTypeVar* ctv) -{ - return ctv->parts; -} - UnionTypeVarIterator begin(const UnionTypeVar* utv) { return UnionTypeVarIterator{utv}; @@ -1081,17 +1077,6 @@ IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) return IntersectionTypeVarIterator{}; } -ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv) -{ - return ConstrainedTypeVarIterator{ctv}; -} - -ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv) -{ - return ConstrainedTypeVarIterator{}; -} - - static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs*"; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 42fcd2fda..e23e6161c 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -13,16 +13,13 @@ #include -LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINT(LuauTypeInferIterationLimit); -LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) -LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) +LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -95,15 +92,6 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor return true; } - bool visit(TypeId ty, const ConstrainedTypeVar&) override - { - if (!FFlag::LuauUnknownAndNeverType) - return visit(ty); - - promote(ty, log.getMutable(ty)); - return true; - } - bool visit(TypeId ty, const FunctionTypeVar&) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -368,26 +356,14 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); ++sharedState.counters.iterationCount; - if (FFlag::LuauAutocompleteDynamicLimits) - { - if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } - } - else + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } + reportError(TypeError{location, UnificationTooComplex{}}); + return; } superTy = log.follow(superTy); @@ -396,9 +372,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; - if (log.get(superTy)) - return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy); - auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); @@ -520,9 +493,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (log.get(subTy)) - tryUnifyWithConstrainedSubTypeVar(subTy, superTy); - else if (const UnionTypeVar* subUnion = log.getMutable(subTy)) + if (const UnionTypeVar* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); } @@ -1011,10 +982,17 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized log.concat(std::move(innerState.log)); if (result) { + if (FFlag::LuauOverloadedFunctionSubtypingPerf) + { + innerState.log.clear(); + innerState.tryUnify_(*result, ftv->retTypes); + } + if (FFlag::LuauOverloadedFunctionSubtypingPerf && innerState.errors.empty()) + log.concat(std::move(innerState.log)); // Annoyingly, since we don't support intersection of generic type packs, // the intersection may fail. We rather arbitrarily use the first matching overload // in that case. - if (std::optional intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) + else if (std::optional intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) result = intersect; } else @@ -1214,26 +1192,14 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall */ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); ++sharedState.counters.iterationCount; - if (FFlag::LuauAutocompleteDynamicLimits) - { - if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } - } - else + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } + reportError(TypeError{location, UnificationTooComplex{}}); + return; } superTp = log.follow(superTp); @@ -2314,186 +2280,6 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); } -void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy) -{ - const ConstrainedTypeVar* subConstrained = get(subTy); - if (!subConstrained) - ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!"); - - const std::vector& subTyParts = subConstrained->parts; - - // A | B <: T if A <: T and B <: T - bool failed = false; - std::optional unificationTooComplex; - - const size_t count = subTyParts.size(); - - for (size_t i = 0; i < count; ++i) - { - TypeId type = subTyParts[i]; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, superTy); - - if (i == count - 1) - log.concat(std::move(innerState.log)); - - ++i; - - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - - if (!innerState.errors.empty()) - { - failed = true; - break; - } - } - - if (unificationTooComplex) - reportError(*unificationTooComplex); - else if (failed) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - else - log.replace(subTy, BoundTypeVar{superTy}); -} - -void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy) -{ - ConstrainedTypeVar* superC = log.getMutable(superTy); - if (!superC) - ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!"); - - // subTy could be a - // table - // metatable - // class - // function - // primitive - // free - // generic - // intersection - // union - // Do we really just tack it on? I think we might! - // We can certainly do some deduplication. - // Is there any point to deducing Player|Instance when we could just reduce to Instance? - // Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type? - // Maybe we do a simplification step during quantification. - - auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy); - if (it != superC->parts.end()) - return; - - superC->parts.push_back(subTy); -} - -void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel) -{ - // The duplication between this and regular typepack unification is tragic. - - auto superIter = begin(superTy, &log); - auto superEndIter = end(superTy); - - auto subIter = begin(subTy, &log); - auto subEndIter = end(subTy); - - int count = FInt::LuauTypeInferLowerBoundsIterationLimit; - - for (; subIter != subEndIter; ++subIter) - { - if (0 >= --count) - ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound"); - - if (superIter != superEndIter) - { - tryUnify_(*subIter, *superIter); - ++superIter; - continue; - } - - if (auto t = superIter.tail()) - { - TypePackId tailPack = follow(*t); - - if (log.get(tailPack) && occursCheck(tailPack, subTy)) - return; - - FreeTypePack* freeTailPack = log.getMutable(tailPack); - if (!freeTailPack) - return; - - TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); - - for (; subIter != subEndIter; ++subIter) - { - tp->head.push_back(types->addType(ConstrainedTypeVar{demotedLevel, {follow(*subIter)}})); - } - - tp->tail = subIter.tail(); - } - - return; - } - - if (superIter != superEndIter) - { - if (auto subTail = subIter.tail()) - { - TypePackId subTailPack = follow(*subTail); - if (get(subTailPack)) - { - TypePack* tp = getMutable(log.replace(subTailPack, TypePack{})); - - for (; superIter != superEndIter; ++superIter) - tp->head.push_back(*superIter); - } - else if (const VariadicTypePack* subVariadic = log.getMutable(subTailPack)) - { - while (superIter != superEndIter) - { - tryUnify_(subVariadic->ty, *superIter); - ++superIter; - } - } - } - else - { - while (superIter != superEndIter) - { - if (!isOptional(*superIter)) - { - errors.push_back(TypeError{location, CountMismatch{size(superTy), std::nullopt, size(subTy), CountMismatch::Return}}); - return; - } - ++superIter; - } - } - - return; - } - - // Both iters are at their respective tails - auto subTail = subIter.tail(); - auto superTail = superIter.tail(); - if (subTail && superTail) - tryUnify(*subTail, *superTail); - else if (subTail) - { - const FreeTypePack* freeSubTail = log.getMutable(*subTail); - if (freeSubTail) - { - log.replace(*subTail, TypePack{}); - } - } - else if (superTail) - { - const FreeTypePack* freeSuperTail = log.getMutable(*superTail); - if (freeSuperTail) - { - log.replace(*superTail, TypePack{}); - } - } -} - bool Unifier::occursCheck(TypeId needle, TypeId haystack) { sharedState.tempSeenTy.clear(); @@ -2503,8 +2289,7 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack) bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); bool occurrence = false; @@ -2547,11 +2332,6 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays for (TypeId ty : a->parts) check(ty); } - else if (auto a = log.getMutable(haystack)) - { - for (TypeId ty : a->parts) - check(ty); - } return occurrence; } @@ -2579,8 +2359,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (!log.getMutable(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); while (!log.getMutable(haystack)) { diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index c20c08471..7150b18fc 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -905,6 +905,25 @@ AstStat* Parser::parseDeclaration(const Location& start) { props.push_back(parseDeclaredClassMethod()); } + else if (lexer.current().type == '[') + { + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "property type annotation"); + AstType* type = parseTypeAnnotation(); + + // TODO: since AstName conains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false}); + else + report(begin.location, "String literal contains malformed escape sequence"); + } else { Name propName = parseName("property name"); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 7e4c5691c..6257e2f3a 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -14,7 +14,6 @@ #endif LUAU_FASTFLAG(DebugLuauTimeTracing) -LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution) enum class ReportFormat { @@ -55,11 +54,9 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); - else if (FFlag::LuauTypeMismatchModuleNameResolution) + else report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); - else - report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str()); } static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index aecddf383..a9dd8970a 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -16,6 +16,7 @@ #include "isocline.h" +#include #include #ifdef _WIN32 @@ -49,6 +50,8 @@ enum class CompileFormat Binary, Remarks, Codegen, + CodegenVerbose, + CodegenNull, Null }; @@ -673,21 +676,33 @@ static void reportError(const char* name, const Luau::CompileError& error) report(name, error.getLocation(), "CompileError", error.what()); } -static std::string getCodegenAssembly(const char* name, const std::string& bytecode) +static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); - setupState(L); - if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) - return Luau::CodeGen::getAssemblyText(L, -1); + return Luau::CodeGen::getAssembly(L, -1, options); fprintf(stderr, "Error loading bytecode %s\n", name); return ""; } -static bool compileFile(const char* name, CompileFormat format) +static void annotateInstruction(void* context, std::string& text, int fid, int instid) +{ + Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; + + bcb.annotateInstruction(text, fid, instid); +} + +struct CompileStats +{ + size_t lines; + size_t bytecode; + size_t codegen; +}; + +static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) { std::optional source = readFile(name); if (!source) @@ -696,9 +711,12 @@ static bool compileFile(const char* name, CompileFormat format) return false; } + stats.lines += std::count(source->begin(), source->end(), '\n'); + try { Luau::BytecodeBuilder bcb; + Luau::CodeGen::AssemblyOptions options = {format == CompileFormat::CodegenNull, format == CompileFormat::Codegen, annotateInstruction, &bcb}; if (format == CompileFormat::Text) { @@ -711,8 +729,15 @@ static bool compileFile(const char* name, CompileFormat format) bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpSource(*source); } + else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenVerbose) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } Luau::compileOrThrow(bcb, *source, copts()); + stats.bytecode += bcb.getBytecode().size(); switch (format) { @@ -726,7 +751,11 @@ static bool compileFile(const char* name, CompileFormat format) fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); break; case CompileFormat::Codegen: - printf("%s", getCodegenAssembly(name, bcb.getBytecode()).c_str()); + case CompileFormat::CodegenVerbose: + printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); + break; + case CompileFormat::CodegenNull: + stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); break; case CompileFormat::Null: break; @@ -755,7 +784,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available modes:\n"); printf(" omitted: compile and run input files one by one\n"); - printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary, text, remarks, codegen or null)\n"); + printf(" --compile[=format]: compile input files and output resulting bytecode/assembly (binary, text, remarks, codegen)\n"); printf("\n"); printf("Available options:\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); @@ -812,6 +841,14 @@ int replMain(int argc, char** argv) { compileFormat = CompileFormat::Codegen; } + else if (strcmp(argv[1], "--compile=codegenverbose") == 0) + { + compileFormat = CompileFormat::CodegenVerbose; + } + else if (strcmp(argv[1], "--compile=codegennull") == 0) + { + compileFormat = CompileFormat::CodegenNull; + } else if (strcmp(argv[1], "--compile=null") == 0) { compileFormat = CompileFormat::Null; @@ -923,10 +960,16 @@ int replMain(int argc, char** argv) _setmode(_fileno(stdout), _O_BINARY); #endif + CompileStats stats = {}; int failed = 0; for (const std::string& path : files) - failed += !compileFile(path.c_str(), compileFormat); + failed += !compileFile(path.c_str(), compileFormat, stats); + + if (compileFormat == CompileFormat::Null) + printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024)); + else if (compileFormat == CompileFormat::CodegenNull) + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024)); return failed ? 1 : 0; } diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 1c7550170..e48388c50 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -23,6 +23,19 @@ enum class RoundingModeX64 RoundToZero = 0b11, }; +enum class AlignmentDataX64 +{ + Nop, + Int3, + Ud2, // int3 will be used as a fall-back if it doesn't fit +}; + +enum class ABIX64 +{ + Windows, + SystemV, +}; + class AssemblyBuilderX64 { public: @@ -80,6 +93,10 @@ class AssemblyBuilderX64 void int3(); + // Code alignment + void nop(uint32_t length = 1); + void align(uint32_t alignment, AlignmentDataX64 data = AlignmentDataX64::Nop); + // AVX void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2); @@ -131,6 +148,8 @@ class AssemblyBuilderX64 void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); + uint32_t getCodeSize() const; + // Resulting data and code that need to be copied over one after the other // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' std::vector data; @@ -140,6 +159,8 @@ class AssemblyBuilderX64 const bool logText = false; + const ABIX64 abi; + private: // Instruction archetypes void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev, @@ -177,7 +198,6 @@ class AssemblyBuilderX64 void commit(); LUAU_NOINLINE void extend(); - uint32_t getCodeSize(); // Data size_t allocateData(size_t size, size_t align); @@ -192,8 +212,8 @@ class AssemblyBuilderX64 LUAU_NOINLINE void log(const char* opcode, Label label); void log(OperandX64 op); - const char* getSizeName(SizeX64 size); - const char* getRegisterName(RegisterX64 reg); + const char* getSizeName(SizeX64 size) const; + const char* getRegisterName(RegisterX64 reg) const; uint32_t nextLabel = 1; std::vector(a) -> a" == toString(idType)); @@ -66,7 +42,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") { - AstStatBlock* block = parse(R"( + solve(R"( local function a(c) local function d(e) return c @@ -78,21 +54,9 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") local b = a(5) )"); - cgb.visit(block); - NotNull rootScope{cgb.rootScope}; - - ToStringOptions opts; - - NullModuleResolver resolver; - InternalErrorReporter iceHandler; - UnifierSharedState sharedState{&iceHandler}; - Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; - ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; - - cs.run(); - TypeId idType = requireBinding(rootScope, "b"); + ToStringOptions opts; CHECK("(a) -> number" == toString(idType, opts)); } diff --git a/tests/DataFlowGraphBuilder.test.cpp b/tests/DataFlowGraphBuilder.test.cpp new file mode 100644 index 000000000..9aa7cde6b --- /dev/null +++ b/tests/DataFlowGraphBuilder.test.cpp @@ -0,0 +1,104 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/DataFlowGraphBuilder.h" +#include "Luau/Error.h" +#include "Luau/Parser.h" + +#include "AstQueryDsl.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +class DataFlowGraphFixture +{ + // Only needed to fix the operator== reflexivity of an empty Symbol. + ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; + + InternalErrorReporter handle; + + Allocator allocator; + AstNameTable names{allocator}; + AstStatBlock* module; + + std::optional graph; + +public: + void dfg(const std::string& code) + { + ParseResult parseResult = Parser::parse(code.c_str(), code.size(), names, allocator); + if (!parseResult.errors.empty()) + throw ParseErrors(std::move(parseResult.errors)); + module = parseResult.root; + graph = DataFlowGraphBuilder::build(module, NotNull{&handle}); + } + + template + std::optional getDef(const std::vector& nths = {nth(N)}) + { + T* node = query(module, nths); + REQUIRE(node); + return graph->getDef(node); + } + + template + DefId requireDef(const std::vector& nths = {nth(N)}) + { + auto loc = getDef(nths); + REQUIRE(loc); + return NotNull{*loc}; + } +}; + +TEST_SUITE_BEGIN("DataFlowGraphBuilder"); + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_locals_in_local_stat") +{ + dfg(R"( + local x = 5 + local y = x + )"); + + REQUIRE(getDef()); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_parameters_in_functions") +{ + dfg(R"( + local function f(x) + local y = x + end + )"); + + REQUIRE(getDef()); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "find_aliases") +{ + dfg(R"( + local x = 5 + local y = x + local z = y + )"); + + DefId x = requireDef(); + DefId y = requireDef(); + REQUIRE(x != y); // TODO: they should be equal but it's not just locals that can alias, so we'll support this later. +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "independent_locals") +{ + dfg(R"( + local x = 5 + local y = 5 + + local a = x + local b = y + )"); + + DefId x = requireDef(); + DefId y = requireDef(); + REQUIRE(x != y); +} + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 9a77bf392..579b8942b 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -430,7 +430,8 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test"); freeze(typeChecker.globalTypes); - dumpErrors(result.module); + if (result.module) + dumpErrors(result.module); REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); return result; } diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index d5f635e65..33d9c75a7 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -226,24 +226,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_free_tables") CHECK_EQ(clonedTtv->state, TableState::Free); } -TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") -{ - TypeArena src; - - TypeId constrained = src.addType(ConstrainedTypeVar{TypeLevel{}, {singletonTypes->numberType, singletonTypes->stringType}}); - - TypeArena dest; - CloneState cloneState; - - TypeId cloned = clone(constrained, dest, cloneState); - CHECK_NE(constrained, cloned); - - const ConstrainedTypeVar* ctv = get(cloned); - REQUIRE_EQ(2, ctv->parts.size()); - CHECK_EQ(singletonTypes->numberType, ctv->parts[0]); - CHECK_EQ(singletonTypes->stringType, ctv->parts[1]); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") { fileResolver.source["Module/A"] = R"( diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index b3522f6e2..20e0e34c5 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -391,26 +391,6 @@ TEST_SUITE_END(); TEST_SUITE_BEGIN("Normalize"); -TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_subtype_relationship") -{ - check(R"( - local t: {x: number} | {x: number?} - )"); - - ModulePtr tempModule{new Module}; - tempModule->scopes.emplace_back(Location(), std::make_shared(singletonTypes->anyTypePack)); - - // HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze - // the arena that the type lives in. - ModulePtr mainModule = getMainModule(); - unfreeze(mainModule->internalTypes); - - TypeId tType = requireType("t"); - normalize(tType, tempModule, singletonTypes, *typeChecker.iceHandler); - - CHECK_EQ("{| x: number? |}", toString(tType, {true})); -} - TEST_CASE_FIXTURE(Fixture, "higher_order_function") { check(R"( diff --git a/tests/Symbol.test.cpp b/tests/Symbol.test.cpp index e7d2973b8..278c6ce2b 100644 --- a/tests/Symbol.test.cpp +++ b/tests/Symbol.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; TEST_SUITE_BEGIN("SymbolTests"); -TEST_CASE("hashing_globals") +TEST_CASE("equality_and_hashing_of_globals") { std::string s1 = "name"; std::string s2 = "name"; @@ -37,7 +37,7 @@ TEST_CASE("hashing_globals") REQUIRE_EQ(1, theMap.size()); } -TEST_CASE("hashing_locals") +TEST_CASE("equality_and_hashing_of_locals") { std::string s1 = "name"; std::string s2 = "name"; @@ -64,4 +64,24 @@ TEST_CASE("hashing_locals") REQUIRE_EQ(2, theMap.size()); } +TEST_CASE("equality_of_empty_symbols") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + std::string s1 = "name"; + std::string s2 = "name"; + + AstName one{s1.data()}; + AstLocal two{AstName{s2.data()}, Location(), nullptr, 0, 0, nullptr}; + + Symbol global{one}; + Symbol local{&two}; + Symbol empty1{}; + Symbol empty2{}; + + CHECK(empty1 != global); + CHECK(empty1 != local); + CHECK(empty1 == empty2); +} + TEST_SUITE_END(); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 98eb9863b..26c9a1ee4 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -79,8 +79,8 @@ n1 [label="AnyTypeVar 1"]; TEST_CASE_FIXTURE(Fixture, "bound") { CheckResult result = check(R"( -local a = 444 -local b = a +function a(): number return 444 end +local b = a() )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -367,27 +367,6 @@ n3 [label="number"]; toDot(*ty, opts)); } -TEST_CASE_FIXTURE(Fixture, "constrained") -{ - // ConstrainedTypeVars never appear in the final type graph, so we have to create one directly - // to dotify it. - TypeVar t{ConstrainedTypeVar{TypeLevel{}, {typeChecker.numberType, typeChecker.stringType, typeChecker.nilType}}}; - - ToDotOptions opts; - opts.showPointers = false; - - CHECK_EQ(R"(digraph graphname { -n1 [label="ConstrainedTypeVar 1"]; -n1 -> n2; -n2 [label="number"]; -n1 -> n3; -n3 [label="string"]; -n1 -> n4; -n4 [label="nil"]; -})", - toDot(&t, opts)); -} - TEST_CASE_FIXTURE(Fixture, "singletontypes") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 5ecc2a8ca..1bb97fbc7 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -846,8 +846,16 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni type FutureIntersection = A & B )"); - // TODO: shared self causes this test to break in bizarre ways. - LUAU_REQUIRE_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // To be quite honest, I don't know exactly why DCR fixes this. + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + // TODO: shared self causes this test to break in bizarre ways. + LUAU_REQUIRE_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 26280c134..15c63ec78 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -362,4 +362,21 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_overload_metamethods") CHECK_EQ(toString(requireType("shouldBeVector")), "Vector3"); } +TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") +{ + loadDefinition(R"( + declare class Foo + ["a property"]: string + end + )"); + + CheckResult result = check(R"( + local x: Foo + local y = x["a property"] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("y")), "string"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index edc25c7e8..075bb01ad 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1692,4 +1692,47 @@ foo(string.find("hello", "e")) CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo' expects 0 to 2 arguments, but 3 are specified"); } +TEST_CASE_FIXTURE(Fixture, "luau_subtyping_is_np_hard") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + {"LuauOverloadedFunctionSubtypingPerf", true}, + }; + + CheckResult result = check(R"( +--!strict + +-- An example of coding up graph coloring in the Luau type system. +-- This codes a three-node, two color problem. +-- A three-node triangle is uncolorable, +-- but a three-node line is colorable. + +type Red = "red" +type Blue = "blue" +type Color = Red | Blue +type Coloring = (Color) -> (Color) -> (Color) -> boolean +type Uncolorable = (Color) -> (Color) -> (Color) -> false + +type Line = Coloring + & ((Red) -> (Red) -> (Color) -> false) + & ((Blue) -> (Blue) -> (Color) -> false) + & ((Color) -> (Red) -> (Red) -> false) + & ((Color) -> (Blue) -> (Blue) -> false) + +type Triangle = Line + & ((Red) -> (Color) -> (Red) -> false) + & ((Blue) -> (Color) -> (Blue) -> false) + +local x : Triangle +local y : Line +local z : Uncolorable +z = x -- OK, so the triangle is uncolorable +z = y -- Not OK, so the line is colorable + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> boolean) & ((\"blue\" | \"red\") -> (\"blue\") -> (\"blue\") -> false) & ((\"blue\" | \"red\") -> (\"red\") -> (\"red\") -> false) & ((\"blue\") -> (\"blue\") -> (\"blue\" | \"red\") -> false) & ((\"red\") -> (\"red\") -> (\"blue\" | \"red\") -> false)' could not be converted into '(\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> false'; none of the intersection parts are compatible"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index ca22c351b..0c10eb87e 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -781,7 +781,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") CheckResult result = check(R"( local a : string? = nil local b : number? = nil - + local x = setmetatable({}, { p = 5, q = a }); local y = setmetatable({}, { q = b, r = "hi" }); local z = setmetatable({}, { p = 5, q = nil, r = "hi" }); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 3d6c0193f..e572c87ac 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + TEST_SUITE_BEGIN("TypeInferOperators"); TEST_CASE_FIXTURE(Fixture, "or_joins_types") @@ -33,7 +35,7 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") local x:number|string = s local y = x or "s" )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(toString(*requireType("s")), "number | string"); CHECK_EQ(toString(*requireType("y")), "number | string"); } @@ -44,7 +46,7 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") local s = "a" or "b" local x:string = s )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*requireType("s"), *typeChecker.stringType); } @@ -54,7 +56,7 @@ TEST_CASE_FIXTURE(Fixture, "and_adds_boolean") local s = "a" and 10 local x:boolean|number = s )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(toString(*requireType("s")), "boolean | number"); } @@ -64,7 +66,7 @@ TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") local s = "a" and true local x:boolean = s )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*requireType("x"), *typeChecker.booleanType); } @@ -73,7 +75,7 @@ TEST_CASE_FIXTURE(Fixture, "and_or_ternary") CheckResult result = check(R"( local s = (1/2) > 0.5 and "a" or 10 )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(toString(*requireType("s")), "number | string"); } @@ -81,7 +83,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") { CheckResult result = check(R"( function add(a: number, b: string) - return a + (tonumber(b) :: number), a .. b + return a + (tonumber(b) :: number), tostring(a) .. b end local n, s = add(2,"3") )"); @@ -558,15 +560,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables LUAU_REQUIRE_ERROR_COUNT(3, result); TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); - REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); + REQUIRE(tm); + CHECK_EQ(*tm->wantedType, *typeChecker.numberType); + CHECK_EQ(*tm->givenType, *typeChecker.stringType); + + GenericError* gen1 = get(result.errors[1]); + REQUIRE(gen1); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(gen1->message, "Operator + is not applicable for '{ value: number }' and 'number' because neither type has a metatable"); + else + CHECK_EQ(gen1->message, "Binary operator '+' not supported by types 'foo' and 'number'"); TypeMismatch* tm2 = get(result.errors[2]); + REQUIRE(tm2); CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); CHECK_EQ(*tm2->givenType, *requireType("foo")); - - GenericError* gen2 = get(result.errors[1]); - REQUIRE_EQ(gen2->message, "Binary operator '+' not supported by types 'foo' and 'number'"); } // CLI-29033 @@ -611,12 +619,10 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") { std::vector ops = {"+", "-", "*", "/", "%", "^", ".."}; - std::string src = R"( - function foo(a, b) - )"; + std::string src = "function foo(a, b)\n"; for (const auto& op : ops) - src += "local _ = a " + op + "b\n"; + src += "local _ = a " + op + " b\n"; src += "end"; @@ -651,7 +657,11 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato GenericError* ge = get(result.errors[0]); REQUIRE(ge); - CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Types 'boolean' and 'boolean' cannot be compared with relational operator <", ge->message); + else + CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); } TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") @@ -666,7 +676,10 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato GenericError* ge = get(result.errors[0]); REQUIRE(ge); - CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Types 'number | string' and 'number | string' cannot be compared with relational operator <", ge->message); + else + CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); } TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") @@ -891,4 +904,63 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "mm_ops_must_return_a_value") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local mm = { + __add = function(self, other) + return + end, + } + + local x = setmetatable({}, mm) + local y = x + 123 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(requireType("y") == singletonTypes->errorRecoveryType()); + + const GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK(ge->message == "Metamethod '__add' must return a value"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local mm1 = { + __lt = function(self, other) + return 123 + end, + } + + local mm2 = { + __lt = function(self, other) + return + end, + } + + local o1 = setmetatable({}, mm1) + local v1 = o1 < o1 + + local o2 = setmetatable({}, mm2) + local v2 = o2 < o2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(requireType("v1") == singletonTypes->booleanType); + CHECK(requireType("v2") == singletonTypes->booleanType); + + CHECK(toString(result.errors[0]) == "Metamethod '__lt' must return type 'boolean'"); + CHECK(toString(result.errors[1]) == "Metamethod '__lt' must return type 'boolean'"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f707f9522..15d94430b 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,6 +8,7 @@ #include "doctest.h" LUAU_FASTFLAG(LuauSpecialTypesAsterisked) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) using namespace Luau; @@ -49,7 +50,6 @@ struct RefinementClassFixture : Fixture {"Y", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}}, }; - normalize(vec3, scope, arena, singletonTypes, *typeChecker.iceHandler); TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); @@ -57,21 +57,17 @@ struct RefinementClassFixture : Fixture TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - normalize(isA, scope, arena, singletonTypes, *typeChecker.iceHandler); getMutable(inst)->props = { {"Name", Property{typeChecker.stringType}}, {"IsA", Property{isA}}, }; - normalize(inst, scope, arena, singletonTypes, *typeChecker.iceHandler); TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); - normalize(folder, scope, arena, singletonTypes, *typeChecker.iceHandler); TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); getMutable(part)->props = { {"Position", Property{vec3}}, }; - normalize(part, scope, arena, singletonTypes, *typeChecker.iceHandler); typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; @@ -102,8 +98,16 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({5, 26}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint") @@ -120,8 +124,16 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(string?) & ~~~(false?)", toString(requireTypeAtPosition({5, 26}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through") @@ -138,8 +150,16 @@ TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(string?) & ~~~(false?)", toString(requireTypeAtPosition({5, 26}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "and_constraint") @@ -963,19 +983,27 @@ TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") { CheckResult result = check(R"( - local function is_true(b: true) end - local function is_false(b: false) end - local function f(x: boolean) if x then - is_true(x) + local foo = x else - is_false(x) + local foo = x end end )"); LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("boolean & ~(false?)", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("boolean & ~~(false?)", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("true", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("false", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 53f9a1abb..2a208cce0 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -11,6 +11,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) TEST_SUITE_BEGIN("TableTests"); @@ -44,7 +46,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_table") const TableTypeVar* tType = get(requireType("t")); REQUIRE(tType != nullptr); - CHECK(tType->props.find("foo") != tType->props.end()); + CHECK(1 == tType->props.count("foo")); } TEST_CASE_FIXTURE(Fixture, "augment_nested_table") @@ -101,7 +103,11 @@ TEST_CASE_FIXTURE(Fixture, "updating_sealed_table_prop_is_ok") TEST_CASE_FIXTURE(Fixture, "cannot_change_type_of_unsealed_table_prop") { - CheckResult result = check("local t = {} t.prop = 999 t.prop = 'hello'"); + CheckResult result = check(R"( + local t = {} + t.prop = 999 + t.prop = 'hello' + )"); LUAU_REQUIRE_ERROR_COUNT(1, result); } @@ -858,11 +864,12 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("a")); + CHECK("string" == toString(*typeChecker.stringType)); TableTypeVar* tableType = getMutable(requireType("t")); REQUIRE(tableType != nullptr); REQUIRE(tableType->indexer == std::nullopt); + REQUIRE(0 != tableType->props.count("a")); TypeId propertyA = tableType->props["a"].type; REQUIRE(propertyA != nullptr); @@ -2390,9 +2397,12 @@ TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") { - CheckResult result = check("local a = {a=1, b=2} a['a'] = nil"); + CheckResult result = check(R"( + local a = {a=1, b=2} + a['a'] = nil + )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ + CHECK_EQ(result.errors[0], (TypeError{Location{Position{2, 17}, Position{2, 20}}, TypeMismatch{ typeChecker.numberType, typeChecker.nilType, }})); @@ -2701,6 +2711,62 @@ local baz = foo[bar] CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_basic") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local a = setmetatable({ + a = 1, + }, { + __call = function(self, b: number) + return self.a * b + end, + }) + + local foo = a(12) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(requireType("foo") == singletonTypes->numberType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_must_be_callable") +{ + CheckResult result = check(R"( + local a = setmetatable({}, { + __call = 123, + }) + + local foo = a() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{ + Location{{5, 20}, {5, 21}}, + CannotCallNonFunction{singletonTypes->numberType}, + }); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_generic") +{ + CheckResult result = check(R"( + local a = setmetatable({}, { + __call = function(self, b: T) + return b + end, + }) + + local foo = a(12) + local bar = a("bar") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(requireType("foo") == singletonTypes->numberType); + CHECK(requireType("bar") == singletonTypes->stringType); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 239b8c28f..dff9649fb 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1046,7 +1046,6 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, - {"LuauAutocompleteDynamicLimits", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index aaa7ded44..4c8eeac60 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -467,6 +467,8 @@ type I = W TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") { + ScopedFastFlag sff("LuauFunctionReturnStringificationFixup", true); + CheckResult result = check(R"( type X = (T...) -> (T...) @@ -490,6 +492,8 @@ type F = X<(string, ...number)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") { + ScopedFastFlag sff("LuauFunctionReturnStringificationFixup", true); + CheckResult result = check(R"( type Y = (T...) -> (U...) diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index b81c80ce4..5dd1b1bcc 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -436,7 +436,6 @@ TEST_CASE("proof_that_isBoolean_uses_all_of") TEST_CASE("content_reassignment") { TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; - myAny.normal = true; myAny.documentationSymbol = "@global/any"; TypeArena arena; @@ -446,7 +445,6 @@ TEST_CASE("content_reassignment") CHECK(get(futureAny) != nullptr); CHECK(!futureAny->persistent); - CHECK(futureAny->normal); CHECK(futureAny->documentationSymbol == "@global/any"); CHECK(futureAny->owningArena == &arena); } diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index b7e85aa74..7a05f8e9c 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -93,7 +93,10 @@ assert((function() local a = 1 a = a * 2 return a end)() == 2) assert((function() local a = 1 a = a / 2 return a end)() == 0.5) assert((function() local a = 5 a = a % 2 return a end)() == 1) assert((function() local a = 3 a = a ^ 2 return a end)() == 9) +assert((function() local a = 3 a = a ^ 3 return a end)() == 27) assert((function() local a = 9 a = a ^ 0.5 return a end)() == 3) +assert((function() local a = -2 a = a ^ 2 return a end)() == 4) +assert((function() local a = -2 a = a ^ 0.5 return tostring(a) end)() == "nan") assert((function() local a = '1' a = a .. '2' return a end)() == "12") assert((function() local a = '1' a = a .. '2' .. '3' return a end)() == "123") @@ -706,7 +709,11 @@ end assert(chainTest(100) == "v0,v100") -- this validates import fallbacks +assert(idontexist == nil) +assert(math.idontexist == nil) assert(pcall(function() return idontexist.a end) == false) +assert(pcall(function() return math.pow.a end) == false) +assert(pcall(function() return math.a.b end) == false) -- make sure that NaN is preserved by the bytecode compiler local realnan = tostring(math.abs(0)/math.abs(0)) diff --git a/tests/conformance/calls.lua b/tests/conformance/calls.lua index 7f9610a32..621a921aa 100644 --- a/tests/conformance/calls.lua +++ b/tests/conformance/calls.lua @@ -226,4 +226,14 @@ assert((function () return nil end)(4) == nil) assert((function () local a; return a end)(4) == nil) assert((function (a) return a end)() == nil) +-- C-stack overflow while handling C-stack overflow +if not limitedstack then + local function loop () + assert(pcall(loop)) + end + + local err, msg = xpcall(loop, loop) + assert(not err and string.find(msg, "error")) +end + return('OK') diff --git a/tests/conformance/datetime.lua b/tests/conformance/datetime.lua index ca35cf2f1..dc73948b6 100644 --- a/tests/conformance/datetime.lua +++ b/tests/conformance/datetime.lua @@ -16,6 +16,7 @@ D = os.date("*t", t) assert(os.date(string.rep("%d", 1000), t) == string.rep(os.date("%d", t), 1000)) assert(os.date(string.rep("%", 200)) == string.rep("%", 100)) +assert(os.date("", -1) == nil) local function checkDateTable (t) local D = os.date("!*t", t) diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index 529e9b0ca..57d2b6939 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -405,5 +405,7 @@ assert(ecall(function() (""):foo() end) == "attempt to call missing method 'foo' assert(ecall(function() (42):foo() end) == "attempt to index number with 'foo'") assert(ecall(function() ({foo=42}):foo() end) == "attempt to call a number value") assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = {} ud:foo() end) == "attempt to call missing method 'foo' of userdata") +assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = function() end ud:foo() end) == "attempt to call missing method 'foo' of userdata") +assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = function() error("nope") end ud:foo() end) == "nope") return('OK') diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 447b67bce..94314c3fb 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -13,6 +13,11 @@ assert(getmetatable(a) == "xuxu") ud=newproxy(true); getmetatable(ud).__metatable = "xuxu" assert(getmetatable(ud) == "xuxu") +assert(pcall(getmetatable) == false) +assert(pcall(function() return getmetatable() end) == false) +assert(select(2, pcall(getmetatable, {})) == nil) +assert(select(2, pcall(getmetatable, ud)) == "xuxu") + local res,err = pcall(tostring, a) assert(not res and err == "'__tostring' must return a string") -- cannot change a protected metatable @@ -475,6 +480,9 @@ function testfenv() assert(_G.X == 20) assert(_G == getfenv(0)) + + assert(pcall(getfenv, 10) == false) + assert(pcall(setfenv, setfenv, {}) == false) end testfenv() -- DONT MOVE THIS LINE diff --git a/tests/conformance/iter.lua b/tests/conformance/iter.lua index 468ffafb3..5f8f1a89e 100644 --- a/tests/conformance/iter.lua +++ b/tests/conformance/iter.lua @@ -193,4 +193,24 @@ do assert(x == 15) end +-- pairs/ipairs/next may be substituted through getfenv +-- however, they *must* be substituted with functions - we don't support them falling back to generalized iteration +function testgetfenv() + local env = getfenv(1) + env.pairs = function() return "nope" end + env.ipairs = function() return "nope" end + env.next = {1, 2, 3} + + local ok, err = pcall(function() for k, v in pairs({}) do end end) + assert(not ok and err:match("attempt to iterate over a string value")) + + local ok, err = pcall(function() for k, v in ipairs({}) do end end) + assert(not ok and err:match("attempt to iterate over a string value")) + + local ok, err = pcall(function() for k, v in next, {} do end end) + assert(not ok and err:match("attempt to iterate over a table value")) +end + +testgetfenv() -- DONT MOVE THIS LINE + return"OK" diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 0cd0cdce7..972c399b2 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -283,6 +283,13 @@ assert(math.fmod(-3, 2) == -1) assert(math.fmod(3, -2) == 1) assert(math.fmod(-3, -2) == -1) +-- pow +assert(math.pow(2, 0) == 1) +assert(math.pow(2, 2) == 4) +assert(math.pow(4, 0.5) == 2) +assert(math.pow(-2, 2) == 4) +assert(tostring(math.pow(-2, 0.5)) == "nan") + -- most of the tests above go through fastcall path -- to make sure the basic implementations are also correct we test these functions with string->number coercions assert(math.abs("-4") == 4) diff --git a/tests/conformance/move.lua b/tests/conformance/move.lua index 3f28b4b35..27a96ffc8 100644 --- a/tests/conformance/move.lua +++ b/tests/conformance/move.lua @@ -74,4 +74,6 @@ checkerror("wrap around", table.move, {}, 1, maxI, 2) checkerror("wrap around", table.move, {}, 1, 2, maxI) checkerror("wrap around", table.move, {}, minI, -2, 2) +checkerror("readonly", table.move, table.freeze({}), 1, 1, 1) + return"OK" diff --git a/tests/conformance/strings.lua b/tests/conformance/strings.lua index 3d8fdd1f4..61bac7266 100644 --- a/tests/conformance/strings.lua +++ b/tests/conformance/strings.lua @@ -48,6 +48,7 @@ assert(string.find("", "") == 1) assert(string.find('', 'aaa', 1) == nil) assert(('alo(.)alo'):find('(.)', 1, 1) == 4) assert(string.find('', '1', 2) == nil) +assert(string.find('123', '2', 0) == 2) print('+') assert(string.len("") == 0) @@ -88,6 +89,8 @@ assert(string.lower("\0ABCc%$") == "\0abcc%$") assert(string.rep('teste', 0) == '') assert(string.rep('tés\00tę', 2) == 'tés\0tętés\000tę') assert(string.rep('', 10) == '') +assert(string.rep('', 1e9) == '') +assert(pcall(string.rep, 'x', 2e9) == false) assert(string.reverse"" == "") assert(string.reverse"\0\1\2\3" == "\3\2\1\0") @@ -126,6 +129,13 @@ assert(string.format("-%.20s.20s", string.rep("%", 2000)) == "-"..string.rep("%" assert(string.format('"-%20s.20s"', string.rep("%", 2000)) == string.format("%q", "-"..string.rep("%", 2000)..".20s")) +assert(string.format("%o %u %x %X", -1, -1, -1, -1) == "1777777777777777777777 18446744073709551615 ffffffffffffffff FFFFFFFFFFFFFFFF") + +assert(string.format("%e %E", 1.5, -1.5) == "1.500000e+00 -1.500000E+00") + +assert(pcall(string.format, "%##################d", 1) == false) +assert(pcall(string.format, "%.123d", 1) == false) +assert(pcall(string.format, "%?", 1) == false) -- longest number that can be formated assert(string.len(string.format('%99.99f', -1e308)) >= 100) @@ -179,6 +189,26 @@ assert(table.concat(a, ",", 2) == "b,c") assert(table.concat(a, ",", 3) == "c") assert(table.concat(a, ",", 4) == "") +-- string.split +do + local function eq(a, b) + if #a ~= #b then + return false + end + for i=1,#a do + if a[i] ~= b[i] then + return false + end + end + return true + end + + assert(eq(string.split("abc", ""), {'a', 'b', 'c'})) + assert(eq(string.split("abc", "b"), {'a', 'c'})) + assert(eq(string.split("abc", "d"), {'abc'})) + assert(eq(string.split("abc", "c"), {'ab', ''})) +end + --[[ local locales = { "ptb", "ISO-8859-1", "pt_BR" } local function trylocale (w) diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 0eff85408..7ae80cc4c 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -87,35 +87,59 @@ print'+' -- testing tables dynamically built local lim = 130 -local a = {}; a[2] = 1; check(a, 0, 1) -a = {}; a[0] = 1; check(a, 0, 1); a[2] = 1; check(a, 0, 2) -a = {}; a[0] = 1; a[1] = 1; check(a, 1, 1) -a = {} -for i = 1,lim do - a[i] = 1 - assert(#a == i) - check(a, mp2(i), 0) + +do + local a = {}; a[2] = 1; check(a, 0, 1) + a = {}; a[0] = 1; check(a, 0, 1); a[2] = 1; check(a, 0, 2) + a = {}; a[0] = 1; a[1] = 1; check(a, 1, 1) + a = {} + for i = 1,lim do + a[i] = 1 + assert(#a == i) + check(a, mp2(i), 0) + end end -a = {} -for i = 1,lim do - a['a'..i] = 1 - assert(#a == 0) - check(a, 0, mp2(i)) +do + local a = {} + for i = 1,lim do + a['a'..i] = 1 + assert(#a == 0) + check(a, 0, mp2(i)) + end end -a = {} -for i=1,16 do a[i] = i end -check(a, 16, 0) -for i=1,11 do a[i] = nil end -for i=30,40 do a[i] = nil end -- force a rehash (?) -check(a, 0, 8) -a[10] = 1 -for i=30,40 do a[i] = nil end -- force a rehash (?) -check(a, 0, 8) -for i=1,14 do a[i] = nil end -for i=30,50 do a[i] = nil end -- force a rehash (?) -check(a, 0, 4) +do + local a = {} + for i=1,16 do a[i] = i end + check(a, 16, 0) + for i=1,11 do a[i] = nil end + for i=30,40 do a[i] = nil end -- force a rehash (?) + check(a, 0, 8) + a[10] = 1 + for i=30,40 do a[i] = nil end -- force a rehash (?) + check(a, 0, 8) + for i=1,14 do a[i] = nil end + for i=30,50 do a[i] = nil end -- force a rehash (?) + check(a, 0, 4) +end + +do -- rehash moving elements from array to hash + local a = {} + for i = 1, 100 do a[i] = i end + check(a, 128, 0) + + for i = 5, 95 do a[i] = nil end + check(a, 128, 0) + + a.x = 1 -- force a re-hash + check(a, 4, 8) + + for i = 1, 4 do assert(a[i] == i) end + for i = 5, 95 do assert(a[i] == nil) end + for i = 96, 100 do assert(a[i] == i) end + assert(a.x == 1) +end -- reverse filling for i=1,lim do @@ -612,4 +636,54 @@ do assert(hit and child.foo == nil and parent.foo == nil) end +-- testing next x GC of deleted keys +do + local co = coroutine.wrap(function (t) + for k, v in pairs(t) do + local k1 = next(t) -- all previous keys were deleted + assert(k == k1) -- current key is the first in the table + t[k] = nil + local expected = (type(k) == "table" and k[1] or + type(k) == "function" and k() or + string.sub(k, 1, 1)) + assert(expected == v) + coroutine.yield(v) + end + end) + local t = {} + t[{1}] = 1 -- add several unanchored, collectable keys + t[{2}] = 2 + t[string.rep("a", 50)] = "a" -- long string + t[string.rep("b", 50)] = "b" + t[{3}] = 3 + t[string.rep("c", 10)] = "c" -- short string + t[function () return 10 end] = 10 + local count = 7 + while co(t) do + collectgarbage("collect") -- collect dead keys + count = count - 1 + end + assert(count == 0 and next(t) == nil) -- traversed the whole table +end + +-- test error cases for table functions +do + assert(pcall(table.insert, {}) == false) + assert(pcall(table.insert, {}, 1, 2, 3) == false) + assert(pcall(table.insert, table.freeze({1, 2, 3}), 4) == false) + assert(pcall(table.insert, table.freeze({1, 2, 3}), 1, 4) == false) + + assert(pcall(table.remove, table.freeze({1})) == false) + + assert(pcall(table.concat, {true}) == false) + + assert(pcall(table.create) == false) + assert(pcall(table.create, -1) == false) + assert(pcall(table.create, 1e9) == false) + + assert(pcall(table.find, {}, 42, 0) == false) + + assert(pcall(table.clear, table.freeze({})) == false) +end + return"OK" diff --git a/tests/conformance/tpack.lua b/tests/conformance/tpack.lua index 835bf5648..b240f4825 100644 --- a/tests/conformance/tpack.lua +++ b/tests/conformance/tpack.lua @@ -306,6 +306,8 @@ do -- testing initial position assert(i == 4 and p == 17) local i, p = unpack("!4 i4", x, -#x) assert(i == 1 and p == 5) + local i, p = unpack("!4 i4", x, 0) + assert(i == 1 and p == 5) -- limits for i = 1, #x + 1 do diff --git a/tools/faillist.txt b/tools/faillist.txt index 0eb022096..c869e0c47 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -18,12 +18,10 @@ AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.autocomplete_string_singleton_escape AutocompleteTest.autocomplete_string_singletons AutocompleteTest.autocompleteProp_index_function_metamethod_is_variadic -AutocompleteTest.cyclic_table AutocompleteTest.do_compatible_self_calls AutocompleteTest.do_wrong_compatible_self_calls AutocompleteTest.keyword_methods AutocompleteTest.no_incompatible_self_calls -AutocompleteTest.no_incompatible_self_calls_2 AutocompleteTest.no_wrong_compatible_self_calls_with_generics AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_argument_type_suggestion @@ -40,8 +38,6 @@ AutocompleteTest.type_correct_keywords AutocompleteTest.type_correct_suggestion_for_overloads AutocompleteTest.type_correct_suggestion_in_argument AutocompleteTest.type_correct_suggestion_in_table -AutocompleteTest.unsealed_table -AutocompleteTest.unsealed_table_2 BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 @@ -75,7 +71,6 @@ BuiltinTests.select_with_decimal_argument_is_rounded_down BuiltinTests.set_metatable_needs_arguments BuiltinTests.setmetatable_should_not_mutate_persisted_types BuiltinTests.sort_with_bad_predicate -BuiltinTests.sort_with_predicate BuiltinTests.string_format_arg_count_mismatch BuiltinTests.string_format_arg_types_inference BuiltinTests.string_format_as_method @@ -93,6 +88,7 @@ BuiltinTests.table_pack_variadic BuiltinTests.tonumber_returns_optional_number_type BuiltinTests.tonumber_returns_optional_number_type2 DefinitionTests.class_definition_overload_metamethods +DefinitionTests.class_definition_string_props DefinitionTests.declaring_generic_functions DefinitionTests.definition_file_classes FrontendTest.environments @@ -128,7 +124,7 @@ GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_methods -GenericsTests.inferred_local_vars_can_be_polytypes +GenericsTests.infer_generic_property GenericsTests.instantiate_cyclic_generic_function GenericsTests.instantiated_function_argument_names GenericsTests.instantiation_sharing_types @@ -173,6 +169,7 @@ ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns +ProvisionalTests.lvalue_equals_another_lvalue_with_no_overlap ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early @@ -234,17 +231,14 @@ RefinementTest.typeguard_not_to_be_string RefinementTest.what_nonsensical_condition RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RefinementTest.x_is_not_instance_or_else_not_part +RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.access_index_metamethod_that_returns_variadic TableTests.accidentally_checked_prop_in_opposite_branch -TableTests.assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer -TableTests.augment_nested_table -TableTests.augment_table TableTests.builtin_table_names TableTests.call_method TableTests.cannot_augment_sealed_table -TableTests.cannot_change_type_of_unsealed_table_prop TableTests.casting_sealed_tables_with_props_into_table_with_indexer TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_tables_with_props_into_table_with_indexer4 @@ -294,6 +288,7 @@ TableTests.metatable_mismatch_should_fail TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys TableTests.nil_assign_doesnt_hit_indexer +TableTests.nil_assign_doesnt_hit_no_indexer TableTests.okay_to_add_property_to_unsealed_tables_by_function_call TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.oop_indexer_works @@ -347,7 +342,6 @@ ToString.toStringNamedFunction_id ToString.toStringNamedFunction_include_self_param ToString.toStringNamedFunction_map ToString.toStringNamedFunction_variadics -TranspilerTests.types_should_not_be_considered_cyclic_if_they_are_not_recursive TryUnifyTests.cli_41095_concat_log_in_sealed_table_unification TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained @@ -377,22 +371,20 @@ TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.globals TypeInfer.globals2 TypeInfer.infer_assignment_value_types_mutable_lval +TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict TypeInfer.no_stack_overflow_from_isoptional TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.tc_if_else_expressions_expected_type_3 TypeInfer.tc_interpolated_string_basic -TypeInfer.tc_interpolated_string_constant_type TypeInfer.tc_interpolated_string_with_invalid_expression TypeInfer.type_infer_recursion_limit_no_ice -TypeInferAnyError.assign_prop_to_table_by_calling_any_yields_any +TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferAnyError.for_in_loop_iterator_is_error2 TypeInferClasses.call_base_method TypeInferClasses.call_instance_method -TypeInferClasses.can_assign_to_prop_of_base_class_using_string TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict -TypeInferClasses.classes_can_have_overloaded_operators TypeInferClasses.classes_without_overloaded_operators_cannot_be_added TypeInferClasses.detailed_class_unification_error TypeInferClasses.higher_order_function_arguments_are_contravariant @@ -420,6 +412,7 @@ TypeInferFunctions.infer_return_value_type TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count +TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.no_lossy_function_type TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.record_matching_overload @@ -459,9 +452,6 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted -TypeInferOperators.and_adds_boolean -TypeInferOperators.and_adds_boolean_no_superfluous_union -TypeInferOperators.and_binexps_dont_unify TypeInferOperators.and_or_ternary TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable @@ -471,23 +461,11 @@ TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.compound_assign_mismatch_op TypeInferOperators.compound_assign_mismatch_result -TypeInferOperators.concat_op_on_free_lhs_and_string_rhs TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops -TypeInferOperators.dont_strip_nil_from_rhs_or_operator -TypeInferOperators.equality_operations_succeed_if_any_union_branch_succeeds -TypeInferOperators.error_on_invalid_operand_types_to_relational_operators -TypeInferOperators.error_on_invalid_operand_types_to_relational_operators2 -TypeInferOperators.expected_types_through_binary_and -TypeInferOperators.expected_types_through_binary_or +TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown -TypeInferOperators.or_joins_types -TypeInferOperators.or_joins_types_with_no_extras -TypeInferOperators.primitive_arith_possible_metatable TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not TypeInferOperators.refine_and_or -TypeInferOperators.strict_binary_op_where_lhs_unknown -TypeInferOperators.strip_nil_from_lhs_or_operator -TypeInferOperators.strip_nil_from_lhs_or_operator2 TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.typecheck_unary_len_error @@ -535,26 +513,17 @@ TypePackTests.unify_variadic_tails_in_arguments TypePackTests.unify_variadic_tails_in_arguments_free TypePackTests.varargs_inference_through_multiple_scopes TypePackTests.variadic_packs -TypeSingletons.enums_using_singletons -TypeSingletons.enums_using_singletons_mismatch -TypeSingletons.enums_using_singletons_subtyping TypeSingletons.error_detailed_tagged_union_mismatch_bool TypeSingletons.error_detailed_tagged_union_mismatch_string TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch -TypeSingletons.if_then_else_expression_singleton_options TypeSingletons.indexing_on_string_singletons TypeSingletons.indexing_on_union_of_string_singletons -TypeSingletons.no_widening_from_callsites TypeSingletons.overloaded_function_call_with_singletons TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened -TypeSingletons.string_singleton_subtype -TypeSingletons.string_singletons -TypeSingletons.string_singletons_escape_chars -TypeSingletons.string_singletons_mismatch +TypeSingletons.table_properties_singleton_strings_mismatch TypeSingletons.table_properties_type_error_escapes -TypeSingletons.tagged_unions_using_singletons TypeSingletons.taking_the_length_of_string_singleton TypeSingletons.taking_the_length_of_union_of_string_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton diff --git a/tools/test_dcr.py b/tools/test_dcr.py index 76bf11ac7..6d553b648 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -42,6 +42,8 @@ def __init__(self, failList): self.fail_count = 0 self.test_count = 0 + self.crashed_tests = [] + def startElement(self, name, attrs): if name == "TestSuite": self.currentTest.append(attrs["name"]) @@ -69,6 +71,10 @@ def startElement(self, name, attrs): elif name == "OverallResultsTestCases": self.numSkippedTests = safeParseInt(attrs.get("skipped", 0)) + elif name == "Exception": + if attrs.get("crash") == "true": + self.crashed_tests.append(makeDottedName(self.currentTest)) + def endElement(self, name): if name == "TestCase": self.currentTest.pop() @@ -192,15 +198,23 @@ def main(): print(name, file=f) print_stderr("Updated faillist.txt") + if handler.crashed_tests: + print_stderr() + for test in handler.crashed_tests: + print_stderr( + f"{c.Fore.RED}{test}{c.Fore.RESET} threw an exception and crashed the test process!" + ) + if handler.numSkippedTests > 0: - print_stderr( - f"{handler.numSkippedTests} test(s) were skipped! That probably means that a test segfaulted!" + print_stderr(f"{handler.numSkippedTests} test(s) were skipped!") + + ok = ( + not handler.crashed_tests + and handler.numSkippedTests == 0 + and all( + not passed == (dottedName in failList) + for dottedName, passed in handler.results.items() ) - sys.exit(1) - - ok = all( - not passed == (dottedName in failList) - for dottedName, passed in handler.results.items() ) if ok: From 99c0db3b0845b6e9450753d5ed45a2cf6f3e6e68 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 28 Oct 2022 01:22:49 +0300 Subject: [PATCH 13/66] Sync to upstream/release/551 --- .../include/Luau/ConstraintGraphBuilder.h | 57 +- Analysis/include/Luau/Error.h | 24 +- Analysis/include/Luau/Normalize.h | 75 +- Analysis/include/Luau/RecursionCounter.h | 22 +- Analysis/include/Luau/ToString.h | 2 + Analysis/include/Luau/TypeInfer.h | 13 +- Analysis/include/Luau/Unifier.h | 2 + Analysis/src/AstQuery.cpp | 83 ++- Analysis/src/ConstraintGraphBuilder.cpp | 437 +++++------ Analysis/src/ConstraintSolver.cpp | 44 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 66 +- Analysis/src/Error.cpp | 43 ++ Analysis/src/Frontend.cpp | 153 +++- Analysis/src/IostreamHelpers.cpp | 2 + Analysis/src/Normalize.cpp | 406 +++++++++- Analysis/src/ToString.cpp | 40 +- Analysis/src/TopoSortStatements.cpp | 5 +- Analysis/src/TypeAttach.cpp | 2 +- Analysis/src/TypeChecker2.cpp | 61 +- Analysis/src/TypeInfer.cpp | 81 +- Analysis/src/TypePack.cpp | 3 +- Analysis/src/TypeVar.cpp | 8 +- Analysis/src/Unifier.cpp | 61 +- Ast/include/Luau/ParseResult.h | 2 + Ast/include/Luau/Parser.h | 4 +- Ast/src/Parser.cpp | 64 +- CLI/Repl.cpp | 19 +- CMakeLists.txt | 5 + CodeGen/include/Luau/CodeGen.h | 2 +- CodeGen/src/CodeGen.cpp | 691 ++++++++++-------- CodeGen/src/CodeGenUtils.cpp | 76 ++ CodeGen/src/CodeGenUtils.h | 17 + CodeGen/src/EmitCommonX64.h | 15 +- CodeGen/src/EmitInstructionX64.cpp | 383 +++++++++- CodeGen/src/EmitInstructionX64.h | 16 +- CodeGen/src/NativeState.cpp | 23 +- CodeGen/src/NativeState.h | 8 + Compiler/include/Luau/BytecodeBuilder.h | 2 +- Compiler/src/BytecodeBuilder.cpp | 18 +- Makefile | 4 +- Sources.cmake | 3 + VM/include/luaconf.h | 8 + VM/src/lbuiltins.cpp | 26 +- VM/src/lbuiltins.h | 2 +- VM/src/lnumutils.h | 1 + VM/src/lvmexecute.cpp | 52 +- bench/tests/voxelgen.lua | 456 ++++++++++++ tests/AstQuery.test.cpp | 71 ++ tests/Fixture.cpp | 9 + tests/Fixture.h | 2 + tests/Frontend.test.cpp | 27 + tests/Module.test.cpp | 10 +- tests/Normalize.test.cpp | 141 +++- tests/Parser.test.cpp | 2 - tests/ToString.test.cpp | 2 + tests/TypeInfer.aliases.test.cpp | 10 + tests/TypeInfer.annotations.test.cpp | 10 +- tests/TypeInfer.anyerror.test.cpp | 2 - tests/TypeInfer.definitions.test.cpp | 17 + tests/TypeInfer.functions.test.cpp | 53 +- tests/TypeInfer.modules.test.cpp | 18 +- tests/TypeInfer.negations.test.cpp | 52 ++ tests/TypeInfer.operators.test.cpp | 42 +- tests/TypeInfer.provisional.test.cpp | 13 +- tests/TypeInfer.tables.test.cpp | 61 +- tests/VisitTypeVar.test.cpp | 9 +- tools/faillist.txt | 23 +- 67 files changed, 3226 insertions(+), 935 deletions(-) create mode 100644 CodeGen/src/CodeGenUtils.cpp create mode 100644 CodeGen/src/CodeGenUtils.h create mode 100644 bench/tests/voxelgen.lua create mode 100644 tests/TypeInfer.negations.test.cpp diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index dc5d45988..6106717c5 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -23,6 +23,30 @@ using ScopePtr = std::shared_ptr; struct DcrLogger; +struct Inference +{ + TypeId ty = nullptr; + + Inference() = default; + + explicit Inference(TypeId ty) + : ty(ty) + { + } +}; + +struct InferencePack +{ + TypePackId tp = nullptr; + + InferencePack() = default; + + explicit InferencePack(TypePackId tp) + : tp(tp) + { + } +}; + struct ConstraintGraphBuilder { // A list of all the scopes in the module. This vector holds ownership of the @@ -130,8 +154,10 @@ struct ConstraintGraphBuilder void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); void visit(const ScopePtr& scope, AstStatError* error); - TypePackId checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes = {}); - TypePackId checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes = {}); + InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes = {}); + InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes = {}); + + InferencePack checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes); /** * Checks an expression that is expected to evaluate to one type. @@ -141,18 +167,19 @@ struct ConstraintGraphBuilder * surrounding context. Used to implement bidirectional type checking. * @return the type of the expression. */ - TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); - - TypeId check(const ScopePtr& scope, AstExprLocal* local); - TypeId check(const ScopePtr& scope, AstExprGlobal* global); - TypeId check(const ScopePtr& scope, AstExprIndexName* indexName); - TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); - TypeId check(const ScopePtr& scope, AstExprUnary* unary); - TypeId check_(const ScopePtr& scope, AstExprUnary* unary); - TypeId check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); - TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); - TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); - TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); + + Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprLocal* local); + Inference check(const ScopePtr& scope, AstExprGlobal* global); + Inference check(const ScopePtr& scope, AstExprIndexName* indexName); + Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); + Inference check(const ScopePtr& scope, AstExprUnary* unary); + Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); + Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); @@ -202,7 +229,7 @@ struct ConstraintGraphBuilder std::vector> createGenerics(const ScopePtr& scope, AstArray generics); std::vector> createGenericPacks(const ScopePtr& scope, AstArray packs); - TypeId flattenPack(const ScopePtr& scope, Location location, TypePackId tp); + Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); void reportError(Location location, TypeErrorData err); void reportCodeTooComplex(Location location); diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 7338627cf..f7bd9d502 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -7,6 +7,8 @@ #include "Luau/Variant.h" #include "Luau/TypeArena.h" +LUAU_FASTFLAG(LuauIceExceptionInheritanceChange) + namespace Luau { struct TypeError; @@ -302,12 +304,20 @@ struct NormalizationTooComplex } }; +struct TypePackMismatch +{ + TypePackId wantedTp; + TypePackId givenTp; + + bool operator==(const TypePackMismatch& rhs) const; +}; + using TypeErrorData = Variant; + TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch>; struct TypeError { @@ -374,6 +384,10 @@ struct InternalErrorReporter class InternalCompilerError : public std::exception { public: + explicit InternalCompilerError(const std::string& message) + : message(message) + { + } explicit InternalCompilerError(const std::string& message, const std::string& moduleName) : message(message) , moduleName(moduleName) @@ -388,8 +402,14 @@ class InternalCompilerError : public std::exception virtual const char* what() const throw(); const std::string message; - const std::string moduleName; + const std::optional moduleName; const std::optional location; }; +// These two function overloads only exist to facilitate fast flagging a change to InternalCompilerError +// Both functions can be removed when FFlagLuauIceExceptionInheritanceChange is removed and calling code +// can directly throw InternalCompilerError. +[[noreturn]] void throwRuntimeError(const std::string& message); +[[noreturn]] void throwRuntimeError(const std::string& message, const std::string& moduleName); + } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index a23d0fda0..f98442dd1 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -106,9 +106,68 @@ struct std::equal_to namespace Luau { -// A normalized string type is either `string` (represented by `nullopt`) -// or a union of string singletons. -using NormalizedStringType = std::optional>; +/** A normalized string type is either `string` (represented by `nullopt`) or a + * union of string singletons. + * + * When FFlagLuauNegatedStringSingletons is unset, the representation is as + * follows: + * + * * The `string` data type is represented by the option `singletons` having the + * value `std::nullopt`. + * * The type `never` is represented by `singletons` being populated with an + * empty map. + * * A union of string singletons is represented by a map populated by the names + * and TypeIds of the singletons contained therein. + * + * When FFlagLuauNegatedStringSingletons is set, the representation is as + * follows: + * + * * A union of string singletons is finite and includes the singletons named by + * the `singletons` field. + * * An intersection of negated string singletons is cofinite and includes the + * singletons excluded by the `singletons` field. It is implied that cofinite + * values are exclusions from `string` itself. + * * The `string` data type is a cofinite set minus zero elements. + * * The `never` data type is a finite set plus zero elements. + */ +struct NormalizedStringType +{ + // When false, this type represents a union of singleton string types. + // eg "a" | "b" | "c" + // + // When true, this type represents string intersected with negated string + // singleton types. + // eg string & ~"a" & ~"b" & ... + bool isCofinite = false; + + // TODO: This field cannot be nullopt when FFlagLuauNegatedStringSingletons + // is set. When clipping that flag, we can remove the wrapping optional. + std::optional> singletons; + + void resetToString(); + void resetToNever(); + + bool isNever() const; + bool isString() const; + + /// Returns true if the string has finite domain. + /// + /// Important subtlety: This method returns true for `never`. The empty set + /// is indeed an empty set. + bool isUnion() const; + + /// Returns true if the string has infinite domain. + bool isIntersection() const; + + bool includes(const std::string& str) const; + + static const NormalizedStringType never; + + NormalizedStringType() = default; + NormalizedStringType(bool isCofinite, std::optional> singletons); +}; + +bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr); // A normalized function type is either `never` (represented by `nullopt`) // or an intersection of function types. @@ -157,7 +216,7 @@ struct NormalizedType // The string part of the type. // This may be the `string` type, or a union of singletons. - NormalizedStringType strings = std::map{}; + NormalizedStringType strings; // The thread part of the type. // This type is either never or thread. @@ -231,8 +290,14 @@ class Normalizer bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); + // ------- Negations + NormalizedType negateNormal(const NormalizedType& here); + TypeIds negateAll(const TypeIds& theres); + TypeId negate(TypeId there); + void subtractPrimitive(NormalizedType& here, TypeId ty); + void subtractSingleton(NormalizedType& here, TypeId ty); + // ------- Normalizing intersections - void intersectTysWithTy(TypeIds& here, TypeId there); TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there); void intersectClasses(TypeIds& heres, const TypeIds& theres); diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index f964dbfe8..632afd195 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Common.h" +#include "Luau/Error.h" #include #include @@ -9,10 +10,20 @@ namespace Luau { -struct RecursionLimitException : public std::exception +struct RecursionLimitException : public InternalCompilerError +{ + RecursionLimitException() + : InternalCompilerError("Internal recursion counter limit exceeded") + { + LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange); + } +}; + +struct RecursionLimitException_DEPRECATED : public std::exception { const char* what() const noexcept { + LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange); return "Internal recursion counter limit exceeded"; } }; @@ -42,7 +53,14 @@ struct RecursionLimiter : RecursionCounter { if (limit > 0 && *count > limit) { - throw RecursionLimitException(); + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw RecursionLimitException(); + } + else + { + throw RecursionLimitException_DEPRECATED(); + } } } }; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index dd2d709bc..ff2561e65 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -117,6 +117,8 @@ inline std::string toStringNamedFunction(const std::string& funcName, const Func return toStringNamedFunction(funcName, ftv, opts); } +std::optional getFunctionNameAsString(const AstExpr& expr); + // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression std::string dump(TypeId ty); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 384637bbc..c5d7501dc 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -48,7 +48,17 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; -class TimeLimitError : public std::exception +class TimeLimitError : public InternalCompilerError +{ +public: + explicit TimeLimitError(const std::string& moduleName) + : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) + { + LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange); + } +}; + +class TimeLimitError_DEPRECATED : public std::exception { public: virtual const char* what() const throw(); @@ -236,6 +246,7 @@ struct TypeChecker [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + [[noreturn]] void throwTimeLimitError(); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); ScopePtr childScope(const ScopePtr& parent, const Location& location); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index c15cae31d..7bf4d50b7 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -96,6 +96,8 @@ struct Unifier void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy); + void tryUnifyNegationWithType(TypeId subTy, TypeId superTy); TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 502997048..b93c2cc22 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -11,6 +11,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauCheckOverloadedDocSymbol, false) + namespace Luau { @@ -427,6 +429,38 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) return findVisitor.result; } +static std::optional checkOverloadedDocumentationSymbol( + const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) +{ + LUAU_ASSERT(FFlag::LuauCheckOverloadedDocSymbol); + + if (!documentationSymbol) + return std::nullopt; + + // This might be an overloaded function. + if (get(follow(ty))) + { + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) + { + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; + } + } + + return documentationSymbol; +} + std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) { std::vector ancestry = findAstAncestryOfPosition(source, position); @@ -436,31 +470,38 @@ std::optional getDocumentationSymbolAtPosition(const Source if (std::optional binding = findBindingAtPosition(module, source, position)) { - if (binding->documentationSymbol) + if (FFlag::LuauCheckOverloadedDocSymbol) { - // This might be an overloaded function binding. - if (get(follow(binding->typeId))) + return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol); + } + else + { + if (binding->documentationSymbol) { - TypeId matchingOverload = nullptr; - if (parentExpr && parentExpr->is()) + // This might be an overloaded function binding. + if (get(follow(binding->typeId))) { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) { - matchingOverload = *it; + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } } - } - if (matchingOverload) - { - std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; - // Default toString options are fine for this purpose. - overloadSymbol += toString(matchingOverload); - return overloadSymbol; + if (matchingOverload) + { + std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; + } } } - } - return binding->documentationSymbol; + return binding->documentationSymbol; + } } if (targetExpr) @@ -474,14 +515,20 @@ std::optional getDocumentationSymbolAtPosition(const Source { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) { - return propIt->second.documentationSymbol; + if (FFlag::LuauCheckOverloadedDocSymbol) + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + else + return propIt->second.documentationSymbol; } } else if (const ClassTypeVar* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) { - return propIt->second.documentationSymbol; + if (FFlag::LuauCheckOverloadedDocSymbol) + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + else + return propIt->second.documentationSymbol; } } } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index de2b0a4e1..455fc221d 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -263,7 +263,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (hasAnnotation) expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes)); - TypePackId exprPack = checkPack(scope, value, expectedTypes); + TypePackId exprPack = checkPack(scope, value, expectedTypes).tp; if (i < local->vars.size) { @@ -292,7 +292,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (hasAnnotation) expectedType = varTypes.at(i); - TypeId exprType = check(scope, value, expectedType); + TypeId exprType = check(scope, value, expectedType).ty; if (i < varTypes.size()) { if (varTypes[i]) @@ -350,7 +350,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) if (!expr) return; - TypeId t = check(scope, expr); + TypeId t = check(scope, expr).ty; addConstraint(scope, expr->location, SubtypeConstraint{t, singletonTypes->numberType}); }; @@ -368,7 +368,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) { ScopePtr loopScope = childScope(forIn, scope); - TypePackId iterator = checkPack(scope, forIn->values); + TypePackId iterator = checkPack(scope, forIn->values).tp; std::vector variableTypes; variableTypes.reserve(forIn->vars.size); @@ -489,7 +489,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct } else if (AstExprIndexName* indexName = function->name->as()) { - TypeId containingTableType = check(scope, indexName->expr); + TypeId containingTableType = check(scope, indexName->expr).ty; functionType = arena->addType(BlockedTypeVar{}); @@ -531,7 +531,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) for (TypeId ty : scope->returnType) expectedTypes.push_back(ty); - TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes); + TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); } @@ -545,7 +545,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { TypePackId varPackId = checkLValues(scope, assign->vars); - TypePackId valuePack = checkPack(scope, assign->values); + TypePackId valuePack = checkPack(scope, assign->values).tp; addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); } @@ -732,7 +732,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) { - std::vector> generics = createGenerics(scope, global->generics); std::vector> genericPacks = createGenericPacks(scope, global->genericPacks); @@ -779,7 +778,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) check(scope, expr); } -TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes) { std::vector head; std::optional tail; @@ -792,201 +791,180 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray expectedType; if (i < expectedTypes.size()) expectedType = expectedTypes[i]; - head.push_back(check(scope, expr)); + head.push_back(check(scope, expr).ty); } else { std::vector expectedTailTypes; if (i < expectedTypes.size()) expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); - tail = checkPack(scope, expr, expectedTailTypes); + tail = checkPack(scope, expr, expectedTailTypes).tp; } } if (head.empty() && tail) - return *tail; + return InferencePack{*tail}; else - return arena->addTypePack(TypePack{std::move(head), tail}); + return InferencePack{arena->addTypePack(TypePack{std::move(head), tail})}; } -TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes) { RecursionCounter counter{&recursionCount}; if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(expr->location); - return singletonTypes->errorRecoveryTypePack(); + return InferencePack{singletonTypes->errorRecoveryTypePack()}; } - TypePackId result = nullptr; + InferencePack result; if (AstExprCall* call = expr->as()) - { - TypeId fnType = check(scope, call->func); - const size_t constraintIndex = scope->constraints.size(); - const size_t scopeIndex = scopes.size(); - - std::vector args; - - for (AstExpr* arg : call->args) - { - args.push_back(check(scope, arg)); - } - - // TODO self - - if (matchSetmetatable(*call)) - { - LUAU_ASSERT(args.size() == 2); - TypeId target = args[0]; - TypeId mt = args[1]; - - MetatableTypeVar mtv{target, mt}; - TypeId resultTy = arena->addType(mtv); - result = arena->addTypePack({resultTy}); - } - else - { - const size_t constraintEndIndex = scope->constraints.size(); - const size_t scopeEndIndex = scopes.size(); - - astOriginalCallTypes[call->func] = fnType; - - TypeId instantiatedType = arena->addType(BlockedTypeVar{}); - // TODO: How do expectedTypes play into this? Do they? - TypePackId rets = arena->addTypePack(BlockedTypePack{}); - TypePackId argPack = arena->addTypePack(TypePack{args, {}}); - FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); - TypeId inferredFnType = arena->addType(ftv); - - scope->unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); - NotNull ic(scope->unqueuedConstraints.back().get()); - - scope->unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); - NotNull sc(scope->unqueuedConstraints.back().get()); - - // We force constraints produced by checking function arguments to wait - // until after we have resolved the constraint on the function itself. - // This ensures, for instance, that we start inferring the contents of - // lambdas under the assumption that their arguments and return types - // will be compatible with the enclosing function call. - for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) - scope->constraints[ci]->dependencies.push_back(sc); - - for (size_t si = scopeIndex; si < scopeEndIndex; ++si) - { - for (auto& c : scopes[si].second->constraints) - { - c->dependencies.push_back(sc); - } - } - - addConstraint(scope, call->func->location, - FunctionCallConstraint{ - {ic, sc}, - fnType, - argPack, - rets, - call, - }); - - result = rets; - } - } + result = {checkPack(scope, call, expectedTypes)}; else if (AstExprVarargs* varargs = expr->as()) { if (scope->varargPack) - result = *scope->varargPack; + result = InferencePack{*scope->varargPack}; else - result = singletonTypes->errorRecoveryTypePack(); + result = InferencePack{singletonTypes->errorRecoveryTypePack()}; } else { std::optional expectedType; if (!expectedTypes.empty()) expectedType = expectedTypes[0]; - TypeId t = check(scope, expr, expectedType); - result = arena->addTypePack({t}); + TypeId t = check(scope, expr, expectedType).ty; + result = InferencePack{arena->addTypePack({t})}; } - LUAU_ASSERT(result); - astTypePacks[expr] = result; + LUAU_ASSERT(result.tp); + astTypePacks[expr] = result.tp; return result; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) { - RecursionCounter counter{&recursionCount}; + TypeId fnType = check(scope, call->func).ty; + const size_t constraintIndex = scope->constraints.size(); + const size_t scopeIndex = scopes.size(); - if (recursionCount >= FInt::LuauCheckRecursionLimit) + std::vector args; + + for (AstExpr* arg : call->args) { - reportCodeTooComplex(expr->location); - return singletonTypes->errorRecoveryType(); + args.push_back(check(scope, arg).ty); } - TypeId result = nullptr; + // TODO self - if (auto group = expr->as()) - result = check(scope, group->expr, expectedType); - else if (auto stringExpr = expr->as()) + if (matchSetmetatable(*call)) { - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - if (get(expectedTy) || get(expectedTy)) - { - result = arena->addType(BlockedTypeVar{}); - TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringExpr->value.data, stringExpr->value.size)})); - addConstraint(scope, expr->location, PrimitiveTypeConstraint{result, expectedTy, singletonType, singletonTypes->stringType}); - } - else if (maybeSingleton(expectedTy)) - result = arena->addType(SingletonTypeVar{StringSingleton{std::string{stringExpr->value.data, stringExpr->value.size}}}); - else - result = singletonTypes->stringType; - } - else - result = singletonTypes->stringType; + LUAU_ASSERT(args.size() == 2); + TypeId target = args[0]; + TypeId mt = args[1]; + + AstExpr* targetExpr = call->args.data[0]; + + MetatableTypeVar mtv{target, mt}; + TypeId resultTy = arena->addType(mtv); + + if (AstExprLocal* targetLocal = targetExpr->as()) + scope->bindings[targetLocal->local].typeId = resultTy; + + return InferencePack{arena->addTypePack({resultTy})}; } - else if (expr->is()) - result = singletonTypes->numberType; - else if (auto boolExpr = expr->as()) + else { - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; + const size_t constraintEndIndex = scope->constraints.size(); + const size_t scopeEndIndex = scopes.size(); + + astOriginalCallTypes[call->func] = fnType; + + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); + // TODO: How do expectedTypes play into this? Do they? + TypePackId rets = arena->addTypePack(BlockedTypePack{}); + TypePackId argPack = arena->addTypePack(TypePack{args, {}}); + FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); + TypeId inferredFnType = arena->addType(ftv); + + scope->unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); + NotNull ic(scope->unqueuedConstraints.back().get()); - if (get(expectedTy) || get(expectedTy)) + scope->unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); + NotNull sc(scope->unqueuedConstraints.back().get()); + + // We force constraints produced by checking function arguments to wait + // until after we have resolved the constraint on the function itself. + // This ensures, for instance, that we start inferring the contents of + // lambdas under the assumption that their arguments and return types + // will be compatible with the enclosing function call. + for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) + scope->constraints[ci]->dependencies.push_back(sc); + + for (size_t si = scopeIndex; si < scopeEndIndex; ++si) + { + for (auto& c : scopes[si].second->constraints) { - result = arena->addType(BlockedTypeVar{}); - addConstraint(scope, expr->location, PrimitiveTypeConstraint{result, expectedTy, singletonType, singletonTypes->booleanType}); + c->dependencies.push_back(sc); } - else if (maybeSingleton(expectedTy)) - result = singletonType; - else - result = singletonTypes->booleanType; } - else - result = singletonTypes->booleanType; + + addConstraint(scope, call->func->location, + FunctionCallConstraint{ + {ic, sc}, + fnType, + argPack, + rets, + call, + }); + + return InferencePack{rets}; } +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return Inference{singletonTypes->errorRecoveryType()}; + } + + Inference result; + + if (auto group = expr->as()) + result = check(scope, group->expr, expectedType); + else if (auto stringExpr = expr->as()) + result = check(scope, stringExpr, expectedType); + else if (expr->is()) + result = Inference{singletonTypes->numberType}; + else if (auto boolExpr = expr->as()) + result = check(scope, boolExpr, expectedType); else if (expr->is()) - result = singletonTypes->nilType; + result = Inference{singletonTypes->nilType}; else if (auto local = expr->as()) result = check(scope, local); else if (auto global = expr->as()) result = check(scope, global); else if (expr->is()) result = flattenPack(scope, expr->location, checkPack(scope, expr)); - else if (expr->is()) - result = flattenPack(scope, expr->location, checkPack(scope, expr)); // TODO: needs predicates too + else if (auto call = expr->as()) + { + std::vector expectedTypes; + if (expectedType) + expectedTypes.push_back(*expectedType); + result = flattenPack(scope, expr->location, checkPack(scope, call, expectedTypes)); // TODO: needs predicates too + } else if (auto a = expr->as()) { FunctionSignature sig = checkFunctionSignature(scope, a); checkFunctionBody(sig.bodyScope, a); - return sig.signature; + return Inference{sig.signature}; } else if (auto indexName = expr->as()) result = check(scope, indexName); @@ -1008,20 +986,63 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: for (AstExpr* subExpr : err->expressions) check(scope, subExpr); - result = singletonTypes->errorRecoveryType(); + result = Inference{singletonTypes->errorRecoveryType()}; } else { LUAU_ASSERT(0); - result = freshType(scope); + result = Inference{freshType(scope)}; } - LUAU_ASSERT(result); - astTypes[expr] = result; + LUAU_ASSERT(result.ty); + astTypes[expr] = result.ty; return result; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType) +{ + if (expectedType) + { + const TypeId expectedTy = follow(*expectedType); + if (get(expectedTy) || get(expectedTy)) + { + TypeId ty = arena->addType(BlockedTypeVar{}); + TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(string->value.data, string->value.size)})); + addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->stringType}); + return Inference{ty}; + } + else if (maybeSingleton(expectedTy)) + return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + + return Inference{singletonTypes->stringType}; + } + + return Inference{singletonTypes->stringType}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType) +{ + if (expectedType) + { + const TypeId expectedTy = follow(*expectedType); + const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; + + if (get(expectedTy) || get(expectedTy)) + { + TypeId ty = arena->addType(BlockedTypeVar{}); + addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->booleanType}); + return Inference{ty}; + } + else if (maybeSingleton(expectedTy)) + return Inference{singletonType}; + + return Inference{singletonTypes->booleanType}; + } + + return Inference{singletonTypes->booleanType}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) { std::optional resultTy; @@ -1035,26 +1056,26 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) } if (!resultTy) - return singletonTypes->errorRecoveryType(); // TODO: replace with ice, locals should never exist before its definition. + return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. - return *resultTy; + return Inference{*resultTy}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) { if (std::optional ty = scope->lookup(global->name)) - return *ty; + return Inference{*ty}; /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any * global that is not already in-scope is definitely an unknown symbol. */ reportError(global->location, UnknownSymbol{global->name.value}); - return singletonTypes->errorRecoveryType(); + return Inference{singletonTypes->errorRecoveryType()}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { - TypeId obj = check(scope, indexName->expr); + TypeId obj = check(scope, indexName->expr).ty; TypeId result = freshType(scope); TableTypeVar::Props props{{indexName->index.value, Property{result}}}; @@ -1065,13 +1086,13 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* in addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); - return result; + return Inference{result}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) { - TypeId obj = check(scope, indexExpr->expr); - TypeId indexType = check(scope, indexExpr->index); + TypeId obj = check(scope, indexExpr->expr).ty; + TypeId indexType = check(scope, indexExpr->index).ty; TypeId result = freshType(scope); @@ -1081,61 +1102,49 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* in addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); - return result; + return Inference{result}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { - TypeId operandType = check_(scope, unary); + TypeId operandType = check(scope, unary->expr).ty; TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); - return resultType; + return Inference{resultType}; } -TypeId ConstraintGraphBuilder::check_(const ScopePtr& scope, AstExprUnary* unary) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) { - if (unary->op == AstExprUnary::Not) - { - TypeId ty = check(scope, unary->expr, std::nullopt); - - return ty; - } - - return check(scope, unary->expr); -} - -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) -{ - TypeId leftType = check(scope, binary->left, expectedType); - TypeId rightType = check(scope, binary->right, expectedType); + TypeId leftType = check(scope, binary->left, expectedType).ty; + TypeId rightType = check(scope, binary->right, expectedType).ty; TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); - return resultType; + return Inference{resultType}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) { check(scope, ifElse->condition); - TypeId thenType = check(scope, ifElse->trueExpr, expectedType); - TypeId elseType = check(scope, ifElse->falseExpr, expectedType); + TypeId thenType = check(scope, ifElse->trueExpr, expectedType).ty; + TypeId elseType = check(scope, ifElse->falseExpr, expectedType).ty; if (ifElse->hasElse) { TypeId resultType = expectedType ? *expectedType : freshType(scope); addConstraint(scope, ifElse->trueExpr->location, SubtypeConstraint{thenType, resultType}); addConstraint(scope, ifElse->falseExpr->location, SubtypeConstraint{elseType, resultType}); - return resultType; + return Inference{resultType}; } - return thenType; + return Inference{thenType}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) { check(scope, typeAssert->expr, std::nullopt); - return resolveType(scope, typeAssert->annotation); + return Inference{resolveType(scope, typeAssert->annotation)}; } TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) @@ -1286,22 +1295,22 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) auto dottedPath = extractDottedName(expr); if (!dottedPath) - return check(scope, expr); + return check(scope, expr).ty; const auto [sym, segments] = std::move(*dottedPath); if (!sym.local) - return check(scope, expr); + return check(scope, expr).ty; auto lookupResult = scope->lookupEx(sym); if (!lookupResult) - return check(scope, expr); + return check(scope, expr).ty; const auto [ty, symbolScope] = std::move(*lookupResult); TypeId replaceTy = arena->freshType(scope.get()); std::optional updatedType = updateTheTableType(arena, ty, segments, replaceTy); if (!updatedType) - return check(scope, expr); + return check(scope, expr).ty; std::optional def = dfg->getDef(sym); LUAU_ASSERT(def); @@ -1310,7 +1319,7 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) return replaceTy; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) { TypeId ty = arena->addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(ty); @@ -1344,16 +1353,14 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, } } - TypeId itemTy = check(scope, item.value, expectedValueType); - if (get(follow(itemTy))) - return ty; + TypeId itemTy = check(scope, item.value, expectedValueType).ty; if (item.key) { // Even though we don't need to use the type of the item's key if // it's a string constant, we still want to check it to populate // astTypes. - TypeId keyTy = check(scope, item.key); + TypeId keyTy = check(scope, item.key).ty; if (AstExprConstantString* key = item.key->as()) { @@ -1373,7 +1380,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, } } - return ty; + return Inference{ty}; } ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn) @@ -1541,9 +1548,18 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b } } - std::optional alias = scope->lookupType(ref->name.value); + std::optional alias; - if (alias.has_value() || ref->prefix.has_value()) + if (ref->prefix.has_value()) + { + alias = scope->lookupImportedType(ref->prefix->value, ref->name.value); + } + else + { + alias = scope->lookupType(ref->name.value); + } + + if (alias.has_value()) { // If the alias is not generic, we don't need to set up a blocked // type and an instantiation constraint. @@ -1586,7 +1602,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b } else { - reportError(ty->location, UnknownSymbol{ref->name.value, UnknownSymbol::Context::Type}); + std::string typeName; + if (ref->prefix) + typeName = std::string(ref->prefix->value) + "."; + typeName += ref->name.value; + result = singletonTypes->errorRecoveryType(); } } @@ -1685,7 +1705,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b else if (auto tof = ty->as()) { // TODO: Recursion limit. - TypeId exprType = check(scope, tof->expr); + TypeId exprType = check(scope, tof->expr).ty; result = exprType; } else if (auto unionAnnotation = ty->as()) @@ -1694,7 +1714,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (AstType* part : unionAnnotation->types) { // TODO: Recursion limit. - parts.push_back(resolveType(scope, part)); + parts.push_back(resolveType(scope, part, topLevel)); } result = arena->addType(UnionTypeVar{parts}); @@ -1705,7 +1725,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (AstType* part : intersectionAnnotation->types) { // TODO: Recursion limit. - parts.push_back(resolveType(scope, part)); + parts.push_back(resolveType(scope, part, topLevel)); } result = arena->addType(IntersectionTypeVar{parts}); @@ -1795,10 +1815,7 @@ std::vector> ConstraintGraphBuilder::crea if (generic.defaultValue) defaultTy = resolveType(scope, generic.defaultValue); - result.push_back({generic.name.value, GenericTypeDefinition{ - genericTy, - defaultTy, - }}); + result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); } return result; @@ -1816,19 +1833,17 @@ std::vector> ConstraintGraphBuilder:: if (generic.defaultValue) defaultTy = resolveTypePack(scope, generic.defaultValue); - result.push_back({generic.name.value, GenericTypePackDefinition{ - genericTy, - defaultTy, - }}); + result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); } return result; } -TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, TypePackId tp) +Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) { + auto [tp] = pack; if (auto f = first(tp)) - return *f; + return Inference{*f}; TypeId typeResult = freshType(scope); TypePack onePack{{typeResult}, freshTypePack(scope)}; @@ -1836,7 +1851,7 @@ TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location locat addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); - return typeResult; + return Inference{typeResult}; } void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 60f4666aa..5e43be0f8 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -544,6 +544,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(singletonTypes->numberType); return true; } @@ -552,13 +553,46 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull(operandType) || get(operandType)) { asMutable(c.resultType)->ty.emplace(c.operandType); - return true; } - break; + else if (std::optional mm = findMetatableEntry(singletonTypes, errors, operandType, "__unm", constraint->location)) + { + const FunctionTypeVar* ftv = get(follow(*mm)); + + if (!ftv) + { + if (std::optional callMm = findMetatableEntry(singletonTypes, errors, follow(*mm), "__call", constraint->location)) + { + ftv = get(follow(*callMm)); + } + } + + if (!ftv) + { + asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + return true; + } + + TypePackId argsPack = arena->addTypePack({operandType}); + unify(ftv->argTypes, argsPack, constraint->scope); + + TypeId result = singletonTypes->errorRecoveryType(); + if (ftv) + { + result = first(ftv->retTypes).value_or(singletonTypes->errorRecoveryType()); + } + + asMutable(c.resultType)->ty.emplace(result); + } + else + { + asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + } + + return true; } } - LUAU_ASSERT(false); // TODO metatable handling + LUAU_ASSERT(false); return false; } @@ -862,6 +896,10 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNullname = c.name; else if (MetatableTypeVar* mtv = getMutable(target)) mtv->syntheticName = c.name; + else if (get(target) || get(target)) + { + // nothing (yet) + } else return block(c.namedType, constraint); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 0f04ace08..67abbff1f 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -26,34 +26,34 @@ declare bit32: { } declare math: { - frexp: (number) -> (number, number), - ldexp: (number, number) -> number, - fmod: (number, number) -> number, - modf: (number) -> (number, number), - pow: (number, number) -> number, - exp: (number) -> number, - - ceil: (number) -> number, - floor: (number) -> number, - abs: (number) -> number, - sqrt: (number) -> number, - - log: (number, number?) -> number, - log10: (number) -> number, - - rad: (number) -> number, - deg: (number) -> number, - - sin: (number) -> number, - cos: (number) -> number, - tan: (number) -> number, - sinh: (number) -> number, - cosh: (number) -> number, - tanh: (number) -> number, - atan: (number) -> number, - acos: (number) -> number, - asin: (number) -> number, - atan2: (number, number) -> number, + frexp: (n: number) -> (number, number), + ldexp: (s: number, e: number) -> number, + fmod: (x: number, y: number) -> number, + modf: (n: number) -> (number, number), + pow: (x: number, y: number) -> number, + exp: (n: number) -> number, + + ceil: (n: number) -> number, + floor: (n: number) -> number, + abs: (n: number) -> number, + sqrt: (n: number) -> number, + + log: (n: number, base: number?) -> number, + log10: (n: number) -> number, + + rad: (n: number) -> number, + deg: (n: number) -> number, + + sin: (n: number) -> number, + cos: (n: number) -> number, + tan: (n: number) -> number, + sinh: (n: number) -> number, + cosh: (n: number) -> number, + tanh: (n: number) -> number, + atan: (n: number) -> number, + acos: (n: number) -> number, + asin: (n: number) -> number, + atan2: (y: number, x: number) -> number, min: (number, ...number) -> number, max: (number, ...number) -> number, @@ -61,13 +61,13 @@ declare math: { pi: number, huge: number, - randomseed: (number) -> (), + randomseed: (seed: number) -> (), random: (number?, number?) -> number, - sign: (number) -> number, - clamp: (number, number, number) -> number, - noise: (number, number?, number?) -> number, - round: (number) -> number, + sign: (n: number) -> number, + clamp: (n: number, min: number, max: number) -> number, + noise: (x: number, y: number?, z: number?) -> number, + round: (n: number) -> number, } type DateTypeArg = { diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index e55530036..ed1a49cde 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,6 +7,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauIceExceptionInheritanceChange, false) + static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { @@ -460,6 +462,11 @@ struct ErrorConverter { return "Code is too complex to typecheck! Consider simplifying the code around this area"; } + + std::string operator()(const TypePackMismatch& e) const + { + return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'"; + } }; struct InvalidNameChecker @@ -718,6 +725,11 @@ bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const return left == rhs.left && right == rhs.right; } +bool TypePackMismatch::operator==(const TypePackMismatch& rhs) const +{ + return *wantedTp == *rhs.wantedTp && *givenTp == *rhs.givenTp; +} + std::string toString(const TypeError& error) { return toString(error, TypeErrorToStringOptions{}); @@ -869,6 +881,11 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) else if constexpr (std::is_same_v) { } + else if constexpr (std::is_same_v) + { + e.wantedTp = clone(e.wantedTp); + e.givenTp = clone(e.givenTp); + } else static_assert(always_false_v, "Non-exhaustive type switch"); } @@ -913,4 +930,30 @@ const char* InternalCompilerError::what() const throw() return this->message.data(); } +// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. +void throwRuntimeError(const std::string& message) +{ + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw InternalCompilerError(message); + } + else + { + throw std::runtime_error(message); + } +} + +// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. +void throwRuntimeError(const std::string& message, const std::string& moduleName) +{ + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw InternalCompilerError(message, moduleName); + } + else + { + throw std::runtime_error(message); + } +} + } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 8f2a3ebd6..39e6428d2 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -30,6 +30,8 @@ LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAGVARIABLE(LuauFixMarkDirtyReverseDeps, false) +LUAU_FASTFLAGVARIABLE(LuauPersistTypesAfterGeneratingDocSyms, false) namespace Luau { @@ -110,24 +112,57 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c CloneState cloneState; - for (const auto& [name, ty] : checkedModule->declaredGlobals) + if (FFlag::LuauPersistTypesAfterGeneratingDocSyms) { - TypeId globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); - persist(globalTy); - } + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + typesToPersist.push_back(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + globalScope->exportedTypeBindings[name] = globalTy; - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); + } + } + else { - TypeFun globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - globalScope->exportedTypeBindings[name] = globalTy; + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + persist(globalTy); + } - persist(globalTy.type); + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + globalScope->exportedTypeBindings[name] = globalTy; + + persist(globalTy.type); + } } return LoadDefinitionFileResult{true, parseResult, checkedModule}; @@ -159,24 +194,57 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t CloneState cloneState; - for (const auto& [name, ty] : checkedModule->declaredGlobals) + if (FFlag::LuauPersistTypesAfterGeneratingDocSyms) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); - persist(globalTy); - } + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + typesToPersist.push_back(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); + } + } + else { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - targetScope->exportedTypeBindings[name] = globalTy; + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + persist(globalTy); + } - persist(globalTy.type); + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; + + persist(globalTy.type); + } } return LoadDefinitionFileResult{true, parseResult, checkedModule}; @@ -425,13 +493,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); + throwRuntimeError("Frontend::modules does not have data for " + name, name); } else { auto it2 = moduleResolver.modules.find(name); if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); + throwRuntimeError("Frontend::modules does not have data for " + name, name); } return CheckResult{ @@ -538,7 +606,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional* marked sourceNode.dirtyModule = true; sourceNode.dirtyModuleForAutocomplete = true; - if (0 == reverseDeps.count(name)) - continue; + if (FFlag::LuauFixMarkDirtyReverseDeps) + { + if (0 == reverseDeps.count(next)) + continue; - sourceModules.erase(name); + sourceModules.erase(next); - const std::vector& dependents = reverseDeps[name]; - queue.insert(queue.end(), dependents.begin(), dependents.end()); + const std::vector& dependents = reverseDeps[next]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); + } + else + { + if (0 == reverseDeps.count(name)) + continue; + + sourceModules.erase(name); + + const std::vector& dependents = reverseDeps[name]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); + } } } @@ -993,11 +1074,11 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const double timestamp = getTimestamp(); - auto parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); + Luau::ParseResult parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); stats.timeParse += getTimestamp() - timestamp; stats.files++; - stats.lines += std::count(src.begin(), src.end(), '\n') + (src.size() && src.back() != '\n'); + stats.lines += parseResult.lines; if (!parseResult.errors.empty()) sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end()); diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index e4fac4554..b47270a07 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -188,6 +188,8 @@ static void errorToString(std::ostream& stream, const T& err) stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; else if constexpr (std::is_same_v) stream << "NormalizationTooComplex { }"; + else if constexpr (std::is_same_v) + stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }"; else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index cea159c36..5ef4b7e7c 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -7,6 +7,7 @@ #include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/TypeVar.h" #include "Luau/Unifier.h" #include "Luau/VisitTypeVar.h" @@ -18,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); +LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); @@ -107,12 +109,110 @@ bool TypeIds::operator==(const TypeIds& there) const return hash == there.hash && types == there.types; } +NormalizedStringType::NormalizedStringType(bool isCofinite, std::optional> singletons) + : isCofinite(isCofinite) + , singletons(std::move(singletons)) +{ + if (!FFlag::LuauNegatedStringSingletons) + LUAU_ASSERT(!isCofinite); +} + +void NormalizedStringType::resetToString() +{ + if (FFlag::LuauNegatedStringSingletons) + { + isCofinite = true; + singletons->clear(); + } + else + singletons.reset(); +} + +void NormalizedStringType::resetToNever() +{ + if (FFlag::LuauNegatedStringSingletons) + { + isCofinite = false; + singletons.emplace(); + } + else + { + if (singletons) + singletons->clear(); + else + singletons.emplace(); + } +} + +bool NormalizedStringType::isNever() const +{ + if (FFlag::LuauNegatedStringSingletons) + return !isCofinite && singletons->empty(); + else + return singletons && singletons->empty(); +} + +bool NormalizedStringType::isString() const +{ + if (FFlag::LuauNegatedStringSingletons) + return isCofinite && singletons->empty(); + else + return !singletons; +} + +bool NormalizedStringType::isUnion() const +{ + if (FFlag::LuauNegatedStringSingletons) + return !isCofinite; + else + return singletons.has_value(); +} + +bool NormalizedStringType::isIntersection() const +{ + if (FFlag::LuauNegatedStringSingletons) + return isCofinite; + else + return false; +} + +bool NormalizedStringType::includes(const std::string& str) const +{ + if (isString()) + return true; + else if (isUnion() && singletons->count(str)) + return true; + else if (isIntersection() && !singletons->count(str)) + return true; + else + return false; +} + +const NormalizedStringType NormalizedStringType::never{false, {{}}}; + +bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr) +{ + if (subStr.isUnion() && superStr.isUnion()) + { + for (auto [name, ty] : *subStr.singletons) + { + if (!superStr.singletons->count(name)) + return false; + } + } + else if (subStr.isString() && superStr.isUnion()) + return false; + + return true; +} + NormalizedType::NormalizedType(NotNull singletonTypes) : tops(singletonTypes->neverType) , booleans(singletonTypes->neverType) , errors(singletonTypes->neverType) , nils(singletonTypes->neverType) , numbers(singletonTypes->neverType) + , strings{NormalizedStringType::never} , threads(singletonTypes->neverType) { } @@ -120,7 +220,7 @@ NormalizedType::NormalizedType(NotNull singletonTypes) static bool isInhabited(const NormalizedType& norm) { return !get(norm.tops) || !get(norm.booleans) || !norm.classes.empty() || !get(norm.errors) || - !get(norm.nils) || !get(norm.numbers) || !norm.strings || !norm.strings->empty() || + !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty(); } @@ -183,10 +283,10 @@ static bool isNormalizedNumber(TypeId ty) static bool isNormalizedString(const NormalizedStringType& ty) { - if (!ty) + if (ty.isString()) return true; - for (auto& [str, ty] : *ty) + for (auto& [str, ty] : *ty.singletons) { if (const SingletonTypeVar* stv = get(ty)) { @@ -317,10 +417,7 @@ void Normalizer::clearNormal(NormalizedType& norm) norm.errors = singletonTypes->neverType; norm.nils = singletonTypes->neverType; norm.numbers = singletonTypes->neverType; - if (norm.strings) - norm.strings->clear(); - else - norm.strings.emplace(); + norm.strings.resetToNever(); norm.threads = singletonTypes->neverType; norm.tables.clear(); norm.functions = std::nullopt; @@ -495,10 +592,56 @@ void Normalizer::unionClasses(TypeIds& heres, const TypeIds& theres) void Normalizer::unionStrings(NormalizedStringType& here, const NormalizedStringType& there) { - if (!there) - here.reset(); - else if (here) - here->insert(there->begin(), there->end()); + if (FFlag::LuauNegatedStringSingletons) + { + if (there.isString()) + here.resetToString(); + else if (here.isUnion() && there.isUnion()) + here.singletons->insert(there.singletons->begin(), there.singletons->end()); + else if (here.isUnion() && there.isIntersection()) + { + here.isCofinite = true; + for (const auto& pair : *there.singletons) + { + auto it = here.singletons->find(pair.first); + if (it != end(*here.singletons)) + here.singletons->erase(it); + else + here.singletons->insert(pair); + } + } + else if (here.isIntersection() && there.isUnion()) + { + for (const auto& [name, ty] : *there.singletons) + here.singletons->erase(name); + } + else if (here.isIntersection() && there.isIntersection()) + { + auto iter = begin(*here.singletons); + auto endIter = end(*here.singletons); + + while (iter != endIter) + { + if (!there.singletons->count(iter->first)) + { + auto eraseIt = iter; + ++iter; + here.singletons->erase(eraseIt); + } + else + ++iter; + } + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + { + if (there.isString()) + here.resetToString(); + else if (here.isUnion()) + here.singletons->insert(there.singletons->begin(), there.singletons->end()); + } } std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePackId there) @@ -858,7 +1001,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else if (ptv->type == PrimitiveTypeVar::Number) here.numbers = there; else if (ptv->type == PrimitiveTypeVar::String) - here.strings = std::nullopt; + here.strings.resetToString(); else if (ptv->type == PrimitiveTypeVar::Thread) here.threads = there; else @@ -870,12 +1013,33 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.booleans = unionOfBools(here.booleans, there); else if (const StringSingleton* sstv = get(stv)) { - if (here.strings) - here.strings->insert({sstv->value, there}); + if (FFlag::LuauNegatedStringSingletons) + { + if (here.strings.isCofinite) + { + auto it = here.strings.singletons->find(sstv->value); + if (it != here.strings.singletons->end()) + here.strings.singletons->erase(it); + } + else + here.strings.singletons->insert({sstv->value, there}); + } + else + { + if (here.strings.isUnion()) + here.strings.singletons->insert({sstv->value, there}); + } } else LUAU_ASSERT(!"Unreachable"); } + else if (const NegationTypeVar* ntv = get(there)) + { + const NormalizedType* thereNormal = normalize(ntv->ty); + NormalizedType tn = negateNormal(*thereNormal); + if (!unionNormals(here, tn)) + return false; + } else LUAU_ASSERT(!"Unreachable"); @@ -887,6 +1051,159 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor return true; } +// ------- Negations + +NormalizedType Normalizer::negateNormal(const NormalizedType& here) +{ + NormalizedType result{singletonTypes}; + if (!get(here.tops)) + { + // The negation of unknown or any is never. Easy. + return result; + } + + if (!get(here.errors)) + { + // Negating an error yields the same error. + result.errors = here.errors; + return result; + } + + if (get(here.booleans)) + result.booleans = singletonTypes->booleanType; + else if (get(here.booleans)) + result.booleans = singletonTypes->neverType; + else if (auto stv = get(here.booleans)) + { + auto boolean = get(stv); + LUAU_ASSERT(boolean != nullptr); + if (boolean->value) + result.booleans = singletonTypes->falseType; + else + result.booleans = singletonTypes->trueType; + } + + result.classes = negateAll(here.classes); + result.nils = get(here.nils) ? singletonTypes->nilType : singletonTypes->neverType; + result.numbers = get(here.numbers) ? singletonTypes->numberType : singletonTypes->neverType; + + result.strings = here.strings; + result.strings.isCofinite = !result.strings.isCofinite; + + result.threads = get(here.threads) ? singletonTypes->threadType : singletonTypes->neverType; + + // TODO: negating tables + // TODO: negating functions + // TODO: negating tyvars? + + return result; +} + +TypeIds Normalizer::negateAll(const TypeIds& theres) +{ + TypeIds tys; + for (TypeId there : theres) + tys.insert(negate(there)); + return tys; +} + +TypeId Normalizer::negate(TypeId there) +{ + there = follow(there); + if (get(there)) + return there; + else if (get(there)) + return singletonTypes->neverType; + else if (get(there)) + return singletonTypes->unknownType; + else if (auto ntv = get(there)) + return ntv->ty; // TODO: do we want to normalize this? + else if (auto utv = get(there)) + { + std::vector parts; + for (TypeId option : utv) + parts.push_back(negate(option)); + return arena->addType(IntersectionTypeVar{std::move(parts)}); + } + else if (auto itv = get(there)) + { + std::vector options; + for (TypeId part : itv) + options.push_back(negate(part)); + return arena->addType(UnionTypeVar{std::move(options)}); + } + else + return there; +} + +void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) +{ + const PrimitiveTypeVar* ptv = get(follow(ty)); + LUAU_ASSERT(ptv); + switch (ptv->type) + { + case PrimitiveTypeVar::NilType: + here.nils = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Boolean: + here.booleans = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Number: + here.numbers = singletonTypes->neverType; + break; + case PrimitiveTypeVar::String: + here.strings.resetToNever(); + break; + case PrimitiveTypeVar::Thread: + here.threads = singletonTypes->neverType; + break; + } +} + +void Normalizer::subtractSingleton(NormalizedType& here, TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauNegatedStringSingletons); + + const SingletonTypeVar* stv = get(ty); + LUAU_ASSERT(stv); + + if (const StringSingleton* ss = get(stv)) + { + if (here.strings.isCofinite) + here.strings.singletons->insert({ss->value, ty}); + else + { + auto it = here.strings.singletons->find(ss->value); + if (it != here.strings.singletons->end()) + here.strings.singletons->erase(it); + } + } + else if (const BooleanSingleton* bs = get(stv)) + { + if (get(here.booleans)) + { + // Nothing + } + else if (get(here.booleans)) + here.booleans = bs->value ? singletonTypes->falseType : singletonTypes->trueType; + else if (auto hereSingleton = get(here.booleans)) + { + const BooleanSingleton* hereBooleanSingleton = get(hereSingleton); + LUAU_ASSERT(hereBooleanSingleton); + + // Crucial subtlety: ty (and thus bs) are the value that is being + // negated out. We therefore reduce to never when the values match, + // rather than when they differ. + if (bs->value == hereBooleanSingleton->value) + here.booleans = singletonTypes->neverType; + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + LUAU_ASSERT(!"Unreachable"); +} + // ------- Normalizing intersections TypeId Normalizer::intersectionOfTops(TypeId here, TypeId there) { @@ -971,17 +1288,17 @@ void Normalizer::intersectClassesWithClass(TypeIds& heres, TypeId there) void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) { - if (!there) + if (there.isString()) return; - if (!here) - here.emplace(); + if (here.isString()) + here.resetToNever(); - for (auto it = here->begin(); it != here->end();) + for (auto it = here.singletons->begin(); it != here.singletons->end();) { - if (there->count(it->first)) + if (there.singletons->count(it->first)) it++; else - it = here->erase(it); + it = here.singletons->erase(it); } } @@ -1646,12 +1963,35 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) here.booleans = intersectionOfBools(booleans, there); else if (const StringSingleton* sstv = get(stv)) { - if (!strings || strings->count(sstv->value)) - here.strings->insert({sstv->value, there}); + if (strings.includes(sstv->value)) + here.strings.singletons->insert({sstv->value, there}); } else LUAU_ASSERT(!"Unreachable"); } + else if (const NegationTypeVar* ntv = get(there); FFlag::LuauNegatedStringSingletons && ntv) + { + TypeId t = follow(ntv->ty); + if (const PrimitiveTypeVar* ptv = get(t)) + subtractPrimitive(here, ntv->ty); + else if (const SingletonTypeVar* stv = get(t)) + subtractSingleton(here, follow(ntv->ty)); + else if (const UnionTypeVar* itv = get(t)) + { + for (TypeId part : itv->options) + { + const NormalizedType* normalPart = normalize(part); + NormalizedType negated = negateNormal(*normalPart); + intersectNormals(here, negated); + } + } + else + { + // TODO negated unions, intersections, table, and function. + // Report a TypeError for other types. + LUAU_ASSERT(!"Unimplemented"); + } + } else LUAU_ASSERT(!"Unreachable"); @@ -1691,11 +2031,25 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) result.push_back(norm.nils); if (!get(norm.numbers)) result.push_back(norm.numbers); - if (norm.strings) - for (auto& [_, ty] : *norm.strings) - result.push_back(ty); - else + if (norm.strings.isString()) result.push_back(singletonTypes->stringType); + else if (norm.strings.isUnion()) + { + for (auto& [_, ty] : *norm.strings.singletons) + result.push_back(ty); + } + else if (FFlag::LuauNegatedStringSingletons && norm.strings.isIntersection()) + { + std::vector parts; + parts.push_back(singletonTypes->stringType); + for (const auto& [name, ty] : *norm.strings.singletons) + parts.push_back(arena->addType(NegationTypeVar{ty})); + + result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); + } + if (!get(norm.threads)) + result.push_back(singletonTypes->threadType); + result.insert(result.end(), norm.tables.begin(), norm.tables.end()); for (auto& [tyvar, intersect] : norm.tyvars) { diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5897ca211..44000647a 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(LuauLvaluelessPath) LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) @@ -434,7 +435,7 @@ struct TypeVarStringifier return; default: LUAU_ASSERT(!"Unknown primitive type"); - throw std::runtime_error("Unknown primitive type " + std::to_string(ptv.type)); + throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type)); } } @@ -451,7 +452,7 @@ struct TypeVarStringifier else { LUAU_ASSERT(!"Unknown singleton type"); - throw std::runtime_error("Unknown singleton type"); + throwRuntimeError("Unknown singleton type"); } } @@ -1538,6 +1539,8 @@ std::string dump(const Constraint& c) std::string toString(const LValue& lvalue) { + LUAU_ASSERT(!FFlag::LuauLvaluelessPath); + std::string s; for (const LValue* current = &lvalue; current; current = baseof(*current)) { @@ -1552,4 +1555,37 @@ std::string toString(const LValue& lvalue) return s; } +std::optional getFunctionNameAsString(const AstExpr& expr) +{ + LUAU_ASSERT(FFlag::LuauLvaluelessPath); + + const AstExpr* curr = &expr; + std::string s; + + for (;;) + { + if (auto local = curr->as()) + return local->local->name.value + s; + + if (auto global = curr->as()) + return global->name.value + s; + + if (auto indexname = curr->as()) + { + curr = indexname->expr; + + s = "." + std::string(indexname->index.value) + s; + } + else if (auto group = curr->as()) + { + curr = group->expr; + } + else + { + return std::nullopt; + } + } + + return s; +} } // namespace Luau diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 1ea2e27d0..052c10dea 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TopoSortStatements.h" +#include "Luau/Error.h" /* Decide the order in which we typecheck Lua statements in a block. * * Algorithm: @@ -149,7 +150,7 @@ Identifier mkName(const AstStatFunction& function) auto name = mkName(*function.name); LUAU_ASSERT(bool(name)); if (!name) - throw std::runtime_error("Internal error: Function declaration has a bad name"); + throwRuntimeError("Internal error: Function declaration has a bad name"); return *name; } @@ -255,7 +256,7 @@ struct ArcCollector : public AstVisitor { auto name = mkName(*node->name); if (!name) - throw std::runtime_error("Internal error: AstStatFunction has a bad name"); + throwRuntimeError("Internal error: AstStatFunction has a bad name"); add(*name); return true; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f2613cae2..179846d7c 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -347,7 +347,7 @@ class TypeRehydrationVisitor AstType* operator()(const NegationTypeVar& ntv) { // FIXME: do the same thing we do with ErrorTypeVar - throw std::runtime_error("Cannot convert NegationTypeVar into AstNode"); + throwRuntimeError("Cannot convert NegationTypeVar into AstNode"); } private: diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index bd220e9c0..a26731586 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -934,8 +934,62 @@ struct TypeChecker2 void visit(AstExprUnary* expr) { - // TODO! visit(expr->expr); + + NotNull scope = stack.back(); + TypeId operandType = lookupType(expr->expr); + + if (get(operandType) || get(operandType) || get(operandType)) + return; + + if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end()) + { + std::optional mm = findMetatableEntry(singletonTypes, module->errors, operandType, it->second, expr->location); + if (mm) + { + if (const FunctionTypeVar* ftv = get(follow(*mm))) + { + TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); + reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); + + if (std::optional ret = first(ftv->retTypes)) + { + if (expr->op == AstExprUnary::Op::Len) + { + reportErrors(tryUnify(scope, expr->location, follow(*ret), singletonTypes->numberType)); + } + } + else + { + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + } + } + + return; + } + } + + if (expr->op == AstExprUnary::Op::Len) + { + DenseHashSet seen{nullptr}; + int recursionCount = 0; + + if (!hasLength(operandType, seen, &recursionCount)) + { + reportError(NotATable{operandType}, expr->location); + } + } + else if (expr->op == AstExprUnary::Op::Minus) + { + reportErrors(tryUnify(scope, expr->location, operandType, singletonTypes->numberType)); + } + else if (expr->op == AstExprUnary::Op::Not) + { + } + else + { + LUAU_ASSERT(!"Unhandled unary operator"); + } } void visit(AstExprBinary* expr) @@ -1240,9 +1294,8 @@ struct TypeChecker2 Scope* scope = findInnermostScope(ty->location); LUAU_ASSERT(scope); - // TODO: Imported types - - std::optional alias = scope->lookupType(ty->name.value); + std::optional alias = + (ty->prefix) ? scope->lookupImportedType(ty->prefix->value, ty->name.value) : scope->lookupType(ty->name.value); if (alias.has_value()) { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index d5c6b2c46..ccb1490a2 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) +LUAU_FASTFLAGVARIABLE(LuauLvaluelessPath, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) @@ -43,15 +44,15 @@ LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) -LUAU_FASTFLAGVARIABLE(LuauUnionOfTypesFollow, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) +LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) namespace Luau { - -const char* TimeLimitError::what() const throw() +const char* TimeLimitError_DEPRECATED::what() const throw() { + LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange); return "Typeinfer failed to complete in allotted time"; } @@ -264,6 +265,11 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona reportErrorCodeTooComplex(module.root->location); return std::move(currentModule); } + catch (const RecursionLimitException_DEPRECATED&) + { + reportErrorCodeTooComplex(module.root->location); + return std::move(currentModule); + } } ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -308,6 +314,10 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo { currentModule->timeout = true; } + catch (const TimeLimitError_DEPRECATED&) + { + currentModule->timeout = true; + } if (FFlag::DebugLuauSharedSelf) { @@ -415,7 +425,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) ice("Unknown AstStat"); if (finishTime && TimeTrace::getClock() > *finishTime) - throw TimeLimitError(); + throwTimeLimitError(); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. @@ -442,6 +452,11 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } + catch (const RecursionLimitException_DEPRECATED&) + { + reportErrorCodeTooComplex(block.location); + return; + } } struct InplaceDemoter : TypeVarOnceVisitor @@ -2456,11 +2471,8 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op) TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes) { - if (FFlag::LuauUnionOfTypesFollow) - { - a = follow(a); - b = follow(b); - } + a = follow(a); + b = follow(b); if (unifyFreeTypes && (get(a) || get(b))) { @@ -3596,8 +3608,17 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam location = {state.location.begin, argLocations.back().end}; std::string namePath; - if (std::optional lValue = tryGetLValue(funName)) - namePath = toString(*lValue); + + if (FFlag::LuauLvaluelessPath) + { + if (std::optional path = getFunctionNameAsString(funName)) + namePath = *path; + } + else + { + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); + } auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); state.reportError(TypeError{location, @@ -3706,11 +3727,28 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam bool isVariadic = tail && Luau::isVariadic(*tail); std::string namePath; - if (std::optional lValue = tryGetLValue(funName)) - namePath = toString(*lValue); - state.reportError(TypeError{ - state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + if (FFlag::LuauLvaluelessPath) + { + if (std::optional path = getFunctionNameAsString(funName)) + namePath = *path; + } + else + { + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); + } + + if (FFlag::LuauArgMismatchReportFunctionLocation) + { + state.reportError(TypeError{ + funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + } + else + { + state.reportError(TypeError{ + state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + } return; } ++paramIter; @@ -4647,6 +4685,19 @@ void TypeChecker::ice(const std::string& message) iceHandler->ice(message); } +// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. +void TypeChecker::throwTimeLimitError() +{ + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw TimeLimitError(iceHandler->moduleName); + } + else + { + throw TimeLimitError_DEPRECATED(); + } +} + void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) { // Remove errors with names that were generated by recovery from a parse error diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 0fa4df605..0852f0535 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypePack.h" +#include "Luau/Error.h" #include "Luau/TxnLog.h" #include @@ -234,7 +235,7 @@ TypePackId follow(TypePackId tp, std::function mapper) cycleTester = nullptr; if (tp == cycleTester) - throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); + throwRuntimeError("Luau::follow detected a TypeVar cycle!!"); } } } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 19d3d2669..94d633c78 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -61,7 +61,7 @@ TypeId follow(TypeId t, std::function mapper) { std::optional ty = utv->scope->lookup(utv->def); if (!ty) - throw std::runtime_error("UseTypeVar must map to another TypeId"); + throwRuntimeError("UseTypeVar must map to another TypeId"); return *ty; } else @@ -73,7 +73,7 @@ TypeId follow(TypeId t, std::function mapper) { TypeId res = ltv->thunk(); if (get(res)) - throw std::runtime_error("Lazy TypeVar cannot resolve to another Lazy TypeVar"); + throwRuntimeError("Lazy TypeVar cannot resolve to another Lazy TypeVar"); *asMutable(ty) = BoundTypeVar(res); } @@ -111,7 +111,7 @@ TypeId follow(TypeId t, std::function mapper) cycleTester = nullptr; if (t == cycleTester) - throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); + throwRuntimeError("Luau::follow detected a TypeVar cycle!!"); } } } @@ -946,7 +946,7 @@ void persist(TypeId ty) queue.push_back(mtv->table); queue.push_back(mtv->metatable); } - else if (get(t) || get(t) || get(t) || get(t) || get(t)) + else if (get(t) || get(t) || get(t) || get(t) || get(t) || get(t)) { } else diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index e23e6161c..b5eba9803 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -16,6 +16,7 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauReportTypeMismatchForTypePackUnificationFailure, false) LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) @@ -273,7 +274,7 @@ TypeId Widen::clean(TypeId ty) TypePackId Widen::clean(TypePackId) { - throw std::runtime_error("Widen attempted to clean a dirty type pack?"); + throwRuntimeError("Widen attempted to clean a dirty type pack?"); } bool Widen::ignoreChildren(TypeId ty) @@ -551,6 +552,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(subTy)) tryUnifyWithClass(subTy, superTy, /*reversed*/ true); + else if (log.get(superTy)) + tryUnifyTypeWithNegation(subTy, superTy); + + else if (log.get(subTy)) + tryUnifyNegationWithType(subTy, superTy); + else reportError(TypeError{location, TypeMismatch{superTy, subTy}}); @@ -866,13 +873,7 @@ void Unifier::tryUnifyNormalizedTypes( if (!get(superNorm.numbers)) return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); - if (subNorm.strings && superNorm.strings) - { - for (auto [name, ty] : *subNorm.strings) - if (!superNorm.strings->count(name)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); - } - else if (!subNorm.strings && superNorm.strings) + if (!isSubtype(subNorm.strings, superNorm.strings)) return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); if (get(subNorm.threads)) @@ -1392,7 +1393,10 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else { - reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); + if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure) + reportError(TypeError{location, TypePackMismatch{subTp, superTp}}); + else + reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); } } @@ -1441,7 +1445,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0); - if (FFlag::LuauInstantiateInSubtyping && variance == Covariant && shouldInstantiate) + // TODO: This is unsound when the context is invariant, but the annotation burden without allowing it and without + // read-only properties is too high for lua-apps. Read-only properties _should_ resolve their issue by allowing + // generic methods in tables to be marked read-only. + if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate) { Instantiation instantiation{&log, types, scope->level, scope}; @@ -1576,6 +1583,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { TableTypeVar* superTable = log.getMutable(superTy); TableTypeVar* subTable = log.getMutable(subTy); + TableTypeVar* instantiatedSubTable = subTable; if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1593,6 +1601,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (instantiated.has_value()) { subTable = log.getMutable(*instantiated); + instantiatedSubTable = subTable; if (!subTable) ice("instantiation made a table type into a non-table type in tryUnifyTables"); @@ -1696,7 +1705,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // txn log. TableTypeVar* newSuperTable = log.getMutable(superTy); TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) return tryUnifyTables(subTy, superTy, isIntersection); @@ -1758,7 +1767,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // txn log. TableTypeVar* newSuperTable = log.getMutable(superTy); TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) return tryUnifyTables(subTy, superTy, isIntersection); @@ -2098,6 +2107,34 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) return fail(); } +void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy) +{ + const NegationTypeVar* ntv = get(superTy); + if (!ntv) + ice("tryUnifyTypeWithNegation superTy must be a negation type"); + + const NormalizedType* subNorm = normalizer->normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + return reportError(TypeError{location, UnificationTooComplex{}}); + + // T (subTy); + if (!ntv) + ice("tryUnifyNegationWithType subTy must be a negation type"); + + // TODO: ~T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { while (true) diff --git a/Ast/include/Luau/ParseResult.h b/Ast/include/Luau/ParseResult.h index 17ce2e3bb..9c0a9527f 100644 --- a/Ast/include/Luau/ParseResult.h +++ b/Ast/include/Luau/ParseResult.h @@ -58,6 +58,8 @@ struct Comment struct ParseResult { AstStatBlock* root; + size_t lines = 0; + std::vector hotcomments; std::vector errors; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 848d71179..8b7eb73cf 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -302,8 +302,8 @@ class Parser AstStatError* reportStatError(const Location& location, const AstArray& expressions, const AstArray& statements, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); - AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) - LUAU_PRINTF_ATTR(5, 6); + AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) + LUAU_PRINTF_ATTR(4, 5); // `parseErrorLocation` is associated with the parser error // `astErrorLocation` is associated with the AstTypeError created // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 7150b18fc..85c5f5c60 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -23,7 +23,6 @@ LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) -LUAU_FASTFLAGVARIABLE(LuauTypeAnnotationLocationChange, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) @@ -164,15 +163,16 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n try { AstStatBlock* root = p.parseChunk(); + size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n'); - return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; + return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; } catch (ParseError& err) { // when catching a fatal error, append it to the list of non-fatal errors and return p.parseErrors.push_back(err); - return ParseResult{nullptr, {}, p.parseErrors}; + return ParseResult{nullptr, 0, {}, p.parseErrors}; } } @@ -811,9 +811,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { - return AstDeclaredClassProp{fnName.name, - reportTypeAnnotationError(Location(start, end), {}, /*isMissing*/ false, "'self' must be present as the unannotated first parameter"), - true}; + return AstDeclaredClassProp{ + fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; } // Skip the first index. @@ -824,8 +823,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args[i].annotation) vars.push_back(args[i].annotation); else - vars.push_back(reportTypeAnnotationError( - Location(start, end), {}, /*isMissing*/ false, "All declaration parameters aside from 'self' must be annotated")); + vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); } if (vararg && !varargAnnotation) @@ -1537,7 +1535,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (isUnion && isIntersection) { - return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), /*isMissing*/ false, + return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } @@ -1623,18 +1621,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) return {allocator.alloc(start, svalue)}; } else - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; + return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")}; } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { parseInterpString(); - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Interpolated string literals cannot be used as types")}; + return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")}; } else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Malformed string")}; + return {reportTypeAnnotationError(start, {}, "Malformed string")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1693,33 +1691,20 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, + return {reportTypeAnnotationError(start, {}, "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " "...any'"), {}}; } else { - if (FFlag::LuauTypeAnnotationLocationChange) - { - // For a missing type annotation, capture 'space' between last token and the next one - Location astErrorlocation(lexer.previousLocation().end, start.begin); - // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). - // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. - Location parseErrorLocation(lexer.previousLocation().end, start.end); - return { - reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), - {}}; - } - else - { - Location location = lexer.current().location; - - // For a missing type annotation, capture 'space' between last token and the next one - location = Location(lexer.previousLocation().end, lexer.current().location.begin); - - return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; - } + // For a missing type annotation, capture 'space' between last token and the next one + Location astErrorlocation(lexer.previousLocation().end, start.begin); + // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). + // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. + Location parseErrorLocation(lexer.previousLocation().end, start.end); + return { + reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } @@ -3033,27 +3018,18 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray(location, expressions, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) +AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) { - if (FFlag::LuauTypeAnnotationLocationChange) - { - // Missing type annotations should be using `reportMissingTypeAnnotationError` when LuauTypeAnnotationLocationChange is enabled - // Note: `isMissing` can be removed once FFlag::LuauTypeAnnotationLocationChange is removed since it will always be true. - LUAU_ASSERT(!isMissing); - } - va_list args; va_start(args, format); report(location, format, args); va_end(args); - return allocator.alloc(location, types, isMissing, unsigned(parseErrors.size() - 1)); + return allocator.alloc(location, types, false, unsigned(parseErrors.size() - 1)); } AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) { - LUAU_ASSERT(FFlag::LuauTypeAnnotationLocationChange); - va_list args; va_start(args, format); report(parseErrorLocation, format, args); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index a9dd8970a..87e19db8b 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -16,7 +16,6 @@ #include "isocline.h" -#include #include #ifdef _WIN32 @@ -688,11 +687,11 @@ static std::string getCodegenAssembly(const char* name, const std::string& bytec return ""; } -static void annotateInstruction(void* context, std::string& text, int fid, int instid) +static void annotateInstruction(void* context, std::string& text, int fid, int instpos) { Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; - bcb.annotateInstruction(text, fid, instid); + bcb.annotateInstruction(text, fid, instpos); } struct CompileStats @@ -711,7 +710,8 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st return false; } - stats.lines += std::count(source->begin(), source->end(), '\n'); + // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) + // This function is much more complicated because it supports many output human-readable formats through internal interfaces try { @@ -736,7 +736,16 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st bcb.setDumpSource(*source); } - Luau::compileOrThrow(bcb, *source, copts()); + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + stats.lines += result.lines; + + Luau::compileOrThrow(bcb, result, names, copts()); stats.bytecode += bcb.getBytecode().size(); switch (format) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0016160ae..05d701ee4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,6 +143,11 @@ if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) endif() +if (NOT MSVC) + # disable support for math_errno which allows compilers to lower sqrt() into a single CPU instruction + target_compile_options(Luau.VM PRIVATE -fno-math-errno) +endif() + if(MSVC AND LUAU_BUILD_CLI) # the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index e8b30195d..cef9ec7cb 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -17,7 +17,7 @@ void create(lua_State* L); // Builds target function and all inner functions void compile(lua_State* L, int idx); -using annotatorFn = void (*)(void* context, std::string& result, int fid, int instid); +using annotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); struct AssemblyOptions { diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index f78ead596..78645766f 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -34,7 +34,350 @@ namespace CodeGen constexpr uint32_t kFunctionAlignment = 32; -static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, Proto* proto, AssemblyOptions options) +struct InstructionOutline +{ + int pcpos; + int length; +}; + +static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) +{ + if (build.logText) + build.logAppend("; exitContinueVm\n"); + helpers.exitContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ true); + + if (build.logText) + build.logAppend("; exitNoContinueVm\n"); + helpers.exitNoContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ false); +} + +static int emitInst( + AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i, Label* labelarr, Label& fallback) +{ + int skip = 0; + + switch (op) + { + case LOP_NOP: + break; + case LOP_LOADNIL: + emitInstLoadNil(build, pc); + break; + case LOP_LOADB: + emitInstLoadB(build, pc, i, labelarr); + break; + case LOP_LOADN: + emitInstLoadN(build, pc); + break; + case LOP_LOADK: + emitInstLoadK(build, pc); + break; + case LOP_LOADKX: + emitInstLoadKX(build, pc); + break; + case LOP_MOVE: + emitInstMove(build, pc); + break; + case LOP_GETGLOBAL: + emitInstGetGlobal(build, pc, i, fallback); + break; + case LOP_SETGLOBAL: + emitInstSetGlobal(build, pc, i, labelarr, fallback); + break; + case LOP_RETURN: + emitInstReturn(build, helpers, pc, i, labelarr); + break; + case LOP_GETTABLE: + emitInstGetTable(build, pc, i, fallback); + break; + case LOP_SETTABLE: + emitInstSetTable(build, pc, i, labelarr, fallback); + break; + case LOP_GETTABLEKS: + emitInstGetTableKS(build, pc, i, fallback); + break; + case LOP_SETTABLEKS: + emitInstSetTableKS(build, pc, i, labelarr, fallback); + break; + case LOP_GETTABLEN: + emitInstGetTableN(build, pc, i, fallback); + break; + case LOP_SETTABLEN: + emitInstSetTableN(build, pc, i, labelarr, fallback); + break; + case LOP_JUMP: + emitInstJump(build, pc, i, labelarr); + break; + case LOP_JUMPBACK: + emitInstJumpBack(build, pc, i, labelarr); + break; + case LOP_JUMPIF: + emitInstJumpIf(build, pc, i, labelarr, /* not_ */ false); + break; + case LOP_JUMPIFNOT: + emitInstJumpIf(build, pc, i, labelarr, /* not_ */ true); + break; + case LOP_JUMPIFEQ: + emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback); + break; + case LOP_JUMPIFLE: + emitInstJumpIfCond(build, pc, i, labelarr, Condition::LessEqual, fallback); + break; + case LOP_JUMPIFLT: + emitInstJumpIfCond(build, pc, i, labelarr, Condition::Less, fallback); + break; + case LOP_JUMPIFNOTEQ: + emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback); + break; + case LOP_JUMPIFNOTLE: + emitInstJumpIfCond(build, pc, i, labelarr, Condition::NotLessEqual, fallback); + break; + case LOP_JUMPIFNOTLT: + emitInstJumpIfCond(build, pc, i, labelarr, Condition::NotLess, fallback); + break; + case LOP_JUMPX: + emitInstJumpX(build, pc, i, labelarr); + break; + case LOP_JUMPXEQKNIL: + emitInstJumpxEqNil(build, pc, i, labelarr); + break; + case LOP_JUMPXEQKB: + emitInstJumpxEqB(build, pc, i, labelarr); + break; + case LOP_JUMPXEQKN: + emitInstJumpxEqN(build, pc, proto->k, i, labelarr); + break; + case LOP_JUMPXEQKS: + emitInstJumpxEqS(build, pc, i, labelarr); + break; + case LOP_ADD: + emitInstBinary(build, pc, i, TM_ADD, fallback); + break; + case LOP_SUB: + emitInstBinary(build, pc, i, TM_SUB, fallback); + break; + case LOP_MUL: + emitInstBinary(build, pc, i, TM_MUL, fallback); + break; + case LOP_DIV: + emitInstBinary(build, pc, i, TM_DIV, fallback); + break; + case LOP_MOD: + emitInstBinary(build, pc, i, TM_MOD, fallback); + break; + case LOP_POW: + emitInstBinary(build, pc, i, TM_POW, fallback); + break; + case LOP_ADDK: + emitInstBinaryK(build, pc, i, TM_ADD, fallback); + break; + case LOP_SUBK: + emitInstBinaryK(build, pc, i, TM_SUB, fallback); + break; + case LOP_MULK: + emitInstBinaryK(build, pc, i, TM_MUL, fallback); + break; + case LOP_DIVK: + emitInstBinaryK(build, pc, i, TM_DIV, fallback); + break; + case LOP_MODK: + emitInstBinaryK(build, pc, i, TM_MOD, fallback); + break; + case LOP_POWK: + emitInstPowK(build, pc, proto->k, i, fallback); + break; + case LOP_NOT: + emitInstNot(build, pc); + break; + case LOP_MINUS: + emitInstMinus(build, pc, i, fallback); + break; + case LOP_LENGTH: + emitInstLength(build, pc, i, fallback); + break; + case LOP_NEWTABLE: + emitInstNewTable(build, pc, i, labelarr); + break; + case LOP_DUPTABLE: + emitInstDupTable(build, pc, i, labelarr); + break; + case LOP_SETLIST: + emitInstSetList(build, pc, i, labelarr); + break; + case LOP_GETUPVAL: + emitInstGetUpval(build, pc, i); + break; + case LOP_SETUPVAL: + emitInstSetUpval(build, pc, i, labelarr); + break; + case LOP_CLOSEUPVALS: + emitInstCloseUpvals(build, pc, i, labelarr); + break; + case LOP_FASTCALL: + skip = emitInstFastCall(build, pc, i, labelarr); + break; + case LOP_FASTCALL1: + skip = emitInstFastCall1(build, pc, i, labelarr); + break; + case LOP_FASTCALL2: + skip = emitInstFastCall2(build, pc, i, labelarr); + break; + case LOP_FASTCALL2K: + skip = emitInstFastCall2K(build, pc, i, labelarr); + break; + case LOP_FORNPREP: + emitInstForNPrep(build, pc, i, labelarr); + break; + case LOP_FORNLOOP: + emitInstForNLoop(build, pc, i, labelarr); + break; + case LOP_FORGLOOP: + emitinstForGLoop(build, pc, i, labelarr, fallback); + break; + case LOP_FORGPREP_NEXT: + emitInstForGPrepNext(build, pc, i, labelarr, fallback); + break; + case LOP_FORGPREP_INEXT: + emitInstForGPrepInext(build, pc, i, labelarr, fallback); + break; + case LOP_AND: + emitInstAnd(build, pc); + break; + case LOP_ANDK: + emitInstAndK(build, pc); + break; + case LOP_OR: + emitInstOr(build, pc); + break; + case LOP_ORK: + emitInstOrK(build, pc); + break; + case LOP_GETIMPORT: + emitInstGetImport(build, pc, fallback); + break; + case LOP_CONCAT: + emitInstConcat(build, pc, i, labelarr); + break; + default: + emitFallback(build, data, op, i); + break; + } + + return skip; +} + +static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauOpcode op, const Instruction* pc, int i, Label* labelarr) +{ + switch (op) + { + case LOP_GETIMPORT: + emitInstGetImportFallback(build, pc, i); + break; + case LOP_GETTABLE: + emitInstGetTableFallback(build, pc, i); + break; + case LOP_SETTABLE: + emitInstSetTableFallback(build, pc, i); + break; + case LOP_GETTABLEN: + emitInstGetTableNFallback(build, pc, i); + break; + case LOP_SETTABLEN: + emitInstSetTableNFallback(build, pc, i); + break; + case LOP_JUMPIFEQ: + emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false); + break; + case LOP_JUMPIFLE: + emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::LessEqual); + break; + case LOP_JUMPIFLT: + emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::Less); + break; + case LOP_JUMPIFNOTEQ: + emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true); + break; + case LOP_JUMPIFNOTLE: + emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::NotLessEqual); + break; + case LOP_JUMPIFNOTLT: + emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::NotLess); + break; + case LOP_ADD: + emitInstBinaryFallback(build, pc, i, TM_ADD); + break; + case LOP_SUB: + emitInstBinaryFallback(build, pc, i, TM_SUB); + break; + case LOP_MUL: + emitInstBinaryFallback(build, pc, i, TM_MUL); + break; + case LOP_DIV: + emitInstBinaryFallback(build, pc, i, TM_DIV); + break; + case LOP_MOD: + emitInstBinaryFallback(build, pc, i, TM_MOD); + break; + case LOP_POW: + emitInstBinaryFallback(build, pc, i, TM_POW); + break; + case LOP_ADDK: + emitInstBinaryKFallback(build, pc, i, TM_ADD); + break; + case LOP_SUBK: + emitInstBinaryKFallback(build, pc, i, TM_SUB); + break; + case LOP_MULK: + emitInstBinaryKFallback(build, pc, i, TM_MUL); + break; + case LOP_DIVK: + emitInstBinaryKFallback(build, pc, i, TM_DIV); + break; + case LOP_MODK: + emitInstBinaryKFallback(build, pc, i, TM_MOD); + break; + case LOP_POWK: + emitInstBinaryKFallback(build, pc, i, TM_POW); + break; + case LOP_MINUS: + emitInstMinusFallback(build, pc, i); + break; + case LOP_LENGTH: + emitInstLengthFallback(build, pc, i); + break; + case LOP_FORGLOOP: + emitinstForGLoopFallback(build, pc, i, labelarr); + break; + case LOP_FORGPREP_NEXT: + case LOP_FORGPREP_INEXT: + emitInstForGPrepXnextFallback(build, pc, i, labelarr); + break; + case LOP_GETGLOBAL: + // TODO: luaV_gettable + cachedslot update instead of full fallback + emitFallback(build, data, op, i); + break; + case LOP_SETGLOBAL: + // TODO: luaV_settable + cachedslot update instead of full fallback + emitFallback(build, data, op, i); + break; + case LOP_GETTABLEKS: + // Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access + // It is also required to perform cached slot update + // TODO: extra fast-paths could be lowered before the full fallback + emitFallback(build, data, op, i); + break; + case LOP_SETTABLEKS: + // TODO: luaV_settable + cachedslot update instead of full fallback + emitFallback(build, data, op, i); + break; + default: + LUAU_ASSERT(!"Expected fallback for instruction"); + } +} + +static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { NativeProto* result = new NativeProto(); @@ -59,222 +402,65 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat std::vector(a) -> ()' could not be converted into 't1 where t1 = ({- Clone: t1 -}) -> (a...)'; different number of generic type parameters)", - toString(result.errors[0])); - } - else - { - LUAU_REQUIRE_NO_ERRORS(result); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp new file mode 100644 index 000000000..1035eda49 --- /dev/null +++ b/tests/TypeInfer.negations.test.cpp @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/Common.h" +#include "ScopedFlags.h" + +using namespace Luau; + +namespace +{ +struct NegationFixture : Fixture +{ + TypeArena arena; + ScopedFastFlag sff[2] { + {"LuauNegatedStringSingletons", true}, + {"LuauSubtypeNormalizer", true}, + }; + + NegationFixture() + { + registerNotType(*this, arena); + } +}; +} + +TEST_SUITE_BEGIN("Negations"); + +TEST_CASE_FIXTURE(NegationFixture, "negated_string_is_a_subtype_of_string") +{ + CheckResult result = check(R"( + function foo(arg: string) end + local a: string & Not<"Hello"> + foo(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NegationFixture, "string_is_not_a_subtype_of_negated_string") +{ + CheckResult result = check(R"( + function foo(arg: string & Not<"hello">) end + local a: string + foo(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index e572c87ac..b2516f6d8 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -434,16 +434,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") { CheckResult result = check(R"( --!strict - local foo = { - value = 10 - } + local foo local mt = {} - setmetatable(foo, mt) mt.__unm = function(val: typeof(foo)): string - return val.value .. "test" + return tostring(val.value) .. "test" end + foo = setmetatable({ + value = 10 + }, mt) + local a = -foo local b = 1+-1 @@ -459,25 +460,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") CHECK_EQ("string", toString(requireType("a"))); CHECK_EQ("number", toString(requireType("b"))); - GenericError* gen = get(result.errors[0]); - REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK(toString(result.errors[0]) == "Type '{ value: number }' could not be converted into 'number'"); + } + else + { + GenericError* gen = get(result.errors[0]); + REQUIRE(gen); + REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") { CheckResult result = check(R"( --!strict - local foo = { - value = 10 - } - local mt = {} - setmetatable(foo, mt) mt.__unm = function(val: boolean): string return "test" end + local foo = setmetatable({ + value = 10 + }, mt) + local a = -foo )"); @@ -494,16 +502,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") { CheckResult result = check(R"( --!strict - local foo = { - value = 10 - } local mt = {} - setmetatable(foo, mt) - mt.__len = function(val: any): string + mt.__len = function(val): string return "test" end + local foo = setmetatable({ + value = 10, + }, mt) + local a = #foo )"); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index ccc4d775a..8e04c165c 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -624,15 +624,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") CHECK_EQ("{string | string}", toString(requireType("t"))); } -struct NormalizeFixture : Fixture +namespace +{ +struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); } }; +} -TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arities") +TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection_of_functions_of_different_arities") { check(R"( type A = (any) -> () @@ -653,7 +656,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arit CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t"))); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity") { check(R"( local a: (number) -> () @@ -676,7 +679,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") CHECK(!isSubtype(b, c)); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_optional_parameters") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity_but_optional_parameters") { /* * (T0..TN) <: (T0..TN, A?) @@ -736,7 +739,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_option // CHECK(!isSubtype(b, c)); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_any_is_an_optional_param") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity_but_any_is_an_optional_param") { check(R"( local a: (number?) -> () diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2a208cce0..7de412ffd 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1,5 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/Frontend.h" +#include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -14,6 +17,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauSpecialTypesAsterisked) TEST_SUITE_BEGIN("TableTests"); @@ -1957,7 +1961,11 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table local c : string = t.m("hi") )"); - LUAU_REQUIRE_ERRORS(result); + // TODO: test behavior is wrong with LuauInstantiateInSubtyping until we can re-enable the covariant requirement for instantiation in subtyping + if (FFlag::LuauInstantiateInSubtyping) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") @@ -3262,11 +3270,13 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table local c : string = t.m("hi") )"); - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' -caused by: - Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)"); - // this error message is not great since the underlying issue is that the context is invariant, + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: test behavior is wrong until we can re-enable the covariant requirement for instantiation in subtyping +// LUAU_REQUIRE_ERRORS(result); +// CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' +// caused by: +// Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)"); +// // this error message is not great since the underlying issue is that the context is invariant, // and `(number) -> number` cannot be a subtype of `(a) -> a`. } @@ -3292,4 +3302,43 @@ local g : ({ p : number, q : string }) -> ({ p : number, r : boolean }) = f CHECK_EQ("r", error->properties[0]); } +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_has_a_side_effect") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local mt = { + __add = function(x, y) + return 123 + end, + } + + local foo = {} + setmetatable(foo, mt) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("foo")) == "{ @metatable { __add: (a, b) -> number }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated") +{ + CheckResult result = check(R"( + local t = { + x = 5 :: NonexistingTypeWhichEndsUpReturningAnErrorType, + y = 5 + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + ToStringOptions opts; + opts.exhaustive = true; + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts)); + else + CHECK_EQ("{ x: , y: number }", toString(requireType("t"), opts)); +} + TEST_SUITE_END(); diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitTypeVar.test.cpp index 4fba694a8..589c3bad5 100644 --- a/tests/VisitTypeVar.test.cpp +++ b/tests/VisitTypeVar.test.cpp @@ -22,7 +22,14 @@ TEST_CASE_FIXTURE(Fixture, "throw_when_limit_is_exceeded") TypeId tType = requireType("t"); - CHECK_THROWS_AS(toString(tType), RecursionLimitException); + if (FFlag::LuauIceExceptionInheritanceChange) + { + CHECK_THROWS_AS(toString(tType), RecursionLimitException); + } + else + { + CHECK_THROWS_AS(toString(tType), RecursionLimitException_DEPRECATED); + } } TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") diff --git a/tools/faillist.txt b/tools/faillist.txt index c869e0c47..a4c05b7bf 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,14 +1,14 @@ -AnnotationTests.builtin_types_are_not_exported AnnotationTests.corecursive_types_error_on_tight_loop AnnotationTests.duplicate_type_param_name AnnotationTests.for_loop_counter_annotation_is_checked AnnotationTests.generic_aliases_are_cloned_properly AnnotationTests.instantiation_clone_has_to_follow +AnnotationTests.luau_print_is_not_special_without_the_flag AnnotationTests.occurs_check_on_cyclic_intersection_typevar AnnotationTests.occurs_check_on_cyclic_union_typevar AnnotationTests.too_many_type_params AnnotationTests.two_type_params -AnnotationTests.use_type_required_from_another_file +AnnotationTests.unknown_type_reference_generates_error AstQuery.last_argument_function_call_type AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AutocompleteTest.autocomplete_first_function_arg_expected_type @@ -86,7 +86,6 @@ BuiltinTests.table_pack BuiltinTests.table_pack_reduce BuiltinTests.table_pack_variadic BuiltinTests.tonumber_returns_optional_number_type -BuiltinTests.tonumber_returns_optional_number_type2 DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props DefinitionTests.declaring_generic_functions @@ -96,7 +95,6 @@ FrontendTest.imported_table_modification_2 FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded FrontendTest.nocheck_cycle_used_by_checked FrontendTest.reexport_cyclic_type -FrontendTest.reexport_type_alias FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics2 @@ -105,7 +103,6 @@ GenericsTests.calling_self_generic_methods GenericsTests.check_generic_typepack_function GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions -GenericsTests.do_not_always_instantiate_generic_intersection_types GenericsTests.do_not_infer_generic_functions GenericsTests.duplicate_generic_type_packs GenericsTests.duplicate_generic_types @@ -143,7 +140,6 @@ IntersectionTypes.table_write_sealed_indirect ModuleTests.any_persistance_does_not_leak ModuleTests.clone_self_property ModuleTests.deepClone_cyclic_table -ModuleTests.do_not_clone_reexports NonstrictModeTests.for_in_iterator_variables_are_any NonstrictModeTests.function_parameters_are_any NonstrictModeTests.inconsistent_module_return_types_are_ok @@ -158,7 +154,6 @@ NonstrictModeTests.parameters_having_type_any_are_optional NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_props_are_any Normalize.cyclic_table_normalizes_sensibly -Normalize.intersection_combine_on_bound_self ParseErrorRecovery.generic_type_list_recovery ParseErrorRecovery.recovery_of_parenthesized_expressions ParserTests.parse_nesting_based_end_detection_failsafe_earlier @@ -249,7 +244,6 @@ TableTests.defining_a_self_method_for_a_builtin_sealed_table_must_fail TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index -TableTests.dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back TableTests.dont_leak_free_table_props TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys @@ -279,7 +273,6 @@ TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors -TableTests.length_operator_union_errors TableTests.less_exponential_blowup_please TableTests.meta_add TableTests.meta_add_both_ways @@ -347,9 +340,9 @@ TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails TryUnifyTests.variadics_should_use_reversed_properly +TypeAliases.cannot_create_cyclic_type_with_unknown_module TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any TypeAliases.generic_param_remap -TypeAliases.mismatched_generic_pack_type_param TypeAliases.mismatched_generic_type_param TypeAliases.mutually_recursive_types_restriction_not_ok_1 TypeAliases.mutually_recursive_types_restriction_not_ok_2 @@ -363,7 +356,7 @@ TypeAliases.type_alias_fwd_declaration_is_precise TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_of_an_imported_recursive_generic_type -TypeAliases.type_alias_of_an_imported_recursive_type +TypeInfer.check_type_infer_recursion_count TypeInfer.checking_should_not_ice TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError @@ -394,6 +387,7 @@ TypeInferClasses.warn_when_prop_almost_matches TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument +TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict @@ -439,12 +433,9 @@ TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferModules.bound_free_table_export_is_ok TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types -TypeInferModules.do_not_modify_imported_types_2 -TypeInferModules.do_not_modify_imported_types_3 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated TypeInferModules.require_a_variadic_function -TypeInferModules.require_types TypeInferModules.type_error_of_unknown_qualified_type TypeInferOOP.CheckMethodsOfSealed TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works @@ -468,9 +459,6 @@ TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with TypeInferOperators.refine_and_or TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs -TypeInferOperators.typecheck_unary_len_error -TypeInferOperators.typecheck_unary_minus -TypeInferOperators.typecheck_unary_minus_error TypeInferOperators.UnknownGlobalCompoundAssign TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.singleton_types @@ -489,6 +477,7 @@ TypeInferUnknownNever.math_operators_and_never TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.unary_minus_of_never +TypePackTests.detect_cyclic_typepacks2 TypePackTests.higher_order_function TypePackTests.pack_tail_unification_check TypePackTests.parenthesized_varargs_returns_any From a6cbb0f65c648ae82934e78b44bcdfd79fec8509 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 27 Oct 2022 17:47:14 -0700 Subject: [PATCH 14/66] Fix clang-14 / GNU ld interaction for target_clones --- VM/src/lnumutils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index 4d7c9516f..abaf9af7a 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -34,7 +34,7 @@ inline bool luai_vecisnan(const float* a) } LUAU_FASTMATH_BEGIN -LUAU_DISPATCH_SSE41 +// TODO: LUAU_DISPATCH_SSE41 would be nice here, but clang-14 doesn't support it correctly on inline functions... inline double luai_nummod(double a, double b) { return a - floor(a / b) * b; From e3fdab308218928f9b9ecf6c77b8c6e8211045c3 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 4 Nov 2022 10:02:37 -0700 Subject: [PATCH 15/66] Sync to upstream/release/552 --- Analysis/include/Luau/Connective.h | 68 + Analysis/include/Luau/Constraint.h | 7 +- .../include/Luau/ConstraintGraphBuilder.h | 19 +- Analysis/include/Luau/ConstraintSolver.h | 2 +- Analysis/include/Luau/Normalize.h | 40 +- Analysis/include/Luau/TypeUtils.h | 4 +- Analysis/include/Luau/TypeVar.h | 18 +- Analysis/include/Luau/Unifier.h | 2 +- Analysis/include/Luau/Variant.h | 7 +- Analysis/include/Luau/VisitTypeVar.h | 6 - Analysis/src/BuiltinDefinitions.cpp | 3 +- Analysis/src/Clone.cpp | 7 - Analysis/src/Connective.cpp | 32 + Analysis/src/ConstraintGraphBuilder.cpp | 245 +- Analysis/src/ConstraintSolver.cpp | 29 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 76 +- Analysis/src/Normalize.cpp | 227 +- Analysis/src/ToString.cpp | 114 +- Analysis/src/TypeAttach.cpp | 6 - Analysis/src/TypeChecker2.cpp | 31 +- Analysis/src/TypeVar.cpp | 11 +- Analysis/src/Unifier.cpp | 181 +- Ast/src/Lexer.cpp | 2 +- Ast/src/Parser.cpp | 15 +- CLI/Repl.cpp | 3 +- CodeGen/include/Luau/AddressA64.h | 52 + CodeGen/include/Luau/AssemblyBuilderA64.h | 144 ++ CodeGen/include/Luau/AssemblyBuilderX64.h | 4 +- CodeGen/include/Luau/ConditionA64.h | 37 + .../Luau/{Condition.h => ConditionX64.h} | 2 +- CodeGen/include/Luau/OperandX64.h | 4 +- CodeGen/include/Luau/RegisterA64.h | 105 + CodeGen/include/Luau/RegisterX64.h | 5 + CodeGen/src/AssemblyBuilderA64.cpp | 607 +++++ CodeGen/src/AssemblyBuilderX64.cpp | 21 +- CodeGen/src/CodeGen.cpp | 31 +- CodeGen/src/CodeGenUtils.cpp | 54 + CodeGen/src/CodeGenUtils.h | 3 + CodeGen/src/CodeGenX64.cpp | 4 +- CodeGen/src/EmitCommonX64.cpp | 143 +- CodeGen/src/EmitCommonX64.h | 38 +- CodeGen/src/EmitInstructionX64.cpp | 322 ++- CodeGen/src/EmitInstructionX64.h | 7 +- CodeGen/src/Fallbacks.cpp | 2121 +---------------- CodeGen/src/Fallbacks.h | 69 - CodeGen/src/NativeState.cpp | 5 +- CodeGen/src/NativeState.h | 7 +- Sources.cmake | 10 +- VM/include/luaconf.h | 8 - VM/src/lbuiltins.cpp | 9 +- VM/src/lnumutils.h | 1 - tests/AssemblyBuilderA64.test.cpp | 221 ++ tests/AssemblyBuilderX64.test.cpp | 75 +- tests/AstJsonEncoder.test.cpp | 2 +- tests/CodeAllocator.test.cpp | 54 +- tests/Conformance.test.cpp | 10 +- tests/Fixture.cpp | 3 +- tests/Fixture.h | 2 +- tests/Normalize.test.cpp | 73 +- tests/Parser.test.cpp | 48 + tests/ToString.test.cpp | 30 +- tests/TypeInfer.anyerror.test.cpp | 22 +- tests/TypeInfer.builtins.test.cpp | 10 +- tests/TypeInfer.functions.test.cpp | 92 +- tests/TypeInfer.generics.test.cpp | 6 +- tests/TypeInfer.loops.test.cpp | 6 +- tests/TypeInfer.modules.test.cpp | 12 +- tests/TypeInfer.negations.test.cpp | 6 +- tests/TypeInfer.primitives.test.cpp | 7 +- tests/TypeInfer.provisional.test.cpp | 2 +- tests/TypeInfer.refinements.test.cpp | 222 +- tests/TypeInfer.singletons.test.cpp | 10 + tests/TypeInfer.tables.test.cpp | 17 +- tests/TypeInfer.test.cpp | 24 +- tests/TypeInfer.tryUnify.test.cpp | 12 +- tests/TypeInfer.unionTypes.test.cpp | 7 +- tests/Variant.test.cpp | 31 + tools/faillist.txt | 23 +- tools/lvmexecute_split.py | 9 +- tools/stack-usage-reporter.py | 173 ++ 80 files changed, 3126 insertions(+), 3051 deletions(-) create mode 100644 Analysis/include/Luau/Connective.h create mode 100644 Analysis/src/Connective.cpp create mode 100644 CodeGen/include/Luau/AddressA64.h create mode 100644 CodeGen/include/Luau/AssemblyBuilderA64.h create mode 100644 CodeGen/include/Luau/ConditionA64.h rename CodeGen/include/Luau/{Condition.h => ConditionX64.h} (94%) create mode 100644 CodeGen/include/Luau/RegisterA64.h create mode 100644 CodeGen/src/AssemblyBuilderA64.cpp create mode 100644 tests/AssemblyBuilderA64.test.cpp create mode 100644 tools/stack-usage-reporter.py diff --git a/Analysis/include/Luau/Connective.h b/Analysis/include/Luau/Connective.h new file mode 100644 index 000000000..c9daa0f9e --- /dev/null +++ b/Analysis/include/Luau/Connective.h @@ -0,0 +1,68 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Def.h" +#include "Luau/TypedAllocator.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +struct Negation; +struct Conjunction; +struct Disjunction; +struct Equivalence; +struct Proposition; +using Connective = Variant; +using ConnectiveId = Connective*; // Can and most likely is nullptr. + +struct Negation +{ + ConnectiveId connective; +}; + +struct Conjunction +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Disjunction +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Equivalence +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Proposition +{ + DefId def; + TypeId discriminantTy; +}; + +template +const T* get(ConnectiveId connective) +{ + return get_if(connective); +} + +struct ConnectiveArena +{ + TypedAllocator allocator; + + ConnectiveId negation(ConnectiveId connective); + ConnectiveId conjunction(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId disjunction(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId equivalence(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId proposition(DefId def, TypeId discriminantTy); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 7f092f5b2..4370d0cf4 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -132,15 +132,16 @@ struct HasPropConstraint std::string prop; }; -struct RefinementConstraint +// result ~ if isSingleton D then ~D else unknown where D = discriminantType +struct SingletonOrTopTypeConstraint { - DefId def; + TypeId resultType; TypeId discriminantType; }; using ConstraintV = Variant; + HasPropConstraint, SingletonOrTopTypeConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 6106717c5..cb5900ea9 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Ast.h" +#include "Luau/Connective.h" #include "Luau/Constraint.h" #include "Luau/DataFlowGraphBuilder.h" #include "Luau/Module.h" @@ -26,11 +27,13 @@ struct DcrLogger; struct Inference { TypeId ty = nullptr; + ConnectiveId connective = nullptr; Inference() = default; - explicit Inference(TypeId ty) + explicit Inference(TypeId ty, ConnectiveId connective = nullptr) : ty(ty) + , connective(connective) { } }; @@ -38,11 +41,13 @@ struct Inference struct InferencePack { TypePackId tp = nullptr; + std::vector connectives; InferencePack() = default; - explicit InferencePack(TypePackId tp) + explicit InferencePack(TypePackId tp, const std::vector& connectives = {}) : tp(tp) + , connectives(connectives) { } }; @@ -73,6 +78,7 @@ struct ConstraintGraphBuilder // Defining scopes for AST nodes. DenseHashMap astTypeAliasDefiningScopes{nullptr}; NotNull dfg; + ConnectiveArena connectiveArena; int recursionCount = 0; @@ -126,6 +132,8 @@ struct ConstraintGraphBuilder */ NotNull addConstraint(const ScopePtr& scope, std::unique_ptr c); + void applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective); + /** * The entry point to the ConstraintGraphBuilder. This will construct a set * of scopes, constraints, and free types that can be solved later. @@ -167,10 +175,10 @@ struct ConstraintGraphBuilder * surrounding context. Used to implement bidirectional type checking. * @return the type of the expression. */ - Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); + Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}, bool forceSingleton = false); - Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType); - Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); + Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprLocal* local); Inference check(const ScopePtr& scope, AstExprGlobal* global); Inference check(const ScopePtr& scope, AstExprIndexName* indexName); @@ -180,6 +188,7 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 5cc63e656..07f027ad2 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -110,7 +110,7 @@ struct ConstraintSolver bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); - bool tryDispatch(const RefinementConstraint& c, NotNull constraint); + bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index f98442dd1..b28c06a58 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -17,10 +17,8 @@ struct SingletonTypes; using ModulePtr = std::shared_ptr; -bool isSubtype( - TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, - bool anyIsTop = true); +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); class TypeIds { @@ -169,12 +167,26 @@ struct NormalizedStringType bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr); -// A normalized function type is either `never` (represented by `nullopt`) +// A normalized function type can be `never`, the top function type `function`, // or an intersection of function types. -// NOTE: type normalization can fail on function types with generics -// (e.g. because we do not support unions and intersections of generic type packs), -// so this type may contain `error`. -using NormalizedFunctionType = std::optional; +// +// NOTE: type normalization can fail on function types with generics (e.g. +// because we do not support unions and intersections of generic type packs), so +// this type may contain `error`. +struct NormalizedFunctionType +{ + NormalizedFunctionType(); + + bool isTop = false; + // TODO: Remove this wrapping optional when clipping + // FFlagLuauNegatedFunctionTypes. + std::optional parts; + + void resetToNever(); + void resetToTop(); + + bool isNever() const; +}; // A normalized generic/free type is a union, where each option is of the form (X & T) where // * X is either a free type or a generic @@ -234,12 +246,14 @@ struct NormalizedType NormalizedType(NotNull singletonTypes); - NormalizedType(const NormalizedType&) = delete; - NormalizedType(NormalizedType&&) = default; NormalizedType() = delete; ~NormalizedType() = default; + + NormalizedType(const NormalizedType&) = delete; + NormalizedType& operator=(const NormalizedType&) = delete; + + NormalizedType(NormalizedType&&) = default; NormalizedType& operator=(NormalizedType&&) = default; - NormalizedType& operator=(NormalizedType&) = delete; }; class Normalizer @@ -291,7 +305,7 @@ class Normalizer bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); // ------- Negations - NormalizedType negateNormal(const NormalizedType& here); + std::optional negateNormal(const NormalizedType& here); TypeIds negateAll(const TypeIds& theres); TypeId negate(TypeId there); void subtractPrimitive(NormalizedType& here, TypeId ty); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 7409dbe74..085ee21b0 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -35,7 +35,7 @@ std::vector flatten(TypeArena& arena, NotNull singletonT * identity) types. * @param types the input type list to reduce. * @returns the reduced type list. -*/ + */ std::vector reduceUnion(const std::vector& types); /** @@ -45,7 +45,7 @@ std::vector reduceUnion(const std::vector& types); * @param arena the type arena to allocate the new type in, if necessary * @param ty the type to remove nil from * @returns a type with nil removed, or nil itself if that were the only option. -*/ + */ TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty); } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 70c12cb9d..0ab4d4749 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -115,6 +115,7 @@ struct PrimitiveTypeVar Number, String, Thread, + Function, }; Type type; @@ -504,14 +505,6 @@ struct NeverTypeVar { }; -// Invariant 1: there should never be a reason why such UseTypeVar exists without it mapping to another type. -// Invariant 2: UseTypeVar should always disappear across modules. -struct UseTypeVar -{ - DefId def; - NotNull scope; -}; - // ~T // TODO: Some simplification step that overwrites the type graph to make sure negation // types disappear from the user's view, and (?) a debug flag to disable that @@ -522,9 +515,9 @@ struct NegationTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = + Unifiable::Variant; struct TypeVar final { @@ -644,13 +637,14 @@ struct SingletonTypes const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId functionType; const TypeId trueType; const TypeId falseType; const TypeId anyType; const TypeId unknownType; const TypeId neverType; const TypeId errorType; - const TypeId falsyType; // No type binding! + const TypeId falsyType; // No type binding! const TypeId truthyType; // No type binding! const TypePackId anyTypePack; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 7bf4d50b7..b5f58d3c6 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -61,7 +61,6 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; - bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. bool normalize; // Normalize unions and intersections if necessary bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; @@ -131,6 +130,7 @@ struct Unifier Unifier makeChildUnifier(); void reportError(TypeError err); + LUAU_NOINLINE void reportError(Location location, TypeErrorData data); private: bool isNonstrictMode() const; diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index 76812c9bf..016c51f62 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -58,13 +58,15 @@ class Variant constexpr int tid = getTypeId(); typeId = tid; - new (&storage) TT(value); + new (&storage) TT(std::forward(value)); } Variant(const Variant& other) { + static constexpr FnCopy table[sizeof...(Ts)] = {&fnCopy...}; + typeId = other.typeId; - tableCopy[typeId](&storage, &other.storage); + table[typeId](&storage, &other.storage); } Variant(Variant&& other) @@ -192,7 +194,6 @@ class Variant return *static_cast(lhs) == *static_cast(rhs); } - static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor...}; diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index d4f8528ff..3dcddba19 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -155,10 +155,6 @@ struct GenericTypeVarVisitor { return visit(ty); } - virtual bool visit(TypeId ty, const UseTypeVar& utv) - { - return visit(ty); - } virtual bool visit(TypeId ty, const NegationTypeVar& ntv) { return visit(ty); @@ -321,8 +317,6 @@ struct GenericTypeVarVisitor traverse(a); } } - else if (auto utv = get(ty)) - visit(ty, *utv); else if (auto ntv = get(ty)) visit(ty, *ntv); else if (!FFlag::LuauCompleteVisitor) diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 6051e117a..ee53ae6b4 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -714,8 +714,7 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context) result = arena->addType(UnionTypeVar{std::move(options)}); TypeId numberType = context.solver->singletonTypes->numberType; - TypeId packedTable = arena->addType( - TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); + TypeId packedTable = arena->addType(TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); TypePackId tableTypePack = arena->addTypePack({packedTable}); asMutable(context.result)->ty.emplace(tableTypePack); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 85408919b..86e1c7fc9 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -62,7 +62,6 @@ struct TypeCloner void operator()(const LazyTypeVar& t); void operator()(const UnknownTypeVar& t); void operator()(const NeverTypeVar& t); - void operator()(const UseTypeVar& t); void operator()(const NegationTypeVar& t); }; @@ -338,12 +337,6 @@ void TypeCloner::operator()(const NeverTypeVar& t) defaultClone(t); } -void TypeCloner::operator()(const UseTypeVar& t) -{ - TypeId result = dest.addType(BoundTypeVar{follow(typeId)}); - seenTypes[typeId] = result; -} - void TypeCloner::operator()(const NegationTypeVar& t) { TypeId result = dest.addType(AnyTypeVar{}); diff --git a/Analysis/src/Connective.cpp b/Analysis/src/Connective.cpp new file mode 100644 index 000000000..114b5f2f7 --- /dev/null +++ b/Analysis/src/Connective.cpp @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Connective.h" + +namespace Luau +{ + +ConnectiveId ConnectiveArena::negation(ConnectiveId connective) +{ + return NotNull{allocator.allocate(Negation{connective})}; +} + +ConnectiveId ConnectiveArena::conjunction(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Conjunction{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::disjunction(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Disjunction{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::equivalence(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Equivalence{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::proposition(DefId def, TypeId discriminantTy) +{ + return NotNull{allocator.allocate(Proposition{def, discriminantTy})}; +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 455fc221d..79a69ca47 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -107,6 +107,101 @@ NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, return NotNull{scope->constraints.emplace_back(std::move(c)).get()}; } +static void unionRefinements(const std::unordered_map& lhs, const std::unordered_map& rhs, + std::unordered_map& dest, NotNull arena) +{ + for (auto [def, ty] : lhs) + { + auto rhsIt = rhs.find(def); + if (rhsIt == rhs.end()) + continue; + + std::vector discriminants{{ty, rhsIt->second}}; + + if (auto destIt = dest.find(def); destIt != dest.end()) + discriminants.push_back(destIt->second); + + dest[def] = arena->addType(UnionTypeVar{std::move(discriminants)}); + } +} + +static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, std::unordered_map* refis, bool sense, + NotNull arena, bool eq, std::vector* constraints) +{ + using RefinementMap = std::unordered_map; + + if (!connective) + return; + else if (auto negation = get(connective)) + return computeRefinement(scope, negation->connective, refis, !sense, arena, eq, constraints); + else if (auto conjunction = get(connective)) + { + RefinementMap lhsRefis; + RefinementMap rhsRefis; + + computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); + computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); + + if (!sense) + unionRefinements(lhsRefis, rhsRefis, *refis, arena); + } + else if (auto disjunction = get(connective)) + { + RefinementMap lhsRefis; + RefinementMap rhsRefis; + + computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); + computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); + + if (sense) + unionRefinements(lhsRefis, rhsRefis, *refis, arena); + } + else if (auto equivalence = get(connective)) + { + computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints); + computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints); + } + else if (auto proposition = get(connective)) + { + TypeId discriminantTy = proposition->discriminantTy; + if (!sense && !eq) + discriminantTy = arena->addType(NegationTypeVar{proposition->discriminantTy}); + else if (!sense && eq) + { + discriminantTy = arena->addType(BlockedTypeVar{}); + constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy}); + } + + if (auto it = refis->find(proposition->def); it != refis->end()) + (*refis)[proposition->def] = arena->addType(IntersectionTypeVar{{discriminantTy, it->second}}); + else + (*refis)[proposition->def] = discriminantTy; + } +} + +void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective) +{ + if (!connective) + return; + + std::unordered_map refinements; + std::vector constraints; + computeRefinement(scope, connective, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); + + for (auto [def, discriminantTy] : refinements) + { + std::optional defTy = scope->lookup(def); + if (!defTy) + ice->ice("Every DefId must map to a type!"); + + TypeId resultTy = arena->addType(IntersectionTypeVar{{*defTy, discriminantTy}}); + scope->dcrRefinements[def] = resultTy; + } + + for (auto& c : constraints) + addConstraint(scope, location, c); +} + void ConstraintGraphBuilder::visit(AstStatBlock* block) { LUAU_ASSERT(scopes.empty()); @@ -250,14 +345,33 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (value->is()) { - // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. - // See the test TypeInfer/infer_locals_with_nil_value. - // Better flow awareness should make this obsolete. + // HACK: we leave nil-initialized things floating under the + // assumption that they will later be populated. + // + // See the test TypeInfer/infer_locals_with_nil_value. Better flow + // awareness should make this obsolete. if (!varTypes[i]) varTypes[i] = freshType(scope); } - else if (i == local->values.size - 1) + // Only function calls and vararg expressions can produce packs. All + // other expressions produce exactly one value. + else if (i != local->values.size - 1 || (!value->is() && !value->is())) + { + std::optional expectedType; + if (hasAnnotation) + expectedType = varTypes.at(i); + + TypeId exprType = check(scope, value, expectedType).ty; + if (i < varTypes.size()) + { + if (varTypes[i]) + addConstraint(scope, local->location, SubtypeConstraint{exprType, varTypes[i]}); + else + varTypes[i] = exprType; + } + } + else { std::vector expectedTypes; if (hasAnnotation) @@ -286,21 +400,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); } } - else - { - std::optional expectedType; - if (hasAnnotation) - expectedType = varTypes.at(i); - - TypeId exprType = check(scope, value, expectedType).ty; - if (i < varTypes.size()) - { - if (varTypes[i]) - addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); - else - varTypes[i] = exprType; - } - } } for (size_t i = 0; i < local->vars.size; ++i) @@ -569,14 +668,16 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement // TODO: Optimization opportunity, the interior scope of the condition could be // reused for the then body, so we don't need to refine twice. ScopePtr condScope = childScope(ifStatement->condition, scope); - check(condScope, ifStatement->condition, std::nullopt); + auto [_, connective] = check(condScope, ifStatement->condition, std::nullopt); ScopePtr thenScope = childScope(ifStatement->thenbody, scope); + applyRefinements(thenScope, Location{}, connective); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { ScopePtr elseScope = childScope(ifStatement->elsebody, scope); + applyRefinements(elseScope, Location{}, connectiveArena.negation(connective)); visit(elseScope, ifStatement->elsebody); } } @@ -925,7 +1026,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton) { RecursionCounter counter{&recursionCount}; @@ -938,13 +1039,13 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st Inference result; if (auto group = expr->as()) - result = check(scope, group->expr, expectedType); + result = check(scope, group->expr, expectedType, forceSingleton); else if (auto stringExpr = expr->as()) - result = check(scope, stringExpr, expectedType); + result = check(scope, stringExpr, expectedType, forceSingleton); else if (expr->is()) result = Inference{singletonTypes->numberType}; else if (auto boolExpr = expr->as()) - result = check(scope, boolExpr, expectedType); + result = check(scope, boolExpr, expectedType, forceSingleton); else if (expr->is()) result = Inference{singletonTypes->nilType}; else if (auto local = expr->as()) @@ -999,8 +1100,11 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st return result; } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton) { + if (forceSingleton) + return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + if (expectedType) { const TypeId expectedTy = follow(*expectedType); @@ -1020,12 +1124,15 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantSt return Inference{singletonTypes->stringType}; } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) { + const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; + if (forceSingleton) + return Inference{singletonType}; + if (expectedType) { const TypeId expectedTy = follow(*expectedType); - const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; if (get(expectedTy) || get(expectedTy)) { @@ -1045,8 +1152,8 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) { std::optional resultTy; - - if (auto def = dfg->getDef(local)) + auto def = dfg->getDef(local); + if (def) resultTy = scope->lookup(*def); if (!resultTy) @@ -1058,7 +1165,10 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* loc if (!resultTy) return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. - return Inference{*resultTy}; + if (def) + return Inference{*resultTy, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + else + return Inference{*resultTy}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) @@ -1107,20 +1217,23 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { - TypeId operandType = check(scope, unary->expr).ty; + auto [operandType, connective] = check(scope, unary->expr); TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); - return Inference{resultType}; + + if (unary->op == AstExprUnary::Not) + return Inference{resultType, connectiveArena.negation(connective)}; + else + return Inference{resultType}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) { - TypeId leftType = check(scope, binary->left, expectedType).ty; - TypeId rightType = check(scope, binary->right, expectedType).ty; + auto [leftType, rightType, connective] = checkBinary(scope, binary, expectedType); TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); - return Inference{resultType}; + return Inference{resultType, std::move(connective)}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) @@ -1147,6 +1260,58 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssert return Inference{resolveType(scope, typeAssert->annotation)}; } +std::tuple ConstraintGraphBuilder::checkBinary( + const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) +{ + if (binary->op == AstExprBinary::And) + { + auto [leftType, leftConnective] = check(scope, binary->left, expectedType); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, leftConnective); + auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, connectiveArena.conjunction(leftConnective, rightConnective)}; + } + else if (binary->op == AstExprBinary::Or) + { + auto [leftType, leftConnective] = check(scope, binary->left, expectedType); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, connectiveArena.negation(leftConnective)); + auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, connectiveArena.disjunction(leftConnective, rightConnective)}; + } + else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) + { + TypeId leftType = check(scope, binary->left, expectedType, true).ty; + TypeId rightType = check(scope, binary->right, expectedType, true).ty; + + ConnectiveId leftConnective = nullptr; + if (auto def = dfg->getDef(binary->left)) + leftConnective = connectiveArena.proposition(*def, rightType); + + ConnectiveId rightConnective = nullptr; + if (auto def = dfg->getDef(binary->right)) + rightConnective = connectiveArena.proposition(*def, leftType); + + if (binary->op == AstExprBinary::CompareNe) + { + leftConnective = connectiveArena.negation(leftConnective); + rightConnective = connectiveArena.negation(rightConnective); + } + + return {leftType, rightType, connectiveArena.equivalence(leftConnective, rightConnective)}; + } + else + { + TypeId leftType = check(scope, binary->left, expectedType).ty; + TypeId rightType = check(scope, binary->right, expectedType).ty; + return {leftType, rightType, nullptr}; + } +} + TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) { std::vector types; @@ -1841,9 +2006,13 @@ std::vector> ConstraintGraphBuilder:: Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) { - auto [tp] = pack; + const auto& [tp, connectives] = pack; + ConnectiveId connective = nullptr; + if (!connectives.empty()) + connective = connectives[0]; + if (auto f = first(tp)) - return Inference{*f}; + return Inference{*f, connective}; TypeId typeResult = freshType(scope); TypePack onePack{{typeResult}, freshTypePack(scope)}; @@ -1851,7 +2020,7 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); - return Inference{typeResult}; + return Inference{typeResult, connective}; } void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 5e43be0f8..c53ac659a 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -440,8 +440,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); - else if (auto rc = get(*constraint)) - success = tryDispatch(*rc, constraint); + else if (auto sottc = get(*constraint)) + success = tryDispatch(*sottc, constraint); else LUAU_ASSERT(false); @@ -1274,25 +1274,18 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull constraint) +bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint) { - // TODO: Figure out exact details on when refinements need to be blocked. - // It's possible that it never needs to be, since we can just use intersection types with the discriminant type? - - if (!constraint->scope->parent) - iceReporter.ice("No parent scope"); - - std::optional previousTy = constraint->scope->parent->lookup(c.def); - if (!previousTy) - iceReporter.ice("No previous type"); + if (isBlocked(c.discriminantType)) + return false; - std::optional useTy = constraint->scope->lookup(c.def); - if (!useTy) - iceReporter.ice("The def is not bound to a type"); + TypeId followed = follow(c.discriminantType); - TypeId resultTy = follow(*useTy); - std::vector parts{*previousTy, c.discriminantType}; - asMutable(resultTy)->ty.emplace(std::move(parts)); + // `nil` is a singleton type too! There's only one value of type `nil`. + if (get(followed) || isNil(followed)) + *asMutable(c.resultType) = NegationTypeVar{c.discriminantType}; + else + *asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType}; return true; } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 67abbff1f..339de9755 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -13,16 +13,16 @@ declare bit32: { bor: (...number) -> number, bxor: (...number) -> number, btest: (number, ...number) -> boolean, - rrotate: (number, number) -> number, - lrotate: (number, number) -> number, - lshift: (number, number) -> number, - arshift: (number, number) -> number, - rshift: (number, number) -> number, - bnot: (number) -> number, - extract: (number, number, number?) -> number, - replace: (number, number, number, number?) -> number, - countlz: (number) -> number, - countrz: (number) -> number, + rrotate: (x: number, disp: number) -> number, + lrotate: (x: number, disp: number) -> number, + lshift: (x: number, disp: number) -> number, + arshift: (x: number, disp: number) -> number, + rshift: (x: number, disp: number) -> number, + bnot: (x: number) -> number, + extract: (n: number, field: number, width: number?) -> number, + replace: (n: number, v: number, field: number, width: number?) -> number, + countlz: (n: number) -> number, + countrz: (n: number) -> number, } declare math: { @@ -93,9 +93,9 @@ type DateTypeResult = { } declare os: { - time: (DateTypeArg?) -> number, - date: (string?, number?) -> DateTypeResult | string, - difftime: (DateTypeResult | number, DateTypeResult | number) -> number, + time: (time: DateTypeArg?) -> number, + date: (formatString: string?, time: number?) -> DateTypeResult | string, + difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, clock: () -> number, } @@ -145,51 +145,51 @@ declare function loadstring(src: string, chunkname: string?): (((A...) -> declare function newproxy(mt: boolean?): any declare coroutine: { - create: ((A...) -> R...) -> thread, - resume: (thread, A...) -> (boolean, R...), + create: (f: (A...) -> R...) -> thread, + resume: (co: thread, A...) -> (boolean, R...), running: () -> thread, - status: (thread) -> "dead" | "running" | "normal" | "suspended", + status: (co: thread) -> "dead" | "running" | "normal" | "suspended", -- FIXME: This technically returns a function, but we can't represent this yet. - wrap: ((A...) -> R...) -> any, + wrap: (f: (A...) -> R...) -> any, yield: (A...) -> R..., isyieldable: () -> boolean, - close: (thread) -> (boolean, any) + close: (co: thread) -> (boolean, any) } declare table: { - concat: ({V}, string?, number?, number?) -> string, - insert: (({V}, V) -> ()) & (({V}, number, V) -> ()), - maxn: ({V}) -> number, - remove: ({V}, number?) -> V?, - sort: ({V}, ((V, V) -> boolean)?) -> (), - create: (number, V?) -> {V}, - find: ({V}, V, number?) -> number?, - - unpack: ({V}, number?, number?) -> ...V, + concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, + insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), + maxn: (t: {V}) -> number, + remove: (t: {V}, number?) -> V?, + sort: (t: {V}, comp: ((V, V) -> boolean)?) -> (), + create: (count: number, value: V?) -> {V}, + find: (haystack: {V}, needle: V, init: number?) -> number?, + + unpack: (list: {V}, i: number?, j: number?) -> ...V, pack: (...V) -> { n: number, [number]: V }, - getn: ({V}) -> number, - foreach: ({[K]: V}, (K, V) -> ()) -> (), + getn: (t: {V}) -> number, + foreach: (t: {[K]: V}, f: (K, V) -> ()) -> (), foreachi: ({V}, (number, V) -> ()) -> (), - move: ({V}, number, number, number, {V}?) -> {V}, - clear: ({[K]: V}) -> (), + move: (src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V}, + clear: (table: {[K]: V}) -> (), - isfrozen: ({[K]: V}) -> boolean, + isfrozen: (t: {[K]: V}) -> boolean, } declare debug: { - info: ((thread, number, string) -> R...) & ((number, string) -> R...) & (((A...) -> R1..., string) -> R2...), - traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), + info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), + traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), } declare utf8: { char: (...number) -> string, charpattern: string, - codes: (string) -> ((string, number) -> (number, number), string, number), - codepoint: (string, number?, number?) -> ...number, - len: (string, number?, number?) -> (number?, number?), - offset: (string, number?, number?) -> number, + codes: (str: string) -> ((string, number) -> (number, number), string, number), + codepoint: (str: string, i: number?, j: number?) -> ...number, + len: (s: string, i: number?, j: number?) -> (number?, number?), + offset: (s: string, n: number?, i: number?) -> number, } -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 5ef4b7e7c..21e9f7874 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -7,9 +7,9 @@ #include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/RecursionCounter.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" -#include "Luau/VisitTypeVar.h" LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) @@ -20,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false); +LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); @@ -206,6 +207,28 @@ bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& s return true; } +NormalizedFunctionType::NormalizedFunctionType() + : parts(FFlag::LuauNegatedFunctionTypes ? std::optional{TypeIds{}} : std::nullopt) +{ +} + +void NormalizedFunctionType::resetToTop() +{ + isTop = true; + parts.emplace(); +} + +void NormalizedFunctionType::resetToNever() +{ + isTop = false; + parts.emplace(); +} + +bool NormalizedFunctionType::isNever() const +{ + return !isTop && (!parts || parts->empty()); +} + NormalizedType::NormalizedType(NotNull singletonTypes) : tops(singletonTypes->neverType) , booleans(singletonTypes->neverType) @@ -220,8 +243,8 @@ NormalizedType::NormalizedType(NotNull singletonTypes) static bool isInhabited(const NormalizedType& norm) { return !get(norm.tops) || !get(norm.booleans) || !norm.classes.empty() || !get(norm.errors) || - !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || - !get(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty(); + !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || + !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } static int tyvarIndex(TypeId ty) @@ -317,10 +340,14 @@ static bool isNormalizedThread(TypeId ty) static bool areNormalizedFunctions(const NormalizedFunctionType& tys) { - if (tys) - for (TypeId ty : *tys) + if (tys.parts) + { + for (TypeId ty : *tys.parts) + { if (!get(ty) && !get(ty)) return false; + } + } return true; } @@ -420,7 +447,7 @@ void Normalizer::clearNormal(NormalizedType& norm) norm.strings.resetToNever(); norm.threads = singletonTypes->neverType; norm.tables.clear(); - norm.functions = std::nullopt; + norm.functions.resetToNever(); norm.tyvars.clear(); } @@ -809,20 +836,28 @@ std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (!theres) + if (FFlag::LuauNegatedFunctionTypes) + { + if (heres.isTop) + return; + if (theres.isTop) + heres.resetToTop(); + } + + if (theres.isNever()) return; TypeIds tmps; - if (!heres) + if (heres.isNever()) { - tmps.insert(theres->begin(), theres->end()); - heres = std::move(tmps); + tmps.insert(theres.parts->begin(), theres.parts->end()); + heres.parts = std::move(tmps); return; } - for (TypeId here : *heres) - for (TypeId there : *theres) + for (TypeId here : *heres.parts) + for (TypeId there : *theres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); @@ -830,28 +865,28 @@ void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedF tmps.insert(singletonTypes->errorRecoveryType(there)); } - heres = std::move(tmps); + heres.parts = std::move(tmps); } void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) { - if (!heres) + if (heres.isNever()) { TypeIds tmps; tmps.insert(there); - heres = std::move(tmps); + heres.parts = std::move(tmps); return; } TypeIds tmps; - for (TypeId here : *heres) + for (TypeId here : *heres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); else tmps.insert(singletonTypes->errorRecoveryType(there)); } - heres = std::move(tmps); + heres.parts = std::move(tmps); } void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) @@ -1004,6 +1039,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.strings.resetToString(); else if (ptv->type == PrimitiveTypeVar::Thread) here.threads = there; + else if (ptv->type == PrimitiveTypeVar::Function) + { + LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); + here.functions.resetToTop(); + } else LUAU_ASSERT(!"Unreachable"); } @@ -1036,8 +1076,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else if (const NegationTypeVar* ntv = get(there)) { const NormalizedType* thereNormal = normalize(ntv->ty); - NormalizedType tn = negateNormal(*thereNormal); - if (!unionNormals(here, tn)) + std::optional tn = negateNormal(*thereNormal); + if (!tn) + return false; + + if (!unionNormals(here, *tn)) return false; } else @@ -1053,7 +1096,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor // ------- Negations -NormalizedType Normalizer::negateNormal(const NormalizedType& here) +std::optional Normalizer::negateNormal(const NormalizedType& here) { NormalizedType result{singletonTypes}; if (!get(here.tops)) @@ -1092,10 +1135,24 @@ NormalizedType Normalizer::negateNormal(const NormalizedType& here) result.threads = get(here.threads) ? singletonTypes->threadType : singletonTypes->neverType; + /* + * Things get weird and so, so complicated if we allow negations of + * arbitrary function types. Ordinary code can never form these kinds of + * types, so we decline to negate them. + */ + if (FFlag::LuauNegatedFunctionTypes) + { + if (here.functions.isNever()) + result.functions.resetToTop(); + else if (here.functions.isTop) + result.functions.resetToNever(); + else + return std::nullopt; + } + // TODO: negating tables - // TODO: negating functions // TODO: negating tyvars? - + return result; } @@ -1142,21 +1199,25 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) LUAU_ASSERT(ptv); switch (ptv->type) { - case PrimitiveTypeVar::NilType: - here.nils = singletonTypes->neverType; - break; - case PrimitiveTypeVar::Boolean: - here.booleans = singletonTypes->neverType; - break; - case PrimitiveTypeVar::Number: - here.numbers = singletonTypes->neverType; - break; - case PrimitiveTypeVar::String: - here.strings.resetToNever(); - break; - case PrimitiveTypeVar::Thread: - here.threads = singletonTypes->neverType; - break; + case PrimitiveTypeVar::NilType: + here.nils = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Boolean: + here.booleans = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Number: + here.numbers = singletonTypes->neverType; + break; + case PrimitiveTypeVar::String: + here.strings.resetToNever(); + break; + case PrimitiveTypeVar::Thread: + here.threads = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Function: + LUAU_ASSERT(FFlag::LuauNegatedStringSingletons); + here.functions.resetToNever(); + break; } } @@ -1589,7 +1650,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th TypePackId argTypes; TypePackId retTypes; - + if (hftv->retTypes == tftv->retTypes) { std::optional argTypesOpt = unionOfTypePacks(hftv->argTypes, tftv->argTypes); @@ -1598,7 +1659,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th argTypes = *argTypesOpt; retTypes = hftv->retTypes; } - else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) + else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) { std::optional retTypesOpt = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); if (!retTypesOpt) @@ -1738,18 +1799,20 @@ std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId th void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) { - if (!heres) + if (heres.isNever()) return; - for (auto it = heres->begin(); it != heres->end();) + heres.isTop = false; + + for (auto it = heres.parts->begin(); it != heres.parts->end();) { TypeId here = *it; if (get(here)) it++; else if (std::optional tmp = intersectionOfFunctions(here, there)) { - heres->erase(it); - heres->insert(*tmp); + heres.parts->erase(it); + heres.parts->insert(*tmp); return; } else @@ -1757,27 +1820,27 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T } TypeIds tmps; - for (TypeId here : *heres) + for (TypeId here : *heres.parts) { if (std::optional tmp = unionSaturatedFunctions(here, there)) tmps.insert(*tmp); } - heres->insert(there); - heres->insert(tmps.begin(), tmps.end()); + heres.parts->insert(there); + heres.parts->insert(tmps.begin(), tmps.end()); } void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (!heres) + if (heres.isNever()) return; - else if (!theres) + else if (theres.isNever()) { - heres = std::nullopt; + heres.resetToNever(); return; } else { - for (TypeId there : *theres) + for (TypeId there : *theres.parts) intersectFunctionsWithFunction(heres, there); } } @@ -1935,6 +1998,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) TypeId nils = here.nils; TypeId numbers = here.numbers; NormalizedStringType strings = std::move(here.strings); + NormalizedFunctionType functions = std::move(here.functions); TypeId threads = here.threads; clearNormal(here); @@ -1949,6 +2013,11 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) here.strings = std::move(strings); else if (ptv->type == PrimitiveTypeVar::Thread) here.threads = threads; + else if (ptv->type == PrimitiveTypeVar::Function) + { + LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); + here.functions = std::move(functions); + } else LUAU_ASSERT(!"Unreachable"); } @@ -1981,8 +2050,10 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) for (TypeId part : itv->options) { const NormalizedType* normalPart = normalize(part); - NormalizedType negated = negateNormal(*normalPart); - intersectNormals(here, negated); + std::optional negated = negateNormal(*normalPart); + if (!negated) + return false; + intersectNormals(here, *negated); } } else @@ -2016,14 +2087,16 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) result.insert(result.end(), norm.classes.begin(), norm.classes.end()); if (!get(norm.errors)) result.push_back(norm.errors); - if (norm.functions) + if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop) + result.push_back(singletonTypes->functionType); + else if (!norm.functions.isNever()) { - if (norm.functions->size() == 1) - result.push_back(*norm.functions->begin()); + if (norm.functions.parts->size() == 1) + result.push_back(*norm.functions.parts->begin()); else { std::vector parts; - parts.insert(parts.end(), norm.functions->begin(), norm.functions->end()); + parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); } } @@ -2070,62 +2143,24 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) return arena->addType(UnionTypeVar{std::move(result)}); } -namespace -{ - -struct Replacer -{ - TypeArena* arena; - TypeId sourceType; - TypeId replacedType; - DenseHashMap newTypes; - - Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) - : arena(arena) - , sourceType(sourceType) - , replacedType(replacedType) - , newTypes(nullptr) - { - } - - TypeId smartClone(TypeId t) - { - t = follow(t); - TypeId* res = newTypes.find(t); - if (res) - return *res; - - TypeId result = shallowClone(t, *arena, TxnLog::empty()); - newTypes[t] = result; - newTypes[result] = result; - - return result; - } -}; - -} // anonymous namespace - -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.anyIsTop = anyIsTop; u.tryUnify(subTy, superTy); const bool ok = u.errors.empty() && u.log.empty(); return ok; } -bool isSubtype( - TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.anyIsTop = anyIsTop; u.tryUnify(subPack, superPack); const bool ok = u.errors.empty() && u.log.empty(); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 44000647a..48215f244 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,12 +10,12 @@ #include #include -LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauLvaluelessPath) -LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) -LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) +LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) /* * Prefix generic typenames with gen- @@ -225,6 +225,20 @@ struct StringifierState result.name += s; } + void emitLevel(Scope* scope) + { + size_t count = 0; + for (Scope* s = scope; s; s = s->parent.get()) + ++count; + + emit(count); + emit("-"); + char buffer[16]; + uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); + snprintf(buffer, sizeof(buffer), "0x%x", s); + emit(buffer); + } + void emit(TypeLevel level) { emit(std::to_string(level.level)); @@ -296,10 +310,7 @@ struct TypeVarStringifier if (tv->ty.valueless_by_exception()) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("* VALUELESS BY EXCEPTION *"); - else - state.emit("< VALUELESS BY EXCEPTION >"); + state.emit("* VALUELESS BY EXCEPTION *"); return; } @@ -376,7 +387,10 @@ struct TypeVarStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(ftv.level); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(ftv.scope); + else + state.emit(ftv.level); } } @@ -398,6 +412,15 @@ struct TypeVarStringifier } else state.emit(state.getName(ty)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(gtv.scope); + else + state.emit(gtv.level); + } } void operator()(TypeId, const BlockedTypeVar& btv) @@ -433,6 +456,9 @@ struct TypeVarStringifier case PrimitiveTypeVar::Thread: state.emit("thread"); return; + case PrimitiveTypeVar::Function: + state.emit("function"); + return; default: LUAU_ASSERT(!"Unknown primitive type"); throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type)); @@ -461,10 +487,7 @@ struct TypeVarStringifier if (state.hasSeen(&ftv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -572,10 +595,7 @@ struct TypeVarStringifier if (state.hasSeen(&ttv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -709,10 +729,7 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -779,10 +796,7 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -827,10 +841,7 @@ struct TypeVarStringifier void operator()(TypeId, const ErrorTypeVar& tv) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); - else - state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } void operator()(TypeId, const LazyTypeVar& ltv) @@ -849,11 +860,6 @@ struct TypeVarStringifier state.emit("never"); } - void operator()(TypeId ty, const UseTypeVar&) - { - stringify(follow(ty)); - } - void operator()(TypeId, const NegationTypeVar& ntv) { state.emit("~"); @@ -906,10 +912,7 @@ struct TypePackStringifier if (tp->ty.valueless_by_exception()) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("* VALUELESS TP BY EXCEPTION *"); - else - state.emit("< VALUELESS TP BY EXCEPTION >"); + state.emit("* VALUELESS TP BY EXCEPTION *"); return; } @@ -932,10 +935,7 @@ struct TypePackStringifier if (state.hasSeen(&tp)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLETP*"); - else - state.emit(""); + state.emit("*CYCLETP*"); return; } @@ -980,10 +980,7 @@ struct TypePackStringifier void operator()(TypePackId, const Unifiable::Error& error) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); - else - state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } void operator()(TypePackId, const VariadicTypePack& pack) @@ -991,10 +988,7 @@ struct TypePackStringifier state.emit("..."); if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) { - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*hidden*"); - else - state.emit(""); + state.emit("*hidden*"); } stringify(pack.ty); } @@ -1029,7 +1023,10 @@ struct TypePackStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(pack.level); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(pack.scope); + else + state.emit(pack.level); } state.emit("..."); @@ -1197,10 +1194,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) { result.truncated = true; - if (FFlag::LuauSpecialTypesAsterisked) - result.name += "... *TRUNCATED*"; - else - result.name += "... "; + result.name += "... *TRUNCATED*"; } return result; @@ -1270,10 +1264,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) { - if (FFlag::LuauSpecialTypesAsterisked) - result.name += "... *TRUNCATED*"; - else - result.name += "... "; + result.name += "... *TRUNCATED*"; } return result; @@ -1516,9 +1507,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - return "TODO"; + std::string result = tos(c.resultType, opts); + std::string discriminant = tos(c.discriminantType, opts); + + return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; } else static_assert(always_false_v, "Non-exhaustive constraint switch"); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 179846d7c..c97ed05d2 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -338,12 +338,6 @@ class TypeRehydrationVisitor { return allocator->alloc(Location(), std::nullopt, AstName{"never"}); } - AstType* operator()(const UseTypeVar& utv) - { - std::optional ty = utv.scope->lookup(utv.def); - LUAU_ASSERT(ty); - return Luau::visit(*this, (*ty)->ty); - } AstType* operator()(const NegationTypeVar& ntv) { // FIXME: do the same thing we do with ErrorTypeVar diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a26731586..dde41a65f 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -301,7 +301,6 @@ struct TypeChecker2 UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; - u.anyIsTop = true; u.tryUnify(actualRetType, expectedRetType); const bool ok = u.errors.empty() && u.log.empty(); @@ -331,16 +330,21 @@ struct TypeChecker2 if (value) visit(value); - if (i != local->values.size - 1) + TypeId* maybeValueType = value ? module->astTypes.find(value) : nullptr; + if (i != local->values.size - 1 || maybeValueType) { AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; if (var && var->annotation) { - TypeId varType = lookupAnnotation(var->annotation); + TypeId annotationType = lookupAnnotation(var->annotation); TypeId valueType = value ? lookupType(value) : nullptr; - if (valueType && !isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) - reportError(TypeMismatch{varType, valueType}, value->location); + if (valueType) + { + ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType); + if (!errors.empty()) + reportErrors(std::move(errors)); + } } } else @@ -606,7 +610,7 @@ struct TypeChecker2 visit(rhs); TypeId rhsType = lookupType(rhs); - if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } @@ -757,7 +761,7 @@ struct TypeChecker2 TypeId actualType = lookupType(number); TypeId numberType = singletonTypes->numberType; - if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -768,7 +772,7 @@ struct TypeChecker2 TypeId actualType = lookupType(string); TypeId stringType = singletonTypes->stringType; - if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -857,7 +861,7 @@ struct TypeChecker2 FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice)) { CloneState cloneState; expectedType = clone(expectedType, module->internalTypes, cloneState); @@ -876,7 +880,7 @@ struct TypeChecker2 getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); if (ty) { - if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{resultType, *ty}, indexName->location); } @@ -909,7 +913,7 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); } @@ -1203,10 +1207,10 @@ struct TypeChecker2 TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice)) return; - if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice)) return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); @@ -1507,7 +1511,6 @@ struct TypeChecker2 UnifierSharedState sharedState{&ice}; Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; - u.anyIsTop = true; u.tryUnify(subTy, superTy); return std::move(u.errors); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 94d633c78..de0890e18 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -57,13 +57,6 @@ TypeId follow(TypeId t, std::function mapper) return btv->boundTo; else if (auto ttv = get(mapper(ty))) return ttv->boundTo; - else if (auto utv = get(mapper(ty))) - { - std::optional ty = utv->scope->lookup(utv->def); - if (!ty) - throwRuntimeError("UseTypeVar must map to another TypeId"); - return *ty; - } else return std::nullopt; }; @@ -761,6 +754,7 @@ SingletonTypes::SingletonTypes() , stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true})) , booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true})) , threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true})) + , functionType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Function}, /*persistent*/ true})) , trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true})) , falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true})) , anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true})) @@ -946,7 +940,8 @@ void persist(TypeId ty) queue.push_back(mtv->table); queue.push_back(mtv->metatable); } - else if (get(t) || get(t) || get(t) || get(t) || get(t) || get(t)) + else if (get(t) || get(t) || get(t) || get(t) || get(t) || + get(t)) { } else diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index b5eba9803..df5d86f1e 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -8,6 +8,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" #include "Luau/ToString.h" @@ -23,6 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauNegatedFunctionTypes) namespace Luau { @@ -363,7 +365,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); return; } @@ -404,7 +406,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) { // TODO: a more informative error message? CLI-39912 - reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}}); + reportError(location, GenericError{"Generic subtype escaping scope"}); return; } @@ -433,7 +435,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 - reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}}); + reportError(location, GenericError{"Generic supertype escaping scope"}); return; } @@ -450,15 +452,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return tryUnifyWithAny(subTy, superTy); if (get(subTy)) - { - if (anyIsTop) - { - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - return; - } - else - return tryUnifyWithAny(superTy, subTy); - } + return tryUnifyWithAny(superTy, subTy); if (log.get(subTy)) return tryUnifyWithAny(superTy, subTy); @@ -478,7 +472,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) { - reportError(TypeError{location, *error}); + reportError(location, *error); return; } } @@ -520,6 +514,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); + else if (auto ptv = get(superTy); + FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveTypeVar::Function && get(subTy)) + { + // Ok. Do nothing. forall functions F, F <: function + } + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyFunctions(subTy, superTy, isFunctionCall); @@ -559,7 +559,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool tryUnifyNegationWithType(subTy, superTy); else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); if (cacheEnabled) cacheResult(subTy, superTy, errorCount); @@ -633,9 +633,9 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, else if (failed) { if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}); else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } } @@ -734,7 +734,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else @@ -743,9 +743,9 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp else if (!found) { if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}); else - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}); } } @@ -774,7 +774,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I if (unificationTooComplex) reportError(*unificationTooComplex); else if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}); } void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) @@ -832,11 +832,11 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV if (subNorm && superNorm) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } else if (!found) { - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}); } } @@ -848,37 +848,37 @@ void Unifier::tryUnifyNormalizedTypes( if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) return; else if (get(subNorm.tops)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.errors)) if (!get(superNorm.errors)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.booleans)) { if (!get(superNorm.booleans)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } else if (const SingletonTypeVar* stv = get(subNorm.booleans)) { if (!get(superNorm.booleans) && stv != get(superNorm.booleans)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } if (get(subNorm.nils)) if (!get(superNorm.nils)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.numbers)) if (!get(superNorm.numbers)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (!isSubtype(subNorm.strings, superNorm.strings)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.threads)) if (!get(superNorm.errors)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); for (TypeId subClass : subNorm.classes) { @@ -894,7 +894,7 @@ void Unifier::tryUnifyNormalizedTypes( } } if (!found) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } for (TypeId subTable : subNorm.tables) @@ -919,21 +919,19 @@ void Unifier::tryUnifyNormalizedTypes( return reportError(*e); } if (!found) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } - if (subNorm.functions) + if (!subNorm.functions.isNever()) { - if (!superNorm.functions) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); - if (superNorm.functions->empty()) - return; - for (TypeId superFun : *superNorm.functions) + if (superNorm.functions.isNever()) + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + for (TypeId superFun : *superNorm.functions.parts) { Unifier innerState = makeChildUnifier(); const FunctionTypeVar* superFtv = get(superFun); if (!superFtv) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); innerState.tryUnify_(tgt, superFtv->retTypes); if (innerState.errors.empty()) @@ -941,7 +939,7 @@ void Unifier::tryUnifyNormalizedTypes( else if (auto e = hasUnificationTooComplex(innerState.errors)) return reportError(*e); else - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } } @@ -959,15 +957,15 @@ void Unifier::tryUnifyNormalizedTypes( TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args) { - if (!overloads || overloads->empty()) + if (overloads.isNever()) { - reportError(TypeError{location, CannotCallNonFunction{function}}); + reportError(location, CannotCallNonFunction{function}); return singletonTypes->errorRecoveryTypePack(); } std::optional result; const FunctionTypeVar* firstFun = nullptr; - for (TypeId overload : *overloads) + for (TypeId overload : *overloads.parts) { if (const FunctionTypeVar* ftv = get(overload)) { @@ -1015,12 +1013,12 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized // TODO: better error reporting? // The logic for error reporting overload resolution // is currently over in TypeInfer.cpp, should we move it? - reportError(TypeError{location, GenericError{"No matching overload."}}); + reportError(location, GenericError{"No matching overload."}); return singletonTypes->errorRecoveryTypePack(firstFun->retTypes); } else { - reportError(TypeError{location, CannotCallNonFunction{function}}); + reportError(location, CannotCallNonFunction{function}); return singletonTypes->errorRecoveryTypePack(); } } @@ -1199,7 +1197,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); return; } @@ -1372,7 +1370,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult) std::swap(expectedSize, actualSize); - reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}}); + reportError(location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}); while (superIter.good()) { @@ -1394,9 +1392,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal else { if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure) - reportError(TypeError{location, TypePackMismatch{subTp, superTp}}); + reportError(location, TypePackMismatch{subTp, superTp}); else - reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); + reportError(location, GenericError{"Failed to unify type packs"}); } } @@ -1408,7 +1406,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) ice("passed non primitive types to unifyPrimitives"); if (superPrim->type != subPrim->type) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) @@ -1429,7 +1427,7 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) @@ -1465,21 +1463,21 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal } else { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } } else if (numGenerics != subFunction->generics.size()) { numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); - reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}); } if (numGenericPacks != subFunction->genericPacks.size()) { numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); - reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}); } for (size_t i = 0; i < numGenerics; i++) @@ -1506,11 +1504,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError( - TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}); innerState.ctx = CountMismatch::FunctionResult; innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); @@ -1520,13 +1517,12 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError( - TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}); } log.concat(std::move(innerState.log)); @@ -1608,7 +1604,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } else { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } } } @@ -1626,7 +1622,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)}); return; } } @@ -1644,7 +1640,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!extraProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}); return; } } @@ -1825,13 +1821,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)}); return; } if (!extraProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}); return; } @@ -1867,14 +1863,14 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) std::swap(subTy, superTy); if (auto ttv = log.get(superTy); !ttv || ttv->state != TableState::Free) - return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); + return reportError(location, TypeMismatch{osuperTy, osubTy}); auto fail = [&](std::optional e) { std::string reason = "The former's metatable does not satisfy the requirements."; if (e) - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}}); + reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e}); else - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}}); + reportError(location, TypeMismatch{osuperTy, osubTy, reason}); }; // Given t1 where t1 = { lower: (t1) -> (a, b...) } @@ -1906,7 +1902,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) } } - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); + reportError(location, TypeMismatch{osuperTy, osubTy}); return; } @@ -1947,7 +1943,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}); log.concat(std::move(innerState.log)); } @@ -2024,9 +2020,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) auto fail = [&]() { if (!reversed) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); else - reportError(TypeError{location, TypeMismatch{subTy, superTy}}); + reportError(location, TypeMismatch{subTy, superTy}); }; const ClassTypeVar* superClass = get(superTy); @@ -2071,7 +2067,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (!classProp) { ok = false; - reportError(TypeError{location, UnknownProperty{superTy, propName}}); + reportError(location, UnknownProperty{superTy, propName}); } else { @@ -2095,7 +2091,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; - reportError(TypeError{location, GenericError{msg}}); + reportError(location, GenericError{msg}); } if (!ok) @@ -2116,13 +2112,13 @@ void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy) const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - return reportError(TypeError{location, UnificationTooComplex{}}); + return reportError(location, UnificationTooComplex{}); // T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) @@ -2195,7 +2191,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else if (get(tail)) { - reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); + reportError(location, GenericError{"Cannot unify variadic and generic packs"}); } else if (get(tail)) { @@ -2209,7 +2205,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else { - reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}}); + reportError(location, GenericError{"Failed to unify variadic packs"}); } } @@ -2351,7 +2347,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - reportError(TypeError{location, OccursCheckFailed{}}); + reportError(location, OccursCheckFailed{}); log.replace(needle, *singletonTypes->errorRecoveryType()); return true; @@ -2402,7 +2398,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { if (needle == haystack) { - reportError(TypeError{location, OccursCheckFailed{}}); + reportError(location, OccursCheckFailed{}); log.replace(needle, *singletonTypes->errorRecoveryTypePack()); return true; @@ -2423,18 +2419,31 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; - u.anyIsTop = anyIsTop; u.normalize = normalize; + u.useScopes = useScopes; return u; } // A utility function that appends the given error to the unifier's error log. // This allows setting a breakpoint wherever the unifier reports an error. +// +// Note: report error accepts its arguments by value intentionally to reduce the stack usage of functions which call `reportError`. +void Unifier::reportError(Location location, TypeErrorData data) +{ + errors.emplace_back(std::move(location), std::move(data)); +} + +// A utility function that appends the given error to the unifier's error log. +// This allows setting a breakpoint wherever the unifier reports an error. +// +// Note: to conserve stack space in calling functions it is generally preferred to call `Unifier::reportError(Location location, TypeErrorData data)` +// instead of this method. void Unifier::reportError(TypeError err) { errors.push_back(std::move(err)); } + bool Unifier::isNonstrictMode() const { return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); @@ -2445,7 +2454,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId if (auto e = hasUnificationTooComplex(innerErrors)) reportError(*e); else if (!innerErrors.empty()) - reportError(TypeError{location, TypeMismatch{wantedType, givenType}}); + reportError(location, TypeMismatch{wantedType, givenType}); } void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index d93f2ccb6..66436acde 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -641,8 +641,8 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT return brokenDoubleBrace; } - Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset); consume(); + Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset - 1); return lexemeOutput; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 85c5f5c60..4c0cc1251 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -25,6 +25,7 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) +LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false) bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false; @@ -2310,9 +2311,13 @@ AstExpr* Parser::parseTableConstructor() MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table literal"); + unsigned lastElementIndent = 0; while (lexer.current().type != '}') { + if (FFlag::LuauTableConstructorRecovery) + lastElementIndent = lexer.current().location.begin.column; + if (lexer.current().type == '[') { MatchLexeme matchLocationBracket = lexer.current(); @@ -2357,10 +2362,14 @@ AstExpr* Parser::parseTableConstructor() { nextLexeme(); } - else + else if (FFlag::LuauTableConstructorRecovery && (lexer.current().type == '[' || lexer.current().type == Lexeme::Name) && + lexer.current().location.begin.column == lastElementIndent) { - if (lexer.current().type != '}') - break; + report(lexer.current().location, "Expected ',' after table constructor element"); + } + else if (lexer.current().type != '}') + { + break; } } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 87e19db8b..e567725e5 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -978,7 +978,8 @@ int replMain(int argc, char** argv) if (compileFormat == CompileFormat::Null) printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024)); else if (compileFormat == CompileFormat::CodegenNull) - printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024)); + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), + int(stats.codegen / 1024)); return failed ? 1 : 0; } diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h new file mode 100644 index 000000000..351e67151 --- /dev/null +++ b/CodeGen/include/Luau/AddressA64.h @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterA64.h" + +namespace Luau +{ +namespace CodeGen +{ + +enum class AddressKindA64 : uint8_t +{ + imm, // reg + imm + reg, // reg + reg + + // TODO: + // reg + reg << shift + // reg + sext(reg) << shift + // reg + uext(reg) << shift + // pc + offset +}; + +struct AddressA64 +{ + AddressA64(RegisterA64 base, int off = 0) + : kind(AddressKindA64::imm) + , base(base) + , offset(xzr) + , data(off) + { + LUAU_ASSERT(base.kind == KindA64::x); + LUAU_ASSERT(off >= 0 && off < 4096); + } + + AddressA64(RegisterA64 base, RegisterA64 offset) + : kind(AddressKindA64::reg) + , base(base) + , offset(offset) + , data(0) + { + LUAU_ASSERT(base.kind == KindA64::x); + LUAU_ASSERT(offset.kind == KindA64::x); + } + + AddressKindA64 kind; + RegisterA64 base; + RegisterA64 offset; + int data; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h new file mode 100644 index 000000000..9a1402bec --- /dev/null +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -0,0 +1,144 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterA64.h" +#include "Luau/AddressA64.h" +#include "Luau/ConditionA64.h" +#include "Luau/Label.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +class AssemblyBuilderA64 +{ +public: + explicit AssemblyBuilderA64(bool logText); + ~AssemblyBuilderA64(); + + // Moves + void mov(RegisterA64 dst, RegisterA64 src); + void mov(RegisterA64 dst, uint16_t src, int shift = 0); + void movk(RegisterA64 dst, uint16_t src, int shift = 0); + + // Arithmetics + void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void add(RegisterA64 dst, RegisterA64 src1, int src2); + void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void sub(RegisterA64 dst, RegisterA64 src1, int src2); + void neg(RegisterA64 dst, RegisterA64 src); + + // Comparisons + // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm + // TODO: add cmp + + // Binary + // Note: shifted-register support and bitfield operations are omitted for simplicity + // TODO: support immediate arguments (they have odd encoding and forbid many values) + // TODO: support not variants for and/or/eor (required to support not...) + void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void clz(RegisterA64 dst, RegisterA64 src); + void rbit(RegisterA64 dst, RegisterA64 src); + + // Load + // Note: paired loads are currently omitted for simplicity + void ldr(RegisterA64 dst, AddressA64 src); + void ldrb(RegisterA64 dst, AddressA64 src); + void ldrh(RegisterA64 dst, AddressA64 src); + void ldrsb(RegisterA64 dst, AddressA64 src); + void ldrsh(RegisterA64 dst, AddressA64 src); + void ldrsw(RegisterA64 dst, AddressA64 src); + + // Store + void str(RegisterA64 src, AddressA64 dst); + void strb(RegisterA64 src, AddressA64 dst); + void strh(RegisterA64 src, AddressA64 dst); + + // Control flow + // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks + void b(ConditionA64 cond, Label& label); + void cbz(RegisterA64 src, Label& label); + void cbnz(RegisterA64 src, Label& label); + void ret(); + + // Run final checks + bool finalize(); + + // Places a label at current location and returns it + Label setLabel(); + + // Assigns label position to the current location + void setLabel(Label& label); + + void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); + + uint32_t getCodeSize() const; + + // Resulting data and code that need to be copied over one after the other + // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' + std::vector data; + std::vector code; + + std::string text; + + const bool logText = false; + +private: + // Instruction archetypes + void place0(const char* name, uint32_t word); + void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0); + void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op); + void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); + void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); + void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); + void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); + void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size); + void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); + void placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond); + + void place(uint32_t word); + void placeLabel(Label& label); + + void commit(); + LUAU_NOINLINE void extend(); + + // Data + size_t allocateData(size_t size, size_t align); + + // Logging of assembly in text form + LUAU_NOINLINE void log(const char* opcode); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); + LUAU_NOINLINE void log(const char* opcode, Label label); + LUAU_NOINLINE void log(Label label); + LUAU_NOINLINE void log(RegisterA64 reg); + LUAU_NOINLINE void log(AddressA64 addr); + + uint32_t nextLabel = 1; + std::vector(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)"); -// // this error message is not great since the underlying issue is that the context is invariant, + // LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' + // caused by: + // Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type + // parameters)"); + // // this error message is not great since the underlying issue is that the context is invariant, // and `(number) -> number` cannot be a subtype of `(a) -> a`. } @@ -3335,10 +3335,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated") ToStringOptions opts; opts.exhaustive = true; - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts)); - else - CHECK_EQ("{ x: , y: number }", toString(requireType("t"), opts)); + CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts)); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index dff9649fb..6c7201a64 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -17,7 +17,6 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping); -LUAU_FASTFLAG(LuauSpecialTypesAsterisked); using namespace Luau; @@ -238,20 +237,10 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") // TODO: Should we assert anything about these tests when DCR is being used? if (!FFlag::DebugLuauDeferredConstraintResolution) { - if (FFlag::LuauSpecialTypesAsterisked) - { - CHECK_EQ("*error-type*", toString(requireType("c"))); - CHECK_EQ("*error-type*", toString(requireType("d"))); - CHECK_EQ("*error-type*", toString(requireType("e"))); - CHECK_EQ("*error-type*", toString(requireType("f"))); - } - else - { - CHECK_EQ("", toString(requireType("c"))); - CHECK_EQ("", toString(requireType("d"))); - CHECK_EQ("", toString(requireType("e"))); - CHECK_EQ("", toString(requireType("f"))); - } + CHECK_EQ("*error-type*", toString(requireType("c"))); + CHECK_EQ("*error-type*", toString(requireType("d"))); + CHECK_EQ("*error-type*", toString(requireType("e"))); + CHECK_EQ("*error-type*", toString(requireType("f"))); } } @@ -662,10 +651,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(t0->type)); - else - CHECK_EQ("", toString(t0->type)); + CHECK_EQ("*error-type*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index c178d2a4e..f04a3d950 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -9,8 +9,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauSpecialTypesAsterisked) - struct TryUnifyFixture : Fixture { TypeArena arena; @@ -124,10 +122,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("a", toString(requireType("a"))); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("b"))); - else - CHECK_EQ("", toString(requireType("b"))); + CHECK_EQ("*error-type*", toString(requireType("b"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") @@ -142,10 +137,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_con LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("a", toString(requireType("a"))); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("b"))); - else - CHECK_EQ("", toString(requireType("b"))); + CHECK_EQ("*error-type*", toString(requireType("b"))); CHECK_EQ("number", toString(requireType("c"))); } diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index dc5516345..0c25386f7 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -6,8 +6,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauSpecialTypesAsterisked) - using namespace Luau; TEST_SUITE_BEGIN("UnionTypes"); @@ -199,10 +197,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("r"))); - else - CHECK_EQ("", toString(requireType("r"))); + CHECK_EQ("*error-type*", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") diff --git a/tests/Variant.test.cpp b/tests/Variant.test.cpp index aa0731ca4..83eec519a 100644 --- a/tests/Variant.test.cpp +++ b/tests/Variant.test.cpp @@ -217,4 +217,35 @@ TEST_CASE("Visit") CHECK(r3 == "1231147"); } +struct MoveOnly +{ + MoveOnly() = default; + + MoveOnly(const MoveOnly&) = delete; + MoveOnly& operator=(const MoveOnly&) = delete; + + MoveOnly(MoveOnly&&) = default; + MoveOnly& operator=(MoveOnly&&) = default; +}; + +TEST_CASE("Move") +{ + Variant v1 = MoveOnly{}; + Variant v2 = std::move(v1); +} + +TEST_CASE("MoveWithCopyableAlternative") +{ + Variant v1 = std::string{"Hello, world! I am longer than a normal hello world string to avoid SSO."}; + Variant v2 = std::move(v1); + + std::string* s1 = get_if(&v1); + REQUIRE(s1); + CHECK(*s1 == ""); + + std::string* s2 = get_if(&v2); + REQUIRE(s2); + CHECK(*s2 == "Hello, world! I am longer than a normal hello world string to avoid SSO."); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index a4c05b7bf..4ac2b357d 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -10,7 +10,9 @@ AnnotationTests.too_many_type_params AnnotationTests.two_type_params AnnotationTests.unknown_type_reference_generates_error AstQuery.last_argument_function_call_type +AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method AstQuery::getDocumentationSymbolAtPosition.overloaded_fn +AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop AutocompleteTest.autocomplete_first_function_arg_expected_type AutocompleteTest.autocomplete_interpolated_string AutocompleteTest.autocomplete_oop_implicit_self @@ -106,17 +108,14 @@ GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_infer_generic_functions GenericsTests.duplicate_generic_type_packs GenericsTests.duplicate_generic_types -GenericsTests.factories_of_generics GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_factories -GenericsTests.generic_functions_in_types GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_table_method GenericsTests.generic_type_pack_parentheses GenericsTests.generic_type_pack_unification1 GenericsTests.generic_type_pack_unification2 -GenericsTests.generic_type_pack_unification3 GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument_overloaded @@ -172,7 +171,6 @@ ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weirditer_should_not_loop_forever ProvisionalTests.while_body_are_also_refined -RefinementTest.and_constraint RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number RefinementTest.assert_non_binary_expressions_actually_resolve_constraints @@ -187,28 +185,17 @@ RefinementTest.either_number_or_string RefinementTest.eliminate_subclasses_of_instance RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.index_on_a_refined_property -RefinementTest.invert_is_truthy_constraint RefinementTest.invert_is_truthy_constraint_ifelse_expression -RefinementTest.is_truthy_constraint RefinementTest.is_truthy_constraint_ifelse_expression -RefinementTest.lvalue_is_not_nil RefinementTest.merge_should_be_fully_agnostic_of_hashmap_ordering -RefinementTest.narrow_boolean_to_true_or_false RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.narrow_this_large_union RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true -RefinementTest.not_a_and_not_b -RefinementTest.not_a_and_not_b2 -RefinementTest.not_and_constraint RefinementTest.not_t_or_some_prop_of_t -RefinementTest.or_predicate_with_truthy_predicates -RefinementTest.parenthesized_expressions_are_followed_through RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table RefinementTest.refine_the_correct_types_opposite_of_when_a_is_not_number_or_string RefinementTest.refine_unknowns -RefinementTest.term_is_equal_to_an_lvalue RefinementTest.truthy_constraint_on_properties -RefinementTest.type_assertion_expr_carry_its_constraints RefinementTest.type_comparison_ifelse_expression RefinementTest.type_guard_can_filter_for_intersection_of_tables RefinementTest.type_guard_can_filter_for_overloaded_function @@ -271,6 +258,7 @@ TableTests.infer_indexer_from_value_property_in_literal TableTests.inferred_return_type_of_free_table TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 +TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors TableTests.less_exponential_blowup_please @@ -315,7 +303,6 @@ TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.tables_get_names_from_their_locals TableTests.tc_member_function TableTests.tc_member_function_2 -TableTests.top_table_type TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.unification_of_unions_in_a_self_referential_type TableTests.unifying_tables_shouldnt_uaf2 @@ -417,12 +404,12 @@ TypeInferFunctions.too_few_arguments_variadic TypeInferFunctions.too_few_arguments_variadic_generic TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments +TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function TypeInferFunctions.vararg_function_is_quantified TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values -TypeInferLoops.for_in_loop_with_custom_iterator TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_generic_next TypeInferLoops.for_in_with_just_one_iterator_is_ok @@ -430,7 +417,6 @@ TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.unreachable_code_after_infinite_loop TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free -TypeInferModules.bound_free_table_export_is_ok TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types TypeInferModules.module_type_conflict @@ -443,6 +429,7 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted +TypeInferOOP.object_constructor_can_refer_to_method_of_self TypeInferOperators.and_or_ternary TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable diff --git a/tools/lvmexecute_split.py b/tools/lvmexecute_split.py index 48d66cb08..f4a78960b 100644 --- a/tools/lvmexecute_split.py +++ b/tools/lvmexecute_split.py @@ -32,6 +32,9 @@ """ function = "" +signature = "" + +includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_COVERAGE", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS"] state = 0 @@ -44,7 +47,6 @@ if match: inst = match[1] signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)" - header += signature + ";\n" function = signature + "\n" function += "{\n" function += " [[maybe_unused]] Closure* cl = clvalue(L->ci->func);\n" @@ -84,7 +86,10 @@ function = function[:-len(finalline)] function += " return pc;\n}\n" - source += function + "\n" + if inst in includeInsts: + header += signature + ";\n" + source += function + "\n" + state = 0 # skip LUA_CUSTOM_EXECUTION code blocks diff --git a/tools/stack-usage-reporter.py b/tools/stack-usage-reporter.py new file mode 100644 index 000000000..91e74887d --- /dev/null +++ b/tools/stack-usage-reporter.py @@ -0,0 +1,173 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# The purpose of this script is to analyze disassembly generated by objdump or +# dumpbin to print (or to compare) the stack usage of functions/methods. +# This is a quickly written script, so it is quite possible it may not handle +# all code properly. +# +# The script expects the user to create a text assembly dump to be passed to +# the script. +# +# objdump Example +# objdump --demangle --disassemble objfile.o > objfile.s +# +# dumpbin Example +# dumpbin /disasm objfile.obj > objfile.s +# +# If the script is passed a single file, then all stack size information that +# is found it printed. If two files are passed, then the script compares the +# stack usage of the two files (useful for A/B comparisons). +# Currently more than two input files are not supported. (But adding support shouldn't +# be very difficult.) +# +# Note: The script only handles x64 disassembly. Supporting x86 is likely +# trivial, but ARM support could be difficult. +# Thus far the script has been tested with MSVC on Win64 and clang on OSX. + +import argparse +import re + +blank_re = re.compile('\s*') + +class LineReader: + def __init__(self, lines): + self.lines = list(reversed(lines)) + def get_line(self): + return self.lines.pop(-1) + def peek_line(self): + return self.lines[-1] + def consume_blank_lines(self): + while blank_re.fullmatch(self.peek_line()): + self.get_line() + def is_empty(self): + return len(self.lines) == 0 + +def parse_objdump_assembly(in_file): + results = {} + text_section_re = re.compile('Disassembly of section __TEXT,__text:\s*') + symbol_re = re.compile('[^<]*<(.*)>:\s*') + stack_alloc = re.compile('.*subq\s*\$(\d*), %rsp\s*') + + lr = LineReader(in_file.readlines()) + + def find_stack_alloc_size(): + while True: + if lr.is_empty(): + return None + if blank_re.fullmatch(lr.peek_line()): + return None + + line = lr.get_line() + mo = stack_alloc.fullmatch(line) + if mo: + lr.consume_blank_lines() + return int(mo.group(1)) + + # Find beginning of disassembly + while not text_section_re.fullmatch(lr.get_line()): + pass + + # Scan for symbols + while not lr.is_empty(): + lr.consume_blank_lines() + if lr.is_empty(): + break + line = lr.get_line() + mo = symbol_re.fullmatch(line) + # Found a symbol + if mo: + symbol = mo.group(1) + stack_size = find_stack_alloc_size() + if stack_size != None: + results[symbol] = stack_size + + return results + +def parse_dumpbin_assembly(in_file): + results = {} + + file_type_re = re.compile('File Type: COFF OBJECT\s*') + symbol_re = re.compile('[^(]*\((.*)\):\s*') + summary_re = re.compile('\s*Summary\s*') + stack_alloc = re.compile('.*sub\s*rsp,([A-Z0-9]*)h\s*') + + lr = LineReader(in_file.readlines()) + + def find_stack_alloc_size(): + while True: + if lr.is_empty(): + return None + if blank_re.fullmatch(lr.peek_line()): + return None + + line = lr.get_line() + mo = stack_alloc.fullmatch(line) + if mo: + lr.consume_blank_lines() + return int(mo.group(1), 16) # return value in decimal + + # Find beginning of disassembly + while not file_type_re.fullmatch(lr.get_line()): + pass + + # Scan for symbols + while not lr.is_empty(): + lr.consume_blank_lines() + if lr.is_empty(): + break + line = lr.get_line() + if summary_re.fullmatch(line): + break + mo = symbol_re.fullmatch(line) + # Found a symbol + if mo: + symbol = mo.group(1) + stack_size = find_stack_alloc_size() + if stack_size != None: + results[symbol] = stack_size + return results + +def main(): + parser = argparse.ArgumentParser(description='Tool used for reporting or comparing the stack usage of functions/methods') + parser.add_argument('--format', choices=['dumpbin', 'objdump'], required=True, help='Specifies the program used to generate the input files') + parser.add_argument('--input', action='append', required=True, help='Input assembly file. This option may be specified multiple times.') + parser.add_argument('--md-output', action='store_true', help='Show table output in markdown format') + parser.add_argument('--only-diffs', action='store_true', help='Only show stack info when it differs between the input files') + args = parser.parse_args() + + parsers = {'dumpbin': parse_dumpbin_assembly, 'objdump' : parse_objdump_assembly} + parse_func = parsers[args.format] + + input_results = [] + for input_name in args.input: + with open(input_name) as in_file: + results = parse_func(in_file) + input_results.append(results) + + if len(input_results) == 1: + # Print out the results sorted by size + size_sorted = sorted([(size, symbol) for symbol, size in results.items()], reverse=True) + print(input_name) + for size, symbol in size_sorted: + print(f'{size:10}\t{symbol}') + print() + elif len(input_results) == 2: + common_symbols = set(input_results[0].keys()).intersection(set(input_results[1].keys())) + print(f'Found {len(common_symbols)} common symbols') + stack_sizes = sorted([(input_results[0][sym], input_results[1][sym], sym) for sym in common_symbols], reverse=True) + if args.md_output: + print('Before | After | Symbol') + print('-- | -- | --') + for size0, size1, symbol in stack_sizes: + if args.only_diffs and size0 == size1: + continue + if args.md_output: + print(f'{size0} | {size1} | {symbol}') + else: + print(f'{size0:10}\t{size1:10}\t{symbol}') + else: + print("TODO support more than 2 inputs") + +if __name__ == '__main__': + main() From 3155ba0358abbf91d8ba71447b25c9ee44af598e Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 11 Nov 2022 00:04:44 +0200 Subject: [PATCH 16/66] Sync to upstream/release/553 --- .../include/Luau/ConstraintGraphBuilder.h | 20 +- Analysis/include/Luau/ConstraintSolver.h | 4 +- Analysis/include/Luau/Scope.h | 5 - Analysis/include/Luau/ToString.h | 2 - Analysis/include/Luau/TypeInfer.h | 4 +- Analysis/src/BuiltinDefinitions.cpp | 94 ++++++-- Analysis/src/ConstraintGraphBuilder.cpp | 189 ++++++++++++---- Analysis/src/ConstraintSolver.cpp | 108 +++++---- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 8 +- Analysis/src/Frontend.cpp | 3 +- Analysis/src/Module.cpp | 14 -- Analysis/src/ToString.cpp | 110 ++++----- Analysis/src/TypeChecker2.cpp | 35 ++- Analysis/src/TypeInfer.cpp | 202 +++++++++++++---- Analysis/src/TypeVar.cpp | 204 +++++++++++++++-- Analysis/src/Unifier.cpp | 43 +++- Ast/src/Parser.cpp | 13 +- CodeGen/include/Luau/AddressA64.h | 7 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 29 ++- CodeGen/src/AssemblyBuilderA64.cpp | 208 ++++++++++++++---- CodeGen/src/CodeGenX64.cpp | 70 +++--- CodeGen/src/EmitCommonX64.h | 8 +- CodeGen/src/UnwindBuilderDwarf2.cpp | 20 +- CodeGen/src/UnwindBuilderWin.cpp | 6 +- Common/include/Luau/ExperimentalFlags.h | 2 + tests/AssemblyBuilderA64.test.cpp | 65 +++++- tests/CodeAllocator.test.cpp | 9 +- tests/ConstraintGraphBuilderFixture.cpp | 4 +- tests/Frontend.test.cpp | 8 - tests/Parser.test.cpp | 17 ++ tests/Repl.test.cpp | 17 ++ tests/ToString.test.cpp | 71 +++--- tests/TypeInfer.aliases.test.cpp | 13 +- tests/TypeInfer.builtins.test.cpp | 6 +- tests/TypeInfer.functions.test.cpp | 2 +- tests/TypeInfer.operators.test.cpp | 83 ++++++- tests/TypeInfer.provisional.test.cpp | 26 +-- tests/TypeInfer.refinements.test.cpp | 189 ++++++++++++---- tests/TypeInfer.tables.test.cpp | 99 ++++++++- tests/TypeInfer.test.cpp | 37 ---- tests/TypeInfer.typePacks.cpp | 2 - tests/TypeInfer.unionTypes.test.cpp | 17 ++ tests/TypeInfer.unknownnever.test.cpp | 4 +- tools/faillist.txt | 32 +-- 44 files changed, 1531 insertions(+), 578 deletions(-) diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index cb5900ea9..2bfc62f17 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -66,6 +66,14 @@ struct ConstraintGraphBuilder // The root scope of the module we're generating constraints for. // This is null when the CGB is initially constructed. Scope* rootScope; + + // Constraints that go straight to the solver. + std::vector constraints; + + // Constraints that do not go to the solver right away. Other constraints + // will enqueue them during solving. + std::vector unqueuedConstraints; + // A mapping of AST node to TypeId. DenseHashMap astTypes{nullptr}; // A mapping of AST node to TypePackId. @@ -252,16 +260,8 @@ struct ConstraintGraphBuilder void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); }; -/** - * Collects a vector of borrowed constraints from the scope and all its child - * scopes. It is important to only call this function when you're done adding - * constraints to the scope or its descendants, lest the borrowed pointers - * become invalid due to a container reallocation. - * @param rootScope the root scope of the scope graph to collect constraints - * from. - * @return a list of pointers to constraints contained within the scope graph. - * None of these pointers should be null. +/** Borrow a vector of pointers from a vector of owning pointers to constraints. */ -std::vector> collectConstraints(NotNull rootScope); +std::vector> borrowConstraints(const std::vector& constraints); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 07f027ad2..7b89a2781 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -76,8 +76,8 @@ struct ConstraintSolver DcrLogger* logger; - explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, - std::vector requireCycles, DcrLogger* logger); + explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index ccf2964ce..a26f506d6 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -38,11 +38,6 @@ struct Scope std::unordered_map bindings; TypePackId returnType; std::optional varargPack; - // All constraints belonging to this scope. - std::vector constraints; - // Constraints belonging to this scope that are queued manually by other - // constraints. - std::vector unqueuedConstraints; TypeLevel level; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index ff2561e65..0200a7190 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -34,7 +34,6 @@ struct ToStringOptions size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); ToStringNameMap nameMap; - std::optional DEPRECATED_nameMap; std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' std::vector namedFunctionOverrideArgNames; // If present, named function argument names will be overridden }; @@ -42,7 +41,6 @@ struct ToStringOptions struct ToStringResult { std::string name; - ToStringNameMap DEPRECATED_nameMap; bool invalid = false; bool error = false; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index c5d7501dc..4eaa59694 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -280,14 +280,14 @@ struct TypeChecker TypeId singletonType(bool value); TypeId singletonType(std::string value); - TypeIdPredicate mkTruthyPredicate(bool sense); + TypeIdPredicate mkTruthyPredicate(bool sense, TypeId emptySetTy); // TODO: Return TypeId only. std::optional filterMapImpl(TypeId type, TypeIdPredicate predicate); std::pair, bool> filterMap(TypeId type, TypeIdPredicate predicate); public: - std::pair, bool> pickTypesFromSense(TypeId type, bool sense); + std::pair, bool> pickTypesFromSense(TypeId type, bool sense, TypeId emptySetTy); private: TypeId unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes = true); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index ee53ae6b4..67e3979a7 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -17,7 +17,9 @@ LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false) +LUAU_FASTFLAG(LuauOptionalNextKey) LUAU_FASTFLAG(LuauReportShadowedTypeAlias) +LUAU_FASTFLAG(LuauNewLibraryTypeNames) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -276,18 +278,38 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - // next(t: Table, i: K?) -> (K, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); - addGlobalBinding(typeChecker, "next", - arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + if (FFlag::LuauOptionalNextKey) + { + // next(t: Table, i: K?) -> (K?, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(typeChecker, arena, genericK), genericV}}); + addGlobalBinding(typeChecker, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } + else + { + // next(t: Table, i: K?) -> (K, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); + addGlobalBinding(typeChecker, "next", + arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); - TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); @@ -319,7 +341,12 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) if (TableTypeVar* ttv = getMutable(pair.second.typeId)) { if (!ttv->name) - ttv->name = toString(pair.first); + { + if (FFlag::LuauNewLibraryTypeNames) + ttv->name = "typeof(" + toString(pair.first) + ")"; + else + ttv->name = toString(pair.first); + } } } @@ -370,18 +397,38 @@ void registerBuiltinGlobals(Frontend& frontend) addGlobalBinding(frontend, "string", it->second.type, "@luau"); - // next(t: Table, i: K?) -> (K, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); - addGlobalBinding(frontend, "next", - arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + if (FFlag::LuauOptionalNextKey) + { + // next(t: Table, i: K?) -> (K?, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(frontend, arena, genericK), genericV}}); + addGlobalBinding(frontend, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) + addGlobalBinding( + frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } + else + { + // next(t: Table, i: K?) -> (K, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); + addGlobalBinding(frontend, "next", + arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); - TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding(frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); + + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); @@ -413,7 +460,12 @@ void registerBuiltinGlobals(Frontend& frontend) if (TableTypeVar* ttv = getMutable(pair.second.typeId)) { if (!ttv->name) - ttv->name = toString(pair.first); + { + if (FFlag::LuauNewLibraryTypeNames) + ttv->name = "typeof(" + toString(pair.first) + ")"; + else + ttv->name = toString(pair.first); + } } } @@ -623,7 +675,7 @@ static std::optional> magicFunctionAssert( if (head.size() > 0) { - auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true); + auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true, typechecker.singletonTypes->nilType); if (FFlag::LuauUnknownAndNeverType) { if (get(*ty)) diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 79a69ca47..42dc07f6d 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -52,6 +52,70 @@ static bool matchSetmetatable(const AstExprCall& call) return true; } +struct TypeGuard +{ + bool isTypeof; + AstExpr* target; + std::string type; +}; + +static std::optional matchTypeGuard(const AstExprBinary* binary) +{ + if (binary->op != AstExprBinary::CompareEq && binary->op != AstExprBinary::CompareNe) + return std::nullopt; + + AstExpr* left = binary->left; + AstExpr* right = binary->right; + if (right->is()) + std::swap(left, right); + + if (!right->is()) + return std::nullopt; + + AstExprCall* call = left->as(); + AstExprConstantString* string = right->as(); + if (!call || !string) + return std::nullopt; + + AstExprGlobal* callee = call->func->as(); + if (!callee) + return std::nullopt; + + if (callee->name != "type" && callee->name != "typeof") + return std::nullopt; + + if (call->args.size != 1) + return std::nullopt; + + return TypeGuard{ + /*isTypeof*/ callee->name == "typeof", + /*target*/ call->args.data[0], + /*type*/ std::string(string->value.data, string->value.size), + }; +} + +namespace +{ + +struct Checkpoint +{ + size_t offset; +}; + +Checkpoint checkpoint(const ConstraintGraphBuilder* cgb) +{ + return Checkpoint{cgb->constraints.size()}; +} + +template +void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const ConstraintGraphBuilder* cgb, F f) +{ + for (size_t i = start.offset; i < end.offset; ++i) + f(cgb->constraints[i]); +} + +} // namespace + ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, NotNull dfg) @@ -99,12 +163,12 @@ ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& paren NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) { - return NotNull{scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; + return NotNull{constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; } NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) { - return NotNull{scope->constraints.emplace_back(std::move(c)).get()}; + return NotNull{constraints.emplace_back(std::move(c)).get()}; } static void unionRefinements(const std::unordered_map& lhs, const std::unordered_map& rhs, @@ -476,6 +540,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) TypeId ty = freshType(loopScope); loopScope->bindings[var] = Binding{ty, var->location}; variableTypes.push_back(ty); + + if (auto def = dfg->getDef(var)) + loopScope->dcrRefinements[*def] = ty; } // It is always ok to provide too few variables, so we give this pack a free tail. @@ -506,20 +573,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) check(repeatScope, repeat->condition); } -void addConstraints(Constraint* constraint, NotNull scope) -{ - scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); - - for (const auto& c : scope->constraints) - constraint->dependencies.push_back(NotNull{c.get()}); - - for (const auto& c : scope->unqueuedConstraints) - constraint->dependencies.push_back(NotNull{c.get()}); - - for (NotNull childScope : scope->children) - addConstraints(constraint, childScope); -} - void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) { // Local @@ -537,12 +590,17 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* FunctionSignature sig = checkFunctionSignature(scope, function->func); sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; + auto start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); + auto end = checkpoint(this); NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); - addConstraints(c.get(), NotNull(sig.bodyScope.get())); + + forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { + c->dependencies.push_back(NotNull{constraint.get()}); + }); addConstraint(scope, std::move(c)); } @@ -610,12 +668,17 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct LUAU_ASSERT(functionType != nullptr); + auto start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); + auto end = checkpoint(this); NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); - addConstraints(c.get(), NotNull(sig.bodyScope.get())); + + forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { + c->dependencies.push_back(NotNull{constraint.get()}); + }); addConstraint(scope, std::move(c)); } @@ -947,8 +1010,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) { TypeId fnType = check(scope, call->func).ty; - const size_t constraintIndex = scope->constraints.size(); - const size_t scopeIndex = scopes.size(); + auto startCheckpoint = checkpoint(this); std::vector args; @@ -977,8 +1039,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } else { - const size_t constraintEndIndex = scope->constraints.size(); - const size_t scopeEndIndex = scopes.size(); + auto endCheckpoint = checkpoint(this); astOriginalCallTypes[call->func] = fnType; @@ -989,29 +1050,22 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); TypeId inferredFnType = arena->addType(ftv); - scope->unqueuedConstraints.push_back( + unqueuedConstraints.push_back( std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); - NotNull ic(scope->unqueuedConstraints.back().get()); + NotNull ic(unqueuedConstraints.back().get()); - scope->unqueuedConstraints.push_back( + unqueuedConstraints.push_back( std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); - NotNull sc(scope->unqueuedConstraints.back().get()); + NotNull sc(unqueuedConstraints.back().get()); // We force constraints produced by checking function arguments to wait // until after we have resolved the constraint on the function itself. // This ensures, for instance, that we start inferring the contents of // lambdas under the assumption that their arguments and return types // will be compatible with the enclosing function call. - for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) - scope->constraints[ci]->dependencies.push_back(sc); - - for (size_t si = scopeIndex; si < scopeEndIndex; ++si) - { - for (auto& c : scopes[si].second->constraints) - { - c->dependencies.push_back(sc); - } - } + forEachConstraint(startCheckpoint, endCheckpoint, this, [sc](const ConstraintPtr& constraint) { + constraint->dependencies.push_back(sc); + }); addConstraint(scope, call->func->location, FunctionCallConstraint{ @@ -1283,6 +1337,54 @@ std::tuple ConstraintGraphBuilder::checkBinary( return {leftType, rightType, connectiveArena.disjunction(leftConnective, rightConnective)}; } + else if (auto typeguard = matchTypeGuard(binary)) + { + TypeId leftType = check(scope, binary->left).ty; + TypeId rightType = check(scope, binary->right).ty; + + std::optional def = dfg->getDef(typeguard->target); + if (!def) + return {leftType, rightType, nullptr}; + + TypeId discriminantTy = singletonTypes->neverType; + if (typeguard->type == "nil") + discriminantTy = singletonTypes->nilType; + else if (typeguard->type == "string") + discriminantTy = singletonTypes->stringType; + else if (typeguard->type == "number") + discriminantTy = singletonTypes->numberType; + else if (typeguard->type == "boolean") + discriminantTy = singletonTypes->threadType; + else if (typeguard->type == "table") + discriminantTy = singletonTypes->neverType; // TODO: replace with top table type + else if (typeguard->type == "function") + discriminantTy = singletonTypes->functionType; + else if (typeguard->type == "userdata") + { + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof + discriminantTy = singletonTypes->neverType; // TODO: replace with top class type + } + else if (!typeguard->isTypeof && typeguard->type == "vector") + discriminantTy = singletonTypes->neverType; // TODO: figure out a way to deal with this quirky type + else if (!typeguard->isTypeof) + discriminantTy = singletonTypes->neverType; + else if (auto typeFun = globalScope->lookupType(typeguard->type); typeFun && typeFun->typeParams.empty() && typeFun->typePackParams.empty()) + { + TypeId ty = follow(typeFun->type); + + // We're only interested in the root class of any classes. + if (auto ctv = get(ty); !ctv || !ctv->parent) + discriminantTy = ty; + } + + ConnectiveId proposition = connectiveArena.proposition(*def, discriminantTy); + if (binary->op == AstExprBinary::CompareEq) + return {leftType, rightType, proposition}; + else if (binary->op == AstExprBinary::CompareNe) + return {leftType, rightType, connectiveArena.negation(proposition)}; + else + ice->ice("matchTypeGuard should only return a Some under `==` or `~=`!"); + } else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) { TypeId leftType = check(scope, binary->left, expectedType, true).ty; @@ -2066,19 +2168,14 @@ void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, program->visit(&gp); } -void collectConstraints(std::vector>& result, NotNull scope) +std::vector> borrowConstraints(const std::vector& constraints) { - for (const auto& c : scope->constraints) - result.push_back(NotNull{c.get()}); + std::vector> result; + result.reserve(constraints.size()); - for (NotNull child : scope->children) - collectConstraints(result, child); -} + for (const auto& c : constraints) + result.emplace_back(c.get()); -std::vector> collectConstraints(NotNull rootScope) -{ - std::vector> result; - collectConstraints(result, rootScope); return result; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index c53ac659a..533652e23 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); -LUAU_FASTFLAG(LuauFixNameMaps) namespace Luau { @@ -27,34 +26,14 @@ namespace Luau { for (const auto& [k, v] : scope->bindings) { - if (FFlag::LuauFixNameMaps) - { - auto d = toString(v.typeId, opts); - printf("\t%s : %s\n", k.c_str(), d.c_str()); - } - else - { - auto d = toStringDetailed(v.typeId, opts); - opts.DEPRECATED_nameMap = d.DEPRECATED_nameMap; - printf("\t%s : %s\n", k.c_str(), d.name.c_str()); - } + auto d = toString(v.typeId, opts); + printf("\t%s : %s\n", k.c_str(), d.c_str()); } for (NotNull child : scope->children) dumpBindings(child, opts); } -static void dumpConstraints(NotNull scope, ToStringOptions& opts) -{ - for (const ConstraintPtr& c : scope->constraints) - { - printf("\t%s\n", toString(*c, opts).c_str()); - } - - for (NotNull child : scope->children) - dumpConstraints(child, opts); -} - static std::pair, std::vector> saturateArguments(TypeArena* arena, NotNull singletonTypes, const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments) { @@ -219,12 +198,6 @@ size_t HashInstantiationSignature::operator()(const InstantiationSignature& sign return hash; } -void dump(NotNull rootScope, ToStringOptions& opts) -{ - printf("constraints:\n"); - dumpConstraints(rootScope, opts); -} - void dump(ConstraintSolver* cs, ToStringOptions& opts) { printf("constraints:\n"); @@ -248,17 +221,17 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) if (auto fcc = get(*c)) { for (NotNull inner : fcc->innerConstraints) - printf("\t\t\t%s\n", toString(*inner, opts).c_str()); + printf("\t ->\t\t%s\n", toString(*inner, opts).c_str()); } } } -ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, - NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) +ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) : arena(normalizer->arena) , singletonTypes(normalizer->singletonTypes) , normalizer(normalizer) - , constraints(collectConstraints(rootScope)) + , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) , moduleResolver(moduleResolver) @@ -267,7 +240,7 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull c : constraints) + for (NotNull c : this->constraints) { unsolvedConstraints.push_back(c); @@ -310,6 +283,8 @@ void ConstraintSolver::run() { printf("Starting solver\n"); dump(this, opts); + printf("Bindings:\n"); + dumpBindings(rootScope, opts); } if (FFlag::DebugLuauLogSolverToJson) @@ -633,13 +608,6 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(unionOfTypes(rightType, singletonTypes->booleanType, constraint->scope, false)); + { + TypeId leftFilteredTy = arena->addType(IntersectionTypeVar{{singletonTypes->falsyType, leftType}}); + + // TODO: normaliztion here should be replaced by a more limited 'simplification' + const NormalizedType* normalized = normalizer->normalize(arena->addType(UnionTypeVar{{leftFilteredTy, rightType}})); + + if (!normalized) + { + reportError(CodeTooComplex{}, constraint->location); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + } + else + { + asMutable(resultType)->ty.emplace(normalizer->typeFromNormal(*normalized)); + } + unblock(resultType); return true; + } // Or evaluates to the LHS type if the LHS is truthy, and the RHS type if // LHS is falsey. case AstExprBinary::Op::Or: - asMutable(resultType)->ty.emplace(unionOfTypes(rightType, leftType, constraint->scope, true)); + { + TypeId rightFilteredTy = arena->addType(IntersectionTypeVar{{singletonTypes->truthyType, leftType}}); + + // TODO: normaliztion here should be replaced by a more limited 'simplification' + const NormalizedType* normalized = normalizer->normalize(arena->addType(UnionTypeVar{{rightFilteredTy, rightType}})); + + if (!normalized) + { + reportError(CodeTooComplex{}, constraint->location); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + } + else + { + asMutable(resultType)->ty.emplace(normalizer->typeFromNormal(*normalized)); + } + unblock(resultType); return true; + } default: iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location); break; @@ -1148,6 +1148,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) + block(ic, blockedConstraint); + } + } + asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm}; asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType}; @@ -1180,6 +1191,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) + block(ic, blockedConstraint); + } + } + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); asMutable(c.result)->ty.emplace(constraint->scope); } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 339de9755..b0f21737f 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,6 +2,7 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(LuauOptionalNextKey) namespace Luau { @@ -126,7 +127,7 @@ declare function rawlen(obj: {[K]: V} | string): number declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? -declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number) +-- TODO: place ipairs definition here with removal of FFlagLuauOptionalNextKey declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) @@ -207,6 +208,11 @@ std::string getBuiltinDefinitionSource() else result += "declare function error(message: T, level: number?)\n"; + if (FFlag::LuauOptionalNextKey) + result += "declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number)\n"; + else + result += "declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number)\n"; + return result; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 39e6428d2..22a9ecfa3 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -955,7 +955,8 @@ ModulePtr Frontend::check( cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), sourceModule.name, NotNull(&moduleResolver), requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, NotNull(&moduleResolver), + requireCycles, logger.get()}; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 0412f0077..62674aa8e 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,9 +14,7 @@ #include -LUAU_FASTFLAG(LuauAnyifyModuleReturnGenerics) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauForceExportSurfacesToBeNormal, false); LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); @@ -222,18 +220,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern } } - if (!FFlag::LuauAnyifyModuleReturnGenerics) - { - for (TypeId ty : returnType) - { - if (get(follow(ty))) - { - auto t = asMutable(ty); - t->ty = AnyTypeVar{}; - } - } - } - for (auto& [name, ty] : declaredGlobals) { if (FFlag::LuauClonePublicInterfaceLess) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 48215f244..5062c3f73 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -13,7 +13,6 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauLvaluelessPath) LUAU_FASTFLAG(LuauUnknownAndNeverType) -LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) @@ -130,28 +129,15 @@ struct StringifierState bool exhaustive; - StringifierState(ToStringOptions& opts, ToStringResult& result, const std::optional& DEPRECATED_nameMap) + StringifierState(ToStringOptions& opts, ToStringResult& result) : opts(opts) , result(result) , exhaustive(opts.exhaustive) { - if (!FFlag::LuauFixNameMaps && DEPRECATED_nameMap) - result.DEPRECATED_nameMap = *DEPRECATED_nameMap; - - if (!FFlag::LuauFixNameMaps) - { - for (const auto& [_, v] : result.DEPRECATED_nameMap.typeVars) - usedNames.insert(v); - for (const auto& [_, v] : result.DEPRECATED_nameMap.typePacks) - usedNames.insert(v); - } - else - { - for (const auto& [_, v] : opts.nameMap.typeVars) - usedNames.insert(v); - for (const auto& [_, v] : opts.nameMap.typePacks) - usedNames.insert(v); - } + for (const auto& [_, v] : opts.nameMap.typeVars) + usedNames.insert(v); + for (const auto& [_, v] : opts.nameMap.typePacks) + usedNames.insert(v); } bool hasSeen(const void* tv) @@ -174,8 +160,8 @@ struct StringifierState std::string getName(TypeId ty) { - const size_t s = FFlag::LuauFixNameMaps ? opts.nameMap.typeVars.size() : result.DEPRECATED_nameMap.typeVars.size(); - std::string& n = FFlag::LuauFixNameMaps ? opts.nameMap.typeVars[ty] : result.DEPRECATED_nameMap.typeVars[ty]; + const size_t s = opts.nameMap.typeVars.size(); + std::string& n = opts.nameMap.typeVars[ty]; if (!n.empty()) return n; @@ -197,8 +183,8 @@ struct StringifierState std::string getName(TypePackId ty) { - const size_t s = FFlag::LuauFixNameMaps ? opts.nameMap.typePacks.size() : result.DEPRECATED_nameMap.typePacks.size(); - std::string& n = FFlag::LuauFixNameMaps ? opts.nameMap.typePacks[ty] : result.DEPRECATED_nameMap.typePacks[ty]; + const size_t s = opts.nameMap.typePacks.size(); + std::string& n = opts.nameMap.typePacks[ty]; if (!n.empty()) return n; @@ -404,10 +390,7 @@ struct TypeVarStringifier if (gtv.explicitName) { state.usedNames.insert(gtv.name); - if (FFlag::LuauFixNameMaps) - state.opts.nameMap.typeVars[ty] = gtv.name; - else - state.result.DEPRECATED_nameMap.typeVars[ty] = gtv.name; + state.opts.nameMap.typeVars[ty] = gtv.name; state.emit(gtv.name); } else @@ -1000,10 +983,7 @@ struct TypePackStringifier if (pack.explicitName) { state.usedNames.insert(pack.name); - if (FFlag::LuauFixNameMaps) - state.opts.nameMap.typePacks[tp] = pack.name; - else - state.result.DEPRECATED_nameMap.typePacks[tp] = pack.name; + state.opts.nameMap.typePacks[tp] = pack.name; state.emit(pack.name); } else @@ -1104,8 +1084,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) ToStringResult result; - StringifierState state = - FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state{opts, result}; std::set cycles; std::set cycleTPs; @@ -1209,8 +1188,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) * 4. Print out the root of the type using the same algorithm as step 3. */ ToStringResult result; - StringifierState state = - FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state{opts, result}; std::set cycles; std::set cycleTPs; @@ -1293,8 +1271,7 @@ std::string toString(const TypePackVar& tp, ToStringOptions& opts) std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts) { ToStringResult result; - StringifierState state = - FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state{opts, result}; TypeVarStringifier tvs{state}; state.emit(funcName); @@ -1426,91 +1403,84 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) auto go = [&opts](auto&& c) -> std::string { using T = std::decay_t; - // TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps - auto tos = [](auto&& a, ToStringOptions& opts) { - if (FFlag::LuauFixNameMaps) - return toString(a, opts); - else - { - ToStringResult tsr = toStringDetailed(a, opts); - opts.DEPRECATED_nameMap = std::move(tsr.DEPRECATED_nameMap); - return tsr.name; - } + auto tos = [&opts](auto&& a) + { + return toString(a, opts); }; if constexpr (std::is_same_v) { - std::string subStr = tos(c.subType, opts); - std::string superStr = tos(c.superType, opts); + std::string subStr = tos(c.subType); + std::string superStr = tos(c.superType); return subStr + " <: " + superStr; } else if constexpr (std::is_same_v) { - std::string subStr = tos(c.subPack, opts); - std::string superStr = tos(c.superPack, opts); + std::string subStr = tos(c.subPack); + std::string superStr = tos(c.superPack); return subStr + " <: " + superStr; } else if constexpr (std::is_same_v) { - std::string subStr = tos(c.generalizedType, opts); - std::string superStr = tos(c.sourceType, opts); + std::string subStr = tos(c.generalizedType); + std::string superStr = tos(c.sourceType); return subStr + " ~ gen " + superStr; } else if constexpr (std::is_same_v) { - std::string subStr = tos(c.subType, opts); - std::string superStr = tos(c.superType, opts); + std::string subStr = tos(c.subType); + std::string superStr = tos(c.superType); return subStr + " ~ inst " + superStr; } else if constexpr (std::is_same_v) { - std::string resultStr = tos(c.resultType, opts); - std::string operandStr = tos(c.operandType, opts); + std::string resultStr = tos(c.resultType); + std::string operandStr = tos(c.operandType); return resultStr + " ~ Unary<" + toString(c.op) + ", " + operandStr + ">"; } else if constexpr (std::is_same_v) { - std::string resultStr = tos(c.resultType, opts); - std::string leftStr = tos(c.leftType, opts); - std::string rightStr = tos(c.rightType, opts); + std::string resultStr = tos(c.resultType); + std::string leftStr = tos(c.leftType); + std::string rightStr = tos(c.rightType); return resultStr + " ~ Binary<" + toString(c.op) + ", " + leftStr + ", " + rightStr + ">"; } else if constexpr (std::is_same_v) { - std::string iteratorStr = tos(c.iterator, opts); - std::string variableStr = tos(c.variables, opts); + std::string iteratorStr = tos(c.iterator); + std::string variableStr = tos(c.variables); return variableStr + " ~ Iterate<" + iteratorStr + ">"; } else if constexpr (std::is_same_v) { - std::string namedStr = tos(c.namedType, opts); + std::string namedStr = tos(c.namedType); return "@name(" + namedStr + ") = " + c.name; } else if constexpr (std::is_same_v) { - std::string targetStr = tos(c.target, opts); + std::string targetStr = tos(c.target); return "expand " + targetStr; } else if constexpr (std::is_same_v) { - return "call " + tos(c.fn, opts) + " with { result = " + tos(c.result, opts) + " }"; + return "call " + tos(c.fn) + " with { result = " + tos(c.result) + " }"; } else if constexpr (std::is_same_v) { - return tos(c.resultType, opts) + " ~ prim " + tos(c.expectedType, opts) + ", " + tos(c.singletonType, opts) + ", " + - tos(c.multitonType, opts); + return tos(c.resultType) + " ~ prim " + tos(c.expectedType) + ", " + tos(c.singletonType) + ", " + + tos(c.multitonType); } else if constexpr (std::is_same_v) { - return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; + return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\""; } else if constexpr (std::is_same_v) { - std::string result = tos(c.resultType, opts); - std::string discriminant = tos(c.discriminantType, opts); + std::string result = tos(c.resultType); + std::string discriminant = tos(c.discriminantType); return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index dde41a65f..03575c405 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -610,7 +610,7 @@ struct TypeChecker2 visit(rhs); TypeId rhsType = lookupType(rhs); - if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice)) + if (!isSubtype(rhsType, lhsType, stack.back())) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } @@ -761,7 +761,7 @@ struct TypeChecker2 TypeId actualType = lookupType(number); TypeId numberType = singletonTypes->numberType; - if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice)) + if (!isSubtype(numberType, actualType, stack.back())) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -772,7 +772,7 @@ struct TypeChecker2 TypeId actualType = lookupType(string); TypeId stringType = singletonTypes->stringType; - if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice)) + if (!isSubtype(actualType, stringType, stack.back())) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -861,7 +861,7 @@ struct TypeChecker2 FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice)) + if (!isSubtype(testFunctionType, expectedType, stack.back())) { CloneState cloneState; expectedType = clone(expectedType, module->internalTypes, cloneState); @@ -880,7 +880,7 @@ struct TypeChecker2 getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); if (ty) { - if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice)) + if (!isSubtype(resultType, *ty, stack.back())) { reportError(TypeMismatch{resultType, *ty}, indexName->location); } @@ -913,7 +913,7 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice)) + if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back())) { reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); } @@ -954,7 +954,7 @@ struct TypeChecker2 if (const FunctionTypeVar* ftv = get(follow(*mm))) { TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); - reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); + reportErrors(tryUnify(scope, expr->location, expectedArgs, ftv->argTypes)); if (std::optional ret = first(ftv->retTypes)) { @@ -1096,7 +1096,7 @@ struct TypeChecker2 expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) { TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType}); - if (!isSubtype(ftv->retTypes, expectedRets, scope, singletonTypes, ice)) + if (!isSubtype(ftv->retTypes, expectedRets, scope)) { reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location); } @@ -1207,10 +1207,10 @@ struct TypeChecker2 TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice)) + if (isSubtype(annotationType, computedType, stack.back())) return; - if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice)) + if (isSubtype(computedType, annotationType, stack.back())) return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); @@ -1505,12 +1505,27 @@ struct TypeChecker2 } } + template + bool isSubtype(TID subTy, TID superTy, NotNull scope) + { + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(subTy, superTy); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; + } + template ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy) { UnifierSharedState sharedState{&ice}; Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; + u.useScopes = true; u.tryUnify(subTy, superTy); return std::move(u.errors); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index ccb1490a2..8ecd45bd5 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,18 +35,20 @@ LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) -LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) LUAU_FASTFLAGVARIABLE(LuauLvaluelessPath, false) +LUAU_FASTFLAGVARIABLE(LuauNilIterator, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) -LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) +LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) +LUAU_FASTFLAGVARIABLE(LuauOptionalNextKey, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) +LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) namespace Luau { @@ -331,8 +333,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo else moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); - if (FFlag::LuauAnyifyModuleReturnGenerics) - moduleScope->returnType = anyifyModuleReturnTypePackGenerics(moduleScope->returnType); + moduleScope->returnType = anyifyModuleReturnTypePackGenerics(moduleScope->returnType); for (auto& [_, typeFun] : moduleScope->exportedTypeBindings) typeFun.type = anyify(moduleScope, typeFun.type, Location{}); @@ -1209,10 +1210,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) AstExpr* firstValue = forin.values.data[0]; // next is a function that takes Table and an optional index of type K - // next(t: Table, index: K | nil) -> (K, V) + // next(t: Table, index: K | nil) -> (K?, V) // however, pairs and ipairs are quite messy, but they both share the same types // pairs returns 'next, t, nil', thus the type would be - // pairs(t: Table) -> ((Table, K | nil) -> (K, V), Table, K | nil) + // pairs(t: Table) -> ((Table, K | nil) -> (K?, V), Table, K | nil) // ipairs returns 'next, t, 0', thus ipairs will also share the same type as pairs, except K = number // // we can also define our own custom iterators by by returning a wrapped coroutine that calls coroutine.yield @@ -1255,6 +1256,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } + if (FFlag::LuauNilIterator) + iterTy = stripFromNilAndReport(iterTy, firstValue->location); + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location, /* addErrors= */ true)) { // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions @@ -1338,21 +1342,61 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) reportErrors(state.errors); } - TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); - - if (forin.values.size >= 2) + if (FFlag::LuauOptionalNextKey) { - AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + TypePackId retPack = iterFunc->retTypes; - Position start = firstValue->location.begin; - Position end = values[forin.values.size - 1]->location.end; - AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + if (forin.values.size >= 2) + { + AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + + Position start = firstValue->location.begin; + Position end = values[forin.values.size - 1]->location.end; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + + retPack = checkExprPack(scope, exprCall).type; + } + + // We need to remove 'nil' from the set of options of the first return value + // Because for loop stops when it gets 'nil', this result is never actually assigned to the first variable + if (std::optional fty = first(retPack); fty && !varTypes.empty()) + { + TypeId keyTy = follow(*fty); + + if (get(keyTy)) + { + if (std::optional ty = tryStripUnionFromNil(keyTy)) + keyTy = *ty; + } + + unify(keyTy, varTypes.front(), scope, forin.location); + + // We have already handled the first variable type, make it match in the pack check + varTypes.front() = *fty; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); - TypePackId retPack = checkExprPack(scope, exprCall).type; unify(retPack, varPack, scope, forin.location); } else - unify(iterFunc->retTypes, varPack, scope, forin.location); + { + TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); + + if (forin.values.size >= 2) + { + AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + + Position start = firstValue->location.begin; + Position end = values[forin.values.size - 1]->location.end; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + + TypePackId retPack = checkExprPack(scope, exprCall).type; + unify(retPack, varPack, scope, forin.location); + } + else + unify(iterFunc->retTypes, varPack, scope, forin.location); + } check(loopScope, *forin.body); } @@ -1855,18 +1899,10 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (get(varargPack)) { - if (FFlag::LuauFixVarargExprHeadType) - { - if (std::optional ty = first(varargPack)) - return {*ty}; + if (std::optional ty = first(varargPack)) + return {*ty}; - return {nilType}; - } - else - { - std::vector types = flatten(varargPack).first; - return {!types.empty() ? types[0] : nilType}; - } + return {nilType}; } else if (get(varargPack)) { @@ -2717,12 +2753,54 @@ TypeId TypeChecker::checkRelationalOperation( case AstExprBinary::And: if (lhsIsAny) + { return lhsType; - return unionOfTypes(rhsType, booleanType, scope, expr.location, false); + } + else if (FFlag::LuauTryhardAnd) + { + // If lhs is free, we can't tell which 'falsy' components it has, if any + if (get(lhsType)) + return unionOfTypes(addType(UnionTypeVar{{nilType, singletonType(false)}}), rhsType, scope, expr.location, false); + + auto [oty, notNever] = pickTypesFromSense(lhsType, false, neverType); // Filter out falsy types + + if (notNever) + { + LUAU_ASSERT(oty); + return unionOfTypes(*oty, rhsType, scope, expr.location, false); + } + else + { + return rhsType; + } + } + else + { + return unionOfTypes(rhsType, booleanType, scope, expr.location, false); + } case AstExprBinary::Or: if (lhsIsAny) + { return lhsType; - return unionOfTypes(lhsType, rhsType, scope, expr.location); + } + else if (FFlag::LuauTryhardAnd) + { + auto [oty, notNever] = pickTypesFromSense(lhsType, true, neverType); // Filter out truthy types + + if (notNever) + { + LUAU_ASSERT(oty); + return unionOfTypes(*oty, rhsType, scope, expr.location); + } + else + { + return rhsType; + } + } + else + { + return unionOfTypes(lhsType, rhsType, scope, expr.location); + } default: LUAU_ASSERT(0); ice(format("checkRelationalOperation called with incorrect binary expression '%s'", toString(expr.op).c_str()), expr.location); @@ -4840,9 +4918,9 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) return singletonTypes->errorRecoveryTypePack(guess); } -TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) +TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense, TypeId emptySetTy) { - return [this, sense](TypeId ty) -> std::optional { + return [this, sense, emptySetTy](TypeId ty) -> std::optional { // any/error/free gets a special pass unconditionally because they can't be decided. if (get(ty) || get(ty) || get(ty)) return ty; @@ -4860,7 +4938,7 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) return sense ? std::nullopt : std::optional(ty); // at this point, anything else is kept if sense is true, or replaced by nil - return sense ? ty : nilType; + return sense ? ty : emptySetTy; }; } @@ -4886,9 +4964,9 @@ std::pair, bool> TypeChecker::filterMap(TypeId type, TypeI } } -std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense) +std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense, TypeId emptySetTy) { - return filterMap(type, mkTruthyPredicate(sense)); + return filterMap(type, mkTruthyPredicate(sense, emptySetTy)); } TypeId TypeChecker::addTV(TypeVar&& tv) @@ -5657,7 +5735,7 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, RefinementMap& refis, if (ty && fromOr) return addRefinement(refis, truthyP.lvalue, *ty); - refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); + refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense, nilType)); } void TypeChecker::resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense) @@ -5850,13 +5928,57 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc if (maybeSingleton(eqP.type)) { - // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. - if (!sense || canUnify(eqP.type, option, scope, eqP.location).empty()) - return sense ? eqP.type : option; + if (FFlag::LuauImplicitElseRefinement) + { + bool optionIsSubtype = canUnify(option, eqP.type, scope, eqP.location).empty(); + bool targetIsSubtype = canUnify(eqP.type, option, scope, eqP.location).empty(); + + // terminology refresher: + // - option is the type of the expression `x`, and + // - eqP.type is the type of the expression `"hello"` + // + // "hello" == x where + // x : "hello" | "world" -> x : "hello" + // x : number | string -> x : "hello" + // x : number -> x : never + // + // "hello" ~= x where + // x : "hello" | "world" -> x : "world" + // x : number | string -> x : number | string + // x : number -> x : number + + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional nope = std::nullopt; + + if (sense) + { + if (optionIsSubtype && !targetIsSubtype) + return option; + else if (!optionIsSubtype && targetIsSubtype) + return eqP.type; + else if (!optionIsSubtype && !targetIsSubtype) + return nope; + else if (optionIsSubtype && targetIsSubtype) + return eqP.type; + } + else + { + bool isOptionSingleton = get(option); + if (!isOptionSingleton) + return option; + else if (optionIsSubtype && targetIsSubtype) + return nope; + } + } + else + { + if (!sense || canUnify(eqP.type, option, scope, eqP.location).empty()) + return sense ? eqP.type : option; - // local variable works around an odd gcc 9.3 warning: may be used uninitialized - std::optional res = std::nullopt; - return res; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; + } } return option; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index de0890e18..814eca0d5 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -3,6 +3,7 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/ConstraintSolver.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" #include "Luau/RecursionCounter.h" @@ -26,6 +27,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAGVARIABLE(LuauNoMoreGlobalSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauNewLibraryTypeNames, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) namespace Luau @@ -33,15 +35,19 @@ namespace Luau std::optional> magicFunctionFormat( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context); static std::optional> magicFunctionGmatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context); static std::optional> magicFunctionMatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context); static std::optional> magicFunctionFind( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionFind(MagicFunctionCallContext context); TypeId follow(TypeId t) { @@ -800,6 +806,7 @@ TypeId SingletonTypes::makeStringMetatable() FunctionTypeVar formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; formatFTV.magicFunction = &magicFunctionFormat; const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); const TypePackId emptyPack = arena->addTypePack({}); const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); @@ -814,14 +821,17 @@ TypeId SingletonTypes::makeStringMetatable() const TypeId gmatchFunc = makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); const TypeId matchFunc = arena->addType( FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); const TypeId findFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, @@ -855,7 +865,7 @@ TypeId SingletonTypes::makeStringMetatable() TypeId tableType = arena->addType(TableTypeVar{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); if (TableTypeVar* ttv = getMutable(tableType)) - ttv->name = "string"; + ttv->name = FFlag::LuauNewLibraryTypeNames ? "typeof(string)" : "string"; return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } @@ -1072,7 +1082,7 @@ IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) return IntersectionTypeVarIterator{}; } -static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) +static std::vector parseFormatString(NotNull singletonTypes, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs*"; @@ -1095,13 +1105,13 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha break; if (data[i] == 'q' || data[i] == 's') - result.push_back(typechecker.stringType); + result.push_back(singletonTypes->stringType); else if (data[i] == '*') - result.push_back(typechecker.unknownType); + result.push_back(singletonTypes->unknownType); else if (strchr(options, data[i])) - result.push_back(typechecker.numberType); + result.push_back(singletonTypes->numberType); else - result.push_back(typechecker.errorRecoveryType(typechecker.anyType)); + result.push_back(singletonTypes->errorRecoveryType(singletonTypes->anyType)); } } @@ -1130,7 +1140,7 @@ std::optional> magicFunctionFormat( if (!fmt) return std::nullopt; - std::vector expected = parseFormatString(typechecker, fmt->value.data, fmt->value.size); + std::vector expected = parseFormatString(typechecker.singletonTypes, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(paramPack); size_t paramOffset = 1; @@ -1154,7 +1164,50 @@ std::optional> magicFunctionFormat( return WithPredicate{arena.addTypePack({typechecker.stringType})}; } -static std::vector parsePatternString(TypeChecker& typechecker, const char* data, size_t size) +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +{ + TypeArena* arena = context.solver->arena; + + AstExprConstantString* fmt = nullptr; + if (auto index = context.callSite->func->as(); index && context.callSite->self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!context.callSite->self && context.callSite->args.size > 0) + fmt = context.callSite->args.data[0]->as(); + + if (!fmt) + return false; + + std::vector expected = parseFormatString(context.solver->singletonTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(context.arguments); + + size_t paramOffset = 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + context.solver->unify(params[i + paramOffset], expected[i], context.solver->rootScope); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + TypePackId resultPack = arena->addTypePack({context.solver->singletonTypes->stringType}); + asMutable(context.result)->ty.emplace(resultPack); + + return true; +} + +static std::vector parsePatternString(NotNull singletonTypes, const char* data, size_t size) { std::vector result; int depth = 0; @@ -1186,12 +1239,12 @@ static std::vector parsePatternString(TypeChecker& typechecker, const ch if (i + 1 < size && data[i + 1] == ')') { i++; - result.push_back(typechecker.numberType); + result.push_back(singletonTypes->numberType); continue; } ++depth; - result.push_back(typechecker.stringType); + result.push_back(singletonTypes->stringType); } else if (data[i] == ')') { @@ -1209,7 +1262,7 @@ static std::vector parsePatternString(TypeChecker& typechecker, const ch return std::vector(); if (result.empty()) - result.push_back(typechecker.stringType); + result.push_back(singletonTypes->stringType); return result; } @@ -1233,7 +1286,7 @@ static std::optional> magicFunctionGmatch( if (!pattern) return std::nullopt; - std::vector returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + std::vector returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1246,6 +1299,39 @@ static std::optional> magicFunctionGmatch( return WithPredicate{arena.addTypePack({iteratorType})}; } +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() != 2) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t index = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > index) + pattern = context.callSite->args.data[index]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(params[0], context.solver->singletonTypes->stringType, context.solver->rootScope); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId returnList = arena->addTypePack(returnTypes); + const TypeId iteratorType = arena->addType(FunctionTypeVar{emptyPack, returnList}); + const TypePackId resTypePack = arena->addTypePack({iteratorType}); + asMutable(context.result)->ty.emplace(resTypePack); + + return true; +} + static std::optional> magicFunctionMatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { @@ -1265,7 +1351,7 @@ static std::optional> magicFunctionMatch( if (!pattern) return std::nullopt; - std::vector returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + std::vector returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1282,6 +1368,42 @@ static std::optional> magicFunctionMatch( return WithPredicate{returnList}; } +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 3) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(params[0], context.solver->singletonTypes->stringType, context.solver->rootScope); + + const TypeId optionalNumber = arena->addType(UnionTypeVar{{context.solver->singletonTypes->nilType, context.solver->singletonTypes->numberType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() == 3 && context.callSite->args.size > initIndex) + context.solver->unify(params[2], optionalNumber, context.solver->rootScope); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + + return true; +} + static std::optional> magicFunctionFind( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { @@ -1312,7 +1434,7 @@ static std::optional> magicFunctionFind( std::vector returnTypes; if (!plain) { - returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1336,6 +1458,60 @@ static std::optional> magicFunctionFind( return WithPredicate{returnList}; } +static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 4) + return false; + + TypeArena* arena = context.solver->arena; + NotNull singletonTypes = context.solver->singletonTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + bool plain = false; + size_t plainIndex = context.callSite->self ? 2 : 3; + if (context.callSite->args.size > plainIndex) + { + AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + } + + context.solver->unify(params[0], singletonTypes->stringType, context.solver->rootScope); + + const TypeId optionalNumber = arena->addType(UnionTypeVar{{singletonTypes->nilType, singletonTypes->numberType}}); + const TypeId optionalBoolean = arena->addType(UnionTypeVar{{singletonTypes->nilType, singletonTypes->booleanType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() >= 3 && context.callSite->args.size > initIndex) + context.solver->unify(params[2], optionalNumber, context.solver->rootScope); + + if (params.size() == 4 && context.callSite->args.size > plainIndex) + context.solver->unify(params[3], optionalBoolean, context.solver->rootScope); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + return true; +} + std::vector filterMap(TypeId type, TypeIdPredicate predicate) { type = follow(type); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index df5d86f1e..4dc909831 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); +LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNegatedFunctionTypes) @@ -1699,8 +1700,20 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); + TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner ? log.follow(superTy) : superTy; + TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner ? log.follow(subTy) : subTy; + + if (FFlag::LuauScalarShapeUnifyToMtOwner) + { + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || subTy != subTyNew) && errors.empty()) + return tryUnify(subTy, superTy, false, isIntersection); + } + + // Otherwise, restart only the table unification + TableTypeVar* newSuperTable = log.getMutable(superTyNew); + TableTypeVar* newSubTable = log.getMutable(subTyNew); + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) @@ -1862,7 +1875,9 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (reversed) std::swap(subTy, superTy); - if (auto ttv = log.get(superTy); !ttv || ttv->state != TableState::Free) + TableTypeVar* superTable = log.getMutable(superTy); + + if (!superTable || superTable->state != TableState::Free) return reportError(location, TypeMismatch{osuperTy, osubTy}); auto fail = [&](std::optional e) { @@ -1887,6 +1902,20 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) Unifier child = makeChildUnifier(); child.tryUnify_(ty, superTy); + if (FFlag::LuauScalarShapeUnifyToMtOwner) + { + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed + // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check + TypeId newSuperTy = child.log.follow(superTy); + + if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) + { + log.replace(superTy, BoundTypeVar{subTy}); + return; + } + } + if (auto e = hasUnificationTooComplex(child.errors)) reportError(*e); else if (!child.errors.empty()) @@ -1894,6 +1923,14 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) log.concat(std::move(child.log)); + if (FFlag::LuauScalarShapeUnifyToMtOwner) + { + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype + if (child.errors.empty()) + log.replace(superTy, BoundTypeVar{subTy}); + } + return; } else diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 4c0cc1251..8338a04a7 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -27,6 +27,8 @@ LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false) +LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false) + bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false; bool lua_telemetry_parsed_double_prefix_hex_integer = false; @@ -2503,7 +2505,7 @@ std::pair, AstArray> Parser::parseG namePacks.push_back({name, nameLocation, typePack}); } - else if (lexer.current().type == '(') + else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') { auto [type, typePack] = parseTypeOrPackAnnotation(); @@ -2512,6 +2514,15 @@ std::pair, AstArray> Parser::parseG namePacks.push_back({name, nameLocation, typePack}); } + else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (type) + report(type->location, "Expected type pack after '=', got type"); + + namePacks.push_back({name, nameLocation, typePack}); + } } else { diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h index 351e67151..53efd3c37 100644 --- a/CodeGen/include/Luau/AddressA64.h +++ b/CodeGen/include/Luau/AddressA64.h @@ -17,7 +17,6 @@ enum class AddressKindA64 : uint8_t // reg + reg << shift // reg + sext(reg) << shift // reg + uext(reg) << shift - // pc + offset }; struct AddressA64 @@ -28,8 +27,8 @@ struct AddressA64 , offset(xzr) , data(off) { - LUAU_ASSERT(base.kind == KindA64::x); - LUAU_ASSERT(off >= 0 && off < 4096); + LUAU_ASSERT(base.kind == KindA64::x || base == sp); + LUAU_ASSERT(off >= -256 && off < 4096); } AddressA64(RegisterA64 base, RegisterA64 offset) @@ -48,5 +47,7 @@ struct AddressA64 int data; }; +using mem = AddressA64; + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 9a1402bec..9e12168a0 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -34,15 +34,18 @@ class AssemblyBuilderA64 // Comparisons // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm - // TODO: add cmp + void cmp(RegisterA64 src1, RegisterA64 src2); + void cmp(RegisterA64 src1, int src2); - // Binary + // Bitwise // Note: shifted-register support and bitfield operations are omitted for simplicity // TODO: support immediate arguments (they have odd encoding and forbid many values) - // TODO: support not variants for and/or/eor (required to support not...) void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void mvn(RegisterA64 dst, RegisterA64 src); + + // Shifts void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); @@ -66,11 +69,19 @@ class AssemblyBuilderA64 // Control flow // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks + void b(Label& label); void b(ConditionA64 cond, Label& label); void cbz(RegisterA64 src, Label& label); void cbnz(RegisterA64 src, Label& label); + void br(RegisterA64 src); + void blr(RegisterA64 src); void ret(); + // Address of embedded data + void adr(RegisterA64 dst, const void* ptr, size_t size); + void adr(RegisterA64 dst, uint64_t value); + void adr(RegisterA64 dst, double value); + // Run final checks bool finalize(); @@ -97,17 +108,21 @@ class AssemblyBuilderA64 // Instruction archetypes void place0(const char* name, uint32_t word); void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0); - void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op); + void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2 = 0); void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size); void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); - void placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond); + void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); + void placeBR(const char* name, RegisterA64 src, uint32_t op); + void placeADR(const char* name, RegisterA64 src, uint8_t op); void place(uint32_t word); - void placeLabel(Label& label); + + void patchLabel(Label& label); + void patchImm19(uint32_t location, int value); void commit(); LUAU_NOINLINE void extend(); @@ -123,6 +138,7 @@ class AssemblyBuilderA64 LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, Label label); LUAU_NOINLINE void log(Label label); LUAU_NOINLINE void log(RegisterA64 reg); @@ -133,6 +149,7 @@ class AssemblyBuilderA64 std::vector labelLocations; bool finalized = false; + bool overflowed = false; size_t dataPos = 0; diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index e4237f530..286800d6d 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -37,7 +37,10 @@ AssemblyBuilderA64::~AssemblyBuilderA64() void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src) { - placeSR2("mov", dst, src, 0b01'01010); + if (dst == sp || src == sp) + placeR1("mov", dst, src, 0b00'100010'0'000000000000); + else + placeSR2("mov", dst, src, 0b01'01010); } void AssemblyBuilderA64::mov(RegisterA64 dst, uint16_t src, int shift) @@ -75,6 +78,20 @@ void AssemblyBuilderA64::neg(RegisterA64 dst, RegisterA64 src) placeSR2("neg", dst, src, 0b10'01011); } +void AssemblyBuilderA64::cmp(RegisterA64 src1, RegisterA64 src2) +{ + RegisterA64 dst = src1.kind == KindA64::x ? xzr : wzr; + + placeSR3("cmp", dst, src1, src2, 0b11'01011); +} + +void AssemblyBuilderA64::cmp(RegisterA64 src1, int src2) +{ + RegisterA64 dst = src1.kind == KindA64::x ? xzr : wzr; + + placeI12("cmp", dst, src1, src2, 0b11'10001); +} + void AssemblyBuilderA64::and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { placeSR3("and", dst, src1, src2, 0b00'01010); @@ -90,6 +107,11 @@ void AssemblyBuilderA64::eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2 placeSR3("eor", dst, src1, src2, 0b10'01010); } +void AssemblyBuilderA64::mvn(RegisterA64 dst, RegisterA64 src) +{ + placeSR2("mvn", dst, src, 0b01'01010, 0b1); +} + void AssemblyBuilderA64::lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { placeR3("lsl", dst, src1, src2, 0b11010110, 0b0010'00); @@ -183,6 +205,12 @@ void AssemblyBuilderA64::strh(RegisterA64 src, AddressA64 dst) placeA("strh", src, dst, 0b11100000, 0b01); } +void AssemblyBuilderA64::b(Label& label) +{ + // Note: we aren't using 'b' form since it has a 26-bit immediate which requires custom fixup logic + placeBC("b", label, 0b0101010'0, codeForCondition[int(ConditionA64::Always)]); +} + void AssemblyBuilderA64::b(ConditionA64 cond, Label& label) { placeBC(textForCondition[int(cond)], label, 0b0101010'0, codeForCondition[int(cond)]); @@ -190,12 +218,22 @@ void AssemblyBuilderA64::b(ConditionA64 cond, Label& label) void AssemblyBuilderA64::cbz(RegisterA64 src, Label& label) { - placeBR("cbz", label, 0b011010'0, src); + placeBCR("cbz", label, 0b011010'0, src); } void AssemblyBuilderA64::cbnz(RegisterA64 src, Label& label) { - placeBR("cbnz", label, 0b011010'1, src); + placeBCR("cbnz", label, 0b011010'1, src); +} + +void AssemblyBuilderA64::br(RegisterA64 src) +{ + placeBR("br", src, 0b1101011'0'0'00'11111'0000'0'0); +} + +void AssemblyBuilderA64::blr(RegisterA64 src) +{ + placeBR("blr", src, 0b1101011'0'0'01'11111'0000'0'0); } void AssemblyBuilderA64::ret() @@ -203,10 +241,41 @@ void AssemblyBuilderA64::ret() place0("ret", 0b1101011'0'0'10'11111'0000'0'0'11110'00000); } -bool AssemblyBuilderA64::finalize() +void AssemblyBuilderA64::adr(RegisterA64 dst, const void* ptr, size_t size) { - bool success = true; + size_t pos = allocateData(size, 4); + uint32_t location = getCodeSize(); + + memcpy(&data[pos], ptr, size); + placeADR("adr", dst, 0b10000); + + patchImm19(location, -int(location) - int((data.size() - pos) / 4)); +} +void AssemblyBuilderA64::adr(RegisterA64 dst, uint64_t value) +{ + size_t pos = allocateData(8, 8); + uint32_t location = getCodeSize(); + + writeu64(&data[pos], value); + placeADR("adr", dst, 0b10000); + + patchImm19(location, -int(location) - int((data.size() - pos) / 4)); +} + +void AssemblyBuilderA64::adr(RegisterA64 dst, double value) +{ + size_t pos = allocateData(8, 8); + uint32_t location = getCodeSize(); + + writef64(&data[pos], value); + placeADR("adr", dst, 0b10000); + + patchImm19(location, -int(location) - int((data.size() - pos) / 4)); +} + +bool AssemblyBuilderA64::finalize() +{ code.resize(codePos - code.data()); // Resolve jump targets @@ -214,15 +283,9 @@ bool AssemblyBuilderA64::finalize() { // If this assertion fires, a label was used in jmp without calling setLabel LUAU_ASSERT(labelLocations[fixup.id - 1] != ~0u); - int value = int(labelLocations[fixup.id - 1]) - int(fixup.location); - // imm19 encoding word offset, at bit offset 5 - // note that 18 bits of word offsets = 20 bits of byte offsets = +-1MB - if (value > -(1 << 18) && value < (1 << 18)) - code[fixup.location] |= (value & ((1 << 19) - 1)) << 5; - else - success = false; // overflow + patchImm19(fixup.location, value); } size_t dataSize = data.size() - dataPos; @@ -235,7 +298,7 @@ bool AssemblyBuilderA64::finalize() finalized = true; - return success; + return !overflowed; } Label AssemblyBuilderA64::setLabel() @@ -303,7 +366,7 @@ void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64 commit(); } -void AssemblyBuilderA64::placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op) +void AssemblyBuilderA64::placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2) { if (logText) log(name, dst, src); @@ -313,7 +376,7 @@ void AssemblyBuilderA64::placeSR2(const char* name, RegisterA64 dst, RegisterA64 uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; - place(dst.index | (0x1f << 5) | (src.index << 16) | (op << 24) | sf); + place(dst.index | (0x1f << 5) | (src.index << 16) | (op2 << 21) | (op << 24) | sf); commit(); } @@ -336,10 +399,10 @@ void AssemblyBuilderA64::placeR1(const char* name, RegisterA64 dst, RegisterA64 if (logText) log(name, dst, src); - LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); - LUAU_ASSERT(dst.kind == src.kind); + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp); + LUAU_ASSERT(dst.kind == src.kind || (dst.kind == KindA64::x && src == sp) || (dst == sp && src.kind == KindA64::x)); - uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; + uint32_t sf = (dst.kind != KindA64::w) ? 0x80000000 : 0; place(dst.index | (src.index << 5) | (op << 10) | sf); commit(); @@ -350,11 +413,11 @@ void AssemblyBuilderA64::placeI12(const char* name, RegisterA64 dst, RegisterA64 if (logText) log(name, dst, src1, src2); - LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); - LUAU_ASSERT(dst.kind == src1.kind); + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp); + LUAU_ASSERT(dst.kind == src1.kind || (dst.kind == KindA64::x && src1 == sp) || (dst == sp && src1.kind == KindA64::x)); LUAU_ASSERT(src2 >= 0 && src2 < (1 << 12)); - uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; + uint32_t sf = (dst.kind != KindA64::w) ? 0x80000000 : 0; place(dst.index | (src1.index << 5) | (src2 << 10) | (op << 24) | sf); commit(); @@ -383,8 +446,12 @@ void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 sr switch (src.kind) { case AddressKindA64::imm: - LUAU_ASSERT(src.data % (1 << size) == 0); - place(dst.index | (src.base.index << 5) | ((src.data >> size) << 10) | (op << 22) | (1 << 24) | (size << 30)); + if (src.data >= 0 && src.data % (1 << size) == 0) + place(dst.index | (src.base.index << 5) | ((src.data >> size) << 10) | (op << 22) | (1 << 24) | (size << 30)); + else if (src.data >= -256 && src.data <= 255) + place(dst.index | (src.base.index << 5) | ((src.data & ((1 << 9) - 1)) << 12) | (op << 22) | (size << 30)); + else + LUAU_ASSERT(!"Unable to encode large immediate offset"); break; case AddressKindA64::reg: place(dst.index | (src.base.index << 5) | (0b10 << 10) | (0b011 << 13) | (src.offset.index << 16) | (1 << 21) | (op << 22) | (size << 30)); @@ -396,27 +463,49 @@ void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 sr void AssemblyBuilderA64::placeBC(const char* name, Label& label, uint8_t op, uint8_t cond) { - placeLabel(label); + place(cond | (op << 24)); + commit(); + + patchLabel(label); if (logText) log(name, label); - - place(cond | (op << 24)); - commit(); } -void AssemblyBuilderA64::placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond) +void AssemblyBuilderA64::placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond) { - placeLabel(label); + LUAU_ASSERT(cond.kind == KindA64::w || cond.kind == KindA64::x); + + uint32_t sf = (cond.kind == KindA64::x) ? 0x80000000 : 0; + + place(cond.index | (op << 24) | sf); + commit(); + + patchLabel(label); if (logText) log(name, cond, label); +} - LUAU_ASSERT(cond.kind == KindA64::w || cond.kind == KindA64::x); +void AssemblyBuilderA64::placeBR(const char* name, RegisterA64 src, uint32_t op) +{ + if (logText) + log(name, src); - uint32_t sf = (cond.kind == KindA64::x) ? 0x80000000 : 0; + LUAU_ASSERT(src.kind == KindA64::x); - place(cond.index | (op << 24) | sf); + place((src.index << 5) | (op << 10)); + commit(); +} + +void AssemblyBuilderA64::placeADR(const char* name, RegisterA64 dst, uint8_t op) +{ + if (logText) + log(name, dst); + + LUAU_ASSERT(dst.kind == KindA64::x); + + place(dst.index | (op << 24)); commit(); } @@ -426,8 +515,10 @@ void AssemblyBuilderA64::place(uint32_t word) *codePos++ = word; } -void AssemblyBuilderA64::placeLabel(Label& label) +void AssemblyBuilderA64::patchLabel(Label& label) { + uint32_t location = getCodeSize() - 1; + if (label.location == ~0u) { if (label.id == 0) @@ -436,18 +527,26 @@ void AssemblyBuilderA64::placeLabel(Label& label) labelLocations.push_back(~0u); } - pendingLabels.push_back({label.id, getCodeSize()}); + pendingLabels.push_back({label.id, location}); } else { - // note: if label has an assigned location we can in theory avoid patching it later, but - // we need to handle potential overflow of 19-bit offsets - LUAU_ASSERT(label.id != 0); - labelLocations[label.id - 1] = label.location; - pendingLabels.push_back({label.id, getCodeSize()}); + int value = int(label.location) - int(location); + + patchImm19(location, value); } } +void AssemblyBuilderA64::patchImm19(uint32_t location, int value) +{ + // imm19 encoding word offset, at bit offset 5 + // note that 18 bits of word offsets = 20 bits of byte offsets = +-1MB + if (value > -(1 << 18) && value < (1 << 18)) + code[location] |= (value & ((1 << 19) - 1)) << 5; + else + overflowed = true; +} + void AssemblyBuilderA64::commit() { LUAU_ASSERT(codePos <= codeEnd); @@ -491,8 +590,11 @@ void AssemblyBuilderA64::log(const char* opcode) void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift) { logAppend(" %-12s", opcode); - log(dst); - text.append(","); + if (dst != xzr && dst != wzr) + { + log(dst); + text.append(","); + } log(src1); text.append(","); log(src2); @@ -504,8 +606,11 @@ void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 sr void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2) { logAppend(" %-12s", opcode); - log(dst); - text.append(","); + if (dst != xzr && dst != wzr) + { + log(dst); + text.append(","); + } log(src1); text.append(","); logAppend("#%d", src2); @@ -549,6 +654,13 @@ void AssemblyBuilderA64::log(const char* opcode, RegisterA64 src, Label label) logAppend(".L%d\n", label.id); } +void AssemblyBuilderA64::log(const char* opcode, RegisterA64 src) +{ + logAppend(" %-12s", opcode); + log(src); + text.append("\n"); +} + void AssemblyBuilderA64::log(const char* opcode, Label label) { logAppend(" %-12s.L%d\n", opcode, label.id); @@ -565,20 +677,24 @@ void AssemblyBuilderA64::log(RegisterA64 reg) { case KindA64::w: if (reg.index == 31) - logAppend("wzr"); + text.append("wzr"); else logAppend("w%d", reg.index); break; case KindA64::x: if (reg.index == 31) - logAppend("xzr"); + text.append("xzr"); else logAppend("x%d", reg.index); break; case KindA64::none: - LUAU_ASSERT(!"Unexpected register kind"); + if (reg.index == 31) + text.append("sp"); + else + LUAU_ASSERT(!"Unexpected register kind"); + break; } } diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 3074cce2a..b23d2b38c 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -14,15 +14,15 @@ * Each line is 8 bytes, stack grows downwards. * * | ... previous frames ... - * | rdx home space | (saved only on windows) - * | rcx home space | (saved only on windows) + * | rdx home space | (unused) + * | rcx home space | (unused) * | return address | - * | ... saved non-volatile registers ... + * | ... saved non-volatile registers ... <-- rsp + kStackSize + kLocalsSize * | unused | for 16 byte alignment of the stack * | sCode | - * | sClosure | <-- rbp points here - * | argument 6 | - * | argument 5 | + * | sClosure | <-- rsp + kStackSize + * | argument 6 | <-- rsp + 40 + * | argument 5 | <-- rsp + 32 * | r9 home space | * | r8 home space | * | rdx home space | @@ -48,28 +48,18 @@ bool initEntryFunction(NativeState& data) unwind.start(); - if (build.abi == ABIX64::Windows) - { - // Place arguments in home space - build.mov(qword[rsp + 16], rArg2); - unwind.spill(16, rArg2); - build.mov(qword[rsp + 8], rArg1); - unwind.spill(8, rArg1); - - // Save non-volatile registers that are specific to Windows x64 ABI - build.push(rdi); - unwind.save(rdi); - build.push(rsi); - unwind.save(rsi); + // Save common non-volatile registers + build.push(rbp); + unwind.save(rbp); - // Once we start using non-volatile SIMD registers, we will save those here + if (build.abi == ABIX64::SystemV) + { + build.mov(rbp, rsp); + unwind.setupFrameReg(rbp, 0); } - // Save common non-volatile registers build.push(rbx); unwind.save(rbx); - build.push(rbp); - unwind.save(rbp); build.push(r12); unwind.save(r12); build.push(r13); @@ -79,16 +69,20 @@ bool initEntryFunction(NativeState& data) build.push(r15); unwind.save(r15); - int stacksize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments - int localssize = 24; // 3 local pointers that also correctly align the stack + if (build.abi == ABIX64::Windows) + { + // Save non-volatile registers that are specific to Windows x64 ABI + build.push(rdi); + unwind.save(rdi); + build.push(rsi); + unwind.save(rsi); - // Allocate stack space (reg home area + local data) - build.sub(rsp, stacksize + localssize); - unwind.allocStack(stacksize + localssize); + // TODO: once we start using non-volatile SIMD registers on Windows, we will save those here + } - // Setup frame pointer - build.lea(rbp, addr[rsp + stacksize]); - unwind.setupFrameReg(rbp, stacksize); + // Allocate stack space (reg home area + local data) + build.sub(rsp, kStackSize + kLocalsSize); + unwind.allocStack(kStackSize + kLocalsSize); unwind.finish(); @@ -113,13 +107,7 @@ bool initEntryFunction(NativeState& data) Label returnOff = build.setLabel(); // Cleanup and exit - build.lea(rsp, addr[rbp + localssize]); - build.pop(r15); - build.pop(r14); - build.pop(r13); - build.pop(r12); - build.pop(rbp); - build.pop(rbx); + build.add(rsp, kStackSize + kLocalsSize); if (build.abi == ABIX64::Windows) { @@ -127,6 +115,12 @@ bool initEntryFunction(NativeState& data) build.pop(rdi); } + build.pop(r15); + build.pop(r14); + build.pop(r13); + build.pop(r12); + build.pop(rbx); + build.pop(rbp); build.ret(); build.finalize(); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 071ef6afd..615448551 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -32,8 +32,12 @@ constexpr RegisterX64 rNativeContext = r13; // NativeContext* context constexpr RegisterX64 rConstants = r12; // TValue* k // Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point -constexpr OperandX64 sClosure = qword[rbp + 0]; // Closure* cl -constexpr OperandX64 sCode = qword[rbp + 8]; // Instruction* code +// See CodeGenX64.cpp for layout +constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments +constexpr unsigned kLocalsSize = 24; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary) + +constexpr OperandX64 sClosure = qword[rsp + kStackSize + 0]; // Closure* cl +constexpr OperandX64 sCode = qword[rsp + kStackSize + 8]; // Instruction* code // TODO: These should be replaced with a portable call function that checks the ABI at runtime and reorders moves accordingly to avoid conflicts #if defined(_WIN32) diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index 8d06864ee..7dc86d3ec 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -13,11 +13,10 @@ // https://refspecs.linuxbase.org/elf/x86_64-abi-0.99.pdf [System V Application Binary Interface (AMD64 Architecture Processor Supplement)] // Interaction between Dwarf2 and System V ABI can be found in sections '3.6.2 DWARF Register Number Mapping' and '4.2.4 EH_FRAME sections' -// Call frame instruction opcodes +// Call frame instruction opcodes (Dwarf2, page 78, ch. 7.23 figure 37) #define DW_CFA_advance_loc 0x40 #define DW_CFA_offset 0x80 #define DW_CFA_restore 0xc0 -#define DW_CFA_nop 0x00 #define DW_CFA_set_loc 0x01 #define DW_CFA_advance_loc1 0x02 #define DW_CFA_advance_loc2 0x03 @@ -33,17 +32,11 @@ #define DW_CFA_def_cfa_register 0x0d #define DW_CFA_def_cfa_offset 0x0e #define DW_CFA_def_cfa_expression 0x0f -#define DW_CFA_expression 0x10 -#define DW_CFA_offset_extended_sf 0x11 -#define DW_CFA_def_cfa_sf 0x12 -#define DW_CFA_def_cfa_offset_sf 0x13 -#define DW_CFA_val_offset 0x14 -#define DW_CFA_val_offset_sf 0x15 -#define DW_CFA_val_expression 0x16 +#define DW_CFA_nop 0x00 #define DW_CFA_lo_user 0x1c #define DW_CFA_hi_user 0x3f -// Register numbers for x64 +// Register numbers for x64 (System V ABI, page 57, ch. 3.7, figure 3.36) #define DW_REG_RAX 0 #define DW_REG_RDX 1 #define DW_REG_RCX 2 @@ -197,7 +190,12 @@ void UnwindBuilderDwarf2::allocStack(int size) void UnwindBuilderDwarf2::setupFrameReg(RegisterX64 reg, int espOffset) { - // Not required for unwinding + if (espOffset != 0) + pos = advanceLocation(pos, 5); // REX.W lea rbp, [rsp + imm8] + else + pos = advanceLocation(pos, 3); // REX.W mov rbp, rsp + + // Cfa is based on rsp, so no additonal commands are required } void UnwindBuilderDwarf2::finish() diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 1b3279e82..13e92ab0a 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -77,7 +77,11 @@ void UnwindBuilderWin::setupFrameReg(RegisterX64 reg, int espOffset) frameReg = reg; frameRegOffset = uint8_t(espOffset / 16); - prologSize += 5; // REX.W lea rbp, [rsp + imm8] + if (espOffset != 0) + prologSize += 5; // REX.W lea rbp, [rsp + imm8] + else + prologSize += 3; // REX.W mov rbp, rsp + unwindCodes.push_back({prologSize, UWOP_SET_FPREG, frameRegOffset}); } diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 17fe26b18..15db9ea38 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,6 +13,8 @@ inline bool isFlagExperimental(const char* flag) static const char* kList[] = { "LuauInterpolatedStringBaseSupport", "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code + "LuauOptionalNextKey", // waiting for a fix to land in lua-apps + "LuauTryhardAnd", // waiting for a fix in graphql-lua -> apollo-client-lia -> lua-apps // makes sure we always have at least one entry nullptr, }; diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index d1f29c23c..d808ac491 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -69,6 +69,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Unary") { SINGLE_COMPARE(neg(x0, x1), 0xCB0103E0); SINGLE_COMPARE(neg(w0, w1), 0x4B0103E0); + SINGLE_COMPARE(mvn(x0, x1), 0xAA2103E0); SINGLE_COMPARE(clz(x0, x1), 0xDAC01020); SINGLE_COMPARE(clz(w0, w1), 0x5AC01020); @@ -91,19 +92,22 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") SINGLE_COMPARE(lsr(x0, x1, x2), 0x9AC22420); SINGLE_COMPARE(asr(x0, x1, x2), 0x9AC22820); SINGLE_COMPARE(ror(x0, x1, x2), 0x9AC22C20); + SINGLE_COMPARE(cmp(x0, x1), 0xEB01001F); // reg, imm SINGLE_COMPARE(add(x3, x7, 78), 0x910138E3); SINGLE_COMPARE(add(w3, w7, 78), 0x110138E3); SINGLE_COMPARE(sub(w3, w7, 78), 0x510138E3); + SINGLE_COMPARE(cmp(w0, 42), 0x7100A81F); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads") { // address forms SINGLE_COMPARE(ldr(x0, x1), 0xF9400020); - SINGLE_COMPARE(ldr(x0, AddressA64(x1, 8)), 0xF9400420); - SINGLE_COMPARE(ldr(x0, AddressA64(x1, x7)), 0xF8676820); + SINGLE_COMPARE(ldr(x0, mem(x1, 8)), 0xF9400420); + SINGLE_COMPARE(ldr(x0, mem(x1, x7)), 0xF8676820); + SINGLE_COMPARE(ldr(x0, mem(x1, -7)), 0xF85F9020); // load sizes SINGLE_COMPARE(ldr(x0, x1), 0xF9400020); @@ -121,8 +125,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores") { // address forms SINGLE_COMPARE(str(x0, x1), 0xF9000020); - SINGLE_COMPARE(str(x0, AddressA64(x1, 8)), 0xF9000420); - SINGLE_COMPARE(str(x0, AddressA64(x1, x7)), 0xF8276820); + SINGLE_COMPARE(str(x0, mem(x1, 8)), 0xF9000420); + SINGLE_COMPARE(str(x0, mem(x1, x7)), 0xF8276820); + SINGLE_COMPARE(strh(w0, mem(x1, -7)), 0x781F9020); // store sizes SINGLE_COMPARE(str(x0, x1), 0xF9000020); @@ -169,26 +174,69 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ControlFlow") build.cbz(x0, skip); build.cbnz(x0, skip); build.setLabel(skip); + build.b(skip); }, - {0x54000060, 0xB4000040, 0xB5000020})); + {0x54000060, 0xB4000040, 0xB5000020, 0x5400000E})); // Basic control flow + SINGLE_COMPARE(br(x0), 0xD61F0000); + SINGLE_COMPARE(blr(x0), 0xD63F0000); SINGLE_COMPARE(ret(), 0xD65F03C0); } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "StackOps") +{ + SINGLE_COMPARE(mov(x0, sp), 0x910003E0); + SINGLE_COMPARE(mov(sp, x0), 0x9100001F); + + SINGLE_COMPARE(add(sp, sp, 4), 0x910013FF); + SINGLE_COMPARE(sub(sp, sp, 4), 0xD10013FF); + + SINGLE_COMPARE(add(x0, sp, 4), 0x910013E0); + SINGLE_COMPARE(sub(sp, x0, 4), 0xD100101F); + + SINGLE_COMPARE(ldr(x0, mem(sp, 8)), 0xF94007E0); + SINGLE_COMPARE(str(x0, mem(sp, 8)), 0xF90007E0); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Constants") +{ + // clang-format off + CHECK(check( + [](AssemblyBuilderA64& build) { + char arr[12] = "hello world"; + build.adr(x0, arr, 12); + build.adr(x0, uint64_t(0x1234567887654321)); + build.adr(x0, 1.0); + }, + { + 0x10ffffa0, 0x10ffff20, 0x10fffec0 + }, + { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, + 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12, + 0x00, 0x00, 0x00, 0x00, // 4b padding to align double + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x0, + })); + // clang-format on +} + TEST_CASE("LogTest") { AssemblyBuilderA64 build(/* logText= */ true); + build.add(sp, sp, 4); build.add(w0, w1, w2); build.add(x0, x1, x2, 2); build.add(w7, w8, 5); build.add(x7, x8, 5); build.ldr(x7, x8); - build.ldr(x7, AddressA64(x8, 8)); - build.ldr(x7, AddressA64(x8, x9)); + build.ldr(x7, mem(x8, 8)); + build.ldr(x7, mem(x8, x9)); build.mov(x1, x2); build.movk(x1, 42, 16); + build.cmp(x1, x2); + build.blr(x0); Label l; build.b(ConditionA64::Plus, l); @@ -200,6 +248,7 @@ TEST_CASE("LogTest") build.finalize(); std::string expected = R"( + add sp,sp,#4 add w0,w1,w2 add x0,x1,x2 LSL #2 add w7,w8,#5 @@ -209,6 +258,8 @@ TEST_CASE("LogTest") ldr x7,[x8,x9] mov x1,x2 movk x1,#42 LSL #16 + cmp x1,x2 + blr x0 b.pl .L1 cbz x7,.L1 .L1: diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 2e9a4b376..65b485a7f 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -185,7 +185,7 @@ TEST_CASE("Dwarf2UnwindCodesX64") 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x85, 0x02, 0x02, 0x02, 0x0e, 0x18, 0x84, 0x03, 0x02, 0x02, 0x0e, 0x20, 0x83, 0x04, 0x02, 0x02, 0x0e, 0x28, 0x86, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, 0x0e, 0x40, - 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x02, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00}; REQUIRE(data.size() == expected.size()); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); @@ -446,7 +446,12 @@ TEST_CASE("GeneratedCodeExecutionA64") build.mov(x1, 0); // doesn't execute due to cbnz above build.setLabel(skip); - build.add(x1, x1, 1); + uint8_t one = 1; + build.adr(x2, &one, 1); + build.ldrb(w2, x2); + build.sub(x1, x1, x2); + + build.add(x1, x1, 2); build.add(x0, x0, x1, /* LSL */ 1); build.ret(); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index d011719ed..30e1b2e6e 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -21,13 +21,13 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) frontend.getGlobalScope(), &logger, NotNull{dfg.get()}); cgb->visit(root); rootScope = cgb->rootScope; - constraints = Luau::collectConstraints(NotNull{cgb->rootScope}); + constraints = Luau::borrowConstraints(cgb->constraints); } void ConstraintGraphBuilderFixture::solve(const std::string& code) { generateConstraints(code); - ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, "MainModule", NotNull(&moduleResolver), {}, &logger}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger}; cs.run(); } diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index df0abdc96..4e72dd4e7 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1054,10 +1054,6 @@ TEST_CASE("check_without_builtin_next") TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_cyclic_type") { - ScopedFastFlag sff[] = { - {"LuauForceExportSurfacesToBeNormal", true}, - }; - fileResolver.source["Module/A"] = R"( type F = (set: G) -> () @@ -1089,10 +1085,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_cyclic_type") TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_type_alias") { - ScopedFastFlag sff[] = { - {"LuauForceExportSurfacesToBeNormal", true}, - }; - fileResolver.source["Module/A"] = R"( type KeyOfTestEvents = "test-file-start" | "test-file-success" | "test-file-failure" | "test-case-result" type MyAny = any diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 77cf6130a..c0989a2e7 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2823,4 +2823,21 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_after_last_t CHECK(table->items.size == 1); } +TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_type_parameter") +{ + ScopedFastFlag sff{"LuauParserErrorsOnMissingDefaultTypePackArgument", true}; + + ParseResult result = tryParse(R"( + type Foo = nil + )"); + + REQUIRE_EQ(2, result.errors.size()); + + CHECK_EQ(Location{{1, 23}, {1, 25}}, result.errors[0].getLocation()); + CHECK_EQ("Expected type, got '>'", result.errors[0].getMessage()); + + CHECK_EQ(Location{{1, 23}, {1, 24}}, result.errors[1].getLocation()); + CHECK_EQ("Expected type pack after '=', got type", result.errors[1].getMessage()); +} + TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index 18c243b0a..c22d464ee 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -404,3 +404,20 @@ t60 = makeChainedTable(60) } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("RegressionTests"); + +TEST_CASE_FIXTURE(ReplFixture, "InfiniteRecursion") +{ + // If the infinite recrusion is not caught, test will fail + runCode(L, R"( +local NewProxyOne = newproxy(true) +local MetaTableOne = getmetatable(NewProxyOne) +MetaTableOne.__index = function() + return NewProxyOne.Game +end +print(NewProxyOne.HelloICauseACrash) +)"); +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index a510f914a..8bb1fbaf2 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -10,7 +10,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); -LUAU_FASTFLAG(LuauFixNameMaps); LUAU_FASTFLAG(LuauFunctionReturnStringificationFixup); TEST_SUITE_BEGIN("ToString"); @@ -266,11 +265,23 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") ToStringOptions o; o.exhaustive = false; - o.maxTypeLength = 40; - CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + o.maxTypeLength = 30; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> (a) -> () -> ()"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> (b) -> (a) -> (... *TRUNCATED*"); + } + else + { + o.maxTypeLength = 40; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + } } TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") @@ -285,11 +296,22 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") ToStringOptions o; o.exhaustive = true; - o.maxTypeLength = 40; - CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + o.maxTypeLength = 30; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> (a) -> () -> ()"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> (b) -> (a) -> (... *TRUNCATED*"); + } + else + { + o.maxTypeLength = 40; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + } } TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table_state_braces") @@ -423,28 +445,19 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TypeId id3Type = requireType("id3"); ToStringResult nameData = toStringDetailed(id3Type, opts); - if (FFlag::LuauFixNameMaps) - REQUIRE(3 == opts.nameMap.typeVars.size()); - else - REQUIRE_EQ(3, nameData.DEPRECATED_nameMap.typeVars.size()); + REQUIRE(3 == opts.nameMap.typeVars.size()); REQUIRE_EQ("(a, b, c) -> (a, b, c)", nameData.name); - ToStringOptions opts2; // TODO: delete opts2 when clipping FFlag::LuauFixNameMaps - if (FFlag::LuauFixNameMaps) - opts2.nameMap = std::move(opts.nameMap); - else - opts2.DEPRECATED_nameMap = std::move(nameData.DEPRECATED_nameMap); - const FunctionTypeVar* ftv = get(follow(id3Type)); REQUIRE(ftv != nullptr); auto params = flatten(ftv->argTypes).first; REQUIRE(3 == params.size()); - CHECK("a" == toString(params[0], opts2)); - CHECK("b" == toString(params[1], opts2)); - CHECK("c" == toString(params[2], opts2)); + CHECK("a" == toString(params[0], opts)); + CHECK("b" == toString(params[1], opts)); + CHECK("c" == toString(params[2], opts)); } TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") @@ -471,13 +484,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") TypeId tType = requireType("inst"); ToStringResult r = toStringDetailed(tType, opts); CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); - if (FFlag::LuauFixNameMaps) - CHECK(0 == opts.nameMap.typeVars.size()); - else - CHECK_EQ(0, r.DEPRECATED_nameMap.typeVars.size()); - - if (!FFlag::LuauFixNameMaps) - opts.DEPRECATED_nameMap = r.DEPRECATED_nameMap; + CHECK(0 == opts.nameMap.typeVars.size()); const MetatableTypeVar* tMeta = get(tType); REQUIRE(tMeta); @@ -502,8 +509,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") REQUIRE(tMeta6->props.count("two") > 0); ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts); - if (!FFlag::LuauFixNameMaps) - opts.DEPRECATED_nameMap = oneResult.DEPRECATED_nameMap; std::string twoResult = toString(tMeta6->props["two"].type, opts); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 8c738b7d2..ef40e2783 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -8,6 +8,7 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) TEST_SUITE_BEGIN("TypeAliases"); @@ -509,11 +510,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_multi_assign") TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") { + ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; + CheckResult result = check("type t10 = typeof(table)"); LUAU_REQUIRE_NO_ERRORS(result); TypeId ty = getGlobalBinding(frontend, "table"); - CHECK_EQ(toString(ty), "table"); + + if (FFlag::LuauNoMoreGlobalSingletonTypes) + { + CHECK_EQ(toString(ty), "typeof(table)"); + } + else + { + CHECK_EQ(toString(ty), "table"); + } const TableTypeVar* ttv = get(ty); REQUIRE(ttv); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 7c465c538..787aea9ae 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -57,7 +57,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "next_iterator_should_infer_types_and_type_ch local s = "foo" local t = { [s] = 1 } - local c: string, d: number = next(t) + local c: string?, d: number = next(t) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -69,7 +69,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_iterator_should_infer_types_and_type_c type Map = { [K]: V } local map: Map = { ["foo"] = 1, ["bar"] = 2, ["baz"] = 3 } - local it: (Map, string | nil) -> (string, number), t: Map, i: nil = pairs(map) + local it: (Map, string | nil) -> (string?, number), t: Map, i: nil = pairs(map) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -81,7 +81,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_iterator_should_infer_types_and_type_ type Map = { [K]: V } local array: Map = { "foo", "bar", "baz" } - local it: (Map, number) -> (number, string), t: Map, i: number = ipairs(array) + local it: (Map, number) -> (number?, string), t: Map, i: number = ipairs(array) )"); LUAU_REQUIRE_NO_ERRORS(result); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index ddf733494..b306515a9 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -532,7 +532,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") REQUIRE_EQ(2, argVec.size()); const FunctionTypeVar* fType = get(follow(argVec[0])); - REQUIRE(fType != nullptr); + REQUIRE_MESSAGE(fType != nullptr, "Expected a function but got " << toString(argVec[0])); std::vector fArgs = flatten(fType->argTypes).first; diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index b2516f6d8..c4bbc2e11 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -50,14 +50,18 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") CHECK_EQ(*requireType("s"), *typeChecker.stringType); } -TEST_CASE_FIXTURE(Fixture, "and_adds_boolean") +TEST_CASE_FIXTURE(Fixture, "and_does_not_always_add_boolean") { + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + }; + CheckResult result = check(R"( local s = "a" and 10 local x:boolean|number = s )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*requireType("s")), "boolean | number"); + CHECK_EQ(toString(*requireType("s")), "number"); } TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") @@ -971,4 +975,79 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") CHECK(toString(result.errors[1]) == "Metamethod '__lt' must return type 'boolean'"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_and") +{ + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + }; + + CheckResult result = check(R"( +local a: number? = 5 +local b: boolean = (a or 1) > 10 +local c -- free + +local x = a and 1 +local y = 'a' and 1 +local z = b and 1 +local w = c and 1 + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK("number?" == toString(requireType("x"))); + CHECK("number" == toString(requireType("y"))); + CHECK("false | number" == toString(requireType("z"))); + CHECK("number" == toString(requireType("w"))); // Normalizer considers free & falsy == never + } + else + { + CHECK("number?" == toString(requireType("x"))); + CHECK("number" == toString(requireType("y"))); + CHECK("boolean | number" == toString(requireType("z"))); // 'false' widened to boolean + CHECK("(boolean | number)?" == toString(requireType("w"))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_or") +{ + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + }; + + CheckResult result = check(R"( +local a: number | false = 5 +local b: number? = 6 +local c: boolean = true +local d: true = true +local e: false = false +local f: nil = false + +local a1 = a or 'a' +local b1 = b or 4 +local c1 = c or 'c' +local d1 = d or 'd' +local e1 = e or 'e' +local f1 = f or 'f' + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK("number | string" == toString(requireType("a1"))); + CHECK("number" == toString(requireType("b1"))); + CHECK("string | true" == toString(requireType("c1"))); + CHECK("string | true" == toString(requireType("d1"))); + CHECK("string" == toString(requireType("e1"))); + CHECK("string" == toString(requireType("f1"))); + } + else + { + CHECK("number | string" == toString(requireType("a1"))); + CHECK("number" == toString(requireType("b1"))); + CHECK("boolean | string" == toString(requireType("c1"))); // 'true' widened to boolean + CHECK("boolean | string" == toString(requireType("d1"))); // 'true' widened to boolean + CHECK("string" == toString(requireType("e1"))); + CHECK("string" == toString(requireType("f1"))); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index f6e60cdce..5688eaaa1 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -461,23 +461,27 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") LUAU_REQUIRE_NO_ERRORS(result); - // Solving this requires recognizing that we can partially solve the - // following constraint: + // Solving this requires recognizing that we can't dispatch a constraint + // like this without doing further work: // // (*blocked*) -> () <: (number) -> (b...) // - // The correct thing for us to do is to consider the constraint dispatched, - // but we need to also record a new constraint number <: *blocked* to finish - // the job later. + // We solve this by searching both types for BlockedTypeVars and block the + // constraint on any we find. It also gets the job done, but I'm worried + // about the efficiency of doing so many deep type traversals and it may + // make us more prone to getting stuck on constraint cycles. + // + // If this doesn't pan out, a possible solution is to go further down the + // path of supporting partial constraint dispatch. The way it would work is + // that we'd dispatch the above constraint by binding b... to (), but we + // would append a new constraint number <: *blocked* to the constraint set + // to be solved later. This should be faster and theoretically less prone + // to cyclic constraint dependencies. CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); } TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") { - ScopedFastFlag sff[] = { - {"LuauFixNameMaps", true}, - }; - TypeArena arena; TypeId nilType = singletonTypes->nilType; @@ -522,8 +526,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") // Ideally, we would not try to export a function type with generic types from incorrect scope TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface") { - ScopedFastFlag LuauAnyifyModuleReturnGenerics{"LuauAnyifyModuleReturnGenerics", true}; - fileResolver.source["game/A"] = R"( local wrapStrictTable @@ -563,8 +565,6 @@ return wrapStrictTable(Constants, "Constants") TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") { - ScopedFastFlag LuauAnyifyModuleReturnGenerics{"LuauAnyifyModuleReturnGenerics", true}; - fileResolver.source["game/A"] = R"( local wrapStrictTable diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 26f23438f..66550be3e 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -35,7 +35,7 @@ std::optional> magicFunctionInstanceIsA( return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } -struct RefinementClassFixture : Fixture +struct RefinementClassFixture : BuiltinsFixture { RefinementClassFixture() { @@ -320,7 +320,7 @@ TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") } } -TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_if_condition_position") { CheckResult result = check(R"( function f(s: any) @@ -332,7 +332,14 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("any & number", toString(requireTypeAtPosition({3, 26}))); + } + else + { + CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") @@ -344,10 +351,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") )"); LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("number", toString(requireType("b"))); } -TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") +TEST_CASE_FIXTURE(BuiltinsFixture, "call_an_incompatible_function_after_using_typeguard") { CheckResult result = check(R"( local function f(x: number) @@ -362,6 +370,7 @@ TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } @@ -648,7 +657,7 @@ TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_narrow_to_vector") { CheckResult result = check(R"( local function f(x) @@ -663,7 +672,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28}))); } -TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") +TEST_CASE_FIXTURE(BuiltinsFixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") { CheckResult result = check(R"( local t = {"hello"} @@ -690,7 +699,7 @@ TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true" CHECK_EQ("string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" } -TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_not_to_be_string") { CheckResult result = check(R"( local function f(x: string | number | boolean) @@ -704,11 +713,19 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("boolean | number", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(boolean | number | string) & ~string", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" + CHECK_EQ("(boolean | number | string) & string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" + } + else + { + CHECK_EQ("boolean | number", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" + } } -TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_narrows_for_table") { CheckResult result = check(R"( local function f(x: string | {x: number} | {y: boolean}) @@ -726,7 +743,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "table" } -TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_narrows_for_functions") { CheckResult result = check(R"( local function weird(x: string | ((number) -> string)) @@ -740,11 +757,19 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(number) -> string", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(((number) -> string) | string) & function", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" + CHECK_EQ("(((number) -> string) | string) & ~function", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" + } + else + { + CHECK_EQ("(number) -> string", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" + } } -TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_can_filter_for_intersection_of_tables") { CheckResult result = check(R"( type XYCoord = {x: number} & {y: number} @@ -763,7 +788,7 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } -TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_can_filter_for_overloaded_function") { CheckResult result = check(R"( type SomeOverloadedFunction = ((number) -> string) & ((string) -> number) @@ -778,8 +803,16 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("((number) -> string) & ((string) -> number)", toString(requireTypeAtPosition({4, 28}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("((((number) -> string) & ((string) -> number))?) & function", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("((((number) -> string) & ((string) -> number))?) & ~function", toString(requireTypeAtPosition({6, 28}))); + } + else + { + CHECK_EQ("((number) -> string) & ((string) -> number)", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_narrowed_into_nothingness") @@ -884,7 +917,7 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") } } -TEST_CASE_FIXTURE(Fixture, "either_number_or_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "either_number_or_string") { CheckResult result = check(R"( local function f(x: any) @@ -896,7 +929,14 @@ TEST_CASE_FIXTURE(Fixture, "either_number_or_string") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(number | string) & any", toString(requireTypeAtPosition({3, 28}))); + } + else + { + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); + } } TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") @@ -946,10 +986,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "merge_should_be_fully_agnostic_of_hashmap_or LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireTypeAtPosition({6, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string | {| x: string |}) & string", toString(requireTypeAtPosition({6, 28}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({6, 28}))); + } } -TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") { CheckResult result = check(R"( local function f(a: string | number | boolean) @@ -963,8 +1010,16 @@ TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_n LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(boolean | number | string) & ~number & ~string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(boolean | number | string) & (number | string)", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_ifelse_expression") @@ -995,7 +1050,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_ifelse_expressio CHECK_EQ("string", toString(requireTypeAtPosition({2, 50}))); } -TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") { CheckResult result = check(R"( function returnOne(x) @@ -1027,7 +1082,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which CHECK_EQ("Type 'number' does not have key 'sub'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_property_whose_base_was_previously_refined") { CheckResult result = check(R"( type T = {x: string | number} @@ -1246,8 +1301,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Instance | Vector3) & Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Instance | Vector3) & ~Vector3", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") @@ -1282,14 +1345,22 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Folder | Part | string) & Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Folder | Part | string) & ~Instance", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); + } } -TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") +TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_from_subclasses_of_instance_or_string_or_vector3") { CheckResult result = check(R"( - local function f(x: Part | Folder | Instance | string | Vector3 | any) + local function f(x: Part | Folder | string | Vector3) if typeof(x) == "Instance" then local foo = x else @@ -1300,8 +1371,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Folder | Part | Vector3 | string) & Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Folder | Part | Vector3 | string) & ~Instance", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Vector3 | string", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") @@ -1342,7 +1421,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); } -TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_doesnt_leak_to_elseif") { CheckResult result = check(R"( function f(a) @@ -1373,8 +1452,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("unknown", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("unknown & string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("unknown & ~string", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("unknown", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") @@ -1408,7 +1495,35 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "what_nonsensical_condition") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("a & number & string", toString(requireTypeAtPosition({3, 28}))); + } + else + { + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "else_with_no_explicit_expression_should_also_refine_the_tagged_union") +{ + ScopedFastFlag sff{"LuauImplicitElseRefinement", true}; + + CheckResult result = check(R"( + type Ok = { tag: "ok", value: T } + type Err = { tag: "err", err: E } + type Result = Ok | Err + + function and_then(r: Result, f: (T) -> U): Result + if r.tag == "ok" then + return { tag = "ok", value = f(r.value) } + else + return r + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 68757fef5..bf0a0af6c 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -17,6 +17,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) TEST_SUITE_BEGIN("TableTests"); @@ -1721,6 +1722,8 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") { + ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; + CheckResult result = check(R"( os.h = 2 string.k = 3 @@ -1728,19 +1731,36 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Cannot add property 'h' to table 'os'", toString(result.errors[0])); - CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); + if (FFlag::LuauNoMoreGlobalSingletonTypes) + { + CHECK_EQ("Cannot add property 'h' to table 'typeof(os)'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'k' to table 'typeof(string)'", toString(result.errors[1])); + } + else + { + CHECK_EQ("Cannot add property 'h' to table 'os'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") { + ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; + CheckResult result = check(R"( --!nonstrict function os:bad() end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Cannot add property 'bad' to table 'os'", toString(result.errors[0])); + if (FFlag::LuauNoMoreGlobalSingletonTypes) + { + CHECK_EQ("Cannot add property 'bad' to table 'typeof(os)'", toString(result.errors[0])); + } + else + { + CHECK_EQ("Cannot add property 'bad' to table 'os'", toString(result.errors[0])); + } const TableTypeVar* osType = get(requireType("os")); REQUIRE(osType != nullptr); @@ -3188,6 +3208,7 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shap TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; CheckResult result = check(R"( local function f(s) @@ -3200,25 +3221,47 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_ )"); LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + + if (FFlag::LuauNoMoreGlobalSingletonTypes) + { + CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[0])); + CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[1])); + CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[2])); + } + else + { + CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); - CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + toString(result.errors[0])); + CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[1])); - CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + toString(result.errors[1])); + CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[2])); + toString(result.errors[2])); + } } TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner", true}; // Changes argument from table type to primitive CheckResult result = check(R"( local function f(s): string @@ -3234,6 +3277,7 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compati TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; CheckResult result = check(R"( local function f(s): string @@ -3243,11 +3287,42 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_ )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' + if (FFlag::LuauNoMoreGlobalSingletonTypes) + { + CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[0])); + CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); + } + else + { + CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' caused by: The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); - CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); + toString(result.errors[0])); + CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "a_free_shape_can_turn_into_a_scalar_directly") +{ + ScopedFastFlag luauScalarShapeSubtyping{"LuauScalarShapeSubtyping", true}; + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner", true}; + + CheckResult result = check(R"( + local function stringByteList(str) + local out = {} + for i = 1, #str do + table.insert(out, string.byte(str, i)) + end + return table.concat(out, ",") + end + + local x = stringByteList("xoo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_call_is_unsound") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 6c7201a64..04b8bf574 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1125,43 +1125,6 @@ TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_higher_order_function") CHECK(location.end.line == 4); } -TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") -{ - ScopedFastFlag sff[] = { - {"DebugLuauDeferredConstraintResolution", true}, - }; - - CheckResult result = check(R"( - local function hasDivisors(value: number) - end - - function prime_iter(state, index) - hasDivisors(index) - index += 1 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Solving this requires recognizing that we can't dispatch a constraint - // like this without doing further work: - // - // (*blocked*) -> () <: (number) -> (b...) - // - // We solve this by searching both types for BlockedTypeVars and block the - // constraint on any we find. It also gets the job done, but I'm worried - // about the efficiency of doing so many deep type traversals and it may - // make us more prone to getting stuck on constraint cycles. - // - // If this doesn't pan out, a possible solution is to go further down the - // path of supporting partial constraint dispatch. The way it would work is - // that we'd dispatch the above constraint by binding b... to (), but we - // would append a new constraint number <: *blocked* to the constraint set - // to be solved later. This should be faster and theoretically less prone - // to cyclic constraint dependencies. - CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 4c8eeac60..cde651dfe 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1002,8 +1002,6 @@ TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments_free") TEST_CASE_FIXTURE(BuiltinsFixture, "type_packs_with_tails_in_vararg_adjustment") { - ScopedFastFlag luauFixVarargExprHeadType{"LuauFixVarargExprHeadType", true}; - CheckResult result = check(R"( local function wrapReject(fn: (self: any, ...TArg) -> ...TResult): (self: any, ...TArg) -> ...TResult return function(self, ...) diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 0c25386f7..627fbb566 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -395,6 +395,23 @@ local e = a.z CHECK_EQ("Type 'A | B | C | D' does not have key 'z'", toString(result.errors[3])); } +TEST_CASE_FIXTURE(Fixture, "optional_iteration") +{ + ScopedFastFlag luauNilIterator{"LuauNilIterator", true}; + + CheckResult result = check(R"( +function foo(values: {number}?) + local s = 0 + for _, value in values do + s += value + end +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{number}?' could be nil", toString(result.errors[0])); +} + TEST_CASE_FIXTURE(Fixture, "unify_unsealed_table_union_check") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index 2288db4e9..9d1e46c54 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -284,6 +284,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_i ScopedFastFlag sff[]{ {"LuauUnknownAndNeverType", true}, {"LuauNeverTypesAndOperatorsInference", true}, + {"LuauTryhardAnd", true}, }; CheckResult result = check(R"( @@ -293,7 +294,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_i )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(nil, a) -> boolean", toString(requireType("ord"))); + // Widening doesn't normalize yet, so the result is a bit strange + CHECK_EQ("(nil, a) -> boolean | boolean", toString(requireType("ord"))); } TEST_CASE_FIXTURE(Fixture, "math_operators_and_never") diff --git a/tools/faillist.txt b/tools/faillist.txt index 4ac2b357d..a228c1714 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -32,10 +32,8 @@ AutocompleteTest.type_correct_expected_argument_type_suggestion_optional AutocompleteTest.type_correct_expected_argument_type_suggestion_self AutocompleteTest.type_correct_expected_return_type_pack_suggestion AutocompleteTest.type_correct_expected_return_type_suggestion -AutocompleteTest.type_correct_full_type_suggestion AutocompleteTest.type_correct_function_no_parenthesis AutocompleteTest.type_correct_function_return_types -AutocompleteTest.type_correct_function_type_suggestion AutocompleteTest.type_correct_keywords AutocompleteTest.type_correct_suggestion_for_overloads AutocompleteTest.type_correct_suggestion_in_argument @@ -53,12 +51,10 @@ BuiltinTests.dont_add_definitions_to_persistent_types BuiltinTests.find_capture_types BuiltinTests.find_capture_types2 BuiltinTests.find_capture_types3 -BuiltinTests.gmatch_capture_types BuiltinTests.gmatch_capture_types2 BuiltinTests.gmatch_capture_types_balanced_escaped_parens BuiltinTests.gmatch_capture_types_default_capture BuiltinTests.gmatch_capture_types_parens_in_sets_are_ignored -BuiltinTests.gmatch_capture_types_set_containing_lbracket BuiltinTests.gmatch_definition BuiltinTests.ipairs_iterator_should_infer_types_and_type_check BuiltinTests.match_capture_types @@ -74,13 +70,12 @@ BuiltinTests.set_metatable_needs_arguments BuiltinTests.setmetatable_should_not_mutate_persisted_types BuiltinTests.sort_with_bad_predicate BuiltinTests.string_format_arg_count_mismatch -BuiltinTests.string_format_arg_types_inference BuiltinTests.string_format_as_method BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument BuiltinTests.string_format_use_correct_argument2 -BuiltinTests.string_format_use_correct_argument3 +BuiltinTests.strings_have_methods BuiltinTests.table_freeze_is_generic BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload @@ -114,7 +109,6 @@ GenericsTests.generic_factories GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_table_method GenericsTests.generic_type_pack_parentheses -GenericsTests.generic_type_pack_unification1 GenericsTests.generic_type_pack_unification2 GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument @@ -174,46 +168,36 @@ ProvisionalTests.while_body_are_also_refined RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number RefinementTest.assert_non_binary_expressions_actually_resolve_constraints -RefinementTest.call_a_more_specific_function_using_typeguard +RefinementTest.call_an_incompatible_function_after_using_typeguard RefinementTest.correctly_lookup_property_whose_base_was_previously_refined RefinementTest.correctly_lookup_property_whose_base_was_previously_refined2 RefinementTest.discriminate_from_isa_of_x RefinementTest.discriminate_from_truthiness_of_x RefinementTest.discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false RefinementTest.discriminate_tag -RefinementTest.either_number_or_string -RefinementTest.eliminate_subclasses_of_instance +RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.index_on_a_refined_property RefinementTest.invert_is_truthy_constraint_ifelse_expression RefinementTest.is_truthy_constraint_ifelse_expression -RefinementTest.merge_should_be_fully_agnostic_of_hashmap_ordering RefinementTest.narrow_property_of_a_bounded_variable -RefinementTest.narrow_this_large_union RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table -RefinementTest.refine_the_correct_types_opposite_of_when_a_is_not_number_or_string RefinementTest.refine_unknowns RefinementTest.truthy_constraint_on_properties RefinementTest.type_comparison_ifelse_expression RefinementTest.type_guard_can_filter_for_intersection_of_tables -RefinementTest.type_guard_can_filter_for_overloaded_function RefinementTest.type_guard_narrowed_into_nothingness RefinementTest.type_narrow_for_all_the_userdata RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector -RefinementTest.typeguard_cast_instance_or_vector3_to_vector -RefinementTest.typeguard_doesnt_leak_to_elseif RefinementTest.typeguard_in_assert_position -RefinementTest.typeguard_in_if_condition_position -RefinementTest.typeguard_narrows_for_functions RefinementTest.typeguard_narrows_for_table -RefinementTest.typeguard_not_to_be_string -RefinementTest.what_nonsensical_condition RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RefinementTest.x_is_not_instance_or_else_not_part RuntimeLimits.typescript_port_of_Result_type +TableTests.a_free_shape_can_turn_into_a_scalar_directly TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.access_index_metamethod_that_returns_variadic @@ -249,7 +233,6 @@ TableTests.generic_table_instantiation_potential_regression TableTests.getmetatable_returns_pointer_to_metatable TableTests.give_up_after_one_metatable_index_look_up TableTests.hide_table_error_properties -TableTests.indexer_fn TableTests.indexer_on_sealed_table_must_unify_with_free_table TableTests.indexing_from_a_table_should_prefer_properties_when_possible TableTests.inequality_operators_imply_exactly_matching_types @@ -262,7 +245,6 @@ TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_i TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors TableTests.less_exponential_blowup_please -TableTests.meta_add TableTests.meta_add_both_ways TableTests.meta_add_inferred TableTests.metatable_mismatch_should_fail @@ -389,7 +371,6 @@ TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_return_type_from_selected_overload -TypeInferFunctions.infer_return_value_type TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count @@ -409,10 +390,12 @@ TypeInferFunctions.too_many_return_values TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function TypeInferFunctions.vararg_function_is_quantified +TypeInferLoops.for_in_loop TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_generic_next TypeInferLoops.for_in_with_just_one_iterator_is_ok +TypeInferLoops.loop_iter_metamethod_ok_with_inference TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.unreachable_code_after_infinite_loop @@ -430,8 +413,6 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted TypeInferOOP.object_constructor_can_refer_to_method_of_self -TypeInferOperators.and_or_ternary -TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators @@ -517,6 +498,7 @@ UnionTypes.optional_assignment_errors UnionTypes.optional_call_error UnionTypes.optional_field_access_error UnionTypes.optional_index_error +UnionTypes.optional_iteration UnionTypes.optional_length_error UnionTypes.optional_missing_key_error_details UnionTypes.optional_union_follow From f52169509cb0bc4330b6dd7fd2e3d6e502e333bc Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 18 Nov 2022 10:45:14 -0800 Subject: [PATCH 17/66] Sync to upstream/release/554 --- Analysis/include/Luau/Constraint.h | 19 +- Analysis/include/Luau/ConstraintSolver.h | 3 + Analysis/src/Autocomplete.cpp | 38 +++- Analysis/src/ConstraintGraphBuilder.cpp | 124 +++-------- Analysis/src/ConstraintSolver.cpp | 270 +++++++++++++++++++---- Analysis/src/ToString.cpp | 5 + Analysis/src/TypeChecker2.cpp | 9 + Analysis/src/TypedAllocator.cpp | 4 + Ast/src/Parser.cpp | 50 ++++- CodeGen/src/CodeAllocator.cpp | 4 + Common/include/Luau/ExperimentalFlags.h | 1 - VM/src/lapi.cpp | 32 --- tests/AstJsonEncoder.test.cpp | 2 +- tests/Autocomplete.test.cpp | 67 +++++- tests/TypeInfer.tables.test.cpp | 4 +- tools/faillist.txt | 12 - 16 files changed, 440 insertions(+), 204 deletions(-) diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 4370d0cf4..16a08e879 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -132,6 +132,23 @@ struct HasPropConstraint std::string prop; }; +// result ~ setProp subjectType ["prop", "prop2", ...] propType +// +// If the subject is a table or table-like thing that already has the named +// property chain, we unify propType with that existing property type. +// +// If the subject is a free table, we augment it in place. +// +// If the subject is an unsealed table, result is an augmented table that +// includes that new prop. +struct SetPropConstraint +{ + TypeId resultType; + TypeId subjectType; + std::vector path; + TypeId propType; +}; + // result ~ if isSingleton D then ~D else unknown where D = discriminantType struct SingletonOrTopTypeConstraint { @@ -141,7 +158,7 @@ struct SingletonOrTopTypeConstraint using ConstraintV = Variant; + HasPropConstraint, SetPropConstraint, SingletonOrTopTypeConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 7b89a2781..e05f6f1f4 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -110,6 +110,7 @@ struct ConstraintSolver bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); + bool tryDispatch(const SetPropConstraint& c, NotNull constraint); bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); // for a, ... in some_table do @@ -120,6 +121,8 @@ struct ConstraintSolver bool tryDispatchIterableFunction( TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); + std::optional lookupTableProp(TypeId subjectType, const std::string& propName); + void block(NotNull target, NotNull constraint); /** * Block a constraint on the resolution of a TypeVar. diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 224e94401..50dc254fc 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -1219,6 +1219,31 @@ static std::optional getMethodContainingClass(const ModuleP return std::nullopt; } +static bool stringPartOfInterpString(const AstNode* node, Position position) +{ + const AstExprInterpString* interpString = node->as(); + if (!interpString) + { + return false; + } + + for (const AstExpr* expression : interpString->expressions) + { + if (expression->location.containsClosed(position)) + { + return false; + } + } + + return true; +} + +static bool isSimpleInterpolatedString(const AstNode* node) +{ + const AstExprInterpString* interpString = node->as(); + return interpString != nullptr && interpString->expressions.size == 0; +} + static std::optional autocompleteStringParams(const SourceModule& sourceModule, const ModulePtr& module, const std::vector& nodes, Position position, StringCompletionCallback callback) { @@ -1227,7 +1252,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - if (!nodes.back()->is() && !nodes.back()->is()) + if (!nodes.back()->is() && !isSimpleInterpolatedString(nodes.back()) && !nodes.back()->is()) { return std::nullopt; } @@ -1432,7 +1457,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is())) + else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is() || node->is())) { for (const auto& [kind, key, value] : exprTable->items) { @@ -1471,7 +1496,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { return {*ret, ancestry, AutocompleteContext::String}; } - else if (node->is()) + else if (node->is() || isSimpleInterpolatedString(node)) { AutocompleteEntryMap result; @@ -1497,6 +1522,13 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {result, ancestry, AutocompleteContext::String}; } + else if (stringPartOfInterpString(node, position)) + { + // We're not a simple interpolated string, we're something like `a{"b"}@1`, and we + // can't know what to format to + AutocompleteEntryMap map; + return {map, ancestry, AutocompleteContext::String}; + } if (node->is()) return {}; diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 42dc07f6d..e3572fe8c 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -11,6 +11,7 @@ #include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeUtils.h" +#include "Luau/TypeVar.h" LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); @@ -1019,7 +1020,22 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa args.push_back(check(scope, arg).ty); } - // TODO self + if (call->self) + { + AstExprIndexName* indexExpr = call->func->as(); + if (!indexExpr) + ice->ice("method call expression has no 'self'"); + + // The call to `check` we already did on `call->func` should have already produced a type for + // `indexExpr->expr`, so we can get it from `astTypes` to avoid exponential blow-up. + TypeId selfType = astTypes[indexExpr->expr]; + + // If we don't have a type for self, it means we had a code too complex error already. + if (selfType == nullptr) + selfType = singletonTypes->errorRecoveryType(); + + args.insert(args.begin(), selfType); + } if (matchSetmetatable(*call)) { @@ -1428,13 +1444,6 @@ TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray< return arena->addTypePack(std::move(types)); } -static bool isUnsealedTable(TypeId ty) -{ - ty = follow(ty); - const TableTypeVar* ttv = get(ty); - return ttv && ttv->state == TableState::Unsealed; -}; - /** * If the expr is a dotted set of names, and if the root symbol refers to an * unsealed table, return that table type, plus the indeces that follow as a @@ -1468,80 +1477,6 @@ static std::optional>> extractDottedN return std::nullopt; } -/** - * Create a shallow copy of `ty` and its properties along `path`. Insert a new - * property (the last segment of `path`) into the tail table with the value `t`. - * - * On success, returns the new outermost table type. If the root table or any - * of its subkeys are not unsealed tables, the function fails and returns - * std::nullopt. - * - * TODO: Prove that we completely give up in the face of indexers and - * metatables. - */ -static std::optional updateTheTableType(NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) -{ - if (path.empty()) - return std::nullopt; - - // First walk the path and ensure that it's unsealed tables all the way - // to the end. - { - TypeId t = ty; - for (size_t i = 0; i < path.size() - 1; ++i) - { - if (!isUnsealedTable(t)) - return std::nullopt; - - const TableTypeVar* tbl = get(t); - auto it = tbl->props.find(path[i]); - if (it == tbl->props.end()) - return std::nullopt; - - t = it->second.type; - } - - // The last path segment should not be a property of the table at all. - // We are not changing property types. We are only admitting this one - // new property to be appended. - if (!isUnsealedTable(t)) - return std::nullopt; - const TableTypeVar* tbl = get(t); - auto it = tbl->props.find(path.back()); - if (it != tbl->props.end()) - return std::nullopt; - } - - const TypeId res = shallowClone(ty, arena); - TypeId t = res; - - for (size_t i = 0; i < path.size() - 1; ++i) - { - const std::string segment = path[i]; - - TableTypeVar* ttv = getMutable(t); - LUAU_ASSERT(ttv); - - auto propIt = ttv->props.find(segment); - if (propIt != ttv->props.end()) - { - LUAU_ASSERT(isUnsealedTable(propIt->second.type)); - t = shallowClone(follow(propIt->second.type), arena); - ttv->props[segment].type = t; - } - else - return std::nullopt; - } - - TableTypeVar* ttv = getMutable(t); - LUAU_ASSERT(ttv); - - const std::string lastSegment = path.back(); - LUAU_ASSERT(0 == ttv->props.count(lastSegment)); - ttv->props[lastSegment] = Property{replaceTy}; - return res; -} - /** * This function is mostly about identifying properties that are being inserted into unsealed tables. * @@ -1559,31 +1494,36 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) return checkLValue(scope, &synthetic); } } + else if (!expr->is()) + return check(scope, expr).ty; auto dottedPath = extractDottedName(expr); if (!dottedPath) return check(scope, expr).ty; const auto [sym, segments] = std::move(*dottedPath); - if (!sym.local) - return check(scope, expr).ty; + LUAU_ASSERT(!segments.empty()); auto lookupResult = scope->lookupEx(sym); if (!lookupResult) return check(scope, expr).ty; - const auto [ty, symbolScope] = std::move(*lookupResult); + const auto [subjectType, symbolScope] = std::move(*lookupResult); - TypeId replaceTy = arena->freshType(scope.get()); + TypeId propTy = freshType(scope); - std::optional updatedType = updateTheTableType(arena, ty, segments, replaceTy); - if (!updatedType) - return check(scope, expr).ty; + std::vector segmentStrings(begin(segments), end(segments)); + + TypeId updatedType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), propTy}); std::optional def = dfg->getDef(sym); LUAU_ASSERT(def); - symbolScope->bindings[sym].typeId = *updatedType; - symbolScope->dcrRefinements[*def] = *updatedType; - return replaceTy; + symbolScope->bindings[sym].typeId = updatedType; + symbolScope->dcrRefinements[*def] = updatedType; + + astTypes[expr] = propTy; + + return propTy; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 533652e23..250e7ae22 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -2,6 +2,7 @@ #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" +#include "Luau/Clone.h" #include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" #include "Luau/Instantiation.h" @@ -415,6 +416,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); + else if (auto spc = get(*constraint)) + success = tryDispatch(*spc, constraint); else if (auto sottc = get(*constraint)) success = tryDispatch(*sottc, constraint); else @@ -1230,69 +1233,180 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(subjectType)) return block(subjectType, constraint); - TypeId resultType = nullptr; + std::optional resultType = lookupTableProp(subjectType, c.prop); + if (!resultType) + return false; - auto collectParts = [&](auto&& unionOrIntersection) -> std::pair> { - bool blocked = false; + if (isBlocked(*resultType)) + { + block(*resultType, constraint); + return false; + } - std::vector parts; - for (TypeId expectedPart : unionOrIntersection) + asMutable(c.resultType)->ty.emplace(*resultType); + return true; +} + +static bool isUnsealedTable(TypeId ty) +{ + ty = follow(ty); + const TableTypeVar* ttv = get(ty); + return ttv && ttv->state == TableState::Unsealed; +} + +/** + * Create a shallow copy of `ty` and its properties along `path`. Insert a new + * property (the last segment of `path`) into the tail table with the value `t`. + * + * On success, returns the new outermost table type. If the root table or any + * of its subkeys are not unsealed tables, the function fails and returns + * std::nullopt. + * + * TODO: Prove that we completely give up in the face of indexers and + * metatables. + */ +static std::optional updateTheTableType(NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) +{ + if (path.empty()) + return std::nullopt; + + // First walk the path and ensure that it's unsealed tables all the way + // to the end. + { + TypeId t = ty; + for (size_t i = 0; i < path.size() - 1; ++i) { - expectedPart = follow(expectedPart); - if (isBlocked(expectedPart) || get(expectedPart)) - { - blocked = true; - block(expectedPart, constraint); - } - else if (const TableTypeVar* ttv = get(follow(expectedPart))) - { - if (auto prop = ttv->props.find(c.prop); prop != ttv->props.end()) - parts.push_back(prop->second.type); - else if (ttv->indexer && maybeString(ttv->indexer->indexType)) - parts.push_back(ttv->indexer->indexResultType); - } + if (!isUnsealedTable(t)) + return std::nullopt; + + const TableTypeVar* tbl = get(t); + auto it = tbl->props.find(path[i]); + if (it == tbl->props.end()) + return std::nullopt; + + t = it->second.type; } - return {blocked, parts}; - }; + // The last path segment should not be a property of the table at all. + // We are not changing property types. We are only admitting this one + // new property to be appended. + if (!isUnsealedTable(t)) + return std::nullopt; + const TableTypeVar* tbl = get(t); + if (0 != tbl->props.count(path.back())) + return std::nullopt; + } - if (auto ttv = get(subjectType)) + const TypeId res = shallowClone(ty, arena); + TypeId t = res; + + for (size_t i = 0; i < path.size() - 1; ++i) { - if (auto prop = ttv->props.find(c.prop); prop != ttv->props.end()) - resultType = prop->second.type; - else if (ttv->indexer && maybeString(ttv->indexer->indexType)) - resultType = ttv->indexer->indexResultType; + const std::string segment = path[i]; + + TableTypeVar* ttv = getMutable(t); + LUAU_ASSERT(ttv); + + auto propIt = ttv->props.find(segment); + if (propIt != ttv->props.end()) + { + LUAU_ASSERT(isUnsealedTable(propIt->second.type)); + t = shallowClone(follow(propIt->second.type), arena); + ttv->props[segment].type = t; + } + else + return std::nullopt; } - else if (auto utv = get(subjectType)) - { - auto [blocked, parts] = collectParts(utv); - if (blocked) - return false; - else if (parts.size() == 1) - resultType = parts[0]; - else if (parts.size() > 1) - resultType = arena->addType(UnionTypeVar{std::move(parts)}); + TableTypeVar* ttv = getMutable(t); + LUAU_ASSERT(ttv); + + const std::string lastSegment = path.back(); + LUAU_ASSERT(0 == ttv->props.count(lastSegment)); + ttv->props[lastSegment] = Property{replaceTy}; + return res; +} + +bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint) +{ + TypeId subjectType = follow(c.subjectType); + + if (isBlocked(subjectType)) + return block(subjectType, constraint); + + std::optional existingPropType = subjectType; + for (const std::string& segment : c.path) + { + ErrorVec e; + std::optional propTy = lookupTableProp(*existingPropType, segment); + if (!propTy) + { + existingPropType = std::nullopt; + break; + } + else if (isBlocked(*propTy)) + return block(*propTy, constraint); else - LUAU_ASSERT(false); // parts.size() == 0 + existingPropType = follow(*propTy); } - else if (auto itv = get(subjectType)) + + auto bind = [](TypeId a, TypeId b) { + asMutable(a)->ty.emplace(b); + }; + + if (existingPropType) { - auto [blocked, parts] = collectParts(itv); + unify(c.propType, *existingPropType, constraint->scope); + bind(c.resultType, c.subjectType); + return true; + } - if (blocked) - return false; - else if (parts.size() == 1) - resultType = parts[0]; - else if (parts.size() > 1) - resultType = arena->addType(IntersectionTypeVar{std::move(parts)}); + if (get(subjectType)) + { + TypeId ty = arena->freshType(constraint->scope); + + // Mint a chain of free tables per c.path + for (auto it = rbegin(c.path); it != rend(c.path); ++it) + { + TableTypeVar t{TableState::Free, TypeLevel{}, constraint->scope}; + t.props[*it] = {ty}; + + ty = arena->addType(std::move(t)); + } + + LUAU_ASSERT(ty); + + bind(subjectType, ty); + bind(c.resultType, ty); + return true; + } + else if (auto ttv = getMutable(subjectType)) + { + if (ttv->state == TableState::Free) + { + ttv->props[c.path[0]] = Property{c.propType}; + bind(c.resultType, c.subjectType); + return true; + } + else if (ttv->state == TableState::Unsealed) + { + std::optional augmented = updateTheTableType(NotNull{arena}, subjectType, c.path, c.propType); + bind(c.resultType, augmented.value_or(subjectType)); + return true; + } else - LUAU_ASSERT(false); // parts.size() == 0 + { + bind(c.resultType, subjectType); + return true; + } + } + else if (get(subjectType) || get(subjectType)) + { + bind(c.resultType, subjectType); + return true; } - if (resultType) - asMutable(c.resultType)->ty.emplace(resultType); - + LUAU_ASSERT(0); return true; } @@ -1481,6 +1595,68 @@ bool ConstraintSolver::tryDispatchIterableFunction( return true; } +std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) +{ + auto collectParts = [&](auto&& unionOrIntersection) -> std::pair, std::vector> { + std::optional blocked; + + std::vector parts; + for (TypeId expectedPart : unionOrIntersection) + { + expectedPart = follow(expectedPart); + if (isBlocked(expectedPart) || get(expectedPart)) + blocked = expectedPart; + else if (const TableTypeVar* ttv = get(follow(expectedPart))) + { + if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) + parts.push_back(prop->second.type); + else if (ttv->indexer && maybeString(ttv->indexer->indexType)) + parts.push_back(ttv->indexer->indexResultType); + } + } + + return {blocked, parts}; + }; + + std::optional resultType; + + if (auto ttv = get(subjectType)) + { + if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) + resultType = prop->second.type; + else if (ttv->indexer && maybeString(ttv->indexer->indexType)) + resultType = ttv->indexer->indexResultType; + } + else if (auto utv = get(subjectType)) + { + auto [blocked, parts] = collectParts(utv); + + if (blocked) + resultType = *blocked; + else if (parts.size() == 1) + resultType = parts[0]; + else if (parts.size() > 1) + resultType = arena->addType(UnionTypeVar{std::move(parts)}); + else + LUAU_ASSERT(false); // parts.size() == 0 + } + else if (auto itv = get(subjectType)) + { + auto [blocked, parts] = collectParts(itv); + + if (blocked) + resultType = *blocked; + else if (parts.size() == 1) + resultType = parts[0]; + else if (parts.size() > 1) + resultType = arena->addType(IntersectionTypeVar{std::move(parts)}); + else + LUAU_ASSERT(false); // parts.size() == 0 + } + + return resultType; +} + void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5062c3f73..9e1fed26e 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1477,6 +1477,11 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\""; } + else if constexpr (std::is_same_v) + { + const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]"; + return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType); + } else if constexpr (std::is_same_v) { std::string result = tos(c.resultType); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 03575c405..35493bdb2 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -857,6 +857,15 @@ struct TypeChecker2 args.head.push_back(argTy); } + if (call->self) + { + AstExprIndexName* indexExpr = call->func->as(); + if (!indexExpr) + ice.ice("method call expression has no 'self'"); + + args.head.insert(args.head.begin(), lookupType(indexExpr->expr)); + } + TypePackId argsTp = arena.addTypePack(args); FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index c95c8eae6..133104d3f 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -17,8 +17,12 @@ const size_t kPageSize = 4096; #include #include +#if defined(__FreeBSD__) && !(_POSIX_C_SOURCE >= 200112L) +const size_t kPageSize = getpagesize(); +#else const size_t kPageSize = sysconf(_SC_PAGESIZE); #endif +#endif #include diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 8338a04a7..85b0d31ab 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -2661,6 +2661,7 @@ AstExpr* Parser::parseInterpString() TempVector expressions(scratchExpr); Location startLocation = lexer.current().location; + Location endLocation; do { @@ -2668,16 +2669,16 @@ AstExpr* Parser::parseInterpString() LUAU_ASSERT(currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple); - Location location = currentLexeme.location; + endLocation = currentLexeme.location; - Location startOfBrace = Location(location.end, 1); + Location startOfBrace = Location(endLocation.end, 1); scratchData.assign(currentLexeme.data, currentLexeme.length); if (!Lexer::fixupQuotedString(scratchData)) { nextLexeme(); - return reportExprError(startLocation, {}, "Interpolated string literal contains malformed escape sequence"); + return reportExprError(Location{startLocation, endLocation}, {}, "Interpolated string literal contains malformed escape sequence"); } AstArray chars = copy(scratchData); @@ -2688,15 +2689,36 @@ AstExpr* Parser::parseInterpString() if (currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple) { - AstArray> stringsArray = copy(strings); - AstArray expressionsArray = copy(expressions); - - return allocator.alloc(startLocation, stringsArray, expressionsArray); + break; } - AstExpr* expression = parseExpr(); + bool errorWhileChecking = false; + + switch (lexer.current().type) + { + case Lexeme::InterpStringMid: + case Lexeme::InterpStringEnd: + { + errorWhileChecking = true; + nextLexeme(); + expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, expected expression inside '{}'")); + break; + } + case Lexeme::BrokenString: + { + errorWhileChecking = true; + nextLexeme(); + expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '`'?")); + break; + } + default: + expressions.push_back(parseExpr()); + } - expressions.push_back(expression); + if (errorWhileChecking) + { + break; + } switch (lexer.current().type) { @@ -2706,14 +2728,18 @@ AstExpr* Parser::parseInterpString() break; case Lexeme::BrokenInterpDoubleBrace: nextLexeme(); - return reportExprError(location, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + return reportExprError(endLocation, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); case Lexeme::BrokenString: nextLexeme(); - return reportExprError(location, {}, "Malformed interpolated string, did you forget to add a '}'?"); + return reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '}'?"); default: - return reportExprError(location, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); + return reportExprError(endLocation, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); } } while (true); + + AstArray> stringsArray = copy(strings); + AstArray expressionsArray = copy(expressions); + return allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); } AstExpr* Parser::parseNumber() diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index 823df0d89..e1950dbc7 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -20,8 +20,12 @@ const size_t kPageSize = 4096; #include #include +#if defined(__FreeBSD__) && !(_POSIX_C_SOURCE >= 200112L) +const size_t kPageSize = getpagesize(); +#else const size_t kPageSize = sysconf(_SC_PAGESIZE); #endif +#endif static size_t alignToPageSize(size_t size) { diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 15db9ea38..41c1af59d 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,7 +13,6 @@ inline bool isFlagExperimental(const char* flag) static const char* kList[] = { "LuauInterpolatedStringBaseSupport", "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code - "LuauOptionalNextKey", // waiting for a fix to land in lua-apps "LuauTryhardAnd", // waiting for a fix in graphql-lua -> apollo-client-lia -> lua-apps // makes sure we always have at least one entry nullptr, diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 28307eb90..d2091c6b5 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -148,7 +148,6 @@ void lua_rawcheckstack(lua_State* L, int size) { luaD_checkstack(L, size); expandstacklimit(L, L->top + size); - return; } void lua_xmove(lua_State* from, lua_State* to, int n) @@ -167,8 +166,6 @@ void lua_xmove(lua_State* from, lua_State* to, int n) from->top = ftop; to->top = ttop + n; - - return; } void lua_xpush(lua_State* from, lua_State* to, int idx) @@ -177,7 +174,6 @@ void lua_xpush(lua_State* from, lua_State* to, int idx) luaC_threadbarrier(to); setobj2s(to, to->top, index2addr(from, idx)); api_incr_top(to); - return; } lua_State* lua_newthread(lua_State* L) @@ -227,7 +223,6 @@ void lua_settop(lua_State* L, int idx) api_check(L, -(idx + 1) <= (L->top - L->base)); L->top += idx + 1; // `subtract' index (index is negative) } - return; } void lua_remove(lua_State* L, int idx) @@ -237,7 +232,6 @@ void lua_remove(lua_State* L, int idx) while (++p < L->top) setobj2s(L, p - 1, p); L->top--; - return; } void lua_insert(lua_State* L, int idx) @@ -248,7 +242,6 @@ void lua_insert(lua_State* L, int idx) for (StkId q = L->top; q > p; q--) setobj2s(L, q, q - 1); setobj2s(L, p, L->top); - return; } void lua_replace(lua_State* L, int idx) @@ -277,7 +270,6 @@ void lua_replace(lua_State* L, int idx) luaC_barrier(L, curr_func(L), L->top - 1); } L->top--; - return; } void lua_pushvalue(lua_State* L, int idx) @@ -286,7 +278,6 @@ void lua_pushvalue(lua_State* L, int idx) StkId o = index2addr(L, idx); setobj2s(L, L->top, o); api_incr_top(L); - return; } /* @@ -570,28 +561,24 @@ void lua_pushnil(lua_State* L) { setnilvalue(L->top); api_incr_top(L); - return; } void lua_pushnumber(lua_State* L, double n) { setnvalue(L->top, n); api_incr_top(L); - return; } void lua_pushinteger(lua_State* L, int n) { setnvalue(L->top, cast_num(n)); api_incr_top(L); - return; } void lua_pushunsigned(lua_State* L, unsigned u) { setnvalue(L->top, cast_num(u)); api_incr_top(L); - return; } #if LUA_VECTOR_SIZE == 4 @@ -599,14 +586,12 @@ void lua_pushvector(lua_State* L, float x, float y, float z, float w) { setvvalue(L->top, x, y, z, w); api_incr_top(L); - return; } #else void lua_pushvector(lua_State* L, float x, float y, float z) { setvvalue(L->top, x, y, z, 0.0f); api_incr_top(L); - return; } #endif @@ -616,7 +601,6 @@ void lua_pushlstring(lua_State* L, const char* s, size_t len) luaC_threadbarrier(L); setsvalue(L, L->top, luaS_newlstr(L, s, len)); api_incr_top(L); - return; } void lua_pushstring(lua_State* L, const char* s) @@ -661,21 +645,18 @@ void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, in setclvalue(L, L->top, cl); LUAU_ASSERT(iswhite(obj2gco(cl))); api_incr_top(L); - return; } void lua_pushboolean(lua_State* L, int b) { setbvalue(L->top, (b != 0)); // ensure that true is 1 api_incr_top(L); - return; } void lua_pushlightuserdata(lua_State* L, void* p) { setpvalue(L->top, p); api_incr_top(L); - return; } int lua_pushthread(lua_State* L) @@ -748,7 +729,6 @@ void lua_createtable(lua_State* L, int narray, int nrec) luaC_threadbarrier(L); sethvalue(L, L->top, luaH_new(L, narray, nrec)); api_incr_top(L); - return; } void lua_setreadonly(lua_State* L, int objindex, int enabled) @@ -758,7 +738,6 @@ void lua_setreadonly(lua_State* L, int objindex, int enabled) Table* t = hvalue(o); api_check(L, t != hvalue(registry(L))); t->readonly = bool(enabled); - return; } int lua_getreadonly(lua_State* L, int objindex) @@ -776,7 +755,6 @@ void lua_setsafeenv(lua_State* L, int objindex, int enabled) api_check(L, ttistable(o)); Table* t = hvalue(o); t->safeenv = bool(enabled); - return; } int lua_getmetatable(lua_State* L, int objindex) @@ -822,7 +800,6 @@ void lua_getfenv(lua_State* L, int idx) break; } api_incr_top(L); - return; } /* @@ -836,7 +813,6 @@ void lua_settable(lua_State* L, int idx) api_checkvalidindex(L, t); luaV_settable(L, t, L->top - 2, L->top - 1); L->top -= 2; // pop index and value - return; } void lua_setfield(lua_State* L, int idx, const char* k) @@ -848,7 +824,6 @@ void lua_setfield(lua_State* L, int idx, const char* k) setsvalue(L, &key, luaS_new(L, k)); luaV_settable(L, t, &key, L->top - 1); L->top--; - return; } void lua_rawsetfield(lua_State* L, int idx, const char* k) @@ -861,7 +836,6 @@ void lua_rawsetfield(lua_State* L, int idx, const char* k) setobj2t(L, luaH_setstr(L, hvalue(t), luaS_new(L, k)), L->top - 1); luaC_barriert(L, hvalue(t), L->top - 1); L->top--; - return; } void lua_rawset(lua_State* L, int idx) @@ -874,7 +848,6 @@ void lua_rawset(lua_State* L, int idx) setobj2t(L, luaH_set(L, hvalue(t), L->top - 2), L->top - 1); luaC_barriert(L, hvalue(t), L->top - 1); L->top -= 2; - return; } void lua_rawseti(lua_State* L, int idx, int n) @@ -887,7 +860,6 @@ void lua_rawseti(lua_State* L, int idx, int n) setobj2t(L, luaH_setnum(L, hvalue(o), n), L->top - 1); luaC_barriert(L, hvalue(o), L->top - 1); L->top--; - return; } int lua_setmetatable(lua_State* L, int objindex) @@ -979,7 +951,6 @@ void lua_call(lua_State* L, int nargs, int nresults) luaD_call(L, func, nresults); adjustresults(L, nresults); - return; } /* @@ -995,7 +966,6 @@ static void f_call(lua_State* L, void* ud) { struct CallS* c = cast_to(struct CallS*, ud); luaD_call(L, c->func, c->nresults); - return; } int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) @@ -1273,7 +1243,6 @@ void lua_concat(lua_State* L, int n) api_incr_top(L); } // else n == 1; nothing to do - return; } void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) @@ -1397,7 +1366,6 @@ void lua_unref(lua_State* L, int ref) TValue* slot = luaH_setnum(L, reg, ref); setnvalue(slot, g->registryfree); // NB: no barrier needed because value isn't collectable g->registryfree = ref; - return; } void lua_setuserdatatag(lua_State* L, int idx, int tag) diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index 81e749410..a14d5f595 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -183,7 +183,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprInterpString") AstStat* statement = expectParseStatement("local a = `var = {x}`"); std::string_view expected = - R"({"type":"AstStatLocal","location":"0,0 - 0,18","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,6 - 0,7"}],"values":[{"type":"AstExprInterpString","location":"0,10 - 0,18","strings":["var = ",""],"expressions":[{"type":"AstExprGlobal","location":"0,18 - 0,19","global":"x"}]}]})"; + R"({"type":"AstStatLocal","location":"0,0 - 0,21","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,6 - 0,7"}],"values":[{"type":"AstExprInterpString","location":"0,10 - 0,21","strings":["var = ",""],"expressions":[{"type":"AstExprGlobal","location":"0,18 - 0,19","global":"x"}]}]})"; CHECK(toJson(statement) == expected); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 9a5c3411c..45baec2ce 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -7,6 +7,7 @@ #include "Luau/StringUtils.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -2708,13 +2709,77 @@ a = if temp then even else abc@3 CHECK(ac.entryMap.count("abcdef")); } -TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string") +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_constant") { + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"(f(`@1`))"); + auto ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"(f(`@1 {"a"}`))"); + ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"(f(`{"a"} @1`))"); + ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"(f(`{"a"} @1 {"b"}`))"); + ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_expression") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + check(R"(f(`expression = {@1}`))"); + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_expression_with_comments") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"(f(`expression = {--[[ bla bla bla ]]@1`))"); auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); CHECK_EQ(ac.context, AutocompleteContext::Expression); + + check(R"(f(`expression = {@1 --[[ bla bla bla ]]`))"); + ac = autocomplete('1'); + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_as_singleton") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"( + --!strict + local function f(a: "cat" | "dog") end + + f(`@1`) + f(`uhhh{'try'}@2`) + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("cat")); + CHECK_EQ(ac.context, AutocompleteContext::String); + + ac = autocomplete('2'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index bf0a0af6c..10186e3a0 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -50,7 +50,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_table") const TableTypeVar* tType = get(requireType("t")); REQUIRE(tType != nullptr); - CHECK(1 == tType->props.count("foo")); + CHECK("{ foo: string }" == toString(requireType("t"), {true})); } TEST_CASE_FIXTURE(Fixture, "augment_nested_table") @@ -65,7 +65,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table") const TableTypeVar* pType = get(tType->props["p"].type); REQUIRE(pType != nullptr); - CHECK(pType->props.find("foo") != pType->props.end()); + CHECK("{ p: { foo: string } }" == toString(requireType("t"), {true})); } TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") diff --git a/tools/faillist.txt b/tools/faillist.txt index a228c1714..433d0cfbe 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -51,7 +51,6 @@ BuiltinTests.dont_add_definitions_to_persistent_types BuiltinTests.find_capture_types BuiltinTests.find_capture_types2 BuiltinTests.find_capture_types3 -BuiltinTests.gmatch_capture_types2 BuiltinTests.gmatch_capture_types_balanced_escaped_parens BuiltinTests.gmatch_capture_types_default_capture BuiltinTests.gmatch_capture_types_parens_in_sets_are_ignored @@ -73,9 +72,7 @@ BuiltinTests.string_format_arg_count_mismatch BuiltinTests.string_format_as_method BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions -BuiltinTests.string_format_use_correct_argument BuiltinTests.string_format_use_correct_argument2 -BuiltinTests.strings_have_methods BuiltinTests.table_freeze_is_generic BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload @@ -115,7 +112,6 @@ GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_methods GenericsTests.infer_generic_property -GenericsTests.instantiate_cyclic_generic_function GenericsTests.instantiated_function_argument_names GenericsTests.instantiation_sharing_types GenericsTests.no_stack_overflow_from_quantifying @@ -198,7 +194,6 @@ RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RefinementTest.x_is_not_instance_or_else_not_part RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_can_turn_into_a_scalar_directly -TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.access_index_metamethod_that_returns_variadic TableTests.accidentally_checked_prop_in_opposite_branch @@ -269,7 +264,6 @@ TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_ TableTests.result_is_always_any_if_lhs_is_any TableTests.result_is_bool_for_equality_operators_if_lhs_is_any TableTests.right_table_missing_key2 -TableTests.scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type TableTests.scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type TableTests.shared_selfs TableTests.shared_selfs_from_free_param @@ -285,7 +279,6 @@ TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.tables_get_names_from_their_locals TableTests.tc_member_function TableTests.tc_member_function_2 -TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.unification_of_unions_in_a_self_referential_type TableTests.unifying_tables_shouldnt_uaf2 TableTests.used_colon_instead_of_dot @@ -343,8 +336,6 @@ TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferAnyError.for_in_loop_iterator_is_error2 -TypeInferClasses.call_base_method -TypeInferClasses.call_instance_method TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.classes_without_overloaded_operators_cannot_be_added @@ -429,10 +420,7 @@ TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.UnknownGlobalCompoundAssign TypeInferPrimitives.CheckMethodsOfNumber -TypeInferPrimitives.singleton_types -TypeInferPrimitives.string_function_other TypeInferPrimitives.string_index -TypeInferPrimitives.string_method TypeInferUnknownNever.assign_to_global_which_is_never TypeInferUnknownNever.assign_to_local_which_is_never TypeInferUnknownNever.assign_to_prop_which_is_never From fc459699daad3eb9e373a3f3be5382155bbbeafc Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 2 Dec 2022 12:46:05 +0200 Subject: [PATCH 18/66] Sync to upstream/release/555 --- Analysis/include/Luau/BuiltinDefinitions.h | 1 + Analysis/include/Luau/Connective.h | 4 +- Analysis/include/Luau/Constraint.h | 6 +- .../include/Luau/ConstraintGraphBuilder.h | 4 +- ...DataFlowGraphBuilder.h => DataFlowGraph.h} | 5 + Analysis/include/Luau/Def.h | 53 ++- Analysis/include/Luau/Error.h | 46 +- Analysis/include/Luau/Normalize.h | 6 + Analysis/include/Luau/NotNull.h | 14 + Analysis/include/Luau/RecursionCounter.h | 19 +- Analysis/include/Luau/ToString.h | 4 +- Analysis/include/Luau/TxnLog.h | 2 + Analysis/include/Luau/TypeInfer.h | 13 +- Analysis/include/Luau/TypeUtils.h | 6 +- Analysis/include/Luau/TypeVar.h | 29 +- Analysis/include/Luau/Unifier.h | 4 + Analysis/src/AstQuery.cpp | 54 +-- Analysis/src/Autocomplete.cpp | 3 +- Analysis/src/BuiltinDefinitions.cpp | 8 + Analysis/src/ConstraintGraphBuilder.cpp | 449 +++++++++++++----- Analysis/src/ConstraintSolver.cpp | 27 +- ...FlowGraphBuilder.cpp => DataFlowGraph.cpp} | 37 +- Analysis/src/Def.cpp | 9 +- Analysis/src/Error.cpp | 85 ++-- Analysis/src/Frontend.cpp | 139 ++---- Analysis/src/IostreamHelpers.cpp | 2 + Analysis/src/Normalize.cpp | 71 ++- Analysis/src/ToString.cpp | 141 +++--- Analysis/src/TopoSortStatements.cpp | 4 +- Analysis/src/TxnLog.cpp | 36 ++ Analysis/src/TypeAttach.cpp | 2 +- Analysis/src/TypeChecker2.cpp | 120 +++-- Analysis/src/TypeInfer.cpp | 399 +++++++++++----- Analysis/src/TypePack.cpp | 8 +- Analysis/src/TypeUtils.cpp | 120 +++-- Analysis/src/TypeVar.cpp | 30 +- Analysis/src/TypedAllocator.cpp | 4 + Analysis/src/Unifier.cpp | 322 +++++++++---- Ast/include/Luau/Location.h | 30 ++ Ast/src/Parser.cpp | 15 +- Compiler/src/BytecodeBuilder.cpp | 5 +- Compiler/src/Compiler.cpp | 45 +- Sources.cmake | 15 +- VM/src/loslib.cpp | 15 + tests/AstQuery.test.cpp | 4 - tests/ClassFixture.cpp | 113 +++++ tests/ClassFixture.h | 13 + tests/Compiler.test.cpp | 15 + ...uilder.test.cpp => DataFlowGraph.test.cpp} | 2 +- tests/Module.test.cpp | 10 +- tests/NotNull.test.cpp | 2 + tests/Parser.test.cpp | 8 - tests/ToString.test.cpp | 8 +- tests/TypeInfer.aliases.test.cpp | 31 +- tests/TypeInfer.annotations.test.cpp | 2 - tests/TypeInfer.builtins.test.cpp | 32 +- tests/TypeInfer.classes.test.cpp | 153 +++--- tests/TypeInfer.definitions.test.cpp | 24 +- tests/TypeInfer.functions.test.cpp | 20 +- tests/TypeInfer.generics.test.cpp | 34 +- tests/TypeInfer.intersectionTypes.test.cpp | 55 ++- tests/TypeInfer.modules.test.cpp | 26 +- tests/TypeInfer.operators.test.cpp | 41 ++ tests/TypeInfer.provisional.test.cpp | 20 +- tests/TypeInfer.refinements.test.cpp | 246 +++++++++- tests/TypeInfer.tables.test.cpp | 85 +++- tests/TypeInfer.test.cpp | 15 +- tests/TypeInfer.tryUnify.test.cpp | 77 +++ tests/TypeInfer.typePacks.cpp | 57 ++- tests/TypeInfer.unionTypes.test.cpp | 33 ++ tests/VisitTypeVar.test.cpp | 9 +- tools/faillist.txt | 41 +- 72 files changed, 2520 insertions(+), 1067 deletions(-) rename Analysis/include/Luau/{DataFlowGraphBuilder.h => DataFlowGraph.h} (91%) rename Analysis/src/{DataFlowGraphBuilder.cpp => DataFlowGraph.cpp} (92%) create mode 100644 tests/ClassFixture.cpp create mode 100644 tests/ClassFixture.h rename tests/{DataFlowGraphBuilder.test.cpp => DataFlowGraph.test.cpp} (98%) diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 616367bb4..4702995d4 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -40,6 +40,7 @@ TypeId makeFunction( // Polymorphic void attachMagicFunction(TypeId ty, MagicFunction fn); void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); +void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName); diff --git a/Analysis/include/Luau/Connective.h b/Analysis/include/Luau/Connective.h index c9daa0f9e..4a6be93c3 100644 --- a/Analysis/include/Luau/Connective.h +++ b/Analysis/include/Luau/Connective.h @@ -3,7 +3,6 @@ #include "Luau/Def.h" #include "Luau/TypedAllocator.h" -#include "Luau/TypeVar.h" #include "Luau/Variant.h" #include @@ -11,6 +10,9 @@ namespace Luau { +struct TypeVar; +using TypeId = const TypeVar*; + struct Negation; struct Conjunction; struct Disjunction; diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 16a08e879..e13613ed8 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -149,11 +149,15 @@ struct SetPropConstraint TypeId propType; }; -// result ~ if isSingleton D then ~D else unknown where D = discriminantType +// if negation: +// result ~ if isSingleton D then ~D else unknown where D = discriminantType +// if not negation: +// result ~ if isSingleton D then D else unknown where D = discriminantType struct SingletonOrTopTypeConstraint { TypeId resultType; TypeId discriminantType; + bool negated; }; using ConstraintV = Variant expectedType = {}); /** * Checks the body of a function expression. diff --git a/Analysis/include/Luau/DataFlowGraphBuilder.h b/Analysis/include/Luau/DataFlowGraph.h similarity index 91% rename from Analysis/include/Luau/DataFlowGraphBuilder.h rename to Analysis/include/Luau/DataFlowGraph.h index 3a72403e3..bd096ea90 100644 --- a/Analysis/include/Luau/DataFlowGraphBuilder.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -69,9 +69,14 @@ struct DataFlowGraphBuilder struct InternalErrorReporter* handle; std::vector> scopes; + // Does not belong in DataFlowGraphBuilder, but the old solver allows properties to escape the scope they were defined in, + // so we will need to be able to emulate this same behavior here too. We can kill this once we have better flow sensitivity. + DenseHashMap> props{nullptr}; + DfgScope* childScope(DfgScope* scope); std::optional use(DfgScope* scope, Symbol symbol, AstExpr* e); + DefId use(DefId def, AstExprIndexName* e); void visit(DfgScope* scope, AstStatBlock* b); void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); diff --git a/Analysis/include/Luau/Def.h b/Analysis/include/Luau/Def.h index ac1fa132c..1eef7dfdc 100644 --- a/Analysis/include/Luau/Def.h +++ b/Analysis/include/Luau/Def.h @@ -5,37 +5,38 @@ #include "Luau/TypedAllocator.h" #include "Luau/Variant.h" +#include +#include + namespace Luau { -using Def = Variant; - -/** - * We statically approximate a value at runtime using a symbolic value, which we call a Def. - * - * DataFlowGraphBuilder will allocate these defs as a stand-in for some Luau values, and bind them to places that - * can hold a Luau value, and then observes how those defs will commute as it statically evaluate the program. - * - * It must also be noted that defs are a cyclic graph, so it is not safe to recursively traverse into it expecting it to terminate. - */ +struct Def; using DefId = NotNull; +struct FieldMetadata +{ + DefId parent; + std::string propName; +}; + /** - * A "single-object" value. + * A cell is a "single-object" value. * * Leaky implementation note: sometimes "multiple-object" values, but none of which were interesting enough to warrant creating a phi node instead. * That can happen because there's no point in creating a phi node that points to either resultant in `if math.random() > 0.5 then 5 else "hello"`. * This might become of utmost importance if we wanted to do some backward reasoning, e.g. if `5` is taken, then `cond` must be `truthy`. */ -struct Undefined +struct Cell { + std::optional field; }; /** - * A phi node is a union of defs. + * A phi node is a union of cells. * * We need this because we're statically evaluating a program, and sometimes a place may be assigned with - * different defs, and when that happens, we need a special data type that merges in all the defs + * different cells, and when that happens, we need a special data type that merges in all the cells * that will flow into that specific place. For example, consider this simple program: * * ``` @@ -56,23 +57,35 @@ struct Phi std::vector operands; }; -template -T* getMutable(DefId def) +/** + * We statically approximate a value at runtime using a symbolic value, which we call a Def. + * + * DataFlowGraphBuilder will allocate these defs as a stand-in for some Luau values, and bind them to places that + * can hold a Luau value, and then observes how those defs will commute as it statically evaluate the program. + * + * It must also be noted that defs are a cyclic graph, so it is not safe to recursively traverse into it expecting it to terminate. + */ +struct Def { - return get_if(def.get()); -} + using V = Variant; + + V v; +}; template const T* get(DefId def) { - return getMutable(def); + return get_if(&def->v); } struct DefArena { TypedAllocator allocator; - DefId freshDef(); + DefId freshCell(); + DefId freshCell(DefId parent, const std::string& prop); + // TODO: implement once we have cases where we need to merge in definitions + // DefId phi(const std::vector& defs); }; } // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index f7bd9d502..893880464 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -7,21 +7,31 @@ #include "Luau/Variant.h" #include "Luau/TypeArena.h" -LUAU_FASTFLAG(LuauIceExceptionInheritanceChange) - namespace Luau { struct TypeError; + struct TypeMismatch { + enum Context + { + CovariantContext, + InvariantContext + }; + TypeMismatch() = default; TypeMismatch(TypeId wantedType, TypeId givenType); TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason); TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error); + TypeMismatch(TypeId wantedType, TypeId givenType, Context context); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, Context context); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error, Context context); + TypeId wantedType = nullptr; TypeId givenType = nullptr; + Context context = CovariantContext; std::string reason; std::shared_ptr error; @@ -312,12 +322,33 @@ struct TypePackMismatch bool operator==(const TypePackMismatch& rhs) const; }; +struct DynamicPropertyLookupOnClassesUnsafe +{ + TypeId ty; + + bool operator==(const DynamicPropertyLookupOnClassesUnsafe& rhs) const; +}; + using TypeErrorData = Variant; + TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe>; + +struct TypeErrorSummary +{ + Location location; + ModuleName moduleName; + int code; + + TypeErrorSummary(const Location& location, const ModuleName& moduleName, int code) + : location(location) + , moduleName(moduleName) + , code(code) + { + } +}; struct TypeError { @@ -325,6 +356,7 @@ struct TypeError ModuleName moduleName; TypeErrorData data; + static int minCode(); int code() const; TypeError() = default; @@ -342,6 +374,8 @@ struct TypeError } bool operator==(const TypeError& rhs) const; + + TypeErrorSummary summary() const; }; template @@ -406,10 +440,4 @@ class InternalCompilerError : public std::exception const std::optional location; }; -// These two function overloads only exist to facilitate fast flagging a change to InternalCompilerError -// Both functions can be removed when FFlagLuauIceExceptionInheritanceChange is removed and calling code -// can directly throw InternalCompilerError. -[[noreturn]] void throwRuntimeError(const std::string& message); -[[noreturn]] void throwRuntimeError(const std::string& message, const std::string& moduleName); - } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index b28c06a58..d7e104ee5 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -194,6 +194,8 @@ struct NormalizedFunctionType struct NormalizedType; using NormalizedTyvars = std::unordered_map>; +bool isInhabited_DEPRECATED(const NormalizedType& norm); + // A normalized type is either any, unknown, or one of the form P | T | F | G where // * P is a union of primitive types (including singletons, classes and the error type) // * T is a union of table types @@ -328,6 +330,10 @@ class Normalizer bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); bool intersectNormalWithTy(NormalizedType& here, TypeId there); + // Check for inhabitance + bool isInhabited(TypeId ty, std::unordered_set seen = {}); + bool isInhabited(const NormalizedType* norm, std::unordered_set seen = {}); + // -------- Convert back from a normalized type to a type TypeId typeFromNormal(const NormalizedType& norm); }; diff --git a/Analysis/include/Luau/NotNull.h b/Analysis/include/Luau/NotNull.h index 714fa1437..ecdcb4769 100644 --- a/Analysis/include/Luau/NotNull.h +++ b/Analysis/include/Luau/NotNull.h @@ -59,6 +59,20 @@ struct NotNull return ptr; } + template + bool operator==(NotNull other) const noexcept + { + return get() == other.get(); + } + + template + bool operator!=(NotNull other) const noexcept + { + return get() != other.get(); + } + + operator bool() const noexcept = delete; + T& operator[](int) = delete; T& operator+(int) = delete; diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index 632afd195..77af10a0a 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -15,16 +15,6 @@ struct RecursionLimitException : public InternalCompilerError RecursionLimitException() : InternalCompilerError("Internal recursion counter limit exceeded") { - LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange); - } -}; - -struct RecursionLimitException_DEPRECATED : public std::exception -{ - const char* what() const noexcept - { - LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange); - return "Internal recursion counter limit exceeded"; } }; @@ -53,14 +43,7 @@ struct RecursionLimiter : RecursionCounter { if (limit > 0 && *count > limit) { - if (FFlag::LuauIceExceptionInheritanceChange) - { - throw RecursionLimitException(); - } - else - { - throw RecursionLimitException_DEPRECATED(); - } + throw RecursionLimitException(); } } }; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 0200a7190..186cc9a5b 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -30,7 +30,7 @@ struct ToStringOptions bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self - bool indent = false; + bool DEPRECATED_indent = false; // TODO Deprecated field, prune when clipping flag FFlagLuauLineBreaksDeterminIndents size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); ToStringNameMap nameMap; @@ -90,8 +90,6 @@ inline std::string toString(const Constraint& c) return toString(c, ToStringOptions{}); } -std::string toString(const LValue& lvalue); - std::string toString(const TypeVar& tv, ToStringOptions& opts); std::string toString(const TypePackVar& tp, ToStringOptions& opts); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 3c3122c27..b1a834126 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -108,6 +108,8 @@ struct TxnLog // If both logs talk about the same type, pack, or table, the rhs takes // priority. void concat(TxnLog rhs); + void concatAsIntersections(TxnLog rhs, NotNull arena); + void concatAsUnion(TxnLog rhs, NotNull arena); // Commits the TxnLog, rebinding all type pointers to their pending states. // Clears the TxnLog afterwards. diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 4eaa59694..c6f153d1d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -54,16 +54,9 @@ class TimeLimitError : public InternalCompilerError explicit TimeLimitError(const std::string& moduleName) : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) { - LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange); } }; -class TimeLimitError_DEPRECATED : public std::exception -{ -public: - virtual const char* what() const throw(); -}; - // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -95,6 +88,7 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0); + void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); @@ -399,6 +393,11 @@ struct TypeChecker */ DenseHashSet, HashBoolNamePair> duplicateTypeAliases; + /** + * A set of incorrect class definitions which is used to avoid a second-pass analysis. + */ + DenseHashSet incorrectClassDefinitions{nullptr}; + std::vector> deferredQuantification; }; diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 085ee21b0..aa9cdde2a 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -25,9 +25,9 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro // Returns the minimum and maximum number of types the argument list can accept. std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false); -// "Render" a type pack out to an array of a given length. Expands variadics and -// various other things to get there. -std::vector flatten(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length); +// Extend the provided pack to at least `length` types. +// Returns a temporary TypePack that contains those types plus a tail. +TypePack extendTypePack(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length); /** * Reduces a union by decomposing to the any/error type if it appears in the diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 0ab4d4749..d355746a5 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -3,6 +3,8 @@ #include "Luau/Ast.h" #include "Luau/Common.h" +#include "Luau/Connective.h" +#include "Luau/DataFlowGraph.h" #include "Luau/DenseHash.h" #include "Luau/Def.h" #include "Luau/NotNull.h" @@ -257,7 +259,17 @@ struct MagicFunctionCallContext TypePackId result; }; -using DcrMagicFunction = std::function; +using DcrMagicFunction = bool (*)(MagicFunctionCallContext); + +struct MagicRefinementContext +{ + ScopePtr scope; + NotNull dfg; + NotNull connectiveArena; + const class AstExprCall* callSite; +}; + +using DcrMagicRefinement = std::vector (*)(MagicRefinementContext); struct FunctionTypeVar { @@ -279,19 +291,20 @@ struct FunctionTypeVar FunctionTypeVar(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); - TypeLevel level; - Scope* scope = nullptr; + std::optional definition; /// These should all be generic std::vector generics; std::vector genericPacks; - TypePackId argTypes; std::vector> argNames; + Tags tags; + TypeLevel level; + Scope* scope = nullptr; + TypePackId argTypes; TypePackId retTypes; - std::optional definition; - MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. - DcrMagicFunction dcrMagicFunction = nullptr; // can be nullptr + MagicFunction magicFunction = nullptr; + DcrMagicFunction dcrMagicFunction = nullptr; // Fired only while solving constraints + DcrMagicRefinement dcrMagicRefinement = nullptr; // Fired only while generating constraints bool hasSelf; - Tags tags; bool hasNoGenerics = false; }; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index b5f58d3c6..af3864ea8 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -120,6 +120,9 @@ struct Unifier std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); + TxnLog combineLogsIntoIntersection(std::vector logs); + TxnLog combineLogsIntoUnion(std::vector logs); + public: // Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error" bool occursCheck(TypeId needle, TypeId haystack); @@ -134,6 +137,7 @@ struct Unifier private: bool isNonstrictMode() const; + TypeMismatch::Context mismatchContext(); void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType); void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index b93c2cc22..85d2320ae 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -11,15 +11,12 @@ #include -LUAU_FASTFLAGVARIABLE(LuauCheckOverloadedDocSymbol, false) - namespace Luau { namespace { - struct AutocompleteNodeFinder : public AstVisitor { const Position pos; @@ -432,8 +429,6 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) static std::optional checkOverloadedDocumentationSymbol( const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) { - LUAU_ASSERT(FFlag::LuauCheckOverloadedDocSymbol); - if (!documentationSymbol) return std::nullopt; @@ -469,40 +464,7 @@ std::optional getDocumentationSymbolAtPosition(const Source AstExpr* parentExpr = ancestry.size() >= 2 ? ancestry[ancestry.size() - 2]->asExpr() : nullptr; if (std::optional binding = findBindingAtPosition(module, source, position)) - { - if (FFlag::LuauCheckOverloadedDocSymbol) - { - return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol); - } - else - { - if (binding->documentationSymbol) - { - // This might be an overloaded function binding. - if (get(follow(binding->typeId))) - { - TypeId matchingOverload = nullptr; - if (parentExpr && parentExpr->is()) - { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) - { - matchingOverload = *it; - } - } - - if (matchingOverload) - { - std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; - // Default toString options are fine for this purpose. - overloadSymbol += toString(matchingOverload); - return overloadSymbol; - } - } - } - - return binding->documentationSymbol; - } - } + return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol); if (targetExpr) { @@ -514,22 +476,12 @@ std::optional getDocumentationSymbolAtPosition(const Source if (const TableTypeVar* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) - { - if (FFlag::LuauCheckOverloadedDocSymbol) - return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); - else - return propIt->second.documentationSymbol; - } + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); } else if (const ClassTypeVar* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) - { - if (FFlag::LuauCheckOverloadedDocSymbol) - return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); - else - return propIt->second.documentationSymbol; - } + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); } } } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 50dc254fc..5374c6b17 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -1457,7 +1457,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is() || node->is())) + else if (AstExprTable* exprTable = parent->as(); + exprTable && (node->is() || node->is() || node->is())) { for (const auto& [kind, key, value] : exprTable->items) { diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 67e3979a7..39568674c 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -132,6 +132,14 @@ void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn) LUAU_ASSERT(!"Got a non functional type"); } +void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn) +{ + if (auto ftv = getMutable(ty)) + ftv->dcrMagicRefinement = fn; + else + LUAU_ASSERT(!"Got a non functional type"); +} + Property makeProperty(TypeId ty, std::optional documentationSymbol) { return { diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index e3572fe8c..600b6d23d 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -9,7 +9,9 @@ #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" +#include "Luau/Substitution.h" #include "Luau/ToString.h" +#include "Luau/TxnLog.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" @@ -191,7 +193,7 @@ static void unionRefinements(const std::unordered_map& lhs, const } static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, std::unordered_map* refis, bool sense, - NotNull arena, bool eq, std::vector* constraints) + NotNull arena, bool eq, std::vector* constraints) { using RefinementMap = std::unordered_map; @@ -231,10 +233,10 @@ static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, st TypeId discriminantTy = proposition->discriminantTy; if (!sense && !eq) discriminantTy = arena->addType(NegationTypeVar{proposition->discriminantTy}); - else if (!sense && eq) + else if (eq) { discriminantTy = arena->addType(BlockedTypeVar{}); - constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy}); + constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy, !sense}); } if (auto it = refis->find(proposition->def); it != refis->end()) @@ -244,23 +246,43 @@ static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, st } } +static std::pair computeDiscriminantType(NotNull arena, const ScopePtr& scope, DefId def, TypeId discriminantTy) +{ + LUAU_ASSERT(get(def)); + + while (const Cell* current = get(def)) + { + if (!current->field) + break; + + TableTypeVar::Props props{{current->field->propName, Property{discriminantTy}}}; + discriminantTy = arena->addType(TableTypeVar{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); + + def = current->field->parent; + current = get(def); + } + + return {def, discriminantTy}; +} + void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective) { if (!connective) return; std::unordered_map refinements; - std::vector constraints; + std::vector constraints; computeRefinement(scope, connective, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); for (auto [def, discriminantTy] : refinements) { - std::optional defTy = scope->lookup(def); + auto [def2, discriminantTy2] = computeDiscriminantType(arena, scope, def, discriminantTy); + std::optional defTy = scope->lookup(def2); if (!defTy) ice->ice("Every DefId must map to a type!"); - TypeId resultTy = arena->addType(IntersectionTypeVar{{*defTy, discriminantTy}}); - scope->dcrRefinements[def] = resultTy; + TypeId resultTy = arena->addType(IntersectionTypeVar{{*defTy, discriminantTy2}}); + scope->dcrRefinements[def2] = resultTy; } for (auto& c : constraints) @@ -446,15 +468,15 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (i < local->vars.size) { - std::vector packTypes = flatten(*arena, singletonTypes, exprPack, varTypes.size() - i); + TypePack packTypes = extendTypePack(*arena, singletonTypes, exprPack, varTypes.size() - i); // fill out missing values in varTypes with values from exprPack for (size_t j = i; j < varTypes.size(); ++j) { if (!varTypes[j]) { - if (j - i < packTypes.size()) - varTypes[j] = packTypes[j - i]; + if (j - i < packTypes.head.size()) + varTypes[j] = packTypes.head[j - i]; else varTypes[j] = freshType(scope); } @@ -591,9 +613,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* FunctionSignature sig = checkFunctionSignature(scope, function->func); sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; - auto start = checkpoint(this); + Checkpoint start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); - auto end = checkpoint(this); + Checkpoint end = checkpoint(this); NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; std::unique_ptr c = @@ -611,7 +633,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. // With or without self - TypeId functionType = nullptr; + TypeId generalizedType = arena->addType(BlockedTypeVar{}); FunctionSignature sig = checkFunctionSignature(scope, function->func); @@ -620,62 +642,59 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct std::optional existingFunctionTy = scope->lookup(localName->local); if (existingFunctionTy) { - // Duplicate definition - functionType = *existingFunctionTy; + addConstraint(scope, function->name->location, SubtypeConstraint{generalizedType, *existingFunctionTy}); + + Symbol sym{localName->local}; + std::optional def = dfg->getDef(sym); + LUAU_ASSERT(def); + scope->bindings[sym].typeId = generalizedType; + scope->dcrRefinements[*def] = generalizedType; } else - { - functionType = arena->addType(BlockedTypeVar{}); - scope->bindings[localName->local] = Binding{functionType, localName->location}; - } + scope->bindings[localName->local] = Binding{generalizedType, localName->location}; + sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; } else if (AstExprGlobal* globalName = function->name->as()) { std::optional existingFunctionTy = scope->lookup(globalName->name); - if (existingFunctionTy) - { - // Duplicate definition - functionType = *existingFunctionTy; - } - else - { - functionType = arena->addType(BlockedTypeVar{}); - rootScope->bindings[globalName->name] = Binding{functionType, globalName->location}; - } + if (!existingFunctionTy) + ice->ice("prepopulateGlobalScope did not populate a global name", globalName->location); + + generalizedType = *existingFunctionTy; + sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; } else if (AstExprIndexName* indexName = function->name->as()) { TypeId containingTableType = check(scope, indexName->expr).ty; - functionType = arena->addType(BlockedTypeVar{}); - // TODO look into stack utilization. This is probably ok because it scales with AST depth. TypeId prospectiveTableType = arena->addType(TableTypeVar{TableState::Unsealed, TypeLevel{}, scope.get()}); NotNull prospectiveTable{getMutable(prospectiveTableType)}; Property& prop = prospectiveTable->props[indexName->index.value]; - prop.type = functionType; + prop.type = generalizedType; prop.location = function->name->location; addConstraint(scope, indexName->location, SubtypeConstraint{containingTableType, prospectiveTableType}); } else if (AstExprError* err = function->name->as()) { - functionType = singletonTypes->errorRecoveryType(); + generalizedType = singletonTypes->errorRecoveryType(); } - LUAU_ASSERT(functionType != nullptr); + if (generalizedType == nullptr) + ice->ice("generalizedType == nullptr", function->location); - auto start = checkpoint(this); + Checkpoint start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); - auto end = checkpoint(this); + Checkpoint end = checkpoint(this); NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; std::unique_ptr c = - std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); + std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { c->dependencies.push_back(NotNull{constraint.get()}); @@ -708,7 +727,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { TypePackId varPackId = checkLValues(scope, assign->vars); - TypePackId valuePack = checkPack(scope, assign->values).tp; + + TypePack expectedTypes = extendTypePack(*arena, singletonTypes, varPackId, assign->values.size); + TypePackId valuePack = checkPack(scope, assign->values, expectedTypes.head).tp; addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); } @@ -729,8 +750,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { - // TODO: Optimization opportunity, the interior scope of the condition could be - // reused for the then body, so we don't need to refine twice. ScopePtr condScope = childScope(ifStatement->condition, scope); auto [_, connective] = check(condScope, ifStatement->condition, std::nullopt); @@ -986,7 +1005,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* InferencePack result; if (AstExprCall* call = expr->as()) - result = {checkPack(scope, call, expectedTypes)}; + result = checkPack(scope, call, expectedTypes); else if (AstExprVarargs* varargs = expr->as()) { if (scope->varargPack) @@ -1010,38 +1029,101 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) { + std::vector exprArgs; + if (call->self) + { + AstExprIndexName* indexExpr = call->func->as(); + if (!indexExpr) + ice->ice("method call expression has no 'self'"); + + exprArgs.push_back(indexExpr->expr); + } + exprArgs.insert(exprArgs.end(), call->args.begin(), call->args.end()); + + Checkpoint startCheckpoint = checkpoint(this); TypeId fnType = check(scope, call->func).ty; - auto startCheckpoint = checkpoint(this); + Checkpoint fnEndCheckpoint = checkpoint(this); - std::vector args; + TypePackId expectedArgPack = arena->freshTypePack(scope.get()); + TypePackId expectedRetPack = arena->freshTypePack(scope.get()); + TypeId expectedFunctionType = arena->addType(FunctionTypeVar{expectedArgPack, expectedRetPack}); + + TypeId instantiatedFnType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, call->location, InstantiationConstraint{instantiatedFnType, fnType}); + + NotNull extractArgsConstraint = addConstraint(scope, call->location, SubtypeConstraint{instantiatedFnType, expectedFunctionType}); + + // Fully solve fnType, then extract its argument list as expectedArgPack. + forEachConstraint(startCheckpoint, fnEndCheckpoint, this, [extractArgsConstraint](const ConstraintPtr& constraint) { + extractArgsConstraint->dependencies.emplace_back(constraint.get()); + }); + + const AstExpr* lastArg = exprArgs.size() ? exprArgs[exprArgs.size() - 1] : nullptr; + const bool needTail = lastArg && (lastArg->is() || lastArg->is()); + + TypePack expectedArgs; + + if (!needTail) + expectedArgs = extendTypePack(*arena, singletonTypes, expectedArgPack, exprArgs.size()); + else + expectedArgs = extendTypePack(*arena, singletonTypes, expectedArgPack, exprArgs.size() - 1); - for (AstExpr* arg : call->args) + std::vector connectives; + if (auto ftv = get(follow(fnType)); ftv && ftv->dcrMagicRefinement) { - args.push_back(check(scope, arg).ty); + MagicRefinementContext ctx{globalScope, dfg, NotNull{&connectiveArena}, call}; + connectives = ftv->dcrMagicRefinement(ctx); } - if (call->self) - { - AstExprIndexName* indexExpr = call->func->as(); - if (!indexExpr) - ice->ice("method call expression has no 'self'"); - // The call to `check` we already did on `call->func` should have already produced a type for - // `indexExpr->expr`, so we can get it from `astTypes` to avoid exponential blow-up. - TypeId selfType = astTypes[indexExpr->expr]; + std::vector args; + std::optional argTail; - // If we don't have a type for self, it means we had a code too complex error already. - if (selfType == nullptr) - selfType = singletonTypes->errorRecoveryType(); + Checkpoint argCheckpoint = checkpoint(this); - args.insert(args.begin(), selfType); + for (size_t i = 0; i < exprArgs.size(); ++i) + { + AstExpr* arg = exprArgs[i]; + std::optional expectedType; + if (i < expectedArgs.head.size()) + expectedType = expectedArgs.head[i]; + + if (i == 0 && call->self) + { + // The self type has already been computed as a side effect of + // computing fnType. If computing that did not cause us to exceed a + // recursion limit, we can fetch it from astTypes rather than + // recomputing it. + TypeId* selfTy = astTypes.find(exprArgs[0]); + if (selfTy) + args.push_back(*selfTy); + else + args.push_back(arena->freshType(scope.get())); + } + else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) + args.push_back(check(scope, arg, expectedType).ty); + else + argTail = checkPack(scope, arg, {}).tp; // FIXME? not sure about expectedTypes here } + Checkpoint argEndCheckpoint = checkpoint(this); + + // Do not solve argument constraints until after we have extracted the + // expected types from the callable. + forEachConstraint(argCheckpoint, argEndCheckpoint, this, [extractArgsConstraint](const ConstraintPtr& constraint) { + constraint->dependencies.push_back(extractArgsConstraint); + }); + if (matchSetmetatable(*call)) { - LUAU_ASSERT(args.size() == 2); - TypeId target = args[0]; - TypeId mt = args[1]; + TypePack argTailPack; + if (argTail && args.size() < 2) + argTailPack = extendTypePack(*arena, singletonTypes, *argTail, 2 - args.size()); + + LUAU_ASSERT(args.size() + argTailPack.head.size() == 2); + + TypeId target = args.size() > 0 ? args[0] : argTailPack.head[0]; + TypeId mt = args.size() > 1 ? args[1] : argTailPack.head[args.size() == 0 ? 1 : 0]; AstExpr* targetExpr = call->args.data[0]; @@ -1051,18 +1133,16 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa if (AstExprLocal* targetLocal = targetExpr->as()) scope->bindings[targetLocal->local].typeId = resultTy; - return InferencePack{arena->addTypePack({resultTy})}; + return InferencePack{arena->addTypePack({resultTy}), std::move(connectives)}; } else { - auto endCheckpoint = checkpoint(this); - astOriginalCallTypes[call->func] = fnType; TypeId instantiatedType = arena->addType(BlockedTypeVar{}); // TODO: How do expectedTypes play into this? Do they? TypePackId rets = arena->addTypePack(BlockedTypePack{}); - TypePackId argPack = arena->addTypePack(TypePack{args, {}}); + TypePackId argPack = arena->addTypePack(TypePack{args, argTail}); FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); TypeId inferredFnType = arena->addType(ftv); @@ -1071,19 +1151,10 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa NotNull ic(unqueuedConstraints.back().get()); unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); + std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{instantiatedType, inferredFnType})); NotNull sc(unqueuedConstraints.back().get()); - // We force constraints produced by checking function arguments to wait - // until after we have resolved the constraint on the function itself. - // This ensures, for instance, that we start inferring the contents of - // lambdas under the assumption that their arguments and return types - // will be compatible with the enclosing function call. - forEachConstraint(startCheckpoint, endCheckpoint, this, [sc](const ConstraintPtr& constraint) { - constraint->dependencies.push_back(sc); - }); - - addConstraint(scope, call->func->location, + NotNull fcc = addConstraint(scope, call->func->location, FunctionCallConstraint{ {ic, sc}, fnType, @@ -1092,7 +1163,16 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa call, }); - return InferencePack{rets}; + // We force constraints produced by checking function arguments to wait + // until after we have resolved the constraint on the function itself. + // This ensures, for instance, that we start inferring the contents of + // lambdas under the assumption that their arguments and return types + // will be compatible with the enclosing function call. + forEachConstraint(fnEndCheckpoint, argEndCheckpoint, this, [fcc](const ConstraintPtr& constraint) { + fcc->dependencies.emplace_back(constraint.get()); + }); + + return InferencePack{rets, std::move(connectives)}; } } @@ -1133,9 +1213,19 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st } else if (auto a = expr->as()) { - FunctionSignature sig = checkFunctionSignature(scope, a); + Checkpoint startCheckpoint = checkpoint(this); + FunctionSignature sig = checkFunctionSignature(scope, a, expectedType); checkFunctionBody(sig.bodyScope, a); - return Inference{sig.signature}; + Checkpoint endCheckpoint = checkpoint(this); + + TypeId generalizedTy = arena->addType(BlockedTypeVar{}); + NotNull gc = addConstraint(scope, expr->location, GeneralizationConstraint{generalizedTy, sig.signature}); + + forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { + gc->dependencies.emplace_back(constraint.get()); + }); + + return Inference{generalizedTy}; } else if (auto indexName = expr->as()) result = check(scope, indexName); @@ -1253,10 +1343,83 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* gl return Inference{singletonTypes->errorRecoveryType()}; } +static std::optional lookupProp(TypeId ty, const std::string& propName, NotNull arena) +{ + ty = follow(ty); + + if (auto ctv = get(ty)) + { + if (auto prop = lookupClassProp(ctv, propName)) + return prop->type; + } + else if (auto ttv = get(ty)) + { + if (auto it = ttv->props.find(propName); it != ttv->props.end()) + return it->second.type; + } + else if (auto utv = get(ty)) + { + std::vector types; + + for (TypeId ty : utv) + { + if (auto prop = lookupProp(ty, propName, arena)) + { + if (std::find(begin(types), end(types), *prop) == end(types)) + types.push_back(*prop); + } + else + return std::nullopt; + } + + if (types.size() == 1) + return types[0]; + else + return arena->addType(IntersectionTypeVar{std::move(types)}); + } + else if (auto utv = get(ty)) + { + std::vector types; + + for (TypeId ty : utv) + { + if (auto prop = lookupProp(ty, propName, arena)) + { + if (std::find(begin(types), end(types), *prop) == end(types)) + types.push_back(*prop); + } + else + return std::nullopt; + } + + if (types.size() == 1) + return types[0]; + else + return arena->addType(UnionTypeVar{std::move(types)}); + } + + return std::nullopt; +} + Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr).ty; - TypeId result = freshType(scope); + + // HACK: We need to return the actual type for type refinements so that it can invoke the dcrMagicRefinement function. + TypeId result; + if (auto prop = lookupProp(obj, indexName->index.value, arena)) + result = *prop; + else + result = freshType(scope); + + std::optional def = dfg->getDef(indexName); + if (def) + { + if (auto ty = scope->lookup(*def)) + return Inference{*ty, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + else + scope->dcrRefinements[*def] = result; + } TableTypeVar::Props props{{indexName->index.value, Property{result}}}; const std::optional indexer; @@ -1266,7 +1429,10 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); - return Inference{result}; + if (def) + return Inference{result, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + else + return Inference{result}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) @@ -1555,8 +1721,16 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp { if (auto stringKey = item.key->as()) { - expectedValueType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, item.value->location, HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data}); + ErrorVec errorVec; + std::optional propTy = + findTablePropertyRespectingMeta(singletonTypes, errorVec, follow(*expectedType), stringKey->value.data, item.value->location); + if (propTy) + expectedValueType = propTy; + else + { + expectedValueType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, item.value->location, HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data}); + } } } @@ -1590,7 +1764,8 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp return Inference{ty}; } -ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn) +ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature( + const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType) { ScopePtr signatureScope = nullptr; ScopePtr bodyScope = nullptr; @@ -1599,22 +1774,22 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS std::vector genericTypes; std::vector genericTypePacks; + if (expectedType) + expectedType = follow(*expectedType); + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; - // If we don't have any generics, we can save some memory and compute by not - // creating the signatureScope, which is only used to scope the declared - // generics properly. - if (hasGenerics) - { - signatureScope = childScope(fn, parent); + signatureScope = childScope(fn, parent); - // We need to assign returnType before creating bodyScope so that the - // return type gets propogated to bodyScope. - returnType = freshTypePack(signatureScope); - signatureScope->returnType = returnType; + // We need to assign returnType before creating bodyScope so that the + // return type gets propogated to bodyScope. + returnType = freshTypePack(signatureScope); + signatureScope->returnType = returnType; - bodyScope = childScope(fn->body, signatureScope); + bodyScope = childScope(fn->body, signatureScope); + if (hasGenerics) + { std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); @@ -1631,18 +1806,48 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS genericTypePacks.push_back(g.tp); signatureScope->privateTypePackBindings[name] = g.tp; } + + expectedType.reset(); } - else + + std::vector argTypes; + TypePack expectedArgPack; + + const FunctionTypeVar* expectedFunction = expectedType ? get(*expectedType) : nullptr; + + if (expectedFunction) { - bodyScope = childScope(fn, parent); + expectedArgPack = extendTypePack(*arena, singletonTypes, expectedFunction->argTypes, fn->args.size); + + genericTypes = expectedFunction->generics; + genericTypePacks = expectedFunction->genericPacks; + } - returnType = freshTypePack(bodyScope); - bodyScope->returnType = returnType; + for (size_t i = 0; i < fn->args.size; ++i) + { + AstLocal* local = fn->args.data[i]; + + TypeId t = freshType(signatureScope); + argTypes.push_back(t); + signatureScope->bindings[local] = Binding{t, local->location}; + + TypeId annotationTy = t; + + if (local->annotation) + { + annotationTy = resolveType(signatureScope, local->annotation, /* topLevel */ true); + addConstraint(signatureScope, local->annotation->location, SubtypeConstraint{t, annotationTy}); + } + else if (i < expectedArgPack.head.size()) + { + addConstraint(signatureScope, local->location, SubtypeConstraint{t, expectedArgPack.head[i]}); + } - // To eliminate the need to branch on hasGenerics below, we say that the - // signature scope is the body scope when there is no real signature - // scope. - signatureScope = bodyScope; + // HACK: This is the one case where the type of the definition will diverge from the type of the binding. + // We need to do this because there are cases where type refinements needs to have the information available + // at constraint generation time. + if (auto def = dfg->getDef(local)) + signatureScope->dcrRefinements[*def] = annotationTy; } TypePackId varargPack = nullptr; @@ -1654,22 +1859,28 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS TypePackId annotationType = resolveTypePack(signatureScope, fn->varargAnnotation); varargPack = annotationType; } + else if (expectedArgPack.tail && get(*expectedArgPack.tail)) + varargPack = *expectedArgPack.tail; else - { - varargPack = arena->freshTypePack(signatureScope.get()); - } + varargPack = singletonTypes->anyTypePack; signatureScope->varargPack = varargPack; + bodyScope->varargPack = varargPack; } else { varargPack = arena->addTypePack(VariadicTypePack{singletonTypes->anyType, /*hidden*/ true}); // We do not add to signatureScope->varargPack because ... is not valid // in functions without an explicit ellipsis. + + signatureScope->varargPack = std::nullopt; + bodyScope->varargPack = std::nullopt; } LUAU_ASSERT(nullptr != varargPack); + // If there is both an annotation and an expected type, the annotation wins. + // Type checking will sort out any discrepancies later. if (fn->returnAnnotation) { TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation); @@ -1680,26 +1891,11 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS LUAU_ASSERT(get(returnType)); asMutable(returnType)->ty.emplace(annotatedRetType); } - - std::vector argTypes; - - for (AstLocal* local : fn->args) + else if (expectedFunction) { - TypeId t = freshType(signatureScope); - argTypes.push_back(t); - signatureScope->bindings[local] = Binding{t, local->location}; - - if (auto def = dfg->getDef(local)) - signatureScope->dcrRefinements[*def] = t; - - if (local->annotation) - { - TypeId argAnnotation = resolveType(signatureScope, local->annotation, /* topLevel */ true); - addConstraint(signatureScope, local->annotation->location, SubtypeConstraint{t, argAnnotation}); - } + asMutable(returnType)->ty.emplace(expectedFunction->retTypes); } - // TODO: Vararg annotation. // TODO: Preserve argument names in the function's type. FunctionTypeVar actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; @@ -1711,11 +1907,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS LUAU_ASSERT(actualFunctionType); astTypes[fn] = actualFunctionType; + if (expectedType && get(*expectedType)) + { + asMutable(*expectedType)->ty.emplace(actualFunctionType); + } + return { /* signature */ actualFunctionType, - // Undo the workaround we made above: if there's no signature scope, - // don't report it. - /* signatureScope */ hasGenerics ? signatureScope : nullptr, + /* signatureScope */ signatureScope, /* bodyScope */ bodyScope, }; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 250e7ae22..d59ea70ae 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1233,9 +1233,22 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(subjectType)) return block(subjectType, constraint); + if (get(subjectType)) + { + TableTypeVar& ttv = asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, constraint->scope); + ttv.props[c.prop] = Property{c.resultType}; + asMutable(c.resultType)->ty.emplace(constraint->scope); + unblock(c.resultType); + return true; + } + std::optional resultType = lookupTableProp(subjectType, c.prop); if (!resultType) - return false; + { + asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + unblock(c.resultType); + return true; + } if (isBlocked(*resultType)) { @@ -1418,8 +1431,10 @@ bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNul TypeId followed = follow(c.discriminantType); // `nil` is a singleton type too! There's only one value of type `nil`. - if (get(followed) || isNil(followed)) + if (c.negated && (get(followed) || isNil(followed))) *asMutable(c.resultType) = NegationTypeVar{c.discriminantType}; + else if (!c.negated && get(followed)) + *asMutable(c.resultType) = BoundTypeVar{c.discriminantType}; else *asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType}; @@ -1509,17 +1524,17 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl TypePackId expectedIterArgs = arena->addTypePack({iteratorTy}); unify(iterFtv->argTypes, expectedIterArgs, constraint->scope); - std::vector iterRets = flatten(*arena, singletonTypes, iterFtv->retTypes, 2); + TypePack iterRets = extendTypePack(*arena, singletonTypes, iterFtv->retTypes, 2); - if (iterRets.size() < 1) + if (iterRets.head.size() < 1) { // We've done what we can; this will get reported as an // error by the type checker. return true; } - TypeId nextFn = iterRets[0]; - TypeId table = iterRets.size() == 2 ? iterRets[1] : arena->freshType(constraint->scope); + TypeId nextFn = iterRets.head[0]; + TypeId table = iterRets.head.size() == 2 ? iterRets.head[1] : arena->freshType(constraint->scope); if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) { diff --git a/Analysis/src/DataFlowGraphBuilder.cpp b/Analysis/src/DataFlowGraph.cpp similarity index 92% rename from Analysis/src/DataFlowGraphBuilder.cpp rename to Analysis/src/DataFlowGraph.cpp index e2c4c2857..cffd00c91 100644 --- a/Analysis/src/DataFlowGraphBuilder.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/DataFlowGraphBuilder.h" +#include "Luau/DataFlowGraph.h" #include "Luau/Error.h" @@ -11,6 +11,9 @@ namespace Luau std::optional DataFlowGraph::getDef(const AstExpr* expr) const { + // We need to skip through AstExprGroup because DFG doesn't try its best to transitively + while (auto group = expr->as()) + expr = group->expr; if (auto def = astDefs.find(expr)) return NotNull{*def}; return std::nullopt; @@ -52,16 +55,25 @@ std::optional DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, A { for (DfgScope* current = scope; current; current = current->parent) { - if (auto loc = current->bindings.find(symbol)) + if (auto def = current->bindings.find(symbol)) { - graph.astDefs[e] = *loc; - return NotNull{*loc}; + graph.astDefs[e] = *def; + return NotNull{*def}; } } return std::nullopt; } +DefId DataFlowGraphBuilder::use(DefId def, AstExprIndexName* e) +{ + auto& propertyDef = props[def][e->index.value]; + if (!propertyDef) + propertyDef = arena->freshCell(def, e->index.value); + graph.astDefs[e] = propertyDef; + return NotNull{propertyDef}; +} + void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) { DfgScope* child = childScope(scope); @@ -180,7 +192,7 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) for (AstLocal* local : l->vars) { - DefId def = arena->freshDef(); + DefId def = arena->freshCell(); graph.localDefs[local] = def; scope->bindings[local] = def; } @@ -189,7 +201,7 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) { DfgScope* forScope = childScope(scope); // TODO: loop scope. - DefId def = arena->freshDef(); + DefId def = arena->freshCell(); graph.localDefs[f->var] = def; scope->bindings[f->var] = def; @@ -203,7 +215,7 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) for (AstLocal* local : f->vars) { - DefId def = arena->freshDef(); + DefId def = arena->freshCell(); graph.localDefs[local] = def; forScope->bindings[local] = def; } @@ -245,7 +257,7 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) // TODO global? if (auto exprLocal = root->as()) { - DefId def = arena->freshDef(); + DefId def = arena->freshCell(); graph.astDefs[exprLocal] = def; // Update the def in the scope that introduced the local. Not @@ -277,7 +289,7 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) { - DefId def = arena->freshDef(); + DefId def = arena->freshCell(); graph.localDefs[l->name] = def; scope->bindings[l->name] = def; @@ -354,8 +366,7 @@ ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInde if (!def) return {}; - // TODO: properties for the above def. - return {}; + return {use(*def, i)}; } ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) @@ -375,14 +386,14 @@ ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunc { if (AstLocal* self = f->self) { - DefId def = arena->freshDef(); + DefId def = arena->freshCell(); graph.localDefs[self] = def; scope->bindings[self] = def; } for (AstLocal* param : f->args) { - DefId def = arena->freshDef(); + DefId def = arena->freshCell(); graph.localDefs[param] = def; scope->bindings[param] = def; } diff --git a/Analysis/src/Def.cpp b/Analysis/src/Def.cpp index 935301c86..8ce1129c6 100644 --- a/Analysis/src/Def.cpp +++ b/Analysis/src/Def.cpp @@ -4,9 +4,14 @@ namespace Luau { -DefId DefArena::freshDef() +DefId DefArena::freshCell() { - return NotNull{allocator.allocate(Undefined{})}; + return NotNull{allocator.allocate(Def{Cell{std::nullopt}})}; +} + +DefId DefArena::freshCell(DefId parent, const std::string& prop) +{ + return NotNull{allocator.allocate(Def{Cell{FieldMetadata{parent, prop}}})}; } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index ed1a49cde..aefaa2c71 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -2,12 +2,14 @@ #include "Luau/Error.h" #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" #include +#include -LUAU_FASTFLAGVARIABLE(LuauIceExceptionInheritanceChange, false) +LUAU_FASTFLAGVARIABLE(LuauTypeMismatchInvarianceInError, false) static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) @@ -89,6 +91,7 @@ struct ErrorConverter if (result.empty()) result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; + if (tm.error) { result += "\ncaused by:\n "; @@ -102,6 +105,10 @@ struct ErrorConverter { result += "; " + tm.reason; } + else if (FFlag::LuauTypeMismatchInvarianceInError && tm.context == TypeMismatch::InvariantContext) + { + result += " in an invariant context"; + } return result; } @@ -467,6 +474,11 @@ struct ErrorConverter { return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'"; } + + std::string operator()(const DynamicPropertyLookupOnClassesUnsafe& e) const + { + return "Attempting a dynamic property access on type '" + Luau::toString(e.ty) + "' is unsafe and may cause exceptions at runtime"; + } }; struct InvalidNameChecker @@ -514,6 +526,30 @@ TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reas { } +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, TypeMismatch::Context context) + : wantedType(wantedType) + , givenType(givenType) + , context(context) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeMismatch::Context context) + : wantedType(wantedType) + , givenType(givenType) + , context(context) + , reason(reason) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error, TypeMismatch::Context context) + : wantedType(wantedType) + , givenType(givenType) + , context(context) + , reason(reason) + , error(error ? std::make_shared(std::move(*error)) : nullptr) +{ +} + bool TypeMismatch::operator==(const TypeMismatch& rhs) const { if (!!error != !!rhs.error) @@ -522,7 +558,7 @@ bool TypeMismatch::operator==(const TypeMismatch& rhs) const if (error && !(*error == *rhs.error)) return false; - return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType && reason == rhs.reason; + return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType && reason == rhs.reason && context == rhs.context; } bool UnknownSymbol::operator==(const UnknownSymbol& rhs) const @@ -662,7 +698,17 @@ bool FunctionExitsWithoutReturning::operator==(const FunctionExitsWithoutReturni int TypeError::code() const { - return 1000 + int(data.index()); + return minCode() + int(data.index()); +} + +int TypeError::minCode() +{ + return 1000; +} + +TypeErrorSummary TypeError::summary() const +{ + return TypeErrorSummary{location, moduleName, code()}; } bool TypeError::operator==(const TypeError& rhs) const @@ -730,6 +776,11 @@ bool TypePackMismatch::operator==(const TypePackMismatch& rhs) const return *wantedTp == *rhs.wantedTp && *givenTp == *rhs.givenTp; } +bool DynamicPropertyLookupOnClassesUnsafe::operator==(const DynamicPropertyLookupOnClassesUnsafe& rhs) const +{ + return ty == rhs.ty; +} + std::string toString(const TypeError& error) { return toString(error, TypeErrorToStringOptions{}); @@ -886,6 +937,8 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) e.wantedTp = clone(e.wantedTp); e.givenTp = clone(e.givenTp); } + else if constexpr (std::is_same_v) + e.ty = clone(e.ty); else static_assert(always_false_v, "Non-exhaustive type switch"); } @@ -930,30 +983,4 @@ const char* InternalCompilerError::what() const throw() return this->message.data(); } -// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. -void throwRuntimeError(const std::string& message) -{ - if (FFlag::LuauIceExceptionInheritanceChange) - { - throw InternalCompilerError(message); - } - else - { - throw std::runtime_error(message); - } -} - -// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. -void throwRuntimeError(const std::string& message, const std::string& moduleName) -{ - if (FFlag::LuauIceExceptionInheritanceChange) - { - throw InternalCompilerError(message, moduleName); - } - else - { - throw std::runtime_error(message); - } -} - } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 22a9ecfa3..356ced0b3 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -7,7 +7,7 @@ #include "Luau/Config.h" #include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintSolver.h" -#include "Luau/DataFlowGraphBuilder.h" +#include "Luau/DataFlowGraph.h" #include "Luau/DcrLogger.h" #include "Luau/FileResolver.h" #include "Luau/Parser.h" @@ -31,7 +31,6 @@ LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAGVARIABLE(LuauFixMarkDirtyReverseDeps, false) -LUAU_FASTFLAGVARIABLE(LuauPersistTypesAfterGeneratingDocSyms, false) namespace Luau { @@ -112,57 +111,32 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c CloneState cloneState; - if (FFlag::LuauPersistTypesAfterGeneratingDocSyms) - { - std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); - - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - - typesToPersist.push_back(globalTy); - } + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - globalScope->exportedTypeBindings[name] = globalTy; - - typesToPersist.push_back(globalTy.type); - } + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - for (TypeId ty : typesToPersist) - { - persist(ty); - } + typesToPersist.push_back(globalTy); } - else - { - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - persist(globalTy); - } + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + globalScope->exportedTypeBindings[name] = globalTy; - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - globalScope->exportedTypeBindings[name] = globalTy; + typesToPersist.push_back(globalTy.type); + } - persist(globalTy.type); - } + for (TypeId ty : typesToPersist) + { + persist(ty); } return LoadDefinitionFileResult{true, parseResult, checkedModule}; @@ -194,57 +168,32 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t CloneState cloneState; - if (FFlag::LuauPersistTypesAfterGeneratingDocSyms) - { - std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); - - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - - typesToPersist.push_back(globalTy); - } + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - targetScope->exportedTypeBindings[name] = globalTy; - - typesToPersist.push_back(globalTy.type); - } + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - for (TypeId ty : typesToPersist) - { - persist(ty); - } + typesToPersist.push_back(globalTy); } - else - { - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - persist(globalTy); - } + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - targetScope->exportedTypeBindings[name] = globalTy; + typesToPersist.push_back(globalTy.type); + } - persist(globalTy.type); - } + for (TypeId ty : typesToPersist) + { + persist(ty); } return LoadDefinitionFileResult{true, parseResult, checkedModule}; @@ -493,13 +442,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond == nullptr) - throwRuntimeError("Frontend::modules does not have data for " + name, name); + throw InternalCompilerError("Frontend::modules does not have data for " + name, name); } else { auto it2 = moduleResolver.modules.find(name); if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throwRuntimeError("Frontend::modules does not have data for " + name, name); + throw InternalCompilerError("Frontend::modules does not have data for " + name, name); } return CheckResult{ @@ -606,7 +555,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional) stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }"; + else if constexpr (std::is_same_v) + stream << "DynamicPropertyLookupOnClassesUnsafe { " << toString(err.ty) << " }"; else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 21e9f7874..fa3503fdd 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -24,6 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); +LUAU_FASTFLAG(LuauUninhabitedSubAnything) namespace Luau { @@ -240,13 +241,75 @@ NormalizedType::NormalizedType(NotNull singletonTypes) { } -static bool isInhabited(const NormalizedType& norm) +static bool isShallowInhabited(const NormalizedType& norm) { + // This test is just a shallow check, for example it returns `true` for `{ p : never }` return !get(norm.tops) || !get(norm.booleans) || !norm.classes.empty() || !get(norm.errors) || !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } +bool isInhabited_DEPRECATED(const NormalizedType& norm) +{ + LUAU_ASSERT(!FFlag::LuauUninhabitedSubAnything); + return isShallowInhabited(norm); +} + +bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_set seen) +{ + if (!get(norm->tops) || !get(norm->booleans) || !get(norm->errors) || + !get(norm->nils) || !get(norm->numbers) || !get(norm->threads) || + !norm->classes.empty() || !norm->strings.isNever() || !norm->functions.isNever()) + return true; + + for (const auto& [_, intersect] : norm->tyvars) + { + if (isInhabited(intersect.get(), seen)) + return true; + } + + for (TypeId table : norm->tables) + { + if (isInhabited(table, seen)) + return true; + } + + return false; +} + +bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) +{ + // TODO: use log.follow(ty), CLI-64291 + ty = follow(ty); + + if (get(ty)) + return false; + + if (!get(ty) && !get(ty) && !get(ty) && !get(ty)) + return true; + + if (seen.count(ty)) + return true; + + seen.insert(ty); + + if (const TableTypeVar* ttv = get(ty)) + { + for (const auto& [_, prop] : ttv->props) + { + if (!isInhabited(prop.type, seen)) + return false; + } + return true; + } + + if (const MetatableTypeVar* mtv = get(ty)) + return isInhabited(mtv->table, seen) && isInhabited(mtv->metatable, seen); + + const NormalizedType* norm = normalize(ty); + return isInhabited(norm, seen); +} + static int tyvarIndex(TypeId ty) { if (const GenericTypeVar* gtv = get(ty)) @@ -378,7 +441,7 @@ static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) { if (!isPlainTyvar(tyvar)) return false; - if (!isInhabited(*intersect)) + if (!isShallowInhabited(*intersect)) return false; for (auto& [other, _] : intersect->tyvars) if (tyvarIndex(other) <= tyvarIndex(tyvar)) @@ -1852,7 +1915,7 @@ bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there) NormalizedType& inter = *it->second; if (!intersectNormalWithTy(inter, there)) return false; - if (isInhabited(inter)) + if (isShallowInhabited(inter)) ++it; else it = here.erase(it); @@ -1914,7 +1977,7 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th if (!intersectNormals(inter, *found->second, index)) return false; } - if (isInhabited(inter)) + if (isShallowInhabited(inter)) it++; else it = here.tyvars.erase(it); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 9e1fed26e..145e7fa76 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,8 +11,8 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauLvaluelessPath) LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauLineBreaksDetermineIndents, false) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) @@ -272,10 +272,20 @@ struct StringifierState private: void emitIndentation() { - if (!opts.indent) - return; + if (!FFlag::LuauLineBreaksDetermineIndents) + { + if (!opts.DEPRECATED_indent) + return; - emit(std::string(indentation, ' ')); + emit(std::string(indentation, ' ')); + } + else + { + if (!opts.useLineBreaks) + return; + + emit(std::string(indentation, ' ')); + } } }; @@ -444,7 +454,7 @@ struct TypeVarStringifier return; default: LUAU_ASSERT(!"Unknown primitive type"); - throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type)); + throw InternalCompilerError("Unknown primitive type " + std::to_string(ptv.type)); } } @@ -461,7 +471,7 @@ struct TypeVarStringifier else { LUAU_ASSERT(!"Unknown singleton type"); - throwRuntimeError("Unknown singleton type"); + throw InternalCompilerError("Unknown singleton type"); } } @@ -507,24 +517,13 @@ struct TypeVarStringifier bool plural = true; - if (FFlag::LuauFunctionReturnStringificationFixup) + auto retBegin = begin(ftv.retTypes); + auto retEnd = end(ftv.retTypes); + if (retBegin != retEnd) { - auto retBegin = begin(ftv.retTypes); - auto retEnd = end(ftv.retTypes); - if (retBegin != retEnd) - { - ++retBegin; - if (retBegin == retEnd && !retBegin.tail()) - plural = false; - } - } - else - { - if (auto retPack = get(follow(ftv.retTypes))) - { - if (retPack->head.size() == 1 && !retPack->tail) - plural = false; - } + ++retBegin; + if (retBegin == retEnd && !retBegin.tail()) + plural = false; } if (plural) @@ -978,8 +977,6 @@ struct TypePackStringifier void operator()(TypePackId tp, const GenericTypePack& pack) { - if (FFlag::DebugLuauVerboseTypeNames) - state.emit("gen-"); if (pack.explicitName) { state.usedNames.insert(pack.name); @@ -990,6 +987,15 @@ struct TypePackStringifier { state.emit(state.getName(tp)); } + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(pack.scope); + else + state.emit(pack.level); + } state.emit("..."); } @@ -1139,7 +1145,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) else tvs.stringify(ty); - if (!state.cycleNames.empty()) + if (!state.cycleNames.empty() || !state.cycleTpNames.empty()) { result.cycle = true; state.emit(" where "); @@ -1169,6 +1175,29 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) semi = true; } + std::vector> sortedCycleTpNames(state.cycleTpNames.begin(), state.cycleTpNames.end()); + std::sort(sortedCycleTpNames.begin(), sortedCycleTpNames.end(), [](const auto& a, const auto& b) { + return a.second < b.second; + }); + + TypePackStringifier tps{state}; + + for (const auto& [cycleTp, name] : sortedCycleTpNames) + { + if (semi) + state.emit(" ; "); + + state.emit(name); + state.emit(" = "); + Luau::visit( + [&tps, cycleTy = cycleTp](auto&& t) { + return tps(cycleTy, t); + }, + cycleTp->ty); + + semi = true; + } + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) { result.truncated = true; @@ -1351,22 +1380,30 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp return result.name; } +static ToStringOptions& dumpOptions() +{ + static ToStringOptions opts = ([]() { + ToStringOptions o; + o.exhaustive = true; + o.functionTypeArguments = true; + o.maxTableLength = 0; + o.maxTypeLength = 0; + return o; + })(); + + return opts; +} + std::string dump(TypeId ty) { - ToStringOptions opts; - opts.exhaustive = true; - opts.functionTypeArguments = true; - std::string s = toString(ty, opts); + std::string s = toString(ty, dumpOptions()); printf("%s\n", s.c_str()); return s; } std::string dump(TypePackId ty) { - ToStringOptions opts; - opts.exhaustive = true; - opts.functionTypeArguments = true; - std::string s = toString(ty, opts); + std::string s = toString(ty, dumpOptions()); printf("%s\n", s.c_str()); return s; } @@ -1381,10 +1418,7 @@ std::string dump(const ScopePtr& scope, const char* name) } TypeId ty = binding->typeId; - ToStringOptions opts; - opts.exhaustive = true; - opts.functionTypeArguments = true; - std::string s = toString(ty, opts); + std::string s = toString(ty, dumpOptions()); printf("%s\n", s.c_str()); return s; } @@ -1403,8 +1437,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) auto go = [&opts](auto&& c) -> std::string { using T = std::decay_t; - auto tos = [&opts](auto&& a) - { + auto tos = [&opts](auto&& a) { return toString(a, opts); }; @@ -1470,8 +1503,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) } else if constexpr (std::is_same_v) { - return tos(c.resultType) + " ~ prim " + tos(c.expectedType) + ", " + tos(c.singletonType) + ", " + - tos(c.multitonType); + return tos(c.resultType) + " ~ prim " + tos(c.expectedType) + ", " + tos(c.singletonType) + ", " + tos(c.multitonType); } else if constexpr (std::is_same_v) { @@ -1487,7 +1519,10 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) std::string result = tos(c.resultType); std::string discriminant = tos(c.discriminantType); - return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; + if (c.negated) + return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; + else + return result + " ~ if isSingleton D then D else unknown where D = " + discriminant; } else static_assert(always_false_v, "Non-exhaustive constraint switch"); @@ -1506,28 +1541,8 @@ std::string dump(const Constraint& c) return s; } -std::string toString(const LValue& lvalue) -{ - LUAU_ASSERT(!FFlag::LuauLvaluelessPath); - - std::string s; - for (const LValue* current = &lvalue; current; current = baseof(*current)) - { - if (auto field = get(*current)) - s = "." + field->key + s; - else if (auto symbol = get(*current)) - s = toString(*symbol) + s; - else - LUAU_ASSERT(!"Unknown LValue"); - } - - return s; -} - std::optional getFunctionNameAsString(const AstExpr& expr) { - LUAU_ASSERT(FFlag::LuauLvaluelessPath); - const AstExpr* curr = &expr; std::string s; diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 052c10dea..fbf741158 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -150,7 +150,7 @@ Identifier mkName(const AstStatFunction& function) auto name = mkName(*function.name); LUAU_ASSERT(bool(name)); if (!name) - throwRuntimeError("Internal error: Function declaration has a bad name"); + throw InternalCompilerError("Internal error: Function declaration has a bad name"); return *name; } @@ -256,7 +256,7 @@ struct ArcCollector : public AstVisitor { auto name = mkName(*node->name); if (!name) - throwRuntimeError("Internal error: AstStatFunction has a bad name"); + throw InternalCompilerError("Internal error: AstStatFunction has a bad name"); add(*name); return true; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 034aeaeca..1a73b049f 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -78,6 +78,42 @@ void TxnLog::concat(TxnLog rhs) typePackChanges[tp] = std::move(rep); } +void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) +{ + for (auto& [ty, rightRep] : rhs.typeVarChanges) + { + if (auto leftRep = typeVarChanges.find(ty)) + { + TypeId leftTy = arena->addType((*leftRep)->pending); + TypeId rightTy = arena->addType(rightRep->pending); + typeVarChanges[ty]->pending.ty = IntersectionTypeVar{{leftTy, rightTy}}; + } + else + typeVarChanges[ty] = std::move(rightRep); + } + + for (auto& [tp, rep] : rhs.typePackChanges) + typePackChanges[tp] = std::move(rep); +} + +void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) +{ + for (auto& [ty, rightRep] : rhs.typeVarChanges) + { + if (auto leftRep = typeVarChanges.find(ty)) + { + TypeId leftTy = arena->addType((*leftRep)->pending); + TypeId rightTy = arena->addType(rightRep->pending); + typeVarChanges[ty]->pending.ty = UnionTypeVar{{leftTy, rightTy}}; + } + else + typeVarChanges[ty] = std::move(rightRep); + } + + for (auto& [tp, rep] : rhs.typePackChanges) + typePackChanges[tp] = std::move(rep); +} + void TxnLog::commit() { for (auto& [ty, rep] : typeVarChanges) diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index c97ed05d2..e483c0473 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -341,7 +341,7 @@ class TypeRehydrationVisitor AstType* operator()(const NegationTypeVar& ntv) { // FIXME: do the same thing we do with ErrorTypeVar - throwRuntimeError("Cannot convert NegationTypeVar into AstNode"); + throw InternalCompilerError("Cannot convert NegationTypeVar into AstNode"); } private: diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 35493bdb2..84c0ca3b0 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -295,11 +295,11 @@ struct TypeChecker2 Scope* scope = findInnermostScope(ret->location); TypePackId expectedRetType = scope->returnType; - TypeArena arena; - TypePackId actualRetType = reconstructPack(ret->list, arena); + TypeArena* arena = &module->internalTypes; + TypePackId actualRetType = reconstructPack(ret->list, *arena); UnifierSharedState sharedState{&ice}; - Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Normalizer normalizer{arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; u.tryUnify(actualRetType, expectedRetType); @@ -424,13 +424,13 @@ struct TypeChecker2 TypePackId iteratorPack = arena.addTypePack(valueTypes, iteratorTail); // ... and then expand it out to 3 values (if possible) - const std::vector iteratorTypes = flatten(arena, singletonTypes, iteratorPack, 3); - if (iteratorTypes.empty()) + TypePack iteratorTypes = extendTypePack(arena, singletonTypes, iteratorPack, 3); + if (iteratorTypes.head.empty()) { reportError(GenericError{"for..in loops require at least one value to iterate over. Got zero"}, getLocation(forInStatement->values)); return; } - TypeId iteratorTy = follow(iteratorTypes[0]); + TypeId iteratorTy = follow(iteratorTypes.head[0]); auto checkFunction = [this, &arena, &scope, &forInStatement, &variableTypes]( const FunctionTypeVar* iterFtv, std::vector iterTys, bool isMm) { @@ -445,8 +445,8 @@ struct TypeChecker2 } // It is okay if there aren't enough iterators, but the iteratee must provide enough. - std::vector expectedVariableTypes = flatten(arena, singletonTypes, iterFtv->retTypes, variableTypes.size()); - if (expectedVariableTypes.size() < variableTypes.size()) + TypePack expectedVariableTypes = extendTypePack(arena, singletonTypes, iterFtv->retTypes, variableTypes.size()); + if (expectedVariableTypes.head.size() < variableTypes.size()) { if (isMm) reportError( @@ -455,8 +455,8 @@ struct TypeChecker2 reportError(GenericError{"next() does not return enough values"}, forInStatement->values.data[0]->location); } - for (size_t i = 0; i < std::min(expectedVariableTypes.size(), variableTypes.size()); ++i) - reportErrors(tryUnify(scope, forInStatement->vars.data[i]->location, variableTypes[i], expectedVariableTypes[i])); + for (size_t i = 0; i < std::min(expectedVariableTypes.head.size(), variableTypes.size()); ++i) + reportErrors(tryUnify(scope, forInStatement->vars.data[i]->location, variableTypes[i], expectedVariableTypes.head[i])); // nextFn is going to be invoked with (arrayTy, startIndexTy) @@ -477,25 +477,25 @@ struct TypeChecker2 if (maxCount && *maxCount < 2) reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - const std::vector flattenedArgTypes = flatten(arena, singletonTypes, iterFtv->argTypes, 2); + TypePack flattenedArgTypes = extendTypePack(arena, singletonTypes, iterFtv->argTypes, 2); size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; - size_t actualArgCount = expectedVariableTypes.size(); + size_t actualArgCount = expectedVariableTypes.head.size(); if (firstIterationArgCount < minCount) reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); else if (actualArgCount < minCount) reportError(CountMismatch{2, std::nullopt, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - if (iterTys.size() >= 2 && flattenedArgTypes.size() > 0) + if (iterTys.size() >= 2 && flattenedArgTypes.head.size() > 0) { size_t valueIndex = forInStatement->values.size > 1 ? 1 : 0; - reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[1], flattenedArgTypes[0])); + reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[1], flattenedArgTypes.head[0])); } - if (iterTys.size() == 3 && flattenedArgTypes.size() > 1) + if (iterTys.size() == 3 && flattenedArgTypes.head.size() > 1) { size_t valueIndex = forInStatement->values.size > 2 ? 2 : 0; - reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[2], flattenedArgTypes[1])); + reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[2], flattenedArgTypes.head[1])); } }; @@ -516,7 +516,7 @@ struct TypeChecker2 */ if (const FunctionTypeVar* nextFn = get(iteratorTy)) { - checkFunction(nextFn, iteratorTypes, false); + checkFunction(nextFn, iteratorTypes.head, false); } else if (const TableTypeVar* ttv = get(iteratorTy)) { @@ -545,19 +545,19 @@ struct TypeChecker2 TypePackId argPack = arena.addTypePack({iteratorTy}); reportErrors(tryUnify(scope, forInStatement->values.data[0]->location, argPack, iterMmFtv->argTypes)); - std::vector mmIteratorTypes = flatten(arena, singletonTypes, iterMmFtv->retTypes, 3); + TypePack mmIteratorTypes = extendTypePack(arena, singletonTypes, iterMmFtv->retTypes, 3); - if (mmIteratorTypes.size() == 0) + if (mmIteratorTypes.head.size() == 0) { reportError(GenericError{"__iter must return at least one value"}, forInStatement->values.data[0]->location); return; } - TypeId nextFn = follow(mmIteratorTypes[0]); + TypeId nextFn = follow(mmIteratorTypes.head[0]); if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) { - std::vector instantiatedIteratorTypes = mmIteratorTypes; + std::vector instantiatedIteratorTypes = mmIteratorTypes.head; instantiatedIteratorTypes[0] = *instantiatedNextFn; if (const FunctionTypeVar* nextFtv = get(*instantiatedNextFn)) @@ -800,8 +800,8 @@ struct TypeChecker2 for (AstExpr* arg : call->args) visit(arg); - TypeArena arena; - Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}, stack.back()}; + TypeArena* arena = &module->internalTypes; + Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; TypePackId expectedRetType = lookupPack(call); TypeId functionType = lookupType(call->func); @@ -845,30 +845,70 @@ struct TypeChecker2 return; } } + else if (auto utv = get(functionType)) + { + // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. + std::optional fst; + for (TypeId ty : utv) + { + if (!fst) + fst = follow(ty); + else if (fst != follow(ty)) + { + reportError(CannotCallNonFunction{functionType}, call->func->location); + return; + } + } + + if (!fst) + ice.ice("UnionTypeVar had no elements, so fst is nullopt?"); + + if (std::optional instantiatedFunctionType = instantiation.substitute(*fst)) + { + testFunctionType = *instantiatedFunctionType; + } + else + { + reportError(UnificationTooComplex{}, call->func->location); + return; + } + } else { reportError(CannotCallNonFunction{functionType}, call->func->location); return; } - for (AstExpr* arg : call->args) - { - TypeId argTy = lookupType(arg); - args.head.push_back(argTy); - } - if (call->self) { AstExprIndexName* indexExpr = call->func->as(); if (!indexExpr) ice.ice("method call expression has no 'self'"); - args.head.insert(args.head.begin(), lookupType(indexExpr->expr)); + args.head.push_back(lookupType(indexExpr->expr)); } - TypePackId argsTp = arena.addTypePack(args); + for (size_t i = 0; i < call->args.size; ++i) + { + AstExpr* arg = call->args.data[i]; + TypeId* argTy = module->astTypes.find(arg); + if (argTy) + args.head.push_back(*argTy); + else if (i == call->args.size - 1) + { + TypePackId* argTail = module->astTypePacks.find(arg); + if (argTail) + args.tail = *argTail; + else + args.tail = singletonTypes->anyTypePack; + } + else + args.head.push_back(singletonTypes->anyType); + } + + TypePackId argsTp = arena->addTypePack(args); FunctionTypeVar ftv{argsTp, expectedRetType}; - TypeId expectedType = arena.addType(ftv); + TypeId expectedType = arena->addType(ftv); if (!isSubtype(testFunctionType, expectedType, stack.back())) { @@ -881,19 +921,7 @@ struct TypeChecker2 void visit(AstExprIndexName* indexName) { TypeId leftType = lookupType(indexName->expr); - TypeId resultType = lookupType(indexName); - - // leftType must have a property called indexName->index - - std::optional ty = - getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); - if (ty) - { - if (!isSubtype(resultType, *ty, stack.back())) - { - reportError(TypeMismatch{resultType, *ty}, indexName->location); - } - } + getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); } void visit(AstExprIndexExpr* indexExpr) @@ -1085,7 +1113,7 @@ struct TypeChecker2 if (mm) { - if (const FunctionTypeVar* ftv = get(*mm)) + if (const FunctionTypeVar* ftv = get(follow(*mm))) { TypePackId expectedArgs; // For >= and > we invoke __lt and __le respectively with diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8ecd45bd5..9f64a6010 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,11 +35,12 @@ LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) -LUAU_FASTFLAGVARIABLE(LuauLvaluelessPath, false) LUAU_FASTFLAGVARIABLE(LuauNilIterator, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) +LUAU_FASTFLAGVARIABLE(LuauTypeInferMissingFollows, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) +LUAU_FASTFLAGVARIABLE(LuauFollowInLvalueIndexCheck, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) @@ -47,16 +48,15 @@ LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) LUAU_FASTFLAGVARIABLE(LuauOptionalNextKey, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) -LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) +LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) +LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) +LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false) +LUAU_FASTFLAG(LuauUninhabitedSubAnything) +LUAU_FASTFLAGVARIABLE(LuauCallableClasses, false) namespace Luau { -const char* TimeLimitError_DEPRECATED::what() const throw() -{ - LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange); - return "Typeinfer failed to complete in allotted time"; -} static bool typeCouldHaveMetatable(TypeId ty) { @@ -267,11 +267,6 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona reportErrorCodeTooComplex(module.root->location); return std::move(currentModule); } - catch (const RecursionLimitException_DEPRECATED&) - { - reportErrorCodeTooComplex(module.root->location); - return std::move(currentModule); - } } ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -316,10 +311,6 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo { currentModule->timeout = true; } - catch (const TimeLimitError_DEPRECATED&) - { - currentModule->timeout = true; - } if (FFlag::DebugLuauSharedSelf) { @@ -356,6 +347,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo unifierState.skipCacheForType.clear(); duplicateTypeAliases.clear(); + incorrectClassDefinitions.clear(); return std::move(currentModule); } @@ -426,7 +418,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) ice("Unknown AstStat"); if (finishTime && TimeTrace::getClock() > *finishTime) - throwTimeLimitError(); + throw TimeLimitError(iceHandler->moduleName); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. @@ -453,11 +445,6 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } - catch (const RecursionLimitException_DEPRECATED&) - { - reportErrorCodeTooComplex(block.location); - return; - } } struct InplaceDemoter : TypeVarOnceVisitor @@ -523,6 +510,10 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A prototype(scope, *typealias, subLevel); ++subLevel; } + else if (const auto& declaredClass = stat->as(); FFlag::LuauDeclareClassPrototype && declaredClass) + { + prototype(scope, *declaredClass); + } } auto protoIter = sorted.begin(); @@ -959,9 +950,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) TypeId right = nullptr; - Location loc = 0 == assign.values.size ? assign.location - : i < assign.values.size ? assign.values.data[i]->location - : assign.values.data[assign.values.size - 1]->location; + Location loc = 0 == assign.values.size + ? assign.location + : i < assign.values.size ? assign.values.data[i]->location : assign.values.data[assign.values.size - 1]->location; if (valueIter != valueEnd) { @@ -1670,8 +1661,10 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea } } -void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { + LUAU_ASSERT(FFlag::LuauDeclareClassPrototype); + std::optional superTy = std::nullopt; if (declaredClass.superName) { @@ -1681,6 +1674,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar if (!lookupType) { reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type}); + incorrectClassDefinitions.insert(&declaredClass); return; } @@ -1692,7 +1686,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar { reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); - + incorrectClassDefinitions.insert(&declaredClass); return; } } @@ -1701,61 +1695,174 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); ClassTypeVar* ctv = getMutable(classTy); - TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); - TableTypeVar* metatable = getMutable(metaTy); ctv->metatable = metaTy; - scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; +} - for (const AstDeclaredClassProp& prop : declaredClass.props) +void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +{ + if (FFlag::LuauDeclareClassPrototype) { - Name propName(prop.name.value); - TypeId propTy = resolveType(scope, *prop.ty); + Name className(declaredClass.name.value); + + // Don't bother checking if the class definition was incorrect + if (incorrectClassDefinitions.find(&declaredClass)) + return; + + std::optional binding; + if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end()) + binding = it->second; + + // This class definition must have been `prototype()`d first. + if (!binding) + ice("Class not predeclared"); - bool assignToMetatable = isMetamethod(propName); - Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + TypeId classTy = binding->type; + ClassTypeVar* ctv = getMutable(classTy); - // Function types always take 'self', but this isn't reflected in the - // parsed annotation. Add it here. - if (prop.isMethod) + if (!ctv->metatable) + ice("No metatable for declared class"); + + TableTypeVar* metatable = getMutable(*ctv->metatable); + for (const AstDeclaredClassProp& prop : declaredClass.props) { - if (FunctionTypeVar* ftv = getMutable(propTy)) + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, *prop.ty); + + bool assignToMetatable = isMetamethod(propName); + Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) { - ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); - ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - ftv->hasSelf = true; + if (FunctionTypeVar* ftv = getMutable(propTy)) + { + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + ftv->hasSelf = true; + } } - } - if (assignTo.count(propName) == 0) - { - assignTo[propName] = {propTy}; + if (assignTo.count(propName) == 0) + { + assignTo[propName] = {propTy}; + } + else + { + TypeId currentTy = assignTo[propName].type; + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionTypeVar* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); + + assignTo[propName] = {newItv}; + } + else if (get(currentTy)) + { + TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); + + assignTo[propName] = {intersection}; + } + else + { + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } + } } - else + } + else + { + std::optional superTy = std::nullopt; + if (declaredClass.superName) { - TypeId currentTy = assignTo[propName].type; + Name superName = Name(declaredClass.superName->value); + std::optional lookupType = scope->lookupType(superName); - // We special-case this logic to keep the intersection flat; otherwise we - // would create a ton of nested intersection types. - if (const IntersectionTypeVar* itv = get(currentTy)) + if (!lookupType) { - std::vector options = itv->parts; - options.push_back(propTy); - TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); + reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type}); + return; + } - assignTo[propName] = {newItv}; + // We don't have generic classes, so this assertion _should_ never be hit. + LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); + superTy = lookupType->type; + + if (!get(follow(*superTy))) + { + reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", + superName.c_str(), declaredClass.name.value)}); + return; } - else if (get(currentTy)) + } + + Name className(declaredClass.name.value); + + TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); + + ClassTypeVar* ctv = getMutable(classTy); + TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); + TableTypeVar* metatable = getMutable(metaTy); + + ctv->metatable = metaTy; + + scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + + for (const AstDeclaredClassProp& prop : declaredClass.props) + { + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, *prop.ty); + + bool assignToMetatable = isMetamethod(propName); + Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) { - TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); + if (FunctionTypeVar* ftv = getMutable(propTy)) + { + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + ftv->hasSelf = true; + } + } - assignTo[propName] = {intersection}; + if (assignTo.count(propName) == 0) + { + assignTo[propName] = {propTy}; } else { - reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + TypeId currentTy = assignTo[propName].type; + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionTypeVar* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); + + assignTo[propName] = {newItv}; + } + else if (get(currentTy)) + { + TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); + + assignTo[propName] = {intersection}; + } + else + { + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } } } } @@ -2548,6 +2655,48 @@ static std::optional getIdentifierOfBaseVar(AstExpr* node) return std::nullopt; } +/** Return true if comparison between the types a and b should be permitted with + * the == or ~= operators. + * + * Two types are considered eligible for equality testing if it is possible for + * the test to ever succeed. In other words, we test to see whether the two + * types have any overlap at all. + * + * In order to make things work smoothly with the greedy solver, this function + * exempts any and FreeTypeVars from this requirement. + * + * This function does not (yet?) take into account extra Lua restrictions like + * that two tables can only be compared if they have the same metatable. That + * is presently handled by the caller. + * + * @return True if the types are comparable. False if they are not. + * + * If an internal recursion limit is reached while performing this test, the + * function returns std::nullopt. + */ +static std::optional areEqComparable(NotNull arena, NotNull normalizer, TypeId a, TypeId b) +{ + a = follow(a); + b = follow(b); + + auto isExempt = [](TypeId t) { + return isNil(t) || get(t); + }; + + if (isExempt(a) || isExempt(b)) + return true; + + TypeId c = arena->addType(IntersectionTypeVar{{a, b}}); + const NormalizedType* n = normalizer->normalize(c); + if (!n) + return std::nullopt; + + if (FFlag::LuauUninhabitedSubAnything) + return normalizer->isInhabited(n); + else + return isInhabited_DEPRECATED(*n); +} + TypeId TypeChecker::checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates) { @@ -2618,6 +2767,28 @@ TypeId TypeChecker::checkRelationalOperation( return booleanType; } + if (FFlag::LuauIntersectionTestForEquality && isEquality) + { + // Unless either type is free or any, an equality comparison is only + // valid when the intersection of the two operands is non-empty. + // + // eg it is okay to compare string? == number? because the two types + // have nil in common, but string == number is not allowed. + std::optional eqTestResult = areEqComparable(NotNull{¤tModule->internalTypes}, NotNull{&normalizer}, lhsType, rhsType); + if (!eqTestResult) + { + reportErrorCodeTooComplex(expr.location); + return errorRecoveryType(booleanType); + } + + if (!*eqTestResult) + { + reportError( + expr.location, GenericError{format("Type %s cannot be compared with %s", toString(lhsType).c_str(), toString(rhsType).c_str())}); + return errorRecoveryType(booleanType); + } + } + /* Subtlety here: * We need to do this unification first, but there are situations where we don't actually want to * report any problems that might have been surfaced as a result of this step because we might already @@ -2630,7 +2801,7 @@ TypeId TypeChecker::checkRelationalOperation( state.log.commit(); } - bool needsMetamethod = !isEquality; + const bool needsMetamethod = !isEquality; TypeId leftType = follow(lhsType); if (get(leftType) || get(leftType) || get(leftType) || get(leftType)) @@ -3212,6 +3383,9 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex TypeId indexType = checkExpr(scope, *expr.index).type; + if (FFlag::LuauFollowInLvalueIndexCheck) + exprType = follow(exprType); + if (get(exprType) || get(exprType)) return exprType; @@ -3233,6 +3407,16 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex return prop->type; } } + else if (FFlag::LuauAllowIndexClassParameters) + { + if (const ClassTypeVar* exprClass = get(exprType)) + { + if (isNonstrictMode()) + return unknownType; + reportError(TypeError{expr.location, DynamicPropertyLookupOnClassesUnsafe{exprType}}); + return errorRecoveryType(scope); + } + } TableTypeVar* exprTable = getMutableTableType(exprType); @@ -3687,16 +3871,8 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam std::string namePath; - if (FFlag::LuauLvaluelessPath) - { - if (std::optional path = getFunctionNameAsString(funName)) - namePath = *path; - } - else - { - if (std::optional lValue = tryGetLValue(funName)) - namePath = toString(*lValue); - } + if (std::optional path = getFunctionNameAsString(funName)) + namePath = *path; auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); state.reportError(TypeError{location, @@ -3806,27 +3982,11 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam std::string namePath; - if (FFlag::LuauLvaluelessPath) - { - if (std::optional path = getFunctionNameAsString(funName)) - namePath = *path; - } - else - { - if (std::optional lValue = tryGetLValue(funName)) - namePath = toString(*lValue); - } + if (std::optional path = getFunctionNameAsString(funName)) + namePath = *path; - if (FFlag::LuauArgMismatchReportFunctionLocation) - { - state.reportError(TypeError{ - funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); - } - else - { - state.reportError(TypeError{ - state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); - } + state.reportError(TypeError{ + funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); return; } ++paramIter; @@ -4135,26 +4295,33 @@ std::optional> TypeChecker::checkCallOverload(const Sc std::vector metaArgLocations; - // Might be a callable table + // Might be a callable table or class + std::optional callTy = std::nullopt; if (const MetatableTypeVar* mttv = get(fn)) { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false)) - { - // Construct arguments with 'self' added in front - TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); + callTy = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false); + } + else if (const ClassTypeVar* ctv = get(fn); FFlag::LuauCallableClasses && ctv && ctv->metatable) + { + callTy = getIndexTypeFromType(scope, *ctv->metatable, "__call", expr.func->location, /* addErrors= */ false); + } - TypePack* metaCallArgs = getMutable(metaCallArgPack); - metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); + if (callTy) + { + // Construct arguments with 'self' added in front + TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); - metaArgLocations = *argLocations; - metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + TypePack* metaCallArgs = getMutable(metaCallArgPack); + metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); - fn = instantiate(scope, *ty, expr.func->location); + metaArgLocations = *argLocations; + metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); - argPack = metaCallArgPack; - args = metaCallArgs; - argLocations = &metaArgLocations; - } + fn = instantiate(scope, *callTy, expr.func->location); + + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; } const FunctionTypeVar* ftv = get(fn); @@ -4331,7 +4498,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast std::string s; for (size_t i = 0; i < overloadTypes.size(); ++i) { - TypeId overload = overloadTypes[i]; + TypeId overload = FFlag::LuauTypeInferMissingFollows ? follow(overloadTypes[i]) : overloadTypes[i]; Unifier state = mkUnifier(scope, expr.location); // Unify return types @@ -4712,7 +4879,10 @@ TypePackId TypeChecker::anyifyModuleReturnTypePackGenerics(TypePackId tp) tp = follow(tp); if (const VariadicTypePack* vtp = get(tp)) - return get(vtp->ty) ? anyTypePack : tp; + { + TypeId ty = FFlag::LuauTypeInferMissingFollows ? follow(vtp->ty) : vtp->ty; + return get(ty) ? anyTypePack : tp; + } if (!get(follow(tp))) return tp; @@ -4763,19 +4933,6 @@ void TypeChecker::ice(const std::string& message) iceHandler->ice(message); } -// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. -void TypeChecker::throwTimeLimitError() -{ - if (FFlag::LuauIceExceptionInheritanceChange) - { - throw TimeLimitError(iceHandler->moduleName); - } - else - { - throw TimeLimitError_DEPRECATED(); - } -} - void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) { // Remove errors with names that were generated by recovery from a parse error @@ -5955,11 +6112,11 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc if (optionIsSubtype && !targetIsSubtype) return option; else if (!optionIsSubtype && targetIsSubtype) - return eqP.type; + return FFlag::LuauTypeInferMissingFollows ? follow(eqP.type) : eqP.type; else if (!optionIsSubtype && !targetIsSubtype) return nope; else if (optionIsSubtype && targetIsSubtype) - return eqP.type; + return FFlag::LuauTypeInferMissingFollows ? follow(eqP.type) : eqP.type; } else { diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 0852f0535..0f75c3efc 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -6,6 +6,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauTxnLogTypePackIterator, false) + namespace Luau { @@ -60,8 +62,8 @@ TypePackIterator::TypePackIterator(TypePackId typePack) } TypePackIterator::TypePackIterator(TypePackId typePack, const TxnLog* log) - : currentTypePack(follow(typePack)) - , tp(get(currentTypePack)) + : currentTypePack(FFlag::LuauTxnLogTypePackIterator ? log->follow(typePack) : follow(typePack)) + , tp(FFlag::LuauTxnLogTypePackIterator ? log->get(currentTypePack) : get(currentTypePack)) , currentIndex(0) , log(log) { @@ -235,7 +237,7 @@ TypePackId follow(TypePackId tp, std::function mapper) cycleTester = nullptr; if (tp == cycleTester) - throwRuntimeError("Luau::follow detected a TypeVar cycle!!"); + throw InternalCompilerError("Luau::follow detected a TypeVar cycle!!"); } } } diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 72597c4a1..876c45a77 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -122,10 +122,6 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro for (TypeId t : utv) { - // TODO: we should probably limit recursion here? - // RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - - // Not needed when we normalize types. if (get(follow(t))) return t; @@ -164,9 +160,6 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro for (TypeId t : itv->parts) { - // TODO: we should probably limit recursion here? - // RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - if (std::optional ty = getIndexTypeFromType(scope, errors, arena, singletonTypes, t, prop, location, /* addErrors= */ false, handle)) parts.push_back(*ty); @@ -183,7 +176,7 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro if (parts.size() == 1) return parts[0]; - return arena->addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. + return arena->addType(IntersectionTypeVar{std::move(parts)}); } if (addErrors) @@ -221,46 +214,95 @@ std::pair> getParameterExtents(const TxnLog* log, return {minCount, minCount + optionalCount}; } -std::vector flatten(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length) +TypePack extendTypePack(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length) { - std::vector result; + TypePack result; - auto it = begin(pack); - auto endIt = end(pack); - - while (it != endIt) + while (true) { - result.push_back(*it); + pack = follow(pack); + + if (const TypePack* p = get(pack)) + { + size_t i = 0; + while (i < p->head.size() && result.head.size() < length) + { + result.head.push_back(p->head[i]); + ++i; + } + + if (result.head.size() == length) + { + if (i == p->head.size()) + result.tail = p->tail; + else + { + TypePackId newTail = arena.addTypePack(TypePack{}); + TypePack* newTailPack = getMutable(newTail); + + newTailPack->head.insert(newTailPack->head.begin(), p->head.begin() + i, p->head.end()); + newTailPack->tail = p->tail; - if (result.size() >= length) + result.tail = newTail; + } + + return result; + } + else if (p->tail) + { + pack = *p->tail; + continue; + } + else + { + // There just aren't enough types in this pack to satisfy the request. + return result; + } + } + else if (const VariadicTypePack* vtp = get(pack)) + { + while (result.head.size() < length) + result.head.push_back(vtp->ty); + result.tail = pack; return result; + } + else if (FreeTypePack* ftp = getMutable(pack)) + { + // If we need to get concrete types out of a free pack, we choose to + // interpret this as proof that the pack must have at least 'length' + // elements. We mint fresh types for each element we're extracting + // and rebind the free pack to be a TypePack containing them. We + // also have to create a new tail. - ++it; - } + TypePack newPack; + newPack.tail = arena.freshTypePack(ftp->scope); - if (!it.tail()) - return result; + while (result.head.size() < length) + { + newPack.head.push_back(arena.freshType(ftp->scope)); + result.head.push_back(newPack.head.back()); + } - TypePackId tail = *it.tail(); - if (get(tail)) - LUAU_ASSERT(0); - else if (auto vtp = get(tail)) - { - while (result.size() < length) - result.push_back(vtp->ty); - } - else if (get(tail) || get(tail)) - { - while (result.size() < length) - result.push_back(arena.addType(FreeTypeVar{nullptr})); - } - else if (auto etp = get(tail)) - { - while (result.size() < length) - result.push_back(singletonTypes->errorRecoveryType()); - } + asMutable(pack)->ty.emplace(std::move(newPack)); - return result; + return result; + } + else if (const Unifiable::Error* etp = getMutable(pack)) + { + while (result.head.size() < length) + result.head.push_back(singletonTypes->errorRecoveryType()); + + result.tail = pack; + return result; + } + else + { + // If the pack is blocked or generic, we can't extract. + // Return whatever we've got with this pack as the tail. + result.tail = pack; + return result; + } + } } std::vector reduceUnion(const std::vector& types) diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 814eca0d5..6771d89b0 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -72,7 +72,7 @@ TypeId follow(TypeId t, std::function mapper) { TypeId res = ltv->thunk(); if (get(res)) - throwRuntimeError("Lazy TypeVar cannot resolve to another Lazy TypeVar"); + throw InternalCompilerError("Lazy TypeVar cannot resolve to another Lazy TypeVar"); *asMutable(ty) = BoundTypeVar(res); } @@ -110,7 +110,7 @@ TypeId follow(TypeId t, std::function mapper) cycleTester = nullptr; if (t == cycleTester) - throwRuntimeError("Luau::follow detected a TypeVar cycle!!"); + throw InternalCompilerError("Luau::follow detected a TypeVar cycle!!"); } } } @@ -468,65 +468,65 @@ PendingExpansionTypeVar::PendingExpansionTypeVar( size_t PendingExpansionTypeVar::nextIndex = 0; FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) - : argTypes(argTypes) + : definition(std::move(defn)) + , argTypes(argTypes) , retTypes(retTypes) - , definition(std::move(defn)) , hasSelf(hasSelf) { } FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) - : level(level) + : definition(std::move(defn)) + , level(level) , argTypes(argTypes) , retTypes(retTypes) - , definition(std::move(defn)) , hasSelf(hasSelf) { } FunctionTypeVar::FunctionTypeVar( TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) - : level(level) + : definition(std::move(defn)) + , level(level) , scope(scope) , argTypes(argTypes) , retTypes(retTypes) - , definition(std::move(defn)) , hasSelf(hasSelf) { } FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) - : generics(generics) + : definition(std::move(defn)) + , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) , retTypes(retTypes) - , definition(std::move(defn)) , hasSelf(hasSelf) { } FunctionTypeVar::FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) - : level(level) + : definition(std::move(defn)) , generics(generics) , genericPacks(genericPacks) + , level(level) , argTypes(argTypes) , retTypes(retTypes) - , definition(std::move(defn)) , hasSelf(hasSelf) { } FunctionTypeVar::FunctionTypeVar(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) - : level(level) - , scope(scope) + : definition(std::move(defn)) , generics(generics) , genericPacks(genericPacks) + , level(level) + , scope(scope) , argTypes(argTypes) , retTypes(retTypes) - , definition(std::move(defn)) , hasSelf(hasSelf) { } diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index 133104d3f..4dc26219c 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -48,6 +48,8 @@ void* pagedAllocate(size_t size) // On Linux, we must use mmap because using regular heap results in mprotect() fragmenting the page table and us bumping into 64K mmap limit. #ifdef _WIN32 return _aligned_malloc(size, kPageSize); +#elif defined(__FreeBSD__) + return aligned_alloc(kPageSize, size); #else return mmap(nullptr, pageAlign(size), PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); #endif @@ -61,6 +63,8 @@ void pagedDeallocate(void* ptr, size_t size) #ifdef _WIN32 _aligned_free(ptr); +#elif defined(__FreeBSD__) + free(ptr); #else int rc = munmap(ptr, size); LUAU_ASSERT(rc == 0); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 4dc909831..5ff405e62 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,8 +22,10 @@ LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); -LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner, false) +LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner2, false) +LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) +LUAU_FASTFLAG(LuauTxnLogTypePackIterator) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNegatedFunctionTypes) @@ -51,6 +53,9 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor template void promote(TID ty, T* t) { + if (FFlag::DebugLuauDeferredConstraintResolution && !t) + return; + LUAU_ASSERT(t); if (useScopes) @@ -102,6 +107,11 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor if (ty->owningArena != typeArena) return false; + // Surprise, it's actually a BoundTypePack that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + return true; + promote(ty, log.getMutable(ty)); return true; } @@ -115,6 +125,11 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor if (ttv.state != TableState::Free && ttv.state != TableState::Generic) return true; + // Surprise, it's actually a BoundTypePack that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + return true; + promote(ty, log.getMutable(ty)); return true; } @@ -277,7 +292,7 @@ TypeId Widen::clean(TypeId ty) TypePackId Widen::clean(TypePackId) { - throwRuntimeError("Widen attempted to clean a dirty type pack?"); + throw InternalCompilerError("Widen attempted to clean a dirty type pack?"); } bool Widen::ignoreChildren(TypeId ty) @@ -336,6 +351,20 @@ static bool subsumes(bool useScopes, TY_A* left, TY_B* right) return left->level.subsumes(right->level); } +TypeMismatch::Context Unifier::mismatchContext() +{ + switch (variance) + { + case Covariant: + return TypeMismatch::CovariantContext; + case Invariant: + return TypeMismatch::InvariantContext; + default: + LUAU_ASSERT(false); // This codepath should be unreachable. + return TypeMismatch::CovariantContext; + } +} + Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog) : types(normalizer->arena) , singletonTypes(normalizer->singletonTypes) @@ -559,8 +588,11 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.get(subTy)) tryUnifyNegationWithType(subTy, superTy); + else if (FFlag::LuauUninhabitedSubAnything && !normalizer->isInhabited(subTy)) + {} + else - reportError(location, TypeMismatch{superTy, subTy}); + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); if (cacheEnabled) cacheResult(subTy, superTy, errorCount); @@ -575,11 +607,16 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, std::optional unificationTooComplex; std::optional firstFailedOption; + std::vector logs; + for (TypeId type : subUnion->options) { Unifier innerState = makeChildUnifier(); innerState.tryUnify_(type, superTy); + if (FFlag::DebugLuauDeferredConstraintResolution) + logs.push_back(std::move(innerState.log)); + if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; else if (!innerState.errors.empty()) @@ -592,51 +629,56 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, } } - // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. - auto tryBind = [this, subTy](TypeId superOption) { - superOption = log.follow(superOption); - - // just skip if the superOption is not free-ish. - auto ttv = log.getMutable(superOption); - if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) - return; + if (FFlag::DebugLuauDeferredConstraintResolution) + log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); + else + { + // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. + auto tryBind = [this, subTy](TypeId superOption) { + superOption = log.follow(superOption); - // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype - // test is successful. - if (auto subUnion = get(subTy)) - { - if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) + // just skip if the superOption is not free-ish. + auto ttv = log.getMutable(superOption); + if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) return; - } - // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. - // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. - if (log.haveSeen(subTy, superOption)) + // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype + // test is successful. + if (auto subUnion = get(subTy)) + { + if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) + return; + } + + // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. + // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. + if (log.haveSeen(subTy, superOption)) + { + // TODO: would it be nice for TxnLog::replace to do this? + if (log.is(superOption)) + log.bindTable(superOption, subTy); + else + log.replace(superOption, *subTy); + } + }; + + if (auto superUnion = log.getMutable(superTy)) { - // TODO: would it be nice for TxnLog::replace to do this? - if (log.is(superOption)) - log.bindTable(superOption, subTy); - else - log.replace(superOption, *subTy); + for (TypeId ty : superUnion) + tryBind(ty); } - }; - - if (auto superUnion = log.getMutable(superTy)) - { - for (TypeId ty : superUnion) - tryBind(ty); + else + tryBind(superTy); } - else - tryBind(superTy); if (unificationTooComplex) reportError(*unificationTooComplex); else if (failed) { if (firstFailedOption) - reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}); + reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption, mismatchContext()}); else - reportError(location, TypeMismatch{superTy, subTy}); + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); } } @@ -696,6 +738,8 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } } + std::vector logs; + for (size_t i = 0; i < uv->options.size(); ++i) { TypeId type = uv->options[(i + startIndex) % uv->options.size()]; @@ -706,9 +750,13 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp if (innerState.errors.empty()) { found = true; - log.concat(std::move(innerState.log)); - - break; + if (FFlag::DebugLuauDeferredConstraintResolution) + logs.push_back(std::move(innerState.log)); + else + { + log.concat(std::move(innerState.log)); + break; + } } else if (auto e = hasUnificationTooComplex(innerState.errors)) { @@ -723,6 +771,9 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } } + if (FFlag::DebugLuauDeferredConstraintResolution) + log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); + if (unificationTooComplex) { reportError(*unificationTooComplex); @@ -744,9 +795,10 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp else if (!found) { if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - reportError(location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}); + reportError( + location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption, mismatchContext()}); else - reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}); + reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible", mismatchContext()}); } } @@ -755,6 +807,8 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I std::optional unificationTooComplex; std::optional firstFailedOption; + std::vector logs; + // T <: A & B if and only if T <: A and T <: B for (TypeId type : uv->parts) { @@ -769,13 +823,19 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I firstFailedOption = {innerState.errors.front()}; } - log.concat(std::move(innerState.log)); + if (FFlag::DebugLuauDeferredConstraintResolution) + logs.push_back(std::move(innerState.log)); + else + log.concat(std::move(innerState.log)); } + if (FFlag::DebugLuauDeferredConstraintResolution) + log.concat(combineLogsIntoIntersection(std::move(logs))); + if (unificationTooComplex) reportError(*unificationTooComplex); else if (firstFailedOption) - reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}); + reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption, mismatchContext()}); } void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) @@ -802,6 +862,8 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV } } + std::vector logs; + for (size_t i = 0; i < uv->parts.size(); ++i) { TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; @@ -812,8 +874,13 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV if (innerState.errors.empty()) { found = true; - log.concat(std::move(innerState.log)); - break; + if (FFlag::DebugLuauDeferredConstraintResolution) + logs.push_back(std::move(innerState.log)); + else + { + log.concat(std::move(innerState.log)); + break; + } } else if (auto e = hasUnificationTooComplex(innerState.errors)) { @@ -821,6 +888,9 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV } } + if (FFlag::DebugLuauDeferredConstraintResolution) + log.concat(combineLogsIntoIntersection(std::move(logs))); + if (unificationTooComplex) reportError(*unificationTooComplex); else if (!found && normalize) @@ -837,7 +907,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV } else if (!found) { - reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}); + reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible", mismatchContext()}); } } @@ -849,37 +919,37 @@ void Unifier::tryUnifyNormalizedTypes( if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) return; else if (get(subNorm.tops)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.errors)) if (!get(superNorm.errors)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.booleans)) { if (!get(superNorm.booleans)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } else if (const SingletonTypeVar* stv = get(subNorm.booleans)) { if (!get(superNorm.booleans) && stv != get(superNorm.booleans)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } if (get(subNorm.nils)) if (!get(superNorm.nils)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.numbers)) if (!get(superNorm.numbers)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (!isSubtype(subNorm.strings, superNorm.strings)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.threads)) if (!get(superNorm.errors)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); for (TypeId subClass : subNorm.classes) { @@ -895,7 +965,7 @@ void Unifier::tryUnifyNormalizedTypes( } } if (!found) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } for (TypeId subTable : subNorm.tables) @@ -920,19 +990,19 @@ void Unifier::tryUnifyNormalizedTypes( return reportError(*e); } if (!found) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } if (!subNorm.functions.isNever()) { if (superNorm.functions.isNever()) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); for (TypeId superFun : *superNorm.functions.parts) { Unifier innerState = makeChildUnifier(); const FunctionTypeVar* superFtv = get(superFun); if (!superFtv) - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); innerState.tryUnify_(tgt, superFtv->retTypes); if (innerState.errors.empty()) @@ -940,7 +1010,7 @@ void Unifier::tryUnifyNormalizedTypes( else if (auto e = hasUnificationTooComplex(innerState.errors)) return reportError(*e); else - return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } } @@ -1306,7 +1376,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - if (subTpv->tail && superTpv->tail) + if (!FFlag::LuauTxnLogTypePackIterator && subTpv->tail && superTpv->tail) { tryUnify_(*subTpv->tail, *superTpv->tail); break; @@ -1314,10 +1384,27 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail) + if (FFlag::LuauTxnLogTypePackIterator && lFreeTail && rFreeTail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + } + else if (lFreeTail) + { tryUnify_(emptyTp, *superTpv->tail); + } else if (rFreeTail) + { tryUnify_(emptyTp, *subTpv->tail); + } + else if (FFlag::LuauTxnLogTypePackIterator && subTpv->tail && superTpv->tail) + { + if (log.getMutable(superIter.packId)) + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + else if (log.getMutable(subIter.packId)) + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + else + tryUnify_(*subTpv->tail, *superTpv->tail); + } break; } @@ -1407,7 +1494,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) ice("passed non primitive types to unifyPrimitives"); if (superPrim->type != subPrim->type) - reportError(location, TypeMismatch{superTy, subTy}); + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); } void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) @@ -1428,7 +1515,7 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; - reportError(location, TypeMismatch{superTy, subTy}); + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); } void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) @@ -1471,14 +1558,14 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); - reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}); + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters", mismatchContext()}); } if (numGenericPacks != subFunction->genericPacks.size()) { numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); - reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}); + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters", mismatchContext()}); } for (size_t i = 0; i < numGenerics; i++) @@ -1506,9 +1593,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal reportError(*e); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}); + innerState.errors.front(), mismatchContext()}); else if (!innerState.errors.empty()) - reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}); + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); innerState.ctx = CountMismatch::FunctionResult; innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); @@ -1518,12 +1605,12 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) - reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}); + reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front(), mismatchContext()}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}); + innerState.errors.front(), mismatchContext()}); else if (!innerState.errors.empty()) - reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}); + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); } log.concat(std::move(innerState.log)); @@ -1700,10 +1787,10 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated // txn log. - TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner ? log.follow(superTy) : superTy; - TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner ? log.follow(subTy) : subTy; + TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(superTy) : superTy; + TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(subTy) : subTy; - if (FFlag::LuauScalarShapeUnifyToMtOwner) + if (FFlag::LuauScalarShapeUnifyToMtOwner2) { // If one of the types stopped being a table altogether, we need to restart from the top if ((superTy != superTyNew || subTy != subTyNew) && errors.empty()) @@ -1771,11 +1858,21 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else extraProperties.push_back(name); + TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(superTy) : superTy; + TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(subTy) : subTy; + + if (FFlag::LuauScalarShapeUnifyToMtOwner2) + { + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || subTy != subTyNew) && errors.empty()) + return tryUnify(subTy, superTy, false, isIntersection); + } + // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); + TableTypeVar* newSuperTable = log.getMutable(superTyNew); + TableTypeVar* newSubTable = log.getMutable(subTyNew); if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) @@ -1829,8 +1926,19 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // Changing the indexer can invalidate the table pointers. - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); + if (FFlag::LuauScalarShapeUnifyToMtOwner2) + { + superTable = log.getMutable(log.follow(superTy)); + subTable = log.getMutable(log.follow(subTy)); + + if (!superTable || !subTable) + return; + } + else + { + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } if (!missingProperties.empty()) { @@ -1872,20 +1980,23 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) TypeId osubTy = subTy; TypeId osuperTy = superTy; + if (FFlag::LuauUninhabitedSubAnything && !normalizer->isInhabited(subTy)) + return; + if (reversed) std::swap(subTy, superTy); TableTypeVar* superTable = log.getMutable(superTy); if (!superTable || superTable->state != TableState::Free) - return reportError(location, TypeMismatch{osuperTy, osubTy}); + return reportError(location, TypeMismatch{osuperTy, osubTy, mismatchContext()}); auto fail = [&](std::optional e) { std::string reason = "The former's metatable does not satisfy the requirements."; if (e) - reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e}); + reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e, mismatchContext()}); else - reportError(location, TypeMismatch{osuperTy, osubTy, reason}); + reportError(location, TypeMismatch{osuperTy, osubTy, reason, mismatchContext()}); }; // Given t1 where t1 = { lower: (t1) -> (a, b...) } @@ -1902,7 +2013,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) Unifier child = makeChildUnifier(); child.tryUnify_(ty, superTy); - if (FFlag::LuauScalarShapeUnifyToMtOwner) + if (FFlag::LuauScalarShapeUnifyToMtOwner2) { // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed @@ -1923,7 +2034,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) log.concat(std::move(child.log)); - if (FFlag::LuauScalarShapeUnifyToMtOwner) + if (FFlag::LuauScalarShapeUnifyToMtOwner2) { // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype @@ -1939,7 +2050,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) } } - reportError(location, TypeMismatch{osuperTy, osubTy}); + reportError(location, TypeMismatch{osuperTy, osubTy, mismatchContext()}); return; } @@ -1969,7 +2080,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (!superMetatable) ice("tryUnifyMetatable invoked with non-metatable TypeVar"); - TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy}}; + TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, mismatchContext()}}; if (const MetatableTypeVar* subMetatable = log.getMutable(subTy)) { @@ -1980,7 +2091,8 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty()) - reportError(location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}); + reportError( + location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}); log.concat(std::move(innerState.log)); } @@ -2017,8 +2129,8 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty()) - reportError( - TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); + reportError(TypeError{location, + TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}}); else if (!missingProperty) { log.concat(std::move(innerState.log)); @@ -2057,9 +2169,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) auto fail = [&]() { if (!reversed) - reportError(location, TypeMismatch{superTy, subTy}); + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); else - reportError(location, TypeMismatch{subTy, superTy}); + reportError(location, TypeMismatch{subTy, superTy, mismatchContext()}); }; const ClassTypeVar* superClass = get(superTy); @@ -2155,7 +2267,7 @@ void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy) Unifier state = makeChildUnifier(); state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, ""); if (state.errors.empty()) - reportError(location, TypeMismatch{superTy, subTy}); + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); } void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy) @@ -2165,7 +2277,7 @@ void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy) ice("tryUnifyNegationWithType subTy must be a negation type"); // TODO: ~T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) @@ -2200,9 +2312,11 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever if (!superVariadic) ice("passed non-variadic pack to tryUnifyVariadics"); - if (const VariadicTypePack* subVariadic = get(subTp)) + if (const VariadicTypePack* subVariadic = FFlag::LuauTxnLogTypePackIterator ? log.get(subTp) : get(subTp)) + { tryUnify_(reversed ? superVariadic->ty : subVariadic->ty, reversed ? subVariadic->ty : superVariadic->ty); - else if (get(subTp)) + } + else if (FFlag::LuauTxnLogTypePackIterator ? log.get(subTp) : get(subTp)) { TypePackIterator subIter = begin(subTp, &log); TypePackIterator subEnd = end(subTp); @@ -2350,6 +2464,24 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); } +TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + TxnLog result; + for (TxnLog& log : logs) + result.concatAsIntersections(std::move(log), NotNull{types}); + return result; +} + +TxnLog Unifier::combineLogsIntoUnion(std::vector logs) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + TxnLog result; + for (TxnLog& log : logs) + result.concatAsUnion(std::move(log), NotNull{types}); + return result; +} + bool Unifier::occursCheck(TypeId needle, TypeId haystack) { sharedState.tempSeenTy.clear(); @@ -2491,7 +2623,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId if (auto e = hasUnificationTooComplex(innerErrors)) reportError(*e); else if (!innerErrors.empty()) - reportError(location, TypeMismatch{wantedType, givenType}); + reportError(location, TypeMismatch{wantedType, givenType, mismatchContext()}); } void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) @@ -2499,8 +2631,8 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const s if (auto e = hasUnificationTooComplex(innerErrors)) reportError(*e); else if (!innerErrors.empty()) - reportError( - TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front()}}); + reportError(TypeError{location, + TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front(), mismatchContext()}}); } void Unifier::ice(const std::string& message, const Location& location) diff --git a/Ast/include/Luau/Location.h b/Ast/include/Luau/Location.h index d3c0a4623..e39bbf8c5 100644 --- a/Ast/include/Luau/Location.h +++ b/Ast/include/Luau/Location.h @@ -50,6 +50,20 @@ struct Position { return *this == rhs || *this > rhs; } + + void shift(const Position& start, const Position& oldEnd, const Position& newEnd) + { + if (*this >= start) + { + if (this->line > oldEnd.line) + this->line += (newEnd.line - oldEnd.line); + else + { + this->line = newEnd.line; + this->column += (newEnd.column - oldEnd.column); + } + } + } }; struct Location @@ -93,6 +107,10 @@ struct Location { return begin <= l.begin && end >= l.end; } + bool overlaps(const Location& l) const + { + return (begin <= l.begin && end >= l.begin) || (begin <= l.end && end >= l.end) || (begin >= l.begin && end <= l.end); + } bool contains(const Position& p) const { return begin <= p && p < end; @@ -101,6 +119,18 @@ struct Location { return begin <= p && p <= end; } + void extend(const Location& other) + { + if (other.begin < begin) + begin = other.begin; + if (other.end > end) + end = other.end; + } + void shift(const Position& start, const Position& oldEnd, const Position& newEnd) + { + begin.shift(start, oldEnd, newEnd); + end.shift(start, oldEnd, newEnd); + } }; std::string toString(const Position& position); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 85b0d31ab..5cd5f7437 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -24,9 +24,6 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) -LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) -LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false) - LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false) bool lua_telemetry_parsed_out_of_range_bin_integer = false; @@ -1084,7 +1081,7 @@ void Parser::parseExprList(TempVector& result) { nextLexeme(); - if (FFlag::LuauCommaParenWarnings && lexer.current().type == ')') + if (lexer.current().type == ')') { report(lexer.current().location, "Expected expression after ',' but got ')' instead"); break; @@ -1179,7 +1176,7 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector, AstArray> Parser::parseG { nextLexeme(); - if (FFlag::LuauCommaParenWarnings && lexer.current().type == '>') + if (lexer.current().type == '>') { report(lexer.current().location, "Expected type after ',' but got '>' instead"); break; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 83764244c..7d230738b 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1895,10 +1895,7 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, case LOP_CAPTURE: formatAppend(result, "CAPTURE %s %c%d\n", - LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL" - : LUAU_INSN_A(insn) == LCT_REF ? "REF" - : LUAU_INSN_A(insn) == LCT_VAL ? "VAL" - : "", + LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL" : LUAU_INSN_A(insn) == LCT_REF ? "REF" : LUAU_INSN_A(insn) == LCT_VAL ? "VAL" : "", LUAU_INSN_A(insn) == LCT_UPVAL ? 'U' : 'R', LUAU_INSN_B(insn)); break; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 7ccd1164d..5d6723669 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,6 +26,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAG(LuauInterpolatedStringBaseSupport) +LUAU_FASTFLAGVARIABLE(LuauMultiAssignmentConflictFix, false) namespace Luau { @@ -2977,16 +2978,46 @@ struct Compiler Visitor visitor(this); - // mark any registers that are used *after* assignment as conflicting - for (size_t i = 0; i < vars.size(); ++i) + if (FFlag::LuauMultiAssignmentConflictFix) { - const LValue& li = vars[i].lvalue; + // mark any registers that are used *after* assignment as conflicting - if (i < values.size) - values.data[i]->visit(&visitor); + // first we go through assignments to locals, since they are performed before assignments to other l-values + for (size_t i = 0; i < vars.size(); ++i) + { + const LValue& li = vars[i].lvalue; + + if (li.kind == LValue::Kind_Local) + { + if (i < values.size) + values.data[i]->visit(&visitor); + + visitor.assigned[li.reg] = true; + } + } + + // and now we handle all other l-values + for (size_t i = 0; i < vars.size(); ++i) + { + const LValue& li = vars[i].lvalue; - if (li.kind == LValue::Kind_Local) - visitor.assigned[li.reg] = true; + if (li.kind != LValue::Kind_Local && i < values.size) + values.data[i]->visit(&visitor); + } + } + else + { + // mark any registers that are used *after* assignment as conflicting + for (size_t i = 0; i < vars.size(); ++i) + { + const LValue& li = vars[i].lvalue; + + if (i < values.size) + values.data[i]->visit(&visitor); + + if (li.kind == LValue::Kind_Local) + visitor.assigned[li.reg] = true; + } } // mark any registers used in trailing expressions as conflicting as well diff --git a/Sources.cmake b/Sources.cmake index 33ecde938..e243ea74d 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -112,7 +112,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Constraint.h Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintSolver.h - Analysis/include/Luau/DataFlowGraphBuilder.h + Analysis/include/Luau/DataFlowGraph.h Analysis/include/Luau/DcrLogger.h Analysis/include/Luau/Def.h Analysis/include/Luau/Documentation.h @@ -166,7 +166,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Constraint.cpp Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintSolver.cpp - Analysis/src/DataFlowGraphBuilder.cpp + Analysis/src/DataFlowGraph.cpp Analysis/src/DcrLogger.cpp Analysis/src/Def.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -298,15 +298,16 @@ endif() if(TARGET Luau.UnitTest) # Luau.UnitTest Sources target_sources(Luau.UnitTest PRIVATE + tests/AstQueryDsl.cpp tests/AstQueryDsl.h + tests/ClassFixture.cpp + tests/ClassFixture.h + tests/ConstraintGraphBuilderFixture.cpp tests/ConstraintGraphBuilderFixture.h + tests/Fixture.cpp tests/Fixture.h tests/IostreamOptional.h tests/ScopedFlags.h - tests/AstQueryDsl.cpp - tests/ConstraintGraphBuilderFixture.cpp - tests/Fixture.cpp - tests/AssemblyBuilderA64.test.cpp tests/AssemblyBuilderX64.test.cpp tests/AstJsonEncoder.test.cpp tests/AstQuery.test.cpp @@ -318,7 +319,7 @@ if(TARGET Luau.UnitTest) tests/Config.test.cpp tests/ConstraintSolver.test.cpp tests/CostModel.test.cpp - tests/DataFlowGraphBuilder.test.cpp + tests/DataFlowGraph.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp tests/JsonEmitter.test.cpp diff --git a/VM/src/loslib.cpp b/VM/src/loslib.cpp index 91ccec0c8..62a5668b2 100644 --- a/VM/src/loslib.cpp +++ b/VM/src/loslib.cpp @@ -22,6 +22,21 @@ static time_t timegm(struct tm* timep) { return _mkgmtime(timep); } +#elif defined(__FreeBSD__) +static tm* gmtime_r(const time_t* timep, tm* result) +{ + return gmtime_s(timep, result) == 0 ? result : NULL; +} + +static tm* localtime_r(const time_t* timep, tm* result) +{ + return localtime_s(timep, result) == 0 ? result : NULL; +} + +static time_t timegm(struct tm* timep) +{ + return mktime(timep); +} #endif static int os_clock(lua_State* L) diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 48bb40d49..a642334af 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -91,8 +91,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "class_method") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_class_method") { - ScopedFastFlag luauCheckOverloadedDocSymbol{"LuauCheckOverloadedDocSymbol", true}; - loadDefinition(R"( declare class Foo function bar(self, x: string): number @@ -127,8 +125,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_function_prop") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_overloaded_function_prop") { - ScopedFastFlag luauCheckOverloadedDocSymbol{"LuauCheckOverloadedDocSymbol", true}; - loadDefinition(R"( declare Foo: { new: ((number) -> string) & ((string) -> number) diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp new file mode 100644 index 000000000..18939e24d --- /dev/null +++ b/tests/ClassFixture.cpp @@ -0,0 +1,113 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "ClassFixture.h" + +#include "Luau/BuiltinDefinitions.h" + +using std::nullopt; + +namespace Luau +{ + +ClassFixture::ClassFixture() +{ + TypeArena& arena = typeChecker.globalTypes; + TypeId numberType = typeChecker.numberType; + + unfreeze(arena); + + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(baseClassInstanceType)->props = { + {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, + {"BaseField", {numberType}}, + }; + + TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(baseClassType)->props = { + {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, + {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, + {"New", {makeFunction(arena, nullopt, {}, {baseClassInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + addGlobalBinding(frontend, "BaseClass", baseClassType, "@test"); + + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); + + getMutable(childClassInstanceType)->props = { + {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, + }; + + TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); + getMutable(childClassType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + addGlobalBinding(frontend, "ChildClass", childClassType, "@test"); + + TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); + + getMutable(grandChildInstanceType)->props = { + {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, + }; + + TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); + getMutable(grandChildType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; + addGlobalBinding(frontend, "GrandChild", childClassType, "@test"); + + TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); + + getMutable(anotherChildInstanceType)->props = { + {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, + }; + + TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); + getMutable(anotherChildType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; + addGlobalBinding(frontend, "AnotherChild", childClassType, "@test"); + + TypeId unrelatedClassInstanceType = arena.addType(ClassTypeVar{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); + + TypeId unrelatedClassType = arena.addType(ClassTypeVar{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(unrelatedClassType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {unrelatedClassInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["UnrelatedClass"] = TypeFun{{}, unrelatedClassInstanceType}; + addGlobalBinding(frontend, "UnrelatedClass", unrelatedClassType, "@test"); + + TypeId vector2MetaType = arena.addType(TableTypeVar{}); + + TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); + getMutable(vector2InstanceType)->props = { + {"X", {numberType}}, + {"Y", {numberType}}, + }; + + TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(vector2Type)->props = { + {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, + }; + getMutable(vector2MetaType)->props = { + {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; + addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); + + TypeId callableClassMetaType = arena.addType(TableTypeVar{}); + TypeId callableClassType = arena.addType(ClassTypeVar{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); + getMutable(callableClassMetaType)->props = { + {"__call", {makeFunction(arena, nullopt, {callableClassType, typeChecker.stringType}, {typeChecker.numberType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; + + for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) + persist(tf.type); + + freeze(arena); +} + +} // namespace Luau diff --git a/tests/ClassFixture.h b/tests/ClassFixture.h new file mode 100644 index 000000000..66aec7646 --- /dev/null +++ b/tests/ClassFixture.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +namespace Luau +{ + +struct ClassFixture : BuiltinsFixture +{ + ClassFixture(); +}; + +} // namespace Luau diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 3da40df80..d2cf0ae8e 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -6751,6 +6751,21 @@ ADD R4 R0 R1 MOVE R0 R2 MOVE R1 R3 RETURN R0 0 +)"); + + ScopedFastFlag luauMultiAssignmentConflictFix{"LuauMultiAssignmentConflictFix", true}; + + // because we perform assignments to complex l-values after assignments to locals, we make sure register conflicts are tracked accordingly + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = ... + a[1], b = b, b + 1 + )"), + R"( +GETVARARGS R0 2 +ADDK R2 R1 K0 +SETTABLEN R1 R0 1 +MOVE R1 R2 +RETURN R0 0 )"); } diff --git a/tests/DataFlowGraphBuilder.test.cpp b/tests/DataFlowGraph.test.cpp similarity index 98% rename from tests/DataFlowGraphBuilder.test.cpp rename to tests/DataFlowGraph.test.cpp index 9aa7cde6b..d8230700a 100644 --- a/tests/DataFlowGraphBuilder.test.cpp +++ b/tests/DataFlowGraph.test.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/DataFlowGraphBuilder.h" +#include "Luau/DataFlowGraph.h" #include "Luau/Error.h" #include "Luau/Parser.h" diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index b289b59e5..33d9c75a7 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -11,7 +11,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauIceExceptionInheritanceChange); TEST_SUITE_BEGIN("ModuleTests"); @@ -279,14 +278,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TypeArena dest; CloneState cloneState; - if (FFlag::LuauIceExceptionInheritanceChange) - { - CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); - } - else - { - CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException_DEPRECATED); - } + CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); } TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp index dfa06aa1b..b827b81bb 100644 --- a/tests/NotNull.test.cpp +++ b/tests/NotNull.test.cpp @@ -9,6 +9,8 @@ using Luau::NotNull; +static_assert(!std::is_convertible, bool>::value, "NotNull ought not to be convertible into bool"); + namespace { diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index c0989a2e7..18e91e1ba 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2723,8 +2723,6 @@ TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation" TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_function_argument_list") { - ScopedFastFlag sff{"LuauCommaParenWarnings", true}; - ParseResult result = tryParse(R"( foo(a, b, c,) )"); @@ -2737,8 +2735,6 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_function_parameter_list") { - ScopedFastFlag sff{"LuauCommaParenWarnings", true}; - ParseResult result = tryParse(R"( export type VisitFn = ( any, @@ -2754,8 +2750,6 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_generic_parameter_list") { - ScopedFastFlag sff{"LuauCommaParenWarnings", true}; - ParseResult result = tryParse(R"( export type VisitFn = (a: A, b: B) -> () )"); @@ -2778,8 +2772,6 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_between_table_members") { - ScopedFastFlag luauTableConstructorRecovery{"LuauTableConstructorRecovery", true}; - ParseResult result = tryParse(R"( local t = { first = 1 diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 8bb1fbaf2..29954dc46 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -10,7 +10,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); -LUAU_FASTFLAG(LuauFunctionReturnStringificationFixup); TEST_SUITE_BEGIN("ToString"); @@ -83,7 +82,7 @@ TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") ToStringOptions opts; opts.useLineBreaks = true; - opts.indent = true; + opts.DEPRECATED_indent = true; //clang-format off CHECK_EQ("{|\n" @@ -568,10 +567,7 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_return_type_if_pack_has_an_emp TypeId functionType = arena.addType(FunctionTypeVar{argList, emptyTail}); - if (FFlag::LuauFunctionReturnStringificationFixup) - CHECK("(string) -> string" == toString(functionType)); - else - CHECK("(string) -> (string)" == toString(functionType)); + CHECK("(string) -> string" == toString(functionType)); } TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union") diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index ef40e2783..53c54f4f3 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -9,6 +9,7 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) TEST_SUITE_BEGIN("TypeAliases"); @@ -199,9 +200,15 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases") LUAU_REQUIRE_ERROR_COUNT(1, result); - const char* expectedError = "Type '{ v: string }' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; + const char* expectedError; + if (FFlag::LuauTypeMismatchInvarianceInError) + expectedError = "Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; + else + expectedError = "Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; CHECK(result.errors[0].location == Location{{4, 31}, {4, 44}}); CHECK(toString(result.errors[0]) == expectedError); @@ -220,11 +227,19 @@ TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") LUAU_REQUIRE_ERROR_COUNT(1, result); - const char* expectedError = "Type '{ t: { v: string } }' could not be converted into 'U'\n" - "caused by:\n" - " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; + const char* expectedError; + if (FFlag::LuauTypeMismatchInvarianceInError) + expectedError = "Type '{ t: { v: string } }' could not be converted into 'U'\n" + "caused by:\n" + " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; + else + expectedError = "Type '{ t: { v: string } }' could not be converted into 'U'\n" + "caused by:\n" + " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; CHECK(result.errors[0].location == Location{{4, 31}, {4, 52}}); CHECK(toString(result.errors[0]) == expectedError); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index bb97bbeb1..b94e1df04 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -7,8 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauIceExceptionInheritanceChange) - using namespace Luau; TEST_SUITE_BEGIN("AnnotationTests"); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 787aea9ae..32e31e16e 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -684,20 +684,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("string", toString(requireType("foo"))); - CHECK_EQ("*error-type*", toString(requireType("bar"))); - CHECK_EQ("*error-type*", toString(requireType("baz"))); - CHECK_EQ("*error-type*", toString(requireType("quux"))); - } - else - { - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); - CHECK_EQ("any", toString(requireType("baz"))); - CHECK_EQ("any", toString(requireType("quux"))); - } + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); + CHECK_EQ("any", toString(requireType("baz"))); + CHECK_EQ("any", toString(requireType("quux"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_string_head") @@ -714,19 +704,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_strin LUAU_REQUIRE_NO_ERRORS(result); if (FFlag::DebugLuauDeferredConstraintResolution) - { CHECK_EQ("string", toString(requireType("foo"))); - CHECK_EQ("string", toString(requireType("bar"))); - CHECK_EQ("*error-type*", toString(requireType("baz"))); - CHECK_EQ("*error-type*", toString(requireType("quux"))); - } else - { CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); - CHECK_EQ("any", toString(requireType("baz"))); - CHECK_EQ("any", toString(requireType("quux"))); - } + + CHECK_EQ("any", toString(requireType("bar"))); + CHECK_EQ("any", toString(requireType("baz"))); + CHECK_EQ("any", toString(requireType("quux"))); } TEST_CASE_FIXTURE(Fixture, "string_format_as_method") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index d00f1d831..07dfc33fe 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -1,102 +1,18 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Fixture.h" +#include "ClassFixture.h" #include "doctest.h" using namespace Luau; using std::nullopt; -struct ClassFixture : BuiltinsFixture -{ - ClassFixture() - { - TypeArena& arena = typeChecker.globalTypes; - TypeId numberType = typeChecker.numberType; - - unfreeze(arena); - - TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); - getMutable(baseClassInstanceType)->props = { - {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, - {"BaseField", {numberType}}, - }; - - TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); - getMutable(baseClassType)->props = { - {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, - {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, - {"New", {makeFunction(arena, nullopt, {}, {baseClassInstanceType})}}, - }; - typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - addGlobalBinding(frontend, "BaseClass", baseClassType, "@test"); - - TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); - - getMutable(childClassInstanceType)->props = { - {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, - }; - - TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); - getMutable(childClassType)->props = { - {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, - }; - typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; - addGlobalBinding(frontend, "ChildClass", childClassType, "@test"); - - TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); - - getMutable(grandChildInstanceType)->props = { - {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, - }; - - TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); - getMutable(grandChildType)->props = { - {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, - }; - typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; - addGlobalBinding(frontend, "GrandChild", childClassType, "@test"); - - TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); - - getMutable(anotherChildInstanceType)->props = { - {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, - }; - - TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); - getMutable(anotherChildType)->props = { - {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, - }; - typeChecker.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; - addGlobalBinding(frontend, "AnotherChild", childClassType, "@test"); - - TypeId vector2MetaType = arena.addType(TableTypeVar{}); - - TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); - getMutable(vector2InstanceType)->props = { - {"X", {numberType}}, - {"Y", {numberType}}, - }; - - TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); - getMutable(vector2Type)->props = { - {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, - }; - getMutable(vector2MetaType)->props = { - {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, - }; - typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; - addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); - - for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) - persist(tf.type); - - freeze(arena); - } -}; +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError); TEST_SUITE_BEGIN("TypeInferClasses"); @@ -514,4 +430,67 @@ TEST_CASE_FIXTURE(ClassFixture, "unions_of_intersections_of_classes") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(ClassFixture, "index_instance_property") +{ + ScopedFastFlag luauAllowIndexClassParameters{"LuauAllowIndexClassParameters", true}; + + CheckResult result = check(R"( + local function execute(object: BaseClass, name: string) + print(object[name]) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Attempting a dynamic property access on type 'BaseClass' is unsafe and may cause exceptions at runtime", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(ClassFixture, "index_instance_property_nonstrict") +{ + ScopedFastFlag luauAllowIndexClassParameters{"LuauAllowIndexClassParameters", true}; + + CheckResult result = check(R"( + --!nonstrict + + local function execute(object: BaseClass, name: string) + print(object[name]) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "type_mismatch_invariance_required_for_error") +{ + CheckResult result = check(R"( +type A = { x: ChildClass } +type B = { x: BaseClass } + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'x' is not compatible. Type 'ChildClass' could not be converted into 'BaseClass' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'x' is not compatible. Type 'ChildClass' could not be converted into 'BaseClass')"); +} + +TEST_CASE_FIXTURE(ClassFixture, "callable_classes") +{ + ScopedFastFlag luauCallableClasses{"LuauCallableClasses", true}; + + CheckResult result = check(R"( + local x : CallableClass + local y = x("testing") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("y"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 684b47e94..26115046d 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -311,8 +311,6 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_referenced_types") { - ScopedFastFlag LuauPersistTypesAfterGeneratingDocSyms("LuauPersistTypesAfterGeneratingDocSyms", true); - loadDefinition(R"( declare class MyClass function myMethod(self) @@ -396,4 +394,26 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") CHECK_EQ(toString(requireType("y")), "string"); } +TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") +{ + ScopedFastFlag LuauDeclareClassPrototype("LuauDeclareClassPrototype", true); + + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + declare class Channel + Messages: { Message } + OnMessage: (message: Message) -> () + end + + declare class Message + Text: string + Channel: Channel + end + )", + "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE(result.success); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index b306515a9..552180401 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -230,8 +230,6 @@ TEST_CASE_FIXTURE(Fixture, "too_many_arguments") TEST_CASE_FIXTURE(Fixture, "too_many_arguments_error_location") { - ScopedFastFlag sff{"LuauArgMismatchReportFunctionLocation", true}; - CheckResult result = check(R"( --!strict @@ -507,7 +505,9 @@ TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotat LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); + TypeId ty = requireType("most_of_the_natural_numbers"); + const FunctionTypeVar* functionType = get(ty); + REQUIRE_MESSAGE(functionType, "Expected function but got " << toString(ty)); std::optional retType = first(functionType->retTypes); REQUIRE(retType); @@ -1830,4 +1830,18 @@ TEST_CASE_FIXTURE(Fixture, "other_things_are_not_related_to_function") CHECK(5 == result.errors[3].location.begin.line); } +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_must_follow_in_overload_resolution") +{ + ScopedFastFlag luauTypeInferMissingFollows{"LuauTypeInferMissingFollows", true}; + + CheckResult result = check(R"( +for _ in function():(t0)&((()->())&(()->())) +end do +_(_(_,_,_),_) +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index de41c3a63..c25f8e5fc 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -10,6 +10,7 @@ #include "doctest.h" LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) using namespace Luau; @@ -717,12 +718,24 @@ y.a.c = y )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), - R"(Type 'y' could not be converted into 'T' + if (FFlag::LuauTypeMismatchInvarianceInError) + { + CHECK_EQ(toString(result.errors[0]), + R"(Type 'y' could not be converted into 'T' +caused by: + Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' +caused by: + Property 'd' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + } + else + { + CHECK_EQ(toString(result.errors[0]), + R"(Type 'y' could not be converted into 'T' caused by: Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' caused by: Property 'd' is not compatible. Type 'number' could not be converted into 'string')"); + } } TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") @@ -1249,4 +1262,21 @@ instantiate(function(x: string) return "foo" end) CHECK_EQ("(string) -> string", toString(tm1->givenType)); } +TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_and_generalization_play_nice") +{ + CheckResult result = check(R"( + local foo = function(a) + return a() + end + + local a = foo(function() return 1 end) + local b = foo(function() return "bar" end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("number" == toString(requireType("a"))); + CHECK("string" == toString(requireType("b"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 0c10eb87e..7d0621a79 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -185,7 +185,15 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_dep )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string & string", toString(requireType("r"))); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("string", toString(requireType("r"))); + } + else + { + CHECK_EQ("string & string", toString(requireType("r"))); + } } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_mixed_types") @@ -199,7 +207,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_mixed_types") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number & string", toString(requireType("r"))); // TODO(amccord): This should be an error. + CHECK_EQ("number & string", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_part_missing_the_property") @@ -525,18 +533,16 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, + {"LuauUninhabitedSubAnything", true}, }; CheckResult result = check(R"( - local x : { p : number?, q : never } & { p : never, q : string? } + local x : { p : number?, q : never } & { p : never, q : string? } -- OK local y : { p : never, q : never } = x -- OK local z : never = x -- OK )"); - // TODO: this should not produce type errors, since never <: { p : never } - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| p: never, q: string? |} & {| p: number?, q: never |}' could not be converted into 'never'; none " - "of the intersection parts are compatible"); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") @@ -848,7 +854,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables_with_properties") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_with table") +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_with_table") { ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, @@ -902,4 +908,37 @@ TEST_CASE_FIXTURE(Fixture, "CLI-44817") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function f(t): { x: number } & { x: string } + local x = t.x + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({| x: number |} & {| x: string |}) -> {| x: number |} & {| x: string |}", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function f(t: { x: number } & { x: string }) + return t.x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({| x: number |} & {| x: string |}) -> number & string", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 4cc628fbf..b06c80e92 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -11,6 +11,7 @@ #include "doctest.h" LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) using namespace Luau; @@ -408,7 +409,12 @@ local b: B.T = a CheckResult result = frontend.check("game/C"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' +caused by: + Property 'x' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); } @@ -442,7 +448,12 @@ local b: B.T = a CheckResult result = frontend.check("game/D"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' +caused by: + Property 'x' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); } @@ -462,4 +473,15 @@ return l0 LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_anyify_variadic_return_must_follow") +{ + ScopedFastFlag luauTypeInferMissingFollows{"LuauTypeInferMissingFollows", true}; + + CheckResult result = check(R"( +return unpack(l0[_]) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index c4bbc2e11..21806082f 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -8,6 +8,7 @@ #include "Luau/VisitTypeVar.h" #include "Fixture.h" +#include "ClassFixture.h" #include "doctest.h" @@ -817,6 +818,21 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_ LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") +{ + ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; + + CheckResult result = check(R"( + local a: string | number = "hi" + local b: {x: string}? = {x = "bye"} + + local r1 = a == b + local r2 = b == a + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + TEST_CASE_FIXTURE(Fixture, "refine_and_or") { CheckResult result = check(R"( @@ -916,6 +932,31 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(ClassFixture, "unrelated_classes_cannot_be_compared") +{ + ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; + + CheckResult result = check(R"( + local a = BaseClass.New() + local b = UnrelatedClass.New() + + local c = a == b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "unrelated_primitives_cannot_be_compared") +{ + ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; + + CheckResult result = check(R"( + local c = 5 == true + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "mm_ops_must_return_a_value") { if (!FFlag::DebugLuauDeferredConstraintResolution) diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 5688eaaa1..259341744 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -170,24 +170,12 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "error_on_eq_metamethod_returning_a_type_othe CHECK_EQ("Metamethod '__eq' must return type 'boolean'", ge->message); } -// Requires success typing to confidently determine that this expression has no overlap. -TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") -{ - CheckResult result = check(R"( - local a: string | number = "hi" - local b: {x: string}? = {x = "bye"} - - local r1 = a == b - local r2 = b == a - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - // Belongs in TypeInfer.refinements.test.cpp. -// We'll need to not only report an error on `a == b`, but also to refine both operands as `never` in the `==` branch. +// We need refine both operands as `never` in the `==` branch. TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") { + ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; + CheckResult result = check(R"( local function f(a: string, b: boolean?) if a == b then @@ -198,7 +186,7 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "string"); // a == b CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 66550be3e..e5bc186a0 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,6 +8,7 @@ #include "doctest.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) using namespace Luau; @@ -35,6 +36,27 @@ std::optional> magicFunctionInstanceIsA( return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } +std::vector dcrMagicRefinementInstanceIsA(MagicRefinementContext ctx) +{ + if (ctx.callSite->args.size != 1) + return {}; + + auto index = ctx.callSite->func->as(); + auto str = ctx.callSite->args.data[0]->as(); + if (!index || !str) + return {}; + + std::optional def = ctx.dfg->getDef(index->expr); + if (!def) + return {}; + + std::optional tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size)); + if (!tfun) + return {}; + + return {ctx.connectiveArena->proposition(*def, tfun->type)}; +} + struct RefinementClassFixture : BuiltinsFixture { RefinementClassFixture() @@ -56,6 +78,7 @@ struct RefinementClassFixture : BuiltinsFixture TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + getMutable(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA; getMutable(inst)->props = { {"Name", Property{typeChecker.stringType}}, @@ -397,13 +420,21 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") local t: {x: number?} = {x = 1} if t.x then - local foo: number = t.x + local t2 = t + local foo = t.x end local bar = t.x )"); LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK("{| x: number? |} & {| x: ~(false?) |}" == toString(requireTypeAtPosition({4, 23}))); + CHECK("(number?) & ~(false?)" == toString(requireTypeAtPosition({5, 26}))); + } + CHECK_EQ("number?", toString(requireType("bar"))); } @@ -442,12 +473,24 @@ TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_ty )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' + + if (FFlag::LuauTypeMismatchInvarianceInError) + { + CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' +caused by: + Property 'x' is not compatible. Type 'number?' could not be converted into 'number' in an invariant context)", + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' caused by: Property 'x' is not compatible. Type 'number?' could not be converted into 'number')", - toString(result.errors[0])); + toString(result.errors[0])); + } } + TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") { CheckResult result = check(R"( @@ -464,8 +507,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "((number | string)?) & (boolean?)"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "((number | string)?) & (boolean?)"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "((number | string)?) & unknown"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "(boolean?) & unknown"); // a == b CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "((number | string)?) & unknown"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "(boolean?) & unknown"); // a ~= b @@ -496,7 +539,7 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & number"); // a == 1 + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & unknown"); // a == 1 CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & unknown"); // a ~= 1 } else @@ -548,8 +591,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & ~nil"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & nil"); // a == nil + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & ~nil"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & unknown"); // a == nil } else { @@ -573,8 +616,8 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") if (FFlag::DebugLuauDeferredConstraintResolution) { ToStringOptions opts; - CHECK_EQ(toString(requireTypeAtPosition({3, 33}), opts), "(string?) & a"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36}), opts), "(string?) & a"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 33}), opts), "a & unknown"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36}), opts), "(string?) & unknown"); // a == b } else { @@ -628,8 +671,8 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string & unknown"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "(string?) & unknown"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "(string?) & string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "(string?) & string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string & unknown"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "(string?) & unknown"); // a == b } else { @@ -1146,8 +1189,17 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ(R"({| tag: "exists", x: string |} | {| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"(({| tag: "exists", x: string |} | {| tag: "missing", x: nil |}) & {| x: ~(false?) |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ( + R"(({| tag: "exists", x: string |} | {| tag: "missing", x: nil |}) & {| x: ~~(false?) |})", toString(requireTypeAtPosition({7, 28}))); + } + else + { + CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "exists", x: string |} | {| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); + } } TEST_CASE_FIXTURE(Fixture, "discriminate_tag") @@ -1159,17 +1211,57 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag") local function f(animal: Animal) if animal.tag == "Cat" then - local cat: Cat = animal + local cat = animal elseif animal.tag == "Dog" then - local dog: Dog = animal + local dog = animal + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"((Cat | Dog) & {| tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ(R"((Cat | Dog) & {| tag: ~"Cat" |} & {| tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); + } + else + { + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "discriminate_tag_with_implicit_else") +{ + ScopedFastFlag sff{"LuauImplicitElseRefinement", true}; + + CheckResult result = check(R"( + type Cat = {tag: "Cat", name: string, catfood: string} + type Dog = {tag: "Dog", name: string, dogfood: string} + type Animal = Cat | Dog + + local function f(animal: Animal) + if animal.tag == "Cat" then + local cat = animal + else + local dog = animal end end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"((Cat | Dog) & {| tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ(R"((Cat | Dog) & {| tag: ~"Cat" |})", toString(requireTypeAtPosition({9, 33}))); + } + else + { + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); + } } TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") @@ -1258,8 +1350,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"(({| tag: "Folder", x: Folder |} | {| tag: "Part", x: Part |}) & {| x: Part |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"(({| tag: "Folder", x: Folder |} | {| tag: "Part", x: Part |}) & {| x: ~Part |})", toString(requireTypeAtPosition({7, 28}))); + } + else + { + CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") @@ -1399,8 +1499,96 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("Folder & Instance & {- -}", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("(~Folder | ~Instance) & {- -} & never", toString(requireTypeAtPosition({7, 28}))); + } + else + { + CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); + } +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "refine_param_of_type_instance_without_using_typeof") +{ + CheckResult result = check(R"( + local function f(x: Instance) + if x:IsA("Folder") then + local foo = x + elseif typeof(x) == "table" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("Folder & Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance & ~Folder & never", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Folder", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("never", toString(requireTypeAtPosition({5, 28}))); + } +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "refine_param_of_type_folder_or_part_without_using_typeof") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder) + if x:IsA("Folder") then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Folder | Part) & Folder", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Folder | Part) & ~Folder", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Folder", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); + } +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "isa_type_refinement_must_be_known_ahead_of_time") +{ + CheckResult result = check(R"( + local function f(x): Instance + if x:IsA("Folder") then + local foo = x + else + local foo = x + end + + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") @@ -1526,4 +1714,18 @@ TEST_CASE_FIXTURE(Fixture, "else_with_no_explicit_expression_should_also_refine_ LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "fuzz_filtered_refined_types_are_followed") +{ + ScopedFastFlag luauTypeInferMissingFollows{"LuauTypeInferMissingFollows", true}; + + CheckResult result = check(R"( +local _ +do +local _ = _ ~= _ or _ or _ +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 10186e3a0..c94ed1f9a 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -18,6 +18,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) TEST_SUITE_BEGIN("TableTests"); @@ -2024,7 +2025,12 @@ local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' caused by: Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); } @@ -2043,7 +2049,14 @@ local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' caused by: Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' caused by: @@ -2063,7 +2076,14 @@ local c2: typeof(a2) = b2 )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' +caused by: + Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' +caused by: + Property 'y' is not compatible. Type 'string' could not be converted into 'number' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' caused by: Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' caused by: @@ -2098,7 +2118,12 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' caused by: Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string')"); } @@ -2114,7 +2139,12 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' caused by: Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string')"); } @@ -3261,7 +3291,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner", true}; // Changes argument from table type to primitive + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; // Changes argument from table type to primitive CheckResult result = check(R"( local function f(s): string @@ -3308,7 +3338,7 @@ caused by: TEST_CASE_FIXTURE(BuiltinsFixture, "a_free_shape_can_turn_into_a_scalar_directly") { ScopedFastFlag luauScalarShapeSubtyping{"LuauScalarShapeSubtyping", true}; - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner", true}; + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; CheckResult result = check(R"( local function stringByteList(str) @@ -3394,7 +3424,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_has_a_side_effect") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK(toString(requireType("foo")) == "{ @metatable { __add: (a, b) -> number }, { } }"); + CHECK(toString(requireType("foo")) == "{ @metatable { __add: (a, b) -> number }, { } }"); } TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated") @@ -3413,4 +3443,43 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated") CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts)); } +TEST_CASE_FIXTURE(Fixture, "fuzz_table_indexer_unification_can_bound_owner_to_string") +{ + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; + + CheckResult result = check(R"( +sin,_ = nil +_ = _[_.sin][_._][_][_]._ +_[_] = _ + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_extra_prop_unification_can_bound_owner_to_string") +{ + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; + + CheckResult result = check(R"( +l0,_ = nil +_ = _,_[_.n5]._[_][_][_]._ +_._.foreach[_],_ = _[_],_._ + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_typelevel_promote_on_changed_table_type") +{ + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; + + CheckResult result = check(R"( +_._,_ = nil +_ = _.foreach[_]._,_[_.n5]._[_.foreach][_][_]._ +_ = _._ + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 04b8bf574..e42cea638 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -351,7 +351,7 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") CheckResult result = check(R"(("foo"))" + rep(":lower()", limit)); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(nullptr != get(result.errors[0])); + CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected CodeTooComplex but got " << toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "globals") @@ -1159,4 +1159,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_retu LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "fuzz_free_table_type_change_during_index_check") +{ + ScopedFastFlag luauFollowInLvalueIndexCheck{"LuauFollowInLvalueIndexCheck", true}; + + CheckResult result = check(R"( +local _ = nil +while _["" >= _] do +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index f04a3d950..5cc07a286 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -110,6 +110,68 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } +TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_never") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f(arg : string & number) : never + return arg + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_anything") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f(arg : string & number) : boolean + return arg + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_never") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + {"LuauUninhabitedSubAnything", true}, + }; + + CheckResult result = check(R"( + function f(arg : { prop : string & number }) : never + return arg + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_anything") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + {"LuauUninhabitedSubAnything", true}, + }; + + CheckResult result = check(R"( + function f(arg : { prop : string & number }) : boolean + return arg + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_unified_with_errorType") { CheckResult result = check(R"( @@ -299,4 +361,19 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table CHECK_EQ(toString(state.errors[0]), expected); } +TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") +{ + ScopedFastFlag luauTxnLogTypePackIterator{"LuauTxnLogTypePackIterator", true}; + + TypePackVar variadicAny{VariadicTypePack{typeChecker.anyType}}; + TypePackVar packTmp{TypePack{{typeChecker.anyType}, &variadicAny}}; + TypePackVar packSub{TypePack{{typeChecker.anyType, typeChecker.anyType}, &packTmp}}; + + TypeVar freeTy{FreeTypeVar{TypeLevel{}}}; + TypePackVar freeTp{FreeTypePack{TypeLevel{}}}; + TypePackVar packSuper{TypePack{{&freeTy}, &freeTp}}; + + state.tryUnify(&packSub, &packSuper); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index cde651dfe..0e4074f75 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -7,8 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauFunctionReturnStringificationFixup); - using namespace Luau; TEST_SUITE_BEGIN("TypePackTests"); @@ -311,10 +309,7 @@ local c: Packed auto ttvA = get(requireType("a")); REQUIRE(ttvA); CHECK_EQ(toString(requireType("a")), "Packed"); - if (FFlag::LuauFunctionReturnStringificationFixup) - CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); - else - CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); REQUIRE(ttvA->instantiatedTypeParams.size() == 1); REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); @@ -467,8 +462,6 @@ type I = W TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") { - ScopedFastFlag sff("LuauFunctionReturnStringificationFixup", true); - CheckResult result = check(R"( type X = (T...) -> (T...) @@ -492,8 +485,6 @@ type F = X<(string, ...number)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") { - ScopedFastFlag sff("LuauFunctionReturnStringificationFixup", true); - CheckResult result = check(R"( type Y = (T...) -> (U...) @@ -1002,6 +993,10 @@ TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments_free") TEST_CASE_FIXTURE(BuiltinsFixture, "type_packs_with_tails_in_vararg_adjustment") { + std::optional sff; + if (FFlag::DebugLuauDeferredConstraintResolution) + sff = {"LuauInstantiateInSubtyping", true}; + CheckResult result = check(R"( local function wrapReject(fn: (self: any, ...TArg) -> ...TResult): (self: any, ...TArg) -> ...TResult return function(self, ...) @@ -1017,4 +1012,46 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_packs_with_tails_in_vararg_adjustment") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "generalize_expectedTypes_with_proper_scope") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + {"LuauInstantiateInSubtyping", true}, + }; + + CheckResult result = check(R"( + local function f(fn: () -> ...TResult): () -> ...TResult + return function() + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_typepack_iter_follow") +{ + ScopedFastFlag luauTxnLogTypePackIterator{"LuauTxnLogTypePackIterator", true}; + + CheckResult result = check(R"( +local _ +local _ = _,_(),_(_) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_typepack_iter_follow_2") +{ + ScopedFastFlag luauTxnLogTypePackIterator{"LuauTxnLogTypePackIterator", true}; + + CheckResult result = check(R"( +function test(name, searchTerm) + local found = string.find(name:lower(), searchTerm:lower()) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 627fbb566..adfc61b63 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -729,4 +729,37 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics "of the union options are compatible"); } +TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function f(t): { x: number } | { x: string } + local x = t.x + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({| x: number |} | {| x: string |}) -> {| x: number |} | {| x: string |}", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types_2") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function f(t: { x: number } | { x: string }) + return t.x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({| x: number |} | {| x: string |}) -> number | string", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitTypeVar.test.cpp index 589c3bad5..4fba694a8 100644 --- a/tests/VisitTypeVar.test.cpp +++ b/tests/VisitTypeVar.test.cpp @@ -22,14 +22,7 @@ TEST_CASE_FIXTURE(Fixture, "throw_when_limit_is_exceeded") TypeId tType = requireType("t"); - if (FFlag::LuauIceExceptionInheritanceChange) - { - CHECK_THROWS_AS(toString(tType), RecursionLimitException); - } - else - { - CHECK_THROWS_AS(toString(tType), RecursionLimitException_DEPRECATED); - } + CHECK_THROWS_AS(toString(tType), RecursionLimitException); } TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") diff --git a/tools/faillist.txt b/tools/faillist.txt index 433d0cfbe..6f49db84c 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -14,9 +14,11 @@ AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop AutocompleteTest.autocomplete_first_function_arg_expected_type -AutocompleteTest.autocomplete_interpolated_string +AutocompleteTest.autocomplete_interpolated_string_as_singleton +AutocompleteTest.autocomplete_interpolated_string_constant +AutocompleteTest.autocomplete_interpolated_string_expression +AutocompleteTest.autocomplete_interpolated_string_expression_with_comments AutocompleteTest.autocomplete_oop_implicit_self -AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.autocomplete_string_singleton_escape AutocompleteTest.autocomplete_string_singletons AutocompleteTest.autocompleteProp_index_function_metamethod_is_variadic @@ -25,9 +27,11 @@ AutocompleteTest.do_wrong_compatible_self_calls AutocompleteTest.keyword_methods AutocompleteTest.no_incompatible_self_calls AutocompleteTest.no_wrong_compatible_self_calls_with_generics +AutocompleteTest.suggest_external_module_type AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_argument_type_suggestion AutocompleteTest.type_correct_expected_argument_type_pack_suggestion +AutocompleteTest.type_correct_expected_argument_type_suggestion AutocompleteTest.type_correct_expected_argument_type_suggestion_optional AutocompleteTest.type_correct_expected_argument_type_suggestion_self AutocompleteTest.type_correct_expected_return_type_pack_suggestion @@ -68,7 +72,6 @@ BuiltinTests.select_with_decimal_argument_is_rounded_down BuiltinTests.set_metatable_needs_arguments BuiltinTests.setmetatable_should_not_mutate_persisted_types BuiltinTests.sort_with_bad_predicate -BuiltinTests.string_format_arg_count_mismatch BuiltinTests.string_format_as_method BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions @@ -106,10 +109,10 @@ GenericsTests.generic_factories GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_table_method GenericsTests.generic_type_pack_parentheses -GenericsTests.generic_type_pack_unification2 GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument_overloaded +GenericsTests.infer_generic_lib_function_function_argument GenericsTests.infer_generic_methods GenericsTests.infer_generic_property GenericsTests.instantiated_function_argument_names @@ -117,9 +120,6 @@ GenericsTests.instantiation_sharing_types GenericsTests.no_stack_overflow_from_quantifying GenericsTests.reject_clashing_generic_and_pack_names GenericsTests.self_recursive_instantiated_param -IntersectionTypes.index_on_an_intersection_type_with_mixed_types -IntersectionTypes.index_on_an_intersection_type_with_property_guaranteed_to_exist -IntersectionTypes.index_on_an_intersection_type_works_at_arbitrary_depth IntersectionTypes.no_stack_overflow_from_flattenintersection IntersectionTypes.select_correct_union_fn IntersectionTypes.should_still_pick_an_overload_whose_arguments_are_unions @@ -151,6 +151,7 @@ ProvisionalTests.bail_early_if_unification_is_too_complicated ProvisionalTests.discriminate_from_x_not_equal_to_nil ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean +ProvisionalTests.free_options_cannot_be_unified_together ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns ProvisionalTests.lvalue_equals_another_lvalue_with_no_overlap @@ -164,15 +165,15 @@ ProvisionalTests.while_body_are_also_refined RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number RefinementTest.assert_non_binary_expressions_actually_resolve_constraints +RefinementTest.assign_table_with_refined_property_with_a_similar_type_is_illegal RefinementTest.call_an_incompatible_function_after_using_typeguard RefinementTest.correctly_lookup_property_whose_base_was_previously_refined RefinementTest.correctly_lookup_property_whose_base_was_previously_refined2 -RefinementTest.discriminate_from_isa_of_x -RefinementTest.discriminate_from_truthiness_of_x RefinementTest.discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false RefinementTest.discriminate_tag RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil +RefinementTest.fuzz_filtered_refined_types_are_followed RefinementTest.index_on_a_refined_property RefinementTest.invert_is_truthy_constraint_ifelse_expression RefinementTest.is_truthy_constraint_ifelse_expression @@ -181,7 +182,6 @@ RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table RefinementTest.refine_unknowns -RefinementTest.truthy_constraint_on_properties RefinementTest.type_comparison_ifelse_expression RefinementTest.type_guard_can_filter_for_intersection_of_tables RefinementTest.type_guard_narrowed_into_nothingness @@ -210,7 +210,6 @@ TableTests.defining_a_self_method_for_a_builtin_sealed_table_must_fail TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index -TableTests.dont_leak_free_table_props TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys TableTests.error_detailed_metatable_prop @@ -240,6 +239,7 @@ TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_i TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors TableTests.less_exponential_blowup_please +TableTests.meta_add TableTests.meta_add_both_ways TableTests.meta_add_inferred TableTests.metatable_mismatch_should_fail @@ -252,8 +252,6 @@ TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.oop_indexer_works TableTests.oop_polymorphic TableTests.open_table_unification_2 -TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table -TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table_2 TableTests.persistent_sealed_table_is_immutable TableTests.prop_access_on_key_whose_types_mismatches TableTests.property_lookup_through_tabletypevar_metatable @@ -268,6 +266,7 @@ TableTests.scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables +TableTests.table_call_metamethod_basic TableTests.table_indexing_error_location TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict TableTests.table_insert_should_cope_with_optional_properties_in_strict @@ -323,6 +322,7 @@ TypeInfer.checking_should_not_ice TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError +TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.globals TypeInfer.globals2 TypeInfer.infer_assignment_value_types_mutable_lval @@ -335,7 +335,6 @@ TypeInfer.tc_interpolated_string_with_invalid_expression TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 -TypeInferAnyError.for_in_loop_iterator_is_error2 TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.classes_without_overloaded_operators_cannot_be_added @@ -351,8 +350,6 @@ TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict -TypeInferFunctions.free_is_not_bound_to_unknown -TypeInferFunctions.func_expr_doesnt_leak_free TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite @@ -385,12 +382,11 @@ TypeInferLoops.for_in_loop TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_generic_next -TypeInferLoops.for_in_with_just_one_iterator_is_ok TypeInferLoops.loop_iter_metamethod_ok_with_inference TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil +TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop -TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types TypeInferModules.module_type_conflict @@ -414,6 +410,8 @@ TypeInferOperators.compound_assign_mismatch_result TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown +TypeInferOperators.mm_comparisons_must_return_a_boolean +TypeInferOperators.mm_ops_must_return_a_value TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not TypeInferOperators.refine_and_or TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection @@ -427,16 +425,13 @@ TypeInferUnknownNever.assign_to_prop_which_is_never TypeInferUnknownNever.assign_to_subscript_which_is_never TypeInferUnknownNever.call_never TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators -TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_never TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_sorta_never TypeInferUnknownNever.math_operators_and_never TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.unary_minus_of_never TypePackTests.detect_cyclic_typepacks2 -TypePackTests.higher_order_function TypePackTests.pack_tail_unification_check -TypePackTests.parenthesized_varargs_returns_any TypePackTests.type_alias_backwards_compatible TypePackTests.type_alias_default_export TypePackTests.type_alias_default_mixed_self @@ -456,7 +451,6 @@ TypePackTests.type_alias_type_packs_nested TypePackTests.type_pack_type_parameters TypePackTests.unify_variadic_tails_in_arguments TypePackTests.unify_variadic_tails_in_arguments_free -TypePackTests.varargs_inference_through_multiple_scopes TypePackTests.variadic_packs TypeSingletons.error_detailed_tagged_union_mismatch_bool TypeSingletons.error_detailed_tagged_union_mismatch_string @@ -477,11 +471,8 @@ TypeSingletons.widening_happens_almost_everywhere_except_for_tables UnionTypes.error_detailed_optional UnionTypes.error_detailed_union_all UnionTypes.index_on_a_union_type_with_missing_property -UnionTypes.index_on_a_union_type_with_mixed_types UnionTypes.index_on_a_union_type_with_one_optional_property UnionTypes.index_on_a_union_type_with_one_property_of_type_any -UnionTypes.index_on_a_union_type_with_property_guaranteed_to_exist -UnionTypes.index_on_a_union_type_works_at_arbitrary_depth UnionTypes.optional_assignment_errors UnionTypes.optional_call_error UnionTypes.optional_field_access_error From f10b294d6297fbf1b24cc82b7fa5f76dfb444115 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 2 Dec 2022 15:46:09 +0200 Subject: [PATCH 19/66] What even is this --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index dbd6a495d..24a763b7c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -27,6 +27,7 @@ jobs: - uses: actions/checkout@v1 - name: make tests run: | + g++ --version make -j2 config=sanitize werror=1 native=1 luau-tests - name: run tests run: | From 6cd507dff0f97f9c40afa909fa47531e793db26c Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 2 Dec 2022 18:22:01 +0200 Subject: [PATCH 20/66] Work-around for gcc --- .github/workflows/build.yml | 1 - Analysis/src/ConstraintGraphBuilder.cpp | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 24a763b7c..dbd6a495d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -27,7 +27,6 @@ jobs: - uses: actions/checkout@v1 - name: make tests run: | - g++ --version make -j2 config=sanitize werror=1 native=1 luau-tests - name: run tests run: | diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 600b6d23d..d41c77723 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -1807,7 +1807,9 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS signatureScope->privateTypePackBindings[name] = g.tp; } - expectedType.reset(); + // Local variable works around an odd gcc 11.3 warning: may be used uninitialized + std::optional none = std::nullopt; + expectedType = none; } std::vector argTypes; From abe6768a1dcff5bea386a421f4807c33f5197993 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 9 Dec 2022 10:07:25 -0800 Subject: [PATCH 21/66] Sync to upstream/release/556 --- Analysis/include/Luau/BuiltinDefinitions.h | 9 +- Analysis/include/Luau/Constraint.h | 7 + .../include/Luau/ConstraintGraphBuilder.h | 12 + Analysis/include/Luau/ConstraintSolver.h | 9 +- Analysis/include/Luau/Error.h | 6 +- Analysis/include/Luau/Linter.h | 1 + Analysis/include/Luau/Normalize.h | 22 +- Analysis/include/Luau/Scope.h | 1 - Analysis/include/Luau/ToString.h | 26 +- Analysis/include/Luau/TxnLog.h | 6 +- Analysis/include/Luau/TypeUtils.h | 4 +- Analysis/include/Luau/TypeVar.h | 7 +- Analysis/src/AstJsonEncoder.cpp | 27 ++- Analysis/src/AstQuery.cpp | 22 +- Analysis/src/Autocomplete.cpp | 47 +++- Analysis/src/BuiltinDefinitions.cpp | 13 + Analysis/src/ConstraintGraphBuilder.cpp | 48 ++-- Analysis/src/ConstraintSolver.cpp | 44 ++-- Analysis/src/Error.cpp | 1 + Analysis/src/Frontend.cpp | 29 +-- Analysis/src/Normalize.cpp | 195 ++++++--------- Analysis/src/ToString.cpp | 13 + Analysis/src/TxnLog.cpp | 1 + Analysis/src/TypeChecker2.cpp | 110 ++++++++- Analysis/src/TypeInfer.cpp | 4 +- Analysis/src/TypeUtils.cpp | 97 -------- Analysis/src/TypeVar.cpp | 7 - Analysis/src/Unifier.cpp | 11 +- Ast/include/Luau/Ast.h | 1 + Ast/include/Luau/Location.h | 136 ++--------- Ast/src/Location.cpp | 122 +++++++++- CLI/Ast.cpp | 1 + CodeGen/src/CodeGen.cpp | 100 ++++---- CodeGen/src/EmitCommonX64.cpp | 34 +-- CodeGen/src/EmitCommonX64.h | 22 +- CodeGen/src/EmitInstructionX64.cpp | 226 +++++++++--------- CodeGen/src/EmitInstructionX64.h | 64 ++--- tests/AstJsonEncoder.test.cpp | 7 + tests/Autocomplete.test.cpp | 65 +++++ tests/Fixture.cpp | 48 +++- tests/Fixture.h | 49 +--- tests/Frontend.test.cpp | 4 - tests/Normalize.test.cpp | 1 - tests/ToString.test.cpp | 12 +- tests/TypeInfer.aliases.test.cpp | 14 +- tests/TypeInfer.intersectionTypes.test.cpp | 2 +- tests/TypeInfer.negations.test.cpp | 3 +- tests/TypeInfer.operators.test.cpp | 86 +++++-- tests/TypeInfer.provisional.test.cpp | 30 +++ tests/TypeInfer.refinements.test.cpp | 82 +++---- tests/TypeInfer.tables.test.cpp | 69 ++---- tests/TypeInfer.tryUnify.test.cpp | 4 +- tools/faillist.txt | 58 +++-- 53 files changed, 1100 insertions(+), 919 deletions(-) diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 4702995d4..16cccafe9 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -1,13 +1,18 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Frontend.h" #include "Luau/Scope.h" -#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include namespace Luau { +struct Frontend; +struct TypeChecker; +struct TypeArena; + void registerBuiltinTypes(Frontend& frontend); void registerBuiltinGlobals(TypeChecker& typeChecker); diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index e13613ed8..9eea9c288 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -3,6 +3,7 @@ #include "Luau/Ast.h" // Used for some of the enumerations #include "Luau/Def.h" +#include "Luau/DenseHash.h" #include "Luau/NotNull.h" #include "Luau/TypeVar.h" #include "Luau/Variant.h" @@ -67,6 +68,12 @@ struct BinaryConstraint TypeId leftType; TypeId rightType; TypeId resultType; + + // When we dispatch this constraint, we update the key at this map to record + // the overload that we selected. + AstExpr* expr; + DenseHashMap* astOriginalCallTypes; + DenseHashMap* astOverloadResolvedTypes; }; // iteratee is iterable diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index d81fe9189..c25a55371 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -76,15 +76,27 @@ struct ConstraintGraphBuilder // A mapping of AST node to TypeId. DenseHashMap astTypes{nullptr}; + // A mapping of AST node to TypePackId. DenseHashMap astTypePacks{nullptr}; + + // If the node was applied as a function, this is the unspecialized type of + // that expression. DenseHashMap astOriginalCallTypes{nullptr}; + + // If overload resolution was performed on this element, this is the + // overload that was selected. + DenseHashMap astOverloadResolvedTypes{nullptr}; + // Types resolved from type annotations. Analogous to astTypes. DenseHashMap astResolvedTypes{nullptr}; + // Type packs resolved from type annotations. Analogous to astTypePacks. DenseHashMap astResolvedTypePacks{nullptr}; + // Defining scopes for AST nodes. DenseHashMap astTypeAliasDefiningScopes{nullptr}; + NotNull dfg; ConnectiveArena connectiveArena; diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index e05f6f1f4..c02cd4d5c 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -2,12 +2,13 @@ #pragma once -#include "Luau/Error.h" -#include "Luau/Variant.h" #include "Luau/Constraint.h" -#include "Luau/TypeVar.h" -#include "Luau/ToString.h" +#include "Luau/Error.h" +#include "Luau/Module.h" #include "Luau/Normalize.h" +#include "Luau/ToString.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" #include diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 893880464..739354f87 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -1,16 +1,16 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/FileResolver.h" #include "Luau/Location.h" #include "Luau/TypeVar.h" #include "Luau/Variant.h" -#include "Luau/TypeArena.h" namespace Luau { -struct TypeError; +struct FileResolver; +struct TypeArena; +struct TypeError; struct TypeMismatch { diff --git a/Analysis/include/Luau/Linter.h b/Analysis/include/Luau/Linter.h index 0e3d98803..6bbc3d660 100644 --- a/Analysis/include/Luau/Linter.h +++ b/Analysis/include/Luau/Linter.h @@ -4,6 +4,7 @@ #include "Luau/Location.h" #include +#include #include namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index d7e104ee5..392573155 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -42,6 +42,7 @@ class TypeIds void retain(const TypeIds& tys); void clear(); + TypeId front() const; iterator begin(); iterator end(); const_iterator begin() const; @@ -107,18 +108,7 @@ namespace Luau /** A normalized string type is either `string` (represented by `nullopt`) or a * union of string singletons. * - * When FFlagLuauNegatedStringSingletons is unset, the representation is as - * follows: - * - * * The `string` data type is represented by the option `singletons` having the - * value `std::nullopt`. - * * The type `never` is represented by `singletons` being populated with an - * empty map. - * * A union of string singletons is represented by a map populated by the names - * and TypeIds of the singletons contained therein. - * - * When FFlagLuauNegatedStringSingletons is set, the representation is as - * follows: + * The representation is as follows: * * * A union of string singletons is finite and includes the singletons named by * the `singletons` field. @@ -138,9 +128,7 @@ struct NormalizedStringType // eg string & ~"a" & ~"b" & ... bool isCofinite = false; - // TODO: This field cannot be nullopt when FFlagLuauNegatedStringSingletons - // is set. When clipping that flag, we can remove the wrapping optional. - std::optional> singletons; + std::map singletons; void resetToString(); void resetToNever(); @@ -161,8 +149,8 @@ struct NormalizedStringType static const NormalizedStringType never; - NormalizedStringType() = default; - NormalizedStringType(bool isCofinite, std::optional> singletons); + NormalizedStringType(); + NormalizedStringType(bool isCofinite, std::map singletons); }; bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr); diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index a26f506d6..851ed1a7d 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -1,7 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Constraint.h" #include "Luau/Location.h" #include "Luau/NotNull.h" #include "Luau/TypeVar.h" diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 186cc9a5b..71c0e3595 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -2,13 +2,12 @@ #pragma once #include "Luau/Common.h" -#include "Luau/TypeVar.h" -#include "Luau/ConstraintGraphBuilder.h" -#include -#include #include +#include #include +#include +#include LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) @@ -16,6 +15,22 @@ LUAU_FASTINT(LuauTypeMaximumStringifierLength) namespace Luau { +class AstExpr; + +struct Scope; + +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +struct FunctionTypeVar; +struct Constraint; + +struct Position; +struct Location; + struct ToStringNameMap { std::unordered_map typeVars; @@ -125,4 +140,7 @@ std::string dump(const std::shared_ptr& scope, const char* name); std::string generateName(size_t n); +std::string toString(const Position& position); +std::string toString(const Location& location); + } // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index b1a834126..82605bff7 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -1,12 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include -#include - #include "Luau/TypeVar.h" #include "Luau/TypePack.h" +#include +#include + namespace Luau { diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index aa9cdde2a..6ed70f468 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -4,6 +4,7 @@ #include "Luau/Error.h" #include "Luau/Location.h" #include "Luau/TypeVar.h" +#include "Luau/TypePack.h" #include #include @@ -12,6 +13,7 @@ namespace Luau { struct TxnLog; +struct TypeArena; using ScopePtr = std::shared_ptr; @@ -19,8 +21,6 @@ std::optional findMetatableEntry( NotNull singletonTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); std::optional findTablePropertyRespectingMeta( NotNull singletonTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); -std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, NotNull singletonTypes, - TypeId type, const std::string& prop, const Location& location, bool addErrors, InternalErrorReporter& handle); // Returns the minimum and maximum number of types the argument list can accept. std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index d355746a5..852a40547 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -264,12 +264,14 @@ using DcrMagicFunction = bool (*)(MagicFunctionCallContext); struct MagicRefinementContext { ScopePtr scope; + NotNull cgb; NotNull dfg; NotNull connectiveArena; + std::vector argumentConnectives; const class AstExprCall* callSite; }; -using DcrMagicRefinement = std::vector (*)(MagicRefinementContext); +using DcrMagicRefinement = std::vector (*)(const MagicRefinementContext&); struct FunctionTypeVar { @@ -666,9 +668,6 @@ struct SingletonTypes const TypePackId errorTypePack; }; -// Clip with FFlagLuauNoMoreGlobalSingletonTypes -SingletonTypes& DEPRECATED_getSingletonTypes(); - void persist(TypeId ty); void persist(TypePackId tp); diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index 8d589037f..57c8c90b4 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -6,6 +6,8 @@ #include "Luau/StringUtils.h" #include "Luau/Common.h" +#include + namespace Luau { @@ -103,9 +105,28 @@ struct AstJsonEncoder : public AstVisitor void write(double d) { - char b[32]; - snprintf(b, sizeof(b), "%.17g", d); - writeRaw(b); + switch (fpclassify(d)) + { + case FP_INFINITE: + if (d < 0) + writeRaw("-Infinity"); + else + writeRaw("Infinity"); + break; + + case FP_NAN: + writeRaw("NaN"); + break; + + case FP_NORMAL: + case FP_SUBNORMAL: + case FP_ZERO: + default: + char b[32]; + snprintf(b, sizeof(b), "%.17g", d); + writeRaw(b); + break; + } } void writeString(std::string_view sv) diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 85d2320ae..e6e7f3d93 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -11,6 +11,8 @@ #include +LUAU_FASTFLAG(LuauCompleteTableKeysBetter); + namespace Luau { @@ -29,12 +31,24 @@ struct AutocompleteNodeFinder : public AstVisitor bool visit(AstExpr* expr) override { - if (expr->location.begin < pos && pos <= expr->location.end) + if (FFlag::LuauCompleteTableKeysBetter) { - ancestry.push_back(expr); - return true; + if (expr->location.begin <= pos && pos <= expr->location.end) + { + ancestry.push_back(expr); + return true; + } + return false; + } + else + { + if (expr->location.begin < pos && pos <= expr->location.end) + { + ancestry.push_back(expr); + return true; + } + return false; } - return false; } bool visit(AstStat* stat) override diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 5374c6b17..83a6f0217 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); + static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -966,13 +968,28 @@ T* extractStat(const std::vector& ancestry) if (!parent) return nullptr; - if (T* t = parent->as(); t && parent->is()) - return t; - AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; - if (!grandParent || !greatGrandParent) - return nullptr; + + if (FFlag::LuauCompleteTableKeysBetter) + { + if (!grandParent) + return nullptr; + + if (T* t = parent->as(); t && grandParent->is()) + return t; + + if (!greatGrandParent) + return nullptr; + } + else + { + if (T* t = parent->as(); t && parent->is()) + return t; + + if (!grandParent || !greatGrandParent) + return nullptr; + } if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) return t; @@ -1469,6 +1486,26 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { auto result = autocompleteProps(*module, &typeArena, singletonTypes, *it, PropIndexType::Key, ancestry); + if (FFlag::LuauCompleteTableKeysBetter) + { + if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*nodeIt, !node->is(), result); + + if (!key) + { + // If there is "no key," it may be that the user + // intends for the current token to be the key, but + // has yet to type the `=` sign. + // + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + { + autocompleteStringSingleton(ttv->indexer->indexType, false, result); + } + } + } + // Remove keys that are already completed for (const auto& item : exprTable->items) { diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 39568674c..612812c56 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -7,6 +7,7 @@ #include "Luau/Common.h" #include "Luau/ToString.h" #include "Luau/ConstraintSolver.h" +#include "Luau/ConstraintGraphBuilder.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -46,6 +47,8 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); static bool dcrMagicFunctionPack(MagicFunctionCallContext context); +static std::vector dcrMagicRefinementAssert(const MagicRefinementContext& context); + TypeId makeUnion(TypeArena& arena, std::vector&& types) { return arena.addType(UnionTypeVar{std::move(types)}); @@ -478,6 +481,7 @@ void registerBuiltinGlobals(Frontend& frontend) } attachMagicFunction(getGlobalBinding(frontend, "assert"), magicFunctionAssert); + attachDcrMagicRefinement(getGlobalBinding(frontend, "assert"), dcrMagicRefinementAssert); attachMagicFunction(getGlobalBinding(frontend, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(frontend, "select"), magicFunctionSelect); attachDcrMagicFunction(getGlobalBinding(frontend, "select"), dcrMagicFunctionSelect); @@ -703,6 +707,15 @@ static std::optional> magicFunctionAssert( return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } +static std::vector dcrMagicRefinementAssert(const MagicRefinementContext& ctx) +{ + if (ctx.argumentConnectives.empty()) + return {}; + + ctx.cgb->applyRefinements(ctx.scope, ctx.callSite->location, ctx.argumentConnectives[0]); + return {}; +} + static std::optional> magicFunctionPack( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 600b6d23d..f0bd958cf 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -2,16 +2,12 @@ #include "Luau/ConstraintGraphBuilder.h" #include "Luau/Ast.h" -#include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Constraint.h" #include "Luau/DcrLogger.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" -#include "Luau/Substitution.h" -#include "Luau/ToString.h" -#include "Luau/TxnLog.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" @@ -1068,16 +1064,9 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa else expectedArgs = extendTypePack(*arena, singletonTypes, expectedArgPack, exprArgs.size() - 1); - std::vector connectives; - if (auto ftv = get(follow(fnType)); ftv && ftv->dcrMagicRefinement) - { - MagicRefinementContext ctx{globalScope, dfg, NotNull{&connectiveArena}, call}; - connectives = ftv->dcrMagicRefinement(ctx); - } - - std::vector args; std::optional argTail; + std::vector argumentConnectives; Checkpoint argCheckpoint = checkpoint(this); @@ -1101,7 +1090,11 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa args.push_back(arena->freshType(scope.get())); } else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) - args.push_back(check(scope, arg, expectedType).ty); + { + auto [ty, connective] = check(scope, arg, expectedType); + args.push_back(ty); + argumentConnectives.push_back(connective); + } else argTail = checkPack(scope, arg, {}).tp; // FIXME? not sure about expectedTypes here } @@ -1114,6 +1107,13 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa constraint->dependencies.push_back(extractArgsConstraint); }); + std::vector returnConnectives; + if (auto ftv = get(follow(fnType)); ftv && ftv->dcrMagicRefinement) + { + MagicRefinementContext ctx{scope, NotNull{this}, dfg, NotNull{&connectiveArena}, std::move(argumentConnectives), call}; + returnConnectives = ftv->dcrMagicRefinement(ctx); + } + if (matchSetmetatable(*call)) { TypePack argTailPack; @@ -1133,7 +1133,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa if (AstExprLocal* targetLocal = targetExpr->as()) scope->bindings[targetLocal->local].typeId = resultTy; - return InferencePack{arena->addTypePack({resultTy}), std::move(connectives)}; + return InferencePack{arena->addTypePack({resultTy}), std::move(returnConnectives)}; } else { @@ -1172,7 +1172,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa fcc->dependencies.emplace_back(constraint.get()); }); - return InferencePack{rets, std::move(connectives)}; + return InferencePack{rets, std::move(returnConnectives)}; } } @@ -1468,16 +1468,22 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* bi auto [leftType, rightType, connective] = checkBinary(scope, binary, expectedType); TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); + addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType, binary, &astOriginalCallTypes, &astOverloadResolvedTypes}); return Inference{resultType, std::move(connective)}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) { - check(scope, ifElse->condition); + ScopePtr condScope = childScope(ifElse->condition, scope); + auto [_, connective] = check(scope, ifElse->condition); + + ScopePtr thenScope = childScope(ifElse->trueExpr, scope); + applyRefinements(thenScope, ifElse->trueExpr->location, connective); + TypeId thenType = check(thenScope, ifElse->trueExpr, expectedType).ty; - TypeId thenType = check(scope, ifElse->trueExpr, expectedType).ty; - TypeId elseType = check(scope, ifElse->falseExpr, expectedType).ty; + ScopePtr elseScope = childScope(ifElse->falseExpr, scope); + applyRefinements(elseScope, ifElse->falseExpr->location, connectiveArena.negation(connective)); + TypeId elseType = check(elseScope, ifElse->falseExpr, expectedType).ty; if (ifElse->hasElse) { @@ -1807,7 +1813,9 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS signatureScope->privateTypePackBindings[name] = g.tp; } - expectedType.reset(); + // Local variable works around an odd gcc 11.3 warning: may be used uninitialized + std::optional none = std::nullopt; + expectedType = none; } std::vector argTypes; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index d59ea70ae..d73c14a60 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -15,7 +15,6 @@ #include "Luau/TypeVar.h" #include "Luau/Unifier.h" #include "Luau/VisitTypeVar.h" -#include "Luau/TypeUtils.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); @@ -635,9 +634,17 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope}; + std::optional instantiatedMm = instantiation.substitute(*mm); + if (!instantiatedMm) + { + reportError(CodeTooComplex{}, constraint->location); + return true; + } + // TODO: Is a table with __call legal here? // TODO: Overloads - if (const FunctionTypeVar* ftv = get(follow(*mm))) + if (const FunctionTypeVar* ftv = get(follow(*instantiatedMm))) { TypePackId inferredArgs; // For >= and > we invoke __lt and __le respectively with @@ -673,6 +680,9 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(mmResult); unblock(resultType); + + (*c.astOriginalCallTypes)[c.expr] = *mm; + (*c.astOverloadResolvedTypes)[c.expr] = *instantiatedMm; return true; } } @@ -743,19 +753,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNulladdType(IntersectionTypeVar{{singletonTypes->falsyType, leftType}}); - // TODO: normaliztion here should be replaced by a more limited 'simplification' - const NormalizedType* normalized = normalizer->normalize(arena->addType(UnionTypeVar{{leftFilteredTy, rightType}})); - - if (!normalized) - { - reportError(CodeTooComplex{}, constraint->location); - asMutable(resultType)->ty.emplace(errorRecoveryType()); - } - else - { - asMutable(resultType)->ty.emplace(normalizer->typeFromNormal(*normalized)); - } - + asMutable(resultType)->ty.emplace(arena->addType(UnionTypeVar{{leftFilteredTy, rightType}})); unblock(resultType); return true; } @@ -763,21 +761,9 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNulladdType(IntersectionTypeVar{{singletonTypes->truthyType, leftType}}); - - // TODO: normaliztion here should be replaced by a more limited 'simplification' - const NormalizedType* normalized = normalizer->normalize(arena->addType(UnionTypeVar{{rightFilteredTy, rightType}})); - - if (!normalized) - { - reportError(CodeTooComplex{}, constraint->location); - asMutable(resultType)->ty.emplace(errorRecoveryType()); - } - else - { - asMutable(resultType)->ty.emplace(normalizer->typeFromNormal(*normalized)); - } + TypeId leftFilteredTy = arena->addType(IntersectionTypeVar{{singletonTypes->truthyType, leftType}}); + asMutable(resultType)->ty.emplace(arena->addType(UnionTypeVar{{leftFilteredTy, rightType}})); unblock(resultType); return true; } diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index aefaa2c71..748cf20fa 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -3,6 +3,7 @@ #include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/FileResolver.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 356ced0b3..e21e42e14 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -21,16 +21,15 @@ #include #include #include +#include LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) -LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(DebugLuauLogSolverToJson); -LUAU_FASTFLAGVARIABLE(LuauFixMarkDirtyReverseDeps, false) namespace Luau { @@ -409,7 +408,7 @@ double getTimestamp() } // namespace Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options) - : singletonTypes(NotNull{FFlag::LuauNoMoreGlobalSingletonTypes ? &singletonTypes_ : &DEPRECATED_getSingletonTypes()}) + : singletonTypes(NotNull{&singletonTypes_}) , fileResolver(fileResolver) , moduleResolver(this) , moduleResolverForAutocomplete(this) @@ -819,26 +818,13 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked sourceNode.dirtyModule = true; sourceNode.dirtyModuleForAutocomplete = true; - if (FFlag::LuauFixMarkDirtyReverseDeps) - { - if (0 == reverseDeps.count(next)) - continue; - - sourceModules.erase(next); - - const std::vector& dependents = reverseDeps[next]; - queue.insert(queue.end(), dependents.begin(), dependents.end()); - } - else - { - if (0 == reverseDeps.count(name)) - continue; + if (0 == reverseDeps.count(next)) + continue; - sourceModules.erase(name); + sourceModules.erase(next); - const std::vector& dependents = reverseDeps[name]; - queue.insert(queue.end(), dependents.begin(), dependents.end()); - } + const std::vector& dependents = reverseDeps[next]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); } } @@ -919,6 +905,7 @@ ModulePtr Frontend::check( result->astTypes = std::move(cgb.astTypes); result->astTypePacks = std::move(cgb.astTypePacks); result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); + result->astOverloadResolvedTypes = std::move(cgb.astOverloadResolvedTypes); result->astResolvedTypes = std::move(cgb.astResolvedTypes); result->astResolvedTypePacks = std::move(cgb.astResolvedTypePacks); result->type = sourceModule.type; diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index fa3503fdd..09f0595d6 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -19,12 +19,11 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); -LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false); LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); -LUAU_FASTFLAG(LuauUninhabitedSubAnything) +LUAU_FASTFLAG(LuauUninhabitedSubAnything2) namespace Luau { @@ -46,6 +45,11 @@ void TypeIds::clear() hash = 0; } +TypeId TypeIds::front() const +{ + return order.at(0); +} + TypeIds::iterator TypeIds::begin() { return order.begin(); @@ -111,94 +115,68 @@ bool TypeIds::operator==(const TypeIds& there) const return hash == there.hash && types == there.types; } -NormalizedStringType::NormalizedStringType(bool isCofinite, std::optional> singletons) +NormalizedStringType::NormalizedStringType() +{} + +NormalizedStringType::NormalizedStringType(bool isCofinite, std::map singletons) : isCofinite(isCofinite) , singletons(std::move(singletons)) { - if (!FFlag::LuauNegatedStringSingletons) - LUAU_ASSERT(!isCofinite); } void NormalizedStringType::resetToString() { - if (FFlag::LuauNegatedStringSingletons) - { - isCofinite = true; - singletons->clear(); - } - else - singletons.reset(); + isCofinite = true; + singletons.clear(); } void NormalizedStringType::resetToNever() { - if (FFlag::LuauNegatedStringSingletons) - { - isCofinite = false; - singletons.emplace(); - } - else - { - if (singletons) - singletons->clear(); - else - singletons.emplace(); - } + isCofinite = false; + singletons.clear(); } bool NormalizedStringType::isNever() const { - if (FFlag::LuauNegatedStringSingletons) - return !isCofinite && singletons->empty(); - else - return singletons && singletons->empty(); + return !isCofinite && singletons.empty(); } bool NormalizedStringType::isString() const { - if (FFlag::LuauNegatedStringSingletons) - return isCofinite && singletons->empty(); - else - return !singletons; + return isCofinite && singletons.empty(); } bool NormalizedStringType::isUnion() const { - if (FFlag::LuauNegatedStringSingletons) - return !isCofinite; - else - return singletons.has_value(); + return !isCofinite; } bool NormalizedStringType::isIntersection() const { - if (FFlag::LuauNegatedStringSingletons) - return isCofinite; - else - return false; + return isCofinite; } bool NormalizedStringType::includes(const std::string& str) const { if (isString()) return true; - else if (isUnion() && singletons->count(str)) + else if (isUnion() && singletons.count(str)) return true; - else if (isIntersection() && !singletons->count(str)) + else if (isIntersection() && !singletons.count(str)) return true; else return false; } -const NormalizedStringType NormalizedStringType::never{false, {{}}}; +const NormalizedStringType NormalizedStringType::never; bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr) { if (subStr.isUnion() && superStr.isUnion()) { - for (auto [name, ty] : *subStr.singletons) + for (auto [name, ty] : subStr.singletons) { - if (!superStr.singletons->count(name)) + if (!superStr.singletons.count(name)) return false; } } @@ -251,17 +229,21 @@ static bool isShallowInhabited(const NormalizedType& norm) bool isInhabited_DEPRECATED(const NormalizedType& norm) { - LUAU_ASSERT(!FFlag::LuauUninhabitedSubAnything); + LUAU_ASSERT(!FFlag::LuauUninhabitedSubAnything2); return isShallowInhabited(norm); } bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_set seen) { + // If normalization failed, the type is complex, and so is more likely than not to be inhabited. + if (!norm) + return true; + if (!get(norm->tops) || !get(norm->booleans) || !get(norm->errors) || !get(norm->nils) || !get(norm->numbers) || !get(norm->threads) || !norm->classes.empty() || !norm->strings.isNever() || !norm->functions.isNever()) return true; - + for (const auto& [_, intersect] : norm->tyvars) { if (isInhabited(intersect.get(), seen)) @@ -372,7 +354,7 @@ static bool isNormalizedString(const NormalizedStringType& ty) if (ty.isString()) return true; - for (auto& [str, ty] : *ty.singletons) + for (auto& [str, ty] : ty.singletons) { if (const SingletonTypeVar* stv = get(ty)) { @@ -682,56 +664,46 @@ void Normalizer::unionClasses(TypeIds& heres, const TypeIds& theres) void Normalizer::unionStrings(NormalizedStringType& here, const NormalizedStringType& there) { - if (FFlag::LuauNegatedStringSingletons) + if (there.isString()) + here.resetToString(); + else if (here.isUnion() && there.isUnion()) + here.singletons.insert(there.singletons.begin(), there.singletons.end()); + else if (here.isUnion() && there.isIntersection()) { - if (there.isString()) - here.resetToString(); - else if (here.isUnion() && there.isUnion()) - here.singletons->insert(there.singletons->begin(), there.singletons->end()); - else if (here.isUnion() && there.isIntersection()) - { - here.isCofinite = true; - for (const auto& pair : *there.singletons) - { - auto it = here.singletons->find(pair.first); - if (it != end(*here.singletons)) - here.singletons->erase(it); - else - here.singletons->insert(pair); - } - } - else if (here.isIntersection() && there.isUnion()) + here.isCofinite = true; + for (const auto& pair : there.singletons) { - for (const auto& [name, ty] : *there.singletons) - here.singletons->erase(name); + auto it = here.singletons.find(pair.first); + if (it != end(here.singletons)) + here.singletons.erase(it); + else + here.singletons.insert(pair); } - else if (here.isIntersection() && there.isIntersection()) - { - auto iter = begin(*here.singletons); - auto endIter = end(*here.singletons); + } + else if (here.isIntersection() && there.isUnion()) + { + for (const auto& [name, ty] : there.singletons) + here.singletons.erase(name); + } + else if (here.isIntersection() && there.isIntersection()) + { + auto iter = begin(here.singletons); + auto endIter = end(here.singletons); - while (iter != endIter) + while (iter != endIter) + { + if (!there.singletons.count(iter->first)) { - if (!there.singletons->count(iter->first)) - { - auto eraseIt = iter; - ++iter; - here.singletons->erase(eraseIt); - } - else - ++iter; + auto eraseIt = iter; + ++iter; + here.singletons.erase(eraseIt); } + else + ++iter; } - else - LUAU_ASSERT(!"Unreachable"); } else - { - if (there.isString()) - here.resetToString(); - else if (here.isUnion()) - here.singletons->insert(there.singletons->begin(), there.singletons->end()); - } + LUAU_ASSERT(!"Unreachable"); } std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePackId there) @@ -1116,22 +1088,14 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.booleans = unionOfBools(here.booleans, there); else if (const StringSingleton* sstv = get(stv)) { - if (FFlag::LuauNegatedStringSingletons) + if (here.strings.isCofinite) { - if (here.strings.isCofinite) - { - auto it = here.strings.singletons->find(sstv->value); - if (it != here.strings.singletons->end()) - here.strings.singletons->erase(it); - } - else - here.strings.singletons->insert({sstv->value, there}); + auto it = here.strings.singletons.find(sstv->value); + if (it != here.strings.singletons.end()) + here.strings.singletons.erase(it); } else - { - if (here.strings.isUnion()) - here.strings.singletons->insert({sstv->value, there}); - } + here.strings.singletons.insert({sstv->value, there}); } else LUAU_ASSERT(!"Unreachable"); @@ -1278,7 +1242,6 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) here.threads = singletonTypes->neverType; break; case PrimitiveTypeVar::Function: - LUAU_ASSERT(FFlag::LuauNegatedStringSingletons); here.functions.resetToNever(); break; } @@ -1286,20 +1249,18 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) void Normalizer::subtractSingleton(NormalizedType& here, TypeId ty) { - LUAU_ASSERT(FFlag::LuauNegatedStringSingletons); - const SingletonTypeVar* stv = get(ty); LUAU_ASSERT(stv); if (const StringSingleton* ss = get(stv)) { if (here.strings.isCofinite) - here.strings.singletons->insert({ss->value, ty}); + here.strings.singletons.insert({ss->value, ty}); else { - auto it = here.strings.singletons->find(ss->value); - if (it != here.strings.singletons->end()) - here.strings.singletons->erase(it); + auto it = here.strings.singletons.find(ss->value); + if (it != here.strings.singletons.end()) + here.strings.singletons.erase(it); } } else if (const BooleanSingleton* bs = get(stv)) @@ -1417,12 +1378,12 @@ void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedSt if (here.isString()) here.resetToNever(); - for (auto it = here.singletons->begin(); it != here.singletons->end();) + for (auto it = here.singletons.begin(); it != here.singletons.end();) { - if (there.singletons->count(it->first)) + if (there.singletons.count(it->first)) it++; else - it = here.singletons->erase(it); + it = here.singletons.erase(it); } } @@ -2096,12 +2057,12 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) else if (const StringSingleton* sstv = get(stv)) { if (strings.includes(sstv->value)) - here.strings.singletons->insert({sstv->value, there}); + here.strings.singletons.insert({sstv->value, there}); } else LUAU_ASSERT(!"Unreachable"); } - else if (const NegationTypeVar* ntv = get(there); FFlag::LuauNegatedStringSingletons && ntv) + else if (const NegationTypeVar* ntv = get(there)) { TypeId t = follow(ntv->ty); if (const PrimitiveTypeVar* ptv = get(t)) @@ -2171,14 +2132,14 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) result.push_back(singletonTypes->stringType); else if (norm.strings.isUnion()) { - for (auto& [_, ty] : *norm.strings.singletons) + for (auto& [_, ty] : norm.strings.singletons) result.push_back(ty); } - else if (FFlag::LuauNegatedStringSingletons && norm.strings.isIntersection()) + else if (norm.strings.isIntersection()) { std::vector parts; parts.push_back(singletonTypes->stringType); - for (const auto& [name, ty] : *norm.strings.singletons) + for (const auto& [name, ty] : norm.strings.singletons) parts.push_back(arena->addType(NegationTypeVar{ty})); result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 145e7fa76..ed7c682d6 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ToString.h" +#include "Luau/Constraint.h" +#include "Luau/Location.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" @@ -1572,4 +1574,15 @@ std::optional getFunctionNameAsString(const AstExpr& expr) return s; } + +std::string toString(const Position& position) +{ + return "{ line = " + std::to_string(position.line) + ", col = " + std::to_string(position.column) + " }"; +} + +std::string toString(const Location& location) +{ + return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }"; +} + } // namespace Luau diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 1a73b049f..18596a638 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -2,6 +2,7 @@ #include "Luau/TxnLog.h" #include "Luau/ToString.h" +#include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 84c0ca3b0..8c44f90a5 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -90,6 +90,9 @@ struct TypeChecker2 std::vector> stack; + UnifierSharedState sharedState{&ice}; + Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; + TypeChecker2(NotNull singletonTypes, DcrLogger* logger, const SourceModule* sourceModule, Module* module) : singletonTypes(singletonTypes) , logger(logger) @@ -298,8 +301,6 @@ struct TypeChecker2 TypeArena* arena = &module->internalTypes; TypePackId actualRetType = reconstructPack(ret->list, *arena); - UnifierSharedState sharedState{&ice}; - Normalizer normalizer{arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; u.tryUnify(actualRetType, expectedRetType); @@ -921,7 +922,12 @@ struct TypeChecker2 void visit(AstExprIndexName* indexName) { TypeId leftType = lookupType(indexName->expr); - getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); + + const NormalizedType* norm = normalizer.normalize(leftType); + if (!norm) + reportError(NormalizationTooComplex{}, indexName->indexLocation); + + checkIndexTypeFromType(leftType, *norm, indexName->index.value, indexName->location); } void visit(AstExprIndexExpr* indexExpr) @@ -1109,11 +1115,18 @@ struct TypeChecker2 if (std::optional leftMm = findMetatableEntry(singletonTypes, module->errors, leftType, it->second, expr->left->location)) mm = leftMm; else if (std::optional rightMm = findMetatableEntry(singletonTypes, module->errors, rightType, it->second, expr->right->location)) + { mm = rightMm; + std::swap(leftType, rightType); + } if (mm) { - if (const FunctionTypeVar* ftv = get(follow(*mm))) + TypeId instantiatedMm = module->astOverloadResolvedTypes[expr]; + if (!instantiatedMm) + reportError(CodeTooComplex{}, expr->location); + + else if (const FunctionTypeVar* ftv = get(follow(instantiatedMm))) { TypePackId expectedArgs; // For >= and > we invoke __lt and __le respectively with @@ -1545,9 +1558,7 @@ struct TypeChecker2 template bool isSubtype(TID subTy, TID superTy, NotNull scope) { - UnifierSharedState sharedState{&ice}; TypeArena arena; - Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; u.useScopes = true; @@ -1559,8 +1570,6 @@ struct TypeChecker2 template ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy) { - UnifierSharedState sharedState{&ice}; - Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; u.useScopes = true; u.tryUnify(subTy, superTy); @@ -1587,9 +1596,90 @@ struct TypeChecker2 reportError(std::move(e)); } - std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const std::string& prop, const Location& location, bool addErrors) + void checkIndexTypeFromType(TypeId denormalizedTy, const NormalizedType& norm, const std::string& prop, const Location& location) { - return Luau::getIndexTypeFromType(scope, module->errors, &module->internalTypes, singletonTypes, type, prop, location, addErrors, ice); + bool foundOneProp = false; + std::vector typesMissingTheProp; + + auto fetch = [&](TypeId ty) { + if (!normalizer.isInhabited(ty)) + return; + + bool found = hasIndexTypeFromType(ty, prop, location); + foundOneProp |= found; + if (!found) + typesMissingTheProp.push_back(ty); + }; + + fetch(norm.tops); + fetch(norm.booleans); + for (TypeId ty : norm.classes) + fetch(ty); + fetch(norm.errors); + fetch(norm.nils); + fetch(norm.numbers); + if (!norm.strings.isNever()) + fetch(singletonTypes->stringType); + fetch(norm.threads); + for (TypeId ty : norm.tables) + fetch(ty); + if (norm.functions.isTop) + fetch(singletonTypes->functionType); + else if (!norm.functions.isNever()) + { + if (norm.functions.parts->size() == 1) + fetch(norm.functions.parts->front()); + else + { + std::vector parts; + parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); + fetch(module->internalTypes.addType(IntersectionTypeVar{std::move(parts)})); + } + } + for (const auto& [tyvar, intersect] : norm.tyvars) + { + if (get(intersect->tops)) + { + TypeId ty = normalizer.typeFromNormal(*intersect); + fetch(module->internalTypes.addType(IntersectionTypeVar{{tyvar, ty}})); + } + else + fetch(tyvar); + } + + if (!typesMissingTheProp.empty()) + { + if (foundOneProp) + reportError(TypeError{location, MissingUnionProperty{denormalizedTy, typesMissingTheProp, prop}}); + else + reportError(TypeError{location, UnknownProperty{denormalizedTy, prop}}); + } + } + + bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location) + { + if (get(ty) || get(ty) || get(ty)) + return true; + + if (isString(ty)) + { + std::optional mtIndex = Luau::findMetatableEntry(singletonTypes, module->errors, singletonTypes->stringType, "__index", location); + LUAU_ASSERT(mtIndex); + ty = *mtIndex; + } + + if (getTableType(ty)) + return bool(findTablePropertyRespectingMeta(singletonTypes, module->errors, ty, prop, location)); + else if (const ClassTypeVar* cls = get(ty)) + return bool(lookupClassProp(cls, prop)); + else if (const UnionTypeVar* utv = get(ty)) + ice.ice("getIndexTypeFromTypeHelper cannot take a UnionTypeVar"); + else if (const IntersectionTypeVar* itv = get(ty)) + return std::any_of(begin(itv), end(itv), [&](TypeId part) { + return hasIndexTypeFromType(part, prop, location); + }); + else + return false; } }; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 9f64a6010..aa738ad96 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -52,7 +52,7 @@ LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false) -LUAU_FASTFLAG(LuauUninhabitedSubAnything) +LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAGVARIABLE(LuauCallableClasses, false) namespace Luau @@ -2691,7 +2691,7 @@ static std::optional areEqComparable(NotNull arena, NotNullisInhabited(n); else return isInhabited_DEPRECATED(*n); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 876c45a77..7478ac22c 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -88,103 +88,6 @@ std::optional findTablePropertyRespectingMeta( return std::nullopt; } -std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, NotNull singletonTypes, - TypeId type, const std::string& prop, const Location& location, bool addErrors, InternalErrorReporter& handle) -{ - type = follow(type); - - if (get(type) || get(type) || get(type)) - return type; - - if (auto f = get(type)) - *asMutable(type) = TableTypeVar{TableState::Free, f->level}; - - if (isString(type)) - { - std::optional mtIndex = Luau::findMetatableEntry(singletonTypes, errors, singletonTypes->stringType, "__index", location); - LUAU_ASSERT(mtIndex); - type = *mtIndex; - } - - if (getTableType(type)) - { - return findTablePropertyRespectingMeta(singletonTypes, errors, type, prop, location); - } - else if (const ClassTypeVar* cls = get(type)) - { - if (const Property* p = lookupClassProp(cls, prop)) - return p->type; - } - else if (const UnionTypeVar* utv = get(type)) - { - std::vector goodOptions; - std::vector badOptions; - - for (TypeId t : utv) - { - if (get(follow(t))) - return t; - - if (std::optional ty = - getIndexTypeFromType(scope, errors, arena, singletonTypes, t, prop, location, /* addErrors= */ false, handle)) - goodOptions.push_back(*ty); - else - badOptions.push_back(t); - } - - if (!badOptions.empty()) - { - if (addErrors) - { - if (goodOptions.empty()) - errors.push_back(TypeError{location, UnknownProperty{type, prop}}); - else - errors.push_back(TypeError{location, MissingUnionProperty{type, badOptions, prop}}); - } - return std::nullopt; - } - - goodOptions = reduceUnion(goodOptions); - - if (goodOptions.empty()) - return singletonTypes->neverType; - - if (goodOptions.size() == 1) - return goodOptions[0]; - - return arena->addType(UnionTypeVar{std::move(goodOptions)}); - } - else if (const IntersectionTypeVar* itv = get(type)) - { - std::vector parts; - - for (TypeId t : itv->parts) - { - if (std::optional ty = - getIndexTypeFromType(scope, errors, arena, singletonTypes, t, prop, location, /* addErrors= */ false, handle)) - parts.push_back(*ty); - } - - // If no parts of the intersection had the property we looked up for, it never existed at all. - if (parts.empty()) - { - if (addErrors) - errors.push_back(TypeError{location, UnknownProperty{type, prop}}); - return std::nullopt; - } - - if (parts.size() == 1) - return parts[0]; - - return arena->addType(IntersectionTypeVar{std::move(parts)}); - } - - if (addErrors) - errors.push_back(TypeError{location, UnknownProperty{type, prop}}); - - return std::nullopt; -} - std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics) { size_t minCount = 0; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 6771d89b0..159e77125 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -26,7 +26,6 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) -LUAU_FASTFLAGVARIABLE(LuauNoMoreGlobalSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauNewLibraryTypeNames, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) @@ -890,12 +889,6 @@ TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) return guess; } -SingletonTypes& DEPRECATED_getSingletonTypes() -{ - static SingletonTypes singletonTypes; - return singletonTypes; -} - void persist(TypeId ty) { std::deque queue{ty}; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 5ff405e62..428820054 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -5,12 +5,13 @@ #include "Luau/Instantiation.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" +#include "Luau/StringUtils.h" +#include "Luau/TimeTrace.h" +#include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" -#include "Luau/TimeTrace.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" -#include "Luau/ToString.h" #include @@ -23,7 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner2, false) -LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything, false) +LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauTxnLogTypePackIterator) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -588,7 +589,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.get(subTy)) tryUnifyNegationWithType(subTy, superTy); - else if (FFlag::LuauUninhabitedSubAnything && !normalizer->isInhabited(subTy)) + else if (FFlag::LuauUninhabitedSubAnything2 && !normalizer->isInhabited(subTy)) {} else @@ -1980,7 +1981,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) TypeId osubTy = subTy; TypeId osuperTy = superTy; - if (FFlag::LuauUninhabitedSubAnything && !normalizer->isInhabited(subTy)) + if (FFlag::LuauUninhabitedSubAnything2 && !normalizer->isInhabited(subTy)) return; if (reversed) diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 070511632..aa87d9e86 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -5,6 +5,7 @@ #include #include +#include #include diff --git a/Ast/include/Luau/Location.h b/Ast/include/Luau/Location.h index e39bbf8c5..dbe36becb 100644 --- a/Ast/include/Luau/Location.h +++ b/Ast/include/Luau/Location.h @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include - namespace Luau { @@ -10,130 +8,36 @@ struct Position { unsigned int line, column; - Position(unsigned int line, unsigned int column) - : line(line) - , column(column) - { - } - - bool operator==(const Position& rhs) const - { - return this->column == rhs.column && this->line == rhs.line; - } - bool operator!=(const Position& rhs) const - { - return !(*this == rhs); - } - - bool operator<(const Position& rhs) const - { - if (line == rhs.line) - return column < rhs.column; - else - return line < rhs.line; - } - - bool operator>(const Position& rhs) const - { - if (line == rhs.line) - return column > rhs.column; - else - return line > rhs.line; - } - - bool operator<=(const Position& rhs) const - { - return *this == rhs || *this < rhs; - } + Position(unsigned int line, unsigned int column); - bool operator>=(const Position& rhs) const - { - return *this == rhs || *this > rhs; - } + bool operator==(const Position& rhs) const; + bool operator!=(const Position& rhs) const; + bool operator<(const Position& rhs) const; + bool operator>(const Position& rhs) const; + bool operator<=(const Position& rhs) const; + bool operator>=(const Position& rhs) const; - void shift(const Position& start, const Position& oldEnd, const Position& newEnd) - { - if (*this >= start) - { - if (this->line > oldEnd.line) - this->line += (newEnd.line - oldEnd.line); - else - { - this->line = newEnd.line; - this->column += (newEnd.column - oldEnd.column); - } - } - } + void shift(const Position& start, const Position& oldEnd, const Position& newEnd); }; struct Location { Position begin, end; - Location() - : begin(0, 0) - , end(0, 0) - { - } + Location(); + Location(const Position& begin, const Position& end); + Location(const Position& begin, unsigned int length); + Location(const Location& begin, const Location& end); - Location(const Position& begin, const Position& end) - : begin(begin) - , end(end) - { - } + bool operator==(const Location& rhs) const; + bool operator!=(const Location& rhs) const; - Location(const Position& begin, unsigned int length) - : begin(begin) - , end(begin.line, begin.column + length) - { - } - - Location(const Location& begin, const Location& end) - : begin(begin.begin) - , end(end.end) - { - } - - bool operator==(const Location& rhs) const - { - return this->begin == rhs.begin && this->end == rhs.end; - } - bool operator!=(const Location& rhs) const - { - return !(*this == rhs); - } - - bool encloses(const Location& l) const - { - return begin <= l.begin && end >= l.end; - } - bool overlaps(const Location& l) const - { - return (begin <= l.begin && end >= l.begin) || (begin <= l.end && end >= l.end) || (begin >= l.begin && end <= l.end); - } - bool contains(const Position& p) const - { - return begin <= p && p < end; - } - bool containsClosed(const Position& p) const - { - return begin <= p && p <= end; - } - void extend(const Location& other) - { - if (other.begin < begin) - begin = other.begin; - if (other.end > end) - end = other.end; - } - void shift(const Position& start, const Position& oldEnd, const Position& newEnd) - { - begin.shift(start, oldEnd, newEnd); - end.shift(start, oldEnd, newEnd); - } + bool encloses(const Location& l) const; + bool overlaps(const Location& l) const; + bool contains(const Position& p) const; + bool containsClosed(const Position& p) const; + void extend(const Location& other); + void shift(const Position& start, const Position& oldEnd, const Position& newEnd); }; -std::string toString(const Position& position); -std::string toString(const Location& location); - } // namespace Luau diff --git a/Ast/src/Location.cpp b/Ast/src/Location.cpp index d7a899ed9..67c2dd4b6 100644 --- a/Ast/src/Location.cpp +++ b/Ast/src/Location.cpp @@ -4,14 +4,128 @@ namespace Luau { -std::string toString(const Position& position) +Position::Position(unsigned int line, unsigned int column) + : line(line) + , column(column) { - return "{ line = " + std::to_string(position.line) + ", col = " + std::to_string(position.column) + " }"; } -std::string toString(const Location& location) +bool Position::operator==(const Position& rhs) const { - return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }"; + return this->column == rhs.column && this->line == rhs.line; +} + +bool Position::operator!=(const Position& rhs) const +{ + return !(*this == rhs); +} + +bool Position::operator<(const Position& rhs) const +{ + if (line == rhs.line) + return column < rhs.column; + else + return line < rhs.line; +} + +bool Position::operator>(const Position& rhs) const +{ + if (line == rhs.line) + return column > rhs.column; + else + return line > rhs.line; +} + +bool Position::operator<=(const Position& rhs) const +{ + return *this == rhs || *this < rhs; +} + +bool Position::operator>=(const Position& rhs) const +{ + return *this == rhs || *this > rhs; +} + +void Position::shift(const Position& start, const Position& oldEnd, const Position& newEnd) +{ + if (*this >= start) + { + if (this->line > oldEnd.line) + this->line += (newEnd.line - oldEnd.line); + else + { + this->line = newEnd.line; + this->column += (newEnd.column - oldEnd.column); + } + } +} + +Location::Location() + : begin(0, 0) + , end(0, 0) +{ +} + +Location::Location(const Position& begin, const Position& end) + : begin(begin) + , end(end) +{ +} + +Location::Location(const Position& begin, unsigned int length) + : begin(begin) + , end(begin.line, begin.column + length) +{ +} + +Location::Location(const Location& begin, const Location& end) + : begin(begin.begin) + , end(end.end) +{ +} + +bool Location::operator==(const Location& rhs) const +{ + return this->begin == rhs.begin && this->end == rhs.end; +} + +bool Location::operator!=(const Location& rhs) const +{ + return !(*this == rhs); +} + +bool Location::encloses(const Location& l) const +{ + return begin <= l.begin && end >= l.end; +} + +bool Location::overlaps(const Location& l) const +{ + return (begin <= l.begin && end >= l.begin) || (begin <= l.end && end >= l.end) || (begin >= l.begin && end <= l.end); +} + +bool Location::contains(const Position& p) const +{ + return begin <= p && p < end; +} + +bool Location::containsClosed(const Position& p) const +{ + return begin <= p && p <= end; +} + +void Location::extend(const Location& other) +{ + if (other.begin < begin) + begin = other.begin; + if (other.end > end) + end = other.end; +} + +void Location::shift(const Position& start, const Position& oldEnd, const Position& newEnd) +{ + begin.shift(start, oldEnd, newEnd); + end.shift(start, oldEnd, newEnd); } } // namespace Luau diff --git a/CLI/Ast.cpp b/CLI/Ast.cpp index fd99d2259..99c583936 100644 --- a/CLI/Ast.cpp +++ b/CLI/Ast.cpp @@ -6,6 +6,7 @@ #include "Luau/AstJsonEncoder.h" #include "Luau/Parser.h" #include "Luau/ParseOptions.h" +#include "Luau/ToString.h" #include "FileUtils.h" diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 39ca913f1..1c05b2986 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -59,7 +59,7 @@ static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) } static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i, - Label* labelarr, Label& fallback) + Label* labelarr, Label& next, Label& fallback) { int skip = 0; @@ -89,31 +89,31 @@ static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& emitInstGetGlobal(build, pc, i, fallback); break; case LOP_SETGLOBAL: - emitInstSetGlobal(build, pc, i, labelarr, fallback); + emitInstSetGlobal(build, pc, i, next, fallback); break; case LOP_CALL: - emitInstCall(build, helpers, pc, i, labelarr); + emitInstCall(build, helpers, pc, i); break; case LOP_RETURN: - emitInstReturn(build, helpers, pc, i, labelarr); + emitInstReturn(build, helpers, pc, i); break; case LOP_GETTABLE: - emitInstGetTable(build, pc, i, fallback); + emitInstGetTable(build, pc, fallback); break; case LOP_SETTABLE: - emitInstSetTable(build, pc, i, labelarr, fallback); + emitInstSetTable(build, pc, next, fallback); break; case LOP_GETTABLEKS: emitInstGetTableKS(build, pc, i, fallback); break; case LOP_SETTABLEKS: - emitInstSetTableKS(build, pc, i, labelarr, fallback); + emitInstSetTableKS(build, pc, i, next, fallback); break; case LOP_GETTABLEN: - emitInstGetTableN(build, pc, i, fallback); + emitInstGetTableN(build, pc, fallback); break; case LOP_SETTABLEN: - emitInstSetTableN(build, pc, i, labelarr, fallback); + emitInstSetTableN(build, pc, next, fallback); break; case LOP_JUMP: emitInstJump(build, pc, i, labelarr); @@ -161,94 +161,96 @@ static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& emitInstJumpxEqS(build, pc, i, labelarr); break; case LOP_ADD: - emitInstBinary(build, pc, i, TM_ADD, fallback); + emitInstBinary(build, pc, TM_ADD, fallback); break; case LOP_SUB: - emitInstBinary(build, pc, i, TM_SUB, fallback); + emitInstBinary(build, pc, TM_SUB, fallback); break; case LOP_MUL: - emitInstBinary(build, pc, i, TM_MUL, fallback); + emitInstBinary(build, pc, TM_MUL, fallback); break; case LOP_DIV: - emitInstBinary(build, pc, i, TM_DIV, fallback); + emitInstBinary(build, pc, TM_DIV, fallback); break; case LOP_MOD: - emitInstBinary(build, pc, i, TM_MOD, fallback); + emitInstBinary(build, pc, TM_MOD, fallback); break; case LOP_POW: - emitInstBinary(build, pc, i, TM_POW, fallback); + emitInstBinary(build, pc, TM_POW, fallback); break; case LOP_ADDK: - emitInstBinaryK(build, pc, i, TM_ADD, fallback); + emitInstBinaryK(build, pc, TM_ADD, fallback); break; case LOP_SUBK: - emitInstBinaryK(build, pc, i, TM_SUB, fallback); + emitInstBinaryK(build, pc, TM_SUB, fallback); break; case LOP_MULK: - emitInstBinaryK(build, pc, i, TM_MUL, fallback); + emitInstBinaryK(build, pc, TM_MUL, fallback); break; case LOP_DIVK: - emitInstBinaryK(build, pc, i, TM_DIV, fallback); + emitInstBinaryK(build, pc, TM_DIV, fallback); break; case LOP_MODK: - emitInstBinaryK(build, pc, i, TM_MOD, fallback); + emitInstBinaryK(build, pc, TM_MOD, fallback); break; case LOP_POWK: - emitInstPowK(build, pc, proto->k, i, fallback); + emitInstPowK(build, pc, proto->k, fallback); break; case LOP_NOT: emitInstNot(build, pc); break; case LOP_MINUS: - emitInstMinus(build, pc, i, fallback); + emitInstMinus(build, pc, fallback); break; case LOP_LENGTH: - emitInstLength(build, pc, i, fallback); + emitInstLength(build, pc, fallback); break; case LOP_NEWTABLE: - emitInstNewTable(build, pc, i, labelarr); + emitInstNewTable(build, pc, i, next); break; case LOP_DUPTABLE: - emitInstDupTable(build, pc, i, labelarr); + emitInstDupTable(build, pc, i, next); break; case LOP_SETLIST: - emitInstSetList(build, pc, i, labelarr); + emitInstSetList(build, pc, next); break; case LOP_GETUPVAL: - emitInstGetUpval(build, pc, i); + emitInstGetUpval(build, pc); break; case LOP_SETUPVAL: - emitInstSetUpval(build, pc, i, labelarr); + emitInstSetUpval(build, pc, next); break; case LOP_CLOSEUPVALS: - emitInstCloseUpvals(build, pc, i, labelarr); + emitInstCloseUpvals(build, pc, next); break; case LOP_FASTCALL: - skip = emitInstFastCall(build, pc, i, labelarr); + // We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1 + skip = emitInstFastCall(build, pc, i, next) + 1; break; case LOP_FASTCALL1: - skip = emitInstFastCall1(build, pc, i, labelarr); + // We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1 + skip = emitInstFastCall1(build, pc, i, next) + 1; break; case LOP_FASTCALL2: - skip = emitInstFastCall2(build, pc, i, labelarr); + skip = emitInstFastCall2(build, pc, i, next); break; case LOP_FASTCALL2K: - skip = emitInstFastCall2K(build, pc, i, labelarr); + skip = emitInstFastCall2K(build, pc, i, next); break; case LOP_FORNPREP: - emitInstForNPrep(build, pc, i, labelarr); + emitInstForNPrep(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); break; case LOP_FORNLOOP: - emitInstForNLoop(build, pc, i, labelarr); + emitInstForNLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); break; case LOP_FORGLOOP: - emitinstForGLoop(build, pc, i, labelarr, fallback); + emitinstForGLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next, fallback); break; case LOP_FORGPREP_NEXT: - emitInstForGPrepNext(build, pc, i, labelarr, fallback); + emitInstForGPrepNext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback); break; case LOP_FORGPREP_INEXT: - emitInstForGPrepInext(build, pc, i, labelarr, fallback); + emitInstForGPrepInext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback); break; case LOP_AND: emitInstAnd(build, pc); @@ -266,7 +268,7 @@ static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& emitInstGetImport(build, pc, fallback); break; case LOP_CONCAT: - emitInstConcat(build, pc, i, labelarr); + emitInstConcat(build, pc, i, next); break; default: emitFallback(build, data, op, i); @@ -281,7 +283,8 @@ static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauO switch (op) { case LOP_GETIMPORT: - emitInstGetImportFallback(build, pc, i); + emitSetSavedPc(build, i + 1); + emitInstGetImportFallback(build, LUAU_INSN_A(*pc), pc[1]); break; case LOP_GETTABLE: emitInstGetTableFallback(build, pc, i); @@ -356,11 +359,11 @@ static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauO emitInstLengthFallback(build, pc, i); break; case LOP_FORGLOOP: - emitinstForGLoopFallback(build, pc, i, labelarr); + emitinstForGLoopFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); break; case LOP_FORGPREP_NEXT: case LOP_FORGPREP_INEXT: - emitInstForGPrepXnextFallback(build, pc, i, labelarr); + emitInstForGPrepXnextFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); break; case LOP_GETGLOBAL: // TODO: luaV_gettable + cachedslot update instead of full fallback @@ -430,7 +433,9 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat if (options.annotator) options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i); - int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), instFallbacks[i]); + Label& next = nexti < proto->sizecode ? instLabels[nexti] : start; // Last instruction can't use 'next' label + + int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), next, instFallbacks[i]); if (skip != 0) instOutlines.push_back({nexti, skip}); @@ -454,15 +459,20 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat const Instruction* pc = &proto->code[i]; LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc)); + int nexti = i + getOpLength(op); + LUAU_ASSERT(nexti <= proto->sizecode); + build.setLabel(instLabels[i]); if (options.annotator && !options.skipOutlinedCode) options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i); - int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), instFallbacks[i]); + Label& next = nexti < proto->sizecode ? instLabels[nexti] : start; // Last instruction can't use 'next' label + + int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), next, instFallbacks[i]); LUAU_ASSERT(skip == 0); - i += getOpLength(op); + i = nexti; } if (i < proto->sizecode) diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index cbaa84948..fe258ff83 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -61,10 +61,8 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, } } -void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label, int pcpos) +void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label) { - emitSetSavedPc(build, pcpos + 1); - build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(ra)); build.lea(rArg3, luauRegAddress(rb)); @@ -85,10 +83,8 @@ void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX6 label); } -RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos) +void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos) { - RegisterX64 node = rdx; - LUAU_ASSERT(tmp != node); LUAU_ASSERT(table != node); @@ -102,16 +98,12 @@ RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, // LuaNode* n = &h->node[slot]; build.shl(dwordReg(tmp), kLuaNodeSizeLog2); build.add(node, tmp); - - return node; } -void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label) +void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label) { LUAU_ASSERT(numi.size == SizeX64::dword); - build.vmovsd(numd, luauRegValue(ri)); - // Convert to integer, NaN is converted into 0x80000000 build.vcvttsd2si(numi, numd); @@ -124,10 +116,8 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi build.jcc(ConditionX64::NotZero, label); } -void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm) +void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm) { - emitSetSavedPc(build, pcpos + 1); - if (build.abi == ABIX64::Windows) build.mov(sArg5, tm); else @@ -142,10 +132,8 @@ void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, in emitUpdateBase(build); } -void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb, int pcpos) +void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb) { - emitSetSavedPc(build, pcpos + 1); - build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(ra)); build.lea(rArg3, luauRegAddress(rb)); @@ -154,10 +142,8 @@ void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb, int pcpos) emitUpdateBase(build); } -void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init, int pcpos) +void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init) { - emitSetSavedPc(build, pcpos + 1); - build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(limit)); build.lea(rArg3, luauRegAddress(step)); @@ -165,10 +151,8 @@ void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init, i build.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]); } -void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos) +void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) { - emitSetSavedPc(build, pcpos + 1); - build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(rb)); build.lea(rArg3, c); @@ -178,10 +162,8 @@ void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int p emitUpdateBase(build); } -void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos) +void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) { - emitSetSavedPc(build, pcpos + 1); - build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(rb)); build.lea(rArg3, c); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 615448551..238a0ed42 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -99,7 +99,7 @@ inline OperandX64 luauRegTag(int ri) return dword[rBase + ri * sizeof(TValue) + offsetof(TValue, tt)]; } -inline OperandX64 luauRegValueBoolean(int ri) +inline OperandX64 luauRegValueInt(int ri) { return dword[rBase + ri * sizeof(TValue) + offsetof(TValue, value)]; } @@ -174,7 +174,7 @@ inline void jumpIfFalsy(AssemblyBuilderX64& build, int ri, Label& target, Label& jumpIfTagIs(build, ri, LUA_TNIL, target); // false if nil jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, fallthrough); // true if not nil or boolean - build.cmp(luauRegValueBoolean(ri), 0); + build.cmp(luauRegValueInt(ri), 0); build.jcc(ConditionX64::Equal, target); // true if boolean value is 'true' } @@ -184,7 +184,7 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label jumpIfTagIs(build, ri, LUA_TNIL, fallthrough); // false if nil jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, target); // true if not nil or boolean - build.cmp(luauRegValueBoolean(ri), 0); + build.cmp(luauRegValueInt(ri), 0); build.jcc(ConditionX64::NotEqual, target); // true if boolean value is 'true' } @@ -236,16 +236,16 @@ inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX6 } void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label); -void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label, int pcpos); +void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label); -RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos); -void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label); +void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); +void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); -void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm); -void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb, int pcpos); -void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init, int pcpos); -void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos); -void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos); +void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm); +void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb); +void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init); +void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); +void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip); void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip); void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip); diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 160f0f6cc..abbdb65ca 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -69,7 +69,7 @@ void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc) build.vmovups(luauReg(ra), xmm0); } -void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos) { int ra = LUAU_INSN_A(*pc); int nparams = LUAU_INSN_B(*pc) - 1; @@ -222,7 +222,7 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instr } } -void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos) { emitInterrupt(build, pcpos); @@ -435,7 +435,8 @@ void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, { Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; - jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target, pcpos); + emitSetSavedPc(build, pcpos + 1); + jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target); } void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback) @@ -456,7 +457,8 @@ void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc { Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; - jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], cond, target, pcpos); + emitSetSavedPc(build, pcpos + 1); + jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], cond, target); } void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) @@ -488,7 +490,7 @@ void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpo jumpIfTagIsNot(build, ra, LUA_TBOOLEAN, not_ ? target : exit); - build.test(luauRegValueBoolean(ra), 1); + build.test(luauRegValueInt(ra), 1); build.jcc((aux & 0x1) ^ not_ ? ConditionX64::NotZero : ConditionX64::Zero, target); } @@ -534,7 +536,7 @@ void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpo build.jcc(not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target); } -static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int rc, OperandX64 opc, int pcpos, TMS tm, Label& fallback) +static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int rc, OperandX64 opc, TMS tm, Label& fallback) { jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback); @@ -580,27 +582,29 @@ static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int build.mov(luauRegTag(ra), LUA_TNUMBER); } -void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback) +void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback) { - emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, tm, fallback); + emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), tm, fallback); } void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm) { - callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), pcpos, tm); + emitSetSavedPc(build, pcpos + 1); + callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), tm); } -void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback) +void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback) { - emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, luauConstantValue(LUAU_INSN_C(*pc)), pcpos, tm, fallback); + emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, luauConstantValue(LUAU_INSN_C(*pc)), tm, fallback); } void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm) { - callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstantAddress(LUAU_INSN_C(*pc)), pcpos, tm); + emitSetSavedPc(build, pcpos + 1); + callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstantAddress(LUAU_INSN_C(*pc)), tm); } -void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label& fallback) +void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, Label& fallback) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -647,17 +651,17 @@ void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc) jumpIfFalsy(build, rb, saveone, savezero); build.setLabel(savezero); - build.mov(luauRegValueBoolean(ra), 0); + build.mov(luauRegValueInt(ra), 0); build.jmp(exit); build.setLabel(saveone); - build.mov(luauRegValueBoolean(ra), 1); + build.mov(luauRegValueInt(ra), 1); build.setLabel(exit); build.mov(luauRegTag(ra), LUA_TBOOLEAN); } -void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) +void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -675,10 +679,11 @@ void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) { - callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_B(*pc)), pcpos, TM_UNM); + emitSetSavedPc(build, pcpos + 1); + callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_B(*pc)), TM_UNM); } -void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) +void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -699,35 +704,32 @@ void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, void emitInstLengthFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) { - callLengthHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), pcpos); + emitSetSavedPc(build, pcpos + 1); + callLengthHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc)); } -void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next) { int ra = LUAU_INSN_A(*pc); int b = LUAU_INSN_B(*pc); uint32_t aux = pc[1]; - Label& exit = labelarr[pcpos + 2]; - emitSetSavedPc(build, pcpos + 1); build.mov(rArg1, rState); build.mov(dwordReg(rArg2), aux); - build.mov(dwordReg(rArg3), 1 << (b - 1)); + build.mov(dwordReg(rArg3), b == 0 ? 0 : 1 << (b - 1)); build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); build.mov(luauRegValue(ra), rax); build.mov(luauRegTag(ra), LUA_TTABLE); - callCheckGc(build, pcpos, /* savepc = */ false, exit); + callCheckGc(build, pcpos, /* savepc = */ false, next); } -void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next) { int ra = LUAU_INSN_A(*pc); - Label& exit = labelarr[pcpos + 1]; - emitSetSavedPc(build, pcpos + 1); build.mov(rArg1, rState); @@ -736,18 +738,16 @@ void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo build.mov(luauRegValue(ra), rax); build.mov(luauRegTag(ra), LUA_TTABLE); - callCheckGc(build, pcpos, /* savepc= */ false, exit); + callCheckGc(build, pcpos, /* savepc= */ false, next); } -void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); int c = LUAU_INSN_C(*pc) - 1; uint32_t index = pc[1]; - Label& exit = labelarr[pcpos + 2]; - OperandX64 last = index + c - 1; // Using non-volatile 'rbx' for dynamic 'c' value (for LUA_MULTRET) to skip later recomputation @@ -842,10 +842,10 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos build.setLabel(endLoop); } - callBarrierTableFast(build, table, exit); + callBarrierTableFast(build, table, next); } -void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) +void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc) { int ra = LUAU_INSN_A(*pc); int up = LUAU_INSN_B(*pc); @@ -869,7 +869,7 @@ void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpo build.vmovups(luauReg(ra), xmm0); } -void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, Label& next) { int ra = LUAU_INSN_A(*pc); int up = LUAU_INSN_B(*pc); @@ -884,32 +884,30 @@ void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpo build.vmovups(xmm0, luauReg(ra)); build.vmovups(xmmword[tmp], xmm0); - callBarrierObject(build, tmp, upval, ra, labelarr[pcpos + 1]); + callBarrierObject(build, tmp, upval, ra, next); } -void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, Label& next) { int ra = LUAU_INSN_A(*pc); - Label& skip = labelarr[pcpos + 1]; - // L->openupval != 0 build.mov(rax, qword[rState + offsetof(lua_State, openupval)]); build.test(rax, rax); - build.jcc(ConditionX64::Zero, skip); + build.jcc(ConditionX64::Zero, next); // ra <= L->openuval->v build.lea(rcx, addr[rBase + ra * sizeof(TValue)]); build.cmp(rcx, qword[rax + offsetof(UpVal, v)]); - build.jcc(ConditionX64::Above, skip); + build.jcc(ConditionX64::Above, next); build.mov(rArg2, rcx); build.mov(rArg1, rState); build.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); } -static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, bool customParams, int customParamCount, OperandX64 customArgs, - int pcpos, int instLen, Label* labelarr) +static int emitInstFastCallN( + AssemblyBuilderX64& build, const Instruction* pc, bool customParams, int customParamCount, OperandX64 customArgs, int pcpos, Label& fallback) { int bfid = LUAU_INSN_A(*pc); int skip = LUAU_INSN_C(*pc); @@ -923,11 +921,9 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1; OperandX64 args = customParams ? customArgs : luauRegAddress(ra + 2); - Label& exit = labelarr[pcpos + instLen]; - - jumpIfUnsafeEnv(build, rax, exit); + jumpIfUnsafeEnv(build, rax, fallback); - BuiltinImplResult br = emitBuiltin(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, exit); + BuiltinImplResult br = emitBuiltin(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); if (br.type == BuiltinImplType::UsesFallback) { @@ -945,7 +941,7 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b build.mov(qword[rState + offsetof(lua_State, top)], rax); } - return skip + 2 - instLen; // Return fallback instruction sequence length + return skip; // Return fallback instruction sequence length } // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline @@ -996,8 +992,8 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b build.call(rax); - build.test(eax, eax); // test here will set SF=1 for a negative number and it always sets OF to 0 - build.jcc(ConditionX64::Less, exit); // jl jumps if SF != OF + build.test(eax, eax); // test here will set SF=1 for a negative number and it always sets OF to 0 + build.jcc(ConditionX64::Less, fallback); // jl jumps if SF != OF if (nresults == LUA_MULTRET) { @@ -1014,35 +1010,33 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b build.mov(qword[rState + offsetof(lua_State, top)], rax); } - return skip + 2 - instLen; // Return fallback instruction sequence length + return skip; // Return fallback instruction sequence length } -int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) { - return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 1, /* customArgs */ 0, pcpos, /* instLen */ 1, labelarr); + return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 1, /* customArgs */ 0, pcpos, fallback); } -int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) { - return emitInstFastCallN( - build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegAddress(pc[1]), pcpos, /* instLen */ 2, labelarr); + return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegAddress(pc[1]), pcpos, fallback); } -int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) { return emitInstFastCallN( - build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantAddress(pc[1]), pcpos, /* instLen */ 2, labelarr); + build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantAddress(pc[1]), pcpos, fallback); } -int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) { - return emitInstFastCallN(build, pc, /* customParams */ false, /* customParamCount */ 0, /* customArgs */ 0, pcpos, /* instLen */ 1, labelarr); + return emitInstFastCallN(build, pc, /* customParams */ false, /* customParamCount */ 0, /* customArgs */ 0, pcpos, fallback); } -void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopExit) { int ra = LUAU_INSN_A(*pc); - Label& loopExit = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; Label tryConvert, exit; @@ -1080,18 +1074,18 @@ void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpo // TOOD: place at the end of the function build.setLabel(tryConvert); - callPrepareForN(build, ra + 0, ra + 1, ra + 2, pcpos); + emitSetSavedPc(build, pcpos + 1); + callPrepareForN(build, ra + 0, ra + 1, ra + 2); build.jmp(retry); build.setLabel(exit); } -void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat) { emitInterrupt(build, pcpos); int ra = LUAU_INSN_A(*pc); - Label& loopRepeat = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; RegisterX64 limit = xmm0; RegisterX64 step = xmm1; @@ -1121,14 +1115,11 @@ void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo build.setLabel(exit); } -void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback) +void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback) { int ra = LUAU_INSN_A(*pc); int aux = pc[1]; - Label& loopRepeat = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; - Label& exit = labelarr[pcpos + 2]; - emitInterrupt(build, pcpos); // fast-path: builtin table iteration @@ -1160,13 +1151,13 @@ void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo // while (unsigned(index) < unsigned(sizearray)) Label arrayLoop = build.setLabel(); build.cmp(dwordReg(index), dword[table + offsetof(Table, sizearray)]); - build.jcc(ConditionX64::NotBelow, isIpairsIter ? exit : skipArray); + build.jcc(ConditionX64::NotBelow, isIpairsIter ? loopExit : skipArray); // If element is nil, we increment the index; if it's not, we still need 'index + 1' inside build.inc(index); build.cmp(dword[elemPtr + offsetof(TValue, tt)], LUA_TNIL); - build.jcc(ConditionX64::Equal, isIpairsIter ? exit : skipArrayNil); + build.jcc(ConditionX64::Equal, isIpairsIter ? loopExit : skipArrayNil); // setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); build.mov(luauRegValue(ra + 2), index); @@ -1202,13 +1193,11 @@ void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo } } -void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat) { int ra = LUAU_INSN_A(*pc); int aux = pc[1]; - Label& loopRepeat = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; - emitSetSavedPc(build, pcpos + 1); build.mov(rArg1, rState); @@ -1220,12 +1209,10 @@ void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, build.jcc(ConditionX64::NotZero, loopRepeat); } -void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback) +void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback) { int ra = LUAU_INSN_A(*pc); - Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; - // fast-path: pairs/next jumpIfUnsafeEnv(build, rax, fallback); jumpIfTagIsNot(build, ra + 1, LUA_TTABLE, fallback); @@ -1240,12 +1227,10 @@ void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, int build.jmp(target); } -void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback) +void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback) { int ra = LUAU_INSN_A(*pc); - Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; - // fast-path: ipairs/inext jumpIfUnsafeEnv(build, rax, fallback); jumpIfTagIsNot(build, ra + 1, LUA_TTABLE, fallback); @@ -1264,12 +1249,10 @@ void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, int build.jmp(target); } -void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& target) { int ra = LUAU_INSN_A(*pc); - Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; - build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(ra)); build.mov(dwordReg(rArg3), pcpos + 1); @@ -1353,7 +1336,7 @@ void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc) emitInstOrX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstant(LUAU_INSN_C(*pc))); } -void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) +void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -1376,12 +1359,14 @@ void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcp void emitInstGetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) { + emitSetSavedPc(build, pcpos + 1); + TValue n; setnvalue(&n, LUAU_INSN_C(*pc) + 1); - callGetTable(build, LUAU_INSN_B(*pc), build.bytes(&n, sizeof(n)), LUAU_INSN_A(*pc), pcpos); + callGetTable(build, LUAU_INSN_B(*pc), build.bytes(&n, sizeof(n)), LUAU_INSN_A(*pc)); } -void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback) +void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -1404,17 +1389,19 @@ void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcp build.vmovups(xmm0, luauReg(ra)); build.vmovups(xmmword[rax + c * sizeof(TValue)], xmm0); - callBarrierTable(build, rax, table, ra, labelarr[pcpos + 1]); + callBarrierTable(build, rax, table, ra, next); } void emitInstSetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) { + emitSetSavedPc(build, pcpos + 1); + TValue n; setnvalue(&n, LUAU_INSN_C(*pc) + 1); - callSetTable(build, LUAU_INSN_B(*pc), build.bytes(&n, sizeof(n)), LUAU_INSN_A(*pc), pcpos); + callSetTable(build, LUAU_INSN_B(*pc), build.bytes(&n, sizeof(n)), LUAU_INSN_A(*pc)); } -void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) +void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -1427,29 +1414,33 @@ void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo RegisterX64 table = rcx; build.mov(table, luauRegValue(rb)); - convertNumberToIndexOrJump(build, xmm1, xmm0, eax, rc, fallback); + RegisterX64 intIndex = eax; + RegisterX64 fpIndex = xmm0; + build.vmovsd(fpIndex, luauRegValue(rc)); + convertNumberToIndexOrJump(build, xmm1, fpIndex, intIndex, fallback); // index - 1 - build.dec(eax); + build.dec(intIndex); // unsigned(index - 1) < unsigned(h->sizearray) - build.cmp(dword[table + offsetof(Table, sizearray)], eax); + build.cmp(dword[table + offsetof(Table, sizearray)], intIndex); build.jcc(ConditionX64::BelowEqual, fallback); jumpIfMetatablePresent(build, table, fallback); // setobj2s(L, ra, &h->array[unsigned(index - 1)]); build.mov(rdx, qword[table + offsetof(Table, array)]); - build.shl(eax, kTValueSizeLog2); + build.shl(intIndex, kTValueSizeLog2); setLuauReg(build, xmm0, ra, xmmword[rdx + rax]); } void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) { - callGetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos); + emitSetSavedPc(build, pcpos + 1); + callGetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc)); } -void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback) +void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -1462,13 +1453,16 @@ void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo RegisterX64 table = rcx; build.mov(table, luauRegValue(rb)); - convertNumberToIndexOrJump(build, xmm1, xmm0, eax, rc, fallback); + RegisterX64 intIndex = eax; + RegisterX64 fpIndex = xmm0; + build.vmovsd(fpIndex, luauRegValue(rc)); + convertNumberToIndexOrJump(build, xmm1, fpIndex, intIndex, fallback); // index - 1 - build.dec(eax); + build.dec(intIndex); // unsigned(index - 1) < unsigned(h->sizearray) - build.cmp(dword[table + offsetof(Table, sizearray)], eax); + build.cmp(dword[table + offsetof(Table, sizearray)], intIndex); build.jcc(ConditionX64::BelowEqual, fallback); jumpIfMetatablePresent(build, table, fallback); @@ -1476,16 +1470,17 @@ void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo // setobj2t(L, &h->array[unsigned(index - 1)], ra); build.mov(rdx, qword[table + offsetof(Table, array)]); - build.shl(eax, kTValueSizeLog2); + build.shl(intIndex, kTValueSizeLog2); build.vmovups(xmm0, luauReg(ra)); build.vmovups(xmmword[rdx + rax], xmm0); - callBarrierTable(build, rdx, table, ra, labelarr[pcpos + 1]); + callBarrierTable(build, rdx, table, ra, next); } void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) { - callSetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos); + emitSetSavedPc(build, pcpos + 1); + callSetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc)); } void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback) @@ -1504,13 +1499,8 @@ void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& build.vmovups(luauReg(ra), xmm0); } -void emitInstGetImportFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) +void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux) { - int ra = LUAU_INSN_A(*pc); - uint32_t aux = pc[1]; - - emitSetSavedPc(build, pcpos + 1); - build.mov(rax, sClosure); // luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false) @@ -1548,14 +1538,15 @@ void emitInstGetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pc RegisterX64 table = rcx; build.mov(table, luauRegValue(rb)); - RegisterX64 node = getTableNodeAtCachedSlot(build, rax, table, pcpos); + RegisterX64 node = rdx; + getTableNodeAtCachedSlot(build, rax, node, table, pcpos); jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback); setLuauReg(build, xmm0, ra, luauNodeValue(node)); } -void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback) +void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -1567,14 +1558,15 @@ void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pc build.mov(table, luauRegValue(rb)); // fast-path: set value at the expected slot - RegisterX64 node = getTableNodeAtCachedSlot(build, rax, table, pcpos); + RegisterX64 node = rdx; + getTableNodeAtCachedSlot(build, rax, node, table, pcpos); jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback); jumpIfTableIsReadOnly(build, table, fallback); setNodeValue(build, xmm0, luauNodeValue(node), ra); - callBarrierTable(build, rax, table, ra, labelarr[pcpos + 2]); + callBarrierTable(build, rax, table, ra, next); } void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) @@ -1585,14 +1577,15 @@ void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcp RegisterX64 table = rcx; build.mov(rax, sClosure); build.mov(table, qword[rax + offsetof(Closure, env)]); - RegisterX64 node = getTableNodeAtCachedSlot(build, rax, table, pcpos); + RegisterX64 node = rdx; + getTableNodeAtCachedSlot(build, rax, node, table, pcpos); jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback); setLuauReg(build, xmm0, ra, luauNodeValue(node)); } -void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback) +void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback) { int ra = LUAU_INSN_A(*pc); uint32_t aux = pc[1]; @@ -1600,17 +1593,18 @@ void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcp RegisterX64 table = rcx; build.mov(rax, sClosure); build.mov(table, qword[rax + offsetof(Closure, env)]); - RegisterX64 node = getTableNodeAtCachedSlot(build, rax, table, pcpos); + RegisterX64 node = rdx; + getTableNodeAtCachedSlot(build, rax, node, table, pcpos); jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback); jumpIfTableIsReadOnly(build, table, fallback); setNodeValue(build, xmm0, luauNodeValue(node), ra); - callBarrierTable(build, rax, table, ra, labelarr[pcpos + 2]); + callBarrierTable(build, rax, table, ra, next); } -void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) +void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -1630,7 +1624,7 @@ void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, build.vmovups(xmm0, luauReg(rb)); build.vmovups(luauReg(ra), xmm0); - callCheckGc(build, pcpos, /* savepc= */ false, labelarr[pcpos + 1]); + callCheckGc(build, pcpos, /* savepc= */ false, next); } } // namespace CodeGen diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index ae310acab..1ecb06d4f 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -24,8 +24,8 @@ void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc); void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc); void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc); -void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr); -void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr); +void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); +void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_); @@ -38,52 +38,52 @@ void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pc void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback); +void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback); void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm); -void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback); +void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback); void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm); -void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label& fallback); +void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, Label& fallback); void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc); -void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); +void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback); void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); -void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); +void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback); void emitInstLengthFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); -void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); -void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback); -void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); -void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback); -void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback); -void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); +void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next); +void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next); +void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next); +void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc); +void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, Label& next); +void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, Label& next); +int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); +int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); +int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); +int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); +void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopExit); +void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat); +void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback); +void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat); +void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback); +void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback); +void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& target); void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc); void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc); void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc); -void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); +void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback); void emitInstGetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); -void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback); +void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback); void emitInstSetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); -void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); +void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback); void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); -void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback); +void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback); void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback); -void emitInstGetImportFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); +void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux); void emitInstGetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); -void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback); +void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback); void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); -void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback); -void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); +void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback); +void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next); } // namespace CodeGen } // namespace Luau diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index a14d5f595..1532b7a8f 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -6,6 +6,7 @@ #include "doctest.h" +#include #include using namespace Luau; @@ -58,6 +59,9 @@ TEST_CASE("encode_constants") AstExprConstantBool b{Location(), true}; AstExprConstantNumber n{Location(), 8.2}; AstExprConstantNumber bigNum{Location(), 0.1677721600000003}; + AstExprConstantNumber positiveInfinity{Location(), INFINITY}; + AstExprConstantNumber negativeInfinity{Location(), -INFINITY}; + AstExprConstantNumber nan{Location(), NAN}; AstArray charString; charString.data = const_cast("a\x1d\0\\\"b"); @@ -69,6 +73,9 @@ TEST_CASE("encode_constants") CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b)); CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":8.1999999999999993})", toJson(&n)); CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":0.16777216000000031})", toJson(&bigNum)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":Infinity})", toJson(&positiveInfinity)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":-Infinity})", toJson(&negativeInfinity)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":NaN})", toJson(&nan)); CHECK_EQ("{\"type\":\"AstExprConstantString\",\"location\":\"0,0 - 0,0\",\"value\":\"a\\u001d\\u0000\\\\\\\"b\"}", toJson(&needsEscaping)); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 45baec2ce..123708cab 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2948,6 +2948,71 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") CHECK_EQ(ac.context, AutocompleteContext::String); } +TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") +{ + ScopedFastFlag sff{"LuauCompleteTableKeysBetter", true}; + + check(R"( + type Direction = "up" | "down" + + local a: {[Direction]: boolean} = {[@1] = true} + local b: {[Direction]: boolean} = {["@2"] = true} + local c: {[Direction]: boolean} = {u@3 = true} + local d: {[Direction]: boolean} = {[u@4] = true} + + local e: {[Direction]: boolean} = {[@5]} + local f: {[Direction]: boolean} = {["@6"]} + local g: {[Direction]: boolean} = {u@7} + local h: {[Direction]: boolean} = {[u@8]} + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("\"up\"")); + CHECK(ac.entryMap.count("\"down\"")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("up")); + CHECK(ac.entryMap.count("down")); + + ac = autocomplete('3'); + + CHECK(ac.entryMap.count("up")); + CHECK(ac.entryMap.count("down")); + + ac = autocomplete('4'); + + CHECK(!ac.entryMap.count("up")); + CHECK(!ac.entryMap.count("down")); + + CHECK(ac.entryMap.count("\"up\"")); + CHECK(ac.entryMap.count("\"down\"")); + + ac = autocomplete('5'); + + CHECK(ac.entryMap.count("\"up\"")); + CHECK(ac.entryMap.count("\"down\"")); + + ac = autocomplete('6'); + + CHECK(ac.entryMap.count("up")); + CHECK(ac.entryMap.count("down")); + + ac = autocomplete('7'); + + CHECK(ac.entryMap.count("up")); + CHECK(ac.entryMap.count("down")); + + ac = autocomplete('8'); + + CHECK(!ac.entryMap.count("up")); + CHECK(!ac.entryMap.count("down")); + + CHECK(ac.entryMap.count("\"up\"")); + CHECK(ac.entryMap.count("\"down\"")); +} + TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") { check(R"( diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index b28155e30..eb77ce521 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -2,19 +2,21 @@ #include "Fixture.h" #include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Constraint.h" #include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" #include "Luau/Parser.h" #include "Luau/TypeVar.h" #include "Luau/TypeAttach.h" #include "Luau/Transpiler.h" -#include "Luau/BuiltinDefinitions.h" #include "doctest.h" #include #include #include +#include static const char* mainModuleName = "MainModule"; @@ -27,6 +29,41 @@ extern std::optional randomSeed; // tests/main.cpp namespace Luau { +std::optional TestFileResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) +{ + if (auto name = pathExprToModuleName(currentModuleName, pathExpr)) + return {{*name, false}}; + + return std::nullopt; +} + +const ModulePtr TestFileResolver::getModule(const ModuleName& moduleName) const +{ + LUAU_ASSERT(false); + return nullptr; +} + +bool TestFileResolver::moduleExists(const ModuleName& moduleName) const +{ + auto it = source.find(moduleName); + return (it != source.end()); +} + +std::optional TestFileResolver::readSource(const ModuleName& name) +{ + auto it = source.find(name); + if (it == source.end()) + return std::nullopt; + + SourceCode::Type sourceType = SourceCode::Module; + + auto it2 = sourceTypes.find(name); + if (it2 != sourceTypes.end()) + sourceType = it2->second; + + return SourceCode{it->second, sourceType}; +} + std::optional TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr) { if (AstExprGlobal* g = expr->as()) @@ -90,6 +127,15 @@ std::optional TestFileResolver::getEnvironmentForModule(const Modul return std::nullopt; } +const Config& TestConfigResolver::getConfig(const ModuleName& name) const +{ + auto it = configFiles.find(name); + if (it != configFiles.end()) + return it->second; + + return defaultConfig; +} + Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) , frontend(&fileResolver, &configResolver, diff --git a/tests/Fixture.h b/tests/Fixture.h index 24c9566fe..5d838b163 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -10,59 +10,31 @@ #include "Luau/ModuleResolver.h" #include "Luau/Scope.h" #include "Luau/ToString.h" -#include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "IostreamOptional.h" #include "ScopedFlags.h" -#include #include #include - #include namespace Luau { +struct TypeChecker; + struct TestFileResolver : FileResolver , ModuleResolver { - std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override - { - if (auto name = pathExprToModuleName(currentModuleName, pathExpr)) - return {{*name, false}}; - - return std::nullopt; - } + std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; - const ModulePtr getModule(const ModuleName& moduleName) const override - { - LUAU_ASSERT(false); - return nullptr; - } + const ModulePtr getModule(const ModuleName& moduleName) const override; - bool moduleExists(const ModuleName& moduleName) const override - { - auto it = source.find(moduleName); - return (it != source.end()); - } + bool moduleExists(const ModuleName& moduleName) const override; - std::optional readSource(const ModuleName& name) override - { - auto it = source.find(name); - if (it == source.end()) - return std::nullopt; - - SourceCode::Type sourceType = SourceCode::Module; - - auto it2 = sourceTypes.find(name); - if (it2 != sourceTypes.end()) - sourceType = it2->second; - - return SourceCode{it->second, sourceType}; - } + std::optional readSource(const ModuleName& name) override; std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; @@ -80,14 +52,7 @@ struct TestConfigResolver : ConfigResolver Config defaultConfig; std::unordered_map configFiles; - const Config& getConfig(const ModuleName& name) const override - { - auto it = configFiles.find(name); - if (it != configFiles.end()) - return it->second; - - return defaultConfig; - } + const Config& getConfig(const ModuleName& name) const override; }; struct Fixture diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 4e72dd4e7..6f92b6551 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -519,10 +519,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_is_dirty") TEST_CASE_FIXTURE(FrontendFixture, "mark_non_immediate_reverse_deps_as_dirty") { - ScopedFastFlag sff[] = { - {"LuauFixMarkDirtyReverseDeps", true}, - }; - fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; fileResolver.source["game/Gui/Modules/B"] = R"( return require(game:GetService('Gui').Modules.A) diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index a8f3c7ba0..e6bf00a12 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -393,7 +393,6 @@ TEST_SUITE_END(); struct NormalizeFixture : Fixture { - ScopedFastFlag sff0{"LuauNegatedStringSingletons", true}; ScopedFastFlag sff1{"LuauNegatedFunctionTypes", true}; TypeArena arena; diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 29954dc46..05f49422b 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -269,9 +269,9 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") { o.maxTypeLength = 30; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> () -> ()"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> (a) -> () -> ()"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> (b) -> (a) -> (... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); } else { @@ -299,9 +299,9 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") { o.maxTypeLength = 30; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> () -> ()"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> (a) -> () -> ()"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> (b) -> (a) -> (... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); } else { diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 53c54f4f3..38e246c8d 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -8,8 +8,8 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(LuauNewLibraryTypeNames) TEST_SUITE_BEGIN("TypeAliases"); @@ -525,21 +525,15 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_multi_assign") TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") { - ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; - CheckResult result = check("type t10 = typeof(table)"); LUAU_REQUIRE_NO_ERRORS(result); TypeId ty = getGlobalBinding(frontend, "table"); - if (FFlag::LuauNoMoreGlobalSingletonTypes) - { - CHECK_EQ(toString(ty), "typeof(table)"); - } + if (FFlag::LuauNewLibraryTypeNames) + CHECK(toString(ty) == "typeof(table)"); else - { - CHECK_EQ(toString(ty), "table"); - } + CHECK(toString(ty) == "table"); const TableTypeVar* ttv = get(ty); REQUIRE(ttv); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 7d0621a79..188be63c7 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -533,7 +533,7 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, - {"LuauUninhabitedSubAnything", true}, + {"LuauUninhabitedSubAnything2", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp index e8256f974..0e7fb03de 100644 --- a/tests/TypeInfer.negations.test.cpp +++ b/tests/TypeInfer.negations.test.cpp @@ -13,8 +13,7 @@ namespace struct NegationFixture : Fixture { TypeArena arena; - ScopedFastFlag sff[2]{ - {"LuauNegatedStringSingletons", true}, + ScopedFastFlag sff[1]{ {"LuauSubtypeNormalizer", true}, }; diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 21806082f..93d7361bf 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -25,8 +25,16 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types") local x:string|number = s )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*requireType("s")), "number | string"); - CHECK_EQ(toString(*requireType("x")), "number | string"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(*requireType("s")), "(string & ~(false?)) | number"); + CHECK_EQ(toString(*requireType("x")), "number | string"); + } + else + { + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("x")), "number | string"); + } } TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") @@ -37,8 +45,16 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") local y = x or "s" )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*requireType("s")), "number | string"); - CHECK_EQ(toString(*requireType("y")), "number | string"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(*requireType("s")), "(string & ~(false?)) | number"); + CHECK_EQ(toString(*requireType("y")), "((number | string) & ~(false?)) | string"); + } + else + { + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("y")), "number | string"); + } } TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") @@ -62,7 +78,14 @@ TEST_CASE_FIXTURE(Fixture, "and_does_not_always_add_boolean") local x:boolean|number = s )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*requireType("s")), "number"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(*requireType("s")), "((false?) & string) | number"); + } + else + { + CHECK_EQ(toString(*requireType("s")), "number"); + } } TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") @@ -81,7 +104,14 @@ TEST_CASE_FIXTURE(Fixture, "and_or_ternary") local s = (1/2) > 0.5 and "a" or 10 )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*requireType("s")), "number | string"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(*requireType("s")), "((((false?) & boolean) | string) & ~(false?)) | number"); + } + else + { + CHECK_EQ(toString(*requireType("s")), "number | string"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") @@ -405,11 +435,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_mismatch_metatable") local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) v1 %= v2 )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeMismatch* tm = get(result.errors[0]); - CHECK_EQ(*tm->wantedType, *requireType("v2")); - CHECK_EQ(*tm->givenType, *typeChecker.numberType); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Type 'number' could not be converted into 'V2'" == toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") @@ -781,7 +809,14 @@ local b: number = 1 or a TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ("number?", toString(tm->givenType)); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("((number & ~(false?)) | number)?", toString(tm->givenType)); + } + else + { + CHECK_EQ("number?", toString(tm->givenType)); + } } TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") @@ -842,7 +877,14 @@ TEST_CASE_FIXTURE(Fixture, "refine_and_or") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireType("u"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("((((false?) & ({| x: number? |}?)) | a) & ~(false?)) | number", toString(requireType("u"))); + } + else + { + CHECK_EQ("number", toString(requireType("u"))); + } } TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") @@ -1035,10 +1077,10 @@ local w = c and 1 if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK("number?" == toString(requireType("x"))); - CHECK("number" == toString(requireType("y"))); - CHECK("false | number" == toString(requireType("z"))); - CHECK("number" == toString(requireType("w"))); // Normalizer considers free & falsy == never + CHECK("((false?) & (number?)) | number" == toString(requireType("x"))); + CHECK("((false?) & string) | number" == toString(requireType("y"))); + CHECK("((false?) & boolean) | number" == toString(requireType("z"))); + CHECK("((false?) & a) | number" == toString(requireType("w"))); } else { @@ -1073,12 +1115,12 @@ local f1 = f or 'f' if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK("number | string" == toString(requireType("a1"))); - CHECK("number" == toString(requireType("b1"))); - CHECK("string | true" == toString(requireType("c1"))); - CHECK("string | true" == toString(requireType("d1"))); - CHECK("string" == toString(requireType("e1"))); - CHECK("string" == toString(requireType("f1"))); + CHECK("((false | number) & ~(false?)) | string" == toString(requireType("a1"))); + CHECK("((number?) & ~(false?)) | number" == toString(requireType("b1"))); + CHECK("(boolean & ~(false?)) | string" == toString(requireType("c1"))); + CHECK("(true & ~(false?)) | string" == toString(requireType("d1"))); + CHECK("(false & ~(false?)) | string" == toString(requireType("e1"))); + CHECK("(nil & ~(false?)) | string" == toString(requireType("f1"))); } else { diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 259341744..b7408f876 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) + TEST_SUITE_BEGIN("ProvisionalTests"); // These tests check for behavior that differs from the final behavior we'd @@ -776,4 +778,32 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity_but_any_is // CHECK(!isSubtype(b, c)); } +TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") +{ + CheckResult result = check(R"( + local t: {x: number?} = {x = nil} + + if t.x then + local u: {x: number} = t + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauTypeMismatchInvarianceInError) + { + CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' +caused by: + Property 'x' is not compatible. Type 'number?' could not be converted into 'number' in an invariant context)", + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' +caused by: + Property 'x' is not compatible. Type 'number?' could not be converted into 'number')", + toString(result.errors[0])); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index e5bc186a0..5a7c8432a 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,7 +8,6 @@ #include "doctest.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) using namespace Luau; @@ -36,7 +35,7 @@ std::optional> magicFunctionInstanceIsA( return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } -std::vector dcrMagicRefinementInstanceIsA(MagicRefinementContext ctx) +std::vector dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) { if (ctx.callSite->args.size != 1) return {}; @@ -462,35 +461,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_non_binary_expressions_actually_resol LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") -{ - CheckResult result = check(R"( - local t: {x: number?} = {x = nil} - - if t.x then - local u: {x: number} = t - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - if (FFlag::LuauTypeMismatchInvarianceInError) - { - CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' -caused by: - Property 'x' is not compatible. Type 'number?' could not be converted into 'number' in an invariant context)", - toString(result.errors[0])); - } - else - { - CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' -caused by: - Property 'x' is not compatible. Type 'number?' could not be converted into 'number')", - toString(result.errors[0])); - } -} - - TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") { CheckResult result = check(R"( @@ -1009,8 +979,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_a_to_be_truthy_then_assert_a_to_be_nu LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 18}))); - CHECK_EQ("number", toString(requireTypeAtPosition({5, 18}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("((number | string)?) & ~(false?)", toString(requireTypeAtPosition({3, 18}))); + CHECK_EQ("((number | string)?) & ~(false?) & number", toString(requireTypeAtPosition({5, 18}))); + } + else + { + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 18}))); + CHECK_EQ("number", toString(requireTypeAtPosition({5, 18}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") @@ -1031,7 +1009,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "merge_should_be_fully_agnostic_of_hashmap_or if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ("(string | {| x: string |}) & string", toString(requireTypeAtPosition({6, 28}))); + CHECK_EQ("(never | string) & (string | {| x: string |}) & string", toString(requireTypeAtPosition({6, 28}))); } else { @@ -1075,8 +1053,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_ifelse_expression") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireTypeAtPosition({2, 29}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({2, 45}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({2, 29}))); + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({2, 45}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({2, 29}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({2, 45}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_ifelse_expression") @@ -1089,8 +1075,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_ifelse_expressio LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("nil", toString(requireTypeAtPosition({2, 42}))); - CHECK_EQ("string", toString(requireTypeAtPosition({2, 50}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({2, 42}))); + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({2, 50}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({2, 42}))); + CHECK_EQ("string", toString(requireTypeAtPosition({2, 50}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") @@ -1107,8 +1101,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); - CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("any & number", toString(requireTypeAtPosition({6, 49}))); + CHECK_EQ("any & ~number", toString(requireTypeAtPosition({6, 66}))); + } + else + { + CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); + CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index c94ed1f9a..c379559dc 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -17,8 +17,8 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(LuauNewLibraryTypeNames) TEST_SUITE_BEGIN("TableTests"); @@ -1723,8 +1723,6 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") { - ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; - CheckResult result = check(R"( os.h = 2 string.k = 3 @@ -1732,7 +1730,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") LUAU_REQUIRE_ERROR_COUNT(2, result); - if (FFlag::LuauNoMoreGlobalSingletonTypes) + if (FFlag::LuauNewLibraryTypeNames) { CHECK_EQ("Cannot add property 'h' to table 'typeof(os)'", toString(result.errors[0])); CHECK_EQ("Cannot add property 'k' to table 'typeof(string)'", toString(result.errors[1])); @@ -1746,22 +1744,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") { - ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; - CheckResult result = check(R"( --!nonstrict function os:bad() end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauNoMoreGlobalSingletonTypes) - { + if (FFlag::LuauNewLibraryTypeNames) CHECK_EQ("Cannot add property 'bad' to table 'typeof(os)'", toString(result.errors[0])); - } else - { CHECK_EQ("Cannot add property 'bad' to table 'os'", toString(result.errors[0])); - } const TableTypeVar* osType = get(requireType("os")); REQUIRE(osType != nullptr); @@ -3238,7 +3230,8 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shap TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; - ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; + if (!FFlag::LuauNewLibraryTypeNames) + return; CheckResult result = check(R"( local function f(s) @@ -3252,40 +3245,20 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_ LUAU_REQUIRE_ERROR_COUNT(3, result); - if (FFlag::LuauNoMoreGlobalSingletonTypes) - { - CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); - CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + toString(result.errors[0])); + CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[1])); - CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + toString(result.errors[1])); + CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[2])); - } - else - { - CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' -caused by: - The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); - CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' -caused by: - The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[1])); - CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' -caused by: - Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' -caused by: - The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[2])); - } + toString(result.errors[2])); } TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") @@ -3307,7 +3280,8 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compati TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; - ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; + if (!FFlag::LuauNewLibraryTypeNames) + return; CheckResult result = check(R"( local function f(s): string @@ -3317,22 +3291,11 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_ )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauNoMoreGlobalSingletonTypes) - { - CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' + CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' caused by: The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); - CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); - } - else - { - CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' -caused by: - The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); - CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); - } + toString(result.errors[0])); + CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "a_free_shape_can_turn_into_a_scalar_directly") diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 5cc07a286..b1abdf7c9 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -145,7 +145,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_never") ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, - {"LuauUninhabitedSubAnything", true}, + {"LuauUninhabitedSubAnything2", true}, }; CheckResult result = check(R"( @@ -161,7 +161,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_anything") ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, - {"LuauUninhabitedSubAnything", true}, + {"LuauUninhabitedSubAnything2", true}, }; CheckResult result = check(R"( diff --git a/tools/faillist.txt b/tools/faillist.txt index 6f49db84c..5d6779f48 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -27,6 +27,7 @@ AutocompleteTest.do_wrong_compatible_self_calls AutocompleteTest.keyword_methods AutocompleteTest.no_incompatible_self_calls AutocompleteTest.no_wrong_compatible_self_calls_with_generics +AutocompleteTest.string_singleton_as_table_key AutocompleteTest.suggest_external_module_type AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_argument_type_suggestion @@ -88,7 +89,6 @@ DefinitionTests.class_definition_string_props DefinitionTests.declaring_generic_functions DefinitionTests.definition_file_classes FrontendTest.environments -FrontendTest.imported_table_modification_2 FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded FrontendTest.nocheck_cycle_used_by_checked FrontendTest.reexport_cyclic_type @@ -96,7 +96,6 @@ FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics2 GenericsTests.better_mismatch_error_messages -GenericsTests.calling_self_generic_methods GenericsTests.check_generic_typepack_function GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions @@ -113,7 +112,6 @@ GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_lib_function_function_argument -GenericsTests.infer_generic_methods GenericsTests.infer_generic_property GenericsTests.instantiated_function_argument_names GenericsTests.instantiation_sharing_types @@ -147,6 +145,7 @@ ParseErrorRecovery.generic_type_list_recovery ParseErrorRecovery.recovery_of_parenthesized_expressions ParserTests.parse_nesting_based_end_detection_failsafe_earlier ParserTests.parse_nesting_based_end_detection_local_function +ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.bail_early_if_unification_is_too_complicated ProvisionalTests.discriminate_from_x_not_equal_to_nil ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack @@ -163,26 +162,16 @@ ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weirditer_should_not_loop_forever ProvisionalTests.while_body_are_also_refined RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string -RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number -RefinementTest.assert_non_binary_expressions_actually_resolve_constraints -RefinementTest.assign_table_with_refined_property_with_a_similar_type_is_illegal RefinementTest.call_an_incompatible_function_after_using_typeguard -RefinementTest.correctly_lookup_property_whose_base_was_previously_refined RefinementTest.correctly_lookup_property_whose_base_was_previously_refined2 RefinementTest.discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false RefinementTest.discriminate_tag RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil -RefinementTest.fuzz_filtered_refined_types_are_followed -RefinementTest.index_on_a_refined_property -RefinementTest.invert_is_truthy_constraint_ifelse_expression -RefinementTest.is_truthy_constraint_ifelse_expression RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true -RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table RefinementTest.refine_unknowns -RefinementTest.type_comparison_ifelse_expression RefinementTest.type_guard_can_filter_for_intersection_of_tables RefinementTest.type_guard_narrowed_into_nothingness RefinementTest.type_narrow_for_all_the_userdata @@ -199,18 +188,18 @@ TableTests.access_index_metamethod_that_returns_variadic TableTests.accidentally_checked_prop_in_opposite_branch TableTests.builtin_table_names TableTests.call_method +TableTests.call_method_with_explicit_self_argument TableTests.cannot_augment_sealed_table TableTests.casting_sealed_tables_with_props_into_table_with_indexer TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.checked_prop_too_early -TableTests.defining_a_method_for_a_builtin_sealed_table_must_fail -TableTests.defining_a_method_for_a_local_sealed_table_must_fail -TableTests.defining_a_self_method_for_a_builtin_sealed_table_must_fail -TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail +TableTests.defining_a_method_for_a_local_unsealed_table_is_ok +TableTests.defining_a_self_method_for_a_local_unsealed_table_is_ok TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index TableTests.dont_quantify_table_that_belongs_to_outer_scope +TableTests.dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table TableTests.dont_suggest_exact_match_keys TableTests.error_detailed_metatable_prop TableTests.expected_indexer_from_table_union @@ -235,12 +224,11 @@ TableTests.infer_indexer_from_value_property_in_literal TableTests.inferred_return_type_of_free_table TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 +TableTests.instantiate_tables_at_scope_level TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors TableTests.less_exponential_blowup_please -TableTests.meta_add -TableTests.meta_add_both_ways TableTests.meta_add_inferred TableTests.metatable_mismatch_should_fail TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred @@ -253,7 +241,6 @@ TableTests.oop_indexer_works TableTests.oop_polymorphic TableTests.open_table_unification_2 TableTests.persistent_sealed_table_is_immutable -TableTests.prop_access_on_key_whose_types_mismatches TableTests.property_lookup_through_tabletypevar_metatable TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table @@ -267,6 +254,7 @@ TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables TableTests.table_call_metamethod_basic +TableTests.table_function_check_use_after_free TableTests.table_indexing_error_location TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict TableTests.table_insert_should_cope_with_optional_properties_in_strict @@ -279,13 +267,17 @@ TableTests.tables_get_names_from_their_locals TableTests.tc_member_function TableTests.tc_member_function_2 TableTests.unification_of_unions_in_a_self_referential_type +TableTests.unifying_tables_shouldnt_uaf1 TableTests.unifying_tables_shouldnt_uaf2 +TableTests.used_colon_correctly TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon +TableTests.used_dot_instead_of_colon_but_correctly ToDot.bound_table ToDot.function ToDot.table ToString.exhaustive_toString_of_cyclic_table +ToString.function_type_with_argument_names_and_self ToString.function_type_with_argument_names_generic ToString.toStringDetailed2 ToString.toStringErrorPack @@ -303,6 +295,7 @@ TryUnifyTests.typepack_unification_should_trim_free_tails TryUnifyTests.variadics_should_use_reversed_properly TypeAliases.cannot_create_cyclic_type_with_unknown_module TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any +TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2 TypeAliases.generic_param_remap TypeAliases.mismatched_generic_type_param TypeAliases.mutually_recursive_types_restriction_not_ok_1 @@ -322,6 +315,7 @@ TypeInfer.checking_should_not_ice TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError +TypeInfer.follow_on_new_types_in_substitution TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.globals TypeInfer.globals2 @@ -335,11 +329,13 @@ TypeInfer.tc_interpolated_string_with_invalid_expression TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 +TypeInferAnyError.metatable_of_any_can_be_a_table TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.classes_without_overloaded_operators_cannot_be_added TypeInferClasses.detailed_class_unification_error TypeInferClasses.higher_order_function_arguments_are_contravariant +TypeInferClasses.index_instance_property TypeInferClasses.optional_class_field_access_error TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferClasses.warn_when_prop_almost_matches @@ -349,7 +345,9 @@ TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_s TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site +TypeInferFunctions.dont_mutate_the_underlying_head_of_typepack_when_calling_with_self TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict +TypeInferFunctions.first_argument_can_be_optional TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite @@ -387,36 +385,45 @@ TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop +TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free +TypeInferModules.bound_free_table_export_is_ok TypeInferModules.custom_require_global -TypeInferModules.do_not_modify_imported_types +TypeInferModules.do_not_modify_imported_types_4 +TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated TypeInferModules.require_a_variadic_function TypeInferModules.type_error_of_unknown_qualified_type -TypeInferOOP.CheckMethodsOfSealed TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory +TypeInferOOP.method_depends_on_table TypeInferOOP.methods_are_topologically_sorted +TypeInferOOP.nonstrict_self_mismatch_tail TypeInferOOP.object_constructor_can_refer_to_method_of_self +TypeInferOOP.table_oop +TypeInferOperators.CallAndOrOfFunctions +TypeInferOperators.CallOrOfFunctions TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators TypeInferOperators.cli_38355_recursive_union +TypeInferOperators.compound_assign_metatable TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.compound_assign_mismatch_op TypeInferOperators.compound_assign_mismatch_result TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown -TypeInferOperators.mm_comparisons_must_return_a_boolean -TypeInferOperators.mm_ops_must_return_a_value +TypeInferOperators.operator_eq_completely_incompatible +TypeInferOperators.or_joins_types_with_no_superfluous_union TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not -TypeInferOperators.refine_and_or TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.UnknownGlobalCompoundAssign +TypeInferOperators.unrelated_classes_cannot_be_compared +TypeInferOperators.unrelated_primitives_cannot_be_compared TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.string_index TypeInferUnknownNever.assign_to_global_which_is_never @@ -432,6 +439,7 @@ TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.unary_minus_of_never TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check +TypePackTests.self_and_varargs_should_work TypePackTests.type_alias_backwards_compatible TypePackTests.type_alias_default_export TypePackTests.type_alias_default_mixed_self From 9958d23caa8437418189b38adf55ebd34e61b15e Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Tue, 3 Jan 2023 19:33:19 +0200 Subject: [PATCH 22/66] Sync to upstream/release/557 --- Analysis/include/Luau/Anyification.h | 8 +- Analysis/include/Luau/ApplyTypeFunction.h | 2 +- Analysis/include/Luau/AstQuery.h | 4 +- Analysis/include/Luau/Autocomplete.h | 6 +- Analysis/include/Luau/BuiltinDefinitions.h | 4 +- Analysis/include/Luau/Clone.h | 2 +- Analysis/include/Luau/Connective.h | 4 +- Analysis/include/Luau/Constraint.h | 8 +- .../include/Luau/ConstraintGraphBuilder.h | 6 +- Analysis/include/Luau/ConstraintSolver.h | 8 +- Analysis/include/Luau/Error.h | 2 +- Analysis/include/Luau/Frontend.h | 12 +- Analysis/include/Luau/Instantiation.h | 2 +- Analysis/include/Luau/IostreamHelpers.h | 4 +- Analysis/include/Luau/LValue.h | 4 +- Analysis/include/Luau/Module.h | 4 +- Analysis/include/Luau/Normalize.h | 54 +- Analysis/include/Luau/Predicate.h | 4 +- Analysis/include/Luau/Quantify.h | 2 +- Analysis/include/Luau/Scope.h | 2 +- Analysis/include/Luau/Substitution.h | 2 +- Analysis/include/Luau/ToDot.h | 4 +- Analysis/include/Luau/ToString.h | 18 +- Analysis/include/Luau/TxnLog.h | 18 +- Analysis/include/Luau/{TypeVar.h => Type.h} | 195 ++-- Analysis/include/Luau/TypeArena.h | 10 +- Analysis/include/Luau/TypeChecker2.h | 4 +- Analysis/include/Luau/TypeInfer.h | 20 +- Analysis/include/Luau/TypePack.h | 6 +- Analysis/include/Luau/TypeUtils.h | 12 +- Analysis/include/Luau/Unifiable.h | 4 +- Analysis/include/Luau/Unifier.h | 16 +- Analysis/include/Luau/UnifierSharedState.h | 2 +- .../Luau/{VisitTypeVar.h => VisitType.h} | 108 +- Analysis/src/Anyification.cpp | 18 +- Analysis/src/ApplyTypeFunction.cpp | 6 +- Analysis/src/AstQuery.cpp | 8 +- Analysis/src/Autocomplete.cpp | 216 ++-- Analysis/src/BuiltinDefinitions.cpp | 142 ++- Analysis/src/Clone.cpp | 126 +- Analysis/src/ConstraintGraphBuilder.cpp | 251 ++-- Analysis/src/ConstraintSolver.cpp | 263 +++-- Analysis/src/Error.cpp | 8 +- Analysis/src/Frontend.cpp | 45 +- Analysis/src/Instantiation.cpp | 24 +- Analysis/src/IostreamHelpers.cpp | 2 +- Analysis/src/Linter.cpp | 10 +- Analysis/src/Module.cpp | 28 +- Analysis/src/Normalize.cpp | 1039 +++++++++++++---- Analysis/src/Quantify.cpp | 40 +- Analysis/src/Substitution.cpp | 32 +- Analysis/src/ToDot.cpp | 68 +- Analysis/src/ToString.cpp | 116 +- Analysis/src/TxnLog.cpp | 38 +- Analysis/src/{TypeVar.cpp => Type.cpp} | 402 +++---- Analysis/src/TypeArena.cpp | 16 +- Analysis/src/TypeAttach.cpp | 54 +- Analysis/src/TypeChecker2.cpp | 188 +-- Analysis/src/TypeInfer.cpp | 481 ++++---- Analysis/src/TypePack.cpp | 6 +- Analysis/src/TypeUtils.cpp | 52 +- Analysis/src/Unifier.cpp | 405 ++++--- Ast/src/Location.cpp | 4 +- CodeGen/include/Luau/AssemblyBuilderX64.h | 7 + CodeGen/src/AssemblyBuilderX64.cpp | 35 + CodeGen/src/EmitBuiltinsX64.cpp | 376 ++++++ CodeGen/src/Fallbacks.cpp | 1 - CodeGen/src/NativeState.cpp | 16 + CodeGen/src/NativeState.h | 15 + Compiler/src/Compiler.cpp | 2 +- Sources.cmake | 8 +- fuzz/proto.cpp | 16 +- tests/AssemblyBuilderX64.test.cpp | 9 + tests/Autocomplete.test.cpp | 10 +- tests/BuiltinDefinitions.test.cpp | 8 +- tests/ClassFixture.cpp | 56 +- tests/Conformance.test.cpp | 26 +- tests/ConstraintGraphBuilderFixture.cpp | 4 +- tests/ConstraintGraphBuilderFixture.h | 2 +- tests/Fixture.cpp | 17 +- tests/Fixture.h | 6 +- tests/Frontend.test.cpp | 2 +- tests/LValue.test.cpp | 48 +- tests/Linter.test.cpp | 12 +- tests/Module.test.cpp | 50 +- tests/NonstrictMode.test.cpp | 10 +- tests/Normalize.test.cpp | 81 +- tests/ToDot.test.cpp | 68 +- tests/ToString.test.cpp | 157 +-- tests/Transpiler.test.cpp | 2 +- tests/TypeInfer.aliases.test.cpp | 28 +- tests/TypeInfer.annotations.test.cpp | 32 +- tests/TypeInfer.anyerror.test.cpp | 8 +- tests/TypeInfer.builtins.test.cpp | 12 +- tests/TypeInfer.classes.test.cpp | 2 +- tests/TypeInfer.definitions.test.cpp | 6 +- tests/TypeInfer.functions.test.cpp | 42 +- tests/TypeInfer.generics.test.cpp | 22 +- tests/TypeInfer.intersectionTypes.test.cpp | 8 +- tests/TypeInfer.loops.test.cpp | 4 +- tests/TypeInfer.modules.test.cpp | 10 +- tests/TypeInfer.negations.test.cpp | 2 + tests/TypeInfer.oop.test.cpp | 10 +- tests/TypeInfer.operators.test.cpp | 12 +- tests/TypeInfer.primitives.test.cpp | 4 +- tests/TypeInfer.provisional.test.cpp | 18 +- tests/TypeInfer.refinements.test.cpp | 23 +- tests/TypeInfer.tables.test.cpp | 116 +- tests/TypeInfer.test.cpp | 26 +- tests/TypeInfer.tryUnify.test.cpp | 84 +- tests/TypeInfer.typePacks.cpp | 26 +- tests/TypeInfer.unionTypes.test.cpp | 4 +- tests/TypePack.test.cpp | 16 +- tests/TypeVar.test.cpp | 176 +-- ...sitTypeVar.test.cpp => VisitType.test.cpp} | 0 tools/faillist.txt | 5 +- tools/natvis/Analysis.natvis | 4 +- 117 files changed, 3810 insertions(+), 2583 deletions(-) rename Analysis/include/Luau/{TypeVar.h => Type.h} (77%) rename Analysis/include/Luau/{VisitTypeVar.h => VisitType.h} (70%) rename Analysis/src/{TypeVar.cpp => Type.cpp} (71%) rename tests/{VisitTypeVar.test.cpp => VisitType.test.cpp} (100%) diff --git a/Analysis/include/Luau/Anyification.h b/Analysis/include/Luau/Anyification.h index a6f3e2a90..7b6f71716 100644 --- a/Analysis/include/Luau/Anyification.h +++ b/Analysis/include/Luau/Anyification.h @@ -4,7 +4,7 @@ #include "Luau/NotNull.h" #include "Luau/Substitution.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include @@ -19,12 +19,12 @@ using ScopePtr = std::shared_ptr; // A substitution which replaces free types by any struct Anyification : Substitution { - Anyification(TypeArena* arena, NotNull scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, TypeId anyType, + Anyification(TypeArena* arena, NotNull scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack); - Anyification(TypeArena* arena, const ScopePtr& scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, TypeId anyType, + Anyification(TypeArena* arena, const ScopePtr& scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack); NotNull scope; - NotNull singletonTypes; + NotNull builtinTypes; InternalErrorReporter* iceHandler; TypeId anyType; diff --git a/Analysis/include/Luau/ApplyTypeFunction.h b/Analysis/include/Luau/ApplyTypeFunction.h index 8da3bc42d..3f5f47fd4 100644 --- a/Analysis/include/Luau/ApplyTypeFunction.h +++ b/Analysis/include/Luau/ApplyTypeFunction.h @@ -3,7 +3,7 @@ #include "Luau/Substitution.h" #include "Luau/TxnLog.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" namespace Luau { diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index 950a19dac..bf7384623 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -13,8 +13,8 @@ struct Binding; struct SourceModule; struct Module; -struct TypeVar; -using TypeId = const TypeVar*; +struct Type; +using TypeId = const Type*; using ScopePtr = std::shared_ptr; diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index f40f8b492..a4101e162 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -2,7 +2,7 @@ #pragma once #include "Luau/Location.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include #include @@ -65,7 +65,7 @@ struct AutocompleteEntry // Set if this suggestion matches the type expected in the context TypeCorrectKind typeCorrect = TypeCorrectKind::None; - std::optional containingClass = std::nullopt; + std::optional containingClass = std::nullopt; std::optional prop = std::nullopt; std::optional documentationSymbol = std::nullopt; Tags tags; @@ -89,7 +89,7 @@ struct AutocompleteResult }; using ModuleName = std::string; -using StringCompletionCallback = std::function(std::string tag, std::optional ctx)>; +using StringCompletionCallback = std::function(std::string tag, std::optional ctx)>; AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 16cccafe9..0604b40e2 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -2,7 +2,7 @@ #pragma once #include "Luau/Scope.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include @@ -48,7 +48,7 @@ void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); -void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName); +void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName); std::string getBuiltinDefinitionSource(); diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index f003c2425..51f1e7a67 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -3,7 +3,7 @@ #include #include "Luau/TypeArena.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include diff --git a/Analysis/include/Luau/Connective.h b/Analysis/include/Luau/Connective.h index 4a6be93c3..d82bc4dda 100644 --- a/Analysis/include/Luau/Connective.h +++ b/Analysis/include/Luau/Connective.h @@ -10,8 +10,8 @@ namespace Luau { -struct TypeVar; -using TypeId = const TypeVar*; +struct Type; +using TypeId = const Type*; struct Negation; struct Conjunction; diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 9eea9c288..b41329548 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -5,7 +5,7 @@ #include "Luau/Def.h" #include "Luau/DenseHash.h" #include "Luau/NotNull.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Variant.h" #include @@ -17,8 +17,8 @@ namespace Luau struct Scope; -struct TypeVar; -using TypeId = const TypeVar*; +struct Type; +using TypeId = const Type*; struct TypePackVar; using TypePackId = const TypePackVar*; @@ -94,7 +94,7 @@ struct NameConstraint // target ~ inst target struct TypeAliasExpansionConstraint { - // Must be a PendingExpansionTypeVar. + // Must be a PendingExpansionType. TypeId target; }; diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index c25a55371..65ea5e093 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -9,7 +9,7 @@ #include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" #include "Luau/Symbol.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Variant.h" #include @@ -61,7 +61,7 @@ struct ConstraintGraphBuilder ModuleName moduleName; ModulePtr module; - NotNull singletonTypes; + NotNull builtinTypes; const NotNull arena; // The root scope of the module we're generating constraints for. // This is null when the CGB is initially constructed. @@ -114,7 +114,7 @@ struct ConstraintGraphBuilder DcrLogger* logger; ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, - NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, + NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, NotNull dfg); /** diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index c02cd4d5c..5c235a354 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -7,7 +7,7 @@ #include "Luau/Module.h" #include "Luau/Normalize.h" #include "Luau/ToString.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Variant.h" #include @@ -44,7 +44,7 @@ struct HashInstantiationSignature struct ConstraintSolver { TypeArena* arena; - NotNull singletonTypes; + NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; // The entire set of constraints that the solver is trying to resolve. @@ -126,13 +126,13 @@ struct ConstraintSolver void block(NotNull target, NotNull constraint); /** - * Block a constraint on the resolution of a TypeVar. + * Block a constraint on the resolution of a Type. * @returns false always. This is just to allow tryDispatch to return the result of block() */ bool block(TypeId target, NotNull constraint); bool block(TypePackId target, NotNull constraint); - // Traverse the type. If any blocked or pending typevars are found, block + // Traverse the type. If any blocked or pending types are found, block // the constraint on them. // // Returns false if a type blocks the constraint. diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 739354f87..69d4cca3c 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -2,7 +2,7 @@ #pragma once #include "Luau/Location.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Variant.h" namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index b2662c688..dfb35cbdb 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -130,8 +130,8 @@ struct Frontend Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {}); CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess - LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); + LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; @@ -165,22 +165,22 @@ struct Frontend ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector requireCycles, bool forAutocomplete = false); - std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); + std::pair getSourceNode(const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); - bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete); + bool parseGraph(std::vector& buildQueue, const ModuleName& root, bool forAutocomplete); static LintResult classifyLints(const std::vector& warnings, const Config& config); - ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete = false); + ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete); std::unordered_map environments; std::unordered_map> builtinDefinitions; - SingletonTypes singletonTypes_; + BuiltinTypes builtinTypes_; public: - const NotNull singletonTypes; + const NotNull builtinTypes; FileResolver* fileResolver; FrontendModuleResolver moduleResolver; diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h index cd88d33a4..c916f953b 100644 --- a/Analysis/include/Luau/Instantiation.h +++ b/Analysis/include/Luau/Instantiation.h @@ -2,7 +2,7 @@ #pragma once #include "Luau/Substitution.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Unifiable.h" namespace Luau diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h index 05b94516b..42b362bee 100644 --- a/Analysis/include/Luau/IostreamHelpers.h +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -3,7 +3,7 @@ #include "Luau/Error.h" #include "Luau/Location.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Ast.h" #include @@ -43,7 +43,7 @@ std::ostream& operator<<(std::ostream& lhs, const MissingUnionProperty& error); std::ostream& operator<<(std::ostream& lhs, const TypesAreUnrelated& error); std::ostream& operator<<(std::ostream& lhs, const TableState& tv); -std::ostream& operator<<(std::ostream& lhs, const TypeVar& tv); +std::ostream& operator<<(std::ostream& lhs, const Type& tv); std::ostream& operator<<(std::ostream& lhs, const TypePackVar& tv); std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted); diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 518cbfafe..9a8b863b3 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -10,8 +10,8 @@ namespace Luau { -struct TypeVar; -using TypeId = const TypeVar*; +struct Type; +using TypeId = const Type*; struct Field; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index d22aad12c..d6d9f841b 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -89,8 +89,8 @@ struct Module ScopePtr getModuleScope() const; // Once a module has been typechecked, we clone its public interface into a separate arena. - // This helps us to force TypeVar ownership into a DAG rather than a DCG. - void clonePublicInterface(NotNull singletonTypes, InternalErrorReporter& ice); + // This helps us to force Type ownership into a DAG rather than a DCG. + void clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice); }; } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 392573155..865a9c4d3 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -2,7 +2,7 @@ #pragma once #include "Luau/NotNull.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/UnifierSharedState.h" #include @@ -13,12 +13,12 @@ namespace Luau struct InternalErrorReporter; struct Module; struct Scope; -struct SingletonTypes; +struct BuiltinTypes; using ModulePtr = std::shared_ptr; -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); class TypeIds { @@ -31,7 +31,7 @@ class TypeIds using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; - TypeIds(const TypeIds&) = delete; + TypeIds(const TypeIds&) = default; TypeIds(TypeIds&&) = default; TypeIds() = default; ~TypeIds() = default; @@ -155,6 +155,32 @@ struct NormalizedStringType bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr); +struct NormalizedClassType +{ + /** Has the following structure: + * + * (C1 & ~N11 & ... & ~Nn) | (C2 & ~N21 & ... & ~N2n) | ... + * + * C2 is either not a subtype of any other Cm, or it is and is also a + * subtype of one of Nmn types within the same cluster. + * + * Each TypeId is a class type. + */ + std::unordered_map classes; + + /** + * In order to maintain a consistent insertion order, we use this vector to + * keep track of it. An ordered std::map will sort by pointer identity, + * which is undesirable. + */ + std::vector ordering; + + void pushPair(TypeId ty, TypeIds negations); + + void resetToNever(); + bool isNever() const; +}; + // A normalized function type can be `never`, the top function type `function`, // or an intersection of function types. // @@ -200,9 +226,11 @@ struct NormalizedType // This type is either never, boolean type, or a boolean singleton. TypeId booleans; + NormalizedClassType classes; + // The class part of the type. // Each element of this set is a class, and none of the classes are subclasses of each other. - TypeIds classes; + TypeIds DEPRECATED_classes; // The error part of the type. // This type is either never or the error type. @@ -234,7 +262,7 @@ struct NormalizedType // The generic/free part of the type. NormalizedTyvars tyvars; - NormalizedType(NotNull singletonTypes); + NormalizedType(NotNull builtinTypes); NormalizedType() = delete; ~NormalizedType() = default; @@ -256,10 +284,10 @@ class Normalizer public: TypeArena* arena; - NotNull singletonTypes; + NotNull builtinTypes; NotNull sharedState; - Normalizer(TypeArena* arena, NotNull singletonTypes, NotNull sharedState); + Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState); Normalizer(const Normalizer&) = delete; Normalizer(Normalizer&&) = delete; Normalizer() = delete; @@ -283,6 +311,8 @@ class Normalizer TypeId unionOfBools(TypeId here, TypeId there); void unionClassesWithClass(TypeIds& heres, TypeId there); void unionClasses(TypeIds& heres, const TypeIds& theres); + void unionClassesWithClass(NormalizedClassType& heres, TypeId there); + void unionClasses(NormalizedClassType& heres, const NormalizedClassType& theres); void unionStrings(NormalizedStringType& here, const NormalizedStringType& there); std::optional unionOfTypePacks(TypePackId here, TypePackId there); std::optional unionOfFunctions(TypeId here, TypeId there); @@ -304,8 +334,10 @@ class Normalizer // ------- Normalizing intersections TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there); - void intersectClasses(TypeIds& heres, const TypeIds& theres); - void intersectClassesWithClass(TypeIds& heres, TypeId there); + void DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres); + void DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there); + void intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres); + void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); std::optional intersectionOfTypePacks(TypePackId here, TypePackId there); std::optional intersectionOfTables(TypeId here, TypeId there); diff --git a/Analysis/include/Luau/Predicate.h b/Analysis/include/Luau/Predicate.h index df93b4f49..8d486ad51 100644 --- a/Analysis/include/Luau/Predicate.h +++ b/Analysis/include/Luau/Predicate.h @@ -10,8 +10,8 @@ namespace Luau { -struct TypeVar; -using TypeId = const TypeVar*; +struct Type; +using TypeId = const Type*; struct TruthyPredicate; struct IsAPredicate; diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index 7edf23b8c..b350fab52 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/TypeVar.h" +#include "Luau/Type.h" namespace Luau { diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 851ed1a7d..797c9cb04 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -3,7 +3,7 @@ #include "Luau/Location.h" #include "Luau/NotNull.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include #include diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 6ad38f9de..2efca2df5 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -3,7 +3,7 @@ #include "Luau/TypeArena.h" #include "Luau/TypePack.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/DenseHash.h" // We provide an implementation of substitution on types, diff --git a/Analysis/include/Luau/ToDot.h b/Analysis/include/Luau/ToDot.h index ce518d3ae..1a9c2811a 100644 --- a/Analysis/include/Luau/ToDot.h +++ b/Analysis/include/Luau/ToDot.h @@ -7,8 +7,8 @@ namespace Luau { -struct TypeVar; -using TypeId = const TypeVar*; +struct Type; +using TypeId = const Type*; struct TypePackVar; using TypePackId = const TypePackVar*; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 71c0e3595..dd8aef574 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -19,13 +19,13 @@ class AstExpr; struct Scope; -struct TypeVar; -using TypeId = const TypeVar*; +struct Type; +using TypeId = const Type*; struct TypePackVar; using TypePackId = const TypePackVar*; -struct FunctionTypeVar; +struct FunctionType; struct Constraint; struct Position; @@ -33,7 +33,7 @@ struct Location; struct ToStringNameMap { - std::unordered_map typeVars; + std::unordered_map types; std::unordered_map typePacks; }; @@ -46,7 +46,7 @@ struct ToStringOptions bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self bool DEPRECATED_indent = false; // TODO Deprecated field, prune when clipping flag FFlagLuauLineBreaksDeterminIndents - size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars + size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); ToStringNameMap nameMap; std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' @@ -105,10 +105,10 @@ inline std::string toString(const Constraint& c) return toString(c, ToStringOptions{}); } -std::string toString(const TypeVar& tv, ToStringOptions& opts); +std::string toString(const Type& tv, ToStringOptions& opts); std::string toString(const TypePackVar& tp, ToStringOptions& opts); -inline std::string toString(const TypeVar& tv) +inline std::string toString(const Type& tv) { ToStringOptions opts; return toString(tv, opts); @@ -120,9 +120,9 @@ inline std::string toString(const TypePackVar& tp) return toString(tp, opts); } -std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts); +std::string toStringNamedFunction(const std::string& funcName, const FunctionType& ftv, ToStringOptions& opts); -inline std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv) +inline std::string toStringNamedFunction(const std::string& funcName, const FunctionType& ftv) { ToStringOptions opts; return toStringNamedFunction(funcName, ftv, opts); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 82605bff7..0ed8a49ad 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/TypePack.h" #include @@ -12,14 +12,14 @@ namespace Luau using TypeOrPackId = const void*; -// Pending state for a TypeVar. Generated by a TxnLog and committed via +// Pending state for a Type. Generated by a TxnLog and committed via // TxnLog::commit. struct PendingType { - // The pending TypeVar state. - TypeVar pending; + // The pending Type state. + Type pending; - explicit PendingType(TypeVar state) + explicit PendingType(Type state) : pending(std::move(state)) { } @@ -163,7 +163,7 @@ struct TxnLog // Queues a replacement of a type with another type. // // The pointer returned lives until `commit` or `clear` is called. - PendingType* replace(TypeId ty, TypeVar replacement); + PendingType* replace(TypeId ty, Type replacement); // Queues a replacement of a type pack with another type pack. // @@ -225,7 +225,7 @@ struct TxnLog template PendingType* replace(TypeId ty, T replacement) { - return replace(ty, TypeVar(replacement)); + return replace(ty, Type(replacement)); } // Replaces a given type pack's state with a new variant. Returns the new @@ -262,12 +262,12 @@ struct TxnLog // Returns whether a given type or type pack is a given state, respecting the // log's pending state. // - // This method will not assert if called on a BoundTypeVar or BoundTypePack. + // This method will not assert if called on a BoundType or BoundTypePack. template bool is(TID ty) const { // We do not use getMutable here because this method can be called on - // BoundTypeVars, which triggers an assertion. + // BoundTypes, which triggers an assertion. auto* pendingTy = pending(ty); if (pendingTy) return Luau::get_if(&pendingTy->pending.ty) != nullptr; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/Type.h similarity index 77% rename from Analysis/include/Luau/TypeVar.h rename to Analysis/include/Luau/Type.h index 852a40547..fcc073d8b 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/Type.h @@ -69,46 +69,45 @@ using ScopePtr = std::shared_ptr; struct TypePackVar; using TypePackId = const TypePackVar*; -// TODO: rename to Type? CLI-39100 -struct TypeVar; +struct Type; // Should never be null -using TypeId = const TypeVar*; +using TypeId = const Type*; using Name = std::string; // A free type var is one whose exact shape has yet to be fully determined. -using FreeTypeVar = Unifiable::Free; +using FreeType = Unifiable::Free; // When a free type var is unified with any other, it is then "bound" // to that type var, indicating that the two types are actually the same type. -using BoundTypeVar = Unifiable::Bound; +using BoundType = Unifiable::Bound; -using GenericTypeVar = Unifiable::Generic; +using GenericType = Unifiable::Generic; using Tags = std::vector; using ModuleName = std::string; -/** A TypeVar that cannot be computed. +/** A Type that cannot be computed. * - * BlockedTypeVars essentially serve as a way to encode partial ordering on the - * constraint graph. Until a BlockedTypeVar is unblocked by its owning + * BlockedTypes essentially serve as a way to encode partial ordering on the + * constraint graph. Until a BlockedType is unblocked by its owning * constraint, nothing at all can be said about it. Constraints that need to - * process a BlockedTypeVar cannot be dispatched. + * process a BlockedType cannot be dispatched. * - * Whenever a BlockedTypeVar is added to the graph, we also record a constraint + * Whenever a BlockedType is added to the graph, we also record a constraint * that will eventually unblock it. */ -struct BlockedTypeVar +struct BlockedType { - BlockedTypeVar(); + BlockedType(); int index; static int nextIndex; }; -struct PrimitiveTypeVar +struct PrimitiveType { enum Type { @@ -123,12 +122,12 @@ struct PrimitiveTypeVar Type type; std::optional metatable; // string has a metatable - explicit PrimitiveTypeVar(Type type) + explicit PrimitiveType(Type type) : type(type) { } - explicit PrimitiveTypeVar(Type type, TypeId metatable) + explicit PrimitiveType(Type type, TypeId metatable) : type(type) , metatable(metatable) { @@ -173,25 +172,25 @@ struct StringSingleton using SingletonVariant = Luau::Variant; -struct SingletonTypeVar +struct SingletonType { - explicit SingletonTypeVar(const SingletonVariant& variant) + explicit SingletonType(const SingletonVariant& variant) : variant(variant) { } - explicit SingletonTypeVar(SingletonVariant&& variant) + explicit SingletonType(SingletonVariant&& variant) : variant(std::move(variant)) { } // Default operator== is C++20. - bool operator==(const SingletonTypeVar& rhs) const + bool operator==(const SingletonType& rhs) const { return variant == rhs.variant; } - bool operator!=(const SingletonTypeVar& rhs) const + bool operator!=(const SingletonType& rhs) const { return !(*this == rhs); } @@ -200,7 +199,7 @@ struct SingletonTypeVar }; template -const T* get(const SingletonTypeVar* stv) +const T* get(const SingletonType* stv) { if (stv) return get_if(&stv->variant); @@ -240,7 +239,7 @@ struct FunctionDefinition // TODO: Come up with a better name. // TODO: Do we actually need this? We'll find out later if we can delete this. -// Does not exactly belong in TypeVar.h, but this is the only way to appease the compiler. +// Does not exactly belong in Type.h, but this is the only way to appease the compiler. template struct WithPredicate { @@ -273,24 +272,24 @@ struct MagicRefinementContext using DcrMagicRefinement = std::vector (*)(const MagicRefinementContext&); -struct FunctionTypeVar +struct FunctionType { // Global monomorphic function - FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); + FunctionType(TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Global polymorphic function - FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, + FunctionType(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local monomorphic function - FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); - FunctionTypeVar( + FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); + FunctionType( TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local polymorphic function - FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, + FunctionType(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); - FunctionTypeVar(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, + FunctionType(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); std::optional definition; @@ -348,7 +347,7 @@ struct Property std::optional documentationSymbol; }; -struct TableTypeVar +struct TableType { // We choose std::map over unordered_map here just because we have unit tests that compare // textual outputs. I don't want to spend the effort making them resilient in the case where @@ -356,10 +355,10 @@ struct TableTypeVar // If this shows up in a profile, we can revisit it. using Props = std::map; - TableTypeVar() = default; - explicit TableTypeVar(TableState state, TypeLevel level, Scope* scope = nullptr); - TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, TableState state); - TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, Scope* scope, TableState state); + TableType() = default; + explicit TableType(TableState state, TypeLevel level, Scope* scope = nullptr); + TableType(const Props& props, const std::optional& indexer, TypeLevel level, TableState state); + TableType(const Props& props, const std::optional& indexer, TypeLevel level, Scope* scope, TableState state); Props props; std::optional indexer; @@ -384,12 +383,12 @@ struct TableTypeVar std::optional selfTy; }; -// Represents a metatable attached to a table typevar. Somewhat analogous to a bound typevar. -struct MetatableTypeVar +// Represents a metatable attached to a table type. Somewhat analogous to a bound type. +struct MetatableType { - // Always points to a TableTypeVar. + // Always points to a TableType. TypeId table; - // Always points to either a TableTypeVar or a MetatableTypeVar. + // Always points to either a TableType or a MetatableType. TypeId metatable; std::optional syntheticName; @@ -409,9 +408,9 @@ struct ClassUserData * Classes optionally have a parent class. * Two different classes that share the same properties are nevertheless distinct and mutually incompatible. */ -struct ClassTypeVar +struct ClassType { - using Props = TableTypeVar::Props; + using Props = TableType::Props; Name name; Props props; @@ -421,7 +420,7 @@ struct ClassTypeVar std::shared_ptr userData; ModuleName definitionModuleName; - ClassTypeVar(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, + ClassType(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, std::shared_ptr userData, ModuleName definitionModuleName) : name(name) , props(props) @@ -474,13 +473,13 @@ struct TypeFun * * In order to afford (co)recursive type aliases, we need to reason about a * partially-complete instantiation. This requires encoding more information in - * a type variable than a BlockedTypeVar affords, hence this. Each - * PendingExpansionTypeVar has a corresponding TypeAliasExpansionConstraint + * a type variable than a BlockedType affords, hence this. Each + * PendingExpansionType has a corresponding TypeAliasExpansionConstraint * enqueued in the solver to convert it to an actual instantiated type */ -struct PendingExpansionTypeVar +struct PendingExpansionType { - PendingExpansionTypeVar(std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments); + PendingExpansionType(std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments); std::optional prefix; AstName name; std::vector typeArguments; @@ -491,69 +490,68 @@ struct PendingExpansionTypeVar }; // Anything! All static checking is off. -struct AnyTypeVar +struct AnyType { }; // T | U -struct UnionTypeVar +struct UnionType { std::vector options; }; // T & U -struct IntersectionTypeVar +struct IntersectionType { std::vector parts; }; -struct LazyTypeVar +struct LazyType { std::function thunk; }; -struct UnknownTypeVar +struct UnknownType { }; -struct NeverTypeVar +struct NeverType { }; // ~T // TODO: Some simplification step that overwrites the type graph to make sure negation // types disappear from the user's view, and (?) a debug flag to disable that -struct NegationTypeVar +struct NegationType { TypeId ty; }; -using ErrorTypeVar = Unifiable::Error; +using ErrorType = Unifiable::Error; -using TypeVariant = - Unifiable::Variant; +using TypeVariant = Unifiable::Variant; -struct TypeVar final +struct Type final { - explicit TypeVar(const TypeVariant& ty) + explicit Type(const TypeVariant& ty) : ty(ty) { } - explicit TypeVar(TypeVariant&& ty) + explicit Type(TypeVariant&& ty) : ty(std::move(ty)) { } - TypeVar(const TypeVariant& ty, bool persistent) + Type(const TypeVariant& ty, bool persistent) : ty(ty) , persistent(persistent) { } // Re-assignes the content of the type, but doesn't change the owning arena and can't make type persistent. - void reassign(const TypeVar& rhs) + void reassign(const Type& rhs) { ty = rhs.ty; documentationSymbol = rhs.documentationSymbol; @@ -561,9 +559,9 @@ struct TypeVar final TypeVariant ty; - // Kludge: A persistent TypeVar is one that belongs to the global scope. + // Kludge: A persistent Type is one that belongs to the global scope. // Global type bindings are immutable but are reused many times. - // Persistent TypeVars do not get cloned. + // Persistent Types do not get cloned. bool persistent = false; std::optional documentationSymbol; @@ -571,25 +569,25 @@ struct TypeVar final // Pointer to the type arena that allocated this type. TypeArena* owningArena = nullptr; - bool operator==(const TypeVar& rhs) const; - bool operator!=(const TypeVar& rhs) const; + bool operator==(const Type& rhs) const; + bool operator!=(const Type& rhs) const; - TypeVar& operator=(const TypeVariant& rhs); - TypeVar& operator=(TypeVariant&& rhs); + Type& operator=(const TypeVariant& rhs); + Type& operator=(TypeVariant&& rhs); - TypeVar& operator=(const TypeVar& rhs); + Type& operator=(const Type& rhs); }; using SeenSet = std::set>; -bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); +bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs); -// Follow BoundTypeVars until we get to something real +// Follow BoundTypes until we get to something real TypeId follow(TypeId t); TypeId follow(TypeId t, std::function mapper); std::vector flattenIntersection(TypeId ty); -bool isPrim(TypeId ty, PrimitiveTypeVar::Type primType); +bool isPrim(TypeId ty, PrimitiveType::Type primType); bool isNil(TypeId ty); bool isBoolean(TypeId ty); bool isNumber(TypeId ty); @@ -602,9 +600,9 @@ bool isOverloadedFunction(TypeId ty); // True when string is a subtype of ty bool maybeString(TypeId ty); -std::optional getMetatable(TypeId type, NotNull singletonTypes); -TableTypeVar* getMutableTableType(TypeId type); -const TableTypeVar* getTableType(TypeId type); +std::optional getMetatable(TypeId type, NotNull builtinTypes); +TableType* getMutableTableType(TypeId type); +const TableType* getTableType(TypeId type); // If the type has a name, return that. Else if it has a synthetic name, return that. // Returns nullptr if the type has no name. @@ -614,7 +612,7 @@ const std::string* getName(TypeId type); std::optional getDefinitionModuleName(TypeId type); // Checks whether a union contains all types of another union. -bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub); +bool isSubset(const UnionType& super, const UnionType& sub); // Checks if a type contains generic type binders bool isGeneric(const TypeId ty); @@ -628,12 +626,12 @@ bool maybeSingleton(TypeId ty); // Checks if the length operator can be applied on the value of type bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount); -struct SingletonTypes +struct BuiltinTypes { - SingletonTypes(); - ~SingletonTypes(); - SingletonTypes(const SingletonTypes&) = delete; - void operator=(const SingletonTypes&) = delete; + BuiltinTypes(); + ~BuiltinTypes(); + BuiltinTypes(const BuiltinTypes&) = delete; + void operator=(const BuiltinTypes&) = delete; TypeId errorRecoveryType(TypeId guess); TypePackId errorRecoveryTypePack(TypePackId guess); @@ -653,6 +651,7 @@ struct SingletonTypes const TypeId booleanType; const TypeId threadType; const TypeId functionType; + const TypeId classType; const TypeId trueType; const TypeId falseType; const TypeId anyType; @@ -676,18 +675,18 @@ TypeLevel* getMutableLevel(TypeId ty); std::optional getLevel(TypePackId tp); -const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); -bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); +const Property* lookupClassProp(const ClassType* cls, const Name& name); +bool isSubclass(const ClassType* cls, const ClassType* parent); -TypeVar* asMutable(TypeId ty); +Type* asMutable(TypeId ty); template const T* get(TypeId tv) { LUAU_ASSERT(tv); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tv->ty) == nullptr); + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); return get_if(&tv->ty); } @@ -697,25 +696,25 @@ T* getMutable(TypeId tv) { LUAU_ASSERT(tv); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tv->ty) == nullptr); + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); return get_if(&asMutable(tv)->ty); } -const std::vector& getTypes(const UnionTypeVar* utv); -const std::vector& getTypes(const IntersectionTypeVar* itv); +const std::vector& getTypes(const UnionType* utv); +const std::vector& getTypes(const IntersectionType* itv); template struct TypeIterator; -using UnionTypeVarIterator = TypeIterator; -UnionTypeVarIterator begin(const UnionTypeVar* utv); -UnionTypeVarIterator end(const UnionTypeVar* utv); +using UnionTypeIterator = TypeIterator; +UnionTypeIterator begin(const UnionType* utv); +UnionTypeIterator end(const UnionType* utv); -using IntersectionTypeVarIterator = TypeIterator; -IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv); -IntersectionTypeVarIterator end(const IntersectionTypeVar* itv); +using IntersectionTypeIterator = TypeIterator; +IntersectionTypeIterator begin(const IntersectionType* itv); +IntersectionTypeIterator end(const IntersectionType* itv); /* Traverses the type T yielding each TypeId. * If the iterator encounters a nested type T, it will instead yield each TypeId within. @@ -788,8 +787,8 @@ struct TypeIterator // Normally, we'd have `begin` and `end` be a template but there's too much trouble // with templates portability in this area, so not worth it. Thanks MSVC. - friend UnionTypeVarIterator end(const UnionTypeVar*); - friend IntersectionTypeVarIterator end(const IntersectionTypeVar*); + friend UnionTypeIterator end(const UnionType*); + friend IntersectionTypeIterator end(const IntersectionType*); private: TypeIterator() = default; diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index c67f643bc..0e69bb4aa 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -2,7 +2,7 @@ #pragma once #include "Luau/TypedAllocator.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/TypePack.h" #include @@ -12,7 +12,7 @@ namespace Luau struct TypeArena { - TypedAllocator typeVars; + TypedAllocator types; TypedAllocator typePacks; void clear(); @@ -20,13 +20,13 @@ struct TypeArena template TypeId addType(T tv) { - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) LUAU_ASSERT(tv.options.size() >= 2); - return addTV(TypeVar(std::move(tv))); + return addTV(Type(std::move(tv))); } - TypeId addTV(TypeVar&& tv); + TypeId addTV(Type&& tv); TypeId freshType(TypeLevel level); TypeId freshType(Scope* scope); diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index a9cd6ec8c..6045aecff 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -10,8 +10,8 @@ namespace Luau { struct DcrLogger; -struct SingletonTypes; +struct BuiltinTypes; -void check(NotNull singletonTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module); +void check(NotNull builtinTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index c6f153d1d..4c2d38ad1 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -9,7 +9,7 @@ #include "Luau/Substitution.h" #include "Luau/TxnLog.h" #include "Luau/TypePack.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Unifier.h" #include "Luau/UnifierSharedState.h" @@ -28,7 +28,7 @@ struct ModuleResolver; using Name = std::string; using ScopePtr = std::shared_ptr; -using OverloadErrorEntry = std::tuple, std::vector, const FunctionTypeVar*>; +using OverloadErrorEntry = std::tuple, std::vector, const FunctionType*>; bool doesCallError(const AstExprCall* call); bool hasBreak(AstStat* node); @@ -57,11 +57,11 @@ class TimeLimitError : public InternalCompilerError } }; -// All TypeVars are retained via Environment::typeVars. All TypeIds +// All Types are retained via Environment::types. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker { - explicit TypeChecker(ModuleResolver* resolver, NotNull singletonTypes, InternalErrorReporter* iceHandler); + explicit TypeChecker(ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); TypeChecker(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete; @@ -163,7 +163,7 @@ struct TypeChecker // Reports an error if the type is already some kind of non-table. void tablify(TypeId type); - /** In nonstrict mode, many typevars need to be replaced by any. + /** In nonstrict mode, many types need to be replaced by any. */ TypeId anyIfNonstrict(TypeId ty) const; @@ -287,14 +287,14 @@ struct TypeChecker TypeId unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes = true); // ex - // TypeId id = addType(FreeTypeVar()); + // TypeId id = addType(FreeType()); template TypeId addType(const T& tv) { - return addTV(TypeVar(tv)); + return addTV(Type(tv)); } - TypeId addTV(TypeVar&& tv); + TypeId addTV(Type&& tv); TypePackId addTypePack(TypePackVar&& tp); TypePackId addTypePack(TypePack&& tp); @@ -343,7 +343,7 @@ struct TypeChecker * Calling this function means submitting evidence that the pack must have the length provided. * If the pack is known not to have the correct length, an error will be reported. * The return vector is always of the exact requested length. In the event that the pack's length does - * not match up, excess TypeIds will be ErrorTypeVars. + * not match up, excess TypeIds will be ErrorTypes. */ std::vector unTypePack(const ScopePtr& scope, TypePackId pack, size_t expectedLength, const Location& location); @@ -356,7 +356,7 @@ struct TypeChecker ModuleName currentModuleName; std::function prepareModuleScope; - NotNull singletonTypes; + NotNull builtinTypes; InternalErrorReporter* iceHandler; UnifierSharedState unifierState; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 296880942..4831f2338 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" @@ -45,7 +45,7 @@ struct VariadicTypePack }; /** - * Analogous to a BlockedTypeVar. + * Analogous to a BlockedType. */ struct BlockedTypePack { @@ -83,7 +83,7 @@ struct TypePackVar /* Walk the set of TypeIds in a TypePack. * - * Like TypeVars, individual TypePacks can be free, generic, or any. + * Like Types, individual TypePacks can be free, generic, or any. * * We afford the ability to work with these kinds of packs by giving the * iterator a .tail() property that yields the tail-most TypePack in the diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 6ed70f468..3f535a03f 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -3,7 +3,7 @@ #include "Luau/Error.h" #include "Luau/Location.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/TypePack.h" #include @@ -18,16 +18,16 @@ struct TypeArena; using ScopePtr = std::shared_ptr; std::optional findMetatableEntry( - NotNull singletonTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); + NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); std::optional findTablePropertyRespectingMeta( - NotNull singletonTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); + NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); // Returns the minimum and maximum number of types the argument list can accept. std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false); // Extend the provided pack to at least `length` types. // Returns a temporary TypePack that contains those types plus a tail. -TypePack extendTypePack(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length); +TypePack extendTypePack(TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length); /** * Reduces a union by decomposing to the any/error type if it appears in the @@ -41,11 +41,11 @@ std::vector reduceUnion(const std::vector& types); /** * Tries to remove nil from a union type, if there's another option. T | nil * reduces to T, but nil itself does not reduce. - * @param singletonTypes the singleton types to use + * @param builtinTypes the singleton types to use * @param arena the type arena to allocate the new type in, if necessary * @param ty the type to remove nil from * @returns a type with nil removed, or nil itself if that were the only option. */ -TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty); +TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty); } // namespace Luau diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index c43daa21a..15e501f02 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -11,7 +11,7 @@ namespace Luau struct Scope; /** - * The 'level' of a TypeVar is an indirect way to talk about the scope that it 'belongs' too. + * The 'level' of a Type is an indirect way to talk about the scope that it 'belongs' too. * To start, read http://okmij.org/ftp/ML/generalization.html * * We extend the idea by adding a "sub-level" which helps us to differentiate sibling scopes @@ -132,7 +132,7 @@ struct Generic struct Error { - // This constructor has to be public, since it's used in TypeVar and TypePack, + // This constructor has to be public, since it's used in Type and TypePack, // but shouldn't be called directly. Please use errorRecoveryType() instead. Error(); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index af3864ea8..cd3e856d9 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -25,13 +25,13 @@ enum Variance // A substitution which replaces singleton types by their wider types struct Widen : Substitution { - Widen(TypeArena* arena, NotNull singletonTypes) + Widen(TypeArena* arena, NotNull builtinTypes) : Substitution(TxnLog::empty(), arena) - , singletonTypes(singletonTypes) + , builtinTypes(builtinTypes) { } - NotNull singletonTypes; + NotNull builtinTypes; bool isDirty(TypeId ty) override; bool isDirty(TypePackId ty) override; @@ -52,7 +52,7 @@ struct UnifierOptions struct Unifier { TypeArena* const types; - NotNull singletonTypes; + NotNull builtinTypes; NotNull normalizer; Mode mode; @@ -82,10 +82,10 @@ struct Unifier private: void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); - void tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId superTy); - void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall); - void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv); - void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); + void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy); + void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall); + void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv); + void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error = std::nullopt); void tryUnifyPrimitives(TypeId subTy, TypeId superTy); diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index d4315d471..ada56ec56 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -3,7 +3,7 @@ #include "Luau/DenseHash.h" #include "Luau/Error.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/TypePack.h" #include diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitType.h similarity index 70% rename from Analysis/include/Luau/VisitTypeVar.h rename to Analysis/include/Luau/VisitType.h index 3dcddba19..fdac65856 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitType.h @@ -6,7 +6,7 @@ #include "Luau/DenseHash.h" #include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" LUAU_FASTINT(LuauVisitRecursionLimit) LUAU_FASTFLAG(LuauCompleteVisitor); @@ -19,7 +19,7 @@ namespace visit_detail /** * Apply f(tid, t, seen) if doing so would pass type checking, else apply f(tid, t) * - * We do this to permit (but not require) TypeVar visitors to accept the seen set as an argument. + * We do this to permit (but not require) Type visitors to accept the seen set as an argument. */ template auto apply(A tid, const B& t, C& c, F& f) -> decltype(f(tid, t, c)) @@ -58,13 +58,13 @@ inline void unsee(std::unordered_set& seen, const void* tv) inline void unsee(DenseHashSet& seen, const void* tv) { - // When DenseHashSet is used for 'visitTypeVarOnce', where don't forget visited elements + // When DenseHashSet is used for 'visitTypeOnce', where don't forget visited elements } } // namespace visit_detail template -struct GenericTypeVarVisitor +struct GenericTypeVisitor { using Set = S; @@ -72,9 +72,9 @@ struct GenericTypeVarVisitor bool skipBoundTypes = false; int recursionCounter = 0; - GenericTypeVarVisitor() = default; + GenericTypeVisitor() = default; - explicit GenericTypeVarVisitor(Set seen, bool skipBoundTypes = false) + explicit GenericTypeVisitor(Set seen, bool skipBoundTypes = false) : seen(std::move(seen)) , skipBoundTypes(skipBoundTypes) { @@ -87,75 +87,75 @@ struct GenericTypeVarVisitor { return true; } - virtual bool visit(TypeId ty, const BoundTypeVar& btv) + virtual bool visit(TypeId ty, const BoundType& btv) { return visit(ty); } - virtual bool visit(TypeId ty, const FreeTypeVar& ftv) + virtual bool visit(TypeId ty, const FreeType& ftv) { return visit(ty); } - virtual bool visit(TypeId ty, const GenericTypeVar& gtv) + virtual bool visit(TypeId ty, const GenericType& gtv) { return visit(ty); } - virtual bool visit(TypeId ty, const ErrorTypeVar& etv) + virtual bool visit(TypeId ty, const ErrorType& etv) { return visit(ty); } - virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) + virtual bool visit(TypeId ty, const PrimitiveType& ptv) { return visit(ty); } - virtual bool visit(TypeId ty, const FunctionTypeVar& ftv) + virtual bool visit(TypeId ty, const FunctionType& ftv) { return visit(ty); } - virtual bool visit(TypeId ty, const TableTypeVar& ttv) + virtual bool visit(TypeId ty, const TableType& ttv) { return visit(ty); } - virtual bool visit(TypeId ty, const MetatableTypeVar& mtv) + virtual bool visit(TypeId ty, const MetatableType& mtv) { return visit(ty); } - virtual bool visit(TypeId ty, const ClassTypeVar& ctv) + virtual bool visit(TypeId ty, const ClassType& ctv) { return visit(ty); } - virtual bool visit(TypeId ty, const AnyTypeVar& atv) + virtual bool visit(TypeId ty, const AnyType& atv) { return visit(ty); } - virtual bool visit(TypeId ty, const UnknownTypeVar& utv) + virtual bool visit(TypeId ty, const UnknownType& utv) { return visit(ty); } - virtual bool visit(TypeId ty, const NeverTypeVar& ntv) + virtual bool visit(TypeId ty, const NeverType& ntv) { return visit(ty); } - virtual bool visit(TypeId ty, const UnionTypeVar& utv) + virtual bool visit(TypeId ty, const UnionType& utv) { return visit(ty); } - virtual bool visit(TypeId ty, const IntersectionTypeVar& itv) + virtual bool visit(TypeId ty, const IntersectionType& itv) { return visit(ty); } - virtual bool visit(TypeId ty, const BlockedTypeVar& btv) + virtual bool visit(TypeId ty, const BlockedType& btv) { return visit(ty); } - virtual bool visit(TypeId ty, const PendingExpansionTypeVar& petv) + virtual bool visit(TypeId ty, const PendingExpansionType& petv) { return visit(ty); } - virtual bool visit(TypeId ty, const SingletonTypeVar& stv) + virtual bool visit(TypeId ty, const SingletonType& stv) { return visit(ty); } - virtual bool visit(TypeId ty, const NegationTypeVar& ntv) + virtual bool visit(TypeId ty, const NegationType& ntv) { return visit(ty); } @@ -203,22 +203,22 @@ struct GenericTypeVarVisitor return; } - if (auto btv = get(ty)) + if (auto btv = get(ty)) { if (skipBoundTypes) traverse(btv->boundTo); else if (visit(ty, *btv)) traverse(btv->boundTo); } - else if (auto ftv = get(ty)) + else if (auto ftv = get(ty)) visit(ty, *ftv); - else if (auto gtv = get(ty)) + else if (auto gtv = get(ty)) visit(ty, *gtv); - else if (auto etv = get(ty)) + else if (auto etv = get(ty)) visit(ty, *etv); - else if (auto ptv = get(ty)) + else if (auto ptv = get(ty)) visit(ty, *ptv); - else if (auto ftv = get(ty)) + else if (auto ftv = get(ty)) { if (visit(ty, *ftv)) { @@ -226,7 +226,7 @@ struct GenericTypeVarVisitor traverse(ftv->retTypes); } } - else if (auto ttv = get(ty)) + else if (auto ttv = get(ty)) { // Some visitors want to see bound tables, that's why we traverse the original type if (skipBoundTypes && ttv->boundTo) @@ -252,7 +252,7 @@ struct GenericTypeVarVisitor } } } - else if (auto mtv = get(ty)) + else if (auto mtv = get(ty)) { if (visit(ty, *mtv)) { @@ -260,7 +260,7 @@ struct GenericTypeVarVisitor traverse(mtv->metatable); } } - else if (auto ctv = get(ty)) + else if (auto ctv = get(ty)) { if (visit(ty, *ctv)) { @@ -274,9 +274,9 @@ struct GenericTypeVarVisitor traverse(*ctv->metatable); } } - else if (auto atv = get(ty)) + else if (auto atv = get(ty)) visit(ty, *atv); - else if (auto utv = get(ty)) + else if (auto utv = get(ty)) { if (visit(ty, *utv)) { @@ -284,7 +284,7 @@ struct GenericTypeVarVisitor traverse(optTy); } } - else if (auto itv = get(ty)) + else if (auto itv = get(ty)) { if (visit(ty, *itv)) { @@ -292,21 +292,21 @@ struct GenericTypeVarVisitor traverse(partTy); } } - else if (get(ty)) + else if (get(ty)) { - // Visiting into LazyTypeVar may necessarily cause infinite expansion, so we don't do that on purpose. - // Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassTypeVar + // Visiting into LazyType may necessarily cause infinite expansion, so we don't do that on purpose. + // Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassType // that doesn't need to be expanded. } - else if (auto stv = get(ty)) + else if (auto stv = get(ty)) visit(ty, *stv); - else if (auto btv = get(ty)) + else if (auto btv = get(ty)) visit(ty, *btv); - else if (auto utv = get(ty)) + else if (auto utv = get(ty)) visit(ty, *utv); - else if (auto ntv = get(ty)) + else if (auto ntv = get(ty)) visit(ty, *ntv); - else if (auto petv = get(ty)) + else if (auto petv = get(ty)) { if (visit(ty, *petv)) { @@ -317,12 +317,12 @@ struct GenericTypeVarVisitor traverse(a); } } - else if (auto ntv = get(ty)) + else if (auto ntv = get(ty)) visit(ty, *ntv); else if (!FFlag::LuauCompleteVisitor) return visit_detail::unsee(seen, ty); else - LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypeId) is not exhaustive!"); + LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypeId) is not exhaustive!"); visit_detail::unsee(seen, ty); } @@ -372,7 +372,7 @@ struct GenericTypeVarVisitor visit(tp, *btp); else - LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!"); + LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypePackId) is not exhaustive!"); visit_detail::unsee(seen, tp); } @@ -381,21 +381,21 @@ struct GenericTypeVarVisitor /** Visit each type under a given type. Skips over cycles and keeps recursion depth under control. * * The same type may be visited multiple times if there are multiple distinct paths to it. If this is undesirable, use - * TypeVarOnceVisitor. + * TypeOnceVisitor. */ -struct TypeVarVisitor : GenericTypeVarVisitor> +struct TypeVisitor : GenericTypeVisitor> { - explicit TypeVarVisitor(bool skipBoundTypes = false) - : GenericTypeVarVisitor{{}, skipBoundTypes} + explicit TypeVisitor(bool skipBoundTypes = false) + : GenericTypeVisitor{{}, skipBoundTypes} { } }; /// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it. -struct TypeVarOnceVisitor : GenericTypeVarVisitor> +struct TypeOnceVisitor : GenericTypeVisitor> { - explicit TypeVarOnceVisitor(bool skipBoundTypes = false) - : GenericTypeVarVisitor{DenseHashSet{nullptr}, skipBoundTypes} + explicit TypeOnceVisitor(bool skipBoundTypes = false) + : GenericTypeVisitor{DenseHashSet{nullptr}, skipBoundTypes} { } }; diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp index 5dd761c25..e0ddeacf2 100644 --- a/Analysis/src/Anyification.cpp +++ b/Analysis/src/Anyification.cpp @@ -11,20 +11,20 @@ LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) namespace Luau { -Anyification::Anyification(TypeArena* arena, NotNull scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, +Anyification::Anyification(TypeArena* arena, NotNull scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) : Substitution(TxnLog::empty(), arena) , scope(scope) - , singletonTypes(singletonTypes) + , builtinTypes(builtinTypes) , iceHandler(iceHandler) , anyType(anyType) , anyTypePack(anyTypePack) { } -Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, +Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) - : Anyification(arena, NotNull{scope.get()}, singletonTypes, iceHandler, anyType, anyTypePack) + : Anyification(arena, NotNull{scope.get()}, builtinTypes, iceHandler, anyType, anyTypePack) { } @@ -33,9 +33,9 @@ bool Anyification::isDirty(TypeId ty) if (ty->persistent) return false; - if (const TableTypeVar* ttv = log->getMutable(ty)) + if (const TableType* ttv = log->getMutable(ty)) return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); - else if (log->getMutable(ty)) + else if (log->getMutable(ty)) return true; else return false; @@ -55,9 +55,9 @@ bool Anyification::isDirty(TypePackId tp) TypeId Anyification::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = log->getMutable(ty)) + if (const TableType* ttv = log->getMutable(ty)) { - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; + TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; @@ -77,7 +77,7 @@ TypePackId Anyification::clean(TypePackId tp) bool Anyification::ignoreChildren(TypeId ty) { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) return true; return ty->persistent; diff --git a/Analysis/src/ApplyTypeFunction.cpp b/Analysis/src/ApplyTypeFunction.cpp index b293ed3d6..fe8cc8ac3 100644 --- a/Analysis/src/ApplyTypeFunction.cpp +++ b/Analysis/src/ApplyTypeFunction.cpp @@ -11,7 +11,7 @@ bool ApplyTypeFunction::isDirty(TypeId ty) { if (typeArguments.count(ty)) return true; - else if (const FreeTypeVar* ftv = get(ty)) + else if (const FreeType* ftv = get(ty)) { if (ftv->forwardedTypeAlias) encounteredForwardedType = true; @@ -31,9 +31,9 @@ bool ApplyTypeFunction::isDirty(TypePackId tp) bool ApplyTypeFunction::ignoreChildren(TypeId ty) { - if (get(ty)) + if (get(ty)) return true; - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) return true; else return false; diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index e6e7f3d93..39f613e55 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -4,7 +4,7 @@ #include "Luau/Module.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/ToString.h" #include "Luau/Common.h" @@ -447,7 +447,7 @@ static std::optional checkOverloadedDocumentationSymbol( return std::nullopt; // This might be an overloaded function. - if (get(follow(ty))) + if (get(follow(ty))) { TypeId matchingOverload = nullptr; if (parentExpr && parentExpr->is()) @@ -487,12 +487,12 @@ std::optional getDocumentationSymbolAtPosition(const Source if (auto it = module.astTypes.find(indexName->expr)) { TypeId parentTy = follow(*it); - if (const TableTypeVar* ttv = get(parentTy)) + if (const TableType* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); } - else if (const ClassTypeVar* ctv = get(parentTy)) + else if (const ClassType* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 83a6f0217..7a649546d 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -43,7 +43,7 @@ static bool alreadyHasParens(const std::vector& nodes) return false; } -static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionTypeVar* func, const std::vector& nodes) +static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionType* func, const std::vector& nodes) { if (alreadyHasParens(nodes)) { @@ -61,12 +61,12 @@ static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionTyp return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; } -static ParenthesesRecommendation getParenRecommendationForIntersect(const IntersectionTypeVar* intersect, const std::vector& nodes) +static ParenthesesRecommendation getParenRecommendationForIntersect(const IntersectionType* intersect, const std::vector& nodes) { ParenthesesRecommendation rec = ParenthesesRecommendation::None; for (Luau::TypeId partId : intersect->parts) { - if (auto partFunc = Luau::get(partId)) + if (auto partFunc = Luau::get(partId)) { rec = std::max(rec, getParenRecommendationForFunc(partFunc, nodes)); } @@ -85,11 +85,11 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve return ParenthesesRecommendation::None; id = Luau::follow(id); - if (auto func = get(id)) + if (auto func = get(id)) { return getParenRecommendationForFunc(func, nodes); } - else if (auto intersect = get(id)) + else if (auto intersect = get(id)) { return getParenRecommendationForIntersect(intersect, nodes); } @@ -113,7 +113,7 @@ static std::optional findExpectedTypeAt(const Module& module, AstNode* n if (!it) return std::nullopt; - const FunctionTypeVar* ftv = get(follow(*it)); + const FunctionType* ftv = get(follow(*it)); if (!ftv) return std::nullopt; @@ -135,18 +135,18 @@ static std::optional findExpectedTypeAt(const Module& module, AstNode* n return *it; } -static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena, NotNull singletonTypes) +static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena, NotNull builtinTypes) { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); - Normalizer normalizer{typeArena, singletonTypes, NotNull{&unifierState}}; + Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); return unifier.canUnify(subTy, superTy).empty(); } static TypeCorrectKind checkTypeCorrectKind( - const Module& module, TypeArena* typeArena, NotNull singletonTypes, AstNode* node, Position position, TypeId ty) + const Module& module, TypeArena* typeArena, NotNull builtinTypes, AstNode* node, Position position, TypeId ty) { ty = follow(ty); @@ -159,31 +159,31 @@ static TypeCorrectKind checkTypeCorrectKind( TypeId expectedType = follow(*typeAtPosition); - auto checkFunctionType = [typeArena, singletonTypes, moduleScope, &expectedType](const FunctionTypeVar* ftv) { + auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) { if (std::optional firstRetTy = first(ftv->retTypes)) - return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, singletonTypes); + return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); return false; }; // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty); ftv && checkFunctionType(ftv)) + if (const FunctionType* ftv = get(ty); ftv && checkFunctionType(ftv)) { return TypeCorrectKind::CorrectFunctionResult; } - else if (const IntersectionTypeVar* itv = get(ty)) + else if (const IntersectionType* itv = get(ty)) { for (TypeId id : itv->parts) { - if (const FunctionTypeVar* ftv = get(id); ftv && checkFunctionType(ftv)) + if (const FunctionType* ftv = get(id); ftv && checkFunctionType(ftv)) { return TypeCorrectKind::CorrectFunctionResult; } } } - return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena, singletonTypes) ? TypeCorrectKind::Correct - : TypeCorrectKind::None; + return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes) ? TypeCorrectKind::Correct + : TypeCorrectKind::None; } enum class PropIndexType @@ -193,9 +193,9 @@ enum class PropIndexType Key, }; -static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull singletonTypes, TypeId rootTy, TypeId ty, +static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull builtinTypes, TypeId rootTy, TypeId ty, PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, - std::optional containingClass = std::nullopt) + std::optional containingClass = std::nullopt) { rootTy = follow(rootTy); ty = follow(ty); @@ -204,41 +204,41 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul return; seen.insert(ty); - auto isWrongIndexer = [typeArena, singletonTypes, &module, rootTy, indexType](Luau::TypeId type) { + auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) { if (indexType == PropIndexType::Key) return false; bool calledWithSelf = indexType == PropIndexType::Colon; - auto isCompatibleCall = [typeArena, singletonTypes, &module, rootTy, calledWithSelf](const FunctionTypeVar* ftv) { + auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) { // Strong match with definition is a success if (calledWithSelf == ftv->hasSelf) return true; // Calls on classes require strict match between how function is declared and how it's called - if (get(rootTy)) + if (get(rootTy)) return false; // When called with ':', but declared without 'self', it is invalid if a function has incompatible first argument or no arguments at all // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible if (std::optional firstArgTy = first(ftv->argTypes)) { - if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, singletonTypes)) + if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes)) return calledWithSelf; } return !calledWithSelf; }; - if (const FunctionTypeVar* ftv = get(type)) + if (const FunctionType* ftv = get(type)) return !isCompatibleCall(ftv); // For intersections, any part that is successful makes the whole call successful - if (const IntersectionTypeVar* itv = get(type)) + if (const IntersectionType* itv = get(type)) { for (auto subType : itv->parts) { - if (const FunctionTypeVar* ftv = get(Luau::follow(subType))) + if (const FunctionType* ftv = get(Luau::follow(subType))) { if (isCompatibleCall(ftv)) return false; @@ -249,7 +249,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul return calledWithSelf; }; - auto fillProps = [&](const ClassTypeVar::Props& props) { + auto fillProps = [&](const ClassType::Props& props) { for (const auto& [name, prop] : props) { // We are walking up the class hierarchy, so if we encounter a property that we have @@ -259,7 +259,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul Luau::TypeId type = Luau::follow(prop.type); TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct - : checkTypeCorrectKind(module, typeArena, singletonTypes, nodes.back(), {{}, {}}, type); + : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); @@ -279,41 +279,41 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul } }; - auto fillMetatableProps = [&](const TableTypeVar* mtable) { + auto fillMetatableProps = [&](const TableType* mtable) { auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { TypeId followed = follow(indexIt->second.type); - if (get(followed) || get(followed)) + if (get(followed) || get(followed)) { - autocompleteProps(module, typeArena, singletonTypes, rootTy, followed, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen); } - else if (auto indexFunction = get(followed)) + else if (auto indexFunction = get(followed)) { std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) - autocompleteProps(module, typeArena, singletonTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, builtinTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } } }; - if (auto cls = get(ty)) + if (auto cls = get(ty)) { containingClass = containingClass.value_or(cls); fillProps(cls->props); if (cls->parent) - autocompleteProps(module, typeArena, singletonTypes, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); + autocompleteProps(module, typeArena, builtinTypes, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); } - else if (auto tbl = get(ty)) + else if (auto tbl = get(ty)) fillProps(tbl->props); - else if (auto mt = get(ty)) + else if (auto mt = get(ty)) { - autocompleteProps(module, typeArena, singletonTypes, rootTy, mt->table, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, builtinTypes, rootTy, mt->table, indexType, nodes, result, seen); - if (auto mtable = get(mt->metatable)) + if (auto mtable = get(mt->metatable)) fillMetatableProps(mtable); } - else if (auto i = get(ty)) + else if (auto i = get(ty)) { // Complete all properties in every variant for (TypeId ty : i->parts) @@ -321,13 +321,13 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul AutocompleteEntryMap inner; std::unordered_set innerSeen = seen; - autocompleteProps(module, typeArena, singletonTypes, rootTy, ty, indexType, nodes, inner, innerSeen); + autocompleteProps(module, typeArena, builtinTypes, rootTy, ty, indexType, nodes, inner, innerSeen); for (auto& pair : inner) result.insert(pair); } } - else if (auto u = get(ty)) + else if (auto u = get(ty)) { // Complete all properties common to all variants auto iter = begin(u); @@ -344,7 +344,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul if (iter == endIter) return; - autocompleteProps(module, typeArena, singletonTypes, rootTy, *iter, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, result, seen); ++iter; @@ -359,7 +359,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul continue; } - autocompleteProps(module, typeArena, singletonTypes, rootTy, *iter, indexType, nodes, inner, innerSeen); + autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, inner, innerSeen); std::unordered_set toRemove; @@ -376,17 +376,17 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul ++iter; } } - else if (auto pt = get(ty)) + else if (auto pt = get(ty)) { if (pt->metatable) { - if (auto mtable = get(*pt->metatable)) + if (auto mtable = get(*pt->metatable)) fillMetatableProps(mtable); } } - else if (get(get(ty))) + else if (get(get(ty))) { - autocompleteProps(module, typeArena, singletonTypes, rootTy, singletonTypes->stringType, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, builtinTypes, rootTy, builtinTypes->stringType, indexType, nodes, result, seen); } } @@ -411,18 +411,18 @@ static void autocompleteKeywords( } } -static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull singletonTypes, TypeId ty, PropIndexType indexType, +static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull builtinTypes, TypeId ty, PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result) { std::unordered_set seen; - autocompleteProps(module, typeArena, singletonTypes, ty, ty, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen); } -AutocompleteEntryMap autocompleteProps(const Module& module, TypeArena* typeArena, NotNull singletonTypes, TypeId ty, +AutocompleteEntryMap autocompleteProps(const Module& module, TypeArena* typeArena, NotNull builtinTypes, TypeId ty, PropIndexType indexType, const std::vector& nodes) { AutocompleteEntryMap result; - autocompleteProps(module, typeArena, singletonTypes, ty, indexType, nodes, result); + autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result); return result; } @@ -455,15 +455,15 @@ static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteE ty = follow(ty); - if (auto ss = get(get(ty))) + if (auto ss = get(get(ty))) { result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; } - else if (auto uty = get(ty)) + else if (auto uty = get(ty)) { for (auto el : uty) { - if (auto ss = get(get(el))) + if (auto ss = get(get(el))) result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; } } @@ -474,14 +474,14 @@ static bool canSuggestInferredType(ScopePtr scope, TypeId ty) ty = follow(ty); // No point in suggesting 'any', invalid to suggest others - if (get(ty) || get(ty) || get(ty) || get(ty)) + if (get(ty) || get(ty) || get(ty) || get(ty)) return false; // No syntax for unnamed tables with a metatable - if (get(ty)) + if (get(ty)) return false; - if (const TableTypeVar* ttv = get(ty)) + if (const TableType* ttv = get(ty)) { if (ttv->name) return true; @@ -544,7 +544,7 @@ static std::optional findTypeElementAt(AstType* astType, TypeId ty, Posi if (AstTypeFunction* type = astType->as()) { - const FunctionTypeVar* ftv = get(ty); + const FunctionType* ftv = get(ty); if (!ftv) return {}; @@ -634,7 +634,7 @@ static std::optional tryGetTypePackTypeAt(TypePackId tp, size_t index) } template -std::optional returnFirstNonnullOptionOfType(const UnionTypeVar* utv) +std::optional returnFirstNonnullOptionOfType(const UnionType* utv) { std::optional ret; for (TypeId subTy : utv) @@ -667,18 +667,18 @@ static std::optional functionIsExpectedAt(const Module& module, AstNode* n TypeId expectedType = follow(*typeAtPosition); - if (get(expectedType)) + if (get(expectedType)) return true; - if (const IntersectionTypeVar* itv = get(expectedType)) + if (const IntersectionType* itv = get(expectedType)) { return std::all_of(begin(itv->parts), end(itv->parts), [](auto&& ty) { - return get(Luau::follow(ty)) != nullptr; + return get(Luau::follow(ty)) != nullptr; }); } - if (const UnionTypeVar* utv = get(expectedType)) - return returnFirstNonnullOptionOfType(utv).has_value(); + if (const UnionType* utv = get(expectedType)) + return returnFirstNonnullOptionOfType(utv).has_value(); return false; } @@ -766,7 +766,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (auto it = module.astTypes.find(exprCall->func)) { - if (const FunctionTypeVar* ftv = get(follow(*it))) + if (const FunctionType* ftv = get(follow(*it))) { if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) inferredType = *ty; @@ -792,7 +792,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi else if (AstExprFunction* node = parent->as()) { // For lookup inside expected function type if that's available - auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionTypeVar* { + auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* { auto it = module.astExpectedTypes.find(expr); if (!it) @@ -800,13 +800,13 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi TypeId ty = follow(*it); - if (const FunctionTypeVar* ftv = get(ty)) + if (const FunctionType* ftv = get(ty)) return ftv; // Handle optional function type - if (const UnionTypeVar* utv = get(ty)) + if (const UnionType* utv = get(ty)) { - return returnFirstNonnullOptionOfType(utv).value_or(nullptr); + return returnFirstNonnullOptionOfType(utv).value_or(nullptr); } return nullptr; @@ -819,7 +819,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (arg->annotation && arg->annotation->location.containsClosed(position)) { - if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) { if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, i)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); @@ -840,7 +840,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (variadic->location.containsClosed(position)) { - if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) { if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, ~0u)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); @@ -858,7 +858,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (ret->location.containsClosed(position)) { - if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) { if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); @@ -875,7 +875,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (variadic->location.containsClosed(position)) { - if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) { if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); @@ -1127,7 +1127,7 @@ static bool autocompleteIfElseExpression( } } -static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull singletonTypes, +static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull builtinTypes, TypeArena* typeArena, const std::vector& ancestry, Position position, AutocompleteEntryMap& result) { LUAU_ASSERT(!ancestry.empty()); @@ -1137,7 +1137,7 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu if (node->is()) { if (auto it = module.astTypes.find(node->asExpr())) - autocompleteProps(module, typeArena, singletonTypes, *it, PropIndexType::Point, ancestry, result); + autocompleteProps(module, typeArena, builtinTypes, *it, PropIndexType::Point, ancestry, result); } else if (autocompleteIfElseExpression(node, ancestry, position, result)) return AutocompleteContext::Keyword; @@ -1161,7 +1161,7 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu std::string n = toString(name); if (!result.count(n)) { - TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, binding.typeId); + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId); result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; @@ -1171,16 +1171,16 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu scope = scope->parent; } - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->nilType); - TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->trueType); - TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->falseType); + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->nilType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->falseType); TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, singletonTypes->booleanType, false, false, correctForTrue}; - result["false"] = {AutocompleteEntryKind::Keyword, singletonTypes->booleanType, false, false, correctForFalse}; - result["nil"] = {AutocompleteEntryKind::Keyword, singletonTypes->nilType, false, false, correctForNil}; + result["true"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForFalse}; + result["nil"] = {AutocompleteEntryKind::Keyword, builtinTypes->nilType, false, false, correctForNil}; result["not"] = {AutocompleteEntryKind::Keyword}; result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; @@ -1191,15 +1191,15 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu return AutocompleteContext::Expression; } -static AutocompleteResult autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull singletonTypes, +static AutocompleteResult autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull builtinTypes, TypeArena* typeArena, const std::vector& ancestry, Position position) { AutocompleteEntryMap result; - AutocompleteContext context = autocompleteExpression(sourceModule, module, singletonTypes, typeArena, ancestry, position, result); + AutocompleteContext context = autocompleteExpression(sourceModule, module, builtinTypes, typeArena, ancestry, position, result); return {result, ancestry, context}; } -static std::optional getMethodContainingClass(const ModulePtr& module, AstExpr* funcExpr) +static std::optional getMethodContainingClass(const ModulePtr& module, AstExpr* funcExpr) { AstExpr* parentExpr = nullptr; if (auto indexName = funcExpr->as()) @@ -1223,14 +1223,14 @@ static std::optional getMethodContainingClass(const ModuleP Luau::TypeId parentType = Luau::follow(*parentIt); - if (auto parentClass = Luau::get(parentType)) + if (auto parentClass = Luau::get(parentType)) { return parentClass; } - if (auto parentUnion = Luau::get(parentType)) + if (auto parentUnion = Luau::get(parentType)) { - return returnFirstNonnullOptionOfType(parentUnion); + return returnFirstNonnullOptionOfType(parentUnion); } return std::nullopt; @@ -1281,7 +1281,7 @@ static std::optional autocompleteStringParams(const Source } // HACK: All current instances of 'magic string' params are the first parameter of their functions, - // so we encode that here rather than putting a useless member on the FunctionTypeVar struct. + // so we encode that here rather than putting a useless member on the FunctionType struct. if (candidate->args.size > 1 && !candidate->args.data[0]->location.contains(position)) { return std::nullopt; @@ -1293,7 +1293,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - auto performCallback = [&](const FunctionTypeVar* funcType) -> std::optional { + auto performCallback = [&](const FunctionType* funcType) -> std::optional { for (const std::string& tag : funcType->tags) { if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func))) @@ -1305,16 +1305,16 @@ static std::optional autocompleteStringParams(const Source }; auto followedId = Luau::follow(*it); - if (auto functionType = Luau::get(followedId)) + if (auto functionType = Luau::get(followedId)) { return performCallback(functionType); } - if (auto intersect = Luau::get(followedId)) + if (auto intersect = Luau::get(followedId)) { for (TypeId part : intersect->parts) { - if (auto candidateFunctionType = Luau::get(part)) + if (auto candidateFunctionType = Luau::get(part)) { if (std::optional ret = performCallback(candidateFunctionType)) { @@ -1327,7 +1327,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } -static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull singletonTypes, +static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull builtinTypes, Scope* globalScope, Position position, StringCompletionCallback callback) { if (isWithinComment(sourceModule, position)) @@ -1360,7 +1360,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - return {autocompleteProps(*module, &typeArena, singletonTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; + return {autocompleteProps(*module, &typeArena, builtinTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; } else if (auto typeReference = node->as()) { @@ -1378,7 +1378,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Unknown}; else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) - return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); else return {}; } @@ -1392,7 +1392,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); return {}; } @@ -1422,7 +1422,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; if (lastExpr->location.containsClosed(position)) - return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); if (position > lastExpr->location.end) return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; @@ -1446,7 +1446,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; if (!statWhile->hasDo || position < statWhile->doLocation.begin) - return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); if (statWhile->hasDo && position > statWhile->doLocation.end) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1463,7 +1463,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatIf* statIf = parent->as(); statIf && node->is()) { if (statIf->condition->is()) - return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; } @@ -1471,7 +1471,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) - return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; else if (AstExprTable* exprTable = parent->as(); @@ -1484,7 +1484,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, &typeArena, singletonTypes, *it, PropIndexType::Key, ancestry); + auto result = autocompleteProps(*module, &typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); if (FFlag::LuauCompleteTableKeysBetter) { @@ -1499,7 +1499,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // // If the key type is a union of singleton strings, // suggest those too. - if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) { autocompleteStringSingleton(ttv->indexer->indexType, false, result); } @@ -1518,7 +1518,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // If we know for sure that a key is being written, do not offer general expression suggestions if (!key) - autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position, result); + autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position, result); return {result, ancestry, AutocompleteContext::Property}; } @@ -1546,7 +1546,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) { if (auto it = module->astTypes.find(idxExpr->expr)) - autocompleteProps(*module, &typeArena, singletonTypes, follow(*it), PropIndexType::Point, ancestry, result); + autocompleteProps(*module, &typeArena, builtinTypes, follow(*it), PropIndexType::Point, ancestry, result); } else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) { @@ -1572,7 +1572,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {}; if (node->asExpr()) - return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); else if (node->asStat()) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1596,10 +1596,10 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName if (!module) return {}; - NotNull singletonTypes = frontend.singletonTypes; + NotNull builtinTypes = frontend.builtinTypes; Scope* globalScope = frontend.typeCheckerForAutocomplete.globalScope.get(); - AutocompleteResult autocompleteResult = autocomplete(*sourceModule, module, singletonTypes, globalScope, position, callback); + AutocompleteResult autocompleteResult = autocomplete(*sourceModule, module, builtinTypes, globalScope, position, callback); return autocompleteResult; } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 612812c56..81702ff65 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -10,7 +10,7 @@ #include "Luau/ConstraintGraphBuilder.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/TypeUtils.h" #include @@ -51,12 +51,12 @@ static std::vector dcrMagicRefinementAssert(const MagicRefinementC TypeId makeUnion(TypeArena& arena, std::vector&& types) { - return arena.addType(UnionTypeVar{std::move(types)}); + return arena.addType(UnionType{std::move(types)}); } TypeId makeIntersection(TypeArena& arena, std::vector&& types) { - return arena.addType(IntersectionTypeVar{std::move(types)}); + return arena.addType(IntersectionType{std::move(types)}); } TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t) @@ -99,7 +99,7 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi TypePackId paramPack = arena.addTypePack(std::move(params)); TypePackId retPack = arena.addTypePack(std::vector(retTypes)); - FunctionTypeVar ftv{generics, genericPacks, paramPack, retPack, {}, selfType.has_value()}; + FunctionType ftv{generics, genericPacks, paramPack, retPack, {}, selfType.has_value()}; if (selfType) ftv.argNames.push_back(Luau::FunctionArgument{"self", {}}); @@ -121,7 +121,7 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi void attachMagicFunction(TypeId ty, MagicFunction fn) { - if (auto ftv = getMutable(ty)) + if (auto ftv = getMutable(ty)) ftv->magicFunction = fn; else LUAU_ASSERT(!"Got a non functional type"); @@ -129,7 +129,7 @@ void attachMagicFunction(TypeId ty, MagicFunction fn) void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn) { - if (auto ftv = getMutable(ty)) + if (auto ftv = getMutable(ty)) ftv->dcrMagicFunction = fn; else LUAU_ASSERT(!"Got a non functional type"); @@ -137,7 +137,7 @@ void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn) void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn) { - if (auto ftv = getMutable(ty)) + if (auto ftv = getMutable(ty)) ftv->dcrMagicRefinement = fn; else LUAU_ASSERT(!"Got a non functional type"); @@ -239,7 +239,7 @@ Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& nam return nullptr; } -void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName) +void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName) { for (auto& [name, prop] : props) { @@ -249,39 +249,39 @@ void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::strin void registerBuiltinTypes(Frontend& frontend) { - frontend.getGlobalScope()->addBuiltinTypeBinding("any", TypeFun{{}, frontend.singletonTypes->anyType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("nil", TypeFun{{}, frontend.singletonTypes->nilType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("number", TypeFun{{}, frontend.singletonTypes->numberType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("string", TypeFun{{}, frontend.singletonTypes->stringType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("boolean", TypeFun{{}, frontend.singletonTypes->booleanType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("thread", TypeFun{{}, frontend.singletonTypes->threadType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("any", TypeFun{{}, frontend.builtinTypes->anyType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("nil", TypeFun{{}, frontend.builtinTypes->nilType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("number", TypeFun{{}, frontend.builtinTypes->numberType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("string", TypeFun{{}, frontend.builtinTypes->stringType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("boolean", TypeFun{{}, frontend.builtinTypes->booleanType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("thread", TypeFun{{}, frontend.builtinTypes->threadType}); if (FFlag::LuauUnknownAndNeverType) { - frontend.getGlobalScope()->addBuiltinTypeBinding("unknown", TypeFun{{}, frontend.singletonTypes->unknownType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("never", TypeFun{{}, frontend.singletonTypes->neverType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("unknown", TypeFun{{}, frontend.builtinTypes->unknownType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("never", TypeFun{{}, frontend.builtinTypes->neverType}); } } void registerBuiltinGlobals(TypeChecker& typeChecker) { - LUAU_ASSERT(!typeChecker.globalTypes.typeVars.isFrozen()); + LUAU_ASSERT(!typeChecker.globalTypes.types.isFrozen()); LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); TypeId nilType = typeChecker.nilType; TypeArena& arena = typeChecker.globalTypes; - NotNull singletonTypes = typeChecker.singletonTypes; + NotNull builtinTypes = typeChecker.builtinTypes; LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); - TypeId genericK = arena.addType(GenericTypeVar{"K"}); - TypeId genericV = arena.addType(GenericTypeVar{"V"}); - TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); + TypeId genericK = arena.addType(GenericType{"K"}); + TypeId genericV = arena.addType(GenericType{"V"}); + TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); - std::optional stringMetatableTy = getMetatable(singletonTypes->stringType, singletonTypes); + std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); LUAU_ASSERT(stringMetatableTy); - const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); + const TableType* stringMetatableTable = get(follow(*stringMetatableTy)); LUAU_ASSERT(stringMetatableTable); auto it = stringMetatableTable->props.find("__index"); @@ -294,40 +294,40 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) // next(t: Table, i: K?) -> (K?, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(typeChecker, arena, genericK), genericV}}); - addGlobalBinding(typeChecker, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + addGlobalBinding(typeChecker, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, nextRetsTypePack}); + TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) addGlobalBinding( - typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + typeChecker, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); } else { // next(t: Table, i: K?) -> (K, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); addGlobalBinding(typeChecker, "next", - arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) addGlobalBinding( - typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + typeChecker, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); } - TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); + TypeId genericMT = arena.addType(GenericType{"MT"}); - TableTypeVar tab{TableState::Generic, typeChecker.globalScope->level}; + TableType tab{TableState::Generic, typeChecker.globalScope->level}; TypeId tabTy = arena.addType(tab); - TypeId tableMetaMT = arena.addType(MetatableTypeVar{tabTy, genericMT}); + TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); addGlobalBinding(typeChecker, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); @@ -335,7 +335,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) // setmetatable(T, MT) -> { @metatable MT, T } addGlobalBinding(typeChecker, "setmetatable", arena.addType( - FunctionTypeVar{ + FunctionType{ {genericMT}, {}, arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}), @@ -349,7 +349,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) { persist(pair.second.typeId); - if (TableTypeVar* ttv = getMutable(pair.second.typeId)) + if (TableType* ttv = getMutable(pair.second.typeId)) { if (!ttv->name) { @@ -366,7 +366,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) attachMagicFunction(getGlobalBinding(typeChecker, "select"), magicFunctionSelect); attachDcrMagicFunction(getGlobalBinding(typeChecker, "select"), dcrMagicFunctionSelect); - if (TableTypeVar* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) { // tabTy is a generic table type which we can't express via declaration syntax yet ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); @@ -382,25 +382,25 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) void registerBuiltinGlobals(Frontend& frontend) { - LUAU_ASSERT(!frontend.globalTypes.typeVars.isFrozen()); + LUAU_ASSERT(!frontend.globalTypes.types.isFrozen()); LUAU_ASSERT(!frontend.globalTypes.typePacks.isFrozen()); if (FFlag::LuauReportShadowedTypeAlias) registerBuiltinTypes(frontend); TypeArena& arena = frontend.globalTypes; - NotNull singletonTypes = frontend.singletonTypes; + NotNull builtinTypes = frontend.builtinTypes; LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); - TypeId genericK = arena.addType(GenericTypeVar{"K"}); - TypeId genericV = arena.addType(GenericTypeVar{"V"}); - TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), frontend.getGlobalScope()->level, TableState::Generic}); + TypeId genericK = arena.addType(GenericType{"K"}); + TypeId genericV = arena.addType(GenericType{"V"}); + TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), frontend.getGlobalScope()->level, TableState::Generic}); - std::optional stringMetatableTy = getMetatable(singletonTypes->stringType, singletonTypes); + std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); LUAU_ASSERT(stringMetatableTy); - const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); + const TableType* stringMetatableTable = get(follow(*stringMetatableTy)); LUAU_ASSERT(stringMetatableTable); auto it = stringMetatableTable->props.find("__index"); @@ -413,40 +413,38 @@ void registerBuiltinGlobals(Frontend& frontend) // next(t: Table, i: K?) -> (K?, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(frontend, arena, genericK), genericV}}); - addGlobalBinding(frontend, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + addGlobalBinding(frontend, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); + TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.builtinTypes->nilType}}); // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) - addGlobalBinding( - frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + addGlobalBinding(frontend, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); } else { // next(t: Table, i: K?) -> (K, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); addGlobalBinding(frontend, "next", - arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); + TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.builtinTypes->nilType}}); // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding( - frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + addGlobalBinding(frontend, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); } - TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); + TypeId genericMT = arena.addType(GenericType{"MT"}); - TableTypeVar tab{TableState::Generic, frontend.getGlobalScope()->level}; + TableType tab{TableState::Generic, frontend.getGlobalScope()->level}; TypeId tabTy = arena.addType(tab); - TypeId tableMetaMT = arena.addType(MetatableTypeVar{tabTy, genericMT}); + TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); addGlobalBinding(frontend, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); @@ -454,7 +452,7 @@ void registerBuiltinGlobals(Frontend& frontend) // setmetatable(T, MT) -> { @metatable MT, T } addGlobalBinding(frontend, "setmetatable", arena.addType( - FunctionTypeVar{ + FunctionType{ {genericMT}, {}, arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}), @@ -468,7 +466,7 @@ void registerBuiltinGlobals(Frontend& frontend) { persist(pair.second.typeId); - if (TableTypeVar* ttv = getMutable(pair.second.typeId)) + if (TableType* ttv = getMutable(pair.second.typeId)) { if (!ttv->name) { @@ -486,7 +484,7 @@ void registerBuiltinGlobals(Frontend& frontend) attachMagicFunction(getGlobalBinding(frontend, "select"), magicFunctionSelect); attachDcrMagicFunction(getGlobalBinding(frontend, "select"), dcrMagicFunctionSelect); - if (TableTypeVar* ttv = getMutable(getGlobalBinding(frontend, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(frontend, "table"))) { // tabTy is a generic table type which we can't express via declaration syntax yet ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); @@ -576,7 +574,7 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) { if (str->value.size == 1 && str->value.data[0] == '#') { - TypePackId numberTypePack = context.solver->arena->addTypePack({context.solver->singletonTypes->numberType}); + TypePackId numberTypePack = context.solver->arena->addTypePack({context.solver->builtinTypes->numberType}); asMutable(context.result)->ty.emplace(numberTypePack); return true; } @@ -609,7 +607,7 @@ static std::optional> magicFunctionSetMetaTable( typechecker.tablify(mt); } - if (const auto& tab = get(target)) + if (const auto& tab = get(target)) { if (target->persistent) { @@ -620,8 +618,8 @@ static std::optional> magicFunctionSetMetaTable( if (!FFlag::LuauUnknownAndNeverType) typechecker.tablify(mt); - const TableTypeVar* mtTtv = get(mt); - MetatableTypeVar mtv{target, mt}; + const TableType* mtTtv = get(mt); + MetatableType mtv{target, mt}; if ((tab->name || tab->syntheticName) && (mtTtv && (mtTtv->name || mtTtv->syntheticName))) { std::string tableName = tab->name ? *tab->name : *tab->syntheticName; @@ -656,7 +654,7 @@ static std::optional> magicFunctionSetMetaTable( return WithPredicate{arena.addTypePack({mtTy})}; } } - else if (get(target) || get(target) || isTableIntersection(target)) + else if (get(target) || get(target) || isTableIntersection(target)) { } else @@ -687,10 +685,10 @@ static std::optional> magicFunctionAssert( if (head.size() > 0) { - auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true, typechecker.singletonTypes->nilType); + auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true, typechecker.builtinTypes->nilType); if (FFlag::LuauUnknownAndNeverType) { - if (get(*ty)) + if (get(*ty)) head = {*ty}; else head[0] = *ty; @@ -747,10 +745,10 @@ static std::optional> magicFunctionPack( else if (options.size() == 1) result = options[0]; else - result = arena.addType(UnionTypeVar{std::move(options)}); + result = arena.addType(UnionType{std::move(options)}); - TypeId packedTable = arena.addType( - TableTypeVar{{{"n", {typechecker.numberType}}}, TableIndexer(typechecker.numberType, result), scope->level, TableState::Sealed}); + TypeId packedTable = + arena.addType(TableType{{{"n", {typechecker.numberType}}}, TableIndexer(typechecker.numberType, result), scope->level, TableState::Sealed}); return WithPredicate{arena.addTypePack({packedTable})}; } @@ -780,14 +778,14 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context) // table.pack(1, "foo") -> {| n: number, [number]: number | string |} TypeId result = nullptr; if (options.empty()) - result = context.solver->singletonTypes->nilType; + result = context.solver->builtinTypes->nilType; else if (options.size() == 1) result = options[0]; else - result = arena->addType(UnionTypeVar{std::move(options)}); + result = arena->addType(UnionType{std::move(options)}); - TypeId numberType = context.solver->singletonTypes->numberType; - TypeId packedTable = arena->addType(TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); + TypeId numberType = context.solver->builtinTypes->numberType; + TypeId packedTable = arena->addType(TableType{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); TypePackId tableTypePack = arena->addTypePack({packedTable}); asMutable(context.result)->ty.emplace(tableTypePack); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 86e1c7fc9..870d29490 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -48,21 +48,21 @@ struct TypeCloner void operator()(const Unifiable::Generic& t); void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); - void operator()(const BlockedTypeVar& t); - void operator()(const PendingExpansionTypeVar& t); - void operator()(const PrimitiveTypeVar& t); - void operator()(const SingletonTypeVar& t); - void operator()(const FunctionTypeVar& t); - void operator()(const TableTypeVar& t); - void operator()(const MetatableTypeVar& t); - void operator()(const ClassTypeVar& t); - void operator()(const AnyTypeVar& t); - void operator()(const UnionTypeVar& t); - void operator()(const IntersectionTypeVar& t); - void operator()(const LazyTypeVar& t); - void operator()(const UnknownTypeVar& t); - void operator()(const NeverTypeVar& t); - void operator()(const NegationTypeVar& t); + void operator()(const BlockedType& t); + void operator()(const PendingExpansionType& t); + void operator()(const PrimitiveType& t); + void operator()(const SingletonType& t); + void operator()(const FunctionType& t); + void operator()(const TableType& t); + void operator()(const MetatableType& t); + void operator()(const ClassType& t); + void operator()(const AnyType& t); + void operator()(const UnionType& t); + void operator()(const IntersectionType& t); + void operator()(const LazyType& t); + void operator()(const UnknownType& t); + void operator()(const NeverType& t); + void operator()(const NegationType& t); }; struct TypePackCloner @@ -107,7 +107,7 @@ struct TypePackCloner defaultClone(t); } - // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. + // While we are a-cloning, we can flatten out bound Types and make things a bit tighter. // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. void operator()(const Unifiable::Bound& t) { @@ -159,7 +159,7 @@ void TypeCloner::operator()(const Unifiable::Bound& t) { TypeId boundTo = clone(t.boundTo, dest, cloneState); if (FFlag::DebugLuauCopyBeforeNormalizing) - boundTo = dest.addType(BoundTypeVar{boundTo}); + boundTo = dest.addType(BoundType{boundTo}); seenTypes[typeId] = boundTo; } @@ -168,15 +168,15 @@ void TypeCloner::operator()(const Unifiable::Error& t) defaultClone(t); } -void TypeCloner::operator()(const BlockedTypeVar& t) +void TypeCloner::operator()(const BlockedType& t) { defaultClone(t); } -void TypeCloner::operator()(const PendingExpansionTypeVar& t) +void TypeCloner::operator()(const PendingExpansionType& t) { - TypeId res = dest.addType(PendingExpansionTypeVar{t.prefix, t.name, t.typeArguments, t.packArguments}); - PendingExpansionTypeVar* petv = getMutable(res); + TypeId res = dest.addType(PendingExpansionType{t.prefix, t.name, t.typeArguments, t.packArguments}); + PendingExpansionType* petv = getMutable(res); LUAU_ASSERT(petv); seenTypes[typeId] = res; @@ -193,23 +193,23 @@ void TypeCloner::operator()(const PendingExpansionTypeVar& t) petv->packArguments = std::move(packArguments); } -void TypeCloner::operator()(const PrimitiveTypeVar& t) +void TypeCloner::operator()(const PrimitiveType& t) { defaultClone(t); } -void TypeCloner::operator()(const SingletonTypeVar& t) +void TypeCloner::operator()(const SingletonType& t) { defaultClone(t); } -void TypeCloner::operator()(const FunctionTypeVar& t) +void TypeCloner::operator()(const FunctionType& t) { // FISHY: We always erase the scope when we clone things. clone() was // originally written so that we could copy a module's type surface into an // export arena. This probably dates to that. - TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); - FunctionTypeVar* ftv = getMutable(result); + TypeId result = dest.addType(FunctionType{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); + FunctionType* ftv = getMutable(result); LUAU_ASSERT(ftv != nullptr); seenTypes[typeId] = result; @@ -227,7 +227,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) ftv->hasNoGenerics = t.hasNoGenerics; } -void TypeCloner::operator()(const TableTypeVar& t) +void TypeCloner::operator()(const TableType& t) { // If table is now bound to another one, we ignore the content of the original if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) @@ -237,8 +237,8 @@ void TypeCloner::operator()(const TableTypeVar& t) return; } - TypeId result = dest.addType(TableTypeVar{}); - TableTypeVar* ttv = getMutable(result); + TypeId result = dest.addType(TableType{}); + TableType* ttv = getMutable(result); LUAU_ASSERT(ttv != nullptr); *ttv = t; @@ -266,20 +266,20 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->tags = t.tags; } -void TypeCloner::operator()(const MetatableTypeVar& t) +void TypeCloner::operator()(const MetatableType& t) { - TypeId result = dest.addType(MetatableTypeVar{}); - MetatableTypeVar* mtv = getMutable(result); + TypeId result = dest.addType(MetatableType{}); + MetatableType* mtv = getMutable(result); seenTypes[typeId] = result; mtv->table = clone(t.table, dest, cloneState); mtv->metatable = clone(t.metatable, dest, cloneState); } -void TypeCloner::operator()(const ClassTypeVar& t) +void TypeCloner::operator()(const ClassType& t) { - TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName}); - ClassTypeVar* ctv = getMutable(result); + TypeId result = dest.addType(ClassType{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName}); + ClassType* ctv = getMutable(result); seenTypes[typeId] = result; @@ -293,12 +293,12 @@ void TypeCloner::operator()(const ClassTypeVar& t) ctv->metatable = clone(*t.metatable, dest, cloneState); } -void TypeCloner::operator()(const AnyTypeVar& t) +void TypeCloner::operator()(const AnyType& t) { defaultClone(t); } -void TypeCloner::operator()(const UnionTypeVar& t) +void TypeCloner::operator()(const UnionType& t) { std::vector options; options.reserve(t.options.size()); @@ -306,44 +306,44 @@ void TypeCloner::operator()(const UnionTypeVar& t) for (TypeId ty : t.options) options.push_back(clone(ty, dest, cloneState)); - TypeId result = dest.addType(UnionTypeVar{std::move(options)}); + TypeId result = dest.addType(UnionType{std::move(options)}); seenTypes[typeId] = result; } -void TypeCloner::operator()(const IntersectionTypeVar& t) +void TypeCloner::operator()(const IntersectionType& t) { - TypeId result = dest.addType(IntersectionTypeVar{}); + TypeId result = dest.addType(IntersectionType{}); seenTypes[typeId] = result; - IntersectionTypeVar* option = getMutable(result); + IntersectionType* option = getMutable(result); LUAU_ASSERT(option != nullptr); for (TypeId ty : t.parts) option->parts.push_back(clone(ty, dest, cloneState)); } -void TypeCloner::operator()(const LazyTypeVar& t) +void TypeCloner::operator()(const LazyType& t) { defaultClone(t); } -void TypeCloner::operator()(const UnknownTypeVar& t) +void TypeCloner::operator()(const UnknownType& t) { defaultClone(t); } -void TypeCloner::operator()(const NeverTypeVar& t) +void TypeCloner::operator()(const NeverType& t) { defaultClone(t); } -void TypeCloner::operator()(const NegationTypeVar& t) +void TypeCloner::operator()(const NegationType& t) { - TypeId result = dest.addType(AnyTypeVar{}); + TypeId result = dest.addType(AnyType{}); seenTypes[typeId] = result; TypeId ty = clone(t.ty, dest, cloneState); - asMutable(result)->ty = NegationTypeVar{ty}; + asMutable(result)->ty = NegationType{ty}; } } // anonymous namespace @@ -430,9 +430,9 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl if (auto pty = log->pending(ty)) ty = &pty->pending; - if (const FunctionTypeVar* ftv = get(ty)) + if (const FunctionType* ftv = get(ty)) { - FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; + FunctionType clone = FunctionType{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; clone.genericPacks = ftv->genericPacks; clone.magicFunction = ftv->magicFunction; @@ -441,10 +441,10 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl clone.argNames = ftv->argNames; result = dest.addType(std::move(clone)); } - else if (const TableTypeVar* ttv = get(ty)) + else if (const TableType* ttv = get(ty)) { LUAU_ASSERT(!ttv->boundTo); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; + TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; @@ -453,41 +453,41 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl clone.tags = ttv->tags; result = dest.addType(std::move(clone)); } - else if (const MetatableTypeVar* mtv = get(ty)) + else if (const MetatableType* mtv = get(ty)) { - MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; + MetatableType clone = MetatableType{mtv->table, mtv->metatable}; clone.syntheticName = mtv->syntheticName; result = dest.addType(std::move(clone)); } - else if (const UnionTypeVar* utv = get(ty)) + else if (const UnionType* utv = get(ty)) { - UnionTypeVar clone; + UnionType clone; clone.options = utv->options; result = dest.addType(std::move(clone)); } - else if (const IntersectionTypeVar* itv = get(ty)) + else if (const IntersectionType* itv = get(ty)) { - IntersectionTypeVar clone; + IntersectionType clone; clone.parts = itv->parts; result = dest.addType(std::move(clone)); } - else if (const PendingExpansionTypeVar* petv = get(ty)) + else if (const PendingExpansionType* petv = get(ty)) { - PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; + PendingExpansionType clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; result = dest.addType(std::move(clone)); } - else if (const ClassTypeVar* ctv = get(ty); FFlag::LuauClonePublicInterfaceLess && ctv && alwaysClone) + else if (const ClassType* ctv = get(ty); FFlag::LuauClonePublicInterfaceLess && ctv && alwaysClone) { - ClassTypeVar clone{ctv->name, ctv->props, ctv->parent, ctv->metatable, ctv->tags, ctv->userData, ctv->definitionModuleName}; + ClassType clone{ctv->name, ctv->props, ctv->parent, ctv->metatable, ctv->tags, ctv->userData, ctv->definitionModuleName}; result = dest.addType(std::move(clone)); } else if (FFlag::LuauClonePublicInterfaceLess && alwaysClone) { result = dest.addType(*ty); } - else if (const NegationTypeVar* ntv = get(ty)) + else if (const NegationType* ntv = get(ty)) { - result = dest.addType(NegationTypeVar{ntv->ty}); + result = dest.addType(NegationType{ntv->ty}); } else return result; diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index f0bd958cf..256eba54d 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -9,11 +9,12 @@ #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypeUtils.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauNegatedClassTypes); namespace Luau { @@ -116,11 +117,11 @@ void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const Con } // namespace ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, - NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, + NotNull moduleResolver, NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, NotNull dfg) : moduleName(moduleName) , module(module) - , singletonTypes(singletonTypes) + , builtinTypes(builtinTypes) , arena(arena) , rootScope(nullptr) , dfg(dfg) @@ -137,7 +138,7 @@ ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, Mod TypeId ConstraintGraphBuilder::freshType(const ScopePtr& scope) { - return arena->addType(FreeTypeVar{scope.get()}); + return arena->addType(FreeType{scope.get()}); } TypePackId ConstraintGraphBuilder::freshTypePack(const ScopePtr& scope) @@ -184,7 +185,7 @@ static void unionRefinements(const std::unordered_map& lhs, const if (auto destIt = dest.find(def); destIt != dest.end()) discriminants.push_back(destIt->second); - dest[def] = arena->addType(UnionTypeVar{std::move(discriminants)}); + dest[def] = arena->addType(UnionType{std::move(discriminants)}); } } @@ -228,15 +229,15 @@ static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, st { TypeId discriminantTy = proposition->discriminantTy; if (!sense && !eq) - discriminantTy = arena->addType(NegationTypeVar{proposition->discriminantTy}); + discriminantTy = arena->addType(NegationType{proposition->discriminantTy}); else if (eq) { - discriminantTy = arena->addType(BlockedTypeVar{}); + discriminantTy = arena->addType(BlockedType{}); constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy, !sense}); } if (auto it = refis->find(proposition->def); it != refis->end()) - (*refis)[proposition->def] = arena->addType(IntersectionTypeVar{{discriminantTy, it->second}}); + (*refis)[proposition->def] = arena->addType(IntersectionType{{discriminantTy, it->second}}); else (*refis)[proposition->def] = discriminantTy; } @@ -251,8 +252,8 @@ static std::pair computeDiscriminantType(NotNull arena if (!current->field) break; - TableTypeVar::Props props{{current->field->propName, Property{discriminantTy}}}; - discriminantTy = arena->addType(TableTypeVar{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); + TableType::Props props{{current->field->propName, Property{discriminantTy}}}; + discriminantTy = arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); def = current->field->parent; current = get(def); @@ -277,7 +278,7 @@ void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location lo if (!defTy) ice->ice("Every DefId must map to a type!"); - TypeId resultTy = arena->addType(IntersectionTypeVar{{*defTy, discriminantTy2}}); + TypeId resultTy = arena->addType(IntersectionType{{*defTy, discriminantTy2}}); scope->dcrRefinements[def2] = resultTy; } @@ -464,7 +465,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (i < local->vars.size) { - TypePack packTypes = extendTypePack(*arena, singletonTypes, exprPack, varTypes.size() - i); + TypePack packTypes = extendTypePack(*arena, builtinTypes, exprPack, varTypes.size() - i); // fill out missing values in varTypes with values from exprPack for (size_t j = i; j < varTypes.size(); ++j) @@ -533,7 +534,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) return; TypeId t = check(scope, expr).ty; - addConstraint(scope, expr->location, SubtypeConstraint{t, singletonTypes->numberType}); + addConstraint(scope, expr->location, SubtypeConstraint{t, builtinTypes->numberType}); }; checkNumber(for_->from); @@ -541,7 +542,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) checkNumber(for_->step); ScopePtr forScope = childScope(for_, scope); - forScope->bindings[for_->var] = Binding{singletonTypes->numberType, for_->var->location}; + forScope->bindings[for_->var] = Binding{builtinTypes->numberType, for_->var->location}; visit(forScope, for_->body); } @@ -603,7 +604,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* auto ty = scope->lookup(function->name); LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. - functionType = arena->addType(BlockedTypeVar{}); + functionType = arena->addType(BlockedType{}); scope->bindings[function->name] = Binding{functionType, function->name->location}; FunctionSignature sig = checkFunctionSignature(scope, function->func); @@ -629,7 +630,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. // With or without self - TypeId generalizedType = arena->addType(BlockedTypeVar{}); + TypeId generalizedType = arena->addType(BlockedType{}); FunctionSignature sig = checkFunctionSignature(scope, function->func); @@ -666,9 +667,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct TypeId containingTableType = check(scope, indexName->expr).ty; // TODO look into stack utilization. This is probably ok because it scales with AST depth. - TypeId prospectiveTableType = arena->addType(TableTypeVar{TableState::Unsealed, TypeLevel{}, scope.get()}); + TypeId prospectiveTableType = arena->addType(TableType{TableState::Unsealed, TypeLevel{}, scope.get()}); - NotNull prospectiveTable{getMutable(prospectiveTableType)}; + NotNull prospectiveTable{getMutable(prospectiveTableType)}; Property& prop = prospectiveTable->props[indexName->index.value]; prop.type = generalizedType; @@ -678,7 +679,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct } else if (AstExprError* err = function->name->as()) { - generalizedType = singletonTypes->errorRecoveryType(); + generalizedType = builtinTypes->errorRecoveryType(); } if (generalizedType == nullptr) @@ -724,7 +725,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { TypePackId varPackId = checkLValues(scope, assign->vars); - TypePack expectedTypes = extendTypePack(*arena, singletonTypes, varPackId, assign->values.size); + TypePack expectedTypes = extendTypePack(*arena, builtinTypes, varPackId, assign->values.size); TypePackId valuePack = checkPack(scope, assign->values, expectedTypes.head).tp; addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); @@ -781,13 +782,13 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia scope->exportedTypeBindings[typeName] = TypeFun{ty}; } - LUAU_ASSERT(get(bindingIt->second.type)); + LUAU_ASSERT(get(bindingIt->second.type)); // Rather than using a subtype constraint, we instead directly bind // the free type we generated in the first pass to the resolved type. // This prevents a case where you could cause another constraint to // bind the free alias type to an unrelated type, causing havoc. - asMutable(bindingIt->second.type)->ty.emplace(ty); + asMutable(bindingIt->second.type)->ty.emplace(ty); addConstraint(scope, alias->location, NameConstraint{ty, alias->name.value}); } @@ -812,7 +813,7 @@ static bool isMetamethod(const Name& name) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) { - std::optional superTy = std::nullopt; + std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; if (declaredClass->superName) { Name superName = Name(declaredClass->superName->value); @@ -828,7 +829,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); superTy = lookupType->type; - if (!get(follow(*superTy))) + if (!get(follow(*superTy))) { reportError(declaredClass->location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)}); @@ -839,11 +840,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d Name className(declaredClass->name.value); - TypeId classTy = arena->addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, moduleName)); - ClassTypeVar* ctv = getMutable(classTy); + TypeId classTy = arena->addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, moduleName)); + ClassType* ctv = getMutable(classTy); - TypeId metaTy = arena->addType(TableTypeVar{TableState::Sealed, scope->level, scope.get()}); - TableTypeVar* metatable = getMutable(metaTy); + TypeId metaTy = arena->addType(TableType{TableState::Sealed, scope->level, scope.get()}); + TableType* metatable = getMutable(metaTy); ctv->metatable = metaTy; @@ -860,7 +861,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d // parsed annotation. Add it here. if (prop.isMethod) { - if (FunctionTypeVar* ftv = getMutable(propTy)) + if (FunctionType* ftv = getMutable(propTy)) { ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = arena->addTypePack(TypePack{{classTy}, ftv->argTypes}); @@ -882,20 +883,20 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. - if (const IntersectionTypeVar* itv = get(currentTy)) + if (const IntersectionType* itv = get(currentTy)) { std::vector options = itv->parts; options.push_back(propTy); - TypeId newItv = arena->addType(IntersectionTypeVar{std::move(options)}); + TypeId newItv = arena->addType(IntersectionType{std::move(options)}); if (assignToMetatable) metatable->props[propName] = {newItv}; else ctv->props[propName] = {newItv}; } - else if (get(currentTy)) + else if (get(currentTy)) { - TypeId intersection = arena->addType(IntersectionTypeVar{{currentTy, propTy}}); + TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); if (assignToMetatable) metatable->props[propName] = {intersection}; @@ -937,8 +938,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction TypePackId paramPack = resolveTypePack(funScope, global->params); TypePackId retPack = resolveTypePack(funScope, global->retTypes); - TypeId fnType = arena->addType(FunctionTypeVar{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); - FunctionTypeVar* ftv = getMutable(fnType); + TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); + FunctionType* ftv = getMutable(fnType); ftv->argNames.reserve(global->paramNames.size); for (const auto& el : global->paramNames) @@ -995,7 +996,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(expr->location); - return InferencePack{singletonTypes->errorRecoveryTypePack()}; + return InferencePack{builtinTypes->errorRecoveryTypePack()}; } InferencePack result; @@ -1007,7 +1008,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* if (scope->varargPack) result = InferencePack{*scope->varargPack}; else - result = InferencePack{singletonTypes->errorRecoveryTypePack()}; + result = InferencePack{builtinTypes->errorRecoveryTypePack()}; } else { @@ -1042,9 +1043,9 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa TypePackId expectedArgPack = arena->freshTypePack(scope.get()); TypePackId expectedRetPack = arena->freshTypePack(scope.get()); - TypeId expectedFunctionType = arena->addType(FunctionTypeVar{expectedArgPack, expectedRetPack}); + TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack}); - TypeId instantiatedFnType = arena->addType(BlockedTypeVar{}); + TypeId instantiatedFnType = arena->addType(BlockedType{}); addConstraint(scope, call->location, InstantiationConstraint{instantiatedFnType, fnType}); NotNull extractArgsConstraint = addConstraint(scope, call->location, SubtypeConstraint{instantiatedFnType, expectedFunctionType}); @@ -1060,9 +1061,9 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa TypePack expectedArgs; if (!needTail) - expectedArgs = extendTypePack(*arena, singletonTypes, expectedArgPack, exprArgs.size()); + expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size()); else - expectedArgs = extendTypePack(*arena, singletonTypes, expectedArgPack, exprArgs.size() - 1); + expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size() - 1); std::vector args; std::optional argTail; @@ -1108,7 +1109,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa }); std::vector returnConnectives; - if (auto ftv = get(follow(fnType)); ftv && ftv->dcrMagicRefinement) + if (auto ftv = get(follow(fnType)); ftv && ftv->dcrMagicRefinement) { MagicRefinementContext ctx{scope, NotNull{this}, dfg, NotNull{&connectiveArena}, std::move(argumentConnectives), call}; returnConnectives = ftv->dcrMagicRefinement(ctx); @@ -1118,7 +1119,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa { TypePack argTailPack; if (argTail && args.size() < 2) - argTailPack = extendTypePack(*arena, singletonTypes, *argTail, 2 - args.size()); + argTailPack = extendTypePack(*arena, builtinTypes, *argTail, 2 - args.size()); LUAU_ASSERT(args.size() + argTailPack.head.size() == 2); @@ -1127,7 +1128,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa AstExpr* targetExpr = call->args.data[0]; - MetatableTypeVar mtv{target, mt}; + MetatableType mtv{target, mt}; TypeId resultTy = arena->addType(mtv); if (AstExprLocal* targetLocal = targetExpr->as()) @@ -1139,11 +1140,11 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa { astOriginalCallTypes[call->func] = fnType; - TypeId instantiatedType = arena->addType(BlockedTypeVar{}); + TypeId instantiatedType = arena->addType(BlockedType{}); // TODO: How do expectedTypes play into this? Do they? TypePackId rets = arena->addTypePack(BlockedTypePack{}); TypePackId argPack = arena->addTypePack(TypePack{args, argTail}); - FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); + FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets); TypeId inferredFnType = arena->addType(ftv); unqueuedConstraints.push_back( @@ -1183,7 +1184,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(expr->location); - return Inference{singletonTypes->errorRecoveryType()}; + return Inference{builtinTypes->errorRecoveryType()}; } Inference result; @@ -1193,11 +1194,11 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st else if (auto stringExpr = expr->as()) result = check(scope, stringExpr, expectedType, forceSingleton); else if (expr->is()) - result = Inference{singletonTypes->numberType}; + result = Inference{builtinTypes->numberType}; else if (auto boolExpr = expr->as()) result = check(scope, boolExpr, expectedType, forceSingleton); else if (expr->is()) - result = Inference{singletonTypes->nilType}; + result = Inference{builtinTypes->nilType}; else if (auto local = expr->as()) result = check(scope, local); else if (auto global = expr->as()) @@ -1218,7 +1219,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st checkFunctionBody(sig.bodyScope, a); Checkpoint endCheckpoint = checkpoint(this); - TypeId generalizedTy = arena->addType(BlockedTypeVar{}); + TypeId generalizedTy = arena->addType(BlockedType{}); NotNull gc = addConstraint(scope, expr->location, GeneralizationConstraint{generalizedTy, sig.signature}); forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { @@ -1247,7 +1248,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st for (AstExpr* subExpr : err->expressions) check(scope, subExpr); - result = Inference{singletonTypes->errorRecoveryType()}; + result = Inference{builtinTypes->errorRecoveryType()}; } else { @@ -1263,30 +1264,30 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton) { if (forceSingleton) - return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; if (expectedType) { const TypeId expectedTy = follow(*expectedType); - if (get(expectedTy) || get(expectedTy)) + if (get(expectedTy) || get(expectedTy)) { - TypeId ty = arena->addType(BlockedTypeVar{}); - TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(string->value.data, string->value.size)})); - addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->stringType}); + TypeId ty = arena->addType(BlockedType{}); + TypeId singletonType = arena->addType(SingletonType(StringSingleton{std::string(string->value.data, string->value.size)})); + addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, builtinTypes->stringType}); return Inference{ty}; } else if (maybeSingleton(expectedTy)) - return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - return Inference{singletonTypes->stringType}; + return Inference{builtinTypes->stringType}; } - return Inference{singletonTypes->stringType}; + return Inference{builtinTypes->stringType}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) { - const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; + const TypeId singletonType = boolExpr->value ? builtinTypes->trueType : builtinTypes->falseType; if (forceSingleton) return Inference{singletonType}; @@ -1294,19 +1295,19 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo { const TypeId expectedTy = follow(*expectedType); - if (get(expectedTy) || get(expectedTy)) + if (get(expectedTy) || get(expectedTy)) { - TypeId ty = arena->addType(BlockedTypeVar{}); - addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->booleanType}); + TypeId ty = arena->addType(BlockedType{}); + addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, builtinTypes->booleanType}); return Inference{ty}; } else if (maybeSingleton(expectedTy)) return Inference{singletonType}; - return Inference{singletonTypes->booleanType}; + return Inference{builtinTypes->booleanType}; } - return Inference{singletonTypes->booleanType}; + return Inference{builtinTypes->booleanType}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) @@ -1323,10 +1324,10 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* loc } if (!resultTy) - return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. + return Inference{builtinTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. if (def) - return Inference{*resultTy, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + return Inference{*resultTy, connectiveArena.proposition(*def, builtinTypes->truthyType)}; else return Inference{*resultTy}; } @@ -1340,24 +1341,24 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* gl * global that is not already in-scope is definitely an unknown symbol. */ reportError(global->location, UnknownSymbol{global->name.value}); - return Inference{singletonTypes->errorRecoveryType()}; + return Inference{builtinTypes->errorRecoveryType()}; } static std::optional lookupProp(TypeId ty, const std::string& propName, NotNull arena) { ty = follow(ty); - if (auto ctv = get(ty)) + if (auto ctv = get(ty)) { if (auto prop = lookupClassProp(ctv, propName)) return prop->type; } - else if (auto ttv = get(ty)) + else if (auto ttv = get(ty)) { if (auto it = ttv->props.find(propName); it != ttv->props.end()) return it->second.type; } - else if (auto utv = get(ty)) + else if (auto utv = get(ty)) { std::vector types; @@ -1375,9 +1376,9 @@ static std::optional lookupProp(TypeId ty, const std::string& propName, if (types.size() == 1) return types[0]; else - return arena->addType(IntersectionTypeVar{std::move(types)}); + return arena->addType(IntersectionType{std::move(types)}); } - else if (auto utv = get(ty)) + else if (auto utv = get(ty)) { std::vector types; @@ -1395,7 +1396,7 @@ static std::optional lookupProp(TypeId ty, const std::string& propName, if (types.size() == 1) return types[0]; else - return arena->addType(UnionTypeVar{std::move(types)}); + return arena->addType(UnionType{std::move(types)}); } return std::nullopt; @@ -1416,21 +1417,21 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* if (def) { if (auto ty = scope->lookup(*def)) - return Inference{*ty, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + return Inference{*ty, connectiveArena.proposition(*def, builtinTypes->truthyType)}; else scope->dcrRefinements[*def] = result; } - TableTypeVar::Props props{{indexName->index.value, Property{result}}}; + TableType::Props props{{indexName->index.value, Property{result}}}; const std::optional indexer; - TableTypeVar ttv{std::move(props), indexer, TypeLevel{}, scope.get(), TableState::Free}; + TableType ttv{std::move(props), indexer, TypeLevel{}, scope.get(), TableState::Free}; TypeId expectedTableType = arena->addType(std::move(ttv)); addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); if (def) - return Inference{result, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + return Inference{result, connectiveArena.proposition(*def, builtinTypes->truthyType)}; else return Inference{result}; } @@ -1443,8 +1444,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* TypeId result = freshType(scope); TableIndexer indexer{indexType, result}; - TypeId tableType = - arena->addType(TableTypeVar{TableTypeVar::Props{}, TableIndexer{indexType, result}, TypeLevel{}, scope.get(), TableState::Free}); + TypeId tableType = arena->addType(TableType{TableType::Props{}, TableIndexer{indexType, result}, TypeLevel{}, scope.get(), TableState::Free}); addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); @@ -1454,7 +1454,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { auto [operandType, connective] = check(scope, unary->expr); - TypeId resultType = arena->addType(BlockedTypeVar{}); + TypeId resultType = arena->addType(BlockedType{}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); if (unary->op == AstExprUnary::Not) @@ -1467,8 +1467,9 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* bi { auto [leftType, rightType, connective] = checkBinary(scope, binary, expectedType); - TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType, binary, &astOriginalCallTypes, &astOverloadResolvedTypes}); + TypeId resultType = arena->addType(BlockedType{}); + addConstraint(scope, binary->location, + BinaryConstraint{binary->op, leftType, rightType, resultType, binary, &astOriginalCallTypes, &astOverloadResolvedTypes}); return Inference{resultType, std::move(connective)}; } @@ -1534,34 +1535,34 @@ std::tuple ConstraintGraphBuilder::checkBinary( if (!def) return {leftType, rightType, nullptr}; - TypeId discriminantTy = singletonTypes->neverType; + TypeId discriminantTy = builtinTypes->neverType; if (typeguard->type == "nil") - discriminantTy = singletonTypes->nilType; + discriminantTy = builtinTypes->nilType; else if (typeguard->type == "string") - discriminantTy = singletonTypes->stringType; + discriminantTy = builtinTypes->stringType; else if (typeguard->type == "number") - discriminantTy = singletonTypes->numberType; + discriminantTy = builtinTypes->numberType; else if (typeguard->type == "boolean") - discriminantTy = singletonTypes->threadType; + discriminantTy = builtinTypes->threadType; else if (typeguard->type == "table") - discriminantTy = singletonTypes->neverType; // TODO: replace with top table type + discriminantTy = builtinTypes->neverType; // TODO: replace with top table type else if (typeguard->type == "function") - discriminantTy = singletonTypes->functionType; + discriminantTy = builtinTypes->functionType; else if (typeguard->type == "userdata") { // For now, we don't really care about being accurate with userdata if the typeguard was using typeof - discriminantTy = singletonTypes->neverType; // TODO: replace with top class type + discriminantTy = builtinTypes->neverType; // TODO: replace with top class type } else if (!typeguard->isTypeof && typeguard->type == "vector") - discriminantTy = singletonTypes->neverType; // TODO: figure out a way to deal with this quirky type + discriminantTy = builtinTypes->neverType; // TODO: figure out a way to deal with this quirky type else if (!typeguard->isTypeof) - discriminantTy = singletonTypes->neverType; + discriminantTy = builtinTypes->neverType; else if (auto typeFun = globalScope->lookupType(typeguard->type); typeFun && typeFun->typeParams.empty() && typeFun->typePackParams.empty()) { TypeId ty = follow(typeFun->type); // We're only interested in the root class of any classes. - if (auto ctv = get(ty); !ctv || !ctv->parent) + if (auto ctv = get(ty); !ctv || !ctv->parent) discriminantTy = ty; } @@ -1685,7 +1686,7 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) std::vector segmentStrings(begin(segments), end(segments)); - TypeId updatedType = arena->addType(BlockedTypeVar{}); + TypeId updatedType = arena->addType(BlockedType{}); addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), propTy}); std::optional def = dfg->getDef(sym); @@ -1700,8 +1701,8 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) { - TypeId ty = arena->addType(TableTypeVar{}); - TableTypeVar* ttv = getMutable(ty); + TypeId ty = arena->addType(TableType{}); + TableType* ttv = getMutable(ty); LUAU_ASSERT(ttv); ttv->state = TableState::Unsealed; @@ -1729,12 +1730,12 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp { ErrorVec errorVec; std::optional propTy = - findTablePropertyRespectingMeta(singletonTypes, errorVec, follow(*expectedType), stringKey->value.data, item.value->location); + findTablePropertyRespectingMeta(builtinTypes, errorVec, follow(*expectedType), stringKey->value.data, item.value->location); if (propTy) expectedValueType = propTy; else { - expectedValueType = arena->addType(BlockedTypeVar{}); + expectedValueType = arena->addType(BlockedType{}); addConstraint(scope, item.value->location, HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data}); } } @@ -1760,7 +1761,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp } else { - TypeId numberType = singletonTypes->numberType; + TypeId numberType = builtinTypes->numberType; // FIXME? The location isn't quite right here. Not sure what is // right. createIndexer(item.value->location, numberType, itemTy); @@ -1821,11 +1822,11 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS std::vector argTypes; TypePack expectedArgPack; - const FunctionTypeVar* expectedFunction = expectedType ? get(*expectedType) : nullptr; + const FunctionType* expectedFunction = expectedType ? get(*expectedType) : nullptr; if (expectedFunction) { - expectedArgPack = extendTypePack(*arena, singletonTypes, expectedFunction->argTypes, fn->args.size); + expectedArgPack = extendTypePack(*arena, builtinTypes, expectedFunction->argTypes, fn->args.size); genericTypes = expectedFunction->generics; genericTypePacks = expectedFunction->genericPacks; @@ -1870,14 +1871,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS else if (expectedArgPack.tail && get(*expectedArgPack.tail)) varargPack = *expectedArgPack.tail; else - varargPack = singletonTypes->anyTypePack; + varargPack = builtinTypes->anyTypePack; signatureScope->varargPack = varargPack; bodyScope->varargPack = varargPack; } else { - varargPack = arena->addTypePack(VariadicTypePack{singletonTypes->anyType, /*hidden*/ true}); + varargPack = arena->addTypePack(VariadicTypePack{builtinTypes->anyType, /*hidden*/ true}); // We do not add to signatureScope->varargPack because ... is not valid // in functions without an explicit ellipsis. @@ -1906,7 +1907,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // TODO: Preserve argument names in the function's type. - FunctionTypeVar actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; + FunctionType actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; actualFunction.hasNoGenerics = !hasGenerics; actualFunction.generics = std::move(genericTypes); actualFunction.genericPacks = std::move(genericTypePacks); @@ -1915,9 +1916,9 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS LUAU_ASSERT(actualFunctionType); astTypes[fn] = actualFunctionType; - if (expectedType && get(*expectedType)) + if (expectedType && get(*expectedType)) { - asMutable(*expectedType)->ty.emplace(actualFunctionType); + asMutable(*expectedType)->ty.emplace(actualFunctionType); } return { @@ -1955,7 +1956,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b if (ref->parameters.size != 1 || !ref->parameters.data[0].type) { reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); - return singletonTypes->errorRecoveryType(); + return builtinTypes->errorRecoveryType(); } else return resolveType(scope, ref->parameters.data[0].type, topLevel); @@ -2006,7 +2007,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b } } - result = arena->addType(PendingExpansionTypeVar{ref->prefix, ref->name, parameters, packParameters}); + result = arena->addType(PendingExpansionType{ref->prefix, ref->name, parameters, packParameters}); if (topLevel) { @@ -2021,12 +2022,12 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b typeName = std::string(ref->prefix->value) + "."; typeName += ref->name.value; - result = singletonTypes->errorRecoveryType(); + result = builtinTypes->errorRecoveryType(); } } else if (auto tab = ty->as()) { - TableTypeVar::Props props; + TableType::Props props; std::optional indexer; for (const AstTableProp& prop : tab->props) @@ -2047,7 +2048,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b }; } - result = arena->addType(TableTypeVar{props, indexer, scope->level, scope.get(), TableState::Sealed}); + result = arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); } else if (auto fn = ty->as()) { @@ -2090,11 +2091,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes); TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes); - // TODO: FunctionTypeVar needs a pointer to the scope so that we know + // TODO: FunctionType needs a pointer to the scope so that we know // how to quantify/instantiate it. - FunctionTypeVar ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; + FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; - // This replicates the behavior of the appropriate FunctionTypeVar + // This replicates the behavior of the appropriate FunctionType // constructors. ftv.hasNoGenerics = !hasGenerics; ftv.generics = std::move(genericTypes); @@ -2131,7 +2132,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b parts.push_back(resolveType(scope, part, topLevel)); } - result = arena->addType(UnionTypeVar{parts}); + result = arena->addType(UnionType{parts}); } else if (auto intersectionAnnotation = ty->as()) { @@ -2142,24 +2143,24 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b parts.push_back(resolveType(scope, part, topLevel)); } - result = arena->addType(IntersectionTypeVar{parts}); + result = arena->addType(IntersectionType{parts}); } else if (auto boolAnnotation = ty->as()) { - result = arena->addType(SingletonTypeVar(BooleanSingleton{boolAnnotation->value})); + result = arena->addType(SingletonType(BooleanSingleton{boolAnnotation->value})); } else if (auto stringAnnotation = ty->as()) { - result = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)})); + result = arena->addType(SingletonType(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)})); } else if (ty->is()) { - result = singletonTypes->errorRecoveryType(); + result = builtinTypes->errorRecoveryType(); } else { LUAU_ASSERT(0); - result = singletonTypes->errorRecoveryType(); + result = builtinTypes->errorRecoveryType(); } astResolvedTypes[ty] = result; @@ -2187,13 +2188,13 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTyp else { reportError(tp->location, UnknownSymbol{gen->genericName.value, UnknownSymbol::Context::Type}); - result = singletonTypes->errorRecoveryTypePack(); + result = builtinTypes->errorRecoveryTypePack(); } } else { LUAU_ASSERT(0); - result = singletonTypes->errorRecoveryTypePack(); + result = builtinTypes->errorRecoveryTypePack(); } astResolvedTypePacks[tp] = result; @@ -2223,7 +2224,7 @@ std::vector> ConstraintGraphBuilder::crea std::vector> result; for (const auto& generic : generics) { - TypeId genericTy = arena->addType(GenericTypeVar{scope.get(), generic.name.value}); + TypeId genericTy = arena->addType(GenericType{scope.get(), generic.name.value}); std::optional defaultTy = std::nullopt; if (generic.defaultValue) @@ -2302,7 +2303,7 @@ struct GlobalPrepopulator : AstVisitor bool visit(AstStatFunction* function) override { if (AstExprGlobal* g = function->name->as()) - globalScope->bindings[g->name] = Binding{arena->addType(BlockedTypeVar{})}; + globalScope->bindings[g->name] = Binding{arena->addType(BlockedType{})}; return true; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index d73c14a60..67c1732c1 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -12,9 +12,9 @@ #include "Luau/Quantify.h" #include "Luau/ToString.h" #include "Luau/TypeUtils.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Unifier.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/VisitType.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); @@ -34,7 +34,7 @@ namespace Luau dumpBindings(child, opts); } -static std::pair, std::vector> saturateArguments(TypeArena* arena, NotNull singletonTypes, +static std::pair, std::vector> saturateArguments(TypeArena* arena, NotNull builtinTypes, const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments) { std::vector saturatedTypeArguments; @@ -115,7 +115,7 @@ static std::pair, std::vector> saturateArguments if (!defaultTy) break; - TypeId instantiatedDefault = atf.substitute(defaultTy).value_or(singletonTypes->errorRecoveryType()); + TypeId instantiatedDefault = atf.substitute(defaultTy).value_or(builtinTypes->errorRecoveryType()); atf.typeArguments[fn.typeParams[i].ty] = instantiatedDefault; saturatedTypeArguments.push_back(instantiatedDefault); } @@ -133,7 +133,7 @@ static std::pair, std::vector> saturateArguments if (!defaultTp) break; - TypePackId instantiatedDefault = atf.substitute(defaultTp).value_or(singletonTypes->errorRecoveryTypePack()); + TypePackId instantiatedDefault = atf.substitute(defaultTp).value_or(builtinTypes->errorRecoveryTypePack()); atf.typePackArguments[fn.typePackParams[i].tp] = instantiatedDefault; saturatedPackArguments.push_back(instantiatedDefault); } @@ -151,12 +151,12 @@ static std::pair, std::vector> saturateArguments // even if they're missing, so we use the error type as a filler. for (size_t i = saturatedTypeArguments.size(); i < typesRequired; ++i) { - saturatedTypeArguments.push_back(singletonTypes->errorRecoveryType()); + saturatedTypeArguments.push_back(builtinTypes->errorRecoveryType()); } for (size_t i = saturatedPackArguments.size(); i < packsRequired; ++i) { - saturatedPackArguments.push_back(singletonTypes->errorRecoveryTypePack()); + saturatedPackArguments.push_back(builtinTypes->errorRecoveryTypePack()); } // At this point, these two conditions should be true. If they aren't we @@ -229,7 +229,7 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) : arena(normalizer->arena) - , singletonTypes(normalizer->singletonTypes) + , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) , constraints(std::move(constraints)) , rootScope(rootScope) @@ -373,12 +373,12 @@ bool ConstraintSolver::isDone() void ConstraintSolver::finalizeModule() { - Anyification a{arena, rootScope, singletonTypes, &iceReporter, singletonTypes->anyType, singletonTypes->anyTypePack}; + Anyification a{arena, rootScope, builtinTypes, &iceReporter, builtinTypes->anyType, builtinTypes->anyTypePack}; std::optional returnType = a.substitute(rootScope->returnType); if (!returnType) { reportError(CodeTooComplex{}, Location{}); - rootScope->returnType = singletonTypes->errorTypePack; + rootScope->returnType = builtinTypes->errorTypePack; } else rootScope->returnType = *returnType; @@ -470,7 +470,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullscope); if (isBlocked(c.generalizedType)) - asMutable(c.generalizedType)->ty.emplace(generalized); + asMutable(c.generalizedType)->ty.emplace(generalized); else unify(c.generalizedType, generalized, constraint->scope); @@ -491,7 +491,7 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNullty.emplace(*instantiated); + asMutable(c.subType)->ty.emplace(*instantiated); else unify(c.subType, *instantiated, constraint->scope); @@ -507,62 +507,62 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull(operandType)) + if (get(operandType)) return block(operandType, constraint); - LUAU_ASSERT(get(c.resultType)); + LUAU_ASSERT(get(c.resultType)); switch (c.op) { case AstExprUnary::Not: { - asMutable(c.resultType)->ty.emplace(singletonTypes->booleanType); + asMutable(c.resultType)->ty.emplace(builtinTypes->booleanType); return true; } case AstExprUnary::Len: { // __len must return a number. - asMutable(c.resultType)->ty.emplace(singletonTypes->numberType); + asMutable(c.resultType)->ty.emplace(builtinTypes->numberType); return true; } case AstExprUnary::Minus: { - if (isNumber(operandType) || get(operandType) || get(operandType)) + if (isNumber(operandType) || get(operandType) || get(operandType)) { - asMutable(c.resultType)->ty.emplace(c.operandType); + asMutable(c.resultType)->ty.emplace(c.operandType); } - else if (std::optional mm = findMetatableEntry(singletonTypes, errors, operandType, "__unm", constraint->location)) + else if (std::optional mm = findMetatableEntry(builtinTypes, errors, operandType, "__unm", constraint->location)) { - const FunctionTypeVar* ftv = get(follow(*mm)); + const FunctionType* ftv = get(follow(*mm)); if (!ftv) { - if (std::optional callMm = findMetatableEntry(singletonTypes, errors, follow(*mm), "__call", constraint->location)) + if (std::optional callMm = findMetatableEntry(builtinTypes, errors, follow(*mm), "__call", constraint->location)) { - ftv = get(follow(*callMm)); + ftv = get(follow(*callMm)); } } if (!ftv) { - asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); return true; } TypePackId argsPack = arena->addTypePack({operandType}); unify(ftv->argTypes, argsPack, constraint->scope); - TypeId result = singletonTypes->errorRecoveryType(); + TypeId result = builtinTypes->errorRecoveryType(); if (ftv) { - result = first(ftv->retTypes).value_or(singletonTypes->errorRecoveryType()); + result = first(ftv->retTypes).value_or(builtinTypes->errorRecoveryType()); } - asMutable(c.resultType)->ty.emplace(result); + asMutable(c.resultType)->ty.emplace(result); } else { - asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); } return true; @@ -598,14 +598,14 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) && !isLogical) + if (get(leftType) && !isLogical) return block(leftType, constraint); } // Logical expressions may proceed if the LHS is free. - if (isBlocked(leftType) || (get(leftType) && !isLogical)) + if (isBlocked(leftType) || (get(leftType) && !isLogical)) { - asMutable(resultType)->ty.emplace(errorRecoveryType()); + asMutable(resultType)->ty.emplace(errorRecoveryType()); unblock(resultType); return true; } @@ -615,10 +615,10 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(singletonTypes->booleanType); + asMutable(resultType)->ty.emplace(builtinTypes->booleanType); unblock(resultType); return true; } @@ -627,9 +627,9 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull leftMm = findMetatableEntry(singletonTypes, errors, leftType, it->second, constraint->location)) + if (std::optional leftMm = findMetatableEntry(builtinTypes, errors, leftType, it->second, constraint->location)) mm = leftMm; - else if (std::optional rightMm = findMetatableEntry(singletonTypes, errors, rightType, it->second, constraint->location)) + else if (std::optional rightMm = findMetatableEntry(builtinTypes, errors, rightType, it->second, constraint->location)) mm = rightMm; if (mm) @@ -644,7 +644,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(follow(*instantiatedMm))) + if (const FunctionType* ftv = get(follow(*instantiatedMm))) { TypePackId inferredArgs; // For >= and > we invoke __lt and __le respectively with @@ -672,13 +672,13 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullbooleanType; + mmResult = builtinTypes->booleanType; break; default: mmResult = first(ftv->retTypes).value_or(errorRecoveryType()); } - asMutable(resultType)->ty.emplace(mmResult); + asMutable(resultType)->ty.emplace(mmResult); unblock(resultType); (*c.astOriginalCallTypes)[c.expr] = *mm; @@ -691,8 +691,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) || get(leftType); - bool rightAny = get(rightType) || get(rightType); + bool leftAny = get(leftType) || get(leftType); + bool rightAny = get(rightType) || get(rightType); bool anyPresent = leftAny || rightAny; switch (c.op) @@ -708,7 +708,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope); - asMutable(resultType)->ty.emplace(anyPresent ? singletonTypes->anyType : leftType); + asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); unblock(resultType); return true; } @@ -720,7 +720,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope); - asMutable(resultType)->ty.emplace(anyPresent ? singletonTypes->anyType : leftType); + asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); unblock(resultType); return true; } @@ -734,7 +734,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(singletonTypes->booleanType); + asMutable(resultType)->ty.emplace(builtinTypes->booleanType); unblock(resultType); return true; } @@ -744,16 +744,16 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(singletonTypes->booleanType); + asMutable(resultType)->ty.emplace(builtinTypes->booleanType); unblock(resultType); return true; // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is // truthy. case AstExprBinary::Op::And: { - TypeId leftFilteredTy = arena->addType(IntersectionTypeVar{{singletonTypes->falsyType, leftType}}); + TypeId leftFilteredTy = arena->addType(IntersectionType{{builtinTypes->falsyType, leftType}}); - asMutable(resultType)->ty.emplace(arena->addType(UnionTypeVar{{leftFilteredTy, rightType}})); + asMutable(resultType)->ty.emplace(arena->addType(UnionType{{leftFilteredTy, rightType}})); unblock(resultType); return true; } @@ -761,9 +761,9 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNulladdType(IntersectionTypeVar{{singletonTypes->truthyType, leftType}}); + TypeId leftFilteredTy = arena->addType(IntersectionType{{builtinTypes->truthyType, leftType}}); - asMutable(resultType)->ty.emplace(arena->addType(UnionTypeVar{{leftFilteredTy, rightType}})); + asMutable(resultType)->ty.emplace(arena->addType(UnionType{{leftFilteredTy, rightType}})); unblock(resultType); return true; } @@ -775,7 +775,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope); unify(rightType, errorRecoveryType(), constraint->scope); - asMutable(resultType)->ty.emplace(errorRecoveryType()); + asMutable(resultType)->ty.emplace(errorRecoveryType()); unblock(resultType); return true; @@ -840,7 +840,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope, singletonTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; + Anyification anyify{arena, constraint->scope, builtinTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; std::optional anyified = anyify.substitute(c.variables); LUAU_ASSERT(anyified); unify(*anyified, c.variables, constraint->scope); @@ -849,16 +849,16 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull(nextTy)) + if (get(nextTy)) return block_(nextTy); - if (get(nextTy)) + if (get(nextTy)) { - TypeId tableTy = singletonTypes->nilType; + TypeId tableTy = builtinTypes->nilType; if (iteratorTypes.size() >= 2) tableTy = iteratorTypes[1]; - TypeId firstIndexTy = singletonTypes->nilType; + TypeId firstIndexTy = builtinTypes->nilType; if (iteratorTypes.size() >= 3) firstIndexTy = iteratorTypes[2]; @@ -881,11 +881,11 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNullpersistent || target->owningArena != arena) return true; - if (TableTypeVar* ttv = getMutable(target)) + if (TableType* ttv = getMutable(target)) ttv->name = c.name; - else if (MetatableTypeVar* mtv = getMutable(target)) + else if (MetatableType* mtv = getMutable(target)) mtv->syntheticName = c.name; - else if (get(target) || get(target)) + else if (get(target) || get(target)) { // nothing (yet) } @@ -895,7 +895,7 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull tf = (petv.prefix) ? scope->lookupImportedType(petv.prefix->value, petv.name.value) : scope->lookupType(petv.name.value); @@ -917,7 +917,7 @@ struct InfiniteTypeFinder : TypeVarOnceVisitor if (!tf.has_value()) return true; - auto [typeArguments, packArguments] = saturateArguments(solver->arena, solver->singletonTypes, *tf, petv.typeArguments, petv.packArguments); + auto [typeArguments, packArguments] = saturateArguments(solver->arena, solver->builtinTypes, *tf, petv.typeArguments, petv.packArguments); if (follow(tf->type) == follow(signature.fn.type) && (signature.arguments != typeArguments || signature.packArguments != packArguments)) { @@ -929,7 +929,7 @@ struct InfiniteTypeFinder : TypeVarOnceVisitor } }; -struct InstantiationQueuer : TypeVarOnceVisitor +struct InstantiationQueuer : TypeOnceVisitor { ConstraintSolver* solver; const InstantiationSignature& signature; @@ -944,7 +944,7 @@ struct InstantiationQueuer : TypeVarOnceVisitor { } - bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override + bool visit(TypeId ty, const PendingExpansionType& petv) override { solver->pushConstraint(scope, location, TypeAliasExpansionConstraint{ty}); return false; @@ -953,7 +953,7 @@ struct InstantiationQueuer : TypeVarOnceVisitor bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint) { - const PendingExpansionTypeVar* petv = get(follow(c.target)); + const PendingExpansionType* petv = get(follow(c.target)); if (!petv) { unblock(c.target); @@ -961,7 +961,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul } auto bindResult = [this, &c](TypeId result) { - asMutable(c.target)->ty.emplace(result); + asMutable(c.target)->ty.emplace(result); unblock(c.target); }; @@ -983,7 +983,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } - auto [typeArguments, packArguments] = saturateArguments(arena, singletonTypes, *tf, petv->typeArguments, petv->packArguments); + auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); bool sameTypes = std::equal(typeArguments.begin(), typeArguments.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& p) { return itp == p.ty; @@ -1067,7 +1067,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // there are e.g. generic saturatedTypeArguments that go unused. bool needsClone = follow(tf->type) == target; // Only tables have the properties we're trying to set. - TableTypeVar* ttv = getMutableTableType(target); + TableType* ttv = getMutableTableType(target); if (ttv) { @@ -1078,17 +1078,17 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // explicitly clone that table as well. If we don't, we will // mutate another module's type surface and cause a // use-after-free. - if (get(target)) + if (get(target)) { instantiated = applyTypeFunction.clone(target); - MetatableTypeVar* mtv = getMutable(instantiated); + MetatableType* mtv = getMutable(instantiated); mtv->table = applyTypeFunction.clone(mtv->table); - ttv = getMutable(mtv->table); + ttv = getMutable(mtv->table); } - else if (get(target)) + else if (get(target)) { instantiated = applyTypeFunction.clone(target); - ttv = getMutable(instantiated); + ttv = getMutable(instantiated); } target = follow(instantiated); @@ -1123,16 +1123,15 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull callMm = findMetatableEntry(singletonTypes, errors, fn, "__call", constraint->location)) + if (std::optional callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location)) { std::vector args{fn}; for (TypeId arg : c.argsPack) args.push_back(arg); - TypeId instantiatedType = arena->addType(BlockedTypeVar{}); - TypeId inferredFnType = - arena->addType(FunctionTypeVar(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result)); + TypeId instantiatedType = arena->addType(BlockedType{}); + TypeId inferredFnType = arena->addType(FunctionType(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result)); // Alter the inner constraints. LUAU_ASSERT(c.innerConstraints.size() == 2); @@ -1158,7 +1157,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn); + const FunctionType* ftv = get(fn); bool usedMagic = false; if (ftv && ftv->dcrMagicFunction != nullptr) @@ -1203,11 +1202,11 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull constraint) { TypeId expectedType = follow(c.expectedType); - if (isBlocked(expectedType) || get(expectedType)) + if (isBlocked(expectedType) || get(expectedType)) return block(expectedType, constraint); TypeId bindTo = maybeSingleton(expectedType) ? c.singletonType : c.multitonType; - asMutable(c.resultType)->ty.emplace(bindTo); + asMutable(c.resultType)->ty.emplace(bindTo); return true; } @@ -1216,14 +1215,14 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(subjectType)) + if (isBlocked(subjectType) || get(subjectType)) return block(subjectType, constraint); - if (get(subjectType)) + if (get(subjectType)) { - TableTypeVar& ttv = asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, constraint->scope); + TableType& ttv = asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, constraint->scope); ttv.props[c.prop] = Property{c.resultType}; - asMutable(c.resultType)->ty.emplace(constraint->scope); + asMutable(c.resultType)->ty.emplace(constraint->scope); unblock(c.resultType); return true; } @@ -1231,7 +1230,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull resultType = lookupTableProp(subjectType, c.prop); if (!resultType) { - asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); unblock(c.resultType); return true; } @@ -1242,14 +1241,14 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullty.emplace(*resultType); + asMutable(c.resultType)->ty.emplace(*resultType); return true; } static bool isUnsealedTable(TypeId ty) { ty = follow(ty); - const TableTypeVar* ttv = get(ty); + const TableType* ttv = get(ty); return ttv && ttv->state == TableState::Unsealed; } @@ -1278,7 +1277,7 @@ static std::optional updateTheTableType(NotNull arena, TypeId if (!isUnsealedTable(t)) return std::nullopt; - const TableTypeVar* tbl = get(t); + const TableType* tbl = get(t); auto it = tbl->props.find(path[i]); if (it == tbl->props.end()) return std::nullopt; @@ -1291,7 +1290,7 @@ static std::optional updateTheTableType(NotNull arena, TypeId // new property to be appended. if (!isUnsealedTable(t)) return std::nullopt; - const TableTypeVar* tbl = get(t); + const TableType* tbl = get(t); if (0 != tbl->props.count(path.back())) return std::nullopt; } @@ -1303,7 +1302,7 @@ static std::optional updateTheTableType(NotNull arena, TypeId { const std::string segment = path[i]; - TableTypeVar* ttv = getMutable(t); + TableType* ttv = getMutable(t); LUAU_ASSERT(ttv); auto propIt = ttv->props.find(segment); @@ -1317,7 +1316,7 @@ static std::optional updateTheTableType(NotNull arena, TypeId return std::nullopt; } - TableTypeVar* ttv = getMutable(t); + TableType* ttv = getMutable(t); LUAU_ASSERT(ttv); const std::string lastSegment = path.back(); @@ -1350,7 +1349,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullty.emplace(b); + asMutable(a)->ty.emplace(b); }; if (existingPropType) @@ -1360,14 +1359,14 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) + if (get(subjectType)) { TypeId ty = arena->freshType(constraint->scope); // Mint a chain of free tables per c.path for (auto it = rbegin(c.path); it != rend(c.path); ++it) { - TableTypeVar t{TableState::Free, TypeLevel{}, constraint->scope}; + TableType t{TableState::Free, TypeLevel{}, constraint->scope}; t.props[*it] = {ty}; ty = arena->addType(std::move(t)); @@ -1379,7 +1378,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) + else if (auto ttv = getMutable(subjectType)) { if (ttv->state == TableState::Free) { @@ -1399,7 +1398,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType) || get(subjectType)) + else if (get(subjectType) || get(subjectType)) { bind(c.resultType, subjectType); return true; @@ -1417,12 +1416,12 @@ bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNul TypeId followed = follow(c.discriminantType); // `nil` is a singleton type too! There's only one value of type `nil`. - if (c.negated && (get(followed) || isNil(followed))) - *asMutable(c.resultType) = NegationTypeVar{c.discriminantType}; - else if (!c.negated && get(followed)) - *asMutable(c.resultType) = BoundTypeVar{c.discriminantType}; + if (c.negated && (get(followed) || isNil(followed))) + *asMutable(c.resultType) = NegationType{c.discriminantType}; + else if (!c.negated && get(followed)) + *asMutable(c.resultType) = BoundType{c.discriminantType}; else - *asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType}; + *asMutable(c.resultType) = BoundType{builtinTypes->unknownType}; return true; } @@ -1445,11 +1444,11 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl // We may have to block here if we don't know what the iteratee type is, // if it's a free table, if we don't know it has a metatable, and so on. iteratorTy = follow(iteratorTy); - if (get(iteratorTy)) + if (get(iteratorTy)) return block_(iteratorTy); auto anyify = [&](auto ty) { - Anyification anyify{arena, constraint->scope, singletonTypes, &iceReporter, singletonTypes->anyType, singletonTypes->anyTypePack}; + Anyification anyify{arena, constraint->scope, builtinTypes, &iceReporter, builtinTypes->anyType, builtinTypes->anyTypePack}; std::optional anyified = anyify.substitute(ty); if (!anyified) reportError(CodeTooComplex{}, constraint->location); @@ -1458,7 +1457,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl }; auto errorify = [&](auto ty) { - Anyification anyify{arena, constraint->scope, singletonTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; + Anyification anyify{arena, constraint->scope, builtinTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; std::optional errorified = anyify.substitute(ty); if (!errorified) reportError(CodeTooComplex{}, constraint->location); @@ -1466,13 +1465,13 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl unify(*errorified, ty, constraint->scope); }; - if (get(iteratorTy)) + if (get(iteratorTy)) { anyify(c.variables); return true; } - if (get(iteratorTy)) + if (get(iteratorTy)) { errorify(c.variables); return true; @@ -1481,7 +1480,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl // Irksome: I don't think we have any way to guarantee that this table // type never has a metatable. - if (auto iteratorTable = get(iteratorTy)) + if (auto iteratorTable = get(iteratorTy)) { if (iteratorTable->state == TableState::Free) return block_(iteratorTy); @@ -1494,7 +1493,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl else errorify(c.variables); } - else if (std::optional iterFn = findMetatableEntry(singletonTypes, errors, iteratorTy, "__iter", Location{})) + else if (std::optional iterFn = findMetatableEntry(builtinTypes, errors, iteratorTy, "__iter", Location{})) { if (isBlocked(*iterFn)) { @@ -1505,12 +1504,12 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (std::optional instantiatedIterFn = instantiation.substitute(*iterFn)) { - if (auto iterFtv = get(*instantiatedIterFn)) + if (auto iterFtv = get(*instantiatedIterFn)) { TypePackId expectedIterArgs = arena->addTypePack({iteratorTy}); unify(iterFtv->argTypes, expectedIterArgs, constraint->scope); - TypePack iterRets = extendTypePack(*arena, singletonTypes, iterFtv->retTypes, 2); + TypePack iterRets = extendTypePack(*arena, builtinTypes, iterFtv->retTypes, 2); if (iterRets.head.size() < 1) { @@ -1527,11 +1526,11 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl const TypeId firstIndex = arena->freshType(constraint->scope); // nextTy : (iteratorTy, indexTy?) -> (indexTy, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({table, arena->addType(UnionTypeVar{{firstIndex, singletonTypes->nilType}})}); + const TypePackId nextArgPack = arena->addTypePack({table, arena->addType(UnionType{{firstIndex, builtinTypes->nilType}})}); const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); - const TypeId expectedNextTy = arena->addType(FunctionTypeVar{nextArgPack, nextRetPack}); + const TypeId expectedNextTy = arena->addType(FunctionType{nextArgPack, nextRetPack}); unify(*instantiatedNextFn, expectedNextTy, constraint->scope); pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); @@ -1551,10 +1550,10 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl reportError(UnificationTooComplex{}, constraint->location); } } - else if (auto iteratorMetatable = get(iteratorTy)) + else if (auto iteratorMetatable = get(iteratorTy)) { TypeId metaTy = follow(iteratorMetatable->metatable); - if (get(metaTy)) + if (get(metaTy)) return block_(metaTy); LUAU_ASSERT(false); @@ -1571,7 +1570,7 @@ bool ConstraintSolver::tryDispatchIterableFunction( // We need to know whether or not this type is nil or not. // If we don't know, block and reschedule ourselves. firstIndexTy = follow(firstIndexTy); - if (get(firstIndexTy)) + if (get(firstIndexTy)) { if (force) LUAU_ASSERT(false); @@ -1584,11 +1583,11 @@ bool ConstraintSolver::tryDispatchIterableFunction( : firstIndexTy; // nextTy : (tableTy, indexTy?) -> (indexTy, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({tableTy, arena->addType(UnionTypeVar{{firstIndex, singletonTypes->nilType}})}); + const TypePackId nextArgPack = arena->addTypePack({tableTy, arena->addType(UnionType{{firstIndex, builtinTypes->nilType}})}); const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); - const TypeId expectedNextTy = arena->addType(FunctionTypeVar{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); + const TypeId expectedNextTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); unify(nextTy, expectedNextTy, constraint->scope); pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); @@ -1605,9 +1604,9 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons for (TypeId expectedPart : unionOrIntersection) { expectedPart = follow(expectedPart); - if (isBlocked(expectedPart) || get(expectedPart)) + if (isBlocked(expectedPart) || get(expectedPart)) blocked = expectedPart; - else if (const TableTypeVar* ttv = get(follow(expectedPart))) + else if (const TableType* ttv = get(follow(expectedPart))) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) parts.push_back(prop->second.type); @@ -1621,14 +1620,14 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons std::optional resultType; - if (auto ttv = get(subjectType)) + if (auto ttv = get(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) resultType = prop->second.type; else if (ttv->indexer && maybeString(ttv->indexer->indexType)) resultType = ttv->indexer->indexResultType; } - else if (auto utv = get(subjectType)) + else if (auto utv = get(subjectType)) { auto [blocked, parts] = collectParts(utv); @@ -1637,11 +1636,11 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons else if (parts.size() == 1) resultType = parts[0]; else if (parts.size() > 1) - resultType = arena->addType(UnionTypeVar{std::move(parts)}); + resultType = arena->addType(UnionType{std::move(parts)}); else LUAU_ASSERT(false); // parts.size() == 0 } - else if (auto itv = get(subjectType)) + else if (auto itv = get(subjectType)) { auto [blocked, parts] = collectParts(itv); @@ -1650,7 +1649,7 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons else if (parts.size() == 1) resultType = parts[0]; else if (parts.size() > 1) - resultType = arena->addType(IntersectionTypeVar{std::move(parts)}); + resultType = arena->addType(IntersectionType{std::move(parts)}); else LUAU_ASSERT(false); // parts.size() == 0 } @@ -1701,7 +1700,7 @@ bool ConstraintSolver::block(TypePackId target, NotNull constr return false; } -struct Blocker : TypeVarOnceVisitor +struct Blocker : TypeOnceVisitor { NotNull solver; NotNull constraint; @@ -1714,14 +1713,14 @@ struct Blocker : TypeVarOnceVisitor { } - bool visit(TypeId ty, const BlockedTypeVar&) + bool visit(TypeId ty, const BlockedType&) { blocked = true; solver->block(ty, constraint); return false; } - bool visit(TypeId ty, const PendingExpansionTypeVar&) + bool visit(TypeId ty, const PendingExpansionType&) { blocked = true; solver->block(ty, constraint); @@ -1805,7 +1804,7 @@ void ConstraintSolver::unblock(const std::vector& packs) bool ConstraintSolver::isBlocked(TypeId ty) { - return nullptr != get(follow(ty)) || nullptr != get(follow(ty)); + return nullptr != get(follow(ty)) || nullptr != get(follow(ty)); } bool ConstraintSolver::isBlocked(TypePackId tp) @@ -1878,7 +1877,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l for (const auto& [location, path] : requireCycles) { if (!path.empty() && path.front() == humanReadableName) - return singletonTypes->anyType; + return builtinTypes->anyType; } ModulePtr module = moduleResolver->getModule(info.name); @@ -1924,12 +1923,12 @@ void ConstraintSolver::reportError(TypeError e) TypeId ConstraintSolver::errorRecoveryType() const { - return singletonTypes->errorRecoveryType(); + return builtinTypes->errorRecoveryType(); } TypePackId ConstraintSolver::errorRecoveryTypePack() const { - return singletonTypes->errorRecoveryTypePack(); + return builtinTypes->errorRecoveryTypePack(); } TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes) @@ -1937,7 +1936,7 @@ TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, a = follow(a); b = follow(b); - if (unifyFreeTypes && (get(a) || get(b))) + if (unifyFreeTypes && (get(a) || get(b))) { Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; u.useScopes = true; @@ -1950,7 +1949,7 @@ TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, } else { - return singletonTypes->errorRecoveryType(singletonTypes->anyType); + return builtinTypes->errorRecoveryType(builtinTypes->anyType); } } @@ -1959,12 +1958,12 @@ TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, std::vector types = reduceUnion({a, b}); if (types.empty()) - return singletonTypes->neverType; + return builtinTypes->neverType; if (types.size() == 1) return types[0]; - return arena->addType(UnionTypeVar{types}); + return arena->addType(UnionType{types}); } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 748cf20fa..a527b2440 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -131,9 +131,9 @@ struct ErrorConverter std::string operator()(const Luau::UnknownProperty& e) const { TypeId t = follow(e.table); - if (get(t)) + if (get(t)) return "Key '" + e.key + "' not found in table '" + Luau::toString(t) + "'"; - else if (get(t)) + else if (get(t)) return "Key '" + e.key + "' not found in class '" + Luau::toString(t) + "'"; else return "Type '" + Luau::toString(e.table) + "' does not have key '" + e.key + "'"; @@ -301,7 +301,7 @@ struct ErrorConverter std::string s = "Key '" + e.key + "' not found in "; TypeId t = follow(e.table); - if (get(t)) + if (get(t)) s += "class"; else s += "table"; @@ -952,7 +952,7 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) copyError(e, destArena, cloneState); }; - LUAU_ASSERT(!destArena.typeVars.isFrozen()); + LUAU_ASSERT(!destArena.types.isFrozen()); LUAU_ASSERT(!destArena.typePacks.isFrozen()); for (TypeError& error : errors) diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index e21e42e14..5d2c15871 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -65,14 +65,14 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) asMutable(ty)->documentationSymbol = rootName; - if (TableTypeVar* ttv = getMutable(ty)) + if (TableType* ttv = getMutable(ty)) { for (auto& [name, prop] : ttv->props) { prop.documentationSymbol = rootName + "." + name; } } - else if (ClassTypeVar* ctv = getMutable(ty)) + else if (ClassType* ctv = getMutable(ty)) { for (auto& [name, prop] : ctv->props) { @@ -408,12 +408,12 @@ double getTimestamp() } // namespace Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options) - : singletonTypes(NotNull{&singletonTypes_}) + : builtinTypes(NotNull{&builtinTypes_}) , fileResolver(fileResolver) , moduleResolver(this) , moduleResolverForAutocomplete(this) - , typeChecker(&moduleResolver, singletonTypes, &iceHandler) - , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, singletonTypes, &iceHandler) + , typeChecker(&moduleResolver, builtinTypes, &iceHandler) + , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, builtinTypes, &iceHandler) , configResolver(configResolver) , options(options) , globalScope(typeChecker.globalScope) @@ -455,12 +455,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional buildQueue; - bool cycleDetected = parseGraph(buildQueue, checkResult, name, frontendOptions.forAutocomplete); - - // Keep track of which AST nodes we've reported cycles in - std::unordered_set reportedCycles; - - double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; + bool cycleDetected = parseGraph(buildQueue, name, frontendOptions.forAutocomplete); for (const ModuleName& moduleName : buildQueue) { @@ -499,6 +494,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete) +bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& root, bool forAutocomplete) { LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); @@ -618,7 +615,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec bool cyclic = false; { - auto [sourceNode, _] = getSourceNode(checkResult, root); + auto [sourceNode, _] = getSourceNode(root); if (sourceNode) stack.push_back(sourceNode); } @@ -682,7 +679,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec } } - auto [sourceNode, _] = getSourceNode(checkResult, dep); + auto [sourceNode, _] = getSourceNode(dep); if (sourceNode) { stack.push_back(sourceNode); @@ -729,8 +726,7 @@ LintResult Frontend::lint(const ModuleName& name, std::optional mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; - Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}}; + Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&typeChecker.unifierState}}; ConstraintGraphBuilder cgb{ sourceModule.name, result, &result->internalTypes, mr, - singletonTypes, + builtinTypes, NotNull(&iceHandler), globalScope, logger.get(), @@ -910,7 +906,12 @@ ModulePtr Frontend::check( result->astResolvedTypePacks = std::move(cgb.astResolvedTypePacks); result->type = sourceModule.type; - Luau::check(singletonTypes, logger.get(), sourceModule, result.get()); + result->clonePublicInterface(builtinTypes, iceHandler); + + freeze(result->internalTypes); + freeze(result->interfaceTypes); + + Luau::check(builtinTypes, logger.get(), sourceModule, result.get()); if (FFlag::DebugLuauLogSolverToJson) { @@ -918,13 +919,11 @@ ModulePtr Frontend::check( printf("%s\n", output.c_str()); } - result->clonePublicInterface(singletonTypes, iceHandler); - return result; } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. -std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) +std::pair Frontend::getSourceNode(const ModuleName& name) { LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 3d0cd0d11..209ba7e90 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -11,7 +11,7 @@ namespace Luau bool Instantiation::isDirty(TypeId ty) { - if (const FunctionTypeVar* ftv = log->getMutable(ty)) + if (const FunctionType* ftv = log->getMutable(ty)) { if (ftv->hasNoGenerics) return false; @@ -31,9 +31,9 @@ bool Instantiation::isDirty(TypePackId tp) bool Instantiation::ignoreChildren(TypeId ty) { - if (log->getMutable(ty)) + if (log->getMutable(ty)) return true; - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) return true; else return false; @@ -41,10 +41,10 @@ bool Instantiation::ignoreChildren(TypeId ty) TypeId Instantiation::clean(TypeId ty) { - const FunctionTypeVar* ftv = log->getMutable(ty); + const FunctionType* ftv = log->getMutable(ty); LUAU_ASSERT(ftv); - FunctionTypeVar clone = FunctionTypeVar{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; + FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.magicFunction = ftv->magicFunction; clone.dcrMagicFunction = ftv->dcrMagicFunction; clone.tags = ftv->tags; @@ -71,7 +71,7 @@ TypePackId Instantiation::clean(TypePackId tp) bool ReplaceGenerics::ignoreChildren(TypeId ty) { - if (const FunctionTypeVar* ftv = log->getMutable(ty)) + if (const FunctionType* ftv = log->getMutable(ty)) { if (ftv->hasNoGenerics) return true; @@ -83,7 +83,7 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) // whenever we quantify, so the vectors overlap if and only if they are equal. return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); } - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) return true; else { @@ -93,9 +93,9 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) bool ReplaceGenerics::isDirty(TypeId ty) { - if (const TableTypeVar* ttv = log->getMutable(ty)) + if (const TableType* ttv = log->getMutable(ty)) return ttv->state == TableState::Generic; - else if (log->getMutable(ty)) + else if (log->getMutable(ty)) return std::find(generics.begin(), generics.end(), ty) != generics.end(); else return false; @@ -112,14 +112,14 @@ bool ReplaceGenerics::isDirty(TypePackId tp) TypeId ReplaceGenerics::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = log->getMutable(ty)) + if (const TableType* ttv = log->getMutable(ty)) { - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, scope, TableState::Free}; + TableType clone = TableType{ttv->props, ttv->indexer, level, scope, TableState::Free}; clone.definitionModuleName = ttv->definitionModuleName; return addType(std::move(clone)); } else - return addType(FreeTypeVar{scope, level}); + return addType(FreeType{scope, level}); } TypePackId ReplaceGenerics::clean(TypePackId tp) diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index cd59fdfb1..43580da4d 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -215,7 +215,7 @@ std::ostream& operator<<(std::ostream& stream, const TableState& tv) return stream << static_cast::type>(tv); } -std::ostream& operator<<(std::ostream& stream, const TypeVar& tv) +std::ostream& operator<<(std::ostream& stream, const Type& tv) { return stream << toString(tv); } diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index d5578446a..4250b3117 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -2144,14 +2144,14 @@ class LintDeprecatedApi : AstVisitor if (!ty) return true; - if (const ClassTypeVar* cty = get(follow(*ty))) + if (const ClassType* cty = get(follow(*ty))) { const Property* prop = lookupClassProp(cty, node->index.value); if (prop && prop->deprecated) report(node->location, *prop, cty->name.c_str(), node->index.value); } - else if (const TableTypeVar* tty = get(follow(*ty))) + else if (const TableType* tty = get(follow(*ty))) { auto prop = tty->props.find(node->index.value); @@ -2302,16 +2302,16 @@ class LintTableOperations : AstVisitor size_t getReturnCount(TypeId ty) { - if (auto ftv = get(ty)) + if (auto ftv = get(ty)) return size(ftv->retTypes); - if (auto itv = get(ty)) + if (auto itv = get(ty)) { // We don't process the type recursively to avoid having to deal with self-recursive intersection types size_t result = 0; for (TypeId part : itv->parts) - if (auto ftv = get(follow(part))) + if (auto ftv = get(follow(part))) result = std::max(result, size(ftv->retTypes)); return result; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 62674aa8e..a73b928bf 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -9,8 +9,8 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include @@ -60,12 +60,12 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) struct ClonePublicInterface : Substitution { - NotNull singletonTypes; + NotNull builtinTypes; NotNull module; - ClonePublicInterface(const TxnLog* log, NotNull singletonTypes, Module* module) + ClonePublicInterface(const TxnLog* log, NotNull builtinTypes, Module* module) : Substitution(log, &module->interfaceTypes) - , singletonTypes(singletonTypes) + , builtinTypes(builtinTypes) , module(module) { LUAU_ASSERT(module); @@ -76,9 +76,9 @@ struct ClonePublicInterface : Substitution if (ty->owningArena == &module->internalTypes) return true; - if (const FunctionTypeVar* ftv = get(ty)) + if (const FunctionType* ftv = get(ty)) return ftv->level.level != 0; - if (const TableTypeVar* ttv = get(ty)) + if (const TableType* ttv = get(ty)) return ttv->level.level != 0; return false; } @@ -92,9 +92,9 @@ struct ClonePublicInterface : Substitution { TypeId result = clone(ty); - if (FunctionTypeVar* ftv = getMutable(result)) + if (FunctionType* ftv = getMutable(result)) ftv->level = TypeLevel{0, 0}; - else if (TableTypeVar* ttv = getMutable(result)) + else if (TableType* ttv = getMutable(result)) ttv->level = TypeLevel{0, 0}; return result; @@ -117,7 +117,7 @@ struct ClonePublicInterface : Substitution else { module->errors.push_back(TypeError{module->scopes[0].first, UnificationTooComplex{}}); - return singletonTypes->errorRecoveryType(); + return builtinTypes->errorRecoveryType(); } } @@ -133,7 +133,7 @@ struct ClonePublicInterface : Substitution else { module->errors.push_back(TypeError{module->scopes[0].first, UnificationTooComplex{}}); - return singletonTypes->errorRecoveryTypePack(); + return builtinTypes->errorRecoveryTypePack(); } } @@ -178,9 +178,9 @@ Module::~Module() unfreeze(internalTypes); } -void Module::clonePublicInterface(NotNull singletonTypes, InternalErrorReporter& ice) +void Module::clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice) { - LUAU_ASSERT(interfaceTypes.typeVars.empty()); + LUAU_ASSERT(interfaceTypes.types.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); CloneState cloneState; @@ -192,7 +192,7 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern std::unordered_map* exportedTypeBindings = &moduleScope->exportedTypeBindings; TxnLog log; - ClonePublicInterface clonePublicInterface{&log, singletonTypes, this}; + ClonePublicInterface clonePublicInterface{&log, builtinTypes, this}; if (FFlag::LuauClonePublicInterfaceLess) returnType = clonePublicInterface.cloneTypePack(returnType); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 09f0595d6..b66546595 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -8,7 +8,7 @@ #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/RecursionCounter.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Unifier.h" LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) @@ -19,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); +LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -115,8 +116,7 @@ bool TypeIds::operator==(const TypeIds& there) const return hash == there.hash && types == there.types; } -NormalizedStringType::NormalizedStringType() -{} +NormalizedStringType::NormalizedStringType() {} NormalizedStringType::NormalizedStringType(bool isCofinite, std::map singletons) : isCofinite(isCofinite) @@ -186,6 +186,23 @@ bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& s return true; } +void NormalizedClassType::pushPair(TypeId ty, TypeIds negations) +{ + ordering.push_back(ty); + classes.insert(std::make_pair(ty, std::move(negations))); +} + +void NormalizedClassType::resetToNever() +{ + ordering.clear(); + classes.clear(); +} + +bool NormalizedClassType::isNever() const +{ + return classes.empty(); +} + NormalizedFunctionType::NormalizedFunctionType() : parts(FFlag::LuauNegatedFunctionTypes ? std::optional{TypeIds{}} : std::nullopt) { @@ -208,22 +225,29 @@ bool NormalizedFunctionType::isNever() const return !isTop && (!parts || parts->empty()); } -NormalizedType::NormalizedType(NotNull singletonTypes) - : tops(singletonTypes->neverType) - , booleans(singletonTypes->neverType) - , errors(singletonTypes->neverType) - , nils(singletonTypes->neverType) - , numbers(singletonTypes->neverType) +NormalizedType::NormalizedType(NotNull builtinTypes) + : tops(builtinTypes->neverType) + , booleans(builtinTypes->neverType) + , errors(builtinTypes->neverType) + , nils(builtinTypes->neverType) + , numbers(builtinTypes->neverType) , strings{NormalizedStringType::never} - , threads(singletonTypes->neverType) + , threads(builtinTypes->neverType) { } static bool isShallowInhabited(const NormalizedType& norm) { + bool inhabitedClasses; + + if (FFlag::LuauNegatedClassTypes) + inhabitedClasses = !norm.classes.isNever(); + else + inhabitedClasses = !norm.DEPRECATED_classes.empty(); + // This test is just a shallow check, for example it returns `true` for `{ p : never }` - return !get(norm.tops) || !get(norm.booleans) || !norm.classes.empty() || !get(norm.errors) || - !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || + return !get(norm.tops) || !get(norm.booleans) || inhabitedClasses || !get(norm.errors) || + !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } @@ -239,9 +263,15 @@ bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_set(norm->tops) || !get(norm->booleans) || !get(norm->errors) || - !get(norm->nils) || !get(norm->numbers) || !get(norm->threads) || - !norm->classes.empty() || !norm->strings.isNever() || !norm->functions.isNever()) + bool inhabitedClasses; + if (FFlag::LuauNegatedClassTypes) + inhabitedClasses = !norm->classes.isNever(); + else + inhabitedClasses = !norm->DEPRECATED_classes.empty(); + + if (!get(norm->tops) || !get(norm->booleans) || !get(norm->errors) || !get(norm->nils) || + !get(norm->numbers) || !get(norm->threads) || inhabitedClasses || !norm->strings.isNever() || + !norm->functions.isNever()) return true; for (const auto& [_, intersect] : norm->tyvars) @@ -264,10 +294,10 @@ bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) // TODO: use log.follow(ty), CLI-64291 ty = follow(ty); - if (get(ty)) + if (get(ty)) return false; - if (!get(ty) && !get(ty) && !get(ty) && !get(ty)) + if (!get(ty) && !get(ty) && !get(ty) && !get(ty)) return true; if (seen.count(ty)) @@ -275,7 +305,7 @@ bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) seen.insert(ty); - if (const TableTypeVar* ttv = get(ty)) + if (const TableType* ttv = get(ty)) { for (const auto& [_, prop] : ttv->props) { @@ -285,7 +315,7 @@ bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) return true; } - if (const MetatableTypeVar* mtv = get(ty)) + if (const MetatableType* mtv = get(ty)) return isInhabited(mtv->table, seen) && isInhabited(mtv->metatable, seen); const NormalizedType* norm = normalize(ty); @@ -294,28 +324,50 @@ bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) static int tyvarIndex(TypeId ty) { - if (const GenericTypeVar* gtv = get(ty)) + if (const GenericType* gtv = get(ty)) return gtv->index; - else if (const FreeTypeVar* ftv = get(ty)) + else if (const FreeType* ftv = get(ty)) return ftv->index; else return 0; } +static bool isTop(NotNull builtinTypes, const NormalizedClassType& classes) +{ + if (classes.classes.size() != 1) + return false; + + auto first = classes.classes.begin(); + if (first->first != builtinTypes->classType) + return false; + + if (!first->second.empty()) + return false; + + return true; +} + +static void resetToTop(NotNull builtinTypes, NormalizedClassType& classes) +{ + classes.ordering.clear(); + classes.classes.clear(); + classes.pushPair(builtinTypes->classType, TypeIds{}); +} + #ifdef LUAU_ASSERTENABLED static bool isNormalizedTop(TypeId ty) { - return get(ty) || get(ty) || get(ty); + return get(ty) || get(ty) || get(ty); } static bool isNormalizedBoolean(TypeId ty) { - if (get(ty)) + if (get(ty)) return true; - else if (const PrimitiveTypeVar* ptv = get(ty)) - return ptv->type == PrimitiveTypeVar::Boolean; - else if (const SingletonTypeVar* stv = get(ty)) + else if (const PrimitiveType* ptv = get(ty)) + return ptv->type == PrimitiveType::Boolean; + else if (const SingletonType* stv = get(ty)) return get(stv); else return false; @@ -323,7 +375,7 @@ static bool isNormalizedBoolean(TypeId ty) static bool isNormalizedError(TypeId ty) { - if (get(ty) || get(ty)) + if (get(ty) || get(ty)) return true; else return false; @@ -331,20 +383,20 @@ static bool isNormalizedError(TypeId ty) static bool isNormalizedNil(TypeId ty) { - if (get(ty)) + if (get(ty)) return true; - else if (const PrimitiveTypeVar* ptv = get(ty)) - return ptv->type == PrimitiveTypeVar::NilType; + else if (const PrimitiveType* ptv = get(ty)) + return ptv->type == PrimitiveType::NilType; else return false; } static bool isNormalizedNumber(TypeId ty) { - if (get(ty)) + if (get(ty)) return true; - else if (const PrimitiveTypeVar* ptv = get(ty)) - return ptv->type == PrimitiveTypeVar::Number; + else if (const PrimitiveType* ptv = get(ty)) + return ptv->type == PrimitiveType::Number; else return false; } @@ -356,7 +408,7 @@ static bool isNormalizedString(const NormalizedStringType& ty) for (auto& [str, ty] : ty.singletons) { - if (const SingletonTypeVar* stv = get(ty)) + if (const SingletonType* stv = get(ty)) { if (const StringSingleton* sstv = get(stv)) { @@ -375,10 +427,10 @@ static bool isNormalizedString(const NormalizedStringType& ty) static bool isNormalizedThread(TypeId ty) { - if (get(ty)) + if (get(ty)) return true; - else if (const PrimitiveTypeVar* ptv = get(ty)) - return ptv->type == PrimitiveTypeVar::Thread; + else if (const PrimitiveType* ptv = get(ty)) + return ptv->type == PrimitiveType::Thread; else return false; } @@ -389,7 +441,7 @@ static bool areNormalizedFunctions(const NormalizedFunctionType& tys) { for (TypeId ty : *tys.parts) { - if (!get(ty) && !get(ty)) + if (!get(ty) && !get(ty)) return false; } } @@ -399,7 +451,7 @@ static bool areNormalizedFunctions(const NormalizedFunctionType& tys) static bool areNormalizedTables(const TypeIds& tys) { for (TypeId ty : tys) - if (!get(ty) && !get(ty)) + if (!get(ty) && !get(ty)) return false; return true; } @@ -407,14 +459,68 @@ static bool areNormalizedTables(const TypeIds& tys) static bool areNormalizedClasses(const TypeIds& tys) { for (TypeId ty : tys) - if (!get(ty)) + if (!get(ty)) return false; return true; } +static bool areNormalizedClasses(const NormalizedClassType& tys) +{ + for (const auto& [ty, negations] : tys.classes) + { + const ClassType* ctv = get(ty); + if (!ctv) + { + return false; + } + + for (TypeId negation : negations) + { + const ClassType* nctv = get(negation); + if (!nctv) + { + return false; + } + + if (!isSubclass(nctv, ctv)) + { + return false; + } + } + + for (const auto& [otherTy, otherNegations] : tys.classes) + { + if (otherTy == ty) + continue; + + const ClassType* octv = get(otherTy); + if (!octv) + { + return false; + } + + if (isSubclass(ctv, octv)) + { + auto iss = [ctv](TypeId t) { + const ClassType* c = get(t); + if (!c) + return false; + + return isSubclass(ctv, c); + }; + + if (!std::any_of(otherNegations.begin(), otherNegations.end(), iss)) + return false; + } + } + } + + return true; +} + static bool isPlainTyvar(TypeId ty) { - return (get(ty) || get(ty)); + return (get(ty) || get(ty)); } static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) @@ -442,6 +548,7 @@ static void assertInvariant(const NormalizedType& norm) LUAU_ASSERT(isNormalizedTop(norm.tops)); LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); + LUAU_ASSERT(areNormalizedClasses(norm.DEPRECATED_classes)); LUAU_ASSERT(areNormalizedClasses(norm.classes)); LUAU_ASSERT(isNormalizedError(norm.errors)); LUAU_ASSERT(isNormalizedNil(norm.nils)); @@ -456,9 +563,9 @@ static void assertInvariant(const NormalizedType& norm) #endif } -Normalizer::Normalizer(TypeArena* arena, NotNull singletonTypes, NotNull sharedState) +Normalizer::Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState) : arena(arena) - , singletonTypes(singletonTypes) + , builtinTypes(builtinTypes) , sharedState(sharedState) { } @@ -472,7 +579,7 @@ const NormalizedType* Normalizer::normalize(TypeId ty) if (found != cachedNormals.end()) return found->second.get(); - NormalizedType norm{singletonTypes}; + NormalizedType norm{builtinTypes}; if (!unionNormalWithTy(norm, ty)) return nullptr; std::unique_ptr uniq = std::make_unique(std::move(norm)); @@ -483,14 +590,15 @@ const NormalizedType* Normalizer::normalize(TypeId ty) void Normalizer::clearNormal(NormalizedType& norm) { - norm.tops = singletonTypes->neverType; - norm.booleans = singletonTypes->neverType; - norm.classes.clear(); - norm.errors = singletonTypes->neverType; - norm.nils = singletonTypes->neverType; - norm.numbers = singletonTypes->neverType; + norm.tops = builtinTypes->neverType; + norm.booleans = builtinTypes->neverType; + norm.classes.resetToNever(); + norm.DEPRECATED_classes.clear(); + norm.errors = builtinTypes->neverType; + norm.nils = builtinTypes->neverType; + norm.numbers = builtinTypes->neverType; norm.strings.resetToNever(); - norm.threads = singletonTypes->neverType; + norm.threads = builtinTypes->neverType; norm.tables.clear(); norm.functions.resetToNever(); norm.tyvars.clear(); @@ -516,14 +624,14 @@ TypeId Normalizer::unionType(TypeId here, TypeId there) if (here == there) return here; - if (get(here) || get(there)) + if (get(here) || get(there)) return there; - if (get(there) || get(here)) + if (get(there) || get(here)) return here; TypeIds tmps; - if (const UnionTypeVar* utv = get(here)) + if (const UnionType* utv = get(here)) { TypeIds heres; heres.insert(begin(utv), end(utv)); @@ -533,7 +641,7 @@ TypeId Normalizer::unionType(TypeId here, TypeId there) else tmps.insert(here); - if (const UnionTypeVar* utv = get(there)) + if (const UnionType* utv = get(there)) { TypeIds theres; theres.insert(begin(utv), end(utv)); @@ -549,7 +657,7 @@ TypeId Normalizer::unionType(TypeId here, TypeId there) std::vector parts; parts.insert(parts.end(), tmps.begin(), tmps.end()); - TypeId result = arena->addType(UnionTypeVar{std::move(parts)}); + TypeId result = arena->addType(UnionType{std::move(parts)}); cachedUnions[cacheTypeIds(std::move(tmps))] = result; return result; @@ -562,14 +670,14 @@ TypeId Normalizer::intersectionType(TypeId here, TypeId there) if (here == there) return here; - if (get(here) || get(there)) + if (get(here) || get(there)) return here; - if (get(there) || get(here)) + if (get(there) || get(here)) return there; TypeIds tmps; - if (const IntersectionTypeVar* utv = get(here)) + if (const IntersectionType* utv = get(here)) { TypeIds heres; heres.insert(begin(utv), end(utv)); @@ -579,7 +687,7 @@ TypeId Normalizer::intersectionType(TypeId here, TypeId there) else tmps.insert(here); - if (const IntersectionTypeVar* utv = get(there)) + if (const IntersectionType* utv = get(there)) { TypeIds theres; theres.insert(begin(utv), end(utv)); @@ -598,7 +706,7 @@ TypeId Normalizer::intersectionType(TypeId here, TypeId there) std::vector parts; parts.insert(parts.end(), tmps.begin(), tmps.end()); - TypeId result = arena->addType(IntersectionTypeVar{std::move(parts)}); + TypeId result = arena->addType(IntersectionType{std::move(parts)}); cachedIntersections[cacheTypeIds(std::move(tmps))] = result; return result; @@ -615,7 +723,7 @@ void Normalizer::clearCaches() // ------- Normalizing unions TypeId Normalizer::unionOfTops(TypeId here, TypeId there) { - if (get(here) || get(there)) + if (get(here) || get(there)) return there; else return here; @@ -623,15 +731,15 @@ TypeId Normalizer::unionOfTops(TypeId here, TypeId there) TypeId Normalizer::unionOfBools(TypeId here, TypeId there) { - if (get(here)) + if (get(here)) return there; - if (get(there)) + if (get(there)) return here; - if (const BooleanSingleton* hbool = get(get(here))) - if (const BooleanSingleton* tbool = get(get(there))) + if (const BooleanSingleton* hbool = get(get(here))) + if (const BooleanSingleton* tbool = get(get(there))) if (hbool->value == tbool->value) return here; - return singletonTypes->booleanType; + return builtinTypes->booleanType; } void Normalizer::unionClassesWithClass(TypeIds& heres, TypeId there) @@ -639,12 +747,12 @@ void Normalizer::unionClassesWithClass(TypeIds& heres, TypeId there) if (heres.count(there)) return; - const ClassTypeVar* tctv = get(there); + const ClassType* tctv = get(there); for (auto it = heres.begin(); it != heres.end();) { TypeId here = *it; - const ClassTypeVar* hctv = get(here); + const ClassType* hctv = get(here); if (isSubclass(tctv, hctv)) return; else if (isSubclass(hctv, tctv)) @@ -662,6 +770,184 @@ void Normalizer::unionClasses(TypeIds& heres, const TypeIds& theres) unionClassesWithClass(heres, there); } +static bool isSubclass(TypeId test, TypeId parent) +{ + const ClassType* testCtv = get(test); + const ClassType* parentCtv = get(parent); + + LUAU_ASSERT(testCtv); + LUAU_ASSERT(parentCtv); + + return isSubclass(testCtv, parentCtv); +} + +void Normalizer::unionClassesWithClass(NormalizedClassType& heres, TypeId there) +{ + for (auto it = heres.ordering.begin(); it != heres.ordering.end();) + { + TypeId hereTy = *it; + TypeIds& hereNegations = heres.classes.at(hereTy); + + // If the incoming class is a subclass of another class in the map, we + // must ensure that it is negated by one of the negations in the same + // cluster. If it isn't, we do not need to insert it - the subtyping + // relationship is already handled by this entry. If it is, we must + // insert it, to capture the presence of this particular subtype. + if (isSubclass(there, hereTy)) + { + for (auto nIt = hereNegations.begin(); nIt != hereNegations.end();) + { + TypeId hereNegation = *nIt; + + // If the incoming class is a subclass of one of the negations, + // we must insert it into the class map. + if (isSubclass(there, hereNegation)) + { + heres.pushPair(there, TypeIds{}); + return; + } + // If the incoming class is a superclass of one of the + // negations, then the negation no longer applies and must be + // removed. This is also true if they are equal. Since classes + // are, at this time, entirely persistent (we do not clone + // them), a pointer identity check is sufficient. + else if (isSubclass(hereNegation, there)) + { + nIt = hereNegations.erase(nIt); + } + // If the incoming class is unrelated to the negation, we move + // on to the next item. + else + { + ++nIt; + } + } + + // If, at the end of the above loop, we haven't returned, that means + // that the class is not a subclass of one of the negations, and is + // covered by the existing subtype relationship. We can return now. + return; + } + // If the incoming class is a superclass of another class in the map, we + // need to replace the existing class with the incoming class, + // preserving the relevant negations. + else if (isSubclass(hereTy, there)) + { + TypeIds negations = std::move(hereNegations); + it = heres.ordering.erase(it); + heres.classes.erase(hereTy); + + heres.pushPair(there, std::move(negations)); + return; + } + + // If the incoming class is unrelated to the class in the map, we move + // on. If we do not otherwise exit from this method body, we will + // eventually fall out of this loop and insert the incoming class, which + // we have proven to be completely unrelated to any class in the map, + // into the map itself. + ++it; + } + + heres.pushPair(there, TypeIds{}); +} + +void Normalizer::unionClasses(NormalizedClassType& heres, const NormalizedClassType& theres) +{ + // This method bears much similarity with unionClassesWithClass, but is + // solving a more general problem. In unionClassesWithClass, we are dealing + // with a singular positive type. Since it's one type, we can use early + // returns as control flow. Since it's guaranteed to be positive, we do not + // have negations to worry about combining. The two aspects combine to make + // the tasks this method must perform different enough to warrant a separate + // implementation. + + for (const TypeId thereTy : theres.ordering) + { + const TypeIds& thereNegations = theres.classes.at(thereTy); + + // If it happens that there are _no_ classes in the current map, or the + // incoming class is completely unrelated to any class in the current + // map, we must insert the incoming pair as-is. + bool insert = true; + + for (auto it = heres.ordering.begin(); it != heres.ordering.end();) + { + TypeId hereTy = *it; + TypeIds& hereNegations = heres.classes.at(hereTy); + + if (isSubclass(thereTy, hereTy)) + { + bool inserted = false; + for (auto nIt = hereNegations.begin(); nIt != hereNegations.end();) + { + TypeId hereNegateTy = *nIt; + + // If the incoming class is a subclass of one of the negations, + // we must insert it into the class map. + if (isSubclass(thereTy, hereNegateTy)) + { + // We do not concern ourselves with iterator + // invalidation here because we will break out of the + // loop over `heres` when `inserted` is set, and we do + // not read from the iterator after this point. + inserted = true; + heres.pushPair(thereTy, thereNegations); + break; + } + // If the incoming class is a superclass of one of the + // negations, then the negation no longer applies and must + // be removed. This is also true if they are equal. Since + // classes are, at this time, entirely persistent (we do not + // clone them), a pointer identity check is sufficient. + else if (isSubclass(hereNegateTy, thereTy)) + { + inserted = true; + nIt = hereNegations.erase(nIt); + break; + } + // If the incoming class is unrelated to the negation, we + // move on to the next item. + else + { + ++nIt; + } + } + + if (inserted) + { + insert = false; + break; + } + } + else if (isSubclass(hereTy, thereTy)) + { + TypeIds negations = std::move(hereNegations); + unionClasses(negations, thereNegations); + + it = heres.ordering.erase(it); + heres.classes.erase(hereTy); + heres.pushPair(thereTy, std::move(negations)); + insert = false; + break; + } + else if (hereTy == thereTy) + { + unionClasses(hereNegations, thereNegations); + insert = false; + break; + } + + ++it; + } + + if (insert) + { + heres.pushPair(thereTy, thereNegations); + } + } +} + void Normalizer::unionStrings(NormalizedStringType& here, const NormalizedStringType& there) { if (there.isString()) @@ -738,7 +1024,7 @@ std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePack bool& thereSubHere) { if (ith != end(here)) { - TypeId tty = singletonTypes->nilType; + TypeId tty = builtinTypes->nilType; if (std::optional ttail = itt.tail()) { if (const VariadicTypePack* tvtp = get(*ttail)) @@ -834,15 +1120,15 @@ std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePack std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) { - if (get(here)) + if (get(here)) return here; - if (get(there)) + if (get(there)) return there; - const FunctionTypeVar* hftv = get(here); + const FunctionType* hftv = get(here); LUAU_ASSERT(hftv); - const FunctionTypeVar* tftv = get(there); + const FunctionType* tftv = get(there); LUAU_ASSERT(tftv); if (hftv->generics != tftv->generics) @@ -863,7 +1149,7 @@ std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) if (*argTypes == tftv->argTypes && *retTypes == tftv->retTypes) return there; - FunctionTypeVar result{*argTypes, *retTypes}; + FunctionType result{*argTypes, *retTypes}; result.generics = hftv->generics; result.genericPacks = hftv->genericPacks; return arena->addType(std::move(result)); @@ -897,7 +1183,7 @@ void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedF if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); else - tmps.insert(singletonTypes->errorRecoveryType(there)); + tmps.insert(builtinTypes->errorRecoveryType(there)); } heres.parts = std::move(tmps); @@ -919,7 +1205,7 @@ void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeI if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); else - tmps.insert(singletonTypes->errorRecoveryType(there)); + tmps.insert(builtinTypes->errorRecoveryType(there)); } heres.parts = std::move(tmps); } @@ -958,7 +1244,7 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { TypeId tops = unionOfTops(here.tops, there.tops); - if (!get(tops)) + if (!get(tops)) { clearNormal(here); here.tops = tops; @@ -972,7 +1258,7 @@ bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int index = tyvarIndex(tyvar); if (index <= ignoreSmallerTyvars) continue; - auto [emplaced, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{singletonTypes})); + auto [emplaced, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{builtinTypes})); if (fresh) if (!unionNormals(*emplaced->second, here, index)) return false; @@ -981,12 +1267,16 @@ bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, } here.booleans = unionOfBools(here.booleans, there.booleans); - unionClasses(here.classes, there.classes); - here.errors = (get(there.errors) ? here.errors : there.errors); - here.nils = (get(there.nils) ? here.nils : there.nils); - here.numbers = (get(there.numbers) ? here.numbers : there.numbers); + if (FFlag::LuauNegatedClassTypes) + unionClasses(here.classes, there.classes); + else + unionClasses(here.DEPRECATED_classes, there.DEPRECATED_classes); + + here.errors = (get(there.errors) ? here.errors : there.errors); + here.nils = (get(there.nils) ? here.nils : there.nils); + here.numbers = (get(there.numbers) ? here.numbers : there.numbers); unionStrings(here.strings, there.strings); - here.threads = (get(there.threads) ? here.threads : there.threads); + here.threads = (get(there.threads) ? here.threads : there.threads); unionFunctions(here.functions, there.functions); unionTables(here.tables, there.tables); return true; @@ -1021,60 +1311,69 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor return false; there = follow(there); - if (get(there) || get(there)) + if (get(there) || get(there)) { TypeId tops = unionOfTops(here.tops, there); clearNormal(here); here.tops = tops; return true; } - else if (get(there) || !get(here.tops)) + else if (get(there) || !get(here.tops)) return true; - else if (const UnionTypeVar* utv = get(there)) + else if (const UnionType* utv = get(there)) { - for (UnionTypeVarIterator it = begin(utv); it != end(utv); ++it) + for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) if (!unionNormalWithTy(here, *it)) return false; return true; } - else if (const IntersectionTypeVar* itv = get(there)) + else if (const IntersectionType* itv = get(there)) { - NormalizedType norm{singletonTypes}; - norm.tops = singletonTypes->anyType; - for (IntersectionTypeVarIterator it = begin(itv); it != end(itv); ++it) + NormalizedType norm{builtinTypes}; + norm.tops = builtinTypes->anyType; + for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) if (!intersectNormalWithTy(norm, *it)) return false; return unionNormals(here, norm); } - else if (get(there) || get(there)) + else if (get(there) || get(there)) { if (tyvarIndex(there) <= ignoreSmallerTyvars) return true; - NormalizedType inter{singletonTypes}; - inter.tops = singletonTypes->unknownType; + NormalizedType inter{builtinTypes}; + inter.tops = builtinTypes->unknownType; here.tyvars.insert_or_assign(there, std::make_unique(std::move(inter))); } - else if (get(there)) + else if (get(there)) unionFunctionsWithFunction(here.functions, there); - else if (get(there) || get(there)) + else if (get(there) || get(there)) unionTablesWithTable(here.tables, there); - else if (get(there)) - unionClassesWithClass(here.classes, there); - else if (get(there)) + else if (get(there)) + { + if (FFlag::LuauNegatedClassTypes) + { + unionClassesWithClass(here.classes, there); + } + else + { + unionClassesWithClass(here.DEPRECATED_classes, there); + } + } + else if (get(there)) here.errors = there; - else if (const PrimitiveTypeVar* ptv = get(there)) + else if (const PrimitiveType* ptv = get(there)) { - if (ptv->type == PrimitiveTypeVar::Boolean) + if (ptv->type == PrimitiveType::Boolean) here.booleans = there; - else if (ptv->type == PrimitiveTypeVar::NilType) + else if (ptv->type == PrimitiveType::NilType) here.nils = there; - else if (ptv->type == PrimitiveTypeVar::Number) + else if (ptv->type == PrimitiveType::Number) here.numbers = there; - else if (ptv->type == PrimitiveTypeVar::String) + else if (ptv->type == PrimitiveType::String) here.strings.resetToString(); - else if (ptv->type == PrimitiveTypeVar::Thread) + else if (ptv->type == PrimitiveType::Thread) here.threads = there; - else if (ptv->type == PrimitiveTypeVar::Function) + else if (ptv->type == PrimitiveType::Function) { LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); here.functions.resetToTop(); @@ -1082,7 +1381,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else LUAU_ASSERT(!"Unreachable"); } - else if (const SingletonTypeVar* stv = get(there)) + else if (const SingletonType* stv = get(there)) { if (get(stv)) here.booleans = unionOfBools(here.booleans, there); @@ -1100,7 +1399,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else LUAU_ASSERT(!"Unreachable"); } - else if (const NegationTypeVar* ntv = get(there)) + else if (const NegationType* ntv = get(there)) { const NormalizedType* thereNormal = normalize(ntv->ty); std::optional tn = negateNormal(*thereNormal); @@ -1125,42 +1424,73 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor std::optional Normalizer::negateNormal(const NormalizedType& here) { - NormalizedType result{singletonTypes}; - if (!get(here.tops)) + NormalizedType result{builtinTypes}; + if (!get(here.tops)) { // The negation of unknown or any is never. Easy. return result; } - if (!get(here.errors)) + if (!get(here.errors)) { // Negating an error yields the same error. result.errors = here.errors; return result; } - if (get(here.booleans)) - result.booleans = singletonTypes->booleanType; - else if (get(here.booleans)) - result.booleans = singletonTypes->neverType; - else if (auto stv = get(here.booleans)) + if (get(here.booleans)) + result.booleans = builtinTypes->booleanType; + else if (get(here.booleans)) + result.booleans = builtinTypes->neverType; + else if (auto stv = get(here.booleans)) { auto boolean = get(stv); LUAU_ASSERT(boolean != nullptr); if (boolean->value) - result.booleans = singletonTypes->falseType; + result.booleans = builtinTypes->falseType; else - result.booleans = singletonTypes->trueType; + result.booleans = builtinTypes->trueType; } - result.classes = negateAll(here.classes); - result.nils = get(here.nils) ? singletonTypes->nilType : singletonTypes->neverType; - result.numbers = get(here.numbers) ? singletonTypes->numberType : singletonTypes->neverType; + if (FFlag::LuauNegatedClassTypes) + { + if (here.classes.isNever()) + { + resetToTop(builtinTypes, result.classes); + } + else if (isTop(builtinTypes, result.classes)) + { + result.classes.resetToNever(); + } + else + { + TypeIds rootNegations{}; + + for (const auto& [hereParent, hereNegations] : here.classes.classes) + { + if (hereParent != builtinTypes->classType) + rootNegations.insert(hereParent); + + for (TypeId hereNegation : hereNegations) + unionClassesWithClass(result.classes, hereNegation); + } + + if (!rootNegations.empty()) + result.classes.pushPair(builtinTypes->classType, rootNegations); + } + } + else + { + result.DEPRECATED_classes = negateAll(here.DEPRECATED_classes); + } + + result.nils = get(here.nils) ? builtinTypes->nilType : builtinTypes->neverType; + result.numbers = get(here.numbers) ? builtinTypes->numberType : builtinTypes->neverType; result.strings = here.strings; result.strings.isCofinite = !result.strings.isCofinite; - result.threads = get(here.threads) ? singletonTypes->threadType : singletonTypes->neverType; + result.threads = get(here.threads) ? builtinTypes->threadType : builtinTypes->neverType; /* * Things get weird and so, so complicated if we allow negations of @@ -1194,27 +1524,27 @@ TypeIds Normalizer::negateAll(const TypeIds& theres) TypeId Normalizer::negate(TypeId there) { there = follow(there); - if (get(there)) + if (get(there)) return there; - else if (get(there)) - return singletonTypes->neverType; - else if (get(there)) - return singletonTypes->unknownType; - else if (auto ntv = get(there)) + else if (get(there)) + return builtinTypes->neverType; + else if (get(there)) + return builtinTypes->unknownType; + else if (auto ntv = get(there)) return ntv->ty; // TODO: do we want to normalize this? - else if (auto utv = get(there)) + else if (auto utv = get(there)) { std::vector parts; for (TypeId option : utv) parts.push_back(negate(option)); - return arena->addType(IntersectionTypeVar{std::move(parts)}); + return arena->addType(IntersectionType{std::move(parts)}); } - else if (auto itv = get(there)) + else if (auto itv = get(there)) { std::vector options; for (TypeId part : itv) options.push_back(negate(part)); - return arena->addType(UnionTypeVar{std::move(options)}); + return arena->addType(UnionType{std::move(options)}); } else return there; @@ -1222,26 +1552,26 @@ TypeId Normalizer::negate(TypeId there) void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) { - const PrimitiveTypeVar* ptv = get(follow(ty)); + const PrimitiveType* ptv = get(follow(ty)); LUAU_ASSERT(ptv); switch (ptv->type) { - case PrimitiveTypeVar::NilType: - here.nils = singletonTypes->neverType; + case PrimitiveType::NilType: + here.nils = builtinTypes->neverType; break; - case PrimitiveTypeVar::Boolean: - here.booleans = singletonTypes->neverType; + case PrimitiveType::Boolean: + here.booleans = builtinTypes->neverType; break; - case PrimitiveTypeVar::Number: - here.numbers = singletonTypes->neverType; + case PrimitiveType::Number: + here.numbers = builtinTypes->neverType; break; - case PrimitiveTypeVar::String: + case PrimitiveType::String: here.strings.resetToNever(); break; - case PrimitiveTypeVar::Thread: - here.threads = singletonTypes->neverType; + case PrimitiveType::Thread: + here.threads = builtinTypes->neverType; break; - case PrimitiveTypeVar::Function: + case PrimitiveType::Function: here.functions.resetToNever(); break; } @@ -1249,7 +1579,7 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) void Normalizer::subtractSingleton(NormalizedType& here, TypeId ty) { - const SingletonTypeVar* stv = get(ty); + const SingletonType* stv = get(ty); LUAU_ASSERT(stv); if (const StringSingleton* ss = get(stv)) @@ -1265,13 +1595,13 @@ void Normalizer::subtractSingleton(NormalizedType& here, TypeId ty) } else if (const BooleanSingleton* bs = get(stv)) { - if (get(here.booleans)) + if (get(here.booleans)) { // Nothing } - else if (get(here.booleans)) - here.booleans = bs->value ? singletonTypes->falseType : singletonTypes->trueType; - else if (auto hereSingleton = get(here.booleans)) + else if (get(here.booleans)) + here.booleans = bs->value ? builtinTypes->falseType : builtinTypes->trueType; + else if (auto hereSingleton = get(here.booleans)) { const BooleanSingleton* hereBooleanSingleton = get(hereSingleton); LUAU_ASSERT(hereBooleanSingleton); @@ -1280,7 +1610,7 @@ void Normalizer::subtractSingleton(NormalizedType& here, TypeId ty) // negated out. We therefore reduce to never when the values match, // rather than when they differ. if (bs->value == hereBooleanSingleton->value) - here.booleans = singletonTypes->neverType; + here.booleans = builtinTypes->neverType; } else LUAU_ASSERT(!"Unreachable"); @@ -1292,7 +1622,7 @@ void Normalizer::subtractSingleton(NormalizedType& here, TypeId ty) // ------- Normalizing intersections TypeId Normalizer::intersectionOfTops(TypeId here, TypeId there) { - if (get(here) || get(there)) + if (get(here) || get(there)) return here; else return there; @@ -1300,30 +1630,30 @@ TypeId Normalizer::intersectionOfTops(TypeId here, TypeId there) TypeId Normalizer::intersectionOfBools(TypeId here, TypeId there) { - if (get(here)) + if (get(here)) return here; - if (get(there)) + if (get(there)) return there; - if (const BooleanSingleton* hbool = get(get(here))) - if (const BooleanSingleton* tbool = get(get(there))) - return (hbool->value == tbool->value ? here : singletonTypes->neverType); + if (const BooleanSingleton* hbool = get(get(here))) + if (const BooleanSingleton* tbool = get(get(there))) + return (hbool->value == tbool->value ? here : builtinTypes->neverType); else return here; else return there; } -void Normalizer::intersectClasses(TypeIds& heres, const TypeIds& theres) +void Normalizer::DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres) { TypeIds tmp; for (auto it = heres.begin(); it != heres.end();) { - const ClassTypeVar* hctv = get(*it); + const ClassType* hctv = get(*it); LUAU_ASSERT(hctv); bool keep = false; for (TypeId there : theres) { - const ClassTypeVar* tctv = get(there); + const ClassType* tctv = get(there); LUAU_ASSERT(tctv); if (isSubclass(hctv, tctv)) { @@ -1345,14 +1675,14 @@ void Normalizer::intersectClasses(TypeIds& heres, const TypeIds& theres) heres.insert(tmp.begin(), tmp.end()); } -void Normalizer::intersectClassesWithClass(TypeIds& heres, TypeId there) +void Normalizer::DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there) { bool foundSuper = false; - const ClassTypeVar* tctv = get(there); + const ClassType* tctv = get(there); LUAU_ASSERT(tctv); for (auto it = heres.begin(); it != heres.end();) { - const ClassTypeVar* hctv = get(*it); + const ClassType* hctv = get(*it); LUAU_ASSERT(hctv); if (isSubclass(hctv, tctv)) it++; @@ -1371,6 +1701,157 @@ void Normalizer::intersectClassesWithClass(TypeIds& heres, TypeId there) } } +void Normalizer::intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres) +{ + if (theres.isNever()) + { + heres.resetToNever(); + return; + } + else if (isTop(builtinTypes, theres)) + { + return; + } + + // For intersections of two distinct class sets, we must normalize to a map + // where, for each entry, one of the following is true: + // - The class is the superclass of all other classes in the map + // - The class is a subclass of another class B in the map _and_ a subclass + // of one of B's negations. + // + // Once we have identified the common superclass, we proceed down the list + // of class types. For each class and negation pair in the incoming set, we + // check each entry in the current set. + // - If the incoming class is exactly identical to a class in the current + // set, we union the negations together and move on. + // - If the incoming class is a subclass of a class in the current set, we + // replace the current class with the incoming class. We keep negations + // that are a subclass of the incoming class, and discard ones that + // aren't. + // - If the incoming class is a superclass of a class in the current set, we + // take the negations that are a subclass of the current class and union + // them with the negations for the current class. + // - If the incoming class is unrelated to any class in the current set, we + // declare the result of the intersection operation to be never. + for (const TypeId thereTy : theres.ordering) + { + const TypeIds& thereNegations = theres.classes.at(thereTy); + + for (auto it = heres.ordering.begin(); it != heres.ordering.end();) + { + TypeId hereTy = *it; + TypeIds& hereNegations = heres.classes.at(hereTy); + + if (isSubclass(thereTy, hereTy)) + { + TypeIds negations = std::move(hereNegations); + + for (auto nIt = negations.begin(); nIt != negations.end();) + { + if (!isSubclass(*nIt, thereTy)) + { + nIt = negations.erase(nIt); + } + else + { + ++nIt; + } + } + + unionClasses(negations, thereNegations); + + it = heres.ordering.erase(it); + heres.classes.erase(hereTy); + heres.pushPair(thereTy, std::move(negations)); + break; + } + else if (isSubclass(hereTy, thereTy)) + { + TypeIds negations = thereNegations; + + for (auto nIt = negations.begin(); nIt != negations.end();) + { + if (!isSubclass(*nIt, hereTy)) + { + nIt = negations.erase(nIt); + } + else + { + ++nIt; + } + } + + unionClasses(hereNegations, negations); + break; + } + else if (hereTy == thereTy) + { + unionClasses(hereNegations, thereNegations); + break; + } + else + { + it = heres.ordering.erase(it); + heres.classes.erase(hereTy); + } + } + } +} + +void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId there) +{ + for (auto it = heres.ordering.begin(); it != heres.ordering.end();) + { + TypeId hereTy = *it; + const TypeIds& hereNegations = heres.classes.at(hereTy); + + // If the incoming class _is_ the current class, we skip it. Maybe + // another entry will have a different story. We check for this first + // because isSubclass will be true if the types are equal, and entering + // either of those branches below will trigger wrong behaviors. + if (hereTy == there) + { + ++it; + } + // If the incoming class is a subclass of this type, we replace the + // current class with the incoming class. We preserve negations that are + // a subclass of the incoming class, and discard ones that aren't. + else if (isSubclass(there, hereTy)) + { + TypeIds negations = std::move(hereNegations); + + for (auto nIt = negations.begin(); nIt != negations.end();) + { + if (!isSubclass(*nIt, there)) + { + nIt = negations.erase(nIt); + } + else + { + ++nIt; + } + } + + it = heres.ordering.erase(it); + heres.classes.erase(hereTy); + heres.pushPair(there, std::move(negations)); + } + // If the incoming class is a superclass of the current class, we don't + // insert it into the map. + else if (isSubclass(hereTy, there)) + { + return; + } + // If the incoming class is completely unrelated to the current class, + // we drop the current class from the map. + else + { + it = heres.ordering.erase(it); + heres.classes.erase(hereTy); + } + } +} + void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) { if (there.isString()) @@ -1419,7 +1900,7 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T bool& thereSubHere) { if (ith != end(here)) { - TypeId tty = singletonTypes->nilType; + TypeId tty = builtinTypes->nilType; if (std::optional ttail = itt.tail()) { if (const VariadicTypePack* tvtp = get(*ttail)) @@ -1518,22 +1999,22 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there TypeId htable = here; TypeId hmtable = nullptr; - if (const MetatableTypeVar* hmtv = get(here)) + if (const MetatableType* hmtv = get(here)) { htable = hmtv->table; hmtable = hmtv->metatable; } TypeId ttable = there; TypeId tmtable = nullptr; - if (const MetatableTypeVar* tmtv = get(there)) + if (const MetatableType* tmtv = get(there)) { ttable = tmtv->table; tmtable = tmtv->metatable; } - const TableTypeVar* httv = get(htable); + const TableType* httv = get(htable); LUAU_ASSERT(httv); - const TableTypeVar* tttv = get(ttable); + const TableType* tttv = get(ttable); LUAU_ASSERT(tttv); if (httv->state == TableState::Free || tttv->state == TableState::Free) @@ -1546,7 +2027,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there state = tttv->state; TypeLevel level = max(httv->level, tttv->level); - TableTypeVar result{state, level}; + TableType result{state, level}; bool hereSubThere = true; bool thereSubHere = true; @@ -1616,7 +2097,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (table == ttable && *mtable == tmtable) return there; else - return arena->addType(MetatableTypeVar{table, *mtable}); + return arena->addType(MetatableType{table, *mtable}); } else return std::nullopt; @@ -1626,14 +2107,14 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there if (table == htable) return here; else - return arena->addType(MetatableTypeVar{table, hmtable}); + return arena->addType(MetatableType{table, hmtable}); } else if (tmtable) { if (table == ttable) return there; else - return arena->addType(MetatableTypeVar{table, tmtable}); + return arena->addType(MetatableType{table, tmtable}); } else return table; @@ -1662,9 +2143,9 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres) std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId there) { - const FunctionTypeVar* hftv = get(here); + const FunctionType* hftv = get(here); LUAU_ASSERT(hftv); - const FunctionTypeVar* tftv = get(there); + const FunctionType* tftv = get(there); LUAU_ASSERT(tftv); if (hftv->generics != tftv->generics) @@ -1699,7 +2180,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th if (argTypes == tftv->argTypes && retTypes == tftv->retTypes) return there; - FunctionTypeVar result{argTypes, retTypes}; + FunctionType result{argTypes, retTypes}; result.generics = hftv->generics; result.genericPacks = hftv->genericPacks; return arena->addType(std::move(result)); @@ -1796,10 +2277,10 @@ std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId th // Proc. Principles and practice of declarative programming 2005, pp 198–208 // https://doi.org/10.1145/1069774.1069793 - const FunctionTypeVar* hftv = get(here); + const FunctionType* hftv = get(here); if (!hftv) return std::nullopt; - const FunctionTypeVar* tftv = get(there); + const FunctionType* tftv = get(there); if (!tftv) return std::nullopt; @@ -1815,7 +2296,7 @@ std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId th if (!retTypes) return std::nullopt; - FunctionTypeVar result{*argTypes, *retTypes}; + FunctionType result{*argTypes, *retTypes}; result.generics = hftv->generics; result.genericPacks = hftv->genericPacks; return arena->addType(std::move(result)); @@ -1831,7 +2312,7 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T for (auto it = heres.parts->begin(); it != heres.parts->end();) { TypeId here = *it; - if (get(here)) + if (get(here)) it++; else if (std::optional tmp = intersectionOfFunctions(here, there)) { @@ -1887,24 +2368,33 @@ bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there) // See above for an explaination of `ignoreSmallerTyvars`. bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { - if (!get(there.tops)) + if (!get(there.tops)) { here.tops = intersectionOfTops(here.tops, there.tops); return true; } - else if (!get(here.tops)) + else if (!get(here.tops)) { clearNormal(here); return unionNormals(here, there, ignoreSmallerTyvars); } here.booleans = intersectionOfBools(here.booleans, there.booleans); - intersectClasses(here.classes, there.classes); - here.errors = (get(there.errors) ? there.errors : here.errors); - here.nils = (get(there.nils) ? there.nils : here.nils); - here.numbers = (get(there.numbers) ? there.numbers : here.numbers); + + if (FFlag::LuauNegatedClassTypes) + { + intersectClasses(here.classes, there.classes); + } + else + { + DEPRECATED_intersectClasses(here.DEPRECATED_classes, there.DEPRECATED_classes); + } + + here.errors = (get(there.errors) ? there.errors : here.errors); + here.nils = (get(there.nils) ? there.nils : here.nils); + here.numbers = (get(there.numbers) ? there.numbers : here.numbers); intersectStrings(here.strings, there.strings); - here.threads = (get(there.threads) ? there.threads : here.threads); + here.threads = (get(there.threads) ? there.threads : here.threads); intersectFunctions(here.functions, there.functions); intersectTables(here.tables, there.tables); @@ -1913,7 +2403,7 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th int index = tyvarIndex(tyvar); if (ignoreSmallerTyvars < index) { - auto [found, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{singletonTypes})); + auto [found, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{builtinTypes})); if (fresh) { if (!unionNormals(*found->second, here, index)) @@ -1953,70 +2443,80 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) return false; there = follow(there); - if (get(there) || get(there)) + if (get(there) || get(there)) { here.tops = intersectionOfTops(here.tops, there); return true; } - else if (!get(here.tops)) + else if (!get(here.tops)) { clearNormal(here); return unionNormalWithTy(here, there); } - else if (const UnionTypeVar* utv = get(there)) + else if (const UnionType* utv = get(there)) { - NormalizedType norm{singletonTypes}; - for (UnionTypeVarIterator it = begin(utv); it != end(utv); ++it) + NormalizedType norm{builtinTypes}; + for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) if (!unionNormalWithTy(norm, *it)) return false; return intersectNormals(here, norm); } - else if (const IntersectionTypeVar* itv = get(there)) + else if (const IntersectionType* itv = get(there)) { - for (IntersectionTypeVarIterator it = begin(itv); it != end(itv); ++it) + for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) if (!intersectNormalWithTy(here, *it)) return false; return true; } - else if (get(there) || get(there)) + else if (get(there) || get(there)) { - NormalizedType thereNorm{singletonTypes}; - NormalizedType topNorm{singletonTypes}; - topNorm.tops = singletonTypes->unknownType; + NormalizedType thereNorm{builtinTypes}; + NormalizedType topNorm{builtinTypes}; + topNorm.tops = builtinTypes->unknownType; thereNorm.tyvars.insert_or_assign(there, std::make_unique(std::move(topNorm))); return intersectNormals(here, thereNorm); } NormalizedTyvars tyvars = std::move(here.tyvars); - if (const FunctionTypeVar* utv = get(there)) + if (const FunctionType* utv = get(there)) { NormalizedFunctionType functions = std::move(here.functions); clearNormal(here); intersectFunctionsWithFunction(functions, there); here.functions = std::move(functions); } - else if (get(there) || get(there)) + else if (get(there) || get(there)) { TypeIds tables = std::move(here.tables); clearNormal(here); intersectTablesWithTable(tables, there); here.tables = std::move(tables); } - else if (get(there)) + else if (get(there)) { - TypeIds classes = std::move(here.classes); - clearNormal(here); - intersectClassesWithClass(classes, there); - here.classes = std::move(classes); + if (FFlag::LuauNegatedClassTypes) + { + NormalizedClassType nct = std::move(here.classes); + clearNormal(here); + intersectClassesWithClass(nct, there); + here.classes = std::move(nct); + } + else + { + TypeIds classes = std::move(here.DEPRECATED_classes); + clearNormal(here); + DEPRECATED_intersectClassesWithClass(classes, there); + here.DEPRECATED_classes = std::move(classes); + } } - else if (get(there)) + else if (get(there)) { TypeId errors = here.errors; clearNormal(here); here.errors = errors; } - else if (const PrimitiveTypeVar* ptv = get(there)) + else if (const PrimitiveType* ptv = get(there)) { TypeId booleans = here.booleans; TypeId nils = here.nils; @@ -2027,17 +2527,17 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) clearNormal(here); - if (ptv->type == PrimitiveTypeVar::Boolean) + if (ptv->type == PrimitiveType::Boolean) here.booleans = booleans; - else if (ptv->type == PrimitiveTypeVar::NilType) + else if (ptv->type == PrimitiveType::NilType) here.nils = nils; - else if (ptv->type == PrimitiveTypeVar::Number) + else if (ptv->type == PrimitiveType::Number) here.numbers = numbers; - else if (ptv->type == PrimitiveTypeVar::String) + else if (ptv->type == PrimitiveType::String) here.strings = std::move(strings); - else if (ptv->type == PrimitiveTypeVar::Thread) + else if (ptv->type == PrimitiveType::Thread) here.threads = threads; - else if (ptv->type == PrimitiveTypeVar::Function) + else if (ptv->type == PrimitiveType::Function) { LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); here.functions = std::move(functions); @@ -2045,7 +2545,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) else LUAU_ASSERT(!"Unreachable"); } - else if (const SingletonTypeVar* stv = get(there)) + else if (const SingletonType* stv = get(there)) { TypeId booleans = here.booleans; NormalizedStringType strings = std::move(here.strings); @@ -2062,14 +2562,22 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) else LUAU_ASSERT(!"Unreachable"); } - else if (const NegationTypeVar* ntv = get(there)) + else if (const NegationType* ntv = get(there)) { TypeId t = follow(ntv->ty); - if (const PrimitiveTypeVar* ptv = get(t)) + if (const PrimitiveType* ptv = get(t)) subtractPrimitive(here, ntv->ty); - else if (const SingletonTypeVar* stv = get(t)) + else if (const SingletonType* stv = get(t)) subtractSingleton(here, follow(ntv->ty)); - else if (const UnionTypeVar* itv = get(t)) + else if (get(t) && FFlag::LuauNegatedClassTypes) + { + const NormalizedType* normal = normalize(t); + std::optional negated = negateNormal(*normal); + if (!negated) + return false; + intersectNormals(here, *negated); + } + else if (const UnionType* itv = get(t)) { for (TypeId part : itv->options) { @@ -2087,6 +2595,10 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) LUAU_ASSERT(!"Unimplemented"); } } + else if (get(there) && FFlag::LuauNegatedClassTypes) + { + here.classes.resetToNever(); + } else LUAU_ASSERT(!"Unreachable"); @@ -2101,18 +2613,67 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) TypeId Normalizer::typeFromNormal(const NormalizedType& norm) { assertInvariant(norm); - if (!get(norm.tops)) + if (!get(norm.tops)) return norm.tops; std::vector result; - if (!get(norm.booleans)) + if (!get(norm.booleans)) result.push_back(norm.booleans); - result.insert(result.end(), norm.classes.begin(), norm.classes.end()); - if (!get(norm.errors)) + + if (FFlag::LuauNegatedClassTypes) + { + if (isTop(builtinTypes, norm.classes)) + { + result.push_back(builtinTypes->classType); + } + else if (!norm.classes.isNever()) + { + std::vector parts; + parts.reserve(norm.classes.classes.size()); + + for (const TypeId normTy : norm.classes.ordering) + { + const TypeIds& normNegations = norm.classes.classes.at(normTy); + + if (normNegations.empty()) + { + parts.push_back(normTy); + } + else + { + std::vector intersection; + intersection.reserve(normNegations.size() + 1); + + intersection.push_back(normTy); + for (TypeId negation : normNegations) + { + intersection.push_back(arena->addType(NegationType{negation})); + } + + parts.push_back(arena->addType(IntersectionType{std::move(intersection)})); + } + } + + if (parts.size() == 1) + { + result.push_back(parts.at(0)); + } + else if (parts.size() > 1) + { + result.push_back(arena->addType(UnionType{std::move(parts)})); + } + } + } + else + { + result.insert(result.end(), norm.DEPRECATED_classes.begin(), norm.DEPRECATED_classes.end()); + } + + if (!get(norm.errors)) result.push_back(norm.errors); if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop) - result.push_back(singletonTypes->functionType); + result.push_back(builtinTypes->functionType); else if (!norm.functions.isNever()) { if (norm.functions.parts->size() == 1) @@ -2121,15 +2682,15 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) { std::vector parts; parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); - result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); + result.push_back(arena->addType(IntersectionType{std::move(parts)})); } } - if (!get(norm.nils)) + if (!get(norm.nils)) result.push_back(norm.nils); - if (!get(norm.numbers)) + if (!get(norm.numbers)) result.push_back(norm.numbers); if (norm.strings.isString()) - result.push_back(singletonTypes->stringType); + result.push_back(builtinTypes->stringType); else if (norm.strings.isUnion()) { for (auto& [_, ty] : norm.strings.singletons) @@ -2138,40 +2699,40 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) else if (norm.strings.isIntersection()) { std::vector parts; - parts.push_back(singletonTypes->stringType); + parts.push_back(builtinTypes->stringType); for (const auto& [name, ty] : norm.strings.singletons) - parts.push_back(arena->addType(NegationTypeVar{ty})); + parts.push_back(arena->addType(NegationType{ty})); - result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); + result.push_back(arena->addType(IntersectionType{std::move(parts)})); } - if (!get(norm.threads)) - result.push_back(singletonTypes->threadType); + if (!get(norm.threads)) + result.push_back(builtinTypes->threadType); result.insert(result.end(), norm.tables.begin(), norm.tables.end()); for (auto& [tyvar, intersect] : norm.tyvars) { - if (get(intersect->tops)) + if (get(intersect->tops)) { TypeId ty = typeFromNormal(*intersect); - result.push_back(arena->addType(IntersectionTypeVar{{tyvar, ty}})); + result.push_back(arena->addType(IntersectionType{{tyvar, ty}})); } else result.push_back(tyvar); } if (result.size() == 0) - return singletonTypes->neverType; + return builtinTypes->neverType; else if (result.size() == 1) return result[0]; else - return arena->addType(UnionTypeVar{std::move(result)}); + return arena->addType(UnionType{std::move(result)}); } -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; - Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; u.tryUnify(subTy, superTy); @@ -2179,11 +2740,11 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; - Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; u.tryUnify(subPack, superPack); diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index e9de094b8..22c5875be 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -5,8 +5,8 @@ #include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TxnLog.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" LUAU_FASTFLAG(DebugLuauSharedSelf) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); @@ -15,7 +15,7 @@ LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) namespace Luau { -struct Quantifier final : TypeVarOnceVisitor +struct Quantifier final : TypeOnceVisitor { TypeLevel level; std::vector generics; @@ -43,24 +43,24 @@ struct Quantifier final : TypeVarOnceVisitor return false; } - bool visit(TypeId ty, const FreeTypeVar& ftv) override + bool visit(TypeId ty, const FreeType& ftv) override { seenMutableType = true; if (!level.subsumes(ftv.level)) return false; - *asMutable(ty) = GenericTypeVar{level}; + *asMutable(ty) = GenericType{level}; generics.push_back(ty); return false; } - bool visit(TypeId ty, const TableTypeVar&) override + bool visit(TypeId ty, const TableType&) override { - LUAU_ASSERT(getMutable(ty)); - TableTypeVar& ttv = *getMutable(ty); + LUAU_ASSERT(getMutable(ty)); + TableType& ttv = *getMutable(ty); if (ttv.state == TableState::Generic) seenGenericType = true; @@ -117,7 +117,7 @@ void quantify(TypeId ty, TypeLevel level) for (const auto& [_, prop] : ttv->props) { - auto ftv = getMutable(follow(prop.type)); + auto ftv = getMutable(follow(prop.type)); if (!ftv || !ftv->hasSelf) continue; @@ -128,7 +128,7 @@ void quantify(TypeId ty, TypeLevel level) } } } - else if (auto ftv = getMutable(ty)) + else if (auto ftv = getMutable(ty)) { Quantifier q{level}; q.traverse(ty); @@ -145,7 +145,7 @@ void quantify(TypeId ty, TypeLevel level) Quantifier q{level}; q.traverse(ty); - FunctionTypeVar* ftv = getMutable(ty); + FunctionType* ftv = getMutable(ty); LUAU_ASSERT(ftv); ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); @@ -168,11 +168,11 @@ struct PureQuantifier : Substitution { LUAU_ASSERT(ty == follow(ty)); - if (auto ftv = get(ty)) + if (auto ftv = get(ty)) { return subsumes(scope, ftv->scope); } - else if (auto ttv = get(ty)) + else if (auto ttv = get(ty)) { return ttv->state == TableState::Free && subsumes(scope, ttv->scope); } @@ -192,16 +192,16 @@ struct PureQuantifier : Substitution TypeId clean(TypeId ty) override { - if (auto ftv = get(ty)) + if (auto ftv = get(ty)) { - TypeId result = arena->addType(GenericTypeVar{scope}); + TypeId result = arena->addType(GenericType{scope}); insertedGenerics.push_back(result); return result; } - else if (auto ttv = get(ty)) + else if (auto ttv = get(ty)) { - TypeId result = arena->addType(TableTypeVar{}); - TableTypeVar* resultTable = getMutable(result); + TypeId result = arena->addType(TableType{}); + TableType* resultTable = getMutable(result); LUAU_ASSERT(resultTable); *resultTable = *ttv; @@ -229,7 +229,7 @@ struct PureQuantifier : Substitution bool ignoreChildren(TypeId ty) override { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) return true; return ty->persistent; @@ -246,7 +246,7 @@ TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope) std::optional result = quantifier.substitute(ty); LUAU_ASSERT(result); - FunctionTypeVar* ftv = getMutable(*result); + FunctionType* ftv = getMutable(*result); LUAU_ASSERT(ftv); ftv->scope = scope; ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 20ed34f6c..2469152eb 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -28,7 +28,7 @@ void Tarjan::visitChildren(TypeId ty, int index) if (auto pty = log->pending(ty)) ty = &pty->pending; - if (const FunctionTypeVar* ftv = get(ty)) + if (const FunctionType* ftv = get(ty)) { if (FFlag::LuauSubstitutionFixMissingFields) { @@ -41,7 +41,7 @@ void Tarjan::visitChildren(TypeId ty, int index) visitChild(ftv->argTypes); visitChild(ftv->retTypes); } - else if (const TableTypeVar* ttv = get(ty)) + else if (const TableType* ttv = get(ty)) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) @@ -58,22 +58,22 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId itp : ttv->instantiatedTypePackParams) visitChild(itp); } - else if (const MetatableTypeVar* mtv = get(ty)) + else if (const MetatableType* mtv = get(ty)) { visitChild(mtv->table); visitChild(mtv->metatable); } - else if (const UnionTypeVar* utv = get(ty)) + else if (const UnionType* utv = get(ty)) { for (TypeId opt : utv->options) visitChild(opt); } - else if (const IntersectionTypeVar* itv = get(ty)) + else if (const IntersectionType* itv = get(ty)) { for (TypeId part : itv->parts) visitChild(part); } - else if (const PendingExpansionTypeVar* petv = get(ty)) + else if (const PendingExpansionType* petv = get(ty)) { for (TypeId a : petv->typeArguments) visitChild(a); @@ -81,7 +81,7 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId a : petv->packArguments) visitChild(a); } - else if (const ClassTypeVar* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + else if (const ClassType* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) { for (auto [name, prop] : ctv->props) visitChild(prop.type); @@ -92,7 +92,7 @@ void Tarjan::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable); } - else if (const NegationTypeVar* ntv = get(ty)) + else if (const NegationType* ntv = get(ty)) { visitChild(ntv->ty); } @@ -559,7 +559,7 @@ void Substitution::replaceChildren(TypeId ty) if (ty->owningArena != arena) return; - if (FunctionTypeVar* ftv = getMutable(ty)) + if (FunctionType* ftv = getMutable(ty)) { if (FFlag::LuauSubstitutionFixMissingFields) { @@ -572,7 +572,7 @@ void Substitution::replaceChildren(TypeId ty) ftv->argTypes = replace(ftv->argTypes); ftv->retTypes = replace(ftv->retTypes); } - else if (TableTypeVar* ttv = getMutable(ty)) + else if (TableType* ttv = getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); for (auto& [name, prop] : ttv->props) @@ -589,22 +589,22 @@ void Substitution::replaceChildren(TypeId ty) for (TypePackId& itp : ttv->instantiatedTypePackParams) itp = replace(itp); } - else if (MetatableTypeVar* mtv = getMutable(ty)) + else if (MetatableType* mtv = getMutable(ty)) { mtv->table = replace(mtv->table); mtv->metatable = replace(mtv->metatable); } - else if (UnionTypeVar* utv = getMutable(ty)) + else if (UnionType* utv = getMutable(ty)) { for (TypeId& opt : utv->options) opt = replace(opt); } - else if (IntersectionTypeVar* itv = getMutable(ty)) + else if (IntersectionType* itv = getMutable(ty)) { for (TypeId& part : itv->parts) part = replace(part); } - else if (PendingExpansionTypeVar* petv = getMutable(ty)) + else if (PendingExpansionType* petv = getMutable(ty)) { for (TypeId& a : petv->typeArguments) a = replace(a); @@ -612,7 +612,7 @@ void Substitution::replaceChildren(TypeId ty) for (TypePackId& a : petv->packArguments) a = replace(a); } - else if (ClassTypeVar* ctv = getMutable(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + else if (ClassType* ctv = getMutable(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) { for (auto& [name, prop] : ctv->props) prop.type = replace(prop.type); @@ -623,7 +623,7 @@ void Substitution::replaceChildren(TypeId ty) if (ctv->metatable) ctv->metatable = replace(*ctv->metatable); } - else if (NegationTypeVar* ntv = getMutable(ty)) + else if (NegationType* ntv = getMutable(ty)) { ntv->ty = replace(ntv->ty); } diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 68fa53931..117d39d20 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -3,7 +3,7 @@ #include "Luau/ToString.h" #include "Luau/TypePack.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/StringUtils.h" #include @@ -49,10 +49,10 @@ struct StateDot bool StateDot::canDuplicatePrimitive(TypeId ty) { - if (get(ty)) + if (get(ty)) return false; - return get(ty) || get(ty); + return get(ty) || get(ty); } void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) @@ -72,9 +72,9 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) { - if (get(ty)) + if (get(ty)) formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str()); - else if (get(ty)) + else if (get(ty)) formatAppend(result, "n%d [label=\"any\"];\n", index); } else @@ -139,31 +139,31 @@ void StateDot::visitChildren(TypeId ty, int index) startNode(index); startNodeLabel(); - if (const BoundTypeVar* btv = get(ty)) + if (const BoundType* btv = get(ty)) { - formatAppend(result, "BoundTypeVar %d", index); + formatAppend(result, "BoundType %d", index); finishNodeLabel(ty); finishNode(); visitChild(btv->boundTo, index); } - else if (const FunctionTypeVar* ftv = get(ty)) + else if (const FunctionType* ftv = get(ty)) { - formatAppend(result, "FunctionTypeVar %d", index); + formatAppend(result, "FunctionType %d", index); finishNodeLabel(ty); finishNode(); visitChild(ftv->argTypes, index, "arg"); visitChild(ftv->retTypes, index, "ret"); } - else if (const TableTypeVar* ttv = get(ty)) + else if (const TableType* ttv = get(ty)) { if (ttv->name) - formatAppend(result, "TableTypeVar %s", ttv->name->c_str()); + formatAppend(result, "TableType %s", ttv->name->c_str()); else if (ttv->syntheticName) - formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str()); + formatAppend(result, "TableType %s", ttv->syntheticName->c_str()); else - formatAppend(result, "TableTypeVar %d", index); + formatAppend(result, "TableType %d", index); finishNodeLabel(ty); finishNode(); @@ -183,69 +183,69 @@ void StateDot::visitChildren(TypeId ty, int index) for (TypePackId itp : ttv->instantiatedTypePackParams) visitChild(itp, index, "typePackParam"); } - else if (const MetatableTypeVar* mtv = get(ty)) + else if (const MetatableType* mtv = get(ty)) { - formatAppend(result, "MetatableTypeVar %d", index); + formatAppend(result, "MetatableType %d", index); finishNodeLabel(ty); finishNode(); visitChild(mtv->table, index, "table"); visitChild(mtv->metatable, index, "metatable"); } - else if (const UnionTypeVar* utv = get(ty)) + else if (const UnionType* utv = get(ty)) { - formatAppend(result, "UnionTypeVar %d", index); + formatAppend(result, "UnionType %d", index); finishNodeLabel(ty); finishNode(); for (TypeId opt : utv->options) visitChild(opt, index); } - else if (const IntersectionTypeVar* itv = get(ty)) + else if (const IntersectionType* itv = get(ty)) { - formatAppend(result, "IntersectionTypeVar %d", index); + formatAppend(result, "IntersectionType %d", index); finishNodeLabel(ty); finishNode(); for (TypeId part : itv->parts) visitChild(part, index); } - else if (const GenericTypeVar* gtv = get(ty)) + else if (const GenericType* gtv = get(ty)) { if (gtv->explicitName) - formatAppend(result, "GenericTypeVar %s", gtv->name.c_str()); + formatAppend(result, "GenericType %s", gtv->name.c_str()); else - formatAppend(result, "GenericTypeVar %d", index); + formatAppend(result, "GenericType %d", index); finishNodeLabel(ty); finishNode(); } - else if (const FreeTypeVar* ftv = get(ty)) + else if (const FreeType* ftv = get(ty)) { - formatAppend(result, "FreeTypeVar %d", index); + formatAppend(result, "FreeType %d", index); finishNodeLabel(ty); finishNode(); } - else if (get(ty)) + else if (get(ty)) { - formatAppend(result, "AnyTypeVar %d", index); + formatAppend(result, "AnyType %d", index); finishNodeLabel(ty); finishNode(); } - else if (get(ty)) + else if (get(ty)) { - formatAppend(result, "PrimitiveTypeVar %s", toString(ty).c_str()); + formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); finishNodeLabel(ty); finishNode(); } - else if (get(ty)) + else if (get(ty)) { - formatAppend(result, "ErrorTypeVar %d", index); + formatAppend(result, "ErrorType %d", index); finishNodeLabel(ty); finishNode(); } - else if (const ClassTypeVar* ctv = get(ty)) + else if (const ClassType* ctv = get(ty)) { - formatAppend(result, "ClassTypeVar %s", ctv->name.c_str()); + formatAppend(result, "ClassType %s", ctv->name.c_str()); finishNodeLabel(ty); finishNode(); @@ -258,7 +258,7 @@ void StateDot::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable, index, "[metatable]"); } - else if (const SingletonTypeVar* stv = get(ty)) + else if (const SingletonType* stv = get(ty)) { std::string res; @@ -276,7 +276,7 @@ void StateDot::visitChildren(TypeId ty, int index) else LUAU_ASSERT(!"unknown singleton type"); - formatAppend(result, "SingletonTypeVar %s", res.c_str()); + formatAppend(result, "SingletonType %s", res.c_str()); finishNodeLabel(ty); finishNode(); } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index ed7c682d6..e80085089 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -4,10 +4,11 @@ #include "Luau/Constraint.h" #include "Luau/Location.h" #include "Luau/Scope.h" +#include "Luau/TxnLog.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include #include @@ -17,6 +18,7 @@ LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauLineBreaksDetermineIndents, false) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) +LUAU_FASTFLAGVARIABLE(LuauSerializeNilUnionAsNil, false) /* * Prefix generic typenames with gen- @@ -31,7 +33,7 @@ namespace Luau namespace { -struct FindCyclicTypes final : TypeVarVisitor +struct FindCyclicTypes final : TypeVisitor { FindCyclicTypes() = default; FindCyclicTypes(const FindCyclicTypes&) = delete; @@ -63,7 +65,7 @@ struct FindCyclicTypes final : TypeVarVisitor return visitedPacks.insert(tp).second; } - bool visit(TypeId ty, const TableTypeVar& ttv) override + bool visit(TypeId ty, const TableType& ttv) override { if (!visited.insert(ty).second) return false; @@ -82,7 +84,7 @@ struct FindCyclicTypes final : TypeVarVisitor return true; } - bool visit(TypeId ty, const ClassTypeVar&) override + bool visit(TypeId ty, const ClassType&) override { return false; } @@ -136,7 +138,7 @@ struct StringifierState , result(result) , exhaustive(opts.exhaustive) { - for (const auto& [_, v] : opts.nameMap.typeVars) + for (const auto& [_, v] : opts.nameMap.types) usedNames.insert(v); for (const auto& [_, v] : opts.nameMap.typePacks) usedNames.insert(v); @@ -162,8 +164,8 @@ struct StringifierState std::string getName(TypeId ty) { - const size_t s = opts.nameMap.typeVars.size(); - std::string& n = opts.nameMap.typeVars[ty]; + const size_t s = opts.nameMap.types.size(); + std::string& n = opts.nameMap.types[ty]; if (!n.empty()) return n; @@ -291,11 +293,11 @@ struct StringifierState } }; -struct TypeVarStringifier +struct TypeStringifier { StringifierState& state; - explicit TypeVarStringifier(StringifierState& state) + explicit TypeStringifier(StringifierState& state) : state(state) { } @@ -392,17 +394,17 @@ struct TypeVarStringifier } } - void operator()(TypeId, const BoundTypeVar& btv) + void operator()(TypeId, const BoundType& btv) { stringify(btv.boundTo); } - void operator()(TypeId ty, const GenericTypeVar& gtv) + void operator()(TypeId ty, const GenericType& gtv) { if (gtv.explicitName) { state.usedNames.insert(gtv.name); - state.opts.nameMap.typeVars[ty] = gtv.name; + state.opts.nameMap.types[ty] = gtv.name; state.emit(gtv.name); } else @@ -418,40 +420,40 @@ struct TypeVarStringifier } } - void operator()(TypeId, const BlockedTypeVar& btv) + void operator()(TypeId, const BlockedType& btv) { state.emit("*blocked-"); state.emit(btv.index); state.emit("*"); } - void operator()(TypeId ty, const PendingExpansionTypeVar& petv) + void operator()(TypeId ty, const PendingExpansionType& petv) { state.emit("*pending-expansion-"); state.emit(petv.index); state.emit("*"); } - void operator()(TypeId, const PrimitiveTypeVar& ptv) + void operator()(TypeId, const PrimitiveType& ptv) { switch (ptv.type) { - case PrimitiveTypeVar::NilType: + case PrimitiveType::NilType: state.emit("nil"); return; - case PrimitiveTypeVar::Boolean: + case PrimitiveType::Boolean: state.emit("boolean"); return; - case PrimitiveTypeVar::Number: + case PrimitiveType::Number: state.emit("number"); return; - case PrimitiveTypeVar::String: + case PrimitiveType::String: state.emit("string"); return; - case PrimitiveTypeVar::Thread: + case PrimitiveType::Thread: state.emit("thread"); return; - case PrimitiveTypeVar::Function: + case PrimitiveType::Function: state.emit("function"); return; default: @@ -460,7 +462,7 @@ struct TypeVarStringifier } } - void operator()(TypeId, const SingletonTypeVar& stv) + void operator()(TypeId, const SingletonType& stv) { if (const BooleanSingleton* bs = Luau::get(&stv)) state.emit(bs->value ? "true" : "false"); @@ -477,7 +479,7 @@ struct TypeVarStringifier } } - void operator()(TypeId, const FunctionTypeVar& ftv) + void operator()(TypeId, const FunctionType& ftv) { if (state.hasSeen(&ftv)) { @@ -539,7 +541,7 @@ struct TypeVarStringifier state.unsee(&ftv); } - void operator()(TypeId, const TableTypeVar& ttv) + void operator()(TypeId, const TableType& ttv) { if (ttv.boundTo) return stringify(*ttv.boundTo); @@ -681,7 +683,7 @@ struct TypeVarStringifier state.unsee(&ttv); } - void operator()(TypeId, const MetatableTypeVar& mtv) + void operator()(TypeId, const MetatableType& mtv) { state.result.invalid = true; if (!state.exhaustive && mtv.syntheticName) @@ -698,17 +700,17 @@ struct TypeVarStringifier state.emit(" }"); } - void operator()(TypeId, const ClassTypeVar& ctv) + void operator()(TypeId, const ClassType& ctv) { state.emit(ctv.name); } - void operator()(TypeId, const AnyTypeVar&) + void operator()(TypeId, const AnyType&) { state.emit("any"); } - void operator()(TypeId, const UnionTypeVar& uv) + void operator()(TypeId, const UnionType& uv) { if (state.hasSeen(&uv)) { @@ -718,6 +720,7 @@ struct TypeVarStringifier } bool optional = false; + bool hasNonNilDisjunct = false; std::vector results = {}; for (auto el : &uv) @@ -729,10 +732,14 @@ struct TypeVarStringifier optional = true; continue; } + else + { + hasNonNilDisjunct = true; + } std::string saved = std::move(state.result.name); - bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); + bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -771,11 +778,17 @@ struct TypeVarStringifier if (results.size() > 1) s = ")?"; + if (FFlag::LuauSerializeNilUnionAsNil) + { + if (!hasNonNilDisjunct) + s = "nil"; + } + state.emit(s); } } - void operator()(TypeId, const IntersectionTypeVar& uv) + void operator()(TypeId, const IntersectionType& uv) { if (state.hasSeen(&uv)) { @@ -791,7 +804,7 @@ struct TypeVarStringifier std::string saved = std::move(state.result.name); - bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); + bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -822,35 +835,35 @@ struct TypeVarStringifier } } - void operator()(TypeId, const ErrorTypeVar& tv) + void operator()(TypeId, const ErrorType& tv) { state.result.error = true; state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } - void operator()(TypeId, const LazyTypeVar& ltv) + void operator()(TypeId, const LazyType& ltv) { state.result.invalid = true; state.emit("lazy?"); } - void operator()(TypeId, const UnknownTypeVar& ttv) + void operator()(TypeId, const UnknownType& ttv) { state.emit("unknown"); } - void operator()(TypeId, const NeverTypeVar& ttv) + void operator()(TypeId, const NeverType& ttv) { state.emit("never"); } - void operator()(TypeId, const NegationTypeVar& ntv) + void operator()(TypeId, const NegationType& ntv) { state.emit("~"); // The precedence of `~` should be less than `|` and `&`. TypeId followed = follow(ntv.ty); - bool parens = get(followed) || get(followed); + bool parens = get(followed) || get(followed); if (parens) state.emit("("); @@ -884,7 +897,7 @@ struct TypePackStringifier void stringify(TypeId tv) { - TypeVarStringifier tvs{state}; + TypeStringifier tvs{state}; tvs.stringify(tv); } @@ -1033,13 +1046,13 @@ struct TypePackStringifier } }; -void TypeVarStringifier::stringify(TypePackId tp) +void TypeStringifier::stringify(TypePackId tp) { TypePackStringifier tps(state); tps.stringify(tp); } -void TypeVarStringifier::stringify(TypePackId tpid, const std::vector>& names) +void TypeStringifier::stringify(TypePackId tpid, const std::vector>& names) { TypePackStringifier tps(state, names); tps.stringify(tpid); @@ -1055,7 +1068,7 @@ static void assignCycleNames(const std::set& cycles, const std::set(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) + if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) { // If we have a cycle type in type parameters, assign a cycle name for this named table if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { @@ -1083,13 +1096,12 @@ static void assignCycleNames(const std::set& cycles, const std::set(ty); ttv && (ttv->name || ttv->syntheticName)) + if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) { if (ttv->syntheticName) result.invalid = true; @@ -1128,7 +1140,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) return result; } - else if (auto mtv = get(ty); mtv && mtv->syntheticName) + else if (auto mtv = get(ty); mtv && mtv->syntheticName) { result.invalid = true; result.name = *mtv->syntheticName; @@ -1213,7 +1225,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) { /* - * 1. Walk the TypeVar and track seen TypeIds. When you reencounter a TypeId, add it to a set of seen cycles. + * 1. Walk the Type and track seen TypeIds. When you reencounter a TypeId, add it to a set of seen cycles. * 2. Generate some names for each cycle. For a starting point, we can just call them t0, t1 and so on. * 3. For each seen cycle, stringify it like we do now, but replace each known cycle with its name. * 4. Print out the root of the type using the same algorithm as step 3. @@ -1228,7 +1240,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) assignCycleNames(cycles, cycleTPs, state.cycleNames, state.cycleTpNames, opts.exhaustive); - TypeVarStringifier tvs{state}; + TypeStringifier tvs{state}; /* If the root itself is a cycle, we special case a little. * We go out of our way to print the following: @@ -1289,7 +1301,7 @@ std::string toString(TypePackId tp, ToStringOptions& opts) return toStringDetailed(tp, opts).name; } -std::string toString(const TypeVar& tv, ToStringOptions& opts) +std::string toString(const Type& tv, ToStringOptions& opts) { return toString(const_cast(&tv), opts); } @@ -1299,11 +1311,11 @@ std::string toString(const TypePackVar& tp, ToStringOptions& opts) return toString(const_cast(&tp), opts); } -std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts) +std::string toStringNamedFunction(const std::string& funcName, const FunctionType& ftv, ToStringOptions& opts) { ToStringResult result; StringifierState state{opts, result}; - TypeVarStringifier tvs{state}; + TypeStringifier tvs{state}; state.emit(funcName); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 18596a638..dacd82dc1 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -87,7 +87,7 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) { TypeId leftTy = arena->addType((*leftRep)->pending); TypeId rightTy = arena->addType(rightRep->pending); - typeVarChanges[ty]->pending.ty = IntersectionTypeVar{{leftTy, rightTy}}; + typeVarChanges[ty]->pending.ty = IntersectionType{{leftTy, rightTy}}; } else typeVarChanges[ty] = std::move(rightRep); @@ -105,7 +105,7 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) { TypeId leftTy = arena->addType((*leftRep)->pending); TypeId rightTy = arena->addType(rightRep->pending); - typeVarChanges[ty]->pending.ty = UnionTypeVar{{leftTy, rightTy}}; + typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; } else typeVarChanges[ty] = std::move(rightRep); @@ -261,7 +261,7 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const return nullptr; } -PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) +PendingType* TxnLog::replace(TypeId ty, Type replacement) { PendingType* newTy = queue(ty); newTy->pending.reassign(replacement); @@ -277,10 +277,10 @@ PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) { - LUAU_ASSERT(get(ty)); + LUAU_ASSERT(get(ty)); PendingType* newTy = queue(ty); - if (TableTypeVar* ttv = Luau::getMutable(newTy)) + if (TableType* ttv = Luau::getMutable(newTy)) ttv->boundTo = newBoundTo; return newTy; @@ -288,19 +288,19 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); - if (FreeTypeVar* ftv = Luau::getMutable(newTy)) + if (FreeType* ftv = Luau::getMutable(newTy)) { ftv->level = newLevel; } - else if (TableTypeVar* ttv = Luau::getMutable(newTy)) + else if (TableType* ttv = Luau::getMutable(newTy)) { LUAU_ASSERT(ttv->state == TableState::Free || ttv->state == TableState::Generic); ttv->level = newLevel; } - else if (FunctionTypeVar* ftv = Luau::getMutable(newTy)) + else if (FunctionType* ftv = Luau::getMutable(newTy)) { ftv->level = newLevel; } @@ -323,19 +323,19 @@ PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) PendingType* TxnLog::changeScope(TypeId ty, NotNull newScope) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); - if (FreeTypeVar* ftv = Luau::getMutable(newTy)) + if (FreeType* ftv = Luau::getMutable(newTy)) { ftv->scope = newScope; } - else if (TableTypeVar* ttv = Luau::getMutable(newTy)) + else if (TableType* ttv = Luau::getMutable(newTy)) { LUAU_ASSERT(ttv->state == TableState::Free || ttv->state == TableState::Generic); ttv->scope = newScope; } - else if (FunctionTypeVar* ftv = Luau::getMutable(newTy)) + else if (FunctionType* ftv = Luau::getMutable(newTy)) { ftv->scope = newScope; } @@ -358,10 +358,10 @@ PendingTypePack* TxnLog::changeScope(TypePackId tp, NotNull newScope) PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexer) { - LUAU_ASSERT(get(ty)); + LUAU_ASSERT(get(ty)); PendingType* newTy = queue(ty); - if (TableTypeVar* ttv = Luau::getMutable(newTy)) + if (TableType* ttv = Luau::getMutable(newTy)) { ttv->indexer = indexer; } @@ -371,11 +371,11 @@ PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexe std::optional TxnLog::getLevel(TypeId ty) const { - if (FreeTypeVar* ftv = getMutable(ty)) + if (FreeType* ftv = getMutable(ty)) return ftv->level; - else if (TableTypeVar* ttv = getMutable(ty); ttv && (ttv->state == TableState::Free || ttv->state == TableState::Generic)) + else if (TableType* ttv = getMutable(ty); ttv && (ttv->state == TableState::Free || ttv->state == TableState::Generic)) return ttv->level; - else if (FunctionTypeVar* ftv = getMutable(ty)) + else if (FunctionType* ftv = getMutable(ty)) return ftv->level; return std::nullopt; @@ -392,7 +392,7 @@ TypeId TxnLog::follow(TypeId ty) const // Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants // that normally apply. This is safe because follow will only call get<> // on the returned pointer. - return const_cast(&state->pending); + return const_cast(&state->pending); }); } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/Type.cpp similarity index 71% rename from Analysis/src/TypeVar.cpp rename to Analysis/src/Type.cpp index 159e77125..aba6bddc1 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/Type.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" @@ -11,7 +11,7 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/VisitType.h" #include #include @@ -60,20 +60,20 @@ TypeId follow(TypeId t, std::function mapper) auto advance = [&mapper](TypeId ty) -> std::optional { if (auto btv = get>(mapper(ty))) return btv->boundTo; - else if (auto ttv = get(mapper(ty))) + else if (auto ttv = get(mapper(ty))) return ttv->boundTo; else return std::nullopt; }; auto force = [&mapper](TypeId ty) { - if (auto ltv = get_if(&mapper(ty)->ty)) + if (auto ltv = get_if(&mapper(ty)->ty)) { TypeId res = ltv->thunk(); - if (get(res)) - throw InternalCompilerError("Lazy TypeVar cannot resolve to another Lazy TypeVar"); + if (get(res)) + throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); - *asMutable(ty) = BoundTypeVar(res); + *asMutable(ty) = BoundType(res); } }; @@ -109,14 +109,14 @@ TypeId follow(TypeId t, std::function mapper) cycleTester = nullptr; if (t == cycleTester) - throw InternalCompilerError("Luau::follow detected a TypeVar cycle!!"); + throw InternalCompilerError("Luau::follow detected a Type cycle!!"); } } } std::vector flattenIntersection(TypeId ty) { - if (!get(follow(ty))) + if (!get(follow(ty))) return {ty}; std::unordered_set seen; @@ -134,7 +134,7 @@ std::vector flattenIntersection(TypeId ty) seen.insert(current); - if (auto itv = get(current)) + if (auto itv = get(current)) { for (TypeId ty : itv->parts) queue.push_back(ty); @@ -146,23 +146,23 @@ std::vector flattenIntersection(TypeId ty) return result; } -bool isPrim(TypeId ty, PrimitiveTypeVar::Type primType) +bool isPrim(TypeId ty, PrimitiveType::Type primType) { - auto p = get(follow(ty)); + auto p = get(follow(ty)); return p && p->type == primType; } bool isNil(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::NilType); + return isPrim(ty, PrimitiveType::NilType); } bool isBoolean(TypeId ty) { - if (isPrim(ty, PrimitiveTypeVar::Boolean) || get(get(follow(ty)))) + if (isPrim(ty, PrimitiveType::Boolean) || get(get(follow(ty)))) return true; - if (auto utv = get(follow(ty))) + if (auto utv = get(follow(ty))) return std::all_of(begin(utv), end(utv), isBoolean); return false; @@ -170,7 +170,7 @@ bool isBoolean(TypeId ty) bool isNumber(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::Number); + return isPrim(ty, PrimitiveType::Number); } // Returns true when ty is a subtype of string @@ -178,10 +178,10 @@ bool isString(TypeId ty) { ty = follow(ty); - if (isPrim(ty, PrimitiveTypeVar::String) || get(get(ty))) + if (isPrim(ty, PrimitiveType::String) || get(get(ty))) return true; - if (auto utv = get(ty)) + if (auto utv = get(ty)) return std::all_of(begin(utv), end(utv), isString); return false; @@ -192,10 +192,10 @@ bool maybeString(TypeId ty) { ty = follow(ty); - if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) + if (isPrim(ty, PrimitiveType::String) || get(ty)) return true; - if (auto utv = get(ty)) + if (auto utv = get(ty)) return std::any_of(begin(utv), end(utv), maybeString); return false; @@ -203,7 +203,7 @@ bool maybeString(TypeId ty) bool isThread(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::Thread); + return isPrim(ty, PrimitiveType::Thread); } bool isOptional(TypeId ty) @@ -213,10 +213,10 @@ bool isOptional(TypeId ty) ty = follow(ty); - if (get(ty) || (FFlag::LuauUnknownAndNeverType && get(ty))) + if (get(ty) || (FFlag::LuauUnknownAndNeverType && get(ty))) return true; - auto utv = get(ty); + auto utv = get(ty); if (!utv) return false; @@ -225,7 +225,7 @@ bool isOptional(TypeId ty) bool isTableIntersection(TypeId ty) { - if (!get(follow(ty))) + if (!get(follow(ty))) return false; std::vector parts = flattenIntersection(ty); @@ -234,28 +234,28 @@ bool isTableIntersection(TypeId ty) bool isOverloadedFunction(TypeId ty) { - if (!get(follow(ty))) + if (!get(follow(ty))) return false; auto isFunction = [](TypeId part) -> bool { - return get(part); + return get(part); }; std::vector parts = flattenIntersection(ty); return std::all_of(parts.begin(), parts.end(), isFunction); } -std::optional getMetatable(TypeId type, NotNull singletonTypes) +std::optional getMetatable(TypeId type, NotNull builtinTypes) { type = follow(type); - if (const MetatableTypeVar* mtType = get(type)) + if (const MetatableType* mtType = get(type)) return mtType->metatable; - else if (const ClassTypeVar* classType = get(type)) + else if (const ClassType* classType = get(type)) return classType->metatable; else if (isString(type)) { - auto ptv = get(singletonTypes->stringType); + auto ptv = get(builtinTypes->stringType); LUAU_ASSERT(ptv && ptv->metatable); return ptv->metatable; } @@ -263,34 +263,34 @@ std::optional getMetatable(TypeId type, NotNull singleto return std::nullopt; } -const TableTypeVar* getTableType(TypeId type) +const TableType* getTableType(TypeId type) { type = follow(type); - if (const TableTypeVar* ttv = get(type)) + if (const TableType* ttv = get(type)) return ttv; - else if (const MetatableTypeVar* mtv = get(type)) - return get(follow(mtv->table)); + else if (const MetatableType* mtv = get(type)) + return get(follow(mtv->table)); else return nullptr; } -TableTypeVar* getMutableTableType(TypeId type) +TableType* getMutableTableType(TypeId type) { - return const_cast(getTableType(type)); + return const_cast(getTableType(type)); } const std::string* getName(TypeId type) { type = follow(type); - if (auto mtv = get(type)) + if (auto mtv = get(type)) { if (mtv->syntheticName) return &*mtv->syntheticName; type = follow(mtv->table); } - if (auto ttv = get(type)) + if (auto ttv = get(type)) { if (ttv->name) return &*ttv->name; @@ -305,17 +305,17 @@ std::optional getDefinitionModuleName(TypeId type) { type = follow(type); - if (auto ttv = get(type)) + if (auto ttv = get(type)) { if (!ttv->definitionModuleName.empty()) return ttv->definitionModuleName; } - else if (auto ftv = get(type)) + else if (auto ftv = get(type)) { if (ftv->definition) return ftv->definition->definitionModuleName; } - else if (auto ctv = get(type)) + else if (auto ctv = get(type)) { if (!ctv->definitionModuleName.empty()) return ctv->definitionModuleName; @@ -324,7 +324,7 @@ std::optional getDefinitionModuleName(TypeId type) return std::nullopt; } -bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub) +bool isSubset(const UnionType& super, const UnionType& sub) { std::unordered_set superTypes; @@ -347,7 +347,7 @@ bool isGeneric(TypeId ty) LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); ty = follow(ty); - if (auto ftv = get(ty)) + if (auto ftv = get(ty)) return ftv->generics.size() > 0 || ftv->genericPacks.size() > 0; else // TODO: recurse on type synonyms CLI-39914 @@ -363,17 +363,17 @@ bool maybeGeneric(TypeId ty) { ty = follow(ty); - if (get(ty)) + if (get(ty)) return true; - if (auto ttv = get(ty)) + if (auto ttv = get(ty)) { // TODO: recurse on table types CLI-39914 (void)ttv; return true; } - if (auto itv = get(ty)) + if (auto itv = get(ty)) { return std::any_of(begin(itv), end(itv), maybeGeneric); } @@ -382,9 +382,9 @@ bool maybeGeneric(TypeId ty) } ty = follow(ty); - if (get(ty)) + if (get(ty)) return true; - else if (auto ttv = get(ty)) + else if (auto ttv = get(ty)) { // TODO: recurse on table types CLI-39914 (void)ttv; @@ -397,11 +397,11 @@ bool maybeGeneric(TypeId ty) bool maybeSingleton(TypeId ty) { ty = follow(ty); - if (get(ty)) + if (get(ty)) return true; - if (const UnionTypeVar* utv = get(ty)) + if (const UnionType* utv = get(ty)) for (TypeId option : utv) - if (get(follow(option))) + if (get(follow(option))) return true; return false; } @@ -415,10 +415,10 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) if (seen.contains(ty)) return true; - if (isString(ty) || get(ty) || get(ty) || get(ty)) + if (isString(ty) || get(ty) || get(ty) || get(ty)) return true; - if (auto uty = get(ty)) + if (auto uty = get(ty)) { seen.insert(ty); @@ -431,7 +431,7 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return true; } - if (auto ity = get(ty)) + if (auto ity = get(ty)) { seen.insert(ty); @@ -447,14 +447,14 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } -BlockedTypeVar::BlockedTypeVar() +BlockedType::BlockedType() : index(++nextIndex) { } -int BlockedTypeVar::nextIndex = 0; +int BlockedType::nextIndex = 0; -PendingExpansionTypeVar::PendingExpansionTypeVar( +PendingExpansionType::PendingExpansionType( std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) : prefix(prefix) , name(name) @@ -464,9 +464,9 @@ PendingExpansionTypeVar::PendingExpansionTypeVar( { } -size_t PendingExpansionTypeVar::nextIndex = 0; +size_t PendingExpansionType::nextIndex = 0; -FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) +FunctionType::FunctionType(TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : definition(std::move(defn)) , argTypes(argTypes) , retTypes(retTypes) @@ -474,7 +474,7 @@ FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std:: { } -FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) +FunctionType::FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : definition(std::move(defn)) , level(level) , argTypes(argTypes) @@ -483,7 +483,7 @@ FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackI { } -FunctionTypeVar::FunctionTypeVar( +FunctionType::FunctionType( TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : definition(std::move(defn)) , level(level) @@ -494,7 +494,7 @@ FunctionTypeVar::FunctionTypeVar( { } -FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, +FunctionType::FunctionType(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : definition(std::move(defn)) , generics(generics) @@ -505,7 +505,7 @@ FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector generics, std::vector genericPacks, TypePackId argTypes, +FunctionType::FunctionType(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : definition(std::move(defn)) , generics(generics) @@ -517,8 +517,8 @@ FunctionTypeVar::FunctionTypeVar(TypeLevel level, std::vector generics, { } -FunctionTypeVar::FunctionTypeVar(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, - TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) +FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, + TypePackId retTypes, std::optional defn, bool hasSelf) : definition(std::move(defn)) , generics(generics) , genericPacks(genericPacks) @@ -530,14 +530,14 @@ FunctionTypeVar::FunctionTypeVar(TypeLevel level, Scope* scope, std::vector& indexer, TypeLevel level, TableState state) +TableType::TableType(const Props& props, const std::optional& indexer, TypeLevel level, TableState state) : props(props) , indexer(indexer) , state(state) @@ -545,7 +545,7 @@ TableTypeVar::TableTypeVar(const Props& props, const std::optional { } -TableTypeVar::TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, Scope* scope, TableState state) +TableType::TableType(const Props& props, const std::optional& indexer, TypeLevel level, Scope* scope, TableState state) : props(props) , indexer(indexer) , state(state) @@ -554,8 +554,8 @@ TableTypeVar::TableTypeVar(const Props& props, const std::optional { } -// Test TypeVars for equivalence -// More complex than we'd like because TypeVars can self-reference. +// Test Types for equivalence +// More complex than we'd like because Types can self-reference. bool areSeen(SeenSet& seen, const void* lhs, const void* rhs) { @@ -570,7 +570,7 @@ bool areSeen(SeenSet& seen, const void* lhs, const void* rhs) return false; } -bool areEqual(SeenSet& seen, const FunctionTypeVar& lhs, const FunctionTypeVar& rhs) +bool areEqual(SeenSet& seen, const FunctionType& lhs, const FunctionType& rhs) { if (areSeen(seen, &lhs, &rhs)) return true; @@ -586,7 +586,7 @@ bool areEqual(SeenSet& seen, const FunctionTypeVar& lhs, const FunctionTypeVar& return true; } -bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs) +bool areEqual(SeenSet& seen, const TableType& lhs, const TableType& rhs) { if (areSeen(seen, &lhs, &rhs)) return true; @@ -626,7 +626,7 @@ bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs) return true; } -static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs) +static bool areEqual(SeenSet& seen, const MetatableType& lhs, const MetatableType& rhs) { if (areSeen(seen, &lhs, &rhs)) return true; @@ -634,110 +634,110 @@ static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const Metatable return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable); } -bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs) +bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs) { - if (auto bound = get_if(&lhs.ty)) + if (auto bound = get_if(&lhs.ty)) return areEqual(seen, *bound->boundTo, rhs); - if (auto bound = get_if(&rhs.ty)) + if (auto bound = get_if(&rhs.ty)) return areEqual(seen, lhs, *bound->boundTo); if (lhs.ty.index() != rhs.ty.index()) return false; { - const FreeTypeVar* lf = get_if(&lhs.ty); - const FreeTypeVar* rf = get_if(&rhs.ty); + const FreeType* lf = get_if(&lhs.ty); + const FreeType* rf = get_if(&rhs.ty); if (lf && rf) return lf->index == rf->index; } { - const GenericTypeVar* lg = get_if(&lhs.ty); - const GenericTypeVar* rg = get_if(&rhs.ty); + const GenericType* lg = get_if(&lhs.ty); + const GenericType* rg = get_if(&rhs.ty); if (lg && rg) return lg->index == rg->index; } { - const PrimitiveTypeVar* lp = get_if(&lhs.ty); - const PrimitiveTypeVar* rp = get_if(&rhs.ty); + const PrimitiveType* lp = get_if(&lhs.ty); + const PrimitiveType* rp = get_if(&rhs.ty); if (lp && rp) return lp->type == rp->type; } { - const GenericTypeVar* lg = get_if(&lhs.ty); - const GenericTypeVar* rg = get_if(&rhs.ty); + const GenericType* lg = get_if(&lhs.ty); + const GenericType* rg = get_if(&rhs.ty); if (lg && rg) return lg->index == rg->index; } { - const ErrorTypeVar* le = get_if(&lhs.ty); - const ErrorTypeVar* re = get_if(&rhs.ty); + const ErrorType* le = get_if(&lhs.ty); + const ErrorType* re = get_if(&rhs.ty); if (le && re) return le->index == re->index; } { - const FunctionTypeVar* lf = get_if(&lhs.ty); - const FunctionTypeVar* rf = get_if(&rhs.ty); + const FunctionType* lf = get_if(&lhs.ty); + const FunctionType* rf = get_if(&rhs.ty); if (lf && rf) return areEqual(seen, *lf, *rf); } { - const TableTypeVar* lt = get_if(&lhs.ty); - const TableTypeVar* rt = get_if(&rhs.ty); + const TableType* lt = get_if(&lhs.ty); + const TableType* rt = get_if(&rhs.ty); if (lt && rt) return areEqual(seen, *lt, *rt); } { - const MetatableTypeVar* lmt = get_if(&lhs.ty); - const MetatableTypeVar* rmt = get_if(&rhs.ty); + const MetatableType* lmt = get_if(&lhs.ty); + const MetatableType* rmt = get_if(&rhs.ty); if (lmt && rmt) return areEqual(seen, *lmt, *rmt); } - if (get_if(&lhs.ty) && get_if(&rhs.ty)) + if (get_if(&lhs.ty) && get_if(&rhs.ty)) return true; return false; } -TypeVar* asMutable(TypeId ty) +Type* asMutable(TypeId ty) { - return const_cast(ty); + return const_cast(ty); } -bool TypeVar::operator==(const TypeVar& rhs) const +bool Type::operator==(const Type& rhs) const { SeenSet seen; return areEqual(seen, *this, rhs); } -bool TypeVar::operator!=(const TypeVar& rhs) const +bool Type::operator!=(const Type& rhs) const { SeenSet seen; return !areEqual(seen, *this, rhs); } -TypeVar& TypeVar::operator=(const TypeVariant& rhs) +Type& Type::operator=(const TypeVariant& rhs) { ty = rhs; return *this; } -TypeVar& TypeVar::operator=(TypeVariant&& rhs) +Type& Type::operator=(TypeVariant&& rhs) { ty = std::move(rhs); return *this; } -TypeVar& TypeVar::operator=(const TypeVar& rhs) +Type& Type::operator=(const Type& rhs) { LUAU_ASSERT(owningArena == rhs.owningArena); LUAU_ASSERT(!rhs.persistent); @@ -751,37 +751,38 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); -SingletonTypes::SingletonTypes() +BuiltinTypes::BuiltinTypes() : arena(new TypeArena) , debugFreezeArena(FFlag::DebugLuauFreezeArena) - , nilType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true})) - , numberType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true})) - , stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true})) - , booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true})) - , threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true})) - , functionType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Function}, /*persistent*/ true})) - , trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true})) - , falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true})) - , anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true})) - , unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true})) - , neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true})) - , errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true})) - , falsyType(arena->addType(TypeVar{UnionTypeVar{{falseType, nilType}}, /*persistent*/ true})) - , truthyType(arena->addType(TypeVar{NegationTypeVar{falsyType}, /*persistent*/ true})) + , nilType(arena->addType(Type{PrimitiveType{PrimitiveType::NilType}, /*persistent*/ true})) + , numberType(arena->addType(Type{PrimitiveType{PrimitiveType::Number}, /*persistent*/ true})) + , stringType(arena->addType(Type{PrimitiveType{PrimitiveType::String}, /*persistent*/ true})) + , booleanType(arena->addType(Type{PrimitiveType{PrimitiveType::Boolean}, /*persistent*/ true})) + , threadType(arena->addType(Type{PrimitiveType{PrimitiveType::Thread}, /*persistent*/ true})) + , functionType(arena->addType(Type{PrimitiveType{PrimitiveType::Function}, /*persistent*/ true})) + , classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}}, /*persistent*/ true})) + , trueType(arena->addType(Type{SingletonType{BooleanSingleton{true}}, /*persistent*/ true})) + , falseType(arena->addType(Type{SingletonType{BooleanSingleton{false}}, /*persistent*/ true})) + , anyType(arena->addType(Type{AnyType{}, /*persistent*/ true})) + , unknownType(arena->addType(Type{UnknownType{}, /*persistent*/ true})) + , neverType(arena->addType(Type{NeverType{}, /*persistent*/ true})) + , errorType(arena->addType(Type{ErrorType{}, /*persistent*/ true})) + , falsyType(arena->addType(Type{UnionType{{falseType, nilType}}, /*persistent*/ true})) + , truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack)) , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) { TypeId stringMetatable = makeStringMetatable(); - asMutable(stringType)->ty = PrimitiveTypeVar{PrimitiveTypeVar::String, stringMetatable}; + asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable}; persist(stringMetatable); persist(uninhabitableTypePack); freeze(*arena); } -SingletonTypes::~SingletonTypes() +BuiltinTypes::~BuiltinTypes() { // Destroy the arena with the same memory management flags it was created with bool prevFlag = FFlag::DebugLuauFreezeArena; @@ -793,16 +794,16 @@ SingletonTypes::~SingletonTypes() FFlag::DebugLuauFreezeArena.value = prevFlag; } -TypeId SingletonTypes::makeStringMetatable() +TypeId BuiltinTypes::makeStringMetatable() { - const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); - const TypeId optionalString = arena->addType(UnionTypeVar{{nilType, stringType}}); - const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, booleanType}}); + const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}}); + const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}}); + const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}}); const TypePackId oneStringPack = arena->addTypePack({stringType}); const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); - FunctionTypeVar formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; + FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; formatFTV.magicFunction = &magicFunctionFormat; const TypeId formatFn = arena->addType(formatFTV); attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); @@ -813,28 +814,28 @@ TypeId SingletonTypes::makeStringMetatable() const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); - const TypeId replArgType = arena->addType( - UnionTypeVar{{stringType, arena->addType(TableTypeVar({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), + const TypeId replArgType = + arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); const TypeId gmatchFunc = - makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); + makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); attachMagicFunction(gmatchFunc, magicFunctionGmatch); attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); const TypeId matchFunc = arena->addType( - FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); + FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); attachMagicFunction(matchFunc, magicFunctionMatch); attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - const TypeId findFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); attachMagicFunction(findFunc, magicFunctionFind); attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); - TableTypeVar::Props stringLib = { - {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}}, + TableType::Props stringLib = { + {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, + {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, {"find", {findFunc}}, {"format", {formatFn}}, // FIXME {"gmatch", {gmatchFunc}}, @@ -847,13 +848,13 @@ TypeId SingletonTypes::makeStringMetatable() {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, {"upper", {stringToStringType}}, {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, - {"pack", {arena->addType(FunctionTypeVar{ + {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, + {"pack", {arena->addType(FunctionType{ arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack, })}}, {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"unpack", {arena->addType(FunctionTypeVar{ + {"unpack", {arena->addType(FunctionType{ arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), anyTypePack, })}}, @@ -861,30 +862,30 @@ TypeId SingletonTypes::makeStringMetatable() assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - TypeId tableType = arena->addType(TableTypeVar{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); - if (TableTypeVar* ttv = getMutable(tableType)) + if (TableType* ttv = getMutable(tableType)) ttv->name = FFlag::LuauNewLibraryTypeNames ? "typeof(string)" : "string"; - return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); + return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } -TypeId SingletonTypes::errorRecoveryType() +TypeId BuiltinTypes::errorRecoveryType() { return errorType; } -TypePackId SingletonTypes::errorRecoveryTypePack() +TypePackId BuiltinTypes::errorRecoveryTypePack() { return errorTypePack; } -TypeId SingletonTypes::errorRecoveryType(TypeId guess) +TypeId BuiltinTypes::errorRecoveryType(TypeId guess) { return guess; } -TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) +TypePackId BuiltinTypes::errorRecoveryTypePack(TypePackId guess) { return guess; } @@ -903,14 +904,14 @@ void persist(TypeId ty) asMutable(t)->persistent = true; - if (auto btv = get(t)) + if (auto btv = get(t)) queue.push_back(btv->boundTo); - else if (auto ftv = get(t)) + else if (auto ftv = get(t)) { persist(ftv->argTypes); persist(ftv->retTypes); } - else if (auto ttv = get(t)) + else if (auto ttv = get(t)) { LUAU_ASSERT(ttv->state != TableState::Free && ttv->state != TableState::Unsealed); @@ -923,28 +924,27 @@ void persist(TypeId ty) queue.push_back(ttv->indexer->indexResultType); } } - else if (auto ctv = get(t)) + else if (auto ctv = get(t)) { for (const auto& [_name, prop] : ctv->props) queue.push_back(prop.type); } - else if (auto utv = get(t)) + else if (auto utv = get(t)) { for (TypeId opt : utv->options) queue.push_back(opt); } - else if (auto itv = get(t)) + else if (auto itv = get(t)) { for (TypeId opt : itv->parts) queue.push_back(opt); } - else if (auto mtv = get(t)) + else if (auto mtv = get(t)) { queue.push_back(mtv->table); queue.push_back(mtv->metatable); } - else if (get(t) || get(t) || get(t) || get(t) || get(t) || - get(t)) + else if (get(t) || get(t) || get(t) || get(t) || get(t) || get(t)) { } else @@ -987,9 +987,9 @@ const TypeLevel* getLevel(TypeId ty) if (auto ftv = get(ty)) return &ftv->level; - else if (auto ttv = get(ty)) + else if (auto ttv = get(ty)) return &ttv->level; - else if (auto ftv = get(ty)) + else if (auto ftv = get(ty)) return &ftv->level; else return nullptr; @@ -1010,7 +1010,7 @@ std::optional getLevel(TypePackId tp) return std::nullopt; } -const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) +const Property* lookupClassProp(const ClassType* cls, const Name& name) { while (cls) { @@ -1019,7 +1019,7 @@ const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) return &it->second; if (cls->parent) - cls = get(*cls->parent); + cls = get(*cls->parent); else return nullptr; @@ -1029,7 +1029,7 @@ const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) return nullptr; } -bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) +bool isSubclass(const ClassType* cls, const ClassType* parent) { while (cls) { @@ -1038,44 +1038,44 @@ bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) else if (!cls->parent) return false; - cls = get(*cls->parent); + cls = get(*cls->parent); LUAU_ASSERT(cls); } return false; } -const std::vector& getTypes(const UnionTypeVar* utv) +const std::vector& getTypes(const UnionType* utv) { return utv->options; } -const std::vector& getTypes(const IntersectionTypeVar* itv) +const std::vector& getTypes(const IntersectionType* itv) { return itv->parts; } -UnionTypeVarIterator begin(const UnionTypeVar* utv) +UnionTypeIterator begin(const UnionType* utv) { - return UnionTypeVarIterator{utv}; + return UnionTypeIterator{utv}; } -UnionTypeVarIterator end(const UnionTypeVar* utv) +UnionTypeIterator end(const UnionType* utv) { - return UnionTypeVarIterator{}; + return UnionTypeIterator{}; } -IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv) +IntersectionTypeIterator begin(const IntersectionType* itv) { - return IntersectionTypeVarIterator{itv}; + return IntersectionTypeIterator{itv}; } -IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) +IntersectionTypeIterator end(const IntersectionType* itv) { - return IntersectionTypeVarIterator{}; + return IntersectionTypeIterator{}; } -static std::vector parseFormatString(NotNull singletonTypes, const char* data, size_t size) +static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs*"; @@ -1098,13 +1098,13 @@ static std::vector parseFormatString(NotNull singletonTy break; if (data[i] == 'q' || data[i] == 's') - result.push_back(singletonTypes->stringType); + result.push_back(builtinTypes->stringType); else if (data[i] == '*') - result.push_back(singletonTypes->unknownType); + result.push_back(builtinTypes->unknownType); else if (strchr(options, data[i])) - result.push_back(singletonTypes->numberType); + result.push_back(builtinTypes->numberType); else - result.push_back(singletonTypes->errorRecoveryType(singletonTypes->anyType)); + result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType)); } } @@ -1133,7 +1133,7 @@ std::optional> magicFunctionFormat( if (!fmt) return std::nullopt; - std::vector expected = parseFormatString(typechecker.singletonTypes, fmt->value.data, fmt->value.size); + std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(paramPack); size_t paramOffset = 1; @@ -1176,7 +1176,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) if (!fmt) return false; - std::vector expected = parseFormatString(context.solver->singletonTypes, fmt->value.data, fmt->value.size); + std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(context.arguments); size_t paramOffset = 1; @@ -1194,13 +1194,13 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - TypePackId resultPack = arena->addTypePack({context.solver->singletonTypes->stringType}); + TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType}); asMutable(context.result)->ty.emplace(resultPack); return true; } -static std::vector parsePatternString(NotNull singletonTypes, const char* data, size_t size) +static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) { std::vector result; int depth = 0; @@ -1232,12 +1232,12 @@ static std::vector parsePatternString(NotNull singletonT if (i + 1 < size && data[i + 1] == ')') { i++; - result.push_back(singletonTypes->numberType); + result.push_back(builtinTypes->numberType); continue; } ++depth; - result.push_back(singletonTypes->stringType); + result.push_back(builtinTypes->stringType); } else if (data[i] == ')') { @@ -1255,7 +1255,7 @@ static std::vector parsePatternString(NotNull singletonT return std::vector(); if (result.empty()) - result.push_back(singletonTypes->stringType); + result.push_back(builtinTypes->stringType); return result; } @@ -1279,7 +1279,7 @@ static std::optional> magicFunctionGmatch( if (!pattern) return std::nullopt; - std::vector returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); + std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1288,7 +1288,7 @@ static std::optional> magicFunctionGmatch( const TypePackId emptyPack = arena.addTypePack({}); const TypePackId returnList = arena.addTypePack(returnTypes); - const TypeId iteratorType = arena.addType(FunctionTypeVar{emptyPack, returnList}); + const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList}); return WithPredicate{arena.addTypePack({iteratorType})}; } @@ -1309,16 +1309,16 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) if (!pattern) return false; - std::vector returnTypes = parsePatternString(context.solver->singletonTypes, pattern->value.data, pattern->value.size); + std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return false; - context.solver->unify(params[0], context.solver->singletonTypes->stringType, context.solver->rootScope); + context.solver->unify(params[0], context.solver->builtinTypes->stringType, context.solver->rootScope); const TypePackId emptyPack = arena->addTypePack({}); const TypePackId returnList = arena->addTypePack(returnTypes); - const TypeId iteratorType = arena->addType(FunctionTypeVar{emptyPack, returnList}); + const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList}); const TypePackId resTypePack = arena->addTypePack({iteratorType}); asMutable(context.result)->ty.emplace(resTypePack); @@ -1344,14 +1344,14 @@ static std::optional> magicFunctionMatch( if (!pattern) return std::nullopt; - std::vector returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); + std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}}); + const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); size_t initIndex = expr.self ? 1 : 2; if (params.size() == 3 && expr.args.size > initIndex) @@ -1378,14 +1378,14 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) if (!pattern) return false; - std::vector returnTypes = parsePatternString(context.solver->singletonTypes, pattern->value.data, pattern->value.size); + std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return false; - context.solver->unify(params[0], context.solver->singletonTypes->stringType, context.solver->rootScope); + context.solver->unify(params[0], context.solver->builtinTypes->stringType, context.solver->rootScope); - const TypeId optionalNumber = arena->addType(UnionTypeVar{{context.solver->singletonTypes->nilType, context.solver->singletonTypes->numberType}}); + const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}}); size_t initIndex = context.callSite->self ? 1 : 2; if (params.size() == 3 && context.callSite->args.size > initIndex) @@ -1427,7 +1427,7 @@ static std::optional> magicFunctionFind( std::vector returnTypes; if (!plain) { - returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); + returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1435,8 +1435,8 @@ static std::optional> magicFunctionFind( typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}}); - const TypeId optionalBoolean = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.booleanType}}); + const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); + const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}}); size_t initIndex = expr.self ? 1 : 2; if (params.size() >= 3 && expr.args.size > initIndex) @@ -1459,7 +1459,7 @@ static bool dcrMagicFunctionFind(MagicFunctionCallContext context) return false; TypeArena* arena = context.solver->arena; - NotNull singletonTypes = context.solver->singletonTypes; + NotNull builtinTypes = context.solver->builtinTypes; AstExprConstantString* pattern = nullptr; size_t patternIndex = context.callSite->self ? 0 : 1; @@ -1480,16 +1480,16 @@ static bool dcrMagicFunctionFind(MagicFunctionCallContext context) std::vector returnTypes; if (!plain) { - returnTypes = parsePatternString(singletonTypes, pattern->value.data, pattern->value.size); + returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return false; } - context.solver->unify(params[0], singletonTypes->stringType, context.solver->rootScope); + context.solver->unify(params[0], builtinTypes->stringType, context.solver->rootScope); - const TypeId optionalNumber = arena->addType(UnionTypeVar{{singletonTypes->nilType, singletonTypes->numberType}}); - const TypeId optionalBoolean = arena->addType(UnionTypeVar{{singletonTypes->nilType, singletonTypes->booleanType}}); + const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}}); + const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}}); size_t initIndex = context.callSite->self ? 1 : 2; if (params.size() >= 3 && context.callSite->args.size > initIndex) @@ -1509,7 +1509,7 @@ std::vector filterMap(TypeId type, TypeIdPredicate predicate) { type = follow(type); - if (auto utv = get(type)) + if (auto utv = get(type)) { std::set options; for (TypeId option : utv) @@ -1528,11 +1528,11 @@ static Tags* getTags(TypeId ty) { ty = follow(ty); - if (auto ftv = getMutable(ty)) + if (auto ftv = getMutable(ty)) return &ftv->tags; - else if (auto ttv = getMutable(ty)) + else if (auto ttv = getMutable(ty)) return &ttv->tags; - else if (auto ctv = getMutable(ty)) + else if (auto ctv = getMutable(ty)) return &ctv->tags; return nullptr; @@ -1565,7 +1565,7 @@ bool hasTag(TypeId ty, const std::string& tagName) // We special case classes because getTags only returns a pointer to one vector of tags. // But classes has multiple vector of tags, represented throughout the hierarchy. - if (auto ctv = get(ty)) + if (auto ctv = get(ty)) { while (ctv) { @@ -1574,7 +1574,7 @@ bool hasTag(TypeId ty, const std::string& tagName) else if (!ctv->parent) return false; - ctv = get(*ctv->parent); + ctv = get(*ctv->parent); LUAU_ASSERT(ctv); } } diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index 666ab8674..ed51517ea 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -9,13 +9,13 @@ namespace Luau void TypeArena::clear() { - typeVars.clear(); + types.clear(); typePacks.clear(); } -TypeId TypeArena::addTV(TypeVar&& tv) +TypeId TypeArena::addTV(Type&& tv) { - TypeId allocated = typeVars.allocate(std::move(tv)); + TypeId allocated = types.allocate(std::move(tv)); asMutable(allocated)->owningArena = this; @@ -24,7 +24,7 @@ TypeId TypeArena::addTV(TypeVar&& tv) TypeId TypeArena::freshType(TypeLevel level) { - TypeId allocated = typeVars.allocate(FreeTypeVar{level}); + TypeId allocated = types.allocate(FreeType{level}); asMutable(allocated)->owningArena = this; @@ -33,7 +33,7 @@ TypeId TypeArena::freshType(TypeLevel level) TypeId TypeArena::freshType(Scope* scope) { - TypeId allocated = typeVars.allocate(FreeTypeVar{scope}); + TypeId allocated = types.allocate(FreeType{scope}); asMutable(allocated)->owningArena = this; @@ -42,7 +42,7 @@ TypeId TypeArena::freshType(Scope* scope) TypeId TypeArena::freshType(Scope* scope, TypeLevel level) { - TypeId allocated = typeVars.allocate(FreeTypeVar{scope, level}); + TypeId allocated = types.allocate(FreeType{scope, level}); asMutable(allocated)->owningArena = this; @@ -99,7 +99,7 @@ void freeze(TypeArena& arena) if (!FFlag::DebugLuauFreezeArena) return; - arena.typeVars.freeze(); + arena.types.freeze(); arena.typePacks.freeze(); } @@ -108,7 +108,7 @@ void unfreeze(TypeArena& arena) if (!FFlag::DebugLuauFreezeArena) return; - arena.typeVars.unfreeze(); + arena.types.unfreeze(); arena.typePacks.unfreeze(); } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index e483c0473..d1d89b25a 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -8,7 +8,7 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include @@ -75,36 +75,36 @@ class TypeRehydrationVisitor AstTypePack* rehydrate(TypePackId tp); - AstType* operator()(const PrimitiveTypeVar& ptv) + AstType* operator()(const PrimitiveType& ptv) { switch (ptv.type) { - case PrimitiveTypeVar::NilType: + case PrimitiveType::NilType: return allocator->alloc(Location(), std::nullopt, AstName("nil")); - case PrimitiveTypeVar::Boolean: + case PrimitiveType::Boolean: return allocator->alloc(Location(), std::nullopt, AstName("boolean")); - case PrimitiveTypeVar::Number: + case PrimitiveType::Number: return allocator->alloc(Location(), std::nullopt, AstName("number")); - case PrimitiveTypeVar::String: + case PrimitiveType::String: return allocator->alloc(Location(), std::nullopt, AstName("string")); - case PrimitiveTypeVar::Thread: + case PrimitiveType::Thread: return allocator->alloc(Location(), std::nullopt, AstName("thread")); default: return nullptr; } } - AstType* operator()(const BlockedTypeVar& btv) + AstType* operator()(const BlockedType& btv) { return allocator->alloc(Location(), std::nullopt, AstName("*blocked*")); } - AstType* operator()(const PendingExpansionTypeVar& petv) + AstType* operator()(const PendingExpansionType& petv) { return allocator->alloc(Location(), std::nullopt, AstName("*pending-expansion*")); } - AstType* operator()(const SingletonTypeVar& stv) + AstType* operator()(const SingletonType& stv) { if (const BooleanSingleton* bs = get(&stv)) return allocator->alloc(Location(), bs->value); @@ -119,11 +119,11 @@ class TypeRehydrationVisitor return nullptr; } - AstType* operator()(const AnyTypeVar&) + AstType* operator()(const AnyType&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); } - AstType* operator()(const TableTypeVar& ttv) + AstType* operator()(const TableType& ttv) { RecursionCounter counter(&count); @@ -182,12 +182,12 @@ class TypeRehydrationVisitor return allocator->alloc(Location(), props, indexer); } - AstType* operator()(const MetatableTypeVar& mtv) + AstType* operator()(const MetatableType& mtv) { return Luau::visit(*this, mtv.table->ty); } - AstType* operator()(const ClassTypeVar& ctv) + AstType* operator()(const ClassType& ctv) { RecursionCounter counter(&count); @@ -214,7 +214,7 @@ class TypeRehydrationVisitor return allocator->alloc(Location(), props); } - AstType* operator()(const FunctionTypeVar& ftv) + AstType* operator()(const FunctionType& ftv) { RecursionCounter counter(&count); @@ -227,7 +227,7 @@ class TypeRehydrationVisitor size_t numGenerics = 0; for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { - if (auto gtv = get(*it)) + if (auto gtv = get(*it)) generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } @@ -237,7 +237,7 @@ class TypeRehydrationVisitor size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { - if (auto gtv = get(*it)) + if (auto gtv = get(*it)) genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } @@ -292,7 +292,7 @@ class TypeRehydrationVisitor { return allocator->alloc(Location(), std::nullopt, AstName("Unifiable")); } - AstType* operator()(const GenericTypeVar& gtv) + AstType* operator()(const GenericType& gtv) { return allocator->alloc(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv))); } @@ -300,11 +300,11 @@ class TypeRehydrationVisitor { return Luau::visit(*this, bound.boundTo->ty); } - AstType* operator()(const FreeTypeVar& ftv) + AstType* operator()(const FreeType& ftv) { return allocator->alloc(Location(), std::nullopt, AstName("free")); } - AstType* operator()(const UnionTypeVar& uv) + AstType* operator()(const UnionType& uv) { AstArray unionTypes; unionTypes.size = uv.options.size(); @@ -315,7 +315,7 @@ class TypeRehydrationVisitor } return allocator->alloc(Location(), unionTypes); } - AstType* operator()(const IntersectionTypeVar& uv) + AstType* operator()(const IntersectionType& uv) { AstArray intersectionTypes; intersectionTypes.size = uv.parts.size(); @@ -326,22 +326,22 @@ class TypeRehydrationVisitor } return allocator->alloc(Location(), intersectionTypes); } - AstType* operator()(const LazyTypeVar& ltv) + AstType* operator()(const LazyType& ltv) { return allocator->alloc(Location(), std::nullopt, AstName("")); } - AstType* operator()(const UnknownTypeVar& ttv) + AstType* operator()(const UnknownType& ttv) { return allocator->alloc(Location(), std::nullopt, AstName{"unknown"}); } - AstType* operator()(const NeverTypeVar& ttv) + AstType* operator()(const NeverType& ttv) { return allocator->alloc(Location(), std::nullopt, AstName{"never"}); } - AstType* operator()(const NegationTypeVar& ntv) + AstType* operator()(const NegationType& ntv) { - // FIXME: do the same thing we do with ErrorTypeVar - throw InternalCompilerError("Cannot convert NegationTypeVar into AstNode"); + // FIXME: do the same thing we do with ErrorType + throw InternalCompilerError("Cannot convert NegationType into AstNode"); } private: diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 8c44f90a5..5451a454e 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -10,7 +10,7 @@ #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/TypeUtils.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Unifier.h" #include "Luau/ToString.h" #include "Luau/DcrLogger.h" @@ -19,6 +19,7 @@ LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauNegatedClassTypes) namespace Luau { @@ -82,19 +83,20 @@ static std::optional getIdentifierOfBaseVar(AstExpr* node) struct TypeChecker2 { - NotNull singletonTypes; + NotNull builtinTypes; DcrLogger* logger; InternalErrorReporter ice; // FIXME accept a pointer from Frontend const SourceModule* sourceModule; Module* module; + TypeArena testArena; std::vector> stack; UnifierSharedState sharedState{&ice}; - Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; + Normalizer normalizer{&testArena, builtinTypes, NotNull{&sharedState}}; - TypeChecker2(NotNull singletonTypes, DcrLogger* logger, const SourceModule* sourceModule, Module* module) - : singletonTypes(singletonTypes) + TypeChecker2(NotNull builtinTypes, DcrLogger* logger, const SourceModule* sourceModule, Module* module) + : builtinTypes(builtinTypes) , logger(logger) , sourceModule(sourceModule) , module(module) @@ -120,7 +122,7 @@ struct TypeChecker2 if (tp) return follow(*tp); else - return singletonTypes->anyTypePack; + return builtinTypes->anyTypePack; } TypeId lookupType(AstExpr* expr) @@ -136,7 +138,7 @@ struct TypeChecker2 if (tp) return flattenPack(*tp); - return singletonTypes->anyType; + return builtinTypes->anyType; } TypeId lookupAnnotation(AstType* annotation) @@ -298,7 +300,7 @@ struct TypeChecker2 Scope* scope = findInnermostScope(ret->location); TypePackId expectedRetType = scope->returnType; - TypeArena* arena = &module->internalTypes; + TypeArena* arena = &testArena; TypePackId actualRetType = reconstructPack(ret->list, *arena); Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; @@ -346,6 +348,8 @@ struct TypeChecker2 if (!errors.empty()) reportErrors(std::move(errors)); } + + visit(var->annotation); } } else @@ -368,6 +372,8 @@ struct TypeChecker2 ErrorVec errors = tryUnify(stack.back(), value->location, *it, varType); if (!errors.empty()) reportErrors(std::move(errors)); + + visit(var->annotation); } ++it; @@ -406,7 +412,7 @@ struct TypeChecker2 return; NotNull scope = stack.back(); - TypeArena& arena = module->internalTypes; + TypeArena& arena = testArena; std::vector variableTypes; for (AstLocal* var : forInStatement->vars) @@ -425,7 +431,7 @@ struct TypeChecker2 TypePackId iteratorPack = arena.addTypePack(valueTypes, iteratorTail); // ... and then expand it out to 3 values (if possible) - TypePack iteratorTypes = extendTypePack(arena, singletonTypes, iteratorPack, 3); + TypePack iteratorTypes = extendTypePack(arena, builtinTypes, iteratorPack, 3); if (iteratorTypes.head.empty()) { reportError(GenericError{"for..in loops require at least one value to iterate over. Got zero"}, getLocation(forInStatement->values)); @@ -434,7 +440,7 @@ struct TypeChecker2 TypeId iteratorTy = follow(iteratorTypes.head[0]); auto checkFunction = [this, &arena, &scope, &forInStatement, &variableTypes]( - const FunctionTypeVar* iterFtv, std::vector iterTys, bool isMm) { + const FunctionType* iterFtv, std::vector iterTys, bool isMm) { if (iterTys.size() < 1 || iterTys.size() > 3) { if (isMm) @@ -446,7 +452,7 @@ struct TypeChecker2 } // It is okay if there aren't enough iterators, but the iteratee must provide enough. - TypePack expectedVariableTypes = extendTypePack(arena, singletonTypes, iterFtv->retTypes, variableTypes.size()); + TypePack expectedVariableTypes = extendTypePack(arena, builtinTypes, iterFtv->retTypes, variableTypes.size()); if (expectedVariableTypes.head.size() < variableTypes.size()) { if (isMm) @@ -478,7 +484,7 @@ struct TypeChecker2 if (maxCount && *maxCount < 2) reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - TypePack flattenedArgTypes = extendTypePack(arena, singletonTypes, iterFtv->argTypes, 2); + TypePack flattenedArgTypes = extendTypePack(arena, builtinTypes, iterFtv->argTypes, 2); size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; size_t actualArgCount = expectedVariableTypes.head.size(); @@ -515,11 +521,11 @@ struct TypeChecker2 * nil. * * nextTy() must be callable with only 2 arguments. */ - if (const FunctionTypeVar* nextFn = get(iteratorTy)) + if (const FunctionType* nextFn = get(iteratorTy)) { checkFunction(nextFn, iteratorTypes.head, false); } - else if (const TableTypeVar* ttv = get(iteratorTy)) + else if (const TableType* ttv = get(iteratorTy)) { if ((forInStatement->vars.size == 1 || forInStatement->vars.size == 2) && ttv->indexer) { @@ -530,23 +536,23 @@ struct TypeChecker2 else reportError(GenericError{"Cannot iterate over a table without indexer"}, forInStatement->values.data[0]->location); } - else if (get(iteratorTy) || get(iteratorTy)) + else if (get(iteratorTy) || get(iteratorTy)) { // nothing } else if (std::optional iterMmTy = - findMetatableEntry(singletonTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) + findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) { Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}, scope}; if (std::optional instantiatedIterMmTy = instantiation.substitute(*iterMmTy)) { - if (const FunctionTypeVar* iterMmFtv = get(*instantiatedIterMmTy)) + if (const FunctionType* iterMmFtv = get(*instantiatedIterMmTy)) { TypePackId argPack = arena.addTypePack({iteratorTy}); reportErrors(tryUnify(scope, forInStatement->values.data[0]->location, argPack, iterMmFtv->argTypes)); - TypePack mmIteratorTypes = extendTypePack(arena, singletonTypes, iterMmFtv->retTypes, 3); + TypePack mmIteratorTypes = extendTypePack(arena, builtinTypes, iterMmFtv->retTypes, 3); if (mmIteratorTypes.head.size() == 0) { @@ -561,7 +567,7 @@ struct TypeChecker2 std::vector instantiatedIteratorTypes = mmIteratorTypes.head; instantiatedIteratorTypes[0] = *instantiatedNextFn; - if (const FunctionTypeVar* nextFtv = get(*instantiatedNextFn)) + if (const FunctionType* nextFtv = get(*instantiatedNextFn)) { checkFunction(nextFtv, instantiatedIteratorTypes, true); } @@ -760,7 +766,7 @@ struct TypeChecker2 void visit(AstExprConstantNumber* number) { TypeId actualType = lookupType(number); - TypeId numberType = singletonTypes->numberType; + TypeId numberType = builtinTypes->numberType; if (!isSubtype(numberType, actualType, stack.back())) { @@ -771,7 +777,7 @@ struct TypeChecker2 void visit(AstExprConstantString* string) { TypeId actualType = lookupType(string); - TypeId stringType = singletonTypes->stringType; + TypeId stringType = builtinTypes->stringType; if (!isSubtype(actualType, stringType, stack.back())) { @@ -801,7 +807,7 @@ struct TypeChecker2 for (AstExpr* arg : call->args) visit(arg); - TypeArena* arena = &module->internalTypes; + TypeArena* arena = &testArena; Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; TypePackId expectedRetType = lookupPack(call); @@ -809,11 +815,11 @@ struct TypeChecker2 TypeId testFunctionType = functionType; TypePack args; - if (get(functionType) || get(functionType)) + if (get(functionType) || get(functionType)) return; - else if (std::optional callMm = findMetatableEntry(singletonTypes, module->errors, functionType, "__call", call->func->location)) + else if (std::optional callMm = findMetatableEntry(builtinTypes, module->errors, functionType, "__call", call->func->location)) { - if (get(follow(*callMm))) + if (get(follow(*callMm))) { if (std::optional instantiatedCallMm = instantiation.substitute(*callMm)) { @@ -834,7 +840,7 @@ struct TypeChecker2 return; } } - else if (get(functionType)) + else if (get(functionType)) { if (std::optional instantiatedFunctionType = instantiation.substitute(functionType)) { @@ -846,7 +852,7 @@ struct TypeChecker2 return; } } - else if (auto utv = get(functionType)) + else if (auto utv = get(functionType)) { // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. std::optional fst; @@ -862,7 +868,7 @@ struct TypeChecker2 } if (!fst) - ice.ice("UnionTypeVar had no elements, so fst is nullopt?"); + ice.ice("UnionType had no elements, so fst is nullopt?"); if (std::optional instantiatedFunctionType = instantiation.substitute(*fst)) { @@ -901,20 +907,20 @@ struct TypeChecker2 if (argTail) args.tail = *argTail; else - args.tail = singletonTypes->anyTypePack; + args.tail = builtinTypes->anyTypePack; } else - args.head.push_back(singletonTypes->anyType); + args.head.push_back(builtinTypes->anyType); } TypePackId argsTp = arena->addTypePack(args); - FunctionTypeVar ftv{argsTp, expectedRetType}; + FunctionType ftv{argsTp, expectedRetType}; TypeId expectedType = arena->addType(ftv); if (!isSubtype(testFunctionType, expectedType, stack.back())) { CloneState cloneState; - expectedType = clone(expectedType, module->internalTypes, cloneState); + expectedType = clone(expectedType, testArena, cloneState); reportError(TypeMismatch{expectedType, functionType}, call->location); } } @@ -942,7 +948,7 @@ struct TypeChecker2 auto StackPusher = pushStack(fn); TypeId inferredFnTy = lookupType(fn); - const FunctionTypeVar* inferredFtv = get(inferredFnTy); + const FunctionType* inferredFtv = get(inferredFnTy); LUAU_ASSERT(inferredFtv); auto argIt = begin(inferredFtv->argTypes); @@ -986,24 +992,24 @@ struct TypeChecker2 NotNull scope = stack.back(); TypeId operandType = lookupType(expr->expr); - if (get(operandType) || get(operandType) || get(operandType)) + if (get(operandType) || get(operandType) || get(operandType)) return; if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end()) { - std::optional mm = findMetatableEntry(singletonTypes, module->errors, operandType, it->second, expr->location); + std::optional mm = findMetatableEntry(builtinTypes, module->errors, operandType, it->second, expr->location); if (mm) { - if (const FunctionTypeVar* ftv = get(follow(*mm))) + if (const FunctionType* ftv = get(follow(*mm))) { - TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); + TypePackId expectedArgs = testArena.addTypePack({operandType}); reportErrors(tryUnify(scope, expr->location, expectedArgs, ftv->argTypes)); if (std::optional ret = first(ftv->retTypes)) { if (expr->op == AstExprUnary::Op::Len) { - reportErrors(tryUnify(scope, expr->location, follow(*ret), singletonTypes->numberType)); + reportErrors(tryUnify(scope, expr->location, follow(*ret), builtinTypes->numberType)); } } else @@ -1028,7 +1034,7 @@ struct TypeChecker2 } else if (expr->op == AstExprUnary::Op::Minus) { - reportErrors(tryUnify(scope, expr->location, operandType, singletonTypes->numberType)); + reportErrors(tryUnify(scope, expr->location, operandType, builtinTypes->numberType)); } else if (expr->op == AstExprUnary::Op::Not) { @@ -1055,15 +1061,15 @@ struct TypeChecker2 if (expr->op == AstExprBinary::Op::Or) { - leftType = stripNil(singletonTypes, module->internalTypes, leftType); + leftType = stripNil(builtinTypes, testArena, leftType); } bool isStringOperation = isString(leftType) && isString(rightType); - if (get(leftType) || get(leftType) || get(rightType) || get(rightType)) + if (get(leftType) || get(leftType) || get(rightType) || get(rightType)) return; - if ((get(leftType) || get(leftType)) && !isEquality && !isLogical) + if ((get(leftType) || get(leftType)) && !isEquality && !isLogical) { auto name = getIdentifierOfBaseVar(expr->left); reportError(CannotInferBinaryOperation{expr->op, name, @@ -1074,16 +1080,16 @@ struct TypeChecker2 if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) { - std::optional leftMt = getMetatable(leftType, singletonTypes); - std::optional rightMt = getMetatable(rightType, singletonTypes); + std::optional leftMt = getMetatable(leftType, builtinTypes); + std::optional rightMt = getMetatable(rightType, builtinTypes); bool matches = leftMt == rightMt; if (isEquality && !matches) { - auto testUnion = [&matches, singletonTypes = this->singletonTypes](const UnionTypeVar* utv, std::optional otherMt) { + auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional otherMt) { for (TypeId option : utv) { - if (getMetatable(follow(option), singletonTypes) == otherMt) + if (getMetatable(follow(option), builtinTypes) == otherMt) { matches = true; break; @@ -1091,12 +1097,12 @@ struct TypeChecker2 } }; - if (const UnionTypeVar* utv = get(leftType); utv && rightMt) + if (const UnionType* utv = get(leftType); utv && rightMt) { testUnion(utv, rightMt); } - if (const UnionTypeVar* utv = get(rightType); utv && leftMt && !matches) + if (const UnionType* utv = get(rightType); utv && leftMt && !matches) { testUnion(utv, leftMt); } @@ -1112,9 +1118,9 @@ struct TypeChecker2 } std::optional mm; - if (std::optional leftMm = findMetatableEntry(singletonTypes, module->errors, leftType, it->second, expr->left->location)) + if (std::optional leftMm = findMetatableEntry(builtinTypes, module->errors, leftType, it->second, expr->left->location)) mm = leftMm; - else if (std::optional rightMm = findMetatableEntry(singletonTypes, module->errors, rightType, it->second, expr->right->location)) + else if (std::optional rightMm = findMetatableEntry(builtinTypes, module->errors, rightType, it->second, expr->right->location)) { mm = rightMm; std::swap(leftType, rightType); @@ -1126,18 +1132,18 @@ struct TypeChecker2 if (!instantiatedMm) reportError(CodeTooComplex{}, expr->location); - else if (const FunctionTypeVar* ftv = get(follow(instantiatedMm))) + else if (const FunctionType* ftv = get(follow(instantiatedMm))) { TypePackId expectedArgs; // For >= and > we invoke __lt and __le respectively with // swapped argument ordering. if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt) { - expectedArgs = module->internalTypes.addTypePack({rightType, leftType}); + expectedArgs = testArena.addTypePack({rightType, leftType}); } else { - expectedArgs = module->internalTypes.addTypePack({leftType, rightType}); + expectedArgs = testArena.addTypePack({leftType, rightType}); } reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); @@ -1145,7 +1151,7 @@ struct TypeChecker2 if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe || expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) { - TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType}); + TypePackId expectedRets = testArena.addTypePack({builtinTypes->booleanType}); if (!isSubtype(ftv->retTypes, expectedRets, scope)) { reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location); @@ -1186,7 +1192,7 @@ struct TypeChecker2 return; } - else if (!leftMt && !rightMt && (get(leftType) || get(rightType))) + else if (!leftMt && !rightMt && (get(leftType) || get(rightType))) { if (isComparison) { @@ -1214,13 +1220,13 @@ struct TypeChecker2 case AstExprBinary::Op::Div: case AstExprBinary::Op::Pow: case AstExprBinary::Op::Mod: - reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->numberType)); - reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType)); + reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->numberType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); break; case AstExprBinary::Op::Concat: - reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->stringType)); - reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType)); + reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->stringType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->stringType)); break; case AstExprBinary::Op::CompareGe: @@ -1228,9 +1234,9 @@ struct TypeChecker2 case AstExprBinary::Op::CompareLe: case AstExprBinary::Op::CompareLt: if (isNumber(leftType)) - reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); else if (isString(leftType)) - reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->stringType)); else reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, @@ -1304,8 +1310,8 @@ struct TypeChecker2 return vtp->ty; else if (auto ftp = get(pack)) { - TypeId result = module->internalTypes.addType(FreeTypeVar{ftp->scope}); - TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); + TypeId result = testArena.addType(FreeType{ftp->scope}); + TypePackId freeTail = testArena.addTypePack(FreeTypePack{ftp->scope}); TypePack& resultPack = asMutable(pack)->ty.emplace(); resultPack.head.assign(1, result); @@ -1314,7 +1320,7 @@ struct TypeChecker2 return result; } else if (get(pack)) - return singletonTypes->errorRecoveryType(); + return builtinTypes->errorRecoveryType(); else ice.ice("flattenPack got a weird pack!"); } @@ -1337,6 +1343,11 @@ struct TypeChecker2 void visit(AstTypeReference* ty) { + // No further validation is necessary in this case. The main logic for + // _luau_print is contained in lookupAnnotation. + if (FFlag::DebugLuauMagicTypes && ty->name == "_luau_print" && ty->parameters.size > 0) + return; + for (const AstTypeOrPack& param : ty->parameters) { if (param.type) @@ -1613,18 +1624,29 @@ struct TypeChecker2 fetch(norm.tops); fetch(norm.booleans); - for (TypeId ty : norm.classes) - fetch(ty); + + if (FFlag::LuauNegatedClassTypes) + { + for (const auto& [ty, _negations] : norm.classes.classes) + { + fetch(ty); + } + } + else + { + for (TypeId ty : norm.DEPRECATED_classes) + fetch(ty); + } fetch(norm.errors); fetch(norm.nils); fetch(norm.numbers); if (!norm.strings.isNever()) - fetch(singletonTypes->stringType); + fetch(builtinTypes->stringType); fetch(norm.threads); for (TypeId ty : norm.tables) fetch(ty); if (norm.functions.isTop) - fetch(singletonTypes->functionType); + fetch(builtinTypes->functionType); else if (!norm.functions.isNever()) { if (norm.functions.parts->size() == 1) @@ -1633,15 +1655,15 @@ struct TypeChecker2 { std::vector parts; parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); - fetch(module->internalTypes.addType(IntersectionTypeVar{std::move(parts)})); + fetch(testArena.addType(IntersectionType{std::move(parts)})); } } for (const auto& [tyvar, intersect] : norm.tyvars) { - if (get(intersect->tops)) + if (get(intersect->tops)) { TypeId ty = normalizer.typeFromNormal(*intersect); - fetch(module->internalTypes.addType(IntersectionTypeVar{{tyvar, ty}})); + fetch(testArena.addType(IntersectionType{{tyvar, ty}})); } else fetch(tyvar); @@ -1658,23 +1680,23 @@ struct TypeChecker2 bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location) { - if (get(ty) || get(ty) || get(ty)) + if (get(ty) || get(ty) || get(ty)) return true; if (isString(ty)) { - std::optional mtIndex = Luau::findMetatableEntry(singletonTypes, module->errors, singletonTypes->stringType, "__index", location); + std::optional mtIndex = Luau::findMetatableEntry(builtinTypes, module->errors, builtinTypes->stringType, "__index", location); LUAU_ASSERT(mtIndex); ty = *mtIndex; } if (getTableType(ty)) - return bool(findTablePropertyRespectingMeta(singletonTypes, module->errors, ty, prop, location)); - else if (const ClassTypeVar* cls = get(ty)) + return bool(findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)); + else if (const ClassType* cls = get(ty)) return bool(lookupClassProp(cls, prop)); - else if (const UnionTypeVar* utv = get(ty)) - ice.ice("getIndexTypeFromTypeHelper cannot take a UnionTypeVar"); - else if (const IntersectionTypeVar* itv = get(ty)) + else if (const UnionType* utv = get(ty)) + ice.ice("getIndexTypeFromTypeHelper cannot take a UnionType"); + else if (const IntersectionType* itv = get(ty)) return std::any_of(begin(itv), end(itv), [&](TypeId part) { return hasIndexTypeFromType(part, prop, location); }); @@ -1683,11 +1705,15 @@ struct TypeChecker2 } }; -void check(NotNull singletonTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module) +void check(NotNull builtinTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module) { - TypeChecker2 typeChecker{singletonTypes, logger, &sourceModule, module}; + TypeChecker2 typeChecker{builtinTypes, logger, &sourceModule, module}; typeChecker.visit(sourceModule.root); + + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes); + freeze(module->interfaceTypes); } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index aa738ad96..f31ea9381 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -18,8 +18,8 @@ #include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include #include @@ -50,6 +50,7 @@ LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) +LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) @@ -60,7 +61,7 @@ namespace Luau static bool typeCouldHaveMetatable(TypeId ty) { - return get(follow(ty)) || get(follow(ty)) || get(follow(ty)); + return get(follow(ty)) || get(follow(ty)) || get(follow(ty)); } static void defaultLuauPrintLine(const std::string& s) @@ -222,23 +223,23 @@ size_t HashBoolNamePair::operator()(const std::pair& pair) const return std::hash()(pair.first) ^ std::hash()(pair.second); } -TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull singletonTypes, InternalErrorReporter* iceHandler) +TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) : resolver(resolver) - , singletonTypes(singletonTypes) + , builtinTypes(builtinTypes) , iceHandler(iceHandler) , unifierState(iceHandler) - , normalizer(nullptr, singletonTypes, NotNull{&unifierState}) - , nilType(singletonTypes->nilType) - , numberType(singletonTypes->numberType) - , stringType(singletonTypes->stringType) - , booleanType(singletonTypes->booleanType) - , threadType(singletonTypes->threadType) - , anyType(singletonTypes->anyType) - , unknownType(singletonTypes->unknownType) - , neverType(singletonTypes->neverType) - , anyTypePack(singletonTypes->anyTypePack) - , neverTypePack(singletonTypes->neverTypePack) - , uninhabitableTypePack(singletonTypes->uninhabitableTypePack) + , normalizer(nullptr, builtinTypes, NotNull{&unifierState}) + , nilType(builtinTypes->nilType) + , numberType(builtinTypes->numberType) + , stringType(builtinTypes->stringType) + , booleanType(builtinTypes->booleanType) + , threadType(builtinTypes->threadType) + , anyType(builtinTypes->anyType) + , unknownType(builtinTypes->unknownType) + , neverType(builtinTypes->neverType) + , anyTypePack(builtinTypes->anyTypePack) + , neverTypePack(builtinTypes->neverTypePack) + , uninhabitableTypePack(builtinTypes->uninhabitableTypePack) , duplicateTypeAliases{{false, {}}} { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -338,7 +339,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo normalizer.arena = nullptr; } - currentModule->clonePublicInterface(singletonTypes, *iceHandler); + currentModule->clonePublicInterface(builtinTypes, *iceHandler); // Clear unifier cache since it's keyed off internal types that get deallocated // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. @@ -447,13 +448,13 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } } -struct InplaceDemoter : TypeVarOnceVisitor +struct InplaceDemoter : TypeOnceVisitor { TypeLevel newLevel; TypeArena* arena; InplaceDemoter(TypeLevel level, TypeArena* arena) - : TypeVarOnceVisitor(/* skipBoundTypes= */ true) + : TypeOnceVisitor(/* skipBoundTypes= */ true) , newLevel(level) , arena(arena) { @@ -670,9 +671,9 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std continue; TypeId type = bindings[name].type; - if (get(follow(type))) + if (get(follow(type))) { - TypeVar* mty = asMutable(follow(type)); + Type* mty = asMutable(follow(type)); mty->reassign(*errorRecoveryType(anyType)); reportError(TypeError{typealias->location, OccursCheckFailed{}}); @@ -785,7 +786,7 @@ struct Demoter : Substitution bool isDirty(TypeId ty) override { - return get(ty); + return get(ty); } bool isDirty(TypePackId tp) override @@ -795,7 +796,7 @@ struct Demoter : Substitution bool ignoreChildren(TypeId ty) override { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) return true; return false; @@ -803,9 +804,9 @@ struct Demoter : Substitution TypeId clean(TypeId ty) override { - auto ftv = get(ty); + auto ftv = get(ty); LUAU_ASSERT(ftv); - return addType(FreeTypeVar{demotedLevel(ftv->level)}); + return addType(FreeType{demotedLevel(ftv->level)}); } TypePackId clean(TypePackId tp) override @@ -987,14 +988,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } // Setting a table entry to nil doesn't mean nil is the type of the indexer, it is just deleting the entry - const TableTypeVar* destTableTypeReceivingNil = nullptr; + const TableType* destTableTypeReceivingNil = nullptr; if (auto indexExpr = dest->as(); isNil(right) && indexExpr) destTableTypeReceivingNil = getTableType(checkExpr(scope, *indexExpr->expr).type); if (!destTableTypeReceivingNil || !destTableTypeReceivingNil->indexer) { - // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. - if (isNonstrictMode() && get(follow(left)) && !get(follow(right))) + // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any type. + if (isNonstrictMode() && get(follow(left)) && !get(follow(right))) unify(anyType, left, scope, loc); else unify(right, left, scope, loc); @@ -1046,7 +1047,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) ty = resolveType(scope, *annotation); // If the annotation type has an error, treat it as if there was no annotation - if (get(follow(ty))) + if (get(follow(ty))) ty = nullptr; } @@ -1102,7 +1103,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) { if (rhs->is()) { - TableTypeVar* ttv = getMutable(follow(*ty)); + TableType* ttv = getMutable(follow(*ty)); if (ttv && !ttv->name && scope == currentModule->getModuleScope()) ttv->syntheticName = vars[0]->name.value; } @@ -1110,7 +1111,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) { if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") { - MetatableTypeVar* mtv = getMutable(follow(*ty)); + MetatableType* mtv = getMutable(follow(*ty)); if (mtv) mtv->syntheticName = vars[0]->name.value; } @@ -1261,7 +1262,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) return check(loopScope, *forin.body); } - if (const TableTypeVar* iterTable = get(iterTy)) + if (const TableType* iterTable = get(iterTy)) { // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting @@ -1294,15 +1295,15 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) return check(loopScope, *forin.body); } - const FunctionTypeVar* iterFunc = get(iterTy); + const FunctionType* iterFunc = get(iterTy); if (!iterFunc) { - TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); + TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) unify(varTy, var, scope, forin.location); - if (!get(iterTy) && !get(iterTy) && !get(iterTy) && !get(iterTy)) + if (!get(iterTy) && !get(iterTy) && !get(iterTy) && !get(iterTy)) reportError(firstValue->location, CannotCallNonFunction{iterTy}); return check(loopScope, *forin.body); @@ -1354,7 +1355,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) { TypeId keyTy = follow(*fty); - if (get(keyTy)) + if (get(keyTy)) { if (std::optional ty = tryStripUnionFromNil(keyTy)) keyTy = *ty; @@ -1432,7 +1433,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco else if (auto name = function.name->as()) { TypeId exprTy = checkExpr(scope, *name->expr).type; - TableTypeVar* ttv = getMutableTableType(exprTy); + TableType* ttv = getMutableTableType(exprTy); if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false)) { @@ -1449,7 +1450,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco if (function.func->self) { - const FunctionTypeVar* funTy = get(ty); + const FunctionType* funTy = get(ty); if (!funTy) ice("Methods should be functions"); @@ -1520,7 +1521,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias for (auto param : binding->typeParams) { - auto generic = get(param.ty); + auto generic = get(param.ty); LUAU_ASSERT(generic); aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, param.ty}; } @@ -1533,7 +1534,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } TypeId ty = resolveType(aliasScope, *typealias.type); - if (auto ttv = getMutable(follow(ty))) + if (auto ttv = getMutable(follow(ty))) { // If the table is already named and we want to rename the type function, we have to bind new alias to a copy // Additionally, we can't modify types that come from other modules @@ -1552,7 +1553,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (!ttv->name || ttv->name != name || !sameTys || !sameTps) { // This is a shallow clone, original recursive links to self are not updated - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; + TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->state}; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; @@ -1578,7 +1579,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ttv->instantiatedTypePackParams.push_back(param.tp); } } - else if (auto mtv = getMutable(follow(ty))) + else if (auto mtv = getMutable(follow(ty))) { // We can't modify types that come from other modules if (follow(ty)->owningArena == ¤tModule->internalTypes) @@ -1634,7 +1635,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); TypeId ty = freshType(aliasScope); - FreeTypeVar* ftv = getMutable(ty); + FreeType* ftv = getMutable(ty); LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; @@ -1652,7 +1653,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); TypeId ty = freshType(aliasScope); - FreeTypeVar* ftv = getMutable(ty); + FreeType* ftv = getMutable(ty); LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; @@ -1664,8 +1665,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { LUAU_ASSERT(FFlag::LuauDeclareClassPrototype); - - std::optional superTy = std::nullopt; + std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; if (declaredClass.superName) { Name superName = Name(declaredClass.superName->value); @@ -1682,7 +1682,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); superTy = lookupType->type; - if (!get(follow(*superTy))) + if (!get(follow(*superTy))) { reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); @@ -1693,9 +1693,9 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de Name className(declaredClass.name.value); - TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); - ClassTypeVar* ctv = getMutable(classTy); - TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); + TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); + ClassType* ctv = getMutable(classTy); + TypeId metaTy = addType(TableType{TableState::Sealed, scope->level}); ctv->metatable = metaTy; scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; @@ -1720,25 +1720,25 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar ice("Class not predeclared"); TypeId classTy = binding->type; - ClassTypeVar* ctv = getMutable(classTy); + ClassType* ctv = getMutable(classTy); if (!ctv->metatable) ice("No metatable for declared class"); - TableTypeVar* metatable = getMutable(*ctv->metatable); + TableType* metatable = getMutable(*ctv->metatable); for (const AstDeclaredClassProp& prop : declaredClass.props) { Name propName(prop.name.value); TypeId propTy = resolveType(scope, *prop.ty); bool assignToMetatable = isMetamethod(propName); - Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + Luau::ClassType::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; // Function types always take 'self', but this isn't reflected in the // parsed annotation. Add it here. if (prop.isMethod) { - if (FunctionTypeVar* ftv = getMutable(propTy)) + if (FunctionType* ftv = getMutable(propTy)) { ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); @@ -1756,17 +1756,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. - if (const IntersectionTypeVar* itv = get(currentTy)) + if (const IntersectionType* itv = get(currentTy)) { std::vector options = itv->parts; options.push_back(propTy); - TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); + TypeId newItv = addType(IntersectionType{std::move(options)}); assignTo[propName] = {newItv}; } - else if (get(currentTy)) + else if (get(currentTy)) { - TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); + TypeId intersection = addType(IntersectionType{{currentTy, propTy}}); assignTo[propName] = {intersection}; } @@ -1779,7 +1779,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } else { - std::optional superTy = std::nullopt; + std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; if (declaredClass.superName) { Name superName = Name(declaredClass.superName->value); @@ -1795,7 +1795,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); superTy = lookupType->type; - if (!get(follow(*superTy))) + if (!get(follow(*superTy))) { reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); @@ -1805,11 +1805,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar Name className(declaredClass.name.value); - TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); + TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); - ClassTypeVar* ctv = getMutable(classTy); - TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); - TableTypeVar* metatable = getMutable(metaTy); + ClassType* ctv = getMutable(classTy); + TypeId metaTy = addType(TableType{TableState::Sealed, scope->level}); + TableType* metatable = getMutable(metaTy); ctv->metatable = metaTy; @@ -1821,13 +1821,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar TypeId propTy = resolveType(scope, *prop.ty); bool assignToMetatable = isMetamethod(propName); - Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + Luau::ClassType::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; // Function types always take 'self', but this isn't reflected in the // parsed annotation. Add it here. if (prop.isMethod) { - if (FunctionTypeVar* ftv = getMutable(propTy)) + if (FunctionType* ftv = getMutable(propTy)) { ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); @@ -1845,17 +1845,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. - if (const IntersectionTypeVar* itv = get(currentTy)) + if (const IntersectionType* itv = get(currentTy)) { std::vector options = itv->parts; options.push_back(propTy); - TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); + TypeId newItv = addType(IntersectionType{std::move(options)}); assignTo[propName] = {newItv}; } - else if (get(currentTy)) + else if (get(currentTy)) { - TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); + TypeId intersection = addType(IntersectionType{{currentTy, propTy}}); assignTo[propName] = {intersection}; } @@ -1888,8 +1888,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); - TypeId fnType = addType(FunctionTypeVar{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); - FunctionTypeVar* ftv = getMutable(fnType); + TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); + FunctionType* ftv = getMutable(fnType); ftv->argNames.reserve(global.paramNames.size); for (const auto& el : global.paramNames) @@ -2018,7 +2018,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp *asMutable(varargPack) = TypePack{{head}, tail}; return {head}; } - if (get(varargPack)) + if (get(varargPack)) return {errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return {vtp->ty}; @@ -2085,7 +2085,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors) { ErrorVec errors; - auto result = Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); + auto result = Luau::findTablePropertyRespectingMeta(builtinTypes, errors, lhsType, name, location); if (addErrors) reportErrors(errors); return result; @@ -2094,7 +2094,7 @@ std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsTyp std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors) { ErrorVec errors; - auto result = Luau::findMetatableEntry(singletonTypes, errors, type, entry, location); + auto result = Luau::findMetatableEntry(builtinTypes, errors, type, entry, location); if (addErrors) reportErrors(errors); return result; @@ -2118,7 +2118,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( { type = follow(type); - if (get(type) || get(type) || get(type)) + if (get(type) || get(type) || get(type)) return type; tablify(type); @@ -2130,7 +2130,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( type = *mtIndex; } - if (TableTypeVar* tableType = getMutableTableType(type)) + if (TableType* tableType = getMutableTableType(type)) { if (auto it = tableType->props.find(name); it != tableType->props.end()) return it->second.type; @@ -2157,13 +2157,13 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( if (auto found = findTablePropertyRespectingMeta(type, name, location, addErrors)) return *found; } - else if (const ClassTypeVar* cls = get(type)) + else if (const ClassType* cls = get(type)) { const Property* prop = lookupClassProp(cls, name); if (prop) return prop->type; } - else if (const UnionTypeVar* utv = get(type)) + else if (const UnionType* utv = get(type)) { std::vector goodOptions; std::vector badOptions; @@ -2173,7 +2173,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); // Not needed when we normalize types. - if (get(follow(t))) + if (get(follow(t))) return t; if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) @@ -2201,9 +2201,9 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( if (result.size() == 1) return result[0]; - return addType(UnionTypeVar{std::move(result)}); + return addType(UnionType{std::move(result)}); } - else if (const IntersectionTypeVar* itv = get(type)) + else if (const IntersectionType* itv = get(type)) { std::vector parts; @@ -2226,7 +2226,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( if (parts.size() == 1) return parts[0]; - return addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. + return addType(IntersectionType{std::move(parts)}); // Not at all correct. } if (addErrors) @@ -2237,7 +2237,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) { - if (const UnionTypeVar* utv = get(ty)) + if (const UnionType* utv = get(ty)) { if (!std::any_of(begin(utv), end(utv), isNil)) return ty; @@ -2253,7 +2253,7 @@ std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) if (result.empty()) return std::nullopt; - return result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); + return result.size() == 1 ? result[0] : addType(UnionType{std::move(result)}); } return std::nullopt; @@ -2263,7 +2263,7 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) { ty = follow(ty); - if (auto utv = get(ty)) + if (auto utv = get(ty)) { if (!std::any_of(begin(utv), end(utv), isNil)) return ty; @@ -2301,14 +2301,14 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp TypeId TypeChecker::checkExprTable( const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType) { - TableTypeVar::Props props; + TableType::Props props; std::optional indexer; - const TableTypeVar* expectedTable = nullptr; + const TableType* expectedTable = nullptr; if (expectedType) { - if (auto ttv = get(follow(*expectedType))) + if (auto ttv = get(follow(*expectedType))) { if (ttv->state == TableState::Sealed) expectedTable = ttv; @@ -2342,7 +2342,7 @@ TypeId TypeChecker::checkExprTable( if (auto key = k->as()) { TypeId exprType = follow(valueType); - if (isNonstrictMode() && !getTableType(exprType) && !get(exprType)) + if (isNonstrictMode() && !getTableType(exprType) && !get(exprType)) exprType = anyType; if (expectedTable) @@ -2388,7 +2388,7 @@ TypeId TypeChecker::checkExprTable( } TableState state = TableState::Unsealed; - TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; + TableType table = TableType{std::move(props), indexer, scope->level, state}; table.definitionModuleName = currentModuleName; return addType(table); } @@ -2404,14 +2404,14 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp std::vector> fieldTypes(expr.items.size); - const TableTypeVar* expectedTable = nullptr; - const UnionTypeVar* expectedUnion = nullptr; + const TableType* expectedTable = nullptr; + const UnionType* expectedUnion = nullptr; std::optional expectedIndexType; std::optional expectedIndexResultType; if (expectedType) { - if (auto ttv = get(follow(*expectedType))) + if (auto ttv = get(follow(*expectedType))) { if (ttv->state == TableState::Sealed) { @@ -2424,7 +2424,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } } } - else if (const UnionTypeVar* utv = get(follow(*expectedType))) + else if (const UnionType* utv = get(follow(*expectedType))) expectedUnion = utv; } @@ -2455,7 +2455,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp std::vector expectedResultTypes; for (TypeId expectedOption : expectedUnion) { - if (const TableTypeVar* ttv = get(follow(expectedOption))) + if (const TableType* ttv = get(follow(expectedOption))) { if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) expectedResultTypes.push_back(prop->second.type); @@ -2467,7 +2467,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (expectedResultTypes.size() == 1) expectedResultType = expectedResultTypes[0]; else if (expectedResultTypes.size() > 1) - expectedResultType = addType(UnionTypeVar{expectedResultTypes}); + expectedResultType = addType(UnionType{expectedResultTypes}); } } else @@ -2500,7 +2500,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {booleanType, {NotPredicate{std::move(result.predicates)}}}; case AstExprUnary::Minus: { - const bool operandIsAny = get(operandType) || get(operandType) || get(operandType); + const bool operandIsAny = get(operandType) || get(operandType) || get(operandType); if (operandIsAny) return {operandType}; @@ -2512,7 +2512,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); TypePackId retTypePack = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); + TypeId expectedFunctionType = addType(FunctionType(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(scope, expr.location); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); @@ -2542,8 +2542,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp operandType = stripFromNilAndReport(operandType, expr.location); // # operator is guaranteed to return number - if ((FFlag::LuauNeverTypesAndOperatorsInference && get(operandType)) || get(operandType) || - get(operandType)) + if ((FFlag::LuauNeverTypesAndOperatorsInference && get(operandType)) || get(operandType) || get(operandType)) { if (FFlag::LuauNeverTypesAndOperatorsInference) return {numberType}; @@ -2560,7 +2559,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); TypePackId retTypePack = addTypePack({numberType}); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); + TypeId expectedFunctionType = addType(FunctionType(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(scope, expr.location); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); @@ -2617,7 +2616,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, cons a = follow(a); b = follow(b); - if (unifyFreeTypes && (get(a) || get(b))) + if (unifyFreeTypes && (get(a) || get(b))) { if (unify(b, a, scope, location)) return a; @@ -2635,7 +2634,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, cons if (types.size() == 1) return types[0]; - return addType(UnionTypeVar{types}); + return addType(UnionType{types}); } static std::optional getIdentifierOfBaseVar(AstExpr* node) @@ -2663,7 +2662,7 @@ static std::optional getIdentifierOfBaseVar(AstExpr* node) * types have any overlap at all. * * In order to make things work smoothly with the greedy solver, this function - * exempts any and FreeTypeVars from this requirement. + * exempts any and FreeTypes from this requirement. * * This function does not (yet?) take into account extra Lua restrictions like * that two tables can only be compared if they have the same metatable. That @@ -2680,13 +2679,13 @@ static std::optional areEqComparable(NotNull arena, NotNull(t); + return isNil(t) || get(t); }; if (isExempt(a) || isExempt(b)) return true; - TypeId c = arena->addType(IntersectionTypeVar{{a, b}}); + TypeId c = arena->addType(IntersectionType{{a, b}}); const NormalizedType* n = normalizer->normalize(c); if (!n) return std::nullopt; @@ -2705,7 +2704,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!isNonstrictMode() && !isOrOp) return ty; - if (get(ty)) + if (get(ty)) { std::optional cleaned = tryStripUnionFromNil(ty); @@ -2726,7 +2725,7 @@ TypeId TypeChecker::checkRelationalOperation( // If we know nothing at all about the lhs type, we can usually say nothing about the result. // The notable exception to this is the equality and inequality operators, which always produce a boolean. - const bool lhsIsAny = get(lhsType) || get(lhsType) || get(lhsType); + const bool lhsIsAny = get(lhsType) || get(lhsType) || get(lhsType); // Peephole check for `cond and a or b -> type(a)|type(b)` // TODO: Kill this when singleton types arrive. :( @@ -2749,7 +2748,7 @@ TypeId TypeChecker::checkRelationalOperation( if (isNonstrictMode() && (isNil(lhsType) || isNil(rhsType))) return booleanType; - const bool rhsIsAny = get(rhsType) || get(rhsType) || get(rhsType); + const bool rhsIsAny = get(rhsType) || get(rhsType) || get(rhsType); if (lhsIsAny || rhsIsAny) return booleanType; @@ -2763,7 +2762,7 @@ TypeId TypeChecker::checkRelationalOperation( if (FFlag::LuauNeverTypesAndOperatorsInference) { // If one of the operand is never, it doesn't make sense to unify these. - if (get(lhsType) || get(rhsType)) + if (get(lhsType) || get(rhsType)) return booleanType; } @@ -2804,11 +2803,11 @@ TypeId TypeChecker::checkRelationalOperation( const bool needsMetamethod = !isEquality; TypeId leftType = follow(lhsType); - if (get(leftType) || get(leftType) || get(leftType) || get(leftType)) + if (get(leftType) || get(leftType) || get(leftType) || get(leftType)) { reportErrors(state.errors); - if (!isEquality && state.errors.empty() && (get(leftType) || isBoolean(leftType))) + if (!isEquality && state.errors.empty() && (get(leftType) || isBoolean(leftType))) { reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())}); @@ -2819,19 +2818,19 @@ TypeId TypeChecker::checkRelationalOperation( std::string metamethodName = opToMetaTableEntry(expr.op); - std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType), singletonTypes); - std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType), singletonTypes); + std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType), builtinTypes); + std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType), builtinTypes); if (leftMetatable != rightMetatable) { bool matches = false; if (isEquality) { - if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) + if (const UnionType* utv = get(leftType); utv && rightMetatable) { for (TypeId leftOption : utv) { - if (getMetatable(follow(leftOption), singletonTypes) == rightMetatable) + if (getMetatable(follow(leftOption), builtinTypes) == rightMetatable) { matches = true; break; @@ -2841,11 +2840,11 @@ TypeId TypeChecker::checkRelationalOperation( if (!matches) { - if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) + if (const UnionType* utv = get(rhsType); utv && leftMetatable) { for (TypeId rightOption : utv) { - if (getMetatable(follow(rightOption), singletonTypes) == leftMetatable) + if (getMetatable(follow(rightOption), builtinTypes) == leftMetatable) { matches = true; break; @@ -2869,7 +2868,7 @@ TypeId TypeChecker::checkRelationalOperation( std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true); if (metamethod) { - if (const FunctionTypeVar* ftv = get(*metamethod)) + if (const FunctionType* ftv = get(*metamethod)) { if (isEquality) { @@ -2888,7 +2887,7 @@ TypeId TypeChecker::checkRelationalOperation( reportErrors(state.errors); - TypeId actualFunctionType = addType(FunctionTypeVar(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType}))); + TypeId actualFunctionType = addType(FunctionType(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType}))); state.tryUnify( instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true); @@ -2905,7 +2904,7 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (get(follow(lhsType)) && !isEquality) + if (get(follow(lhsType)) && !isEquality) { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); @@ -2930,8 +2929,8 @@ TypeId TypeChecker::checkRelationalOperation( else if (FFlag::LuauTryhardAnd) { // If lhs is free, we can't tell which 'falsy' components it has, if any - if (get(lhsType)) - return unionOfTypes(addType(UnionTypeVar{{nilType, singletonType(false)}}), rhsType, scope, expr.location, false); + if (get(lhsType)) + return unionOfTypes(addType(UnionType{{nilType, singletonType(false)}}), rhsType, scope, expr.location, false); auto [oty, notNever] = pickTypesFromSense(lhsType, false, neverType); // Filter out falsy types @@ -2999,7 +2998,7 @@ TypeId TypeChecker::checkBinaryOperation( lhsType = follow(lhsType); rhsType = follow(rhsType); - if (!isNonstrictMode() && get(lhsType)) + if (!isNonstrictMode() && get(lhsType)) { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); @@ -3008,17 +3007,17 @@ TypeId TypeChecker::checkBinaryOperation( // If we know nothing at all about the lhs type, we can usually say nothing about the result. // The notable exception to this is the equality and inequality operators, which always produce a boolean. - const bool lhsIsAny = get(lhsType) || get(lhsType) || - (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get(lhsType)); - const bool rhsIsAny = get(rhsType) || get(rhsType) || - (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get(rhsType)); + const bool lhsIsAny = get(lhsType) || get(lhsType) || + (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get(lhsType)); + const bool rhsIsAny = get(rhsType) || get(rhsType) || + (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get(rhsType)); if (lhsIsAny) return lhsType; if (rhsIsAny) return rhsType; - if (get(lhsType)) + if (get(lhsType)) { // Inferring this accurately will get a bit weird. // If the lhs type is not known, it could be assumed that it is a table or class that has a metatable @@ -3027,7 +3026,7 @@ TypeId TypeChecker::checkBinaryOperation( return anyType; } - if (get(rhsType)) + if (get(rhsType)) unify(rhsType, lhsType, scope, expr.location); if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) @@ -3036,7 +3035,7 @@ TypeId TypeChecker::checkBinaryOperation( TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypePackId arguments = addTypePack({lhst, rhst}); TypePackId retTypePack = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); + TypeId expectedFunctionType = addType(FunctionType(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(scope, expr.location); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); @@ -3049,7 +3048,7 @@ TypeId TypeChecker::checkBinaryOperation( // If there are unification errors, the return type may still be unknown // so we loosen the argument types to see if that helps. TypePackId fallbackArguments = freshTypePack(scope); - TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); + TypeId fallbackFunctionType = addType(FunctionType(scope->level, fallbackArguments, retTypePack)); state.errors.clear(); state.log.clear(); @@ -3088,8 +3087,8 @@ TypeId TypeChecker::checkBinaryOperation( switch (expr.op) { case AstExprBinary::Concat: - reportErrors(tryUnify(lhsType, addType(UnionTypeVar{{stringType, numberType}}), scope, expr.left->location)); - reportErrors(tryUnify(rhsType, addType(UnionTypeVar{{stringType, numberType}}), scope, expr.right->location)); + reportErrors(tryUnify(lhsType, addType(UnionType{{stringType, numberType}}), scope, expr.left->location)); + reportErrors(tryUnify(rhsType, addType(UnionType{{stringType, numberType}}), scope, expr.right->location)); return stringType; case AstExprBinary::Add: case AstExprBinary::Sub: @@ -3215,7 +3214,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp std::vector types = reduceUnion({trueType.type, falseType.type}); if (FFlag::LuauUnknownAndNeverType && types.empty()) return {neverType}; - return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; + return {types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprInterpString& expr) @@ -3256,7 +3255,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal if (std::optional ty = scope->lookup(expr.local)) { ty = follow(*ty); - return get(*ty) ? unknownType : *ty; + return get(*ty) ? unknownType : *ty; } reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); @@ -3273,7 +3272,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGloba if (it != moduleScope->bindings.end()) { TypeId ty = follow(it->second.typeId); - return get(ty) ? unknownType : ty; + return get(ty) ? unknownType : ty; } TypeId result = freshType(scope); @@ -3292,10 +3291,10 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { TypeId lhs = checkExpr(scope, *expr.expr).type; - if (get(lhs) || get(lhs)) + if (get(lhs) || get(lhs)) return lhs; - if (get(lhs)) + if (get(lhs)) return unknownType; tablify(lhs); @@ -3304,7 +3303,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex lhs = stripFromNilAndReport(lhs, expr.expr->location); - if (TableTypeVar* lhsTable = getMutableTableType(lhs)) + if (TableType* lhsTable = getMutableTableType(lhs)) { const auto& it = lhsTable->props.find(name); if (it != lhsTable->props.end()) @@ -3346,7 +3345,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex return errorRecoveryType(scope); } } - else if (const ClassTypeVar* lhsClass = get(lhs)) + else if (const ClassType* lhsClass = get(lhs)) { const Property* prop = lookupClassProp(lhsClass, name); if (!prop) @@ -3357,7 +3356,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex return prop->type; } - else if (get(lhs)) + else if (get(lhs)) { if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, /* addErrors= */ false)) return *ty; @@ -3386,17 +3385,17 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex if (FFlag::LuauFollowInLvalueIndexCheck) exprType = follow(exprType); - if (get(exprType) || get(exprType)) + if (get(exprType) || get(exprType)) return exprType; - if (get(exprType)) + if (get(exprType)) return unknownType; AstExprConstantString* value = expr.index->as(); if (value) { - if (const ClassTypeVar* exprClass = get(exprType)) + if (const ClassType* exprClass = get(exprType)) { const Property* prop = lookupClassProp(exprClass, value->value.data); if (!prop) @@ -3409,7 +3408,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (FFlag::LuauAllowIndexClassParameters) { - if (const ClassTypeVar* exprClass = get(exprType)) + if (const ClassType* exprClass = get(exprType)) { if (isNonstrictMode()) return unknownType; @@ -3418,7 +3417,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } } - TableTypeVar* exprTable = getMutableTableType(exprType); + TableType* exprTable = getMutableTableType(exprType); if (!exprTable) { @@ -3504,7 +3503,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T else if (auto indexName = funName.as()) { TypeId lhsType = checkExpr(scope, *indexName->expr).type; - TableTypeVar* ttv = getMutableTableType(lhsType); + TableType* ttv = getMutableTableType(lhsType); if (!ttv || ttv->state == TableState::Sealed) { @@ -3549,22 +3548,22 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& { ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel); - const FunctionTypeVar* expectedFunctionType = nullptr; + const FunctionType* expectedFunctionType = nullptr; if (expectedType) { LUAU_ASSERT(!expr.self); - if (auto ftv = get(follow(*expectedType))) + if (auto ftv = get(follow(*expectedType))) { expectedFunctionType = ftv; } - else if (auto utv = get(follow(*expectedType))) + else if (auto utv = get(follow(*expectedType))) { // Look for function type in a union. Other types can be ignored since current expression is a function for (auto option : utv) { - if (auto ftv = get(follow(option))) + if (auto ftv = get(follow(option))) { if (!expectedFunctionType) { @@ -3677,7 +3676,7 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& argType = resolveType(funScope, *local->annotation); // If the annotation type has an error, treat it as if there was no annotation - if (get(follow(argType))) + if (get(follow(argType))) argType = anyIfNonstrict(freshType(funScope)); } else @@ -3741,9 +3740,9 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& } TypeId funTy = - addType(FunctionTypeVar(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self))); + addType(FunctionType(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self))); - FunctionTypeVar* ftv = getMutable(funTy); + FunctionType* ftv = getMutable(funTy); ftv->argNames.reserve(expr.args.size + (expr.self ? 1 : 0)); @@ -3760,7 +3759,7 @@ static bool allowsNoReturnValues(const TypePackId tp) { for (TypeId ty : tp) { - if (!get(follow(ty))) + if (!get(follow(ty))) { return false; } @@ -3791,7 +3790,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE else LUAU_TIMETRACE_ARGUMENT("line", std::to_string(function.location.begin.line).c_str()); - if (FunctionTypeVar* funTy = getMutable(ty)) + if (FunctionType* funTy = getMutable(ty)) { check(scope, *function.body); @@ -3970,7 +3969,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam if (isOptional(t)) { } // ok - else if (state.log.getMutable(t)) + else if (state.log.getMutable(t)) { } // ok else @@ -4117,11 +4116,11 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope } TypePackId retPack; - if (auto free = get(actualFunctionType)) + if (auto free = get(actualFunctionType)) { retPack = freshTypePack(free->level); TypePackId freshArgPack = freshTypePack(free->level); - asMutable(actualFunctionType)->ty.emplace(free->level, freshArgPack, retPack); + asMutable(actualFunctionType)->ty.emplace(free->level, freshArgPack, retPack); } else retPack = freshTypePack(scope->level); @@ -4188,11 +4187,11 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); - const FunctionTypeVar* overload = nullptr; + const FunctionType* overload = nullptr; if (!overloadsThatMatchArgCount.empty()) - overload = get(overloadsThatMatchArgCount[0]); + overload = get(overloadsThatMatchArgCount[0]); if (!overload && !overloadsThatDont.empty()) - overload = get(overloadsThatDont[0]); + overload = get(overloadsThatDont[0]); if (overload) return {errorRecoveryTypePack(overload->retTypes)}; @@ -4222,14 +4221,14 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st if (FFlag::LuauUnknownAndNeverType && result.empty()) el = neverType; else - el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); + el = result.size() == 1 ? result[0] : addType(UnionType{std::move(result)}); } } }; for (const TypeId overload : overloads) { - if (const FunctionTypeVar* ftv = get(overload)) + if (const FunctionType* ftv = get(overload)) { auto [argsHead, argsTail] = flatten(ftv->argTypes); @@ -4265,26 +4264,26 @@ std::optional> TypeChecker::checkCallOverload(const Sc fn = stripFromNilAndReport(fn, expr.func->location); - if (get(fn)) + if (get(fn)) { unify(anyTypePack, argPack, scope, expr.location); return {{anyTypePack}}; } - if (get(fn)) + if (get(fn)) { return {{errorRecoveryTypePack(scope)}}; } - if (get(fn)) + if (get(fn)) return {{uninhabitableTypePack}}; - if (auto ftv = get(fn)) + if (auto ftv = get(fn)) { // fn is one of the overloads of actualFunctionType, which // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. - TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); + TypeId r = addType(FunctionType(scope->level, argPack, retPack)); UnifierOptions options; options.isFunctionCall = true; @@ -4297,11 +4296,11 @@ std::optional> TypeChecker::checkCallOverload(const Sc // Might be a callable table or class std::optional callTy = std::nullopt; - if (const MetatableTypeVar* mttv = get(fn)) + if (const MetatableType* mttv = get(fn)) { callTy = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false); } - else if (const ClassTypeVar* ctv = get(fn); FFlag::LuauCallableClasses && ctv && ctv->metatable) + else if (const ClassType* ctv = get(fn); FFlag::LuauCallableClasses && ctv && ctv->metatable) { callTy = getIndexTypeFromType(scope, *ctv->metatable, "__call", expr.func->location, /* addErrors= */ false); } @@ -4324,7 +4323,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc argLocations = &metaArgLocations; } - const FunctionTypeVar* ftv = get(fn); + const FunctionType* ftv = get(fn); if (!ftv) { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); @@ -4481,7 +4480,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast // Remove the overload we are reporting errors about, from the list of alternative overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end()); - const FunctionTypeVar* ftv = get(overload); + const FunctionType* ftv = get(overload); auto error = std::find_if(errors.begin(), errors.end(), [ftv](const OverloadErrorEntry& e) { return ftv == std::get<2>(e); @@ -4502,7 +4501,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast Unifier state = mkUnifier(scope, expr.location); // Unify return types - if (const FunctionTypeVar* ftv = get(overload)) + if (const FunctionType* ftv = get(overload)) { checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, {}); checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, argLocations); @@ -4586,7 +4585,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons auto [type, exprPredicates] = checkExpr(scope, *expr, expectedType); insert(exprPredicates); - if (FFlag::LuauUnknownAndNeverType && get(type)) + if (FFlag::LuauUnknownAndNeverType && get(type)) { // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, // ...never) @@ -4705,8 +4704,8 @@ void TypeChecker::tablify(TypeId type) { type = follow(type); - if (auto f = get(type)) - *asMutable(type) = TableTypeVar{TableState::Free, f->level}; + if (auto f = get(type)) + *asMutable(type) = TableType{TableState::Free, f->level}; } TypeId TypeChecker::anyIfNonstrict(TypeId ty) const @@ -4807,14 +4806,14 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (FFlag::DebugLuauSharedSelf) { - if (auto ftv = get(ty)) + if (auto ftv = get(ty)) Luau::quantify(ty, scope->level); else if (auto ttv = getTableType(ty); ttv && ttv->selfTy) Luau::quantify(ty, scope->level); } else { - const FunctionTypeVar* ftv = get(ty); + const FunctionType* ftv = get(ty); if (ftv) Luau::quantify(ty, scope->level); @@ -4827,7 +4826,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat { ty = follow(ty); - const FunctionTypeVar* ftv = get(ty); + const FunctionType* ftv = get(ty); if (ftv && ftv->hasNoGenerics) return ty; @@ -4848,7 +4847,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { - Anyification anyification{¤tModule->internalTypes, scope, singletonTypes, iceHandler, anyType, anyTypePack}; + Anyification anyification{¤tModule->internalTypes, scope, builtinTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (anyification.normalizationTooComplex) reportError(location, NormalizationTooComplex{}); @@ -4863,7 +4862,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { - Anyification anyification{¤tModule->internalTypes, scope, singletonTypes, iceHandler, anyType, anyTypePack}; + Anyification anyification{¤tModule->internalTypes, scope, builtinTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4881,7 +4880,7 @@ TypePackId TypeChecker::anyifyModuleReturnTypePackGenerics(TypePackId tp) if (const VariadicTypePack* vtp = get(tp)) { TypeId ty = FFlag::LuauTypeInferMissingFollows ? follow(vtp->ty) : vtp->ty; - return get(ty) ? anyTypePack : tp; + return get(ty) ? anyTypePack : tp; } if (!get(follow(tp))) @@ -4895,7 +4894,7 @@ TypePackId TypeChecker::anyifyModuleReturnTypePackGenerics(TypePackId tp) for (TypePackIterator e = end(tp); it != e; ++it) { TypeId ty = follow(*it); - resultTypes.push_back(get(ty) ? anyType : ty); + resultTypes.push_back(get(ty) ? anyType : ty); } if (std::optional tail = it.tail()) @@ -4954,7 +4953,7 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d std::string_view sv(utk->key); std::set candidates; - auto accumulate = [&](const TableTypeVar::Props& props) { + auto accumulate = [&](const TableType::Props& props) { for (const auto& [name, ty] : props) { if (sv != name && equalsLower(sv, name)) @@ -4964,7 +4963,7 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d if (auto ttv = getTableType(utk->table)) accumulate(ttv->props); - else if (auto ctv = get(follow(utk->table))) + else if (auto ctv = get(follow(utk->table))) { while (ctv) { @@ -4973,7 +4972,7 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d if (!ctv->parent) break; - ctv = get(*ctv->parent); + ctv = get(*ctv->parent); LUAU_ASSERT(ctv); } } @@ -5009,15 +5008,15 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) { Luau::merge(l, r, [this](TypeId a, TypeId b) { - // TODO: normalize(UnionTypeVar{{a, b}}) + // TODO: normalize(UnionType{{a, b}}) std::unordered_set set; - if (auto utv = get(follow(a))) + if (auto utv = get(follow(a))) set.insert(begin(utv), end(utv)); else set.insert(a); - if (auto utv = get(follow(b))) + if (auto utv = get(follow(b))) set.insert(begin(utv), end(utv)); else set.insert(b); @@ -5025,7 +5024,7 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) std::vector options(set.begin(), set.end()); if (set.size() == 1) return options[0]; - return addType(UnionTypeVar{std::move(options)}); + return addType(UnionType{std::move(options)}); }); } @@ -5041,53 +5040,53 @@ TypeId TypeChecker::freshType(const ScopePtr& scope) TypeId TypeChecker::freshType(TypeLevel level) { - return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); + return currentModule->internalTypes.addType(Type(FreeType(level))); } TypeId TypeChecker::singletonType(bool value) { - return value ? singletonTypes->trueType : singletonTypes->falseType; + return value ? builtinTypes->trueType : builtinTypes->falseType; } TypeId TypeChecker::singletonType(std::string value) { // TODO: cache singleton types - return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(StringSingleton{std::move(value)}))); + return currentModule->internalTypes.addType(Type(SingletonType(StringSingleton{std::move(value)}))); } TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) { - return singletonTypes->errorRecoveryType(); + return builtinTypes->errorRecoveryType(); } TypeId TypeChecker::errorRecoveryType(TypeId guess) { - return singletonTypes->errorRecoveryType(guess); + return builtinTypes->errorRecoveryType(guess); } TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) { - return singletonTypes->errorRecoveryTypePack(); + return builtinTypes->errorRecoveryTypePack(); } TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) { - return singletonTypes->errorRecoveryTypePack(guess); + return builtinTypes->errorRecoveryTypePack(guess); } TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense, TypeId emptySetTy) { return [this, sense, emptySetTy](TypeId ty) -> std::optional { // any/error/free gets a special pass unconditionally because they can't be decided. - if (get(ty) || get(ty) || get(ty)) + if (get(ty) || get(ty) || get(ty)) return ty; // maps boolean primitive to the corresponding singleton equal to sense - if (isPrim(ty, PrimitiveTypeVar::Boolean)) + if (isPrim(ty, PrimitiveType::Boolean)) return singletonType(sense); // if we have boolean singleton, eliminate it if the sense doesn't match with that singleton - if (auto boolean = get(get(ty))) + if (auto boolean = get(get(ty))) return boolean->value == sense ? std::optional(ty) : std::nullopt; // if we have nil, eliminate it if sense is true, otherwise take it @@ -5103,7 +5102,7 @@ std::optional TypeChecker::filterMapImpl(TypeId type, TypeIdPredicate pr { std::vector types = Luau::filterMap(type, predicate); if (!types.empty()) - return types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)}); + return types.size() == 1 ? types[0] : addType(UnionType{std::move(types)}); return std::nullopt; } @@ -5112,7 +5111,7 @@ std::pair, bool> TypeChecker::filterMap(TypeId type, TypeI if (FFlag::LuauUnknownAndNeverType) { TypeId ty = filterMapImpl(type, predicate).value_or(neverType); - return {ty, !bool(get(ty))}; + return {ty, !bool(get(ty))}; } else { @@ -5126,7 +5125,7 @@ std::pair, bool> TypeChecker::pickTypesFromSense(TypeId ty return filterMap(type, mkTruthyPredicate(sense, emptySetTy)); } -TypeId TypeChecker::addTV(TypeVar&& tv) +TypeId TypeChecker::addTV(Type&& tv) { return currentModule->internalTypes.addType(std::move(tv)); } @@ -5381,7 +5380,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno } else if (const auto& table = annotation.as()) { - TableTypeVar::Props props; + TableType::Props props; std::optional tableIndexer; for (const auto& prop : table->props) @@ -5390,7 +5389,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno if (const auto& indexer = table->indexer) tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); - TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; + TableType ttv{props, tableIndexer, scope->level, TableState::Sealed}; ttv.definitionModuleName = currentModuleName; return addType(std::move(ttv)); } @@ -5416,9 +5415,9 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno return el.tp; }); - TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes}); + TypeId fnType = addType(FunctionType{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes}); - FunctionTypeVar* ftv = getMutable(fnType); + FunctionType* ftv = getMutable(fnType); ftv->argNames.reserve(func->argNames.size); for (const auto& el : func->argNames) @@ -5442,7 +5441,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno for (AstType* ann : un->types) types.push_back(resolveType(scope, *ann)); - return addType(UnionTypeVar{types}); + return addType(UnionType{types}); } else if (const auto& un = annotation.as()) { @@ -5450,7 +5449,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno for (AstType* ann : un->types) types.push_back(resolveType(scope, *ann)); - return addType(IntersectionTypeVar{types}); + return addType(IntersectionType{types}); } else if (const auto& tsb = annotation.as()) { @@ -5566,7 +5565,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId target = follow(instantiated); bool needsClone = follow(tf.type) == target; bool shouldMutate = getTableType(tf.type); - TableTypeVar* ttv = getMutableTableType(target); + TableType* ttv = getMutableTableType(target); if (shouldMutate && ttv && needsClone) { @@ -5574,17 +5573,17 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, // want to mutate its table, so we need to explicitly clone that table as // well. If we don't, we will mutate another module's type surface and cause // a use-after-free. - if (get(target)) + if (get(target)) { instantiated = applyTypeFunction.clone(tf.type); - MetatableTypeVar* mtv = getMutable(instantiated); + MetatableType* mtv = getMutable(instantiated); mtv->table = applyTypeFunction.clone(mtv->table); - ttv = getMutable(mtv->table); + ttv = getMutable(mtv->table); } - if (get(target)) + if (get(target)) { instantiated = applyTypeFunction.clone(tf.type); - ttv = getMutable(instantiated); + ttv = getMutable(instantiated); } } @@ -5617,7 +5616,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st Name n = generic.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that - // a collision can only occur when two generic typevars have the same name. + // a collision can only occur when two generic types have the same name. if (scope->privateTypeBindings.count(n) || scope->privateTypePackBindings.count(n)) { // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. @@ -5629,7 +5628,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) - cached = addType(GenericTypeVar{level, n}); + cached = addType(GenericType{level, n}); g = cached; } else @@ -5653,7 +5652,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st Name n = genericPack.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that - // a collision can only occur when two generic typevars have the same name. + // a collision can only occur when two generic types have the same name. if (scope->privateTypePackBindings.count(n) || scope->privateTypeBindings.count(n)) { // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. @@ -5685,7 +5684,7 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const if (auto base = baseof(lvalue)) { std::optional baseTy = resolveLValue(scope, *base); - if (baseTy && get(follow(*baseTy))) + if (baseTy && get(follow(*baseTy))) { ty = baseTy; target = base; @@ -5713,7 +5712,7 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const } // Otherwise, we'll want to walk each option of ty, get its index type, and filter that. - auto utv = get(follow(*ty)); + auto utv = get(follow(*ty)); LUAU_ASSERT(utv); std::unordered_set viableTargetOptions; @@ -5733,7 +5732,7 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const auto [result, ok] = filterMap(*discriminantTy, predicate); if (FFlag::LuauUnknownAndNeverType) { - if (!get(*result)) + if (!get(*result)) { viableTargetOptions.insert(option); viableChildOptions.insert(*result); @@ -5753,12 +5752,12 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const if (s.empty()) return std::nullopt; - // TODO: allocate UnionTypeVar and just normalize. + // TODO: allocate UnionType and just normalize. std::vector options(s.begin(), s.end()); if (options.size() == 1) return options[0]; - return addType(UnionTypeVar{std::move(options)}); + return addType(UnionType{std::move(options)}); }; if (std::optional viableTargetType = intoType(viableTargetOptions)) @@ -5852,7 +5851,7 @@ std::optional TypeChecker::resolveLValue(const RefinementMap& refis, con static bool isUndecidable(TypeId ty) { ty = follow(ty); - return get(ty) || get(ty) || get(ty); + return get(ty) || get(ty) || get(ty); } void TypeChecker::resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense) @@ -5964,7 +5963,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const // So we can just return the right hand side immediately. // typeof(x) == "Instance" where x : any - auto ttv = get(option); + auto ttv = get(option); if (isUndecidable(option) || (ttv && ttv->state == TableState::Free)) return sense ? isaP.ty : option; @@ -6001,7 +6000,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional mapsTo = std::nullopt) { TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional { - if (FFlag::LuauUnknownAndNeverType && sense && get(ty)) + if (FFlag::LuauUnknownAndNeverType && sense && get(ty)) return mapsTo.value_or(ty); if (f(ty) == sense) @@ -6030,20 +6029,20 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r else if (typeguardP.kind == "table") { return refine([](TypeId ty) -> bool { - return isTableIntersection(ty) || get(ty) || get(ty); + return isTableIntersection(ty) || get(ty) || get(ty); }); } else if (typeguardP.kind == "function") { return refine([](TypeId ty) -> bool { - return isOverloadedFunction(ty) || get(ty); + return isOverloadedFunction(ty) || get(ty); }); } else if (typeguardP.kind == "userdata") { // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. return refine([](TypeId ty) -> bool { - return get(ty); + return get(ty); }); } @@ -6056,8 +6055,18 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r TypeId type = follow(typeFun->type); + // You cannot refine to the top class type. + if (FFlag::LuauNegatedClassTypes) + { + if (type == builtinTypes->classType) + { + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); + } + } + // We're only interested in the root class of any classes. - if (auto ctv = get(type); !ctv || ctv->parent) + if (auto ctv = get(type); + !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent != builtinTypes->classType) : (ctv->parent != std::nullopt))) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. @@ -6069,7 +6078,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. auto options = [](TypeId ty) -> std::vector { - if (auto utv = get(follow(ty))) + if (auto utv = get(follow(ty))) return std::vector(begin(utv), end(utv)); return {ty}; }; @@ -6120,7 +6129,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc } else { - bool isOptionSingleton = get(option); + bool isOptionSingleton = get(option); if (!isOptionSingleton) return option; else if (optionIsSubtype && targetIsSubtype) diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 0f75c3efc..e41bf2fe9 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -237,7 +237,7 @@ TypePackId follow(TypePackId tp, std::function mapper) cycleTester = nullptr; if (tp == cycleTester) - throw InternalCompilerError("Luau::follow detected a TypeVar cycle!!"); + throw InternalCompilerError("Luau::follow detected a Type cycle!!"); } } } @@ -381,14 +381,14 @@ bool containsNever(TypePackId tp) while (it != endIt) { - if (get(follow(*it))) + if (get(follow(*it))) return true; ++it; } if (auto tail = it.tail()) { - if (auto vtp = get(*tail); vtp && get(follow(vtp->ty))) + if (auto vtp = get(*tail); vtp && get(follow(vtp->ty))) return true; } diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 7478ac22c..f8f51bcf1 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -12,20 +12,20 @@ namespace Luau { std::optional findMetatableEntry( - NotNull singletonTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) + NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) { type = follow(type); - std::optional metatable = getMetatable(type, singletonTypes); + std::optional metatable = getMetatable(type, builtinTypes); if (!metatable) return std::nullopt; TypeId unwrapped = follow(*metatable); - if (get(unwrapped)) - return singletonTypes->anyType; + if (get(unwrapped)) + return builtinTypes->anyType; - const TableTypeVar* mtt = getTableType(unwrapped); + const TableType* mtt = getTableType(unwrapped); if (!mtt) { errors.push_back(TypeError{location, GenericError{"Metatable was not a table"}}); @@ -40,19 +40,19 @@ std::optional findMetatableEntry( } std::optional findTablePropertyRespectingMeta( - NotNull singletonTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) + NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) { - if (get(ty)) + if (get(ty)) return ty; - if (const TableTypeVar* tableType = getTableType(ty)) + if (const TableType* tableType = getTableType(ty)) { const auto& it = tableType->props.find(name); if (it != tableType->props.end()) return it->second.type; } - std::optional mtIndex = findMetatableEntry(singletonTypes, errors, ty, "__index", location); + std::optional mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location); int count = 0; while (mtIndex) { @@ -69,20 +69,20 @@ std::optional findTablePropertyRespectingMeta( if (fit != itt->props.end()) return fit->second.type; } - else if (const auto& itf = get(index)) + else if (const auto& itf = get(index)) { std::optional r = first(follow(itf->retTypes)); if (!r) - return singletonTypes->nilType; + return builtinTypes->nilType; else return *r; } - else if (get(index)) - return singletonTypes->anyType; + else if (get(index)) + return builtinTypes->anyType; else errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); - mtIndex = findMetatableEntry(singletonTypes, errors, *mtIndex, "__index", location); + mtIndex = findMetatableEntry(builtinTypes, errors, *mtIndex, "__index", location); } return std::nullopt; @@ -117,7 +117,7 @@ std::pair> getParameterExtents(const TxnLog* log, return {minCount, minCount + optionalCount}; } -TypePack extendTypePack(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length) +TypePack extendTypePack(TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length) { TypePack result; @@ -193,7 +193,7 @@ TypePack extendTypePack(TypeArena& arena, NotNull singletonTypes else if (const Unifiable::Error* etp = getMutable(pack)) { while (result.head.size() < length) - result.head.push_back(singletonTypes->errorRecoveryType()); + result.head.push_back(builtinTypes->errorRecoveryType()); result.tail = pack; return result; @@ -214,20 +214,20 @@ std::vector reduceUnion(const std::vector& types) for (TypeId t : types) { t = follow(t); - if (get(t)) + if (get(t)) continue; - if (get(t) || get(t)) + if (get(t) || get(t)) return {t}; - if (const UnionTypeVar* utv = get(t)) + if (const UnionType* utv = get(t)) { for (TypeId ty : utv) { ty = follow(ty); - if (get(ty)) + if (get(ty)) continue; - if (get(ty) || get(ty)) + if (get(ty) || get(ty)) return {ty}; if (result.end() == std::find(result.begin(), result.end(), ty)) @@ -243,7 +243,7 @@ std::vector reduceUnion(const std::vector& types) static std::optional tryStripUnionFromNil(TypeArena& arena, TypeId ty) { - if (const UnionTypeVar* utv = get(ty)) + if (const UnionType* utv = get(ty)) { if (!std::any_of(begin(utv), end(utv), isNil)) return ty; @@ -259,23 +259,23 @@ static std::optional tryStripUnionFromNil(TypeArena& arena, TypeId ty) if (result.empty()) return std::nullopt; - return result.size() == 1 ? result[0] : arena.addType(UnionTypeVar{std::move(result)}); + return result.size() == 1 ? result[0] : arena.addType(UnionType{std::move(result)}); } return std::nullopt; } -TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty) +TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty) { ty = follow(ty); - if (get(ty)) + if (get(ty)) { std::optional cleaned = tryStripUnionFromNil(arena, ty); // If there is no union option without 'nil' if (!cleaned) - return singletonTypes->nilType; + return builtinTypes->nilType; return follow(*cleaned); } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 428820054..d35e37710 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -10,8 +10,8 @@ #include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include @@ -29,11 +29,12 @@ LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauTxnLogTypePackIterator) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNegatedFunctionTypes) +LUAU_FASTFLAG(LuauNegatedClassTypes) namespace Luau { -struct PromoteTypeLevels final : TypeVarOnceVisitor +struct PromoteTypeLevels final : TypeOnceVisitor { TxnLog& log; const TypeArena* typeArena = nullptr; @@ -91,18 +92,18 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor return true; } - bool visit(TypeId ty, const FreeTypeVar&) override + bool visit(TypeId ty, const FreeType&) override { - // Surprise, it's actually a BoundTypeVar that hasn't been committed yet. + // Surprise, it's actually a BoundType that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (!log.is(ty)) + if (!log.is(ty)) return true; - promote(ty, log.getMutable(ty)); + promote(ty, log.getMutable(ty)); return true; } - bool visit(TypeId ty, const FunctionTypeVar&) override + bool visit(TypeId ty, const FunctionType&) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -110,14 +111,14 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) return true; - promote(ty, log.getMutable(ty)); + promote(ty, log.getMutable(ty)); return true; } - bool visit(TypeId ty, const TableTypeVar& ttv) override + bool visit(TypeId ty, const TableType& ttv) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -128,10 +129,10 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) return true; - promote(ty, log.getMutable(ty)); + promote(ty, log.getMutable(ty)); return true; } @@ -167,7 +168,7 @@ void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLev ptl.traverse(tp); } -struct SkipCacheForType final : TypeVarOnceVisitor +struct SkipCacheForType final : TypeOnceVisitor { SkipCacheForType(const DenseHashMap& skipCacheForType, const TypeArena* typeArena) : skipCacheForType(skipCacheForType) @@ -175,31 +176,31 @@ struct SkipCacheForType final : TypeVarOnceVisitor { } - bool visit(TypeId, const FreeTypeVar&) override + bool visit(TypeId, const FreeType&) override { result = true; return false; } - bool visit(TypeId, const BoundTypeVar&) override + bool visit(TypeId, const BoundType&) override { result = true; return false; } - bool visit(TypeId, const GenericTypeVar&) override + bool visit(TypeId, const GenericType&) override { result = true; return false; } - bool visit(TypeId ty, const TableTypeVar&) override + bool visit(TypeId ty, const TableType&) override { // Types from other modules don't contain mutable elements and are ok to cache if (ty->owningArena != typeArena) return false; - TableTypeVar& ttv = *getMutable(ty); + TableType& ttv = *getMutable(ty); if (ttv.boundTo) { @@ -267,7 +268,7 @@ struct SkipCacheForType final : TypeVarOnceVisitor bool Widen::isDirty(TypeId ty) { - return log->is(ty); + return log->is(ty); } bool Widen::isDirty(TypePackId) @@ -278,16 +279,16 @@ bool Widen::isDirty(TypePackId) TypeId Widen::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - auto stv = log->getMutable(ty); + auto stv = log->getMutable(ty); LUAU_ASSERT(stv); if (get(stv)) - return singletonTypes->stringType; + return builtinTypes->stringType; else { // If this assert trips, it's likely we now have number singletons. LUAU_ASSERT(get(stv)); - return singletonTypes->booleanType; + return builtinTypes->booleanType; } } @@ -298,10 +299,10 @@ TypePackId Widen::clean(TypePackId) bool Widen::ignoreChildren(TypeId ty) { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) return true; - return !log->is(ty); + return !log->is(ty); } TypeId Widen::operator()(TypeId ty) @@ -328,13 +329,13 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) } // Used for tagged union matching heuristic, returns first singleton type field -static std::optional> getTableMatchTag(TypeId type) +static std::optional> getTableMatchTag(TypeId type) { if (auto ttv = getTableType(type)) { for (auto&& [name, prop] : ttv->props) { - if (auto sing = get(follow(prop.type))) + if (auto sing = get(follow(prop.type))) return {{name, sing}}; } } @@ -368,7 +369,7 @@ TypeMismatch::Context Unifier::mismatchContext() Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog) : types(normalizer->arena) - , singletonTypes(normalizer->singletonTypes) + , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) , mode(mode) , scope(scope) @@ -406,13 +407,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; - auto superFree = log.getMutable(superTy); - auto subFree = log.getMutable(subTy); + auto superFree = log.getMutable(superTy); + auto subFree = log.getMutable(subTy); if (superFree && subFree && subsumes(useScopes, superFree, subFree)) { if (!occursCheck(subTy, superTy)) - log.replace(subTy, BoundTypeVar(superTy)); + log.replace(subTy, BoundType(superTy)); return; } @@ -425,7 +426,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool log.changeLevel(subTy, superFree->level); } - log.replace(superTy, BoundTypeVar(subTy)); + log.replace(superTy, BoundType(subTy)); } return; @@ -433,7 +434,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (superFree) { // Unification can't change the level of a generic. - auto subGeneric = log.getMutable(subTy); + auto subGeneric = log.getMutable(subTy); if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) { // TODO: a more informative error message? CLI-39912 @@ -445,8 +446,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { promoteTypeLevels(log, types, superFree->level, superFree->scope, useScopes, subTy); - Widen widen{types, singletonTypes}; - log.replace(superTy, BoundTypeVar(widen(subTy))); + Widen widen{types, builtinTypes}; + log.replace(superTy, BoundType(widen(subTy))); } return; @@ -457,12 +458,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { // Normally, if the subtype is free, it should not be bound to any, unknown, or error types. // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. - if (log.get(superTy)) + if (log.get(superTy)) return; } // Unification can't change the level of a generic. - auto superGeneric = log.getMutable(superTy); + auto superGeneric = log.getMutable(superTy); if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 @@ -473,22 +474,22 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursCheck(subTy, superTy)) { promoteTypeLevels(log, types, subFree->level, subFree->scope, useScopes, superTy); - log.replace(subTy, BoundTypeVar(superTy)); + log.replace(subTy, BoundType(superTy)); } return; } - if (get(superTy) || get(superTy) || get(superTy)) + if (get(superTy) || get(superTy) || get(superTy)) return tryUnifyWithAny(subTy, superTy); - if (get(subTy)) + if (get(subTy)) return tryUnifyWithAny(superTy, subTy); - if (log.get(subTy)) + if (log.get(subTy)) return tryUnifyWithAny(superTy, subTy); - if (log.get(subTy)) + if (log.get(subTy)) return tryUnifyWithAny(superTy, subTy); auto& cache = sharedState.cachedUnify; @@ -519,78 +520,77 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (const UnionTypeVar* subUnion = log.getMutable(subTy)) + if (const UnionType* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); } - else if (const UnionTypeVar* uv = (FFlag::LuauSubtypeNormalizer ? nullptr : log.getMutable(superTy))) + else if (const UnionType* uv = (FFlag::LuauSubtypeNormalizer ? nullptr : log.getMutable(superTy))) { tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); } - else if (const IntersectionTypeVar* uv = log.getMutable(superTy)) + else if (const IntersectionType* uv = log.getMutable(superTy)) { tryUnifyTypeWithIntersection(subTy, superTy, uv); } - else if (const UnionTypeVar* uv = log.getMutable(superTy)) + else if (const UnionType* uv = log.getMutable(superTy)) { tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); } - else if (const IntersectionTypeVar* uv = log.getMutable(subTy)) + else if (const IntersectionType* uv = log.getMutable(subTy)) { tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); } - else if (log.getMutable(superTy) && log.getMutable(subTy)) + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyPrimitives(subTy, superTy); - else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) + else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); - else if (auto ptv = get(superTy); - FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveTypeVar::Function && get(subTy)) + else if (auto ptv = get(superTy); + FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveType::Function && get(subTy)) { // Ok. Do nothing. forall functions F, F <: function } - else if (log.getMutable(superTy) && log.getMutable(subTy)) + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyFunctions(subTy, superTy, isFunctionCall); - else if (log.getMutable(superTy) && log.getMutable(subTy)) + else if (log.getMutable(superTy) && log.getMutable(subTy)) { tryUnifyTables(subTy, superTy, isIntersection); } - else if (FFlag::LuauScalarShapeSubtyping && log.get(superTy) && - (log.get(subTy) || log.get(subTy))) + else if (FFlag::LuauScalarShapeSubtyping && log.get(superTy) && (log.get(subTy) || log.get(subTy))) { tryUnifyScalarShape(subTy, superTy, /*reversed*/ false); } - else if (FFlag::LuauScalarShapeSubtyping && log.get(subTy) && - (log.get(superTy) || log.get(superTy))) + else if (FFlag::LuauScalarShapeSubtyping && log.get(subTy) && (log.get(superTy) || log.get(superTy))) { tryUnifyScalarShape(subTy, superTy, /*reversed*/ true); } - // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. - else if (log.getMutable(superTy)) + // tryUnifyWithMetatable assumes its first argument is a MetatableType. The check is otherwise symmetrical. + else if (log.getMutable(superTy)) tryUnifyWithMetatable(subTy, superTy, /*reversed*/ false); - else if (log.getMutable(subTy)) + else if (log.getMutable(subTy)) tryUnifyWithMetatable(superTy, subTy, /*reversed*/ true); - else if (log.getMutable(superTy)) + else if (log.getMutable(superTy)) tryUnifyWithClass(subTy, superTy, /*reversed*/ false); // Unification of nonclasses with classes is almost, but not quite symmetrical. // The order in which we perform this test is significant in the case that both types are classes. - else if (log.getMutable(subTy)) + else if (log.getMutable(subTy)) tryUnifyWithClass(subTy, superTy, /*reversed*/ true); - else if (log.get(superTy)) + else if (log.get(superTy)) tryUnifyTypeWithNegation(subTy, superTy); - else if (log.get(subTy)) + else if (log.get(subTy)) tryUnifyNegationWithType(subTy, superTy); else if (FFlag::LuauUninhabitedSubAnything2 && !normalizer->isInhabited(subTy)) - {} + { + } else reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); @@ -601,7 +601,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool log.popSeen(superTy, subTy); } -void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, TypeId superTy) +void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, TypeId superTy) { // A | B <: T if and only if A <: T and B <: T bool failed = false; @@ -639,13 +639,13 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, superOption = log.follow(superOption); // just skip if the superOption is not free-ish. - auto ttv = log.getMutable(superOption); - if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) + auto ttv = log.getMutable(superOption); + if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) return; // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype // test is successful. - if (auto subUnion = get(subTy)) + if (auto subUnion = get(subTy)) { if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) return; @@ -656,14 +656,14 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, if (log.haveSeen(subTy, superOption)) { // TODO: would it be nice for TxnLog::replace to do this? - if (log.is(superOption)) + if (log.is(superOption)) log.bindTable(superOption, subTy); else log.replace(superOption, *subTy); } }; - if (auto superUnion = log.getMutable(superTy)) + if (auto superUnion = log.getMutable(superTy)) { for (TypeId ty : superUnion) tryBind(ty); @@ -683,7 +683,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, } } -void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall) +void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall) { // T <: A | B if T <: A or T <: B bool found = false; @@ -803,7 +803,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } } -void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv) +void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv) { std::optional unificationTooComplex; std::optional firstFailedOption; @@ -839,7 +839,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption, mismatchContext()}); } -void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) +void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) { // A & B <: T if A <: T or B <: T bool found = false; @@ -917,56 +917,99 @@ void Unifier::tryUnifyNormalizedTypes( { LUAU_ASSERT(FFlag::LuauSubtypeNormalizer); - if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) + if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) return; - else if (get(subNorm.tops)) + else if (get(subNorm.tops)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - if (get(subNorm.errors)) - if (!get(superNorm.errors)) + if (get(subNorm.errors)) + if (!get(superNorm.errors)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - if (get(subNorm.booleans)) + if (get(subNorm.booleans)) { - if (!get(superNorm.booleans)) + if (!get(superNorm.booleans)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } - else if (const SingletonTypeVar* stv = get(subNorm.booleans)) + else if (const SingletonType* stv = get(subNorm.booleans)) { - if (!get(superNorm.booleans) && stv != get(superNorm.booleans)) + if (!get(superNorm.booleans) && stv != get(superNorm.booleans)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } - if (get(subNorm.nils)) - if (!get(superNorm.nils)) + if (get(subNorm.nils)) + if (!get(superNorm.nils)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - if (get(subNorm.numbers)) - if (!get(superNorm.numbers)) + if (get(subNorm.numbers)) + if (!get(superNorm.numbers)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (!isSubtype(subNorm.strings, superNorm.strings)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - if (get(subNorm.threads)) - if (!get(superNorm.errors)) + if (get(subNorm.threads)) + if (!get(superNorm.errors)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - for (TypeId subClass : subNorm.classes) + if (FFlag::LuauNegatedClassTypes) { - bool found = false; - const ClassTypeVar* subCtv = get(subClass); - for (TypeId superClass : superNorm.classes) + for (const auto& [subClass, _] : subNorm.classes.classes) { - const ClassTypeVar* superCtv = get(superClass); - if (isSubclass(subCtv, superCtv)) + bool found = false; + const ClassType* subCtv = get(subClass); + LUAU_ASSERT(subCtv); + + for (const auto& [superClass, superNegations] : superNorm.classes.classes) { - found = true; - break; + const ClassType* superCtv = get(superClass); + LUAU_ASSERT(superCtv); + + if (isSubclass(subCtv, superCtv)) + { + found = true; + + for (TypeId negation : superNegations) + { + const ClassType* negationCtv = get(negation); + LUAU_ASSERT(negationCtv); + + if (isSubclass(subCtv, negationCtv)) + { + found = false; + break; + } + } + + if (found) + break; + } + } + + if (!found) + { + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } } - if (!found) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } + else + { + for (TypeId subClass : subNorm.DEPRECATED_classes) + { + bool found = false; + const ClassType* subCtv = get(subClass); + for (TypeId superClass : superNorm.DEPRECATED_classes) + { + const ClassType* superCtv = get(superClass); + if (isSubclass(subCtv, superCtv)) + { + found = true; + break; + } + } + if (!found) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } } for (TypeId subTable : subNorm.tables) @@ -975,9 +1018,9 @@ void Unifier::tryUnifyNormalizedTypes( for (TypeId superTable : superNorm.tables) { Unifier innerState = makeChildUnifier(); - if (get(superTable)) + if (get(superTable)) innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); - else if (get(subTable)) + else if (get(subTable)) innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); else innerState.tryUnifyTables(subTable, superTable); @@ -1001,7 +1044,7 @@ void Unifier::tryUnifyNormalizedTypes( for (TypeId superFun : *superNorm.functions.parts) { Unifier innerState = makeChildUnifier(); - const FunctionTypeVar* superFtv = get(superFun); + const FunctionType* superFtv = get(superFun); if (!superFtv) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); @@ -1032,14 +1075,14 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized if (overloads.isNever()) { reportError(location, CannotCallNonFunction{function}); - return singletonTypes->errorRecoveryTypePack(); + return builtinTypes->errorRecoveryTypePack(); } std::optional result; - const FunctionTypeVar* firstFun = nullptr; + const FunctionType* firstFun = nullptr; for (TypeId overload : *overloads.parts) { - if (const FunctionTypeVar* ftv = get(overload)) + if (const FunctionType* ftv = get(overload)) { // TODO: instantiate generics? if (ftv->generics.empty() && ftv->genericPacks.empty()) @@ -1072,7 +1115,7 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized else if (auto e = hasUnificationTooComplex(innerState.errors)) { reportError(*e); - return singletonTypes->errorRecoveryTypePack(args); + return builtinTypes->errorRecoveryTypePack(args); } } } @@ -1086,12 +1129,12 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized // The logic for error reporting overload resolution // is currently over in TypeInfer.cpp, should we move it? reportError(location, GenericError{"No matching overload."}); - return singletonTypes->errorRecoveryTypePack(firstFun->retTypes); + return builtinTypes->errorRecoveryTypePack(firstFun->retTypes); } else { reportError(location, CannotCallNonFunction{function}); - return singletonTypes->errorRecoveryTypePack(); + return builtinTypes->errorRecoveryTypePack(); } } @@ -1302,7 +1345,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal { if (!occursCheck(superTp, subTp)) { - Widen widen{types, singletonTypes}; + Widen widen{types, builtinTypes}; log.replace(superTp, Unifiable::Bound(widen(subTp))); } } @@ -1463,13 +1506,13 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(*superIter, singletonTypes->errorRecoveryType()); + tryUnify_(*superIter, builtinTypes->errorRecoveryType()); superIter.advance(); } while (subIter.good()) { - tryUnify_(*subIter, singletonTypes->errorRecoveryType()); + tryUnify_(*subIter, builtinTypes->errorRecoveryType()); subIter.advance(); } @@ -1489,8 +1532,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) { - const PrimitiveTypeVar* superPrim = get(superTy); - const PrimitiveTypeVar* subPrim = get(subTy); + const PrimitiveType* superPrim = get(superTy); + const PrimitiveType* subPrim = get(subTy); if (!superPrim || !subPrim) ice("passed non primitive types to unifyPrimitives"); @@ -1500,9 +1543,9 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) { - const PrimitiveTypeVar* superPrim = get(superTy); - const SingletonTypeVar* superSingleton = get(superTy); - const SingletonTypeVar* subSingleton = get(subTy); + const PrimitiveType* superPrim = get(superTy); + const SingletonType* superSingleton = get(superTy); + const SingletonType* subSingleton = get(subTy); if ((!superPrim && !superSingleton) || !subSingleton) ice("passed non singleton/primitive types to unifySingletons"); @@ -1510,10 +1553,10 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) if (superSingleton && *superSingleton == *subSingleton) return; - if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveType::Boolean && get(subSingleton) && variance == Covariant) return; - if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveType::String && get(subSingleton) && variance == Covariant) return; reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); @@ -1521,8 +1564,8 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) { - FunctionTypeVar* superFunction = log.getMutable(superTy); - FunctionTypeVar* subFunction = log.getMutable(subTy); + FunctionType* superFunction = log.getMutable(superTy); + FunctionType* subFunction = log.getMutable(subTy); if (!superFunction || !subFunction) ice("passed non-function types to unifyFunction"); @@ -1542,7 +1585,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal std::optional instantiated = instantiation.substitute(subTy); if (instantiated.has_value()) { - subFunction = log.getMutable(*instantiated); + subFunction = log.getMutable(*instantiated); if (!subFunction) ice("instantiation made a function type into a non-function type in unifyFunction"); @@ -1626,8 +1669,8 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal } // Updating the log may have invalidated the function pointers - superFunction = log.getMutable(superTy); - subFunction = log.getMutable(subTy); + superFunction = log.getMutable(superTy); + subFunction = log.getMutable(subTy); ctx = context; @@ -1666,9 +1709,9 @@ struct Resetter void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { - TableTypeVar* superTable = log.getMutable(superTy); - TableTypeVar* subTable = log.getMutable(subTy); - TableTypeVar* instantiatedSubTable = subTable; + TableType* superTable = log.getMutable(superTy); + TableType* subTable = log.getMutable(subTy); + TableType* instantiatedSubTable = subTable; if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1685,7 +1728,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) std::optional instantiated = instantiation.substitute(subTy); if (instantiated.has_value()) { - subTable = log.getMutable(*instantiated); + subTable = log.getMutable(*instantiated); instantiatedSubTable = subTable; if (!subTable) @@ -1777,7 +1820,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (subTable->state == TableState::Free) { PendingType* pendingSub = log.queue(subTy); - TableTypeVar* ttv = getMutable(pendingSub); + TableType* ttv = getMutable(pendingSub); LUAU_ASSERT(ttv); ttv->props[name] = prop; subTable = ttv; @@ -1799,8 +1842,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // Otherwise, restart only the table unification - TableTypeVar* newSuperTable = log.getMutable(superTyNew); - TableTypeVar* newSubTable = log.getMutable(subTyNew); + TableType* newSuperTable = log.getMutable(superTyNew); + TableType* newSubTable = log.getMutable(subTyNew); if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { @@ -1842,7 +1885,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) clone.type = deeplyOptional(clone.type); PendingType* pendingSuper = log.queue(superTy); - TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + TableType* pendingSuperTtv = getMutable(pendingSuper); pendingSuperTtv->props[name] = clone; superTable = pendingSuperTtv; } @@ -1852,7 +1895,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (superTable->state == TableState::Free) { PendingType* pendingSuper = log.queue(superTy); - TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + TableType* pendingSuperTtv = getMutable(pendingSuper); pendingSuperTtv->props[name] = prop; superTable = pendingSuperTtv; } @@ -1872,8 +1915,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTyNew); - TableTypeVar* newSubTable = log.getMutable(subTyNew); + TableType* newSuperTable = log.getMutable(superTyNew); + TableType* newSubTable = log.getMutable(subTyNew); if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) @@ -1929,16 +1972,16 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Changing the indexer can invalidate the table pointers. if (FFlag::LuauScalarShapeUnifyToMtOwner2) { - superTable = log.getMutable(log.follow(superTy)); - subTable = log.getMutable(log.follow(subTy)); + superTable = log.getMutable(log.follow(superTy)); + subTable = log.getMutable(log.follow(subTy)); if (!superTable || !subTable) return; } else { - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); } if (!missingProperties.empty()) @@ -1954,7 +1997,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } /* - * TypeVars are commonly cyclic, so it is entirely possible + * Types are commonly cyclic, so it is entirely possible * for unifying a property of a table to change the table itself! * We need to check for this and start over if we notice this occurring. * @@ -1987,7 +2030,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (reversed) std::swap(subTy, superTy); - TableTypeVar* superTable = log.getMutable(superTy); + TableType* superTable = log.getMutable(superTy); if (!superTable || superTable->state != TableState::Free) return reportError(location, TypeMismatch{osuperTy, osubTy, mismatchContext()}); @@ -2002,9 +2045,9 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) // Given t1 where t1 = { lower: (t1) -> (a, b...) } // It should be the case that `string <: t1` iff `(subtype's metatable).__index <: t1` - if (auto metatable = getMetatable(subTy, singletonTypes)) + if (auto metatable = getMetatable(subTy, builtinTypes)) { - auto mttv = log.get(*metatable); + auto mttv = log.get(*metatable); if (!mttv) fail(std::nullopt); @@ -2023,7 +2066,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) { - log.replace(superTy, BoundTypeVar{subTy}); + log.replace(superTy, BoundType{subTy}); return; } } @@ -2040,7 +2083,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype if (child.errors.empty()) - log.replace(superTy, BoundTypeVar{subTy}); + log.replace(superTy, BoundType{subTy}); } return; @@ -2060,30 +2103,30 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see ty = follow(ty); if (isOptional(ty)) return ty; - else if (const TableTypeVar* ttv = get(ty)) + else if (const TableType* ttv = get(ty)) { TypeId& result = seen[ty]; if (result) return result; result = types->addType(*ttv); - TableTypeVar* resultTtv = getMutable(result); + TableType* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) prop.type = deeplyOptional(prop.type, seen); - return types->addType(UnionTypeVar{{singletonTypes->nilType, result}}); + return types->addType(UnionType{{builtinTypes->nilType, result}}); } else - return types->addType(UnionTypeVar{{singletonTypes->nilType, ty}}); + return types->addType(UnionType{{builtinTypes->nilType, ty}}); } void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { - const MetatableTypeVar* superMetatable = get(superTy); + const MetatableType* superMetatable = get(superTy); if (!superMetatable) - ice("tryUnifyMetatable invoked with non-metatable TypeVar"); + ice("tryUnifyMetatable invoked with non-metatable Type"); TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, mismatchContext()}}; - if (const MetatableTypeVar* subMetatable = log.getMutable(subTy)) + if (const MetatableType* subMetatable = log.getMutable(subTy)) { Unifier innerState = makeChildUnifier(); innerState.tryUnify_(subMetatable->table, superMetatable->table); @@ -2097,7 +2140,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) log.concat(std::move(innerState.log)); } - else if (TableTypeVar* subTable = log.getMutable(subTy)) + else if (TableType* subTable = log.getMutable(subTy)) { switch (subTable->state) { @@ -2122,7 +2165,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) } } - if (const TableTypeVar* superTable = log.get(log.follow(superMetatable->table))) + if (const TableType* superTable = log.get(log.follow(superMetatable->table))) { // TODO: Unify indexers. } @@ -2153,7 +2196,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) reportError(mismatchError); } } - else if (log.getMutable(subTy) || log.getMutable(subTy)) + else if (log.getMutable(subTy) || log.getMutable(subTy)) { } else @@ -2175,11 +2218,11 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) reportError(location, TypeMismatch{subTy, superTy, mismatchContext()}); }; - const ClassTypeVar* superClass = get(superTy); + const ClassType* superClass = get(superTy); if (!superClass) - ice("tryUnifyClass invoked with non-class TypeVar"); + ice("tryUnifyClass invoked with non-class Type"); - if (const ClassTypeVar* subClass = get(subTy)) + if (const ClassType* subClass = get(subTy)) { switch (variance) { @@ -2194,7 +2237,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) } ice("Illegal variance setting!"); } - else if (TableTypeVar* subTable = getMutable(subTy)) + else if (TableType* subTable = getMutable(subTy)) { /** * A free table is something whose shape we do not exactly know yet. @@ -2255,7 +2298,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy) { - const NegationTypeVar* ntv = get(superTy); + const NegationType* ntv = get(superTy); if (!ntv) ice("tryUnifyTypeWithNegation superTy must be a negation type"); @@ -2273,7 +2316,7 @@ void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy) void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy) { - const NegationTypeVar* ntv = get(subTy); + const NegationType* ntv = get(subTy); if (!ntv) ice("tryUnifyNegationWithType subTy must be a negation type"); @@ -2378,17 +2421,17 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas seen.insert(ty); - if (state.log.getMutable(ty)) + if (state.log.getMutable(ty)) { // TODO: Only bind if the anyType isn't any, unknown, or error (?) - state.log.replace(ty, BoundTypeVar{anyType}); + state.log.replace(ty, BoundType{anyType}); } - else if (auto fun = state.log.getMutable(ty)) + else if (auto fun = state.log.getMutable(ty)) { queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); queueTypePack(queue, seenTypePacks, state, fun->retTypes, anyTypePack); } - else if (auto table = state.log.getMutable(ty)) + else if (auto table = state.log.getMutable(ty)) { for (const auto& [_name, prop] : table->props) queue.push_back(prop.type); @@ -2399,18 +2442,18 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas queue.push_back(table->indexer->indexResultType); } } - else if (auto mt = state.log.getMutable(ty)) + else if (auto mt = state.log.getMutable(ty)) { queue.push_back(mt->table); queue.push_back(mt->metatable); } - else if (state.log.getMutable(ty)) + else if (state.log.getMutable(ty)) { - // ClassTypeVars never contain free typevars. + // ClassTypes never contain free types. } - else if (auto union_ = state.log.getMutable(ty)) + else if (auto union_ = state.log.getMutable(ty)) queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = state.log.getMutable(ty)) + else if (auto intersection = state.log.getMutable(ty)) queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); else { @@ -2420,10 +2463,10 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) { - LUAU_ASSERT(get(anyTy) || get(anyTy) || get(anyTy) || get(anyTy)); + LUAU_ASSERT(get(anyTy) || get(anyTy) || get(anyTy) || get(anyTy)); // These types are not visited in general loop below - if (get(subTy) || get(subTy) || get(subTy)) + if (get(subTy) || get(subTy) || get(subTy)) return; TypePackId anyTp; @@ -2431,8 +2474,8 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) anyTp = types->addTypePack(TypePackVar{VariadicTypePack{anyTy}}); else { - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes->anyType}}); - anyTp = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{builtinTypes->anyType}}); + anyTp = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); } std::vector queue = {subTy}; @@ -2441,14 +2484,14 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) sharedState.tempSeenTp.clear(); Luau::tryUnifyWithAny( - queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, FFlag::LuauUnknownAndNeverType ? anyTy : singletonTypes->anyType, anyTp); + queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, FFlag::LuauUnknownAndNeverType ? anyTy : builtinTypes->anyType, anyTp); } void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) { LUAU_ASSERT(get(anyTp)); - const TypeId anyTy = singletonTypes->errorRecoveryType(); + const TypeId anyTy = builtinTypes->errorRecoveryType(); std::vector queue; @@ -2462,7 +2505,7 @@ void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) { - return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); + return Luau::findTablePropertyRespectingMeta(builtinTypes, errors, lhsType, name, location); } TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) @@ -2518,19 +2561,19 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { reportError(location, OccursCheckFailed{}); - log.replace(needle, *singletonTypes->errorRecoveryType()); + log.replace(needle, *builtinTypes->errorRecoveryType()); return true; } - if (log.getMutable(haystack)) + if (log.getMutable(haystack)) return false; - else if (auto a = log.getMutable(haystack)) + else if (auto a = log.getMutable(haystack)) { for (TypeId ty : a->options) check(ty); } - else if (auto a = log.getMutable(haystack)) + else if (auto a = log.getMutable(haystack)) { for (TypeId ty : a->parts) check(ty); @@ -2564,12 +2607,12 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); - while (!log.getMutable(haystack)) + while (!log.getMutable(haystack)) { if (needle == haystack) { reportError(location, OccursCheckFailed{}); - log.replace(needle, *singletonTypes->errorRecoveryTypePack()); + log.replace(needle, *builtinTypes->errorRecoveryTypePack()); return true; } diff --git a/Ast/src/Location.cpp b/Ast/src/Location.cpp index 67c2dd4b6..d01d8a186 100644 --- a/Ast/src/Location.cpp +++ b/Ast/src/Location.cpp @@ -5,8 +5,8 @@ namespace Luau { Position::Position(unsigned int line, unsigned int column) - : line(line) - , column(column) + : line(line) + , column(column) { } diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 59f2c14ab..dbb366b23 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -107,6 +107,8 @@ class AssemblyBuilderX64 void vmulsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vdivsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vandpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vxorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vucomisd(OperandX64 src1, OperandX64 src2); @@ -129,6 +131,11 @@ class AssemblyBuilderX64 void vmovaps(OperandX64 dst, OperandX64 src); void vmovupd(OperandX64 dst, OperandX64 src); void vmovups(OperandX64 dst, OperandX64 src); + void vmovq(OperandX64 lhs, OperandX64 rhs); + + void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + // Run final checks void finalize(); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index c17fab6bc..f23fe4634 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -614,6 +614,11 @@ void AssemblyBuilderX64::vdivsd(OperandX64 dst, OperandX64 src1, OperandX64 src2 placeAvx("vdivsd", dst, src1, src2, 0x5e, false, AVX_0F, AVX_F2); } +void AssemblyBuilderX64::vandpd(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + placeAvx("vandpd", dst, src1, src2, 0x54, false, AVX_0F, AVX_66); +} + void AssemblyBuilderX64::vxorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2) { placeAvx("vxorpd", dst, src1, src2, 0x57, false, AVX_0F, AVX_66); @@ -699,6 +704,36 @@ void AssemblyBuilderX64::vmovups(OperandX64 dst, OperandX64 src) placeAvx("vmovups", dst, src, 0x10, 0x11, false, AVX_0F, AVX_NP); } +void AssemblyBuilderX64::vmovq(OperandX64 dst, OperandX64 src) +{ + if (dst.base.size == SizeX64::xmmword) + { + LUAU_ASSERT(dst.cat == CategoryX64::reg); + LUAU_ASSERT(src.base.size == SizeX64::qword); + placeAvx("vmovq", dst, src, 0x6e, true, AVX_0F, AVX_66); + } + else if (dst.base.size == SizeX64::qword) + { + LUAU_ASSERT(src.cat == CategoryX64::reg); + LUAU_ASSERT(src.base.size == SizeX64::xmmword); + placeAvx("vmovq", src, dst, 0x7e, true, AVX_0F, AVX_66); + } + else + { + LUAU_ASSERT(!"No encoding for left operand of this category"); + } +} + +void AssemblyBuilderX64::vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + placeAvx("vmaxsd", dst, src1, src2, 0x5f, false, AVX_0F, AVX_F2); +} + +void AssemblyBuilderX64::vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + placeAvx("vminsd", dst, src1, src2, 0x5d, false, AVX_0F, AVX_F2); +} + void AssemblyBuilderX64::finalize() { code.resize(codePos - code.data()); diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 04bd43aaf..41a2c260f 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -5,6 +5,7 @@ #include "Luau/Bytecode.h" #include "EmitCommonX64.h" +#include "NativeState.h" #include "lstate.h" @@ -88,6 +89,341 @@ BuiltinImplResult emitBuiltinMathSqrt(AssemblyBuilderX64& build, int nparams, in return {BuiltinImplType::UsesFallback, 1}; } +BuiltinImplResult emitBuiltinMathAbs(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + if (build.logText) + build.logAppend("; inlined LBF_MATH_ABS\n"); + + jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + + build.vmovsd(xmm0, luauRegValue(arg)); + build.vandpd(xmm0, xmm0, build.i64(~(1LL << 63))); + build.vmovsd(luauRegValue(ra), xmm0); + + if (ra != arg) + build.mov(luauRegTag(ra), LUA_TNUMBER); + + return {BuiltinImplType::UsesFallback, 1}; +} + +static BuiltinImplResult emitBuiltinMathSingleArgFunc( + AssemblyBuilderX64& build, int nparams, int ra, int arg, int nresults, Label& fallback, const char* name, int32_t offset) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + if (build.logText) + build.logAppend("; inlined %s\n", name); + + jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + + build.vmovsd(xmm0, luauRegValue(arg)); + build.call(qword[rNativeContext + offset]); + + build.vmovsd(luauRegValue(ra), xmm0); + + if (ra != arg) + build.mov(luauRegTag(ra), LUA_TNUMBER); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult emitBuiltinMathExp(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_EXP", offsetof(NativeContext, libm_exp)); +} + +BuiltinImplResult emitBuiltinMathDeg(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + if (build.logText) + build.logAppend("; inlined LBF_MATH_DEG\n"); + + jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + + const double rpd = (3.14159265358979323846 / 180.0); + + build.vmovsd(xmm0, luauRegValue(arg)); + build.vdivsd(xmm0, xmm0, build.f64(rpd)); + build.vmovsd(luauRegValue(ra), xmm0); + + if (ra != arg) + build.mov(luauRegTag(ra), LUA_TNUMBER); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult emitBuiltinMathRad(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + if (build.logText) + build.logAppend("; inlined LBF_MATH_RAD\n"); + + jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + + const double rpd = (3.14159265358979323846 / 180.0); + + build.vmovsd(xmm0, luauRegValue(arg)); + build.vmulsd(xmm0, xmm0, build.f64(rpd)); + build.vmovsd(luauRegValue(ra), xmm0); + + if (ra != arg) + build.mov(luauRegTag(ra), LUA_TNUMBER); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult emitBuiltinMathFmod(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + if (nparams < 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + if (build.logText) + build.logAppend("; inlined LBF_MATH_FMOD\n"); + + jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + + // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though + build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); + build.jcc(ConditionX64::NotEqual, fallback); + + build.vmovsd(xmm0, luauRegValue(arg)); + build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); + build.call(qword[rNativeContext + offsetof(NativeContext, libm_fmod)]); + + build.vmovsd(luauRegValue(ra), xmm0); + + if (ra != arg) + build.mov(luauRegTag(ra), LUA_TNUMBER); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult emitBuiltinMathPow(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + if (nparams < 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + if (build.logText) + build.logAppend("; inlined LBF_MATH_POW\n"); + + jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + + // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though + build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); + build.jcc(ConditionX64::NotEqual, fallback); + + build.vmovsd(xmm0, luauRegValue(arg)); + build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); + build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); + + build.vmovsd(luauRegValue(ra), xmm0); + + if (ra != arg) + build.mov(luauRegTag(ra), LUA_TNUMBER); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult emitBuiltinMathMin(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + if (nparams != 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + if (build.logText) + build.logAppend("; inlined LBF_MATH_MIN\n"); + + jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + + // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though + build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); + build.jcc(ConditionX64::NotEqual, fallback); + + build.vmovsd(xmm0, qword[args + offsetof(TValue, value)]); + build.vminsd(xmm0, xmm0, luauRegValue(arg)); + + build.vmovsd(luauRegValue(ra), xmm0); + + if (ra != arg) + build.mov(luauRegTag(ra), LUA_TNUMBER); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult emitBuiltinMathMax(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + if (nparams != 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + if (build.logText) + build.logAppend("; inlined LBF_MATH_MAX\n"); + + jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + + // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though + build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); + build.jcc(ConditionX64::NotEqual, fallback); + + build.vmovsd(xmm0, qword[args + offsetof(TValue, value)]); + build.vmaxsd(xmm0, xmm0, luauRegValue(arg)); + + build.vmovsd(luauRegValue(ra), xmm0); + + if (ra != arg) + build.mov(luauRegTag(ra), LUA_TNUMBER); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult emitBuiltinMathAsin(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_ASIN", offsetof(NativeContext, libm_asin)); +} + +BuiltinImplResult emitBuiltinMathSin(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_SIN", offsetof(NativeContext, libm_sin)); +} + +BuiltinImplResult emitBuiltinMathSinh(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_SINH", offsetof(NativeContext, libm_sinh)); +} + +BuiltinImplResult emitBuiltinMathAcos(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_ACOS", offsetof(NativeContext, libm_acos)); +} + +BuiltinImplResult emitBuiltinMathCos(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_COS", offsetof(NativeContext, libm_cos)); +} + +BuiltinImplResult emitBuiltinMathCosh(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_COSH", offsetof(NativeContext, libm_cosh)); +} + +BuiltinImplResult emitBuiltinMathAtan(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_ATAN", offsetof(NativeContext, libm_atan)); +} + +BuiltinImplResult emitBuiltinMathTan(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_TAN", offsetof(NativeContext, libm_tan)); +} + +BuiltinImplResult emitBuiltinMathTanh(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_TANH", offsetof(NativeContext, libm_tanh)); +} + +BuiltinImplResult emitBuiltinMathAtan2(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + if (nparams < 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + if (build.logText) + build.logAppend("; inlined LBF_MATH_ATAN2\n"); + + jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + + // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though + build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); + build.jcc(ConditionX64::NotEqual, fallback); + + build.vmovsd(xmm0, luauRegValue(arg)); + build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); + build.call(qword[rNativeContext + offsetof(NativeContext, libm_atan2)]); + + build.vmovsd(luauRegValue(ra), xmm0); + + if (ra != arg) + build.mov(luauRegTag(ra), LUA_TNUMBER); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult emitBuiltinMathLog10(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_LOG10", offsetof(NativeContext, libm_log10)); +} + +BuiltinImplResult emitBuiltinMathLog(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + if (build.logText) + build.logAppend("; inlined LBF_MATH_LOG\n"); + + jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + + build.vmovsd(xmm0, luauRegValue(arg)); + + if (nparams == 1) + { + build.call(qword[rNativeContext + offsetof(NativeContext, libm_log)]); + } + else + { + Label log10check, logdivlog, exit; + + // Using 'rbx' for non-volatile temporary storage of log(arg1) result + RegisterX64 tmp = rbx; + OperandX64 arg2value = qword[args + offsetof(TValue, value)]; + + // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though + build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); + build.jcc(ConditionX64::NotEqual, fallback); + + build.vmovsd(xmm1, arg2value); + + jumpOnNumberCmp(build, noreg, build.f64(2.0), xmm1, ConditionX64::NotEqual, log10check); + + build.call(qword[rNativeContext + offsetof(NativeContext, libm_log2)]); + build.jmp(exit); + + build.setLabel(log10check); + jumpOnNumberCmp(build, noreg, build.f64(10.0), xmm1, ConditionX64::NotEqual, logdivlog); + + build.call(qword[rNativeContext + offsetof(NativeContext, libm_log10)]); + build.jmp(exit); + + build.setLabel(logdivlog); + + // log(arg1) + build.call(qword[rNativeContext + offsetof(NativeContext, libm_log)]); + build.vmovq(tmp, xmm0); + + // log(arg2) + build.vmovsd(xmm0, arg2value); + build.call(qword[rNativeContext + offsetof(NativeContext, libm_log)]); + + // log(arg1) / log(arg2) + build.vmovq(xmm1, tmp); + build.vdivsd(xmm0, xmm1, xmm0); + + build.setLabel(exit); + } + + build.vmovsd(luauRegValue(ra), xmm0); + + if (ra != arg) + build.mov(luauRegTag(ra), LUA_TNUMBER); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult emitBuiltin(AssemblyBuilderX64& build, int bfid, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) { switch (bfid) @@ -100,6 +436,46 @@ BuiltinImplResult emitBuiltin(AssemblyBuilderX64& build, int bfid, int nparams, return emitBuiltinMathCeil(build, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_SQRT: return emitBuiltinMathSqrt(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_ABS: + return emitBuiltinMathAbs(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_EXP: + return emitBuiltinMathExp(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_DEG: + return emitBuiltinMathDeg(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_RAD: + return emitBuiltinMathRad(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_FMOD: + return emitBuiltinMathFmod(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_POW: + return emitBuiltinMathPow(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_MIN: + return emitBuiltinMathMin(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_MAX: + return emitBuiltinMathMax(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_ASIN: + return emitBuiltinMathAsin(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_SIN: + return emitBuiltinMathSin(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_SINH: + return emitBuiltinMathSinh(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_ACOS: + return emitBuiltinMathAcos(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_COS: + return emitBuiltinMathCos(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_COSH: + return emitBuiltinMathCosh(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_ATAN: + return emitBuiltinMathAtan(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_TAN: + return emitBuiltinMathTan(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_TANH: + return emitBuiltinMathTanh(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_ATAN2: + return emitBuiltinMathAtan2(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_LOG10: + return emitBuiltinMathLog10(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_LOG: + return emitBuiltinMathLog(build, nparams, ra, arg, args, nresults, fallback); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/Fallbacks.cpp b/CodeGen/src/Fallbacks.cpp index a458c4eb0..41f2bc8c8 100644 --- a/CodeGen/src/Fallbacks.cpp +++ b/CodeGen/src/Fallbacks.cpp @@ -612,4 +612,3 @@ const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId LUAU_ASSERT(!"Unsupported deprecated opcode"); LUAU_UNREACHABLE(); } - diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 542203755..62974fe3e 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -80,7 +80,23 @@ void initHelperFunctions(NativeState& data) data.context.luaF_close = luaF_close; + data.context.libm_exp = exp; data.context.libm_pow = pow; + data.context.libm_fmod = fmod; + data.context.libm_log = log; + data.context.libm_log2 = log2; + data.context.libm_log10 = log10; + + data.context.libm_asin = asin; + data.context.libm_sin = sin; + data.context.libm_sinh = sinh; + data.context.libm_acos = acos; + data.context.libm_cos = cos; + data.context.libm_cosh = cosh; + data.context.libm_atan = atan; + data.context.libm_atan2 = atan2; + data.context.libm_tan = tan; + data.context.libm_tanh = tanh; data.context.forgLoopNodeIter = forgLoopNodeIter; data.context.forgLoopNonTableFallback = forgLoopNonTableFallback; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index e5a244162..9138ba472 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -78,7 +78,22 @@ struct NativeContext void (*luaF_close)(lua_State* L, StkId level) = nullptr; + double (*libm_exp)(double) = nullptr; double (*libm_pow)(double, double) = nullptr; + double (*libm_fmod)(double, double) = nullptr; + double (*libm_asin)(double) = nullptr; + double (*libm_sin)(double) = nullptr; + double (*libm_sinh)(double) = nullptr; + double (*libm_acos)(double) = nullptr; + double (*libm_cos)(double) = nullptr; + double (*libm_cosh)(double) = nullptr; + double (*libm_atan)(double) = nullptr; + double (*libm_atan2)(double, double) = nullptr; + double (*libm_tan)(double) = nullptr; + double (*libm_tanh)(double) = nullptr; + double (*libm_log)(double) = nullptr; + double (*libm_log2)(double) = nullptr; + double (*libm_log10)(double) = nullptr; // Helper functions bool (*forgLoopNodeIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 5d6723669..b4dea4f56 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -350,7 +350,7 @@ struct Compiler { LUAU_ASSERT(!multRet || unsigned(target + targetCount) == regTop); - setDebugLine(expr); // normally compileExpr sets up line info, but compileExprCall can be called directly + setDebugLine(expr); // normally compileExpr sets up line info, but compileExprVarargs can be called directly bytecode.emitABC(LOP_GETVARARGS, target, multRet ? 0 : uint8_t(targetCount + 1), 0); } diff --git a/Sources.cmake b/Sources.cmake index e243ea74d..437ff9934 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -147,12 +147,12 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypePack.h Analysis/include/Luau/TypeUtils.h - Analysis/include/Luau/TypeVar.h + Analysis/include/Luau/Type.h Analysis/include/Luau/Unifiable.h Analysis/include/Luau/Unifier.h Analysis/include/Luau/UnifierSharedState.h Analysis/include/Luau/Variant.h - Analysis/include/Luau/VisitTypeVar.h + Analysis/include/Luau/VisitType.h Analysis/src/Anyification.cpp Analysis/src/ApplyTypeFunction.cpp @@ -196,7 +196,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypeInfer.cpp Analysis/src/TypePack.cpp Analysis/src/TypeUtils.cpp - Analysis/src/TypeVar.cpp + Analysis/src/Type.cpp Analysis/src/Unifiable.cpp Analysis/src/Unifier.cpp ) @@ -366,7 +366,7 @@ if(TARGET Luau.UnitTest) tests/TypePack.test.cpp tests/TypeVar.test.cpp tests/Variant.test.cpp - tests/VisitTypeVar.test.cpp + tests/VisitType.test.cpp tests/main.cpp) endif() diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 9e0a68296..c8ba22681 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -101,32 +101,32 @@ int registerTypes(Luau::TypeChecker& env) TypeArena& arena = env.globalTypes; // Vector3 stub - TypeId vector3MetaType = arena.addType(TableTypeVar{}); + TypeId vector3MetaType = arena.addType(TableType{}); - TypeId vector3InstanceType = arena.addType(ClassTypeVar{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test"}); - getMutable(vector3InstanceType)->props = { + TypeId vector3InstanceType = arena.addType(ClassType{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test"}); + getMutable(vector3InstanceType)->props = { {"X", {env.numberType}}, {"Y", {env.numberType}}, {"Z", {env.numberType}}, }; - getMutable(vector3MetaType)->props = { + getMutable(vector3MetaType)->props = { {"__add", {makeFunction(arena, nullopt, {vector3InstanceType, vector3InstanceType}, {vector3InstanceType})}}, }; env.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; // Instance stub - TypeId instanceType = arena.addType(ClassTypeVar{"Instance", {}, nullopt, nullopt, {}, {}, "Test"}); - getMutable(instanceType)->props = { + TypeId instanceType = arena.addType(ClassType{"Instance", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(instanceType)->props = { {"Name", {env.stringType}}, }; env.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; // Part stub - TypeId partType = arena.addType(ClassTypeVar{"Part", {}, instanceType, nullopt, {}, {}, "Test"}); - getMutable(partType)->props = { + TypeId partType = arena.addType(ClassType{"Part", {}, instanceType, nullopt, {}, {}, "Test"}); + getMutable(partType)->props = { {"Position", {vector3InstanceType}}, }; diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 1af03e4f5..b5dbf583b 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -436,6 +436,11 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") SINGLE_COMPARE(vdivsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5e, 0xc6); SINGLE_COMPARE(vxorpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x29, 0x57, 0xc6); + + SINGLE_COMPARE(vandpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x29, 0x54, 0xc6); + + SINGLE_COMPARE(vmaxsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5f, 0xc6); + SINGLE_COMPARE(vminsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5d, 0xc6); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXUnaryMergeInstructionForms") @@ -475,6 +480,10 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXMoveInstructionForms") SINGLE_COMPARE(vmovups(xmm8, xmmword[r9]), 0xc4, 0x41, 0x78, 0x10, 0x01); SINGLE_COMPARE(vmovups(xmmword[r9], xmm10), 0xc4, 0x41, 0x78, 0x11, 0x11); SINGLE_COMPARE(vmovups(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7c, 0x10, 0x01); + SINGLE_COMPARE(vmovq(xmm1, rbx), 0xc4, 0xe1, 0xf9, 0x6e, 0xcb); + SINGLE_COMPARE(vmovq(rbx, xmm1), 0xc4, 0xe1, 0xf9, 0x7e, 0xcb); + SINGLE_COMPARE(vmovq(xmm1, qword[r9]), 0xc4, 0xc1, 0xf9, 0x6e, 0x09); + SINGLE_COMPARE(vmovq(qword[r9], xmm1), 0xc4, 0xc1, 0xf9, 0x7e, 0x09); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXConversionInstructionForms") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 123708cab..105829473 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2,8 +2,8 @@ #include "Luau/Autocomplete.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include "Luau/StringUtils.h" #include "Fixture.h" @@ -18,7 +18,7 @@ LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) using namespace Luau; -static std::optional nullCallback(std::string tag, std::optional ptr) +static std::optional nullCallback(std::string tag, std::optional ptr) { return std::nullopt; } @@ -499,7 +499,7 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") CHECK_EQ(ac.context, AutocompleteContext::Statement); TypeId t = follow(*ac.entryMap["A"].type); - const TableTypeVar* tt = get(t); + const TableType* tt = get(t); REQUIRE(tt); CHECK(tt->props.count("two")); @@ -1244,7 +1244,7 @@ end REQUIRE(ac.entryMap.count("Table")); REQUIRE(ac.entryMap["Table"].type); - const TableTypeVar* tv = get(follow(*ac.entryMap["Table"].type)); + const TableType* tv = get(follow(*ac.entryMap["Table"].type)); REQUIRE(tv); CHECK(tv->props.count("x")); } diff --git a/tests/BuiltinDefinitions.test.cpp b/tests/BuiltinDefinitions.test.cpp index 496df4b42..188f2190f 100644 --- a/tests/BuiltinDefinitions.test.cpp +++ b/tests/BuiltinDefinitions.test.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -22,12 +22,12 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "lib_documentation_symbols") CHECK_MESSAGE( actualRootSymbol == expectedRootSymbol, "expected symbol ", expectedRootSymbol, " for global ", nameString, ", got ", actualRootSymbol); - const TableTypeVar::Props* props = nullptr; - if (const TableTypeVar* ttv = get(binding.typeId)) + const TableType::Props* props = nullptr; + if (const TableType* ttv = get(binding.typeId)) { props = &ttv->props; } - else if (const ClassTypeVar* ctv = get(binding.typeId)) + else if (const ClassType* ctv = get(binding.typeId)) { props = &ctv->props; } diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index 18939e24d..087b88d53 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -16,14 +16,14 @@ ClassFixture::ClassFixture() unfreeze(arena); - TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); - getMutable(baseClassInstanceType)->props = { + TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(baseClassInstanceType)->props = { {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, {"BaseField", {numberType}}, }; - TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); - getMutable(baseClassType)->props = { + TypeId baseClassType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(baseClassType)->props = { {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, {"New", {makeFunction(arena, nullopt, {}, {baseClassInstanceType})}}, @@ -31,75 +31,75 @@ ClassFixture::ClassFixture() typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; addGlobalBinding(frontend, "BaseClass", baseClassType, "@test"); - TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); + TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); - getMutable(childClassInstanceType)->props = { + getMutable(childClassInstanceType)->props = { {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); - getMutable(childClassType)->props = { + TypeId childClassType = arena.addType(ClassType{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); + getMutable(childClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; addGlobalBinding(frontend, "ChildClass", childClassType, "@test"); - TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); + TypeId grandChildInstanceType = arena.addType(ClassType{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); - getMutable(grandChildInstanceType)->props = { + getMutable(grandChildInstanceType)->props = { {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); - getMutable(grandChildType)->props = { + TypeId grandChildType = arena.addType(ClassType{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); + getMutable(grandChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; addGlobalBinding(frontend, "GrandChild", childClassType, "@test"); - TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); + TypeId anotherChildInstanceType = arena.addType(ClassType{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); - getMutable(anotherChildInstanceType)->props = { + getMutable(anotherChildInstanceType)->props = { {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); - getMutable(anotherChildType)->props = { + TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); + getMutable(anotherChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; addGlobalBinding(frontend, "AnotherChild", childClassType, "@test"); - TypeId unrelatedClassInstanceType = arena.addType(ClassTypeVar{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId unrelatedClassInstanceType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); - TypeId unrelatedClassType = arena.addType(ClassTypeVar{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); - getMutable(unrelatedClassType)->props = { + TypeId unrelatedClassType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(unrelatedClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {unrelatedClassInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["UnrelatedClass"] = TypeFun{{}, unrelatedClassInstanceType}; addGlobalBinding(frontend, "UnrelatedClass", unrelatedClassType, "@test"); - TypeId vector2MetaType = arena.addType(TableTypeVar{}); + TypeId vector2MetaType = arena.addType(TableType{}); - TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); - getMutable(vector2InstanceType)->props = { + TypeId vector2InstanceType = arena.addType(ClassType{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); + getMutable(vector2InstanceType)->props = { {"X", {numberType}}, {"Y", {numberType}}, }; - TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); - getMutable(vector2Type)->props = { + TypeId vector2Type = arena.addType(ClassType{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(vector2Type)->props = { {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, }; - getMutable(vector2MetaType)->props = { + getMutable(vector2MetaType)->props = { {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); - TypeId callableClassMetaType = arena.addType(TableTypeVar{}); - TypeId callableClassType = arena.addType(ClassTypeVar{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); - getMutable(callableClassMetaType)->props = { + TypeId callableClassMetaType = arena.addType(TableType{}); + TypeId callableClassType = arena.addType(ClassType{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); + getMutable(callableClassMetaType)->props = { {"__call", {makeFunction(arena, nullopt, {callableClassType, typeChecker.stringType}, {typeChecker.numberType})}}, }; typeChecker.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 6be838dd2..c32f2870f 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -440,27 +440,27 @@ TEST_CASE("Vector") static void populateRTTI(lua_State* L, Luau::TypeId type) { - if (auto p = Luau::get(type)) + if (auto p = Luau::get(type)) { switch (p->type) { - case Luau::PrimitiveTypeVar::Boolean: + case Luau::PrimitiveType::Boolean: lua_pushstring(L, "boolean"); break; - case Luau::PrimitiveTypeVar::NilType: + case Luau::PrimitiveType::NilType: lua_pushstring(L, "nil"); break; - case Luau::PrimitiveTypeVar::Number: + case Luau::PrimitiveType::Number: lua_pushstring(L, "number"); break; - case Luau::PrimitiveTypeVar::String: + case Luau::PrimitiveType::String: lua_pushstring(L, "string"); break; - case Luau::PrimitiveTypeVar::Thread: + case Luau::PrimitiveType::Thread: lua_pushstring(L, "thread"); break; @@ -468,7 +468,7 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) LUAU_ASSERT(!"Unknown primitive type"); } } - else if (auto t = Luau::get(type)) + else if (auto t = Luau::get(type)) { lua_newtable(L); @@ -478,18 +478,18 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) lua_setfield(L, -2, name.c_str()); } } - else if (Luau::get(type)) + else if (Luau::get(type)) { lua_pushstring(L, "function"); } - else if (Luau::get(type)) + else if (Luau::get(type)) { lua_pushstring(L, "any"); } - else if (auto i = Luau::get(type)) + else if (auto i = Luau::get(type)) { for (const auto& part : i->parts) - LUAU_ASSERT(Luau::get(part)); + LUAU_ASSERT(Luau::get(part)); lua_pushstring(L, "function"); } @@ -504,8 +504,8 @@ TEST_CASE("Types") runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; - Luau::SingletonTypes singletonTypes; - Luau::TypeChecker env(&moduleResolver, Luau::NotNull{&singletonTypes}, &iceHandler); + Luau::BuiltinTypes builtinTypes; + Luau::TypeChecker env(&moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); Luau::registerBuiltinGlobals(env); Luau::freeze(env.globalTypes); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index 30e1b2e6e..64e6baaf0 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -9,7 +9,7 @@ ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() , mainModule(new Module) , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { - BlockedTypeVar::nextIndex = 0; + BlockedType::nextIndex = 0; BlockedTypePack::nextIndex = 0; } @@ -17,7 +17,7 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) { AstStatBlock* root = parse(code); dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); - cgb = std::make_unique("MainModule", mainModule, &arena, NotNull(&moduleResolver), singletonTypes, NotNull(&ice), + cgb = std::make_unique("MainModule", mainModule, &arena, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), frontend.getGlobalScope(), &logger, NotNull{dfg.get()}); cgb->visit(root); rootScope = cgb->rootScope; diff --git a/tests/ConstraintGraphBuilderFixture.h b/tests/ConstraintGraphBuilderFixture.h index 9785a838a..5e7fedab5 100644 --- a/tests/ConstraintGraphBuilderFixture.h +++ b/tests/ConstraintGraphBuilderFixture.h @@ -19,7 +19,7 @@ struct ConstraintGraphBuilderFixture : Fixture ModulePtr mainModule; DcrLogger logger; UnifierSharedState sharedState{&ice}; - Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; std::unique_ptr dfg; std::unique_ptr cgb; diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index eb77ce521..416292817 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -7,7 +7,7 @@ #include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" #include "Luau/Parser.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/TypeAttach.h" #include "Luau/Transpiler.h" @@ -141,7 +141,7 @@ Fixture::Fixture(bool freeze, bool prepareAutocomplete) , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* randomConstraintResolutionSeed */ randomSeed}) , typeChecker(frontend.typeChecker) - , singletonTypes(frontend.singletonTypes) + , builtinTypes(frontend.builtinTypes) { configResolver.defaultConfig.mode = Mode::Strict; configResolver.defaultConfig.enabledLint.warningMask = ~0ull; @@ -293,14 +293,14 @@ SourceModule* Fixture::getMainSourceModule() return frontend.getSourceModule(fromString(mainModuleName)); } -std::optional Fixture::getPrimitiveType(TypeId ty) +std::optional Fixture::getPrimitiveType(TypeId ty) { REQUIRE(ty != nullptr); TypeId aType = follow(ty); REQUIRE(aType != nullptr); - const PrimitiveTypeVar* pt = get(aType); + const PrimitiveType* pt = get(aType); if (pt != nullptr) return pt->type; else @@ -513,7 +513,7 @@ std::string rep(const std::string& s, size_t n) bool isInArena(TypeId t, const TypeArena& arena) { - return arena.typeVars.contains(t); + return arena.types.contains(t); } void dumpErrors(const ModulePtr& module) @@ -554,12 +554,13 @@ std::optional linearSearchForBinding(Scope* scope, const char* name) void registerHiddenTypes(Fixture& fixture, TypeArena& arena) { - TypeId t = arena.addType(GenericTypeVar{"T"}); + TypeId t = arena.addType(GenericType{"T"}); GenericTypeDefinition genericT{t}; ScopePtr moduleScope = fixture.frontend.getGlobalScope(); - moduleScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, arena.addType(NegationTypeVar{t})}; - moduleScope->exportedTypeBindings["fun"] = TypeFun{{}, fixture.singletonTypes->functionType}; + moduleScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, arena.addType(NegationType{t})}; + moduleScope->exportedTypeBindings["fun"] = TypeFun{{}, fixture.builtinTypes->functionType}; + moduleScope->exportedTypeBindings["cls"] = TypeFun{{}, fixture.builtinTypes->classType}; } void dump(const std::vector& constraints) diff --git a/tests/Fixture.h b/tests/Fixture.h index 5d838b163..3edd6b4c1 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -10,7 +10,7 @@ #include "Luau/ModuleResolver.h" #include "Luau/Scope.h" #include "Luau/ToString.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "IostreamOptional.h" #include "ScopedFlags.h" @@ -78,7 +78,7 @@ struct Fixture ModulePtr getMainModule(); SourceModule* getMainSourceModule(); - std::optional getPrimitiveType(TypeId ty); + std::optional getPrimitiveType(TypeId ty); std::optional getType(const std::string& name); TypeId requireType(const std::string& name); TypeId requireType(const ModuleName& moduleName, const std::string& name); @@ -102,7 +102,7 @@ struct Fixture Frontend frontend; InternalErrorReporter ice; TypeChecker& typeChecker; - NotNull singletonTypes; + NotNull builtinTypes; std::string decorateWithTypes(const std::string& code); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 6f92b6551..93df5605e 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -791,7 +791,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") ModulePtr module = fe.moduleResolver.getModule("Module/A"); - CHECK_EQ(0, module->internalTypes.typeVars.size()); + CHECK_EQ(0, module->internalTypes.types.size()); CHECK_EQ(0, module->internalTypes.typePacks.size()); CHECK_EQ(0, module->astTypes.size()); CHECK_EQ(0, module->astResolvedTypes.size()); diff --git a/tests/LValue.test.cpp b/tests/LValue.test.cpp index 0bb91ecef..c71d97d16 100644 --- a/tests/LValue.test.cpp +++ b/tests/LValue.test.cpp @@ -14,18 +14,18 @@ static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) // TODO: normalize here also. std::unordered_set s; - if (auto utv = get(follow(a))) + if (auto utv = get(follow(a))) s.insert(begin(utv), end(utv)); else s.insert(a); - if (auto utv = get(follow(b))) + if (auto utv = get(follow(b))) s.insert(begin(utv), end(utv)); else s.insert(b); std::vector options(s.begin(), s.end()); - return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)}); + return options.size() == 1 ? options[0] : arena.addType(UnionType{std::move(options)}); }); } @@ -36,7 +36,7 @@ static LValue mkSymbol(const std::string& s) struct LValueFixture { - SingletonTypes singletonTypes; + BuiltinTypes builtinTypes; }; TEST_SUITE_BEGIN("LValue"); @@ -48,14 +48,14 @@ TEST_CASE_FIXTURE(LValueFixture, "Luau_merge_hashmap_order") std::string c = "c"; RefinementMap m{{ - {mkSymbol(b), singletonTypes.stringType}, - {mkSymbol(c), singletonTypes.numberType}, + {mkSymbol(b), builtinTypes.stringType}, + {mkSymbol(c), builtinTypes.numberType}, }}; RefinementMap other{{ - {mkSymbol(a), singletonTypes.stringType}, - {mkSymbol(b), singletonTypes.stringType}, - {mkSymbol(c), singletonTypes.booleanType}, + {mkSymbol(a), builtinTypes.stringType}, + {mkSymbol(b), builtinTypes.stringType}, + {mkSymbol(c), builtinTypes.booleanType}, }}; TypeArena arena; @@ -78,14 +78,14 @@ TEST_CASE_FIXTURE(LValueFixture, "Luau_merge_hashmap_order2") std::string c = "c"; RefinementMap m{{ - {mkSymbol(a), singletonTypes.stringType}, - {mkSymbol(b), singletonTypes.stringType}, - {mkSymbol(c), singletonTypes.numberType}, + {mkSymbol(a), builtinTypes.stringType}, + {mkSymbol(b), builtinTypes.stringType}, + {mkSymbol(c), builtinTypes.numberType}, }}; RefinementMap other{{ - {mkSymbol(b), singletonTypes.stringType}, - {mkSymbol(c), singletonTypes.booleanType}, + {mkSymbol(b), builtinTypes.stringType}, + {mkSymbol(c), builtinTypes.booleanType}, }}; TypeArena arena; @@ -110,15 +110,15 @@ TEST_CASE_FIXTURE(LValueFixture, "one_map_has_overlap_at_end_whereas_other_has_i std::string e = "e"; RefinementMap m{{ - {mkSymbol(a), singletonTypes.stringType}, - {mkSymbol(b), singletonTypes.numberType}, - {mkSymbol(c), singletonTypes.booleanType}, + {mkSymbol(a), builtinTypes.stringType}, + {mkSymbol(b), builtinTypes.numberType}, + {mkSymbol(c), builtinTypes.booleanType}, }}; RefinementMap other{{ - {mkSymbol(c), singletonTypes.stringType}, - {mkSymbol(d), singletonTypes.numberType}, - {mkSymbol(e), singletonTypes.booleanType}, + {mkSymbol(c), builtinTypes.stringType}, + {mkSymbol(d), builtinTypes.numberType}, + {mkSymbol(e), builtinTypes.booleanType}, }}; TypeArena arena; @@ -159,8 +159,8 @@ TEST_CASE_FIXTURE(LValueFixture, "hashing_lvalue_global_prop_access") CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); RefinementMap m; - m[t_x1] = singletonTypes.stringType; - m[t_x2] = singletonTypes.numberType; + m[t_x1] = builtinTypes.stringType; + m[t_x2] = builtinTypes.numberType; CHECK_EQ(1, m.size()); } @@ -188,8 +188,8 @@ TEST_CASE_FIXTURE(LValueFixture, "hashing_lvalue_local_prop_access") CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); RefinementMap m; - m[t_x1] = singletonTypes.stringType; - m[t_x2] = singletonTypes.numberType; + m[t_x1] = builtinTypes.stringType; + m[t_x2] = builtinTypes.numberType; CHECK_EQ(2, m.size()); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 78ff85c3e..426b520c6 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -610,11 +610,11 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { unfreeze(typeChecker.globalTypes); - TableTypeVar::Props instanceProps{ + TableType::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, }; - TableTypeVar instanceTable{instanceProps, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}; + TableType instanceTable{instanceProps, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}; TypeId instanceType = typeChecker.globalTypes.addType(instanceTable); TypeFun instanceTypeFun{{}, instanceType}; @@ -1448,19 +1448,19 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") { unfreeze(typeChecker.globalTypes); - TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); + TypeId instanceType = typeChecker.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); persist(instanceType); typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; - getMutable(instanceType)->props = { + getMutable(instanceType)->props = { {"Name", {typeChecker.stringType}}, {"DataCost", {typeChecker.numberType, /* deprecated= */ true}}, {"Wait", {typeChecker.anyType, /* deprecated= */ true}}, }; - TypeId colorType = typeChecker.globalTypes.addType(TableTypeVar{{}, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}); + TypeId colorType = typeChecker.globalTypes.addType(TableType{{}, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}); - getMutable(colorType)->props = {{"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; + getMutable(colorType)->props = {{"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; addGlobalBinding(frontend, "Color3", Binding{colorType, {}}); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 33d9c75a7..5f97fb6cd 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -59,14 +59,14 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") // Create a new number type that isn't persistent unfreeze(typeChecker.globalTypes); - TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number}); + TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveType{PrimitiveType::Number}); freeze(typeChecker.globalTypes); TypeId newNumber = clone(oldNumber, dest, cloneState); CHECK_NE(newNumber, oldNumber); CHECK_EQ(*oldNumber, *newNumber); CHECK_EQ("number", toString(newNumber)); - CHECK_EQ(1, dest.typeVars.size()); + CHECK_EQ(1, dest.types.size()); } TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") @@ -91,7 +91,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") CloneState cloneState; TypeId counterCopy = clone(counterType, dest, cloneState); - TableTypeVar* ttv = getMutable(counterCopy); + TableType* ttv = getMutable(counterCopy); REQUIRE(ttv != nullptr); CHECK_EQ(std::optional{"Cyclic"}, ttv->syntheticName); @@ -99,7 +99,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TypeId methodType = ttv->props["get"].type; REQUIRE(methodType != nullptr); - const FunctionTypeVar* ftv = get(methodType); + const FunctionType* ftv = get(methodType); REQUIRE(ftv != nullptr); std::optional methodReturnType = first(ftv->retTypes); @@ -107,7 +107,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") CHECK_EQ(methodReturnType, counterCopy); CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type - CHECK_EQ(2, dest.typeVars.size()); // One table and one function + CHECK_EQ(2, dest.types.size()); // One table and one function } TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") @@ -124,7 +124,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") REQUIRE(isInArena(*exports, module->interfaceTypes)); - TableTypeVar* exportsTable = getMutable(*exports); + TableType* exportsTable = getMutable(*exports); REQUIRE(exportsTable != nullptr); TypeId signType = exportsTable->props["sign"].type; @@ -143,13 +143,13 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") CloneState cloneState; unfreeze(typeChecker.globalTypes); - TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); + TypeId oldUnion = typeChecker.globalTypes.addType(UnionType{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); TypeId newUnion = clone(oldUnion, dest, cloneState); CHECK_NE(newUnion, oldUnion); CHECK_EQ("number | string", toString(newUnion)); - CHECK_EQ(1, dest.typeVars.size()); + CHECK_EQ(1, dest.types.size()); } TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") @@ -158,23 +158,23 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") CloneState cloneState; unfreeze(typeChecker.globalTypes); - TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); + TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionType{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); TypeId newIntersection = clone(oldIntersection, dest, cloneState); CHECK_NE(newIntersection, oldIntersection); CHECK_EQ("number & string", toString(newIntersection)); - CHECK_EQ(1, dest.typeVars.size()); + CHECK_EQ(1, dest.types.size()); } TEST_CASE_FIXTURE(Fixture, "clone_class") { - TypeVar exampleMetaClass{ClassTypeVar{"ExampleClassMeta", + Type exampleMetaClass{ClassType{"ExampleClassMeta", { {"__add", {typeChecker.anyType}}, }, std::nullopt, std::nullopt, {}, {}, "Test"}}; - TypeVar exampleClass{ClassTypeVar{"ExampleClass", + Type exampleClass{ClassType{"ExampleClass", { {"PropOne", {typeChecker.numberType}}, {"PropTwo", {typeChecker.stringType}}, @@ -185,11 +185,11 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") CloneState cloneState; TypeId cloned = clone(&exampleClass, dest, cloneState); - const ClassTypeVar* ctv = get(cloned); + const ClassType* ctv = get(cloned); REQUIRE(ctv != nullptr); REQUIRE(ctv->metatable); - const ClassTypeVar* metatable = get(*ctv->metatable); + const ClassType* metatable = get(*ctv->metatable); REQUIRE(metatable); CHECK_EQ("ExampleClass", ctv->name); @@ -198,14 +198,14 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_free_types") { - TypeVar freeTy(FreeTypeVar{TypeLevel{}}); + Type freeTy(FreeType{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); TypeArena dest; CloneState cloneState; TypeId clonedTy = clone(&freeTy, dest, cloneState); - CHECK(get(clonedTy)); + CHECK(get(clonedTy)); cloneState = {}; TypePackId clonedTp = clone(&freeTp, dest, cloneState); @@ -214,15 +214,15 @@ TEST_CASE_FIXTURE(Fixture, "clone_free_types") TEST_CASE_FIXTURE(Fixture, "clone_free_tables") { - TypeVar tableTy{TableTypeVar{}}; - TableTypeVar* ttv = getMutable(&tableTy); + Type tableTy{TableType{}}; + TableType* ttv = getMutable(&tableTy); ttv->state = TableState::Free; TypeArena dest; CloneState cloneState; TypeId cloned = clone(&tableTy, dest, cloneState); - const TableTypeVar* clonedTtv = get(cloned); + const TableType* clonedTtv = get(cloned); CHECK_EQ(clonedTtv->state, TableState::Free); } @@ -264,14 +264,14 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TypeArena src; - TypeId table = src.addType(TableTypeVar{}); + TypeId table = src.addType(TableType{}); TypeId nested = table; for (int i = 0; i < limit + 100; i++) { - TableTypeVar* ttv = getMutable(nested); + TableType* ttv = getMutable(nested); - ttv->props["a"].type = src.addType(TableTypeVar{}); + ttv->props["a"].type = src.addType(TableType{}); nested = ttv->props["a"].type; } @@ -332,7 +332,7 @@ return {} REQUIRE(modBiter != modB->getModuleScope()->exportedTypeBindings.end()); TypeId typeA = modAiter->second.type; TypeId typeB = modBiter->second.type; - TableTypeVar* tableB = getMutable(typeB); + TableType* tableB = getMutable(typeB); REQUIRE(tableB); CHECK(typeA == tableB->props["q"].type); } @@ -368,8 +368,8 @@ return exports std::optional typeB = first(modB->getModuleScope()->returnType); REQUIRE(typeA); REQUIRE(typeB); - TableTypeVar* tableA = getMutable(*typeA); - TableTypeVar* tableB = getMutable(*typeB); + TableType* tableA = getMutable(*typeA); + TableType* tableB = getMutable(*typeB); CHECK(tableA->props["a"].type == tableB->props["b"].type); } diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 89aab5eef..8a25a5e59 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -23,7 +23,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") TypeId fooType = requireType("foo"); REQUIRE(fooType); - const FunctionTypeVar* ftv = get(fooType); + const FunctionType* ftv = get(fooType); REQUIRE_MESSAGE(ftv != nullptr, "Expected a function, got " << toString(fooType)); auto args = flatten(ftv->argTypes).first; @@ -165,7 +165,7 @@ TEST_CASE_FIXTURE(Fixture, "table_props_are_any") LUAU_REQUIRE_NO_ERRORS(result); - TableTypeVar* ttv = getMutable(requireType("T")); + TableType* ttv = getMutable(requireType("T")); REQUIRE(ttv != nullptr); @@ -189,12 +189,12 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") LUAU_REQUIRE_NO_ERRORS(result); - TableTypeVar* ttv = getMutable(requireType("T")); + TableType* ttv = getMutable(requireType("T")); REQUIRE_MESSAGE(ttv, "Should be a table: " << toString(requireType("T"))); CHECK_EQ(*typeChecker.anyType, *ttv->props["one"].type); CHECK_EQ(*typeChecker.anyType, *ttv->props["two"].type); - CHECK_MESSAGE(get(ttv->props["three"].type), "Should be a function: " << *ttv->props["three"].type); + CHECK_MESSAGE(get(ttv->props["three"].type), "Should be a function: " << *ttv->props["three"].type); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_iterator_variables_are_any") diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index e6bf00a12..ba9f5c525 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -3,7 +3,7 @@ #include "Fixture.h" #include "Luau/Common.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "doctest.h" #include "Luau/Normalize.h" @@ -17,7 +17,7 @@ struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { - return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); + return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, builtinTypes, ice); } }; } // namespace @@ -28,9 +28,9 @@ void createSomeClasses(Frontend& frontend) unfreeze(arena); - TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); + TypeId parentType = arena.addType(ClassType{"Parent", {}, frontend.builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); - ClassTypeVar* parentClass = getMutable(parentType); + ClassType* parentClass = getMutable(parentType); parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; @@ -38,15 +38,15 @@ void createSomeClasses(Frontend& frontend) addGlobalBinding(frontend, "Parent", {parentType}); frontend.getGlobalScope()->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; - TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); + TypeId childType = arena.addType(ClassType{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); - ClassTypeVar* childClass = getMutable(childType); + ClassType* childClass = getMutable(childType); childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; addGlobalBinding(frontend, "Child", {childType}); frontend.getGlobalScope()->exportedTypeBindings["Child"] = TypeFun{{}, childType}; - TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); + TypeId unrelatedType = arena.addType(ClassType{"Unrelated", {}, frontend.builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); addGlobalBinding(frontend, "Unrelated", {unrelatedType}); frontend.getGlobalScope()->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; @@ -394,11 +394,12 @@ TEST_SUITE_END(); struct NormalizeFixture : Fixture { ScopedFastFlag sff1{"LuauNegatedFunctionTypes", true}; + ScopedFastFlag sff2{"LuauNegatedClassTypes", true}; TypeArena arena; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; - Normalizer normalizer{&arena, singletonTypes, NotNull{&unifierState}}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&unifierState}}; NormalizeFixture() { @@ -524,7 +525,9 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_function_and_top_function") TEST_CASE_FIXTURE(NormalizeFixture, "negated_function_is_anything_except_a_function") { - CHECK("(boolean | number | string | thread)?" == toString(normal(R"( + ScopedFastFlag{"LuauNegatedClassTypes", true}; + + CHECK("(boolean | class | number | string | thread)?" == toString(normal(R"( Not )"))); } @@ -536,8 +539,9 @@ TEST_CASE_FIXTURE(NormalizeFixture, "specific_functions_cannot_be_negated") TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean") { + ScopedFastFlag{"LuauNegatedClassTypes", true}; // TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function - CHECK("(function | number | string | thread)?" == toString(normal(R"( + CHECK("(class | function | number | string | thread)?" == toString(normal(R"( Not )"))); } @@ -603,4 +607,61 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes") +{ + ScopedFastFlag sff{"LuauNegatedClassTypes", true}; + + createSomeClasses(frontend); + CHECK("Parent | Unrelated" == toString(normal("Parent | Unrelated"))); + CHECK("Parent" == toString(normal("Parent | Child"))); + CHECK("Parent | Unrelated" == toString(normal("Parent | Child | Unrelated"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes") +{ + ScopedFastFlag sff{"LuauNegatedClassTypes", true}; + + createSomeClasses(frontend); + CHECK("Child" == toString(normal("Parent & Child"))); + CHECK("never" == toString(normal("Child & Unrelated"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") +{ + ScopedFastFlag sff{"LuauNegatedClassTypes", true}; + + createSomeClasses(frontend); + CHECK("Child" == toString(normal("(Child | Unrelated) & Child"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") +{ + ScopedFastFlag sff{"LuauNegatedClassTypes", true}; + + createSomeClasses(frontend); + CHECK("(Parent & ~Child) | Unrelated" == toString(normal("(Parent & Not) | Unrelated"))); + CHECK("((class & ~Child) | boolean | function | number | string | thread)?" == toString(normal("Not"))); + CHECK("Child" == toString(normal("Not & Child"))); + CHECK("((class & ~Parent) | Child | boolean | function | number | string | thread)?" == toString(normal("Not | Child"))); + CHECK("(boolean | function | number | string | thread)?" == toString(normal("Not"))); + CHECK("(Parent | Unrelated | boolean | function | number | string | thread)?" == + toString(normal("Not & Not & Not>"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_unknown") +{ + ScopedFastFlag sff{"LuauNegatedClassTypes", true}; + + createSomeClasses(frontend); + CHECK("Parent" == toString(normal("Parent & unknown"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_never") +{ + ScopedFastFlag sff{"LuauNegatedClassTypes", true}; + + createSomeClasses(frontend); + CHECK("never" == toString(normal("Parent & never"))); +} + TEST_SUITE_END(); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 26c9a1ee4..dc08ae1c5 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -17,16 +17,16 @@ struct ToDotClassFixture : Fixture unfreeze(arena); - TypeId baseClassMetaType = arena.addType(TableTypeVar{}); + TypeId baseClassMetaType = arena.addType(TableType{}); - TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test"}); - getMutable(baseClassInstanceType)->props = { + TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test"}); + getMutable(baseClassInstanceType)->props = { {"BaseField", {typeChecker.numberType}}, }; typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test"}); - getMutable(childClassInstanceType)->props = { + TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test"}); + getMutable(childClassInstanceType)->props = { {"ChildField", {typeChecker.stringType}}, }; typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; @@ -66,12 +66,12 @@ n1 [label="any"]; opts.duplicatePrimitives = false; CHECK_EQ(R"(digraph graphname { -n1 [label="PrimitiveTypeVar number"]; +n1 [label="PrimitiveType number"]; })", toDot(requireType("b"), opts)); CHECK_EQ(R"(digraph graphname { -n1 [label="AnyTypeVar 1"]; +n1 [label="AnyType 1"]; })", toDot(requireType("c"), opts)); } @@ -90,7 +90,7 @@ local b = a() ToDotOptions opts; opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="BoundTypeVar 1"]; +n1 [label="BoundType 1"]; n1 -> n2; n2 [label="number"]; })", @@ -110,11 +110,11 @@ local function f(a, ...: string) return a end opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="FunctionTypeVar 1"]; +n1 [label="FunctionType 1"]; n1 -> n2 [label="arg"]; n2 [label="TypePack 2"]; n2 -> n3; -n3 [label="GenericTypeVar 3"]; +n3 [label="GenericType 3"]; n2 -> n4 [label="tail"]; n4 [label="VariadicTypePack 4"]; n4 -> n5; @@ -138,7 +138,7 @@ local a: string | number ToDotOptions opts; opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="UnionTypeVar 1"]; +n1 [label="UnionType 1"]; n1 -> n2; n2 [label="string"]; n1 -> n3; @@ -157,7 +157,7 @@ local a: string & number -- uninhabited ToDotOptions opts; opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="IntersectionTypeVar 1"]; +n1 [label="IntersectionType 1"]; n1 -> n2; n2 [label="string"]; n1 -> n3; @@ -177,11 +177,11 @@ local a: A ToDotOptions opts; opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="TableTypeVar A"]; +n1 [label="TableType A"]; n1 -> n2 [label="x"]; n2 [label="number"]; n1 -> n3 [label="y"]; -n3 [label="FunctionTypeVar 3"]; +n3 [label="FunctionType 3"]; n3 -> n4 [label="arg"]; n4 [label="VariadicTypePack 4"]; n4 -> n5; @@ -212,47 +212,47 @@ local a: typeof(setmetatable({}, {})) ToDotOptions opts; opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="MetatableTypeVar 1"]; +n1 [label="MetatableType 1"]; n1 -> n2 [label="table"]; -n2 [label="TableTypeVar 2"]; +n2 [label="TableType 2"]; n1 -> n3 [label="metatable"]; -n3 [label="TableTypeVar 3"]; +n3 [label="TableType 3"]; })", toDot(requireType("a"), opts)); } TEST_CASE_FIXTURE(Fixture, "free") { - TypeVar type{TypeVariant{FreeTypeVar{TypeLevel{0, 0}}}}; + Type type{TypeVariant{FreeType{TypeLevel{0, 0}}}}; ToDotOptions opts; opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="FreeTypeVar 1"]; +n1 [label="FreeType 1"]; })", toDot(&type, opts)); } TEST_CASE_FIXTURE(Fixture, "error") { - TypeVar type{TypeVariant{ErrorTypeVar{}}}; + Type type{TypeVariant{ErrorType{}}}; ToDotOptions opts; opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="ErrorTypeVar 1"]; +n1 [label="ErrorType 1"]; })", toDot(&type, opts)); } TEST_CASE_FIXTURE(Fixture, "generic") { - TypeVar type{TypeVariant{GenericTypeVar{"T"}}}; + Type type{TypeVariant{GenericType{"T"}}}; ToDotOptions opts; opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="GenericTypeVar T"]; +n1 [label="GenericType T"]; })", toDot(&type, opts)); } @@ -267,15 +267,15 @@ local a: ChildClass ToDotOptions opts; opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="ClassTypeVar ChildClass"]; +n1 [label="ClassType ChildClass"]; n1 -> n2 [label="ChildField"]; n2 [label="string"]; n1 -> n3 [label="[parent]"]; -n3 [label="ClassTypeVar BaseClass"]; +n3 [label="ClassType BaseClass"]; n3 -> n4 [label="BaseField"]; n4 [label="number"]; n3 -> n5 [label="[metatable]"]; -n5 [label="TableTypeVar 5"]; +n5 [label="TableType 5"]; })", toDot(requireType("a"), opts)); } @@ -358,16 +358,16 @@ b = a ToDotOptions opts; opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="TableTypeVar 1"]; +n1 [label="TableType 1"]; n1 -> n2 [label="boundTo"]; -n2 [label="TableTypeVar a"]; +n2 [label="TableType a"]; n2 -> n3 [label="x"]; n3 [label="number"]; })", toDot(*ty, opts)); } -TEST_CASE_FIXTURE(Fixture, "singletontypes") +TEST_CASE_FIXTURE(Fixture, "builtintypes") { CheckResult result = check(R"( local x: "hi" | "\"hello\"" | true | false @@ -377,17 +377,17 @@ TEST_CASE_FIXTURE(Fixture, "singletontypes") opts.showPointers = false; CHECK_EQ(R"(digraph graphname { -n1 [label="UnionTypeVar 1"]; +n1 [label="UnionType 1"]; n1 -> n2; -n2 [label="SingletonTypeVar string: hi"]; +n2 [label="SingletonType string: hi"]; n1 -> n3; )" - "n3 [label=\"SingletonTypeVar string: \\\"hello\\\"\"];" + "n3 [label=\"SingletonType string: \\\"hello\\\"\"];" R"( n1 -> n4; -n4 [label="SingletonTypeVar boolean: true"]; +n4 [label="SingletonType boolean: true"]; n1 -> n5; -n5 [label="SingletonTypeVar boolean: false"]; +n5 [label="SingletonType boolean: false"]; })", toDot(requireType("x"), opts)); } diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 05f49422b..7ccda8def 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -5,6 +5,7 @@ #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" using namespace Luau; @@ -44,8 +45,8 @@ TEST_CASE_FIXTURE(Fixture, "free_types") TEST_CASE_FIXTURE(Fixture, "cyclic_table") { - TypeVar cyclicTable{TypeVariant(TableTypeVar())}; - TableTypeVar* tableOne = getMutable(&cyclicTable); + Type cyclicTable{TypeVariant(TableType())}; + TableType* tableOne = getMutable(&cyclicTable); tableOne->props["self"] = {&cyclicTable}; CHECK_EQ("t1 where t1 = { self: t1 }", toString(&cyclicTable)); @@ -53,8 +54,8 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table") TEST_CASE_FIXTURE(Fixture, "named_table") { - TypeVar table{TypeVariant(TableTypeVar())}; - TableTypeVar* t = getMutable(&table); + Type table{TypeVariant(TableType())}; + TableType* t = getMutable(&table); t->name = "TheTable"; CHECK_EQ("TheTable", toString(&table)); @@ -94,19 +95,43 @@ TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") //clang-format on } +TEST_CASE_FIXTURE(Fixture, "nil_or_nil_is_nil_not_question_mark") +{ + ScopedFastFlag sff("LuauSerializeNilUnionAsNil", true); + CheckResult result = check(R"( + type nil_ty = nil | nil + local a : nil_ty = nil + )"); + ToStringOptions opts; + opts.useLineBreaks = false; + CHECK_EQ("nil", toString(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "long_disjunct_of_nil_is_nil_not_question_mark") +{ + ScopedFastFlag sff("LuauSerializeNilUnionAsNil", true); + CheckResult result = check(R"( + type nil_ty = nil | nil | nil | nil | nil + local a : nil_ty = nil + )"); + ToStringOptions opts; + opts.useLineBreaks = false; + CHECK_EQ("nil", toString(requireType("a"), opts)); +} + TEST_CASE_FIXTURE(Fixture, "metatable") { - TypeVar table{TypeVariant(TableTypeVar())}; - TypeVar metatable{TypeVariant(TableTypeVar())}; - TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable})}; + Type table{TypeVariant(TableType())}; + Type metatable{TypeVariant(TableType())}; + Type mtv{TypeVariant(MetatableType{&table, &metatable})}; CHECK_EQ("{ @metatable { }, { } }", toString(&mtv)); } TEST_CASE_FIXTURE(Fixture, "named_metatable") { - TypeVar table{TypeVariant(TableTypeVar())}; - TypeVar metatable{TypeVariant(TableTypeVar())}; - TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable, "NamedMetatable"})}; + Type table{TypeVariant(TableType())}; + Type metatable{TypeVariant(TableType())}; + Type mtv{TypeVariant(MetatableType{&table, &metatable, "NamedMetatable"})}; CHECK_EQ("NamedMetatable", toString(&mtv)); } @@ -120,7 +145,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "named_metatable_toStringNamedFunction") )"); TypeId ty = requireType("createTbl"); - const FunctionTypeVar* ftv = get(follow(ty)); + const FunctionType* ftv = get(follow(ty)); REQUIRE(ftv); CHECK_EQ("createTbl(): NamedMetatable", toStringNamedFunction("createTbl", *ftv)); } @@ -162,16 +187,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") TEST_CASE_FIXTURE(Fixture, "intersection_parenthesized_only_if_needed") { - auto utv = TypeVar{UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}}; - auto itv = TypeVar{IntersectionTypeVar{{&utv, typeChecker.booleanType}}}; + auto utv = Type{UnionType{{typeChecker.numberType, typeChecker.stringType}}}; + auto itv = Type{IntersectionType{{&utv, typeChecker.booleanType}}}; CHECK_EQ(toString(&itv), "(number | string) & boolean"); } TEST_CASE_FIXTURE(Fixture, "union_parenthesized_only_if_needed") { - auto itv = TypeVar{IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}}; - auto utv = TypeVar{UnionTypeVar{{&itv, typeChecker.booleanType}}}; + auto itv = Type{IntersectionType{{typeChecker.numberType, typeChecker.stringType}}}; + auto utv = Type{UnionType{{&itv, typeChecker.booleanType}}}; CHECK_EQ(toString(&utv), "(number & string) | boolean"); } @@ -181,11 +206,11 @@ TEST_CASE_FIXTURE(Fixture, "functions_are_always_parenthesized_in_unions_or_inte auto stringAndNumberPack = TypePackVar{TypePack{{typeChecker.stringType, typeChecker.numberType}}}; auto numberAndStringPack = TypePackVar{TypePack{{typeChecker.numberType, typeChecker.stringType}}}; - auto sn2ns = TypeVar{FunctionTypeVar{&stringAndNumberPack, &numberAndStringPack}}; - auto ns2sn = TypeVar{FunctionTypeVar(typeChecker.globalScope->level, &numberAndStringPack, &stringAndNumberPack)}; + auto sn2ns = Type{FunctionType{&stringAndNumberPack, &numberAndStringPack}}; + auto ns2sn = Type{FunctionType(typeChecker.globalScope->level, &numberAndStringPack, &stringAndNumberPack)}; - auto utv = TypeVar{UnionTypeVar{{&ns2sn, &sn2ns}}}; - auto itv = TypeVar{IntersectionTypeVar{{&ns2sn, &sn2ns}}}; + auto utv = Type{UnionType{{&ns2sn, &sn2ns}}}; + auto itv = Type{IntersectionType{{&ns2sn, &sn2ns}}}; CHECK_EQ(toString(&utv), "((number, string) -> (string, number)) | ((string, number) -> (number, string))"); CHECK_EQ(toString(&itv), "((number, string) -> (string, number)) & ((string, number) -> (number, string))"); @@ -226,11 +251,11 @@ TEST_CASE_FIXTURE(Fixture, "unions_respects_use_line_breaks") TEST_CASE_FIXTURE(Fixture, "quit_stringifying_table_type_when_length_is_exceeded") { - TableTypeVar ttv{}; + TableType ttv{}; for (char c : std::string("abcdefghijklmno")) ttv.props[std::string(1, c)] = {typeChecker.numberType}; - TypeVar tv{ttv}; + Type tv{ttv}; ToStringOptions o; o.exhaustive = false; @@ -240,11 +265,11 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_table_type_when_length_is_exceeded TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_is_still_capped_when_exhaustive") { - TableTypeVar ttv{}; + TableType ttv{}; for (char c : std::string("abcdefg")) ttv.props[std::string(1, c)] = {typeChecker.numberType}; - TypeVar tv{ttv}; + Type tv{ttv}; ToStringOptions o; o.exhaustive = true; @@ -315,11 +340,11 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table_state_braces") { - TableTypeVar ttv{TableState::Sealed, TypeLevel{}}; + TableType ttv{TableState::Sealed, TypeLevel{}}; for (char c : std::string("abcdefghij")) ttv.props[std::string(1, c)] = {typeChecker.numberType}; - TypeVar tv{ttv}; + Type tv{ttv}; ToStringOptions o; o.maxTableLength = 40; @@ -328,8 +353,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_union_type_bails_early") { - TypeVar tv{UnionTypeVar{{typeChecker.stringType, typeChecker.numberType}}}; - UnionTypeVar* utv = getMutable(&tv); + Type tv{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; + UnionType* utv = getMutable(&tv); utv->options.push_back(&tv); utv->options.push_back(&tv); @@ -338,8 +363,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_union_type_bails_early") TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_intersection_type_bails_early") { - TypeVar tv{IntersectionTypeVar{}}; - IntersectionTypeVar* itv = getMutable(&tv); + Type tv{IntersectionType{}}; + IntersectionType* itv = getMutable(&tv); itv->parts.push_back(&tv); itv->parts.push_back(&tv); @@ -348,17 +373,17 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_intersection_type_bails_early") TEST_CASE_FIXTURE(Fixture, "stringifying_array_uses_array_syntax") { - TableTypeVar ttv{TableState::Sealed, TypeLevel{}}; + TableType ttv{TableState::Sealed, TypeLevel{}}; ttv.indexer = TableIndexer{typeChecker.numberType, typeChecker.stringType}; - CHECK_EQ("{string}", toString(TypeVar{ttv})); + CHECK_EQ("{string}", toString(Type{ttv})); ttv.props["A"] = {typeChecker.numberType}; - CHECK_EQ("{| [number]: string, A: number |}", toString(TypeVar{ttv})); + CHECK_EQ("{| [number]: string, A: number |}", toString(Type{ttv})); ttv.props.clear(); ttv.state = TableState::Unsealed; - CHECK_EQ("{string}", toString(TypeVar{ttv})); + CHECK_EQ("{string}", toString(Type{ttv})); } @@ -367,7 +392,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_packs_are_stringified_differently_from_gener TypePackVar tpv{GenericTypePack{"a"}}; CHECK_EQ(toString(&tpv), "a..."); - TypeVar tv{GenericTypeVar{"a"}}; + Type tv{GenericType{"a"}}; CHECK_EQ(toString(&tv), "a"); } @@ -444,11 +469,11 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TypeId id3Type = requireType("id3"); ToStringResult nameData = toStringDetailed(id3Type, opts); - REQUIRE(3 == opts.nameMap.typeVars.size()); + REQUIRE(3 == opts.nameMap.types.size()); REQUIRE_EQ("(a, b, c) -> (a, b, c)", nameData.name); - const FunctionTypeVar* ftv = get(follow(id3Type)); + const FunctionType* ftv = get(follow(id3Type)); REQUIRE(ftv != nullptr); auto params = flatten(ftv->argTypes).first; @@ -483,27 +508,27 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") TypeId tType = requireType("inst"); ToStringResult r = toStringDetailed(tType, opts); CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); - CHECK(0 == opts.nameMap.typeVars.size()); + CHECK(0 == opts.nameMap.types.size()); - const MetatableTypeVar* tMeta = get(tType); + const MetatableType* tMeta = get(tType); REQUIRE(tMeta); - TableTypeVar* tMeta2 = getMutable(tMeta->metatable); + TableType* tMeta2 = getMutable(tMeta->metatable); REQUIRE(tMeta2); REQUIRE(tMeta2->props.count("__index")); - const MetatableTypeVar* tMeta3 = get(tMeta2->props["__index"].type); + const MetatableType* tMeta3 = get(tMeta2->props["__index"].type); REQUIRE(tMeta3); - TableTypeVar* tMeta4 = getMutable(tMeta3->metatable); + TableType* tMeta4 = getMutable(tMeta3->metatable); REQUIRE(tMeta4); REQUIRE(tMeta4->props.count("__index")); - TableTypeVar* tMeta5 = getMutable(tMeta4->props["__index"].type); + TableType* tMeta5 = getMutable(tMeta4->props["__index"].type); REQUIRE(tMeta5); REQUIRE(tMeta5->props.count("one") > 0); - TableTypeVar* tMeta6 = getMutable(tMeta3->table); + TableType* tMeta6 = getMutable(tMeta3->table); REQUIRE(tMeta6); REQUIRE(tMeta6->props.count("two") > 0); @@ -537,16 +562,16 @@ function foo(a, b) return a(b) end TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_TypePack") { - TypeVar tv1{TableTypeVar{}}; - TableTypeVar* ttv = getMutable(&tv1); + Type tv1{TableType{}}; + TableType* ttv = getMutable(&tv1); ttv->state = TableState::Sealed; ttv->props["hello"] = {typeChecker.numberType}; ttv->props["world"] = {typeChecker.numberType}; TypePackVar tpv1{TypePack{{&tv1}}}; - TypeVar tv2{TableTypeVar{}}; - TableTypeVar* bttv = getMutable(&tv2); + Type tv2{TableType{}}; + TableType* bttv = getMutable(&tv2); bttv->state = TableState::Free; bttv->props["hello"] = {typeChecker.numberType}; bttv->boundTo = &tv1; @@ -560,12 +585,12 @@ TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_T TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_return_type_if_pack_has_an_empty_head_link") { TypeArena arena; - TypePackId realTail = arena.addTypePack({singletonTypes->stringType}); + TypePackId realTail = arena.addTypePack({builtinTypes->stringType}); TypePackId emptyTail = arena.addTypePack({}, realTail); - TypePackId argList = arena.addTypePack({singletonTypes->stringType}); + TypePackId argList = arena.addTypePack({builtinTypes->stringType}); - TypeId functionType = arena.addType(FunctionTypeVar{argList, emptyTail}); + TypeId functionType = arena.addType(FunctionType{argList, emptyTail}); CHECK("(string) -> string" == toString(functionType)); } @@ -597,8 +622,8 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_inters TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - TypeVar tableTy{TableTypeVar{}}; - TableTypeVar* ttv = getMutable(&tableTy); + Type tableTy{TableType{}}; + TableType* ttv = getMutable(&tableTy); ttv->name = "Table"; ttv->instantiatedTypeParams.push_back(&tableTy); @@ -612,7 +637,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") )"); TypeId ty = requireType("id"); - const FunctionTypeVar* ftv = get(follow(ty)); + const FunctionType* ftv = get(follow(ty)); CHECK_EQ("id(x: a): a", toStringNamedFunction("id", *ftv)); } @@ -630,7 +655,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") )"); TypeId ty = requireType("map"); - const FunctionTypeVar* ftv = get(follow(ty)); + const FunctionType* ftv = get(follow(ty)); CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); } @@ -646,7 +671,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") )"); TypeId ty = requireType("test"); - const FunctionTypeVar* ftv = get(follow(ty)); + const FunctionType* ftv = get(follow(ty)); CHECK_EQ("test(...: T...): U...", toStringNamedFunction("test", *ftv)); } @@ -654,7 +679,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") TEST_CASE("toStringNamedFunction_unit_f") { TypePackVar empty{TypePack{}}; - FunctionTypeVar ftv{&empty, &empty, {}, false}; + FunctionType ftv{&empty, &empty, {}, false}; CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); } @@ -667,7 +692,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") )"); TypeId ty = requireType("f"); - auto ftv = get(follow(ty)); + auto ftv = get(follow(ty)); CHECK_EQ("f(x: a, ...: any): (a, a, b...)", toStringNamedFunction("f", *ftv)); } @@ -681,7 +706,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") )"); TypeId ty = requireType("f"); - auto ftv = get(follow(ty)); + auto ftv = get(follow(ty)); CHECK_EQ("f(): ...number", toStringNamedFunction("f", *ftv)); } @@ -695,7 +720,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") )"); TypeId ty = requireType("f"); - auto ftv = get(follow(ty)); + auto ftv = get(follow(ty)); CHECK_EQ("f(): (string, ...number)", toStringNamedFunction("f", *ftv)); } @@ -707,7 +732,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_ar )"); TypeId ty = requireType("f"); - auto ftv = get(follow(ty)); + auto ftv = get(follow(ty)); CHECK_EQ("f(_: number, y: number): number", toStringNamedFunction("f", *ftv)); } @@ -720,7 +745,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") )"); TypeId ty = requireType("f"); - auto ftv = get(follow(ty)); + auto ftv = get(follow(ty)); ToStringOptions opts; opts.hideNamedFunctionTypeParameters = true; @@ -734,7 +759,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") )"); TypeId ty = requireType("test"); - const FunctionTypeVar* ftv = get(follow(ty)); + const FunctionType* ftv = get(follow(ty)); ToStringOptions opts; opts.namedFunctionOverrideArgNames = {"first", "second", "third"}; @@ -763,8 +788,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") )"); TypeId parentTy = requireType("foo"); - auto ttv = get(follow(parentTy)); - auto ftv = get(ttv->props.at("method").type); + auto ttv = get(follow(parentTy)); + auto ftv = get(ttv->props.at("method").type); CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); } @@ -782,8 +807,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") )"); TypeId parentTy = requireType("foo"); - auto ttv = get(follow(parentTy)); - auto ftv = get(ttv->props.at("method").type); + auto ttv = get(follow(parentTy)); + auto ftv = get(ttv->props.at("method").type); ToStringOptions opts; opts.hideFunctionSelfArgument = true; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index e79bc9b76..bcaf088ee 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -2,7 +2,7 @@ #include "Luau/Parser.h" #include "Luau/TypeAttach.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Transpiler.h" #include "Fixture.h" diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 38e246c8d..4dd822690 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -78,8 +78,8 @@ TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") Location{{1, 21}, {1, 26}}, getMainSourceModule()->name, TypeMismatch{ - singletonTypes->numberType, - singletonTypes->stringType, + builtinTypes->numberType, + builtinTypes->stringType, }, }); } @@ -89,8 +89,8 @@ TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") Location{{1, 8}, {1, 26}}, getMainSourceModule()->name, TypeMismatch{ - singletonTypes->numberType, - singletonTypes->stringType, + builtinTypes->numberType, + builtinTypes->stringType, }, }); } @@ -512,13 +512,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_multi_assign") std::optional aTypeId = lookupName(m->getModuleScope(), "a"); REQUIRE(aTypeId); - const Luau::TableTypeVar* aType = get(follow(*aTypeId)); + const Luau::TableType* aType = get(follow(*aTypeId)); REQUIRE(aType); REQUIRE(aType->props.size() == 2); std::optional bTypeId = lookupName(m->getModuleScope(), "b"); REQUIRE(bTypeId); - const Luau::TableTypeVar* bType = get(follow(*bTypeId)); + const Luau::TableType* bType = get(follow(*bTypeId)); REQUIRE(bType); REQUIRE(bType->props.size() == 3); } @@ -535,7 +535,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") else CHECK(toString(ty) == "table"); - const TableTypeVar* ttv = get(ty); + const TableType* ttv = get(ty); REQUIRE(ttv); CHECK(ttv->instantiatedTypeParams.empty()); @@ -554,7 +554,7 @@ type NotCool = Cool REQUIRE(ty); CHECK_EQ(toString(*ty), "Cool"); - const TableTypeVar* ttv = get(*ty); + const TableType* ttv = get(*ty); REQUIRE(ttv); CHECK(ttv->instantiatedTypeParams.empty()); @@ -590,7 +590,7 @@ type Cool = typeof(c) std::optional ty = requireType("c"); REQUIRE(ty); - const TableTypeVar* ttv = get(*ty); + const TableType* ttv = get(*ty); REQUIRE(ttv); CHECK_EQ(ttv->name, "Cool"); } @@ -801,9 +801,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_quantify_unresolved_aliases") } /* - * We keep a cache of type alias onto TypeVar to prevent infinite types from + * We keep a cache of type alias onto Type to prevent infinite types from * being constructed via recursive or corecursive aliases. We have to adjust - * the TypeLevels of those generic TypeVars so that the unifier doesn't think + * the TypeLevels of those generic Types so that the unifier doesn't think * they have improperly leaked out of their scope. */ TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_scope_if_they_are_reused_in_multiple_aliases") @@ -817,7 +817,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_ } /* - * The two-pass alias definition system starts by ascribing a free TypeVar to each alias. It then + * The two-pass alias definition system starts by ascribing a free Type to each alias. It then * circles back to fill in the actual type later on. * * If this free type is unified with something degenerate like `any`, we need to take extra care @@ -913,11 +913,11 @@ TEST_CASE_FIXTURE(Fixture, "report_shadowed_aliases") std::optional t1 = lookupType("MyString"); REQUIRE(t1); - CHECK(isPrim(*t1, PrimitiveTypeVar::String)); + CHECK(isPrim(*t1, PrimitiveType::String)); std::optional t2 = lookupType("string"); REQUIRE(t2); - CHECK(isPrim(*t2, PrimitiveTypeVar::String)); + CHECK(isPrim(*t2, PrimitiveType::String)); } TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_shadow_user_defined_alias") diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index b94e1df04..bf66ecbc9 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -77,7 +77,7 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") LUAU_REQUIRE_NO_ERRORS(result); TypeId fiftyType = requireType("fifty"); - const FunctionTypeVar* ftv = get(fiftyType); + const FunctionType* ftv = get(fiftyType); REQUIRE(ftv != nullptr); TypePackId retPack = follow(ftv->retTypes); @@ -182,8 +182,8 @@ TEST_CASE_FIXTURE(Fixture, "table_annotation") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(follow(requireType("y")))); - CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(follow(requireType("z")))); + CHECK_EQ(PrimitiveType::Number, getPrimitiveType(follow(requireType("y")))); + CHECK_EQ(PrimitiveType::String, getPrimitiveType(follow(requireType("z")))); } TEST_CASE_FIXTURE(Fixture, "function_annotation") @@ -196,7 +196,7 @@ TEST_CASE_FIXTURE(Fixture, "function_annotation") dumpErrors(result); TypeId fType = requireType("f"); - const FunctionTypeVar* ftv = get(follow(fType)); + const FunctionType* ftv = get(follow(fType)); REQUIRE(ftv != nullptr); } @@ -208,7 +208,7 @@ TEST_CASE_FIXTURE(Fixture, "function_annotation_with_a_defined_function") )"); TypeId fType = requireType("f"); - const FunctionTypeVar* ftv = get(follow(fType)); + const FunctionType* ftv = get(follow(fType)); REQUIRE(ftv != nullptr); LUAU_REQUIRE_NO_ERRORS(result); @@ -323,13 +323,13 @@ TEST_CASE_FIXTURE(Fixture, "self_referential_type_alias") REQUIRE(res); TypeId oType = follow(res->type); - const TableTypeVar* oTable = get(oType); + const TableType* oTable = get(oType); REQUIRE(oTable); std::optional incr = get(oTable->props, "incr"); REQUIRE(incr); - const FunctionTypeVar* incrFunc = get(incr->type); + const FunctionType* incrFunc = get(incr->type); REQUIRE(incrFunc); std::optional firstArg = first(incrFunc->argTypes); @@ -441,7 +441,7 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") )"); TypeId fType = requireType("aa"); - const AnyTypeVar* ftv = get(follow(fType)); + const AnyType* ftv = get(follow(fType)); REQUIRE(ftv != nullptr); REQUIRE(!result.errors.empty()); } @@ -483,7 +483,7 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") std::optional exportsType = first(mod.getModuleScope()->returnType); REQUIRE(exportsType); - TableTypeVar* exportsTable = getMutable(*exportsType); + TableType* exportsTable = getMutable(*exportsType); REQUIRE(exportsTable != nullptr); TypeId n = exportsTable->props["n"].type; @@ -509,7 +509,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") REQUIRE_EQ(1, array.typeParams.size()); - const TableTypeVar* arrayTable = get(array.type); + const TableType* arrayTable = get(array.type); REQUIRE(arrayTable != nullptr); CHECK_EQ(0, arrayTable->props.size()); @@ -538,7 +538,7 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti std::optional exportsType = first(mod.getModuleScope()->returnType); REQUIRE(exportsType); - TableTypeVar* exportsTable = getMutable(*exportsType); + TableType* exportsTable = getMutable(*exportsType); REQUIRE(exportsTable != nullptr); TypeId aType = exportsTable->props["a"].type; @@ -740,8 +740,8 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_type_fun_should_not_trip_rbxassert") } #if 0 -// This is because, after visiting all nodes in a block, we check if each type alias still points to a FreeTypeVar. -// Doing it that way is wrong, but I also tried to make typeof(x) return a BoundTypeVar, with no luck. +// This is because, after visiting all nodes in a block, we check if each type alias still points to a FreeType. +// Doing it that way is wrong, but I also tried to make typeof(x) return a BoundType, with no luck. // Not important enough to fix today. TEST_CASE_FIXTURE(Fixture, "pulling_a_type_from_value_dont_falsely_create_occurs_check_failed") { @@ -755,7 +755,7 @@ TEST_CASE_FIXTURE(Fixture, "pulling_a_type_from_value_dont_falsely_create_occurs } #endif -TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_union_typevar") +TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_union_type") { CheckResult result = check(R"( type T = T | T @@ -767,7 +767,7 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_union_typevar") REQUIRE(ocf); } -TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_intersection_typevar") +TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_intersection_type") { CheckResult result = check(R"( type T = T & T diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 91201812f..9988a1fc5 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -4,8 +4,8 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include "Fixture.h" @@ -175,7 +175,7 @@ TEST_CASE_FIXTURE(Fixture, "can_get_length_of_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("bar"))); + CHECK_EQ(PrimitiveType::Number, getPrimitiveType(requireType("bar"))); } TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") @@ -191,7 +191,7 @@ TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") LUAU_REQUIRE_NO_ERRORS(result); - TableTypeVar* ttv = getMutable(requireType("T")); + TableType* ttv = getMutable(requireType("T")); REQUIRE(ttv); REQUIRE(ttv->props.count("prop")); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 32e31e16e..6c2d31088 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -185,7 +185,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_tables_sealed") )LUA"); TypeId bit32 = requireType("b"); REQUIRE(bit32 != nullptr); - const TableTypeVar* bit32t = get(bit32); + const TableType* bit32t = get(bit32); REQUIRE(bit32t != nullptr); CHECK_EQ(bit32t->state, TableState::Sealed); } @@ -508,7 +508,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_should_not_mutate_persisted_typ LUAU_REQUIRE_ERROR_COUNT(1, result); auto stringType = requireType("string"); - auto ttv = get(stringType); + auto ttv = get(stringType); REQUIRE(ttv); } @@ -915,13 +915,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_add_definitions_to_persistent_types") LUAU_REQUIRE_NO_ERRORS(result); TypeId fType = requireType("f"); - const FunctionTypeVar* ftv = get(fType); + const FunctionType* ftv = get(fType); REQUIRE(fType); REQUIRE(fType->persistent); REQUIRE(!ftv->definition); TypeId gType = requireType("g"); - const FunctionTypeVar* gtv = get(gType); + const FunctionType* gtv = get(gType); REQUIRE(gType); REQUIRE(!gType->persistent); REQUIRE(gtv->definition); @@ -1029,9 +1029,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_persistent_typelevel_change") { TypeId mathTy = requireType(typeChecker.globalScope, "math"); REQUIRE(mathTy); - TableTypeVar* ttv = getMutable(mathTy); + TableType* ttv = getMutable(mathTy); REQUIRE(ttv); - const FunctionTypeVar* ftv = get(ttv->props["frexp"].type); + const FunctionType* ftv = get(ttv->props["frexp"].type); REQUIRE(ftv); auto original = ftv->level; diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 07dfc33fe..28315b676 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -2,7 +2,7 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" #include "ClassFixture.h" diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 26115046d..93b405c25 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -294,7 +294,7 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") REQUIRE(bool(barTy)); CHECK_EQ(barTy->type->documentationSymbol, "@test/globaltype/Bar"); - ClassTypeVar* barClass = getMutable(barTy->type); + ClassType* barClass = getMutable(barTy->type); REQUIRE(bool(barClass)); REQUIRE_EQ(barClass->props.count("prop"), 1); CHECK_EQ(barClass->props["prop"].documentationSymbol, "@test/globaltype/Bar.prop"); @@ -303,7 +303,7 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") REQUIRE(bool(yBinding)); CHECK_EQ(yBinding->documentationSymbol, "@test/global/y"); - TableTypeVar* yTtv = getMutable(yBinding->typeId); + TableType* yTtv = getMutable(yBinding->typeId); REQUIRE(bool(yTtv)); REQUIRE_EQ(yTtv->props.count("x"), 1); CHECK_EQ(yTtv->props["x"].documentationSymbol, "@test/global/y.x"); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 552180401..a97cea21c 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -5,8 +5,8 @@ #include "Luau/Error.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include "Fixture.h" @@ -23,7 +23,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_function") CheckResult result = check("function five() return 5 end"); LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* fiveType = get(requireType("five")); + const FunctionType* fiveType = get(requireType("five")); REQUIRE(fiveType != nullptr); } @@ -64,7 +64,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_return_type") CheckResult result = check("function take_five() return 5 end"); LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* takeFiveType = get(requireType("take_five")); + const FunctionType* takeFiveType = get(requireType("take_five")); REQUIRE(takeFiveType != nullptr); std::vector retVec = flatten(takeFiveType->retTypes).first; @@ -132,7 +132,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") auto r = first(getMainModule()->getModuleScope()->returnType); REQUIRE(r); - TableTypeVar* ttv = getMutable(*r); + TableType* ttv = getMutable(*r); REQUIRE(ttv); REQUIRE(ttv->props.count("f")); @@ -389,14 +389,14 @@ TEST_CASE_FIXTURE(Fixture, "local_function") TypeId h = follow(requireType("h")); - const FunctionTypeVar* ftv = get(h); + const FunctionType* ftv = get(h); REQUIRE(ftv != nullptr); std::optional rt = first(ftv->retTypes); REQUIRE(bool(rt)); TypeId retType = follow(*rt); - CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(retType)); + CHECK_EQ(PrimitiveType::String, getPrimitiveType(retType)); } TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") @@ -406,11 +406,11 @@ TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") )"); LUAU_REQUIRE_NO_ERRORS(result); - const Luau::FunctionTypeVar* fn = get(requireType("p")); + const Luau::FunctionType* fn = get(requireType("p")); REQUIRE(fn); auto ret = first(fn->retTypes); REQUIRE(ret); - REQUIRE(get(follow(*ret))); + REQUIRE(get(follow(*ret))); } TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") @@ -506,12 +506,12 @@ TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotat LUAU_REQUIRE_NO_ERRORS(result); TypeId ty = requireType("most_of_the_natural_numbers"); - const FunctionTypeVar* functionType = get(ty); + const FunctionType* functionType = get(ty); REQUIRE_MESSAGE(functionType, "Expected function but got " << toString(ty)); std::optional retType = first(functionType->retTypes); REQUIRE(retType); - CHECK(get(*retType)); + CHECK(get(*retType)); } TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") @@ -524,14 +524,14 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* ftv = get(requireType("apply")); + const FunctionType* ftv = get(requireType("apply")); REQUIRE(ftv != nullptr); std::vector argVec = flatten(ftv->argTypes).first; REQUIRE_EQ(2, argVec.size()); - const FunctionTypeVar* fType = get(follow(argVec[0])); + const FunctionType* fType = get(follow(argVec[0])); REQUIRE_MESSAGE(fType != nullptr, "Expected a function but got " << toString(argVec[0])); std::vector fArgs = flatten(fType->argTypes).first; @@ -561,14 +561,14 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* ftv = get(requireType("bottomupmerge")); + const FunctionType* ftv = get(requireType("bottomupmerge")); REQUIRE(ftv != nullptr); std::vector argVec = flatten(ftv->argTypes).first; REQUIRE_EQ(6, argVec.size()); - const FunctionTypeVar* fType = get(follow(argVec[0])); + const FunctionType* fType = get(follow(argVec[0])); REQUIRE(fType != nullptr); } @@ -591,14 +591,14 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* ftv = get(requireType("swapTwice")); + const FunctionType* ftv = get(requireType("swapTwice")); REQUIRE(ftv != nullptr); std::vector argVec = flatten(ftv->argTypes).first; REQUIRE_EQ(1, argVec.size()); - const TableTypeVar* argType = get(follow(argVec[0])); + const TableType* argType = get(follow(argVec[0])); REQUIRE(argType != nullptr); CHECK(bool(argType->indexer)); @@ -648,18 +648,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") * In other words, comp(arr[x], arr[y]) is well-typed. */ - const FunctionTypeVar* ftv = get(requireType("mergesort")); + const FunctionType* ftv = get(requireType("mergesort")); REQUIRE(ftv != nullptr); std::vector argVec = flatten(ftv->argTypes).first; REQUIRE_EQ(2, argVec.size()); - const TableTypeVar* arg0 = get(follow(argVec[0])); + const TableType* arg0 = get(follow(argVec[0])); REQUIRE(arg0 != nullptr); REQUIRE(bool(arg0->indexer)); - const FunctionTypeVar* arg1 = get(follow(argVec[1])); + const FunctionType* arg1 = get(follow(argVec[1])); REQUIRE(arg1 != nullptr); REQUIRE_EQ(2, size(arg1->argTypes)); @@ -1003,7 +1003,7 @@ TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") LUAU_REQUIRE_NO_ERRORS(result); TypeId type = requireTypeAtPosition(Position(6, 14)); CHECK_EQ("(tbl, number, number) -> number", toString(type)); - auto ftv = get(type); + auto ftv = get(type); REQUIRE(ftv); CHECK(ftv->hasSelf); } diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index c25f8e5fc..7b4176211 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Luau/Scope.h" #include @@ -224,7 +224,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") LUAU_REQUIRE_NO_ERRORS(result); TypeId idType = requireType("id"); - const FunctionTypeVar* idFun = get(idType); + const FunctionType* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); auto [rets, varrets] = flatten(idFun->retTypes); @@ -247,7 +247,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") LUAU_REQUIRE_NO_ERRORS(result); TypeId idType = requireType("id"); - const FunctionTypeVar* idFun = get(idType); + const FunctionType* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); auto [rets, varrets] = flatten(idFun->retTypes); @@ -854,14 +854,14 @@ TEST_CASE_FIXTURE(Fixture, "generic_table_method") LUAU_REQUIRE_NO_ERRORS(result); TypeId tType = requireType("T"); - TableTypeVar* tTable = getMutable(tType); + TableType* tTable = getMutable(tType); REQUIRE(tTable != nullptr); REQUIRE(tTable->props.count("bar")); TypeId barType = tTable->props["bar"].type; REQUIRE(barType != nullptr); - const FunctionTypeVar* ftv = get(follow(barType)); + const FunctionType* ftv = get(follow(barType)); REQUIRE_MESSAGE(ftv != nullptr, "Should be a function: " << *barType); std::vector args = flatten(ftv->argTypes).first; @@ -887,20 +887,20 @@ TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") LUAU_REQUIRE_NO_ERRORS(result); dumpErrors(result); - const TableTypeVar* t = get(requireType("T")); + const TableType* t = get(requireType("T")); REQUIRE(t != nullptr); std::optional fooProp = get(t->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionTypeVar* foo = get(follow(fooProp->type)); + const FunctionType* foo = get(follow(fooProp->type)); REQUIRE(bool(foo)); std::optional ret_ = first(foo->retTypes); REQUIRE(bool(ret_)); TypeId ret = follow(*ret_); - REQUIRE_EQ(getPrimitiveType(ret), PrimitiveTypeVar::Number); + REQUIRE_EQ(getPrimitiveType(ret), PrimitiveType::Number); } /* @@ -927,20 +927,20 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") )"); TypeId g = requireType("g"); - const FunctionTypeVar* gFun = get(g); + const FunctionType* gFun = get(g); REQUIRE(gFun != nullptr); auto optionArg = first(gFun->argTypes); REQUIRE(bool(optionArg)); TypeId arg = follow(*optionArg); - const TableTypeVar* argTable = get(arg); + const TableType* argTable = get(arg); REQUIRE(argTable != nullptr); std::optional methodProp = get(argTable->props, "method"); REQUIRE(bool(methodProp)); - const FunctionTypeVar* methodFunction = get(methodProp->type); + const FunctionType* methodFunction = get(methodProp->type); REQUIRE(methodFunction != nullptr); std::optional methodArg = first(methodFunction->argTypes); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 188be63c7..b57d88202 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -162,14 +162,14 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guarante LUAU_REQUIRE_NO_ERRORS(result); - const IntersectionTypeVar* r = get(requireType("r")); + const IntersectionType* r = get(requireType("r")); REQUIRE(r); - TableTypeVar* a = getMutable(r->parts[0]); + TableType* a = getMutable(r->parts[0]); REQUIRE(a); CHECK_EQ(typeChecker.numberType, a->props["y"].type); - TableTypeVar* b = getMutable(r->parts[1]); + TableType* b = getMutable(r->parts[1]); REQUIRE(b); CHECK_EQ(typeChecker.numberType, b->props["y"].type); } diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 40912a95f..7a89df96c 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -4,8 +4,8 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include "Fixture.h" diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index b06c80e92..fe52d1682 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -4,7 +4,7 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -105,7 +105,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_types") REQUIRE(b != nullptr); TypeId hType = requireType(b, "h"); - REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); + REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); } TEST_CASE_FIXTURE(BuiltinsFixture, "require_a_variadic_function") @@ -128,7 +128,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_a_variadic_function") TypeId f = follow(requireType(bModule, "f")); - const FunctionTypeVar* ftv = get(f); + const FunctionType* ftv = get(f); REQUIRE(ftv); auto iter = begin(ftv->argTypes); @@ -351,7 +351,7 @@ local arrayops = require(game.A) local tbl = {} tbl.a = 2 function tbl:foo(b: number, c: number) - -- introduce BoundTypeVar to imported type + -- introduce BoundType to imported type arrayops.foo(self._regions) end -- this alias decreases function type level and causes a demotion of its type @@ -376,7 +376,7 @@ local arrayops = require(game.A) local tbl = {} tbl.a = 2 function tbl:foo(b: number, c: number) - -- introduce boundTo TableTypeVar to imported type + -- introduce boundTo TableType to imported type self.x.a = 2 arrayops.foo(self.x) end diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp index 0e7fb03de..02350a728 100644 --- a/tests/TypeInfer.negations.test.cpp +++ b/tests/TypeInfer.negations.test.cpp @@ -10,6 +10,7 @@ using namespace Luau; namespace { + struct NegationFixture : Fixture { TypeArena arena; @@ -22,6 +23,7 @@ struct NegationFixture : Fixture registerHiddenTypes(*this, arena); } }; + } // namespace TEST_SUITE_BEGIN("Negations"); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 41690704a..088b4d56e 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -4,8 +4,8 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include "Fixture.h" @@ -93,8 +93,8 @@ TEST_CASE_FIXTURE(Fixture, "methods_are_topologically_sorted") LUAU_REQUIRE_NO_ERRORS(result); dumpErrors(result); - CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("a"))); - CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(requireType("b"))); + CHECK_EQ(PrimitiveType::Number, getPrimitiveType(requireType("a"))); + CHECK_EQ(PrimitiveType::String, getPrimitiveType(requireType("b"))); } TEST_CASE_FIXTURE(Fixture, "quantify_methods_defined_using_dot_syntax_and_explicit_self_parameter") @@ -139,7 +139,7 @@ TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocat )"); ModulePtr module = getMainModule(); - CHECK_GE(50, module->internalTypes.typeVars.size()); + CHECK_GE(50, module->internalTypes.types.size()); } TEST_CASE_FIXTURE(BuiltinsFixture, "object_constructor_can_refer_to_method_of_self") diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 93d7361bf..0196666a0 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -4,8 +4,8 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include "Fixture.h" #include "ClassFixture.h" @@ -125,7 +125,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* functionType = get(requireType("add")); + const FunctionType* functionType = get(requireType("add")); std::optional retType = first(functionType->retTypes); REQUIRE(retType.has_value()); @@ -1017,7 +1017,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "mm_ops_must_return_a_value") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(requireType("y") == singletonTypes->errorRecoveryType()); + CHECK(requireType("y") == builtinTypes->errorRecoveryType()); const GenericError* ge = get(result.errors[0]); REQUIRE(ge); @@ -1051,8 +1051,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK(requireType("v1") == singletonTypes->booleanType); - CHECK(requireType("v2") == singletonTypes->booleanType); + CHECK(requireType("v1") == builtinTypes->booleanType); + CHECK(requireType("v2") == builtinTypes->booleanType); CHECK(toString(result.errors[0]) == "Metamethod '__lt' must return type 'boolean'"); CHECK(toString(result.errors[1]) == "Metamethod '__lt' must return type 'boolean'"); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 3c2c8781d..7e99f0b02 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -4,8 +4,8 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include "Fixture.h" diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index b7408f876..cf969f2d7 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -456,7 +456,7 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") // // (*blocked*) -> () <: (number) -> (b...) // - // We solve this by searching both types for BlockedTypeVars and block the + // We solve this by searching both types for BlockedTypes and block the // constraint on any we find. It also gets the job done, but I'm worried // about the efficiency of doing so many deep type traversals and it may // make us more prone to getting stuck on constraint cycles. @@ -473,19 +473,19 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") { TypeArena arena; - TypeId nilType = singletonTypes->nilType; + TypeId nilType = builtinTypes->nilType; - std::unique_ptr scope = std::make_unique(singletonTypes->anyTypePack); + std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); TypeId free1 = arena.addType(FreeTypePack{scope.get()}); - TypeId option1 = arena.addType(UnionTypeVar{{nilType, free1}}); + TypeId option1 = arena.addType(UnionType{{nilType, free1}}); TypeId free2 = arena.addType(FreeTypePack{scope.get()}); - TypeId option2 = arena.addType(UnionTypeVar{{nilType, free2}}); + TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; UnifierSharedState sharedState{&iceHandler}; - Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant}; u.tryUnify(option1, option2); @@ -550,7 +550,7 @@ return wrapStrictTable(Constants, "Constants") std::optional result = first(m->getModuleScope()->returnType); REQUIRE(result); - CHECK(get(*result)); + CHECK(get(*result)); } TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") @@ -589,7 +589,7 @@ return wrapStrictTable(Constants, "Constants") std::optional result = first(m->getModuleScope()->returnType); REQUIRE(result); - CHECK(get(*result)); + CHECK(get(*result)); } // We need a simplification step to make this do the right thing. ("normalization-lite") @@ -620,7 +620,7 @@ struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { - return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); + return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, builtinTypes, ice); } }; } // namespace diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 5a7c8432a..f77cacfa9 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,6 +8,7 @@ #include "doctest.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauNegatedClassTypes) using namespace Luau; @@ -63,30 +64,32 @@ struct RefinementClassFixture : BuiltinsFixture TypeArena& arena = typeChecker.globalTypes; NotNull scope{typeChecker.globalScope.get()}; + std::optional rootSuper = FFlag::LuauNegatedClassTypes ? std::make_optional(typeChecker.builtinTypes->classType) : std::nullopt; + unfreeze(arena); - TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); - getMutable(vec3)->props = { + TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); + getMutable(vec3)->props = { {"X", Property{typeChecker.numberType}}, {"Y", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}}, }; - TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); + TypeId inst = arena.addType(ClassType{"Instance", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); - TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); - getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - getMutable(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA; + TypeId isA = arena.addType(FunctionType{isAParams, isARets}); + getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + getMutable(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA; - getMutable(inst)->props = { + getMutable(inst)->props = { {"Name", Property{typeChecker.stringType}}, {"IsA", Property{isA}}, }; - TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); - TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); - getMutable(part)->props = { + TypeId folder = typeChecker.globalTypes.addType(ClassType{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); + TypeId part = typeChecker.globalTypes.addType(ClassType{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); + getMutable(part)->props = { {"Position", Property{vec3}}, }; diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index c379559dc..e2aa01f95 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4,7 +4,7 @@ #include "Luau/Frontend.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -27,20 +27,20 @@ TEST_CASE_FIXTURE(Fixture, "basic") CheckResult result = check("local t = {foo = \"bar\", baz = 9, quux = nil}"); LUAU_REQUIRE_NO_ERRORS(result); - const TableTypeVar* tType = get(requireType("t")); + const TableType* tType = get(requireType("t")); REQUIRE(tType != nullptr); std::optional fooProp = get(tType->props, "foo"); REQUIRE(bool(fooProp)); - CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(fooProp->type)); + CHECK_EQ(PrimitiveType::String, getPrimitiveType(fooProp->type)); std::optional bazProp = get(tType->props, "baz"); REQUIRE(bool(bazProp)); - CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(bazProp->type)); + CHECK_EQ(PrimitiveType::Number, getPrimitiveType(bazProp->type)); std::optional quuxProp = get(tType->props, "quux"); REQUIRE(bool(quuxProp)); - CHECK_EQ(PrimitiveTypeVar::NilType, getPrimitiveType(quuxProp->type)); + CHECK_EQ(PrimitiveType::NilType, getPrimitiveType(quuxProp->type)); } TEST_CASE_FIXTURE(Fixture, "augment_table") @@ -48,7 +48,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_table") CheckResult result = check("local t = {} t.foo = 'bar'"); LUAU_REQUIRE_NO_ERRORS(result); - const TableTypeVar* tType = get(requireType("t")); + const TableType* tType = get(requireType("t")); REQUIRE(tType != nullptr); CHECK("{ foo: string }" == toString(requireType("t"), {true})); @@ -59,11 +59,11 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table") CheckResult result = check("local t = { p = {} } t.p.foo = 'bar'"); LUAU_REQUIRE_NO_ERRORS(result); - TableTypeVar* tType = getMutable(requireType("t")); + TableType* tType = getMutable(requireType("t")); REQUIRE(tType != nullptr); REQUIRE(tType->props.find("p") != tType->props.end()); - const TableTypeVar* pType = get(tType->props["p"].type); + const TableType* pType = get(tType->props["p"].type); REQUIRE(pType != nullptr); CHECK("{ p: { foo: string } }" == toString(requireType("t"), {true})); @@ -142,13 +142,13 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function") CheckResult result = check("local T = {} function T:foo() return 5 end"); LUAU_REQUIRE_NO_ERRORS(result); - const TableTypeVar* tableType = get(requireType("T")); + const TableType* tableType = get(requireType("T")); REQUIRE(tableType != nullptr); std::optional fooProp = get(tableType->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionTypeVar* methodType = get(follow(fooProp->type)); + const FunctionType* methodType = get(follow(fooProp->type)); REQUIRE(methodType != nullptr); } @@ -157,20 +157,20 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function_2") CheckResult result = check("local T = {U={}} function T.U:foo() return 5 end"); LUAU_REQUIRE_NO_ERRORS(result); - const TableTypeVar* tableType = get(requireType("T")); + const TableType* tableType = get(requireType("T")); REQUIRE(tableType != nullptr); std::optional uProp = get(tableType->props, "U"); REQUIRE(bool(uProp)); TypeId uType = uProp->type; - const TableTypeVar* uTable = get(uType); + const TableType* uTable = get(uType); REQUIRE(uTable != nullptr); std::optional fooProp = get(uTable->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionTypeVar* methodType = get(follow(fooProp->type)); + const FunctionType* methodType = get(follow(fooProp->type)); REQUIRE(methodType != nullptr); std::vector methodArgs = flatten(methodType->argTypes).first; @@ -324,7 +324,7 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_3") )"); TypeId fooType = requireType("foo"); - const FunctionTypeVar* fooFn = get(fooType); + const FunctionType* fooFn = get(fooType); REQUIRE(fooFn != nullptr); std::vector fooArgs = flatten(fooFn->argTypes).first; @@ -332,7 +332,7 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_3") REQUIRE_EQ(1, fooArgs.size()); TypeId arg0 = fooArgs[0]; - const TableTypeVar* arg0Table = get(follow(arg0)); + const TableType* arg0Table = get(follow(arg0)); REQUIRE(arg0Table != nullptr); REQUIRE(arg0Table->props.find("bar") != arg0Table->props.end()); @@ -433,7 +433,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") std::cout << "Error: " << e << std::endl; TypeId qType = requireType("q"); - const TableTypeVar* qTable = get(qType); + const TableType* qTable = get(qType); REQUIRE(qType != nullptr); CHECK(qTable->props.find("x") != qTable->props.end()); @@ -442,7 +442,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") CHECK(qTable->props.find("w") != qTable->props.end()); TypeId wType = requireType("w"); - const TableTypeVar* wTable = get(wType); + const TableType* wTable = get(wType); REQUIRE(wTable != nullptr); CHECK(wTable->props.find("x") != wTable->props.end()); @@ -553,7 +553,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_array") LUAU_REQUIRE_NO_ERRORS(result); - const TableTypeVar* ttv = get(requireType("t")); + const TableType* ttv = get(requireType("t")); REQUIRE(ttv != nullptr); REQUIRE(bool(ttv->indexer)); @@ -603,14 +603,14 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* ftv = get(requireType("swap")); + const FunctionType* ftv = get(requireType("swap")); REQUIRE(ftv != nullptr); std::vector argVec = flatten(ftv->argTypes).first; REQUIRE_EQ(1, argVec.size()); - const TableTypeVar* ttv = get(follow(argVec[0])); + const TableType* ttv = get(follow(argVec[0])); REQUIRE(ttv != nullptr); REQUIRE(bool(ttv->indexer)); @@ -619,7 +619,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") REQUIRE_EQ(indexer.indexType, typeChecker.numberType); - REQUIRE(nullptr != get(follow(indexer.indexResultType))); + REQUIRE(nullptr != get(follow(indexer.indexResultType))); } TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") @@ -633,19 +633,19 @@ TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* ftv = get(requireType("mergesort")); + const FunctionType* ftv = get(requireType("mergesort")); REQUIRE(ftv != nullptr); std::vector argVec = flatten(ftv->argTypes).first; REQUIRE_EQ(1, argVec.size()); - const TableTypeVar* argType = get(follow(argVec[0])); + const TableType* argType = get(follow(argVec[0])); REQUIRE(argType != nullptr); std::vector retVec = flatten(ftv->retTypes).first; - const TableTypeVar* retType = get(follow(retVec[0])); + const TableType* retType = get(follow(retVec[0])); REQUIRE(retType != nullptr); CHECK_EQ(argType->state, retType->state); @@ -661,7 +661,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_array_like_table") LUAU_REQUIRE_NO_ERRORS(result); - const TableTypeVar* ttv = get(requireType("t")); + const TableType* ttv = get(requireType("t")); REQUIRE(ttv != nullptr); REQUIRE(bool(ttv->indexer)); @@ -689,13 +689,13 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* fType = get(requireType("f")); + const FunctionType* fType = get(requireType("f")); REQUIRE(fType != nullptr); auto retType_ = first(fType->retTypes); REQUIRE(bool(retType_)); - auto retType = get(follow(*retType_)); + auto retType = get(follow(*retType_)); REQUIRE(retType != nullptr); CHECK(bool(retType->indexer)); @@ -718,7 +718,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_its_variable_type_and_unifiable") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm != nullptr); - const TableTypeVar* tTy = get(requireType("t2")); + const TableType* tTy = get(requireType("t2")); REQUIRE(tTy != nullptr); REQUIRE(tTy->indexer); @@ -742,8 +742,8 @@ TEST_CASE_FIXTURE(Fixture, "indexer_mismatch") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm != nullptr); - CHECK_EQ(tm->wantedType, t2); - CHECK_EQ(tm->givenType, t1); + CHECK(toString(tm->wantedType) == "{number}"); + CHECK(toString(tm->givenType) == "{| [string]: string |}"); CHECK_NE(*t1, *t2); } @@ -871,7 +871,7 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s CHECK("string" == toString(*typeChecker.stringType)); - TableTypeVar* tableType = getMutable(requireType("t")); + TableType* tableType = getMutable(requireType("t")); REQUIRE(tableType != nullptr); REQUIRE(tableType->indexer == std::nullopt); REQUIRE(0 != tableType->props.count("a")); @@ -998,11 +998,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "unification_of_unions_in_a_self_referential_ LUAU_REQUIRE_NO_ERRORS(result); - const MetatableTypeVar* amtv = get(requireType("a")); + const MetatableType* amtv = get(requireType("a")); REQUIRE(amtv); CHECK_EQ(amtv->metatable, requireType("amt")); - const MetatableTypeVar* bmtv = get(requireType("b")); + const MetatableType* bmtv = get(requireType("b")); REQUIRE(bmtv); CHECK_EQ(bmtv->metatable, requireType("bmt")); } @@ -1408,10 +1408,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "missing_metatable_for_sealed_tables_do_not_g CHECK_EQ(tm->wantedType, t); CHECK_EQ(tm->givenType, a); - const MetatableTypeVar* aTy = get(a); + const MetatableType* aTy = get(a); REQUIRE(aTy); - const TableTypeVar* tTy = get(t); + const TableType* tTy = get(t); REQUIRE(tTy); } @@ -1621,7 +1621,7 @@ TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(requireType("t"), tm->wantedType); + CHECK("{ a: number, b: number, c: number, d: number, e: number, ... 1 more ... }" == toString(requireType("t"))); CHECK_EQ("number", toString(tm->givenType)); CHECK_EQ("Type 'number' could not be converted into '{ a: number, b: number, c: number, d: number, e: number, ... 1 more ... }'", @@ -1755,7 +1755,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") else CHECK_EQ("Cannot add property 'bad' to table 'os'", toString(result.errors[0])); - const TableTypeVar* osType = get(requireType("os")); + const TableType* osType = get(requireType("os")); REQUIRE(osType != nullptr); CHECK(osType->props.find("bad") == osType->props.end()); } @@ -1865,19 +1865,19 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") LUAU_REQUIRE_NO_ERRORS(result); TypeId ty = requireType("clazz"); - TableTypeVar* ttv = getMutable(ty); + TableType* ttv = getMutable(ty); REQUIRE(ttv); REQUIRE(ttv->props.count("new")); Property& prop = ttv->props["new"]; REQUIRE(prop.type); - const FunctionTypeVar* ftv = get(follow(prop.type)); + const FunctionType* ftv = get(follow(prop.type)); REQUIRE(ftv); const TypePack* res = get(follow(ftv->retTypes)); REQUIRE(res); REQUIRE(res->head.size() == 1); - const MetatableTypeVar* mtv = get(follow(res->head[0])); + const MetatableType* mtv = get(follow(res->head[0])); REQUIRE(mtv); - ttv = getMutable(follow(mtv->table)); + ttv = getMutable(follow(mtv->table)); REQUIRE(ttv); REQUIRE_EQ(ttv->state, TableState::Sealed); } @@ -2424,7 +2424,7 @@ TEST_CASE_FIXTURE(Fixture, "table_length") LUAU_REQUIRE_NO_ERRORS(result); - CHECK(nullptr != get(requireType("t"))); + CHECK(nullptr != get(requireType("t"))); CHECK_EQ(*typeChecker.numberType, *requireType("s")); } @@ -2535,13 +2535,13 @@ TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") LUAU_REQUIRE_NO_ERRORS(result); dumpErrors(result); - const FunctionTypeVar* fooType = get(requireType("foo")); + const FunctionType* fooType = get(requireType("foo")); REQUIRE(fooType); std::optional fooArg1 = first(fooType->argTypes); REQUIRE(fooArg1); - const TableTypeVar* fooArg1Table = get(*fooArg1); + const TableType* fooArg1Table = get(*fooArg1); REQUIRE(fooArg1Table); CHECK_EQ(fooArg1Table->state, TableState::Generic); @@ -2549,13 +2549,13 @@ TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") /* * This test case exposed an oversight in the treatment of free tables. - * Free tables, like free TypeVars, need to record the scope depth where they were created so that + * Free tables, like free Types, need to record the scope depth where they were created so that * we do not erroneously let-generalize them when they are used in a nested lambda. * * For more information about let-generalization, see * * The important idea here is that the return type of Counter.new is a table with some metatable. - * That metatable *must* be the same TypeVar as the type of Counter. If it is a copy (produced by + * That metatable *must* be the same Type as the type of Counter. If it is a copy (produced by * the generalization process), then it loses the knowledge that its metatable will have an :incr() * method. */ @@ -2581,20 +2581,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_sc LUAU_REQUIRE_NO_ERRORS(result); - TableTypeVar* counterType = getMutable(requireType("Counter")); + TableType* counterType = getMutable(requireType("Counter")); REQUIRE(counterType); REQUIRE(counterType->props.count("new")); - const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); + const FunctionType* newType = get(follow(counterType->props["new"].type)); REQUIRE(newType); std::optional newRetType = *first(newType->retTypes); REQUIRE(newRetType); - const MetatableTypeVar* newRet = get(follow(*newRetType)); + const MetatableType* newRet = get(follow(*newRetType)); REQUIRE(newRet); - const TableTypeVar* newRetMeta = get(newRet->metatable); + const TableType* newRetMeta = get(newRet->metatable); REQUIRE(newRetMeta); CHECK(newRetMeta->props.count("incr")); @@ -2627,7 +2627,7 @@ TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") )"); ModulePtr module = getMainModule(); - CHECK_GE(100, module->internalTypes.typeVars.size()); + CHECK_GE(100, module->internalTypes.types.size()); } TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers") @@ -2703,7 +2703,7 @@ type t0 = any std::optional ty = requireType("math"); REQUIRE(ty); - const TableTypeVar* ttv = get(*ty); + const TableType* ttv = get(*ty); REQUIRE(ttv); CHECK(ttv->instantiatedTypeParams.empty()); } @@ -2720,7 +2720,7 @@ type K = X std::optional ty = requireType("math"); REQUIRE(ty); - const TableTypeVar* ttv = get(*ty); + const TableType* ttv = get(*ty); REQUIRE(ttv); CHECK(ttv->instantiatedTypeParams.empty()); } @@ -2742,7 +2742,7 @@ c = b std::optional ty = requireType("a"); REQUIRE(ty); - const TableTypeVar* ttv = get(*ty); + const TableType* ttv = get(*ty); REQUIRE(ttv); CHECK(ttv->instantiatedTypeParams.empty()); } @@ -2778,7 +2778,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_basic") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK(requireType("foo") == singletonTypes->numberType); + CHECK(requireType("foo") == builtinTypes->numberType); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_must_be_callable") @@ -2794,7 +2794,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_must_be_callable") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(result.errors[0] == TypeError{ Location{{5, 20}, {5, 21}}, - CannotCallNonFunction{singletonTypes->numberType}, + CannotCallNonFunction{builtinTypes->numberType}, }); } @@ -2812,8 +2812,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_generic") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK(requireType("foo") == singletonTypes->numberType); - CHECK(requireType("bar") == singletonTypes->stringType); + CHECK(requireType("foo") == builtinTypes->numberType); + CHECK(requireType("bar") == builtinTypes->stringType); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index e42cea638..f6279fa2c 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4,8 +4,8 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -28,7 +28,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_hello_world") LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::Number); + CHECK_EQ(getPrimitiveType(aType), PrimitiveType::Number); } TEST_CASE_FIXTURE(Fixture, "tc_propagation") @@ -37,7 +37,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_propagation") LUAU_REQUIRE_NO_ERRORS(result); TypeId bType = requireType("b"); - CHECK_EQ(getPrimitiveType(bType), PrimitiveTypeVar::Number); + CHECK_EQ(getPrimitiveType(bType), PrimitiveType::Number); } TEST_CASE_FIXTURE(Fixture, "tc_error") @@ -65,7 +65,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_with_nil_value") LUAU_REQUIRE_NO_ERRORS(result); TypeId ty = requireType("f"); - CHECK_EQ(getPrimitiveType(ty), PrimitiveTypeVar::String); + CHECK_EQ(getPrimitiveType(ty), PrimitiveType::String); } TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") @@ -213,7 +213,7 @@ TEST_CASE_FIXTURE(Fixture, "crazy_complexity") A:A():A():A():A():A():A():A():A():A():A():A() )"); - std::cout << "OK! Allocated " << typeChecker.typeVars.size() << " typevars" << std::endl; + std::cout << "OK! Allocated " << typeChecker.types.size() << " types" << std::endl; } #endif @@ -294,7 +294,7 @@ TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") // If we're not careful about copying, this ends up with O(2^N) types rather than O(N) // (in this case 5 vs 31). - CHECK_GE(5, module->interfaceTypes.typeVars.size()); + CHECK_GE(5, module->interfaceTypes.types.size()); } // In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type @@ -455,7 +455,7 @@ end )"); } -struct FindFreeTypeVars +struct FindFreeTypes { bool foundOne = false; @@ -487,7 +487,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery") LUAU_REQUIRE_ERRORS(result); TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::Number); + CHECK_EQ(getPrimitiveType(aType), PrimitiveType::Number); } // Check that type checker knows about error expressions @@ -758,7 +758,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") CheckResult result = check(R"(local a = if true then "true" else "false")"); LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); + CHECK_EQ(getPrimitiveType(aType), PrimitiveType::String); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") @@ -769,7 +769,7 @@ local a = if false then "a" elseif false then "b" else "c" )"); LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); + CHECK_EQ(getPrimitiveType(aType), PrimitiveType::String); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") @@ -854,11 +854,11 @@ TEST_CASE_FIXTURE(Fixture, "tc_interpolated_string_constant_type") * * We had an issue here where the scope for the `if` block here would * have an elevated TypeLevel even though there is no function nesting going on. - * This would result in a free typevar for the type of _ that was much higher than + * This would result in a free type for the type of _ that was much higher than * it should be. This type would be erroneously quantified in the definition of `aaa`. * This in turn caused an ice when evaluating `_()` in the while loop. */ -TEST_CASE_FIXTURE(Fixture, "free_typevars_introduced_within_control_flow_constructs_do_not_get_an_elevated_TypeLevel") +TEST_CASE_FIXTURE(Fixture, "free_types_introduced_within_control_flow_constructs_do_not_get_an_elevated_TypeLevel") { check(R"( --!strict diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index b1abdf7c9..80c7ab579 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -15,7 +15,7 @@ struct TryUnifyFixture : Fixture ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; - Normalizer normalizer{&arena, singletonTypes, NotNull{&unifierState}}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&unifierState}}; Unifier state{NotNull{&normalizer}, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant}; }; @@ -23,8 +23,8 @@ TEST_SUITE_BEGIN("TryUnifyTests"); TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") { - TypeVar numberOne{TypeVariant{PrimitiveTypeVar{PrimitiveTypeVar::Number}}}; - TypeVar numberTwo = numberOne; + Type numberOne{TypeVariant{PrimitiveType{PrimitiveType::Number}}}; + Type numberTwo = numberOne; state.tryUnify(&numberTwo, &numberOne); @@ -33,11 +33,11 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") { - TypeVar functionOne{ - TypeVariant{FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; + Type functionOne{ + TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; - TypeVar functionTwo{TypeVariant{ - FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; + Type functionTwo{TypeVariant{ + FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; state.tryUnify(&functionTwo, &functionOne); CHECK(state.errors.empty()); @@ -50,16 +50,16 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") { TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; - TypeVar functionOne{ - TypeVariant{FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; + Type functionOne{ + TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; - TypeVar functionOneSaved = functionOne; + Type functionOneSaved = functionOne; TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; - TypeVar functionTwo{ - TypeVariant{FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.stringType}))}}; + Type functionTwo{ + TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.stringType}))}}; - TypeVar functionTwoSaved = functionTwo; + Type functionTwoSaved = functionTwo; state.tryUnify(&functionTwo, &functionOne); CHECK(!state.errors.empty()); @@ -70,15 +70,15 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") { - TypeVar tableOne{TypeVariant{ - TableTypeVar{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, + Type tableOne{TypeVariant{ + TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; - TypeVar tableTwo{TypeVariant{ - TableTypeVar{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, + Type tableTwo{TypeVariant{ + TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); state.tryUnify(&tableTwo, &tableOne); @@ -86,28 +86,28 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") state.log.commit(); - CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") { - TypeVar tableOne{TypeVariant{ - TableTypeVar{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.numberType}}}, std::nullopt, globalScope->level, + Type tableOne{TypeVariant{ + TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.numberType}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; - TypeVar tableTwo{TypeVariant{ - TableTypeVar{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.stringType}}}, std::nullopt, globalScope->level, + Type tableTwo{TypeVariant{ + TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.stringType}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); state.tryUnify(&tableTwo, &tableOne); CHECK_EQ(1, state.errors.size()); - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_never") @@ -282,11 +282,11 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") { - TypeVar redirect{FreeTypeVar{TypeLevel{}}}; - TypeVar table{TableTypeVar{}}; - TypeVar metatable{MetatableTypeVar{&redirect, &table}}; - redirect = BoundTypeVar{&metatable}; // Now we have a metatable that is recursive on the table type - TypeVar variant{UnionTypeVar{{&metatable, typeChecker.numberType}}}; + Type redirect{FreeType{TypeLevel{}}}; + Type table{TableType{}}; + Type metatable{MetatableType{&redirect, &table}}; + redirect = BoundType{&metatable}; // Now we have a metatable that is recursive on the table type + Type variant{UnionType{{&metatable, typeChecker.numberType}}}; state.tryUnify(&metatable, &variant); } @@ -296,7 +296,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") TypePackVar free{FreeTypePack{TypeLevel{}}}; TypePackVar target{TypePack{}}; - TypeVar func{FunctionTypeVar{&free, &free}}; + Type func{FunctionType{&free, &free}}; state.tryUnify(&free, &target); // Shouldn't assert or error. @@ -305,7 +305,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") { - TypeId a = arena.addType(TypeVar{FreeTypeVar{TypeLevel{}}}); + TypeId a = arena.addType(Type{FreeType{TypeLevel{}}}); TypeId b = typeChecker.numberType; state.tryUnify(a, b); @@ -329,26 +329,26 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table { ScopedFastFlag sff("DebugLuauDeferredConstraintResolution", true); - TableTypeVar::Props freeProps{ + TableType::Props freeProps{ {"foo", {typeChecker.numberType}}, }; - TypeId free = arena.addType(TableTypeVar{freeProps, std::nullopt, TypeLevel{}, TableState::Free}); + TypeId free = arena.addType(TableType{freeProps, std::nullopt, TypeLevel{}, TableState::Free}); - TableTypeVar::Props indexProps{ + TableType::Props indexProps{ {"foo", {typeChecker.stringType}}, }; - TypeId index = arena.addType(TableTypeVar{indexProps, std::nullopt, TypeLevel{}, TableState::Sealed}); + TypeId index = arena.addType(TableType{indexProps, std::nullopt, TypeLevel{}, TableState::Sealed}); - TableTypeVar::Props mtProps{ + TableType::Props mtProps{ {"__index", {index}}, }; - TypeId mt = arena.addType(TableTypeVar{mtProps, std::nullopt, TypeLevel{}, TableState::Sealed}); + TypeId mt = arena.addType(TableType{mtProps, std::nullopt, TypeLevel{}, TableState::Sealed}); - TypeId target = arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); - TypeId metatable = arena.addType(MetatableTypeVar{target, mt}); + TypeId target = arena.addType(TableType{TableState::Unsealed, TypeLevel{}}); + TypeId metatable = arena.addType(MetatableType{target, mt}); state.tryUnify(metatable, free); state.log.commit(); @@ -369,7 +369,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") TypePackVar packTmp{TypePack{{typeChecker.anyType}, &variadicAny}}; TypePackVar packSub{TypePack{{typeChecker.anyType, typeChecker.anyType}, &packTmp}}; - TypeVar freeTy{FreeTypeVar{TypeLevel{}}}; + Type freeTy{FreeType{TypeLevel{}}}; TypePackVar freeTp{FreeTypePack{TypeLevel{}}}; TypePackVar packSuper{TypePack{{&freeTy}, &freeTp}}; diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 0e4074f75..b753d30e9 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -21,7 +21,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* takeTwoType = get(requireType("take_two")); + const FunctionType* takeTwoType = get(requireType("take_two")); REQUIRE(takeTwoType != nullptr); const auto& [returns, tail] = flatten(takeTwoType->retTypes); @@ -68,7 +68,7 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac LUAU_REQUIRE_NO_ERRORS(result); dumpErrors(result); - const FunctionTypeVar* takeOneMoreType = get(requireType("take_three")); + const FunctionType* takeOneMoreType = get(requireType("take_three")); REQUIRE(takeOneMoreType != nullptr); const auto& [rets, tail] = flatten(takeOneMoreType->retTypes); @@ -101,10 +101,10 @@ TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") function g() return end )"); LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* fTy = get(requireType("f")); + const FunctionType* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); CHECK_EQ(0, size(fTy->retTypes)); - const FunctionTypeVar* gTy = get(requireType("g")); + const FunctionType* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); CHECK_EQ(0, size(gTy->retTypes)); } @@ -121,15 +121,15 @@ TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") )"); LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* fTy = get(requireType("f")); + const FunctionType* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); CHECK_EQ(1, size(follow(fTy->retTypes))); - const FunctionTypeVar* gTy = get(requireType("g")); + const FunctionType* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); CHECK_EQ(0, size(gTy->retTypes)); - const FunctionTypeVar* hTy = get(requireType("h")); + const FunctionType* hTy = get(requireType("h")); REQUIRE(hTy != nullptr); CHECK_EQ(0, size(hTy->retTypes)); } @@ -194,7 +194,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_packs") // clang-format off addGlobalBinding(frontend, "foo", arena.addType( - FunctionTypeVar{ + FunctionType{ listOfNumbers, arena.addTypePack({typeChecker.numberType}) } @@ -203,7 +203,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_packs") ); addGlobalBinding(frontend, "bar", arena.addType( - FunctionTypeVar{ + FunctionType{ arena.addTypePack({{typeChecker.numberType}, listOfStrings}), arena.addTypePack({typeChecker.numberType}) } @@ -306,7 +306,7 @@ local c: Packed CHECK_EQ(toString(*tf), "Packed"); CHECK_EQ(toString(*tf, {true}), "{| f: (T, U...) -> (T, U...) |}"); - auto ttvA = get(requireType("a")); + auto ttvA = get(requireType("a")); REQUIRE(ttvA); CHECK_EQ(toString(requireType("a")), "Packed"); CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); @@ -315,7 +315,7 @@ local c: Packed CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); CHECK_EQ(toString(ttvA->instantiatedTypePackParams[0], {true}), ""); - auto ttvB = get(requireType("b")); + auto ttvB = get(requireType("b")); REQUIRE(ttvB); CHECK_EQ(toString(requireType("b")), "Packed"); CHECK_EQ(toString(requireType("b"), {true}), "{| f: (string, number) -> (string, number) |}"); @@ -324,7 +324,7 @@ local c: Packed CHECK_EQ(toString(ttvB->instantiatedTypeParams[0], {true}), "string"); CHECK_EQ(toString(ttvB->instantiatedTypePackParams[0], {true}), "number"); - auto ttvC = get(requireType("c")); + auto ttvC = get(requireType("c")); REQUIRE(ttvC); CHECK_EQ(toString(requireType("c")), "Packed"); CHECK_EQ(toString(requireType("c"), {true}), "{| f: (string, number, boolean) -> (string, number, boolean) |}"); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index adfc61b63..d30220953 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -26,7 +26,7 @@ TEST_CASE_FIXTURE(Fixture, "return_types_can_be_disjoint") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* utv = get(requireType("most_of_the_natural_numbers")); + const FunctionType* utv = get(requireType("most_of_the_natural_numbers")); REQUIRE(utv != nullptr); } diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 1087a24c8..20404434a 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" +#include "Luau/Type.h" #include "Fixture.h" @@ -14,10 +14,10 @@ struct TypePackFixture { TypePackFixture() { - typeVars.emplace_back(new TypeVar(PrimitiveTypeVar(PrimitiveTypeVar::NilType))); - typeVars.emplace_back(new TypeVar(PrimitiveTypeVar(PrimitiveTypeVar::Boolean))); - typeVars.emplace_back(new TypeVar(PrimitiveTypeVar(PrimitiveTypeVar::Number))); - typeVars.emplace_back(new TypeVar(PrimitiveTypeVar(PrimitiveTypeVar::String))); + typeVars.emplace_back(new Type(PrimitiveType(PrimitiveType::NilType))); + typeVars.emplace_back(new Type(PrimitiveType(PrimitiveType::Boolean))); + typeVars.emplace_back(new Type(PrimitiveType(PrimitiveType::Number))); + typeVars.emplace_back(new Type(PrimitiveType(PrimitiveType::String))); for (const auto& ptr : typeVars) types.push_back(ptr.get()); @@ -37,7 +37,7 @@ struct TypePackFixture std::vector> typePacks; - std::vector> typeVars; + std::vector> typeVars; std::vector types; }; @@ -54,7 +54,7 @@ TEST_CASE_FIXTURE(TypePackFixture, "type_pack_hello") TEST_CASE_FIXTURE(TypePackFixture, "first_chases_Bound_TypePackVars") { - TypeVar nilType{PrimitiveTypeVar{PrimitiveTypeVar::NilType}}; + Type nilType{PrimitiveType{PrimitiveType::NilType}}; auto tp1 = TypePackVar{TypePack{{&nilType}, std::nullopt}}; @@ -206,7 +206,7 @@ TEST_CASE("content_reassignment") TypePackId futureError = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); asMutable(futureError)->reassign(myError); - CHECK(get(futureError) != nullptr); + CHECK(get(futureError) != nullptr); CHECK(!futureError->persistent); CHECK(futureError->owningArena == &arena); } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 5dd1b1bcc..ec0a2473c 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -1,8 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Scope.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" +#include "Luau/Type.h" +#include "Luau/VisitType.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -11,7 +11,7 @@ using namespace Luau; -TEST_SUITE_BEGIN("TypeVarTests"); +TEST_SUITE_BEGIN("TypeTests"); TEST_CASE_FIXTURE(Fixture, "primitives_are_equal") { @@ -20,18 +20,18 @@ TEST_CASE_FIXTURE(Fixture, "primitives_are_equal") TEST_CASE_FIXTURE(Fixture, "bound_type_is_equal_to_that_which_it_is_bound") { - TypeVar bound(BoundTypeVar(typeChecker.booleanType)); + Type bound(BoundType(typeChecker.booleanType)); REQUIRE_EQ(bound, *typeChecker.booleanType); } TEST_CASE_FIXTURE(Fixture, "equivalent_cyclic_tables_are_equal") { - TypeVar cycleOne{TypeVariant(TableTypeVar())}; - TableTypeVar* tableOne = getMutable(&cycleOne); + Type cycleOne{TypeVariant(TableType())}; + TableType* tableOne = getMutable(&cycleOne); tableOne->props["self"] = {&cycleOne}; - TypeVar cycleTwo{TypeVariant(TableTypeVar())}; - TableTypeVar* tableTwo = getMutable(&cycleTwo); + Type cycleTwo{TypeVariant(TableType())}; + TableType* tableTwo = getMutable(&cycleTwo); tableTwo->props["self"] = {&cycleTwo}; CHECK_EQ(cycleOne, cycleTwo); @@ -39,12 +39,12 @@ TEST_CASE_FIXTURE(Fixture, "equivalent_cyclic_tables_are_equal") TEST_CASE_FIXTURE(Fixture, "different_cyclic_tables_are_not_equal") { - TypeVar cycleOne{TypeVariant(TableTypeVar())}; - TableTypeVar* tableOne = getMutable(&cycleOne); + Type cycleOne{TypeVariant(TableType())}; + TableType* tableOne = getMutable(&cycleOne); tableOne->props["self"] = {&cycleOne}; - TypeVar cycleTwo{TypeVariant(TableTypeVar())}; - TableTypeVar* tableTwo = getMutable(&cycleTwo); + Type cycleTwo{TypeVariant(TableType())}; + TableType* tableTwo = getMutable(&cycleTwo); tableTwo->props["this"] = {&cycleTwo}; CHECK_NE(cycleOne, cycleTwo); @@ -54,7 +54,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_not_parenthesized_if_just { auto emptyArgumentPack = TypePackVar{TypePack{}}; auto returnPack = TypePackVar{TypePack{{typeChecker.numberType}}}; - auto returnsTwo = TypeVar(FunctionTypeVar(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ("() -> number", res); @@ -64,7 +64,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_not_just { auto emptyArgumentPack = TypePackVar{TypePack{}}; auto returnPack = TypePackVar{TypePack{{typeChecker.numberType, typeChecker.numberType}}}; - auto returnsTwo = TypeVar(FunctionTypeVar(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ("() -> (number, number)", res); @@ -76,7 +76,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_ auto free = Unifiable::Free(TypeLevel()); auto freePack = TypePackVar{TypePackVariant{free}}; auto returnPack = TypePackVar{TypePack{{typeChecker.numberType}, &freePack}}; - auto returnsTwo = TypeVar(FunctionTypeVar(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ(res, "() -> (number, a...)"); @@ -84,7 +84,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_ TEST_CASE_FIXTURE(Fixture, "subset_check") { - UnionTypeVar super, sub, notSub; + UnionType super, sub, notSub; super.options = {typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType}; sub.options = {typeChecker.numberType, typeChecker.stringType}; notSub.options = {typeChecker.numberType, typeChecker.nilType}; @@ -93,9 +93,9 @@ TEST_CASE_FIXTURE(Fixture, "subset_check") CHECK(!isSubset(super, notSub)); } -TEST_CASE_FIXTURE(Fixture, "iterate_over_UnionTypeVar") +TEST_CASE_FIXTURE(Fixture, "iterate_over_UnionType") { - UnionTypeVar utv; + UnionType utv; utv.options = {typeChecker.numberType, typeChecker.stringType, typeChecker.anyType}; std::vector result; @@ -105,13 +105,13 @@ TEST_CASE_FIXTURE(Fixture, "iterate_over_UnionTypeVar") CHECK(result == utv.options); } -TEST_CASE_FIXTURE(Fixture, "iterating_over_nested_UnionTypeVars") +TEST_CASE_FIXTURE(Fixture, "iterating_over_nested_UnionTypes") { - TypeVar subunion{UnionTypeVar{}}; - UnionTypeVar* innerUtv = getMutable(&subunion); + Type subunion{UnionType{}}; + UnionType* innerUtv = getMutable(&subunion); innerUtv->options = {typeChecker.numberType, typeChecker.stringType}; - UnionTypeVar utv; + UnionType utv; utv.options = {typeChecker.anyType, &subunion}; std::vector result; @@ -124,13 +124,13 @@ TEST_CASE_FIXTURE(Fixture, "iterating_over_nested_UnionTypeVars") CHECK_EQ(result[1], typeChecker.numberType); } -TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypeVars_and_skips_over_them") +TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypes_and_skips_over_them") { - TypeVar atv{UnionTypeVar{}}; - UnionTypeVar* utv1 = getMutable(&atv); + Type atv{UnionType{}}; + UnionType* utv1 = getMutable(&atv); - TypeVar btv{UnionTypeVar{}}; - UnionTypeVar* utv2 = getMutable(&btv); + Type btv{UnionType{}}; + UnionType* utv2 = getMutable(&btv); utv2->options.push_back(typeChecker.numberType); utv2->options.push_back(typeChecker.stringType); utv2->options.push_back(&atv); @@ -148,9 +148,9 @@ TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypeVars_and_skips_over TEST_CASE_FIXTURE(Fixture, "iterator_descends_on_nested_in_first_operator*") { - TypeVar tv1{UnionTypeVar{{typeChecker.stringType, typeChecker.numberType}}}; - TypeVar tv2{UnionTypeVar{{&tv1, typeChecker.booleanType}}}; - auto utv = get(&tv2); + Type tv1{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; + Type tv2{UnionType{{&tv1, typeChecker.booleanType}}}; + auto utv = get(&tv2); std::vector result; for (TypeId ty : utv) @@ -162,30 +162,30 @@ TEST_CASE_FIXTURE(Fixture, "iterator_descends_on_nested_in_first_operator*") CHECK_EQ(result[2], typeChecker.booleanType); } -TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_vector_iter_ctor") +TEST_CASE_FIXTURE(Fixture, "UnionTypeIterator_with_vector_iter_ctor") { - TypeVar tv1{UnionTypeVar{{typeChecker.stringType, typeChecker.numberType}}}; - TypeVar tv2{UnionTypeVar{{&tv1, typeChecker.booleanType}}}; - auto utv = get(&tv2); + Type tv1{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; + Type tv2{UnionType{{&tv1, typeChecker.booleanType}}}; + auto utv = get(&tv2); std::vector actual(begin(utv), end(utv)); std::vector expected{typeChecker.stringType, typeChecker.numberType, typeChecker.booleanType}; CHECK_EQ(actual, expected); } -TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_empty_union") +TEST_CASE_FIXTURE(Fixture, "UnionTypeIterator_with_empty_union") { - TypeVar tv{UnionTypeVar{}}; - auto utv = get(&tv); + Type tv{UnionType{}}; + auto utv = get(&tv); std::vector actual(begin(utv), end(utv)); CHECK(actual.empty()); } -TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_only_cyclic_union") +TEST_CASE_FIXTURE(Fixture, "UnionTypeIterator_with_only_cyclic_union") { - TypeVar tv{UnionTypeVar{}}; - auto utv = getMutable(&tv); + Type tv{UnionType{}}; + auto utv = getMutable(&tv); utv->options.push_back(&tv); utv->options.push_back(&tv); @@ -200,44 +200,44 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_only_cyclic_union") */ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { - TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; + Type ftv11{FreeType{TypeLevel{}}}; TypePackVar tp24{TypePack{{&ftv11}}}; TypePackVar tp17{TypePack{}}; - TypeVar ftv23{FunctionTypeVar{&tp24, &tp17}}; + Type ftv23{FunctionType{&tp24, &tp17}}; - TypeVar ttvConnection2{TableTypeVar{}}; - TableTypeVar* ttvConnection2_ = getMutable(&ttvConnection2); + Type ttvConnection2{TableType{}}; + TableType* ttvConnection2_ = getMutable(&ttvConnection2); ttvConnection2_->instantiatedTypeParams.push_back(&ftv11); ttvConnection2_->props["f"] = {&ftv23}; TypePackVar tp21{TypePack{{&ftv11}}}; TypePackVar tp20{TypePack{}}; - TypeVar ftv19{FunctionTypeVar{&tp21, &tp20}}; + Type ftv19{FunctionType{&tp21, &tp20}}; - TypeVar ttvSignal{TableTypeVar{}}; - TableTypeVar* ttvSignal_ = getMutable(&ttvSignal); + Type ttvSignal{TableType{}}; + TableType* ttvSignal_ = getMutable(&ttvSignal); ttvSignal_->instantiatedTypeParams.push_back(&ftv11); ttvSignal_->props["f"] = {&ftv19}; // Back edge ttvConnection2_->props["signal"] = {&ttvSignal}; - TypeVar gtvK2{GenericTypeVar{}}; - TypeVar gtvV2{GenericTypeVar{}}; + Type gtvK2{GenericType{}}; + Type gtvV2{GenericType{}}; - TypeVar ttvTweenResult2{TableTypeVar{}}; - TableTypeVar* ttvTweenResult2_ = getMutable(&ttvTweenResult2); + Type ttvTweenResult2{TableType{}}; + TableType* ttvTweenResult2_ = getMutable(&ttvTweenResult2); ttvTweenResult2_->instantiatedTypeParams.push_back(>vK2); ttvTweenResult2_->instantiatedTypeParams.push_back(>vV2); TypePackVar tp13{TypePack{{&ttvTweenResult2}}}; - TypeVar ftv12{FunctionTypeVar{&tp13, &tp17}}; + Type ftv12{FunctionType{&tp13, &tp17}}; - TypeVar ttvConnection{TableTypeVar{}}; - TableTypeVar* ttvConnection_ = getMutable(&ttvConnection); + Type ttvConnection{TableType{}}; + TableType* ttvConnection_ = getMutable(&ttvConnection); ttvConnection_->instantiatedTypeParams.push_back(&ttvTweenResult2); ttvConnection_->props["f"] = {&ftv12}; ttvConnection_->props["signal"] = {&ttvSignal}; @@ -245,10 +245,10 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypePackVar tp9{TypePack{}}; TypePackVar tp10{TypePack{{&ttvConnection}}}; - TypeVar ftv8{FunctionTypeVar{&tp9, &tp10}}; + Type ftv8{FunctionType{&tp9, &tp10}}; - TypeVar ttvTween{TableTypeVar{}}; - TableTypeVar* ttvTween_ = getMutable(&ttvTween); + Type ttvTween{TableType{}}; + TableType* ttvTween_ = getMutable(&ttvTween); ttvTween_->instantiatedTypeParams.push_back(>vK2); ttvTween_->instantiatedTypeParams.push_back(>vV2); ttvTween_->props["f"] = {&ftv8}; @@ -256,16 +256,16 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypePackVar tp4{TypePack{}}; TypePackVar tp5{TypePack{{&ttvTween}}}; - TypeVar ftv3{FunctionTypeVar{&tp4, &tp5}}; + Type ftv3{FunctionType{&tp4, &tp5}}; // Back edge ttvTweenResult2_->props["f"] = {&ftv3}; - TypeVar gtvK{GenericTypeVar{}}; - TypeVar gtvV{GenericTypeVar{}}; + Type gtvK{GenericType{}}; + Type gtvV{GenericType{}}; - TypeVar ttvTweenResult{TableTypeVar{}}; - TableTypeVar* ttvTweenResult_ = getMutable(&ttvTweenResult); + Type ttvTweenResult{TableType{}}; + TableType* ttvTweenResult_ = getMutable(&ttvTweenResult); ttvTweenResult_->instantiatedTypeParams.push_back(>vK); ttvTweenResult_->instantiatedTypeParams.push_back(>vV); ttvTweenResult_->props["f"] = {&ftv3}; @@ -273,7 +273,7 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId root = &ttvTweenResult; typeChecker.currentModule = std::make_shared(); - typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(singletonTypes->anyTypePack)); + typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(builtinTypes->anyTypePack)); TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); @@ -282,7 +282,7 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TEST_CASE("tagging_tables") { - TypeVar ttv{TableTypeVar{}}; + Type ttv{TableType{}}; CHECK(!Luau::hasTag(&ttv, "foo")); Luau::attachTag(&ttv, "foo"); CHECK(Luau::hasTag(&ttv, "foo")); @@ -290,7 +290,7 @@ TEST_CASE("tagging_tables") TEST_CASE("tagging_classes") { - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; + Type base{ClassType{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; CHECK(!Luau::hasTag(&base, "foo")); Luau::attachTag(&base, "foo"); CHECK(Luau::hasTag(&base, "foo")); @@ -298,8 +298,8 @@ TEST_CASE("tagging_classes") TEST_CASE("tagging_subclasses") { - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; - TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr, "Test"}}; + Type base{ClassType{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; + Type derived{ClassType{"Derived", {}, &base, std::nullopt, {}, nullptr, "Test"}}; CHECK(!Luau::hasTag(&base, "foo")); CHECK(!Luau::hasTag(&derived, "foo")); @@ -316,7 +316,7 @@ TEST_CASE("tagging_subclasses") TEST_CASE("tagging_functions") { TypePackVar empty{TypePack{}}; - TypeVar ftv{FunctionTypeVar{&empty, &empty}}; + Type ftv{FunctionType{&empty, &empty}}; CHECK(!Luau::hasTag(&ftv, "foo")); Luau::attachTag(&ftv, "foo"); CHECK(Luau::hasTag(&ftv, "foo")); @@ -330,7 +330,7 @@ TEST_CASE("tagging_props") CHECK(Luau::hasTag(prop, "foo")); } -struct VisitCountTracker final : TypeVarOnceVisitor +struct VisitCountTracker final : TypeOnceVisitor { std::unordered_map tyVisits; std::unordered_map tpVisits; @@ -385,65 +385,65 @@ local b: (T, T, T) -> T TEST_CASE("isString_on_string_singletons") { - TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + Type helloString{SingletonType{StringSingleton{"hello"}}}; CHECK(isString(&helloString)); } TEST_CASE("isString_on_unions_of_various_string_singletons") { - TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; - TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; - TypeVar union_{UnionTypeVar{{&helloString, &byeString}}}; + Type helloString{SingletonType{StringSingleton{"hello"}}}; + Type byeString{SingletonType{StringSingleton{"bye"}}}; + Type union_{UnionType{{&helloString, &byeString}}}; CHECK(isString(&union_)); } TEST_CASE("proof_that_isString_uses_all_of") { - TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; - TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; - TypeVar booleanType{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}}; - TypeVar union_{UnionTypeVar{{&helloString, &byeString, &booleanType}}}; + Type helloString{SingletonType{StringSingleton{"hello"}}}; + Type byeString{SingletonType{StringSingleton{"bye"}}}; + Type booleanType{PrimitiveType{PrimitiveType::Boolean}}; + Type union_{UnionType{{&helloString, &byeString, &booleanType}}}; CHECK(!isString(&union_)); } TEST_CASE("isBoolean_on_boolean_singletons") { - TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + Type trueBool{SingletonType{BooleanSingleton{true}}}; CHECK(isBoolean(&trueBool)); } TEST_CASE("isBoolean_on_unions_of_true_or_false_singletons") { - TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; - TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; - TypeVar union_{UnionTypeVar{{&trueBool, &falseBool}}}; + Type trueBool{SingletonType{BooleanSingleton{true}}}; + Type falseBool{SingletonType{BooleanSingleton{false}}}; + Type union_{UnionType{{&trueBool, &falseBool}}}; CHECK(isBoolean(&union_)); } TEST_CASE("proof_that_isBoolean_uses_all_of") { - TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; - TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; - TypeVar stringType{PrimitiveTypeVar{PrimitiveTypeVar::String}}; - TypeVar union_{UnionTypeVar{{&trueBool, &falseBool, &stringType}}}; + Type trueBool{SingletonType{BooleanSingleton{true}}}; + Type falseBool{SingletonType{BooleanSingleton{false}}}; + Type stringType{PrimitiveType{PrimitiveType::String}}; + Type union_{UnionType{{&trueBool, &falseBool, &stringType}}}; CHECK(!isBoolean(&union_)); } TEST_CASE("content_reassignment") { - TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; + Type myAny{AnyType{}, /*presistent*/ true}; myAny.documentationSymbol = "@global/any"; TypeArena arena; - TypeId futureAny = arena.addType(FreeTypeVar{TypeLevel{}}); + TypeId futureAny = arena.addType(FreeType{TypeLevel{}}); asMutable(futureAny)->reassign(myAny); - CHECK(get(futureAny) != nullptr); + CHECK(get(futureAny) != nullptr); CHECK(!futureAny->persistent); CHECK(futureAny->documentationSymbol == "@global/any"); CHECK(futureAny->owningArena == &arena); diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitType.test.cpp similarity index 100% rename from tests/VisitTypeVar.test.cpp rename to tests/VisitType.test.cpp diff --git a/tools/faillist.txt b/tools/faillist.txt index 5d6779f48..233c75c1a 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -4,11 +4,10 @@ AnnotationTests.for_loop_counter_annotation_is_checked AnnotationTests.generic_aliases_are_cloned_properly AnnotationTests.instantiation_clone_has_to_follow AnnotationTests.luau_print_is_not_special_without_the_flag -AnnotationTests.occurs_check_on_cyclic_intersection_typevar -AnnotationTests.occurs_check_on_cyclic_union_typevar +AnnotationTests.occurs_check_on_cyclic_intersection_type +AnnotationTests.occurs_check_on_cyclic_union_type AnnotationTests.too_many_type_params AnnotationTests.two_type_params -AnnotationTests.unknown_type_reference_generates_error AstQuery.last_argument_function_call_type AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method AstQuery::getDocumentationSymbolAtPosition.overloaded_fn diff --git a/tools/natvis/Analysis.natvis b/tools/natvis/Analysis.natvis index 7d03dd3f9..ca66cbe2c 100644 --- a/tools/natvis/Analysis.natvis +++ b/tools/natvis/Analysis.natvis @@ -1,8 +1,8 @@ - - AnyTypeVar + + AnyType From ee364a33b4e6353e1e5e1f049946bc1e042564d8 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Wed, 4 Jan 2023 00:31:14 +0200 Subject: [PATCH 23/66] Fixed iterator invalidation issue --- Analysis/src/Normalize.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index b66546595..e19d48f80 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1835,6 +1835,7 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th it = heres.ordering.erase(it); heres.classes.erase(hereTy); heres.pushPair(there, std::move(negations)); + break; } // If the incoming class is a superclass of the current class, we don't // insert it into the map. From 1958676f2988324d5c996d23c1ed1106c5f58913 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Wed, 4 Jan 2023 14:45:18 +0200 Subject: [PATCH 24/66] Re-using uncleared normalizer in unsafe --- tests/Normalize.test.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index ba9f5c525..2ee82623b 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -408,6 +408,7 @@ struct NormalizeFixture : Fixture const NormalizedType* toNormalizedType(const std::string& annotation) { + normalizer.clearCaches(); CheckResult result = check("type _Res = " + annotation); LUAU_REQUIRE_NO_ERRORS(result); std::optional ty = lookupType("_Res"); From 36f5009026e13fc5192943280a290198f7b81e0c Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 6 Jan 2023 18:07:19 +0200 Subject: [PATCH 25/66] Sync to upstream/release/558 --- Analysis/include/Luau/Constraint.h | 6 +- .../include/Luau/ConstraintGraphBuilder.h | 4 +- Analysis/include/Luau/Module.h | 15 +- Analysis/include/Luau/RecursionCounter.h | 2 +- Analysis/include/Luau/Type.h | 16 +- Analysis/include/Luau/TypeReduction.h | 40 + Analysis/src/AstQuery.cpp | 4 +- Analysis/src/Autocomplete.cpp | 36 +- Analysis/src/BuiltinDefinitions.cpp | 90 +- Analysis/src/ConstraintGraphBuilder.cpp | 21 +- Analysis/src/ConstraintSolver.cpp | 7 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 8 +- Analysis/src/Frontend.cpp | 46 +- Analysis/src/Module.cpp | 32 +- Analysis/src/Normalize.cpp | 1 + Analysis/src/ToString.cpp | 7 +- Analysis/src/Type.cpp | 3 +- Analysis/src/TypeAttach.cpp | 7 +- Analysis/src/TypeChecker2.cpp | 103 +- Analysis/src/TypeInfer.cpp | 80 +- Analysis/src/TypeReduction.cpp | 838 +++++++++++ Compiler/src/Compiler.cpp | 5 +- Sources.cmake | 3 + VM/include/lua.h | 6 +- tests/Autocomplete.test.cpp | 32 + tests/Compiler.test.cpp | 40 +- tests/ConstraintGraphBuilderFixture.cpp | 4 + tests/Fixture.cpp | 78 +- tests/Fixture.h | 4 +- tests/Frontend.test.cpp | 20 +- tests/Module.test.cpp | 24 +- tests/NonstrictMode.test.cpp | 4 +- tests/Normalize.test.cpp | 62 +- tests/TypeInfer.aliases.test.cpp | 19 +- tests/TypeInfer.annotations.test.cpp | 22 +- tests/TypeInfer.functions.test.cpp | 10 +- tests/TypeInfer.generics.test.cpp | 4 +- tests/TypeInfer.negations.test.cpp | 2 +- tests/TypeInfer.operators.test.cpp | 40 +- tests/TypeInfer.provisional.test.cpp | 16 +- tests/TypeInfer.tables.test.cpp | 26 +- tests/TypeInfer.test.cpp | 4 +- tests/TypeInfer.typePacks.cpp | 16 +- tests/TypeInfer.unionTypes.test.cpp | 2 - tests/TypeReduction.test.cpp | 1249 +++++++++++++++++ tests/TypeVar.test.cpp | 3 +- tools/faillist.txt | 4 +- 47 files changed, 2684 insertions(+), 381 deletions(-) create mode 100644 Analysis/include/Luau/TypeReduction.h create mode 100644 Analysis/src/TypeReduction.cpp create mode 100644 tests/TypeReduction.test.cpp diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index b41329548..ec94eee96 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -71,9 +71,9 @@ struct BinaryConstraint // When we dispatch this constraint, we update the key at this map to record // the overload that we selected. - AstExpr* expr; - DenseHashMap* astOriginalCallTypes; - DenseHashMap* astOverloadResolvedTypes; + const void* astFragment; + DenseHashMap* astOriginalCallTypes; + DenseHashMap* astOverloadResolvedTypes; }; // iteratee is iterable diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 65ea5e093..3a67610a8 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -82,11 +82,11 @@ struct ConstraintGraphBuilder // If the node was applied as a function, this is the unspecialized type of // that expression. - DenseHashMap astOriginalCallTypes{nullptr}; + DenseHashMap astOriginalCallTypes{nullptr}; // If overload resolution was performed on this element, this is the // overload that was selected. - DenseHashMap astOverloadResolvedTypes{nullptr}; + DenseHashMap astOverloadResolvedTypes{nullptr}; // Types resolved from type annotations. Analogous to astTypes. DenseHashMap astResolvedTypes{nullptr}; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index d6d9f841b..2cd6802e3 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -73,19 +73,30 @@ struct Module DenseHashMap astTypes{nullptr}; DenseHashMap astTypePacks{nullptr}; DenseHashMap astExpectedTypes{nullptr}; - DenseHashMap astOriginalCallTypes{nullptr}; - DenseHashMap astOverloadResolvedTypes{nullptr}; + + // Pointers are either AstExpr or AstStat. + DenseHashMap astOriginalCallTypes{nullptr}; + + // Pointers are either AstExpr or AstStat. + DenseHashMap astOverloadResolvedTypes{nullptr}; + DenseHashMap astResolvedTypes{nullptr}; DenseHashMap astResolvedTypePacks{nullptr}; // Map AST nodes to the scope they create. Cannot be NotNull because we need a sentinel value for the map. DenseHashMap astScopes{nullptr}; + std::unique_ptr reduction; + std::unordered_map declaredGlobals; ErrorVec errors; Mode mode; SourceCode::Type type; bool timeout = false; + TypePackId returnType = nullptr; + std::unordered_map exportedTypeBindings; + + bool hasModuleScope() const; ScopePtr getModuleScope() const; // Once a module has been typechecked, we clone its public interface into a separate arena. diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index 77af10a0a..0dc557009 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -32,7 +32,7 @@ struct RecursionCounter --(*count); } -private: +protected: int* count; }; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index fcc073d8b..734d40eac 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -494,13 +494,13 @@ struct AnyType { }; -// T | U +// `T | U` struct UnionType { std::vector options; }; -// T & U +// `T & U` struct IntersectionType { std::vector parts; @@ -519,9 +519,7 @@ struct NeverType { }; -// ~T -// TODO: Some simplification step that overwrites the type graph to make sure negation -// types disappear from the user's view, and (?) a debug flag to disable that +// `~T` struct NegationType { TypeId ty; @@ -676,6 +674,8 @@ TypeLevel* getMutableLevel(TypeId ty); std::optional getLevel(TypePackId tp); const Property* lookupClassProp(const ClassType* cls, const Name& name); + +// Whether `cls` is a subclass of `parent` bool isSubclass(const ClassType* cls, const ClassType* parent); Type* asMutable(TypeId ty); @@ -767,7 +767,7 @@ struct TypeIterator return !(*this == rhs); } - const TypeId& operator*() + TypeId operator*() { descend(); @@ -779,8 +779,8 @@ struct TypeIterator const std::vector& types = getTypes(t); LUAU_ASSERT(currentIndex < types.size()); - const TypeId& ty = types[currentIndex]; - LUAU_ASSERT(!get(follow(ty))); + TypeId ty = follow(types[currentIndex]); + LUAU_ASSERT(!get(ty)); return ty; } diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h new file mode 100644 index 000000000..7df7edfa7 --- /dev/null +++ b/Analysis/include/Luau/TypeReduction.h @@ -0,0 +1,40 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypePack.h" +#include "Luau/Variant.h" + +namespace Luau +{ + +/// If it's desirable to allocate into a different arena than the TypeReduction instance you have, you will need +/// to create a temporary TypeReduction in that case. This is because TypeReduction caches the reduced type. +struct TypeReduction +{ + explicit TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle); + + std::optional reduce(TypeId ty); + std::optional reduce(TypePackId tp); + std::optional reduce(const TypeFun& fun); + +private: + NotNull arena; + NotNull builtinTypes; + NotNull handle; + + DenseHashMap cachedTypes{nullptr}; + DenseHashMap cachedTypePacks{nullptr}; + + std::optional reduceImpl(TypeId ty); + std::optional reduceImpl(TypePackId tp); + + // Computes an *estimated length* of the cartesian product of the given type. + size_t cartesianProductSize(TypeId ty) const; + + bool hasExceededCartesianProductLimit(TypeId ty) const; + bool hasExceededCartesianProductLimit(TypePackId tp) const; +}; + +} // namespace Luau diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 39f613e55..ffab734ab 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -256,7 +256,8 @@ AstExpr* findExprAtPosition(const SourceModule& source, Position pos) ScopePtr findScopeAtPosition(const Module& module, Position pos) { - LUAU_ASSERT(!module.scopes.empty()); + if (module.scopes.empty()) + return nullptr; Location scopeLocation = module.scopes.front().first; ScopePtr scope = module.scopes.front().second; @@ -320,7 +321,6 @@ std::optional findBindingAtPosition(const Module& module, const SourceM return std::nullopt; ScopePtr currentScope = findScopeAtPosition(module, pos); - LUAU_ASSERT(currentScope); while (currentScope) { diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 7a649546d..49c430e63 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -150,6 +150,8 @@ static TypeCorrectKind checkTypeCorrectKind( { ty = follow(ty); + LUAU_ASSERT(module.hasModuleScope()); + NotNull moduleScope{module.getModuleScope().get()}; auto typeAtPosition = findExpectedTypeAt(module, node, position); @@ -182,8 +184,7 @@ static TypeCorrectKind checkTypeCorrectKind( } } - return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes) ? TypeCorrectKind::Correct - : TypeCorrectKind::None; + return checkTypeMatch(ty, expectedType, moduleScope, typeArena, builtinTypes) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } enum class PropIndexType @@ -1328,13 +1329,11 @@ static std::optional autocompleteStringParams(const Source } static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull builtinTypes, - Scope* globalScope, Position position, StringCompletionCallback callback) + TypeArena* typeArena, Scope* globalScope, Position position, StringCompletionCallback callback) { if (isWithinComment(sourceModule, position)) return {}; - TypeArena typeArena; - std::vector ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position); LUAU_ASSERT(!ancestry.empty()); AstNode* node = ancestry.back(); @@ -1360,7 +1359,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - return {autocompleteProps(*module, &typeArena, builtinTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; + return {autocompleteProps(*module, typeArena, builtinTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; } else if (auto typeReference = node->as()) { @@ -1378,7 +1377,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Unknown}; else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) - return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); else return {}; } @@ -1392,7 +1391,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); return {}; } @@ -1422,7 +1421,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; if (lastExpr->location.containsClosed(position)) - return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); if (position > lastExpr->location.end) return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; @@ -1446,7 +1445,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; if (!statWhile->hasDo || position < statWhile->doLocation.begin) - return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); if (statWhile->hasDo && position > statWhile->doLocation.end) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1463,7 +1462,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatIf* statIf = parent->as(); statIf && node->is()) { if (statIf->condition->is()) - return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; } @@ -1471,7 +1470,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) - return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; else if (AstExprTable* exprTable = parent->as(); @@ -1484,7 +1483,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, &typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); + auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); if (FFlag::LuauCompleteTableKeysBetter) { @@ -1518,7 +1517,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // If we know for sure that a key is being written, do not offer general expression suggestions if (!key) - autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position, result); + autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position, result); return {result, ancestry, AutocompleteContext::Property}; } @@ -1546,7 +1545,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) { if (auto it = module->astTypes.find(idxExpr->expr)) - autocompleteProps(*module, &typeArena, builtinTypes, follow(*it), PropIndexType::Point, ancestry, result); + autocompleteProps(*module, typeArena, builtinTypes, follow(*it), PropIndexType::Point, ancestry, result); } else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) { @@ -1572,7 +1571,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {}; if (node->asExpr()) - return autocompleteExpression(sourceModule, *module, builtinTypes, &typeArena, ancestry, position); + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); else if (node->asStat()) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1599,9 +1598,8 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName NotNull builtinTypes = frontend.builtinTypes; Scope* globalScope = frontend.typeCheckerForAutocomplete.globalScope.get(); - AutocompleteResult autocompleteResult = autocomplete(*sourceModule, module, builtinTypes, globalScope, position, callback); - - return autocompleteResult; + TypeArena typeArena; + return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); } } // namespace Luau diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 81702ff65..26aaf54fe 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -18,9 +18,7 @@ LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false) -LUAU_FASTFLAG(LuauOptionalNextKey) LUAU_FASTFLAG(LuauReportShadowedTypeAlias) -LUAU_FASTFLAG(LuauNewLibraryTypeNames) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -289,38 +287,18 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - if (FFlag::LuauOptionalNextKey) - { - // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(typeChecker, arena, genericK), genericV}}); - addGlobalBinding(typeChecker, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - - TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + // next(t: Table, i: K?) -> (K?, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(typeChecker, arena, genericK), genericV}}); + addGlobalBinding(typeChecker, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding( - typeChecker, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); - } - else - { - // next(t: Table, i: K?) -> (K, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); - addGlobalBinding(typeChecker, "next", - arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - - // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding( - typeChecker, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); - } + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericType{"MT"}); @@ -352,12 +330,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) if (TableType* ttv = getMutable(pair.second.typeId)) { if (!ttv->name) - { - if (FFlag::LuauNewLibraryTypeNames) - ttv->name = "typeof(" + toString(pair.first) + ")"; - else - ttv->name = toString(pair.first); - } + ttv->name = "typeof(" + toString(pair.first) + ")"; } } @@ -408,36 +381,18 @@ void registerBuiltinGlobals(Frontend& frontend) addGlobalBinding(frontend, "string", it->second.type, "@luau"); - if (FFlag::LuauOptionalNextKey) - { - // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(frontend, arena, genericK), genericV}}); - addGlobalBinding(frontend, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - - TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.builtinTypes->nilType}}); + // next(t: Table, i: K?) -> (K?, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(frontend, arena, genericK), genericV}}); + addGlobalBinding(frontend, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) - addGlobalBinding(frontend, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); - } - else - { - // next(t: Table, i: K?) -> (K, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); - addGlobalBinding(frontend, "next", - arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.builtinTypes->nilType}}); - TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.builtinTypes->nilType}}); - - // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding(frontend, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); - } + // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) + addGlobalBinding(frontend, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericType{"MT"}); @@ -469,12 +424,7 @@ void registerBuiltinGlobals(Frontend& frontend) if (TableType* ttv = getMutable(pair.second.typeId)) { if (!ttv->name) - { - if (FFlag::LuauNewLibraryTypeNames) - ttv->name = "typeof(" + toString(pair.first) + ")"; - else - ttv->name = toString(pair.first); - } + ttv->name = "typeof(" + toString(pair.first) + ")"; } } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 256eba54d..6a80fed2e 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -15,6 +15,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauNegatedClassTypes); +LUAU_FASTFLAG(LuauScopelessModule); namespace Luau { @@ -520,7 +521,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) const Name name{local->vars.data[i]->name.value}; if (ModulePtr module = moduleResolver->getModule(moduleInfo->name)) - scope->importedTypeBindings[name] = module->getModuleScope()->exportedTypeBindings; + scope->importedTypeBindings[name] = + FFlag::LuauScopelessModule ? module->exportedTypeBindings : module->getModuleScope()->exportedTypeBindings; } } } @@ -733,16 +735,15 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) { - // Synthesize A = A op B from A op= B and then build constraints for that instead. + // We need to tweak the BinaryConstraint that we emit, so we cannot use the + // strategy of falsifying an AST fragment. + TypeId varId = checkLValue(scope, assign->var); + Inference valueInf = check(scope, assign->value); - AstExprBinary exprBinary{assign->location, assign->op, assign->var, assign->value}; - AstExpr* exprBinaryPtr = &exprBinary; - - AstArray vars{&assign->var, 1}; - AstArray values{&exprBinaryPtr, 1}; - AstStatAssign syntheticAssign{assign->location, vars, values}; - - visit(scope, &syntheticAssign); + TypeId resultType = arena->addType(BlockedType{}); + addConstraint(scope, assign->location, + BinaryConstraint{assign->op, varId, valueInf.ty, resultType, assign, &astOriginalCallTypes, &astOverloadResolvedTypes}); + addConstraint(scope, assign->location, SubtypeConstraint{resultType, varId}); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 67c1732c1..8092144cc 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -18,6 +18,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); +LUAU_FASTFLAG(LuauScopelessModule); namespace Luau { @@ -681,8 +682,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(mmResult); unblock(resultType); - (*c.astOriginalCallTypes)[c.expr] = *mm; - (*c.astOverloadResolvedTypes)[c.expr] = *instantiatedMm; + (*c.astOriginalCallTypes)[c.astFragment] = *mm; + (*c.astOverloadResolvedTypes)[c.astFragment] = *instantiatedMm; return true; } } @@ -1895,7 +1896,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l return errorRecoveryType(); } - TypePackId modulePack = module->getModuleScope()->returnType; + TypePackId modulePack = FFlag::LuauScopelessModule ? module->returnType : module->getModuleScope()->returnType; if (get(modulePack)) return errorRecoveryType(); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index b0f21737f..1fe09773c 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,7 +2,6 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAG(LuauUnknownAndNeverType) -LUAU_FASTFLAG(LuauOptionalNextKey) namespace Luau { @@ -127,7 +126,7 @@ declare function rawlen(obj: {[K]: V} | string): number declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? --- TODO: place ipairs definition here with removal of FFlagLuauOptionalNextKey +declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number) declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) @@ -208,11 +207,6 @@ std::string getBuiltinDefinitionSource() else result += "declare function error(message: T, level: number?)\n"; - if (FFlag::LuauOptionalNextKey) - result += "declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number)\n"; - else - result += "declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number)\n"; - return result; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 5d2c15871..f4e529dbe 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -16,6 +16,7 @@ #include "Luau/TimeTrace.h" #include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" +#include "Luau/TypeReduction.h" #include "Luau/Variant.h" #include @@ -30,6 +31,7 @@ LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAG(LuauScopelessModule); namespace Luau { @@ -111,7 +113,9 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c CloneState cloneState; std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); + typesToPersist.reserve( + checkedModule->declaredGlobals.size() + + (FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings.size() : checkedModule->getModuleScope()->exportedTypeBindings.size())); for (const auto& [name, ty] : checkedModule->declaredGlobals) { @@ -123,7 +127,8 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c typesToPersist.push_back(globalTy); } - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + for (const auto& [name, ty] : + FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings : checkedModule->getModuleScope()->exportedTypeBindings) { TypeFun globalTy = clone(ty, globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; @@ -168,7 +173,9 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t CloneState cloneState; std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); + typesToPersist.reserve( + checkedModule->declaredGlobals.size() + + (FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings.size() : checkedModule->getModuleScope()->exportedTypeBindings.size())); for (const auto& [name, ty] : checkedModule->declaredGlobals) { @@ -180,7 +187,8 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t typesToPersist.push_back(globalTy); } - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + for (const auto& [name, ty] : + FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings : checkedModule->getModuleScope()->exportedTypeBindings) { TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; @@ -562,12 +570,29 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalinterfaceTypes); module->internalTypes.clear(); - module->astTypes.clear(); - module->astExpectedTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astResolvedTypes.clear(); - module->astResolvedTypePacks.clear(); - module->scopes.resize(1); + + if (FFlag::LuauScopelessModule) + { + module->astTypes.clear(); + module->astTypePacks.clear(); + module->astExpectedTypes.clear(); + module->astOriginalCallTypes.clear(); + module->astOverloadResolvedTypes.clear(); + module->astResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); + module->astScopes.clear(); + + module->scopes.clear(); + } + else + { + module->astTypes.clear(); + module->astExpectedTypes.clear(); + module->astOriginalCallTypes.clear(); + module->astResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); + module->scopes.resize(1); + } } if (mode != Mode::NoCheck) @@ -852,6 +877,7 @@ ModulePtr Frontend::check( const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector requireCycles, bool forAutocomplete) { ModulePtr result = std::make_shared(); + result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, NotNull{&iceHandler}); std::unique_ptr logger; if (FFlag::DebugLuauLogSolverToJson) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index a73b928bf..e54a44936 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -7,9 +7,10 @@ #include "Luau/Normalize.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" +#include "Luau/Type.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/Type.h" +#include "Luau/TypeReduction.h" #include "Luau/VisitType.h" #include @@ -17,6 +18,7 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); +LUAU_FASTFLAG(LuauScopelessModule); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); @@ -189,7 +191,6 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr TypePackId returnType = moduleScope->returnType; std::optional varargPack = FFlag::DebugLuauDeferredConstraintResolution ? std::nullopt : moduleScope->varargPack; - std::unordered_map* exportedTypeBindings = &moduleScope->exportedTypeBindings; TxnLog log; ClonePublicInterface clonePublicInterface{&log, builtinTypes, this}; @@ -209,15 +210,12 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr moduleScope->varargPack = varargPack; } - if (exportedTypeBindings) + for (auto& [name, tf] : moduleScope->exportedTypeBindings) { - for (auto& [name, tf] : *exportedTypeBindings) - { - if (FFlag::LuauClonePublicInterfaceLess) - tf = clonePublicInterface.cloneTypeFun(tf); - else - tf = clone(tf, interfaceTypes, cloneState); - } + if (FFlag::LuauClonePublicInterfaceLess) + tf = clonePublicInterface.cloneTypeFun(tf); + else + tf = clone(tf, interfaceTypes, cloneState); } for (auto& [name, ty] : declaredGlobals) @@ -228,13 +226,25 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr ty = clone(ty, interfaceTypes, cloneState); } + // Copy external stuff over to Module itself + if (FFlag::LuauScopelessModule) + { + this->returnType = moduleScope->returnType; + this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); + } + freeze(internalTypes); freeze(interfaceTypes); } +bool Module::hasModuleScope() const +{ + return !scopes.empty(); +} + ScopePtr Module::getModuleScope() const { - LUAU_ASSERT(!scopes.empty()); + LUAU_ASSERT(hasModuleScope()); return scopes.front().second; } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index b66546595..e19d48f80 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1835,6 +1835,7 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th it = heres.ordering.erase(it); heres.classes.erase(hereTy); heres.pushPair(there, std::move(negations)); + break; } // If the incoming class is a superclass of the current class, we don't // insert it into the map. diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index e80085089..c7d9b3733 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -26,6 +26,7 @@ LUAU_FASTFLAGVARIABLE(LuauSerializeNilUnionAsNil, false) * Fair warning: Setting this will break a lot of Luau unit tests. */ LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) +LUAU_FASTFLAGVARIABLE(DebugLuauToStringNoLexicalSort, false) namespace Luau { @@ -755,7 +756,8 @@ struct TypeStringifier state.unsee(&uv); - std::sort(results.begin(), results.end()); + if (!FFlag::DebugLuauToStringNoLexicalSort) + std::sort(results.begin(), results.end()); if (optional && results.size() > 1) state.emit("("); @@ -820,7 +822,8 @@ struct TypeStringifier state.unsee(&uv); - std::sort(results.begin(), results.end()); + if (!FFlag::DebugLuauToStringNoLexicalSort) + std::sort(results.begin(), results.end()); bool first = true; for (std::string& ss : results) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index aba6bddc1..f03061a85 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -26,7 +26,6 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) -LUAU_FASTFLAGVARIABLE(LuauNewLibraryTypeNames, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) namespace Luau @@ -865,7 +864,7 @@ TypeId BuiltinTypes::makeStringMetatable() TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); if (TableType* ttv = getMutable(tableType)) - ttv->name = FFlag::LuauNewLibraryTypeNames ? "typeof(string)" : "string"; + ttv->name = "typeof(string)"; return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index d1d89b25a..f9a162056 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -487,8 +487,11 @@ class TypeAttacher : public AstVisitor AstType* annotation = local->annotation; if (!annotation) { - if (auto result = getScope(local->location)->lookup(local)) - local->annotation = typeAst(*result); + if (auto scope = getScope(local->location)) + { + if (auto result = scope->lookup(local)) + local->annotation = typeAst(*result); + } } return true; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 5451a454e..1d212851a 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -626,8 +626,11 @@ struct TypeChecker2 void visit(AstStatCompoundAssign* stat) { - visit(stat->var); - visit(stat->value); + AstExprBinary fake{stat->location, stat->op, stat->var, stat->value}; + TypeId resultTy = visit(&fake, stat); + TypeId varTy = lookupType(stat->var); + + reportErrors(tryUnify(stack.back(), stat->location, resultTy, varTy)); } void visit(AstStatFunction* stat) @@ -737,7 +740,10 @@ struct TypeChecker2 else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) - return visit(e); + { + visit(e); + return; + } else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) @@ -1045,7 +1051,7 @@ struct TypeChecker2 } } - void visit(AstExprBinary* expr) + TypeId visit(AstExprBinary* expr, void* overrideKey = nullptr) { visit(expr->left); visit(expr->right); @@ -1066,8 +1072,10 @@ struct TypeChecker2 bool isStringOperation = isString(leftType) && isString(rightType); - if (get(leftType) || get(leftType) || get(rightType) || get(rightType)) - return; + if (get(leftType) || get(leftType)) + return leftType; + else if (get(rightType) || get(rightType)) + return rightType; if ((get(leftType) || get(leftType)) && !isEquality && !isLogical) { @@ -1075,14 +1083,13 @@ struct TypeChecker2 reportError(CannotInferBinaryOperation{expr->op, name, isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation}, expr->location); - return; + return leftType; } if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) { std::optional leftMt = getMetatable(leftType, builtinTypes); std::optional rightMt = getMetatable(rightType, builtinTypes); - bool matches = leftMt == rightMt; if (isEquality && !matches) { @@ -1114,7 +1121,7 @@ struct TypeChecker2 toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, expr->location); - return; + return builtinTypes->errorRecoveryType(); } std::optional mm; @@ -1128,7 +1135,11 @@ struct TypeChecker2 if (mm) { - TypeId instantiatedMm = module->astOverloadResolvedTypes[expr]; + void* key = expr; + if (overrideKey != nullptr) + key = overrideKey; + + TypeId instantiatedMm = module->astOverloadResolvedTypes[key]; if (!instantiatedMm) reportError(CodeTooComplex{}, expr->location); @@ -1146,20 +1157,50 @@ struct TypeChecker2 expectedArgs = testArena.addTypePack({leftType, rightType}); } - reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); - + TypePackId expectedRets; if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe || expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) { - TypePackId expectedRets = testArena.addTypePack({builtinTypes->booleanType}); - if (!isSubtype(ftv->retTypes, expectedRets, scope)) + expectedRets = testArena.addTypePack({builtinTypes->booleanType}); + } + else + { + expectedRets = testArena.addTypePack({testArena.freshType(scope, TypeLevel{})}); + } + + TypeId expectedTy = testArena.addType(FunctionType(expectedArgs, expectedRets)); + + reportErrors(tryUnify(scope, expr->location, follow(*mm), expectedTy)); + + std::optional ret = first(ftv->retTypes); + if (ret) + { + if (isComparison) { - reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location); + if (!isBoolean(follow(*ret))) + { + reportError(GenericError{format("Metamethod '%s' must return a boolean", it->second)}, expr->location); + } + + return builtinTypes->booleanType; + } + else + { + return follow(*ret); } } - else if (!first(ftv->retTypes)) + else { - reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + if (isComparison) + { + reportError(GenericError{format("Metamethod '%s' must return a boolean", it->second)}, expr->location); + } + else + { + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + } + + return builtinTypes->errorRecoveryType(); } } else @@ -1167,13 +1208,13 @@ struct TypeChecker2 reportError(CannotCallNonFunction{*mm}, expr->location); } - return; + return builtinTypes->errorRecoveryType(); } // If this is a string comparison, or a concatenation of strings, we // want to fall through to primitive behavior. else if (!isEquality && !(isStringOperation && (expr->op == AstExprBinary::Op::Concat || isComparison))) { - if (leftMt || rightMt) + if ((leftMt && !isString(leftType)) || (rightMt && !isString(rightType))) { if (isComparison) { @@ -1190,7 +1231,7 @@ struct TypeChecker2 expr->location); } - return; + return builtinTypes->errorRecoveryType(); } else if (!leftMt && !rightMt && (get(leftType) || get(rightType))) { @@ -1207,7 +1248,7 @@ struct TypeChecker2 expr->location); } - return; + return builtinTypes->errorRecoveryType(); } } } @@ -1223,34 +1264,44 @@ struct TypeChecker2 reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->numberType)); reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); - break; + return builtinTypes->numberType; case AstExprBinary::Op::Concat: reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->stringType)); reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->stringType)); - break; + return builtinTypes->stringType; case AstExprBinary::Op::CompareGe: case AstExprBinary::Op::CompareGt: case AstExprBinary::Op::CompareLe: case AstExprBinary::Op::CompareLt: if (isNumber(leftType)) + { reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); + return builtinTypes->numberType; + } else if (isString(leftType)) + { reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->stringType)); + return builtinTypes->stringType; + } else + { reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, expr->location); - - break; + return builtinTypes->errorRecoveryType(); + } case AstExprBinary::Op::And: case AstExprBinary::Op::Or: case AstExprBinary::Op::CompareEq: case AstExprBinary::Op::CompareNe: - break; + // Ugly case: we don't care about this possibility, because a + // compound assignment will never exist with one of these operators. + return builtinTypes->anyType; default: // Unhandled AstExprBinary::Op possibility. LUAU_ASSERT(false); + return builtinTypes->errorRecoveryType(); } } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index f31ea9381..5c1ee3888 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -16,9 +16,10 @@ #include "Luau/TopoSortStatements.h" #include "Luau/ToString.h" #include "Luau/ToString.h" +#include "Luau/Type.h" #include "Luau/TypePack.h" +#include "Luau/TypeReduction.h" #include "Luau/TypeUtils.h" -#include "Luau/Type.h" #include "Luau/VisitType.h" #include @@ -35,17 +36,16 @@ LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) -LUAU_FASTFLAGVARIABLE(LuauNilIterator, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauTypeInferMissingFollows, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) +LUAU_FASTFLAGVARIABLE(LuauScopelessModule, false) LUAU_FASTFLAGVARIABLE(LuauFollowInLvalueIndexCheck, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) -LUAU_FASTFLAGVARIABLE(LuauOptionalNextKey, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) @@ -276,6 +276,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); currentModule.reset(new Module); + currentModule->reduction = std::make_unique(NotNull{¤tModule->internalTypes}, builtinTypes, NotNull{iceHandler}); currentModule->type = module.type; currentModule->allocator = module.allocator; currentModule->names = module.names; @@ -1136,7 +1137,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) const Name name{local.vars.data[i]->name.value}; if (ModulePtr module = resolver->getModule(moduleInfo->name)) - scope->importedTypeBindings[name] = module->getModuleScope()->exportedTypeBindings; + scope->importedTypeBindings[name] = + FFlag::LuauScopelessModule ? module->exportedTypeBindings : module->getModuleScope()->exportedTypeBindings; // In non-strict mode we force the module type on the variable, in strict mode it is already unified if (isNonstrictMode()) @@ -1248,8 +1250,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } - if (FFlag::LuauNilIterator) - iterTy = stripFromNilAndReport(iterTy, firstValue->location); + iterTy = stripFromNilAndReport(iterTy, firstValue->location); if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location, /* addErrors= */ true)) { @@ -1334,61 +1335,40 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) reportErrors(state.errors); } - if (FFlag::LuauOptionalNextKey) + TypePackId retPack = iterFunc->retTypes; + + if (forin.values.size >= 2) { - TypePackId retPack = iterFunc->retTypes; + AstArray arguments{forin.values.data + 1, forin.values.size - 1}; - if (forin.values.size >= 2) - { - AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + Position start = firstValue->location.begin; + Position end = values[forin.values.size - 1]->location.end; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; - Position start = firstValue->location.begin; - Position end = values[forin.values.size - 1]->location.end; - AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + retPack = checkExprPack(scope, exprCall).type; + } - retPack = checkExprPack(scope, exprCall).type; - } + // We need to remove 'nil' from the set of options of the first return value + // Because for loop stops when it gets 'nil', this result is never actually assigned to the first variable + if (std::optional fty = first(retPack); fty && !varTypes.empty()) + { + TypeId keyTy = follow(*fty); - // We need to remove 'nil' from the set of options of the first return value - // Because for loop stops when it gets 'nil', this result is never actually assigned to the first variable - if (std::optional fty = first(retPack); fty && !varTypes.empty()) + if (get(keyTy)) { - TypeId keyTy = follow(*fty); - - if (get(keyTy)) - { - if (std::optional ty = tryStripUnionFromNil(keyTy)) - keyTy = *ty; - } - - unify(keyTy, varTypes.front(), scope, forin.location); - - // We have already handled the first variable type, make it match in the pack check - varTypes.front() = *fty; + if (std::optional ty = tryStripUnionFromNil(keyTy)) + keyTy = *ty; } - TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); + unify(keyTy, varTypes.front(), scope, forin.location); - unify(retPack, varPack, scope, forin.location); + // We have already handled the first variable type, make it match in the pack check + varTypes.front() = *fty; } - else - { - TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); - if (forin.values.size >= 2) - { - AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); - Position start = firstValue->location.begin; - Position end = values[forin.values.size - 1]->location.end; - AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; - - TypePackId retPack = checkExprPack(scope, exprCall).type; - unify(retPack, varPack, scope, forin.location); - } - else - unify(iterFunc->retTypes, varPack, scope, forin.location); - } + unify(retPack, varPack, scope, forin.location); check(loopScope, *forin.body); } @@ -4685,7 +4665,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return errorRecoveryType(scope); } - TypePackId modulePack = module->getModuleScope()->returnType; + TypePackId modulePack = FFlag::LuauScopelessModule ? module->returnType : module->getModuleScope()->returnType; if (get(modulePack)) return errorRecoveryType(scope); diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp new file mode 100644 index 000000000..e47ee39cc --- /dev/null +++ b/Analysis/src/TypeReduction.cpp @@ -0,0 +1,838 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeReduction.h" + +#include "Luau/Common.h" +#include "Luau/Error.h" +#include "Luau/RecursionCounter.h" + +#include +#include + +LUAU_FASTINTVARIABLE(LuauTypeReductionCartesianProductLimit, 100'000) +LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 900) +LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) + +namespace Luau +{ + +namespace +{ + +struct RecursionGuard : RecursionLimiter +{ + std::deque* seen; + + RecursionGuard(int* count, int limit, std::deque* seen) + : RecursionLimiter(count, limit) + , seen(seen) + { + // count has been incremented, which should imply that seen has already had an element pushed in. + LUAU_ASSERT(*count == seen->size()); + } + + ~RecursionGuard() + { + LUAU_ASSERT(!seen->empty()); // It is UB to pop_back() on an empty deque. + seen->pop_back(); + } +}; + +template +std::pair get2(const Thing& one, const Thing& two) +{ + const A* a = get(one); + const B* b = get(two); + return a && b ? std::make_pair(a, b) : std::make_pair(nullptr, nullptr); +} + +struct TypeReducer +{ + NotNull arena; + NotNull builtinTypes; + NotNull handle; + + TypeId reduce(TypeId ty); + TypePackId reduce(TypePackId tp); + + std::optional intersectionType(TypeId left, TypeId right); + std::optional unionType(TypeId left, TypeId right); + TypeId tableType(TypeId ty); + TypeId functionType(TypeId ty); + TypeId negationType(TypeId ty); + + std::deque seen; + int depth = 0; + + RecursionGuard guard(TypeId ty); + RecursionGuard guard(TypePackId tp); + + std::unordered_map copies; + + template + LUAU_NOINLINE std::pair copy(TypeId ty, const T* t) + { + if (auto it = copies.find(ty); it != copies.end()) + return {it->second, getMutable(it->second)}; + + TypeId copiedTy = arena->addType(*t); + copies[ty] = copiedTy; + return {copiedTy, getMutable(copiedTy)}; + } + + using Folder = std::optional (TypeReducer::*)(TypeId, TypeId); + + template + void foldl_impl(Iter it, Iter endIt, Folder f, NotNull> result) + { + while (it != endIt) + { + bool replaced = false; + TypeId currentTy = reduce(*it); + RecursionGuard rg = guard(*it); + + // We're hitting a case where the `currentTy` returned a type that's the same as `T`. + // e.g. `(string?) & ~(false | nil)` became `(string?) & (~false & ~nil)` but the current iterator we're consuming doesn't know this. + // We will need to recurse and traverse that first. + if (auto t = get(currentTy)) + { + foldl_impl(begin(t), end(t), f, result); + ++it; + continue; + } + + auto resultIt = result->begin(); + while (resultIt != result->end()) + { + TypeId& ty = *resultIt; + + std::optional reduced = (this->*f)(ty, currentTy); + if (reduced && replaced) + { + // We want to erase any other elements that occurs after the first replacement too. + // e.g. `"a" | "b" | string` where `"a"` and `"b"` is in the `result` vector, then `string` replaces both `"a"` and `"b"`. + // If we don't erase redundant elements, `"b"` may be kept or be replaced by `string`, leaving us with `string | string`. + resultIt = result->erase(resultIt); + } + else if (reduced && !replaced) + { + ++resultIt; + replaced = true; + ty = *reduced; + } + else + { + ++resultIt; + continue; + } + } + + if (!replaced) + result->push_back(currentTy); + + ++it; + } + } + + template + TypeId foldl(Iter it, Iter endIt, Folder f) + { + std::vector result; + foldl_impl(it, endIt, f, NotNull{&result}); + if (result.size() == 1) + return result[0]; + else + return arena->addType(T{std::move(result)}); + } +}; + +TypeId TypeReducer::reduce(TypeId ty) +{ + ty = follow(ty); + + if (std::find(seen.begin(), seen.end(), ty) != seen.end()) + return ty; + + RecursionGuard rg = guard(ty); + + if (auto i = get(ty)) + return foldl(begin(i), end(i), &TypeReducer::intersectionType); + else if (auto u = get(ty)) + return foldl(begin(u), end(u), &TypeReducer::unionType); + else if (get(ty) || get(ty)) + return tableType(ty); + else if (get(ty)) + return functionType(ty); + else if (auto n = get(ty)) + return negationType(follow(n->ty)); + else + return ty; +} + +TypePackId TypeReducer::reduce(TypePackId tp) +{ + tp = follow(tp); + + if (std::find(seen.begin(), seen.end(), tp) != seen.end()) + return tp; + + RecursionGuard rg = guard(tp); + + TypePackIterator it = begin(tp); + + std::vector head; + while (it != end(tp)) + { + head.push_back(reduce(*it)); + ++it; + } + + std::optional tail = it.tail(); + if (tail) + { + if (auto vtp = get(follow(*it.tail()))) + tail = arena->addTypePack(VariadicTypePack{reduce(vtp->ty), vtp->hidden}); + } + + return arena->addTypePack(TypePack{std::move(head), tail}); +} + +std::optional TypeReducer::intersectionType(TypeId left, TypeId right) +{ + LUAU_ASSERT(!get(left)); + LUAU_ASSERT(!get(right)); + + if (get(left)) + return left; // never & T ~ never + else if (get(right)) + return right; // T & never ~ never + else if (get(left)) + return right; // unknown & T ~ T + else if (get(right)) + return left; // T & unknown ~ T + else if (get(left)) + return right; // any & T ~ T + else if (get(right)) + return left; // T & any ~ T + else if (get(left)) + return std::nullopt; // error & T ~ error & T + else if (get(right)) + return std::nullopt; // T & error ~ T & error + else if (auto ut = get(left)) + { + std::vector options; + for (TypeId option : ut) + { + if (auto result = intersectionType(option, right)) + options.push_back(*result); + else + options.push_back(arena->addType(IntersectionType{{option, right}})); + } + + return foldl(begin(options), end(options), &TypeReducer::unionType); // (A | B) & T ~ (A & T) | (B & T) + } + else if (get(right)) + return intersectionType(right, left); // T & (A | B) ~ (A | B) & T + else if (auto [p1, p2] = get2(left, right); p1 && p2) + { + if (p1->type == p2->type) + return left; // P1 & P2 ~ P1 iff P1 == P2 + else + return builtinTypes->neverType; // P1 & P2 ~ never iff P1 != P2 + } + else if (auto [p, s] = get2(left, right); p && s) + { + if (p->type == PrimitiveType::String && get(s)) + return right; // string & "A" ~ "A" + else if (p->type == PrimitiveType::Boolean && get(s)) + return right; // boolean & true ~ true + else + return builtinTypes->neverType; // string & true ~ never + } + else if (auto [s, p] = get2(left, right); s && p) + return intersectionType(right, left); // S & P ~ P & S + else if (auto [p, f] = get2(left, right); p && f) + { + if (p->type == PrimitiveType::Function) + return right; // function & () -> () ~ () -> () + else + return builtinTypes->neverType; // string & () -> () ~ never + } + else if (auto [f, p] = get2(left, right); f && p) + return intersectionType(right, left); // () -> () & P ~ P & () -> () + else if (auto [s1, s2] = get2(left, right); s1 && s2) + { + if (*s1 == *s2) + return left; // "a" & "a" ~ "a" + else + return builtinTypes->neverType; // "a" & "b" ~ never + } + else if (auto [c1, c2] = get2(left, right); c1 && c2) + { + if (isSubclass(c1, c2)) + return left; // Derived & Base ~ Derived + else if (isSubclass(c2, c1)) + return right; // Base & Derived ~ Derived + else + return builtinTypes->neverType; // Base & Unrelated ~ never + } + else if (auto [f1, f2] = get2(left, right); f1 && f2) + { + if (std::find(seen.begin(), seen.end(), left) != seen.end()) + return std::nullopt; + else if (std::find(seen.begin(), seen.end(), right) != seen.end()) + return std::nullopt; + + return std::nullopt; // TODO + } + else if (auto [t1, t2] = get2(left, right); t1 && t2) + { + if (t1->state == TableState::Free || t2->state == TableState::Free) + return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } + else if (t1->state == TableState::Generic || t2->state == TableState::Generic) + return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } + + if (std::find(seen.begin(), seen.end(), left) != seen.end()) + return std::nullopt; + else if (std::find(seen.begin(), seen.end(), right) != seen.end()) + return std::nullopt; + + TypeId resultTy = arena->addType(TableType{}); + TableType* table = getMutable(resultTy); + table->state = t1->state == TableState::Sealed || t2->state == TableState::Sealed ? TableState::Sealed : TableState::Unsealed; + + for (const auto& [name, prop] : t1->props) + { + // TODO: when t1 has properties, we should also intersect that with the indexer in t2 if it exists, + // even if we have the corresponding property in the other one. + if (auto other = t2->props.find(name); other != t2->props.end()) + { + std::vector parts{prop.type, other->second.type}; + TypeId propTy = foldl(begin(parts), end(parts), &TypeReducer::intersectionType); + if (get(propTy)) + return builtinTypes->neverType; // { p : string } & { p : number } ~ { p : string & number } ~ { p : never } ~ never + else + table->props[name] = {propTy}; // { p : string } & { p : ~"a" } ~ { p : string & ~"a" } + } + else + table->props[name] = prop; // { p : string } & {} ~ { p : string } + } + + for (const auto& [name, prop] : t2->props) + { + // TODO: And vice versa, t2 properties against t1 indexer if it exists, + // even if we have the corresponding property in the other one. + if (!t1->props.count(name)) + table->props[name] = prop; // {} & { p : string } ~ { p : string } + } + + if (t1->indexer && t2->indexer) + { + std::vector keyParts{t1->indexer->indexType, t2->indexer->indexType}; + TypeId keyTy = foldl(begin(keyParts), end(keyParts), &TypeReducer::intersectionType); + if (get(keyTy)) + return builtinTypes->neverType; // { [string]: _ } & { [number]: _ } ~ { [string & number]: _ } ~ { [never]: _ } ~ never + + std::vector valueParts{t1->indexer->indexResultType, t2->indexer->indexResultType}; + TypeId valueTy = foldl(begin(valueParts), end(valueParts), &TypeReducer::intersectionType); + if (get(valueTy)) + return builtinTypes->neverType; // { [_]: string } & { [_]: number } ~ { [_]: string & number } ~ { [_]: never } ~ never + + table->indexer = TableIndexer{keyTy, valueTy}; + } + else if (t1->indexer) + table->indexer = t1->indexer; // { [number]: boolean } & { p : string } ~ { p : string, [number]: boolean } + else if (t2->indexer) + table->indexer = t2->indexer; // { p : string } & { [number]: boolean } ~ { p : string, [number]: boolean } + + return resultTy; + } + else if (auto [mt, tt] = get2(left, right); mt && tt) + return std::nullopt; // TODO + else if (auto [tt, mt] = get2(left, right); tt && mt) + return intersectionType(right, left); // T & M ~ M & T + else if (auto [m1, m2] = get2(left, right); m1 && m2) + return std::nullopt; // TODO + else if (auto nl = get(left)) + { + // These should've been reduced already. + TypeId nlTy = follow(nl->ty); + LUAU_ASSERT(!get(nlTy)); + LUAU_ASSERT(!get(nlTy)); + LUAU_ASSERT(!get(nlTy)); + LUAU_ASSERT(!get(nlTy)); + LUAU_ASSERT(!get(nlTy)); + + if (auto [np, p] = get2(nlTy, right); np && p) + { + if (np->type == p->type) + return builtinTypes->neverType; // ~P1 & P2 ~ never iff P1 == P2 + else + return right; // ~P1 & P2 ~ P2 iff P1 != P2 + } + else if (auto [ns, s] = get2(nlTy, right); ns && s) + { + if (*ns == *s) + return builtinTypes->neverType; // ~"A" & "A" ~ never + else + return right; // ~"A" & "B" ~ "B" + } + else if (auto [ns, p] = get2(nlTy, right); ns && p) + { + if (get(ns) && p->type == PrimitiveType::String) + return std::nullopt; // ~"A" & string ~ ~"A" & string + else if (get(ns) && p->type == PrimitiveType::Boolean) + { + // Because booleans contain a fixed amount of values (2), we can do something cooler with this one. + const BooleanSingleton* b = get(ns); + return arena->addType(SingletonType{BooleanSingleton{!b->value}}); // ~false & boolean ~ true + } + else + return right; // ~"A" & number ~ number + } + else if (auto [np, s] = get2(nlTy, right); np && s) + { + if (np->type == PrimitiveType::String && get(s)) + return builtinTypes->neverType; // ~string & "A" ~ never + else if (np->type == PrimitiveType::Boolean && get(s)) + return builtinTypes->neverType; // ~boolean & true ~ never + else + return right; // ~P & "A" ~ "A" + } + else if (auto [np, f] = get2(nlTy, right); np && f) + { + if (np->type == PrimitiveType::Function) + return builtinTypes->neverType; // ~function & () -> () ~ never + else + return right; // ~string & () -> () ~ () -> () + } + else if (auto [nc, c] = get2(nlTy, right); nc && c) + { + if (isSubclass(nc, c)) + return std::nullopt; // ~Derived & Base ~ ~Derived & Base + else if (isSubclass(c, nc)) + return builtinTypes->neverType; // ~Base & Derived ~ never + else + return right; // ~Base & Unrelated ~ Unrelated + } + else + return std::nullopt; // TODO + } + else if (get(right)) + return intersectionType(right, left); // T & ~U ~ ~U & T + else + return builtinTypes->neverType; // for all T and U except the ones handled above, T & U ~ never +} + +std::optional TypeReducer::unionType(TypeId left, TypeId right) +{ + LUAU_ASSERT(!get(left)); + LUAU_ASSERT(!get(right)); + + if (get(left)) + return right; // never | T ~ T + else if (get(right)) + return left; // T | never ~ T + else if (get(left)) + return left; // unknown | T ~ unknown + else if (get(right)) + return right; // T | unknown ~ unknown + else if (get(left)) + return left; // any | T ~ any + else if (get(right)) + return right; // T | any ~ any + else if (get(left)) + return std::nullopt; // error | T ~ error | T + else if (get(right)) + return std::nullopt; // T | error ~ T | error + else if (auto [p1, p2] = get2(left, right); p1 && p2) + { + if (p1->type == p2->type) + return left; // P1 | P2 ~ P1 iff P1 == P2 + else + return std::nullopt; // P1 | P2 ~ P1 | P2 iff P1 != P2 + } + else if (auto [p, s] = get2(left, right); p && s) + { + if (p->type == PrimitiveType::String && get(s)) + return left; // string | "A" ~ string + else if (p->type == PrimitiveType::Boolean && get(s)) + return left; // boolean | true ~ boolean + else + return std::nullopt; // string | true ~ string | true + } + else if (auto [s, p] = get2(left, right); s && p) + return unionType(right, left); // S | P ~ P | S + else if (auto [p, f] = get2(left, right); p && f) + { + if (p->type == PrimitiveType::Function) + return left; // function | () -> () ~ function + else + return std::nullopt; // P | () -> () ~ P | () -> () + } + else if (auto [f, p] = get2(left, right); f && p) + return unionType(right, left); // () -> () | P ~ P | () -> () + else if (auto [s1, s2] = get2(left, right); s1 && s2) + { + if (*s1 == *s2) + return left; // "a" | "a" ~ "a" + else + return std::nullopt; // "a" | "b" ~ "a" | "b" + } + else if (auto [c1, c2] = get2(left, right); c1 && c2) + { + if (isSubclass(c1, c2)) + return right; // Derived | Base ~ Base + else if (isSubclass(c2, c1)) + return left; // Base | Derived ~ Base + else + return std::nullopt; // Base | Unrelated ~ Base | Unrelated + } + else if (auto [nt, it] = get2(left, right); nt && it) + { + std::vector parts; + for (TypeId option : it) + { + if (auto result = unionType(left, option)) + parts.push_back(*result); + else + { + // TODO: does there exist a reduced form such that `~T | A` hasn't already reduced it, if `A & B` is irreducible? + // I want to say yes, but I can't generate a case that hits this code path. + parts.push_back(arena->addType(UnionType{{left, option}})); + } + } + + return foldl(begin(parts), end(parts), &TypeReducer::intersectionType); // ~T | (A & B) ~ (~T | A) & (~T | B) + } + else if (auto [it, nt] = get2(left, right); it && nt) + return unionType(right, left); // (A & B) | ~T ~ ~T | (A & B) + else if (auto [nl, nr] = get2(left, right); nl && nr) + { + // These should've been reduced already. + TypeId nlTy = follow(nl->ty); + TypeId nrTy = follow(nr->ty); + LUAU_ASSERT(!get(nlTy) && !get(nrTy)); + LUAU_ASSERT(!get(nlTy) && !get(nrTy)); + LUAU_ASSERT(!get(nlTy) && !get(nrTy)); + LUAU_ASSERT(!get(nlTy) && !get(nrTy)); + LUAU_ASSERT(!get(nlTy) && !get(nrTy)); + + if (auto [npl, npr] = get2(nlTy, nrTy); npl && npr) + { + if (npl->type == npr->type) + return left; // ~P1 | ~P2 ~ ~P1 iff P1 == P2 + else + return builtinTypes->unknownType; // ~P1 | ~P2 ~ ~P1 iff P1 != P2 + } + else if (auto [nsl, nsr] = get2(nlTy, nrTy); nsl && nsr) + { + if (*nsl == *nsr) + return left; // ~"A" | ~"A" ~ ~"A" + else + return builtinTypes->unknownType; // ~"A" | ~"B" ~ unknown + } + else if (auto [ns, np] = get2(nlTy, nrTy); ns && np) + { + if (get(ns) && np->type == PrimitiveType::String) + return left; // ~"A" | ~string ~ ~"A" + else if (get(ns) && np->type == PrimitiveType::Boolean) + return left; // ~false | ~boolean ~ ~false + else + return builtinTypes->unknownType; // ~"A" | ~P ~ unknown + } + else if (auto [np, ns] = get2(nlTy, nrTy); np && ns) + return unionType(right, left); // ~P | ~S ~ ~S | ~P + else + return std::nullopt; // TODO! + } + else if (auto nl = get(left)) + { + // These should've been reduced already. + TypeId nlTy = follow(nl->ty); + LUAU_ASSERT(!get(nlTy)); + LUAU_ASSERT(!get(nlTy)); + LUAU_ASSERT(!get(nlTy)); + LUAU_ASSERT(!get(nlTy)); + LUAU_ASSERT(!get(nlTy)); + + if (auto [np, p] = get2(nlTy, right); np && p) + { + if (np->type == p->type) + return builtinTypes->unknownType; // ~P1 | P2 ~ unknown iff P1 == P2 + else + return left; // ~P1 | P2 ~ ~P1 iff P1 != P2 + } + else if (auto [ns, s] = get2(nlTy, right); ns && s) + { + if (*ns == *s) + return builtinTypes->unknownType; // ~"A" | "A" ~ unknown + else + return left; // ~"A" | "B" ~ ~"A" + } + else if (auto [ns, p] = get2(nlTy, right); ns && p) + { + if (get(ns) && p->type == PrimitiveType::String) + return builtinTypes->unknownType; // ~"A" | string ~ unknown + else if (get(ns) && p->type == PrimitiveType::Boolean) + return builtinTypes->unknownType; // ~false | boolean ~ unknown + else + return left; // ~"A" | T ~ ~"A" + } + else if (auto [np, s] = get2(nlTy, right); np && s) + { + if (np->type == PrimitiveType::String && get(s)) + return std::nullopt; // ~string | "A" ~ ~string | "A" + else if (np->type == PrimitiveType::Boolean && get(s)) + { + const BooleanSingleton* b = get(s); + return negationType(arena->addType(SingletonType{BooleanSingleton{!b->value}})); // ~boolean | false ~ ~true + } + else + return left; // ~P | "A" ~ ~P + } + else if (auto [nc, c] = get2(nlTy, right); nc && c) + { + if (isSubclass(nc, c)) + return builtinTypes->unknownType; // ~Derived | Base ~ unknown + else if (isSubclass(c, nc)) + return std::nullopt; // ~Base | Derived ~ ~Base | Derived + else + return left; // ~Base | Unrelated ~ ~Base + } + else + return std::nullopt; // TODO + } + else if (get(right)) + return unionType(right, left); // T | ~U ~ ~U | T + else + return std::nullopt; // for all T and U except the ones handled above, T | U ~ T | U +} + +TypeId TypeReducer::tableType(TypeId ty) +{ + RecursionGuard rg = guard(ty); + + if (auto mt = get(ty)) + { + auto [copiedTy, copied] = copy(ty, mt); + copied->table = reduce(mt->table); + copied->metatable = reduce(mt->metatable); + return copiedTy; + } + else if (auto tt = get(ty)) + { + auto [copiedTy, copied] = copy(ty, tt); + + for (auto& [name, prop] : copied->props) + prop.type = reduce(prop.type); + + if (auto& indexer = copied->indexer) + { + indexer->indexType = reduce(indexer->indexType); + indexer->indexResultType = reduce(indexer->indexResultType); + } + + for (TypeId& ty : copied->instantiatedTypeParams) + ty = reduce(ty); + + for (TypePackId& tp : copied->instantiatedTypePackParams) + tp = reduce(tp); + + return copiedTy; + } + else + handle->ice("Unexpected type in TypeReducer::tableType"); +} + +TypeId TypeReducer::functionType(TypeId ty) +{ + RecursionGuard rg = guard(ty); + + const FunctionType* f = get(ty); + if (!f) + handle->ice("TypeReducer::reduce expects a FunctionType"); + + // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. + auto [copiedTy, copied] = copy(ty, f); + copied->argTypes = reduce(f->argTypes); + copied->retTypes = reduce(f->retTypes); + return copiedTy; +} + +TypeId TypeReducer::negationType(TypeId ty) +{ + RecursionGuard rg = guard(ty); + + if (auto nn = get(ty)) + return nn->ty; // ~~T ~ T + else if (get(ty)) + return builtinTypes->unknownType; // ~never ~ unknown + else if (get(ty)) + return builtinTypes->neverType; // ~unknown ~ never + else if (get(ty)) + return builtinTypes->anyType; // ~any ~ any + else if (auto ni = get(ty)) + { + std::vector options; + for (TypeId part : ni) + options.push_back(negationType(part)); + return foldl(begin(options), end(options), &TypeReducer::unionType); // ~(T & U) ~ (~T | ~U) + } + else if (auto nu = get(ty)) + { + std::vector parts; + for (TypeId option : nu) + parts.push_back(negationType(option)); + return foldl(begin(parts), end(parts), &TypeReducer::intersectionType); // ~(T | U) ~ (~T & ~U) + } + else + return arena->addType(NegationType{ty}); // for all T except the ones handled above, ~T ~ ~T +} + +RecursionGuard TypeReducer::guard(TypeId ty) +{ + seen.push_back(ty); + return RecursionGuard{&depth, FInt::LuauTypeReductionRecursionLimit, &seen}; +} + +RecursionGuard TypeReducer::guard(TypePackId tp) +{ + seen.push_back(tp); + return RecursionGuard{&depth, FInt::LuauTypeReductionRecursionLimit, &seen}; +} + +} // namespace + +TypeReduction::TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle) + : arena(arena) + , builtinTypes(builtinTypes) + , handle(handle) +{ +} + +std::optional TypeReduction::reduce(TypeId ty) +{ + if (auto found = cachedTypes.find(ty)) + return *found; + + if (auto reduced = reduceImpl(ty)) + { + cachedTypes[ty] = *reduced; + return *reduced; + } + + return std::nullopt; +} + +std::optional TypeReduction::reduce(TypePackId tp) +{ + if (auto found = cachedTypePacks.find(tp)) + return *found; + + if (auto reduced = reduceImpl(tp)) + { + cachedTypePacks[tp] = *reduced; + return *reduced; + } + + return std::nullopt; +} + +std::optional TypeReduction::reduceImpl(TypeId ty) +{ + if (FFlag::DebugLuauDontReduceTypes) + return ty; + + if (hasExceededCartesianProductLimit(ty)) + return std::nullopt; + + try + { + TypeReducer reducer{arena, builtinTypes, handle}; + return reducer.reduce(ty); + } + catch (const RecursionLimitException&) + { + return std::nullopt; + } +} + +std::optional TypeReduction::reduceImpl(TypePackId tp) +{ + if (FFlag::DebugLuauDontReduceTypes) + return tp; + + if (hasExceededCartesianProductLimit(tp)) + return std::nullopt; + + try + { + TypeReducer reducer{arena, builtinTypes, handle}; + return reducer.reduce(tp); + } + catch (const RecursionLimitException&) + { + return std::nullopt; + } +} + +std::optional TypeReduction::reduce(const TypeFun& fun) +{ + if (FFlag::DebugLuauDontReduceTypes) + return fun; + + // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. + if (auto reducedTy = reduce(fun.type)) + return TypeFun{fun.typeParams, fun.typePackParams, *reducedTy}; + + return std::nullopt; +} + +size_t TypeReduction::cartesianProductSize(TypeId ty) const +{ + ty = follow(ty); + + auto it = get(follow(ty)); + if (!it) + return 1; + + return std::accumulate(begin(it), end(it), size_t(1), [](size_t acc, TypeId ty) { + if (auto ut = get(ty)) + return acc * std::distance(begin(ut), end(ut)); + else if (get(ty)) + return acc * 0; + else + return acc * 1; + }); +} + +bool TypeReduction::hasExceededCartesianProductLimit(TypeId ty) const +{ + return cartesianProductSize(ty) >= size_t(FInt::LuauTypeReductionCartesianProductLimit); +} + +bool TypeReduction::hasExceededCartesianProductLimit(TypePackId tp) const +{ + TypePackIterator it = begin(tp); + + while (it != end(tp)) + { + if (hasExceededCartesianProductLimit(*it)) + return true; + + ++it; + } + + if (auto tail = it.tail()) + { + if (auto vtp = get(follow(*tail))) + { + if (hasExceededCartesianProductLimit(vtp->ty)) + return true; + } + } + + return false; +} + +} // namespace Luau diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index b4dea4f56..8a1e80fc5 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAG(LuauInterpolatedStringBaseSupport) LUAU_FASTFLAGVARIABLE(LuauMultiAssignmentConflictFix, false) +LUAU_FASTFLAGVARIABLE(LuauSelfAssignmentSkip, false) namespace Luau { @@ -2027,7 +2028,9 @@ struct Compiler // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining if (int reg = getExprLocalReg(expr); reg >= 0) { - bytecode.emitABC(LOP_MOVE, target, uint8_t(reg), 0); + // Optimization: we don't need to move if target happens to be in the same register + if (!FFlag::LuauSelfAssignmentSkip || options.optimizationLevel == 0 || target != reg) + bytecode.emitABC(LOP_MOVE, target, uint8_t(reg), 0); } else { diff --git a/Sources.cmake b/Sources.cmake index 437ff9934..87d76bf39 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -146,6 +146,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypePack.h + Analysis/include/Luau/TypeReduction.h Analysis/include/Luau/TypeUtils.h Analysis/include/Luau/Type.h Analysis/include/Luau/Unifiable.h @@ -195,6 +196,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypedAllocator.cpp Analysis/src/TypeInfer.cpp Analysis/src/TypePack.cpp + Analysis/src/TypeReduction.cpp Analysis/src/TypeUtils.cpp Analysis/src/Type.cpp Analysis/src/Unifiable.cpp @@ -364,6 +366,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.unionTypes.test.cpp tests/TypeInfer.unknownnever.test.cpp tests/TypePack.test.cpp + tests/TypeReduction.test.cpp tests/TypeVar.test.cpp tests/Variant.test.cpp tests/VisitType.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index d13305783..649c96c1a 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -17,9 +17,9 @@ /* ** pseudo-indices */ -#define LUA_REGISTRYINDEX (-10000) -#define LUA_ENVIRONINDEX (-10001) -#define LUA_GLOBALSINDEX (-10002) +#define LUA_REGISTRYINDEX (-LUAI_MAXCSTACK - 2000) +#define LUA_ENVIRONINDEX (-LUAI_MAXCSTACK - 2001) +#define LUA_GLOBALSINDEX (-LUAI_MAXCSTACK - 2002) #define lua_upvalueindex(i) (LUA_GLOBALSINDEX - (i)) #define lua_ispseudo(i) ((i) <= LUA_REGISTRYINDEX) diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 105829473..f241963a6 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3331,4 +3331,36 @@ TEST_CASE_FIXTURE(ACFixture, "globals_are_order_independent") CHECK(ac.entryMap.count("abc1")); } +TEST_CASE_FIXTURE(ACFixture, "type_reduction_is_hooked_up_to_autocomplete") +{ + check(R"( + type T = { x: (number & string)? } + + function f(thingamabob: T) + thingamabob.@1 + end + + function g(thingamabob: T) + thingama@2 + end + )"); + + ToStringOptions opts; + opts.exhaustive = true; + + auto ac1 = autocomplete('1'); + REQUIRE(ac1.entryMap.count("x")); + std::optional ty1 = ac1.entryMap.at("x").type; + REQUIRE(ty1); + CHECK("(number & string)?" == toString(*ty1, opts)); + // CHECK("nil" == toString(*ty1, opts)); + + auto ac2 = autocomplete('2'); + REQUIRE(ac2.entryMap.count("thingamabob")); + std::optional ty2 = ac2.entryMap.at("thingamabob").type; + REQUIRE(ty2); + CHECK("{| x: (number & string)? |}" == toString(*ty2, opts)); + // CHECK("{| x: nil |}" == toString(*ty2, opts)); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index d2cf0ae8e..1a6061267 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -1025,6 +1025,8 @@ L0: RETURN R0 0 TEST_CASE("AndOr") { + ScopedFastFlag luauSelfAssignmentSkip{"LuauSelfAssignmentSkip", true}; + // codegen for constant, local, global for and CHECK_EQ("\n" + compileFunction0("local a = 1 a = a and 2 return a"), R"( LOADN R0 1 @@ -1079,7 +1081,6 @@ RETURN R0 1 // note: `a = a` assignment is to disable constant folding for testing purposes CHECK_EQ("\n" + compileFunction0("local a = 1 a = a b = 2 local c = a and b return c"), R"( LOADN R0 1 -MOVE R0 R0 LOADN R1 2 SETGLOBAL R1 K0 MOVE R1 R0 @@ -1090,7 +1091,6 @@ L0: RETURN R1 1 CHECK_EQ("\n" + compileFunction0("local a = 1 a = a b = 2 local c = a or b return c"), R"( LOADN R0 1 -MOVE R0 R0 LOADN R1 2 SETGLOBAL R1 K0 MOVE R1 R0 @@ -2260,6 +2260,8 @@ L1: RETURN R3 -1 TEST_CASE("UpvaluesLoopsBytecode") { + ScopedFastFlag luauSelfAssignmentSkip{"LuauSelfAssignmentSkip", true}; + CHECK_EQ("\n" + compileFunction(R"( function test() for i=1,10 do @@ -2279,7 +2281,6 @@ LOADN R0 10 LOADN R1 1 FORNPREP R0 L2 L0: MOVE R3 R2 -MOVE R3 R3 GETIMPORT R4 1 NEWCLOSURE R5 P0 CAPTURE REF R3 @@ -2312,8 +2313,7 @@ GETIMPORT R0 1 GETIMPORT R1 3 CALL R0 1 3 FORGPREP_INEXT R0 L2 -L0: MOVE R3 R3 -GETIMPORT R5 5 +L0: GETIMPORT R5 5 NEWCLOSURE R6 P0 CAPTURE REF R3 CALL R5 1 0 @@ -5159,6 +5159,8 @@ RETURN R1 1 TEST_CASE("InlineMutate") { + ScopedFastFlag luauSelfAssignmentSkip{"LuauSelfAssignmentSkip", true}; + // if the argument is mutated, it gets a register even if the value is constant CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -5231,7 +5233,6 @@ return x 1, 2), R"( DUPCLOSURE R0 K0 -MOVE R0 R0 MOVE R1 R0 LOADN R2 42 CALL R1 1 1 @@ -6790,4 +6791,31 @@ L0: RETURN R1 -1 )"); } +TEST_CASE("SkipSelfAssignment") +{ + ScopedFastFlag luauSelfAssignmentSkip{"LuauSelfAssignmentSkip", true}; + + CHECK_EQ("\n" + compileFunction0("local a a = a"), R"( +LOADNIL R0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a a = a :: number"), R"( +LOADNIL R0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a a = (((a)))"), R"( +LOADNIL R0 +RETURN R0 0 +)"); + + // Keep it on optimization level 0 + CHECK_EQ("\n" + compileFunction("local a a = a", 0, 0), R"( +LOADNIL R0 +MOVE R0 R0 +RETURN R0 0 +)"); +} + TEST_SUITE_END(); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index 64e6baaf0..4d7ee4fe6 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "ConstraintGraphBuilderFixture.h" +#include "Luau/TypeReduction.h" + namespace Luau { @@ -9,6 +11,8 @@ ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() , mainModule(new Module) , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { + mainModule->reduction = std::make_unique(NotNull{&mainModule->internalTypes}, builtinTypes, NotNull{&ice}); + BlockedType::nextIndex = 0; BlockedTypePack::nextIndex = 0; } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 416292817..5ff006277 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -312,6 +312,9 @@ std::optional Fixture::getType(const std::string& name) ModulePtr module = getMainModule(); REQUIRE(module); + if (!module->hasModuleScope()) + return std::nullopt; + if (FFlag::DebugLuauDeferredConstraintResolution) return linearSearchForBinding(module->getModuleScope().get(), name.c_str()); else @@ -329,11 +332,14 @@ TypeId Fixture::requireType(const ModuleName& moduleName, const std::string& nam { ModulePtr module = frontend.moduleResolver.getModule(moduleName); REQUIRE(module); - return requireType(module->getModuleScope(), name); + return requireType(module, name); } TypeId Fixture::requireType(const ModulePtr& module, const std::string& name) { + if (!module->hasModuleScope()) + FAIL("requireType: module scope data is not available"); + return requireType(module->getModuleScope(), name); } @@ -367,7 +373,12 @@ TypeId Fixture::requireTypeAtPosition(Position position) std::optional Fixture::lookupType(const std::string& name) { - if (auto typeFun = getMainModule()->getModuleScope()->lookupType(name)) + ModulePtr module = getMainModule(); + + if (!module->hasModuleScope()) + return std::nullopt; + + if (auto typeFun = module->getModuleScope()->lookupType(name)) return typeFun->type; return std::nullopt; @@ -375,12 +386,24 @@ std::optional Fixture::lookupType(const std::string& name) std::optional Fixture::lookupImportedType(const std::string& moduleAlias, const std::string& name) { - if (auto typeFun = getMainModule()->getModuleScope()->lookupImportedType(moduleAlias, name)) + ModulePtr module = getMainModule(); + + if (!module->hasModuleScope()) + FAIL("lookupImportedType: module scope data is not available"); + + if (auto typeFun = module->getModuleScope()->lookupImportedType(moduleAlias, name)) return typeFun->type; return std::nullopt; } +TypeId Fixture::requireTypeAlias(const std::string& name) +{ + std::optional ty = lookupType(name); + REQUIRE(ty); + return *ty; +} + std::string Fixture::decorateWithTypes(const std::string& code) { fileResolver.source[mainModuleName] = code; @@ -552,15 +575,52 @@ std::optional linearSearchForBinding(Scope* scope, const char* name) return std::nullopt; } -void registerHiddenTypes(Fixture& fixture, TypeArena& arena) +void registerHiddenTypes(Frontend* frontend) { - TypeId t = arena.addType(GenericType{"T"}); + TypeId t = frontend->globalTypes.addType(GenericType{"T"}); GenericTypeDefinition genericT{t}; - ScopePtr moduleScope = fixture.frontend.getGlobalScope(); - moduleScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, arena.addType(NegationType{t})}; - moduleScope->exportedTypeBindings["fun"] = TypeFun{{}, fixture.builtinTypes->functionType}; - moduleScope->exportedTypeBindings["cls"] = TypeFun{{}, fixture.builtinTypes->classType}; + ScopePtr globalScope = frontend->getGlobalScope(); + globalScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, frontend->globalTypes.addType(NegationType{t})}; + globalScope->exportedTypeBindings["fun"] = TypeFun{{}, frontend->builtinTypes->functionType}; + globalScope->exportedTypeBindings["cls"] = TypeFun{{}, frontend->builtinTypes->classType}; + globalScope->exportedTypeBindings["err"] = TypeFun{{}, frontend->builtinTypes->errorType}; +} + +void createSomeClasses(Frontend* frontend) +{ + TypeArena& arena = frontend->globalTypes; + unfreeze(arena); + + ScopePtr moduleScope = frontend->getGlobalScope(); + + TypeId parentType = arena.addType(ClassType{"Parent", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); + + ClassType* parentClass = getMutable(parentType); + parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; + + parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; + + addGlobalBinding(*frontend, "Parent", {parentType}); + moduleScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; + + TypeId childType = arena.addType(ClassType{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); + + ClassType* childClass = getMutable(childType); + childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; + + addGlobalBinding(*frontend, "Child", {childType}); + moduleScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; + + TypeId unrelatedType = arena.addType(ClassType{"Unrelated", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); + + addGlobalBinding(*frontend, "Unrelated", {unrelatedType}); + moduleScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; + + for (const auto& [name, ty] : moduleScope->exportedTypeBindings) + persist(ty.type); + + freeze(arena); } void dump(const std::vector& constraints) diff --git a/tests/Fixture.h b/tests/Fixture.h index 3edd6b4c1..6dc8abf2d 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -91,6 +91,7 @@ struct Fixture std::optional lookupType(const std::string& name); std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); + TypeId requireTypeAlias(const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; ScopedFastFlag sff_UnknownNever{"LuauUnknownAndNeverType", true}; @@ -151,7 +152,8 @@ std::optional lookupName(ScopePtr scope, const std::string& name); // Wa std::optional linearSearchForBinding(Scope* scope, const char* name); -void registerHiddenTypes(Fixture& fixture, TypeArena& arena); +void registerHiddenTypes(Frontend* frontend); +void createSomeClasses(Frontend* frontend); } // namespace Luau diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 93df5605e..a69965e04 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -10,6 +10,8 @@ #include +LUAU_FASTFLAG(LuauScopelessModule) + using namespace Luau; namespace @@ -143,6 +145,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "real_source") TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; fileResolver.source["game/Gui/Modules/B"] = R"( local Modules = game:GetService('Gui').Modules @@ -157,7 +161,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") CHECK(bModule->errors.empty()); Luau::dumpErrors(bModule); - auto bExports = first(bModule->getModuleScope()->returnType); + auto bExports = first(bModule->returnType); REQUIRE(!!bExports); CHECK_EQ("{| b_value: number |}", toString(*bExports)); @@ -220,6 +224,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "any_annotation_breaks_cycle") TEST_CASE_FIXTURE(FrontendFixture, "nocheck_modules_are_typed") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + fileResolver.source["game/Gui/Modules/A"] = R"( --!nocheck export type Foo = number @@ -243,13 +249,13 @@ TEST_CASE_FIXTURE(FrontendFixture, "nocheck_modules_are_typed") ModulePtr aModule = frontend.moduleResolver.modules["game/Gui/Modules/A"]; REQUIRE(bool(aModule)); - std::optional aExports = first(aModule->getModuleScope()->returnType); + std::optional aExports = first(aModule->returnType); REQUIRE(bool(aExports)); ModulePtr bModule = frontend.moduleResolver.modules["game/Gui/Modules/B"]; REQUIRE(bool(bModule)); - std::optional bExports = first(bModule->getModuleScope()->returnType); + std::optional bExports = first(bModule->returnType); REQUIRE(bool(bExports)); CHECK_EQ(toString(*aExports), toString(*bExports)); @@ -275,6 +281,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_detection_between_check_and_nocheck") TEST_CASE_FIXTURE(FrontendFixture, "nocheck_cycle_used_by_checked") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + fileResolver.source["game/Gui/Modules/A"] = R"( --!nocheck local Modules = game:GetService('Gui').Modules @@ -300,7 +308,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "nocheck_cycle_used_by_checked") ModulePtr cModule = frontend.moduleResolver.modules["game/Gui/Modules/C"]; REQUIRE(bool(cModule)); - std::optional cExports = first(cModule->getModuleScope()->returnType); + std::optional cExports = first(cModule->returnType); REQUIRE(bool(cExports)); CHECK_EQ("{| a: any, b: any |}", toString(*cExports)); } @@ -493,6 +501,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "dont_recheck_script_that_hasnt_been_marked_d TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_is_dirty") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; fileResolver.source["game/Gui/Modules/B"] = R"( local Modules = game:GetService('Gui').Modules @@ -511,7 +521,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_is_dirty") CHECK(bModule->errors.empty()); Luau::dumpErrors(bModule); - auto bExports = first(bModule->getModuleScope()->returnType); + auto bExports = first(bModule->returnType); REQUIRE(!!bExports); CHECK_EQ("{| b_value: string |}", toString(*bExports)); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 5f97fb6cd..34c2e8fd8 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -112,6 +112,8 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + CheckResult result = check(R"( return {sign=math.sign} )"); @@ -119,7 +121,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") LUAU_REQUIRE_NO_ERRORS(result); ModulePtr module = frontend.moduleResolver.getModule("MainModule"); - std::optional exports = first(module->getModuleScope()->returnType); + std::optional exports = first(module->returnType); REQUIRE(bool(exports)); REQUIRE(isInArena(*exports, module->interfaceTypes)); @@ -283,6 +285,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + fileResolver.source["Module/A"] = R"( export type A = B type B = A @@ -294,8 +298,8 @@ type B = A LUAU_REQUIRE_ERRORS(result); auto mod = frontend.moduleResolver.getModule("Module/A"); - auto it = mod->getModuleScope()->exportedTypeBindings.find("A"); - REQUIRE(it != mod->getModuleScope()->exportedTypeBindings.end()); + auto it = mod->exportedTypeBindings.find("A"); + REQUIRE(it != mod->exportedTypeBindings.end()); CHECK(toString(it->second.type) == "any"); } @@ -306,6 +310,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_reexports") {"LuauSubstitutionReentrant", true}, {"LuauClassTypeVarsInSubstitution", true}, {"LuauSubstitutionFixMissingFields", true}, + {"LuauScopelessModule", true}, }; fileResolver.source["Module/A"] = R"( @@ -326,10 +331,10 @@ return {} ModulePtr modB = frontend.moduleResolver.getModule("Module/B"); REQUIRE(modA); REQUIRE(modB); - auto modAiter = modA->getModuleScope()->exportedTypeBindings.find("A"); - auto modBiter = modB->getModuleScope()->exportedTypeBindings.find("B"); - REQUIRE(modAiter != modA->getModuleScope()->exportedTypeBindings.end()); - REQUIRE(modBiter != modB->getModuleScope()->exportedTypeBindings.end()); + auto modAiter = modA->exportedTypeBindings.find("A"); + auto modBiter = modB->exportedTypeBindings.find("B"); + REQUIRE(modAiter != modA->exportedTypeBindings.end()); + REQUIRE(modBiter != modB->exportedTypeBindings.end()); TypeId typeA = modAiter->second.type; TypeId typeB = modBiter->second.type; TableType* tableB = getMutable(typeB); @@ -344,6 +349,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") {"LuauSubstitutionReentrant", true}, {"LuauClassTypeVarsInSubstitution", true}, {"LuauSubstitutionFixMissingFields", true}, + {"LuauScopelessModule", true}, }; fileResolver.source["Module/A"] = R"( @@ -364,8 +370,8 @@ return exports ModulePtr modB = frontend.moduleResolver.getModule("Module/B"); REQUIRE(modA); REQUIRE(modB); - std::optional typeA = first(modA->getModuleScope()->returnType); - std::optional typeB = first(modB->getModuleScope()->returnType); + std::optional typeA = first(modA->returnType); + std::optional typeB = first(modB->returnType); REQUIRE(typeA); REQUIRE(typeB); TableType* tableA = getMutable(*typeA); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 8a25a5e59..5deeb35dc 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -253,6 +253,8 @@ TEST_CASE_FIXTURE(Fixture, "delay_function_does_not_require_its_argument_to_retu TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + CheckResult result = check(R"( --!nonstrict @@ -269,7 +271,7 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType)); + REQUIRE_EQ("any", toString(getMainModule()->returnType)); } TEST_CASE_FIXTURE(Fixture, "returning_insufficient_return_values") diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index ba9f5c525..615fc997c 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -17,46 +17,17 @@ struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { - return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, builtinTypes, ice); + ModulePtr module = getMainModule(); + REQUIRE(module); + + if (!module->hasModuleScope()) + FAIL("isSubtype: module scope data is not available"); + + return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); } }; } // namespace -void createSomeClasses(Frontend& frontend) -{ - auto& arena = frontend.globalTypes; - - unfreeze(arena); - - TypeId parentType = arena.addType(ClassType{"Parent", {}, frontend.builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); - - ClassType* parentClass = getMutable(parentType); - parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; - - parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; - - addGlobalBinding(frontend, "Parent", {parentType}); - frontend.getGlobalScope()->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; - - TypeId childType = arena.addType(ClassType{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); - - ClassType* childClass = getMutable(childType); - childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; - - addGlobalBinding(frontend, "Child", {childType}); - frontend.getGlobalScope()->exportedTypeBindings["Child"] = TypeFun{{}, childType}; - - TypeId unrelatedType = arena.addType(ClassType{"Unrelated", {}, frontend.builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); - - addGlobalBinding(frontend, "Unrelated", {unrelatedType}); - frontend.getGlobalScope()->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; - - for (const auto& [name, ty] : frontend.getGlobalScope()->exportedTypeBindings) - persist(ty.type); - - freeze(arena); -} - TEST_SUITE_BEGIN("isSubtype"); TEST_CASE_FIXTURE(IsSubtypeFixture, "primitives") @@ -352,7 +323,7 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "cyclic_table") TEST_CASE_FIXTURE(IsSubtypeFixture, "classes") { - createSomeClasses(frontend); + createSomeClasses(&frontend); check(""); // Ensure that we have a main Module. @@ -403,11 +374,12 @@ struct NormalizeFixture : Fixture NormalizeFixture() { - registerHiddenTypes(*this, arena); + registerHiddenTypes(&frontend); } const NormalizedType* toNormalizedType(const std::string& annotation) { + normalizer.clearCaches(); CheckResult result = check("type _Res = " + annotation); LUAU_REQUIRE_NO_ERRORS(result); std::optional ty = lookupType("_Res"); @@ -588,7 +560,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly") TEST_CASE_FIXTURE(BuiltinsFixture, "skip_force_normal_on_external_types") { - createSomeClasses(frontend); + createSomeClasses(&frontend); CheckResult result = check(R"( export type t0 = { a: Child } @@ -611,7 +583,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes") { ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(frontend); + createSomeClasses(&frontend); CHECK("Parent | Unrelated" == toString(normal("Parent | Unrelated"))); CHECK("Parent" == toString(normal("Parent | Child"))); CHECK("Parent | Unrelated" == toString(normal("Parent | Child | Unrelated"))); @@ -621,7 +593,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes") { ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(frontend); + createSomeClasses(&frontend); CHECK("Child" == toString(normal("Parent & Child"))); CHECK("never" == toString(normal("Child & Unrelated"))); } @@ -630,7 +602,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") { ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(frontend); + createSomeClasses(&frontend); CHECK("Child" == toString(normal("(Child | Unrelated) & Child"))); } @@ -638,7 +610,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") { ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(frontend); + createSomeClasses(&frontend); CHECK("(Parent & ~Child) | Unrelated" == toString(normal("(Parent & Not) | Unrelated"))); CHECK("((class & ~Child) | boolean | function | number | string | thread)?" == toString(normal("Not"))); CHECK("Child" == toString(normal("Not & Child"))); @@ -652,7 +624,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_unknown") { ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(frontend); + createSomeClasses(&frontend); CHECK("Parent" == toString(normal("Parent & unknown"))); } @@ -660,7 +632,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_never") { ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(frontend); + createSomeClasses(&frontend); CHECK("never" == toString(normal("Parent & never"))); } diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 4dd822690..9b65cf248 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -9,7 +9,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) -LUAU_FASTFLAG(LuauNewLibraryTypeNames) TEST_SUITE_BEGIN("TypeAliases"); @@ -506,19 +505,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_multi_assign") CheckResult result = frontend.check("workspace/C"); LUAU_REQUIRE_NO_ERRORS(result); - ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; - REQUIRE(m != nullptr); - - std::optional aTypeId = lookupName(m->getModuleScope(), "a"); - REQUIRE(aTypeId); - const Luau::TableType* aType = get(follow(*aTypeId)); + TypeId aTypeId = requireType("workspace/C", "a"); + const Luau::TableType* aType = get(follow(aTypeId)); REQUIRE(aType); REQUIRE(aType->props.size() == 2); - std::optional bTypeId = lookupName(m->getModuleScope(), "b"); - REQUIRE(bTypeId); - const Luau::TableType* bType = get(follow(*bTypeId)); + TypeId bTypeId = requireType("workspace/C", "b"); + const Luau::TableType* bType = get(follow(bTypeId)); REQUIRE(bType); REQUIRE(bType->props.size() == 3); } @@ -530,10 +524,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") TypeId ty = getGlobalBinding(frontend, "table"); - if (FFlag::LuauNewLibraryTypeNames) - CHECK(toString(ty) == "typeof(table)"); - else - CHECK(toString(ty) == "table"); + CHECK(toString(ty) == "typeof(table)"); const TableType* ttv = get(ty); REQUIRE(ttv); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index bf66ecbc9..3e98367cd 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -319,10 +319,10 @@ TEST_CASE_FIXTURE(Fixture, "self_referential_type_alias") LUAU_REQUIRE_NO_ERRORS(result); - std::optional res = getMainModule()->getModuleScope()->lookupType("O"); + std::optional res = lookupType("O"); REQUIRE(res); - TypeId oType = follow(res->type); + TypeId oType = follow(*res); const TableType* oTable = get(oType); REQUIRE(oTable); @@ -347,6 +347,8 @@ TEST_CASE_FIXTURE(Fixture, "define_generic_type_alias") LUAU_REQUIRE_NO_ERRORS(result); ModulePtr mainModule = getMainModule(); + REQUIRE(mainModule); + REQUIRE(mainModule->hasModuleScope()); auto it = mainModule->getModuleScope()->privateTypeBindings.find("Array"); REQUIRE(it != mainModule->getModuleScope()->privateTypeBindings.end()); @@ -463,6 +465,8 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_always_resolve_to_a_real_type") TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + CheckResult result = check(R"( export type A = {field: number} @@ -475,12 +479,12 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") Module& mod = *getMainModule(); - const TypeFun& a = mod.getModuleScope()->exportedTypeBindings["A"]; + const TypeFun& a = mod.exportedTypeBindings["A"]; CHECK(isInArena(a.type, mod.interfaceTypes)); CHECK(!isInArena(a.type, typeChecker.globalTypes)); - std::optional exportsType = first(mod.getModuleScope()->returnType); + std::optional exportsType = first(mod.returnType); REQUIRE(exportsType); TableType* exportsTable = getMutable(*exportsType); @@ -494,6 +498,8 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + CheckResult result = check(R"( export type Array = { [number]: T } )"); @@ -501,7 +507,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") dumpErrors(result); Module& mod = *getMainModule(); - const auto& typeBindings = mod.getModuleScope()->exportedTypeBindings; + const auto& typeBindings = mod.exportedTypeBindings; auto it = typeBindings.find("Array"); REQUIRE(typeBindings.end() != it); @@ -521,6 +527,8 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + CheckResult result = check(R"( export type Record = { name: string, location: string } local a: Record = { name="Waldo", location="?????" } @@ -533,9 +541,9 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti Module& mod = *getMainModule(); - TypeId recordType = mod.getModuleScope()->exportedTypeBindings["Record"].type; + TypeId recordType = mod.exportedTypeBindings["Record"].type; - std::optional exportsType = first(mod.getModuleScope()->returnType); + std::optional exportsType = first(mod.returnType); REQUIRE(exportsType); TableType* exportsTable = getMutable(*exportsType); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index a97cea21c..70de13d15 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -109,6 +109,8 @@ TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + CheckResult result = check(R"( local T = {} function T.f(...) @@ -129,7 +131,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") LUAU_REQUIRE_NO_ERRORS(result); - auto r = first(getMainModule()->getModuleScope()->returnType); + auto r = first(getMainModule()->returnType); REQUIRE(r); TableType* ttv = getMutable(*r); @@ -1772,7 +1774,7 @@ z = y -- Not OK, so the line is colorable TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") { ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; - registerHiddenTypes(*this, frontend.globalTypes); + registerHiddenTypes(&frontend); CheckResult result = check(R"( function foo(f: fun) end @@ -1791,7 +1793,7 @@ TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function") { ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; - registerHiddenTypes(*this, frontend.globalTypes); + registerHiddenTypes(&frontend); CheckResult result = check(R"( local a: fun = function() end @@ -1812,7 +1814,7 @@ TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function") TEST_CASE_FIXTURE(Fixture, "other_things_are_not_related_to_function") { ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; - registerHiddenTypes(*this, frontend.globalTypes); + registerHiddenTypes(&frontend); CheckResult result = check(R"( local a: fun = function() end diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 7b4176211..3861a8b6c 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -1021,9 +1021,9 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") LUAU_REQUIRE_ERRORS(result); - std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); + std::optional t0 = lookupType("t0"); REQUIRE(t0); - CHECK_EQ("*error-type*", toString(t0->type)); + CHECK_EQ("*error-type*", toString(*t0)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp index 02350a728..261314a64 100644 --- a/tests/TypeInfer.negations.test.cpp +++ b/tests/TypeInfer.negations.test.cpp @@ -20,7 +20,7 @@ struct NegationFixture : Fixture NegationFixture() { - registerHiddenTypes(*this, arena); + registerHiddenTypes(&frontend); } }; diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 0196666a0..5db5b880b 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -405,17 +405,41 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") type V2B = { x: number, y: number } local v2b: V2B = { x = 0, y = 0 } local VMT = {} - type V2 = typeof(setmetatable(v2b, VMT)) - function VMT.__add(a: V2, b: V2): V2 + VMT.__add = function(a: V2, b: V2): V2 return setmetatable({ x = a.x + b.x, y = a.y + b.y }, VMT) end + type V2 = typeof(setmetatable(v2b, VMT)) + local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) v1 += v2 )"); - CHECK_EQ(0, result.errors.size()); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_result_must_be_compatible_with_var") +{ + CheckResult result = check(R"( + function __add(left, right) + return 123 + end + + local mt = { + __add = __add, + } + + local x = setmetatable({}, mt) + local v: number + + v += x -- okay: number + x -> number + x += v -- not okay: x numberType}}); } TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_mismatch_metatable") @@ -1015,11 +1039,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "mm_ops_must_return_a_value") local y = x + 123 )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK(requireType("y") == builtinTypes->errorRecoveryType()); - const GenericError* ge = get(result.errors[0]); + const GenericError* ge = get(result.errors[1]); REQUIRE(ge); CHECK(ge->message == "Metamethod '__add' must return a value"); } @@ -1049,13 +1073,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") local v2 = o2 < o2 )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERROR_COUNT(4, result); CHECK(requireType("v1") == builtinTypes->booleanType); CHECK(requireType("v2") == builtinTypes->booleanType); - CHECK(toString(result.errors[0]) == "Metamethod '__lt' must return type 'boolean'"); - CHECK(toString(result.errors[1]) == "Metamethod '__lt' must return type 'boolean'"); + CHECK(toString(result.errors[1]) == "Metamethod '__lt' must return a boolean"); + CHECK(toString(result.errors[3]) == "Metamethod '__lt' must return a boolean"); } TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_and") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index cf969f2d7..3e278ca21 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -516,6 +516,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") // Ideally, we would not try to export a function type with generic types from incorrect scope TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + fileResolver.source["game/A"] = R"( local wrapStrictTable @@ -548,13 +550,15 @@ return wrapStrictTable(Constants, "Constants") ModulePtr m = frontend.moduleResolver.modules["game/B"]; REQUIRE(m); - std::optional result = first(m->getModuleScope()->returnType); + std::optional result = first(m->returnType); REQUIRE(result); CHECK(get(*result)); } TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + fileResolver.source["game/A"] = R"( local wrapStrictTable @@ -587,7 +591,7 @@ return wrapStrictTable(Constants, "Constants") ModulePtr m = frontend.moduleResolver.modules["game/B"]; REQUIRE(m); - std::optional result = first(m->getModuleScope()->returnType); + std::optional result = first(m->returnType); REQUIRE(result); CHECK(get(*result)); } @@ -620,7 +624,13 @@ struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { - return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, builtinTypes, ice); + ModulePtr module = getMainModule(); + REQUIRE(module); + + if (!module->hasModuleScope()) + FAIL("isSubtype: module scope data is not available"); + + return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); } }; } // namespace diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index e2aa01f95..dc3b7ceb7 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) -LUAU_FASTFLAG(LuauNewLibraryTypeNames) TEST_SUITE_BEGIN("TableTests"); @@ -1730,16 +1729,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") LUAU_REQUIRE_ERROR_COUNT(2, result); - if (FFlag::LuauNewLibraryTypeNames) - { - CHECK_EQ("Cannot add property 'h' to table 'typeof(os)'", toString(result.errors[0])); - CHECK_EQ("Cannot add property 'k' to table 'typeof(string)'", toString(result.errors[1])); - } - else - { - CHECK_EQ("Cannot add property 'h' to table 'os'", toString(result.errors[0])); - CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); - } + CHECK_EQ("Cannot add property 'h' to table 'typeof(os)'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'k' to table 'typeof(string)'", toString(result.errors[1])); } TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") @@ -1750,10 +1741,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauNewLibraryTypeNames) - CHECK_EQ("Cannot add property 'bad' to table 'typeof(os)'", toString(result.errors[0])); - else - CHECK_EQ("Cannot add property 'bad' to table 'os'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'bad' to table 'typeof(os)'", toString(result.errors[0])); const TableType* osType = get(requireType("os")); REQUIRE(osType != nullptr); @@ -2967,6 +2955,8 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the // The real bug here was that we weren't always uncondionally typechecking a trailing return statement last. TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") { + ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; + CheckResult result = check(R"( local function a(state) print(state.blah) @@ -2988,7 +2978,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") CHECK_EQ("({+ blah: a +}) -> ()", toString(requireType("a"))); CHECK_EQ("({+ gwar: a +}) -> ()", toString(requireType("b"))); - CHECK_EQ("() -> ({+ blah: a, gwar: b +}) -> ()", toString(getMainModule()->getModuleScope()->returnType)); + CHECK_EQ("() -> ({+ blah: a, gwar: b +}) -> ()", toString(getMainModule()->returnType)); } TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") @@ -3230,8 +3220,6 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shap TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; - if (!FFlag::LuauNewLibraryTypeNames) - return; CheckResult result = check(R"( local function f(s) @@ -3280,8 +3268,6 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compati TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; - if (!FFlag::LuauNewLibraryTypeNames) - return; CheckResult result = check(R"( local function f(s): string diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index f6279fa2c..f4b84262c 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -648,10 +648,10 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") LUAU_REQUIRE_ERRORS(result); - std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); + std::optional t0 = lookupType("t0"); REQUIRE(t0); - CHECK_EQ("*error-type*", toString(t0->type)); + CHECK_EQ("*error-type*", toString(*t0)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index b753d30e9..94448cfa5 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -428,8 +428,12 @@ type E = X<(number, ...string)> LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*lookupType("D")), "(...number) -> (string, ...number)"); - CHECK_EQ(toString(*lookupType("E")), "(number, ...string) -> (string, number, ...string)"); + auto d = lookupType("D"); + REQUIRE(d); + auto e = lookupType("E"); + REQUIRE(e); + CHECK_EQ(toString(*d), "(...number) -> (string, ...number)"); + CHECK_EQ(toString(*e), "(number, ...string) -> (string, number, ...string)"); } TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") @@ -887,9 +891,13 @@ TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE(bool(getMainModule()->getModuleScope()->varargPack)); + ModulePtr mainModule = getMainModule(); + REQUIRE(mainModule); + REQUIRE(mainModule->hasModuleScope()); - TypePackId varargPack = *getMainModule()->getModuleScope()->varargPack; + REQUIRE(bool(mainModule->getModuleScope()->varargPack)); + + TypePackId varargPack = *mainModule->getModuleScope()->varargPack; auto iter = begin(varargPack); auto endIter = end(varargPack); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index d30220953..8831bb2ea 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -397,8 +397,6 @@ local e = a.z TEST_CASE_FIXTURE(Fixture, "optional_iteration") { - ScopedFastFlag luauNilIterator{"LuauNilIterator", true}; - CheckResult result = check(R"( function foo(values: {number}?) local s = 0 diff --git a/tests/TypeReduction.test.cpp b/tests/TypeReduction.test.cpp new file mode 100644 index 000000000..c629b3e34 --- /dev/null +++ b/tests/TypeReduction.test.cpp @@ -0,0 +1,1249 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeReduction.h" + +#include "Fixture.h" +#include "doctest.h" + +using namespace Luau; + +namespace +{ +struct ReductionFixture : Fixture +{ + TypeArena arena; + InternalErrorReporter iceHandler; + UnifierSharedState unifierState{&iceHandler}; + TypeReduction reduction{NotNull{&arena}, builtinTypes, NotNull{&iceHandler}}; + + ReductionFixture() + { + registerHiddenTypes(&frontend); + createSomeClasses(&frontend); + } + + TypeId reductionof(TypeId ty) + { + std::optional reducedTy = reduction.reduce(ty); + REQUIRE(reducedTy); + return *reducedTy; + } + + std::optional tryReduce(const std::string& annotation) + { + CheckResult result = check("type _Res = " + annotation); + LUAU_REQUIRE_NO_ERRORS(result); + return reduction.reduce(requireTypeAlias("_Res")); + } + + TypeId reductionof(const std::string& annotation) + { + std::optional reducedTy = tryReduce(annotation); + REQUIRE_MESSAGE(reducedTy, "Exceeded the cartesian product of the type"); + return *reducedTy; + } +}; +} // namespace + +TEST_SUITE_BEGIN("TypeReductionTests"); + +TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_exceeded") +{ + ScopedFastInt sfi{"LuauTypeReductionCartesianProductLimit", 5}; + + std::optional ty = tryReduce(R"( + string & (number | string | boolean) & (number | string | boolean) + )"); + + CHECK(!ty); +} + +TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_exceeded_with_normal_limit") +{ + std::optional ty = tryReduce(R"( + string -- 1 = 1 + & (number | string | boolean) -- 1 * 3 = 3 + & (number | string | boolean) -- 3 * 3 = 9 + & (number | string | boolean) -- 9 * 3 = 27 + & (number | string | boolean) -- 27 * 3 = 81 + & (number | string | boolean) -- 81 * 3 = 243 + & (number | string | boolean) -- 243 * 3 = 729 + & (number | string | boolean) -- 729 * 3 = 2187 + & (number | string | boolean) -- 2187 * 3 = 6561 + & (number | string | boolean) -- 6561 * 3 = 19683 + & (number | string | boolean) -- 19683 * 3 = 59049 + & (number | string) -- 59049 * 2 = 118098 + )"); + + CHECK(!ty); +} + +TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_is_zero") +{ + ScopedFastInt sfi{"LuauTypeReductionCartesianProductLimit", 5}; + + std::optional ty = tryReduce(R"( + string & (number | string | boolean) & (number | string | boolean) & never + )"); + + CHECK(ty); +} + +TEST_CASE_FIXTURE(ReductionFixture, "intersections_without_negations") +{ + SUBCASE("string_and_string") + { + TypeId ty = reductionof("string & string"); + CHECK("string" == toString(ty)); + } + + SUBCASE("never_and_string") + { + TypeId ty = reductionof("never & string"); + CHECK("never" == toString(ty)); + } + + SUBCASE("string_and_never") + { + TypeId ty = reductionof("string & never"); + CHECK("never" == toString(ty)); + } + + SUBCASE("unknown_and_string") + { + TypeId ty = reductionof("unknown & string"); + CHECK("string" == toString(ty)); + } + + SUBCASE("string_and_unknown") + { + TypeId ty = reductionof("string & unknown"); + CHECK("string" == toString(ty)); + } + + SUBCASE("any_and_string") + { + TypeId ty = reductionof("any & string"); + CHECK("string" == toString(ty)); + } + + SUBCASE("string_and_any") + { + TypeId ty = reductionof("string & any"); + CHECK("string" == toString(ty)); + } + + SUBCASE("string_or_number_and_string") + { + TypeId ty = reductionof("(string | number) & string"); + CHECK("string" == toString(ty)); + } + + SUBCASE("string_and_string_or_number") + { + TypeId ty = reductionof("string & (string | number)"); + CHECK("string" == toString(ty)); + } + + SUBCASE("string_and_a") + { + TypeId ty = reductionof(R"(string & "a")"); + CHECK(R"("a")" == toString(ty)); + } + + SUBCASE("boolean_and_true") + { + TypeId ty = reductionof("boolean & true"); + CHECK("true" == toString(ty)); + } + + SUBCASE("boolean_and_a") + { + TypeId ty = reductionof(R"(boolean & "a")"); + CHECK("never" == toString(ty)); + } + + SUBCASE("a_and_a") + { + TypeId ty = reductionof(R"("a" & "a")"); + CHECK(R"("a")" == toString(ty)); + } + + SUBCASE("a_and_b") + { + TypeId ty = reductionof(R"("a" & "b")"); + CHECK("never" == toString(ty)); + } + + SUBCASE("a_and_true") + { + TypeId ty = reductionof(R"("a" & true)"); + CHECK("never" == toString(ty)); + } + + SUBCASE("a_and_true") + { + TypeId ty = reductionof(R"(true & false)"); + CHECK("never" == toString(ty)); + } + + SUBCASE("function_type_and_function") + { + TypeId ty = reductionof("() -> () & fun"); + CHECK("() -> ()" == toString(ty)); + } + + SUBCASE("function_type_and_string") + { + TypeId ty = reductionof("() -> () & string"); + CHECK("never" == toString(ty)); + } + + SUBCASE("parent_and_child") + { + TypeId ty = reductionof("Parent & Child"); + CHECK("Child" == toString(ty)); + } + + SUBCASE("child_and_parent") + { + TypeId ty = reductionof("Child & Parent"); + CHECK("Child" == toString(ty)); + } + + SUBCASE("child_and_unrelated") + { + TypeId ty = reductionof("Child & Unrelated"); + CHECK("never" == toString(ty)); + } + + SUBCASE("string_and_table") + { + TypeId ty = reductionof("string & {}"); + CHECK("never" == toString(ty)); + } + + SUBCASE("string_and_child") + { + TypeId ty = reductionof("string & Child"); + CHECK("never" == toString(ty)); + } + + SUBCASE("string_and_function") + { + TypeId ty = reductionof("string & () -> ()"); + CHECK("never" == toString(ty)); + } + + SUBCASE("function_and_table") + { + TypeId ty = reductionof("() -> () & {}"); + CHECK("never" == toString(ty)); + } + + SUBCASE("function_and_class") + { + TypeId ty = reductionof("() -> () & Child"); + CHECK("never" == toString(ty)); + } + + SUBCASE("function_and_function") + { + TypeId ty = reductionof("() -> () & () -> ()"); + CHECK("(() -> ()) & (() -> ())" == toString(ty)); + } + + SUBCASE("table_and_table") + { + TypeId ty = reductionof("{} & {}"); + CHECK("{| |}" == toString(ty)); + } + + SUBCASE("table_and_metatable") + { + // No setmetatable in ReductionFixture, so we mix and match. + BuiltinsFixture fixture; + fixture.check(R"( + type Ty = {} & typeof(setmetatable({}, {})) + )"); + + TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); + CHECK("{ @metatable { }, { } } & {| |}" == toString(ty)); + } + + SUBCASE("a_and_string") + { + TypeId ty = reductionof(R"("a" & string)"); + CHECK(R"("a")" == toString(ty)); + } + + SUBCASE("reducible_function_and_function") + { + TypeId ty = reductionof("((string | string) -> (number | number)) & fun"); + CHECK("(string) -> number" == toString(ty)); + } + + SUBCASE("string_and_error") + { + TypeId ty = reductionof("string & err"); + CHECK("*error-type* & string" == toString(ty)); + } + + SUBCASE("table_p_string_and_table_p_number") + { + TypeId ty = reductionof("{ p: string } & { p: number }"); + CHECK("never" == toString(ty)); + } + + SUBCASE("table_p_string_and_table_p_string") + { + TypeId ty = reductionof("{ p: string } & { p: string }"); + CHECK("{| p: string |}" == toString(ty)); + } + + SUBCASE("table_x_table_p_string_and_table_x_table_p_number") + { + TypeId ty = reductionof("{ x: { p: string } } & { x: { p: number } }"); + CHECK("never" == toString(ty)); + } + + SUBCASE("table_p_and_table_q") + { + TypeId ty = reductionof("{ p: string } & { q: number }"); + CHECK("{| p: string, q: number |}" == toString(ty)); + } + + SUBCASE("table_tag_a_or_table_tag_b_and_table_b") + { + TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { b: string }"); + CHECK("{| a: number, b: string, tag: string |} | {| b: string, tag: number |}" == toString(ty)); + } + + SUBCASE("table_string_number_indexer_and_table_string_number_indexer") + { + TypeId ty = reductionof("{ [string]: number } & { [string]: number }"); + CHECK("{| [string]: number |}" == toString(ty)); + } + + SUBCASE("table_string_number_indexer_and_empty_table") + { + TypeId ty = reductionof("{ [string]: number } & {}"); + CHECK("{| [string]: number |}" == toString(ty)); + } + + SUBCASE("empty_table_table_string_number_indexer") + { + TypeId ty = reductionof("{} & { [string]: number }"); + CHECK("{| [string]: number |}" == toString(ty)); + } + + SUBCASE("string_number_indexer_and_number_number_indexer") + { + TypeId ty = reductionof("{ [string]: number } & { [number]: number }"); + CHECK("never" == toString(ty)); + } + + SUBCASE("table_p_string_and_indexer_number_number") + { + TypeId ty = reductionof("{ p: string } & { [number]: number }"); + CHECK("{| [number]: number, p: string |}" == toString(ty)); + } + + SUBCASE("table_p_string_and_indexer_string_number") + { + TypeId ty = reductionof("{ p: string } & { [string]: number }"); + CHECK("{| [string]: number, p: string |}" == toString(ty)); + } + + SUBCASE("table_p_string_and_table_p_string_plus_indexer_string_number") + { + TypeId ty = reductionof("{ p: string } & { p: string, [string]: number }"); + CHECK("{| [string]: number, p: string |}" == toString(ty)); + } +} // intersections_without_negations + +TEST_CASE_FIXTURE(ReductionFixture, "intersections_with_negations") +{ + SUBCASE("nil_and_not_nil") + { + TypeId ty = reductionof("nil & Not"); + CHECK("never" == toString(ty)); + } + + SUBCASE("nil_and_not_false") + { + TypeId ty = reductionof("nil & Not"); + CHECK("nil" == toString(ty)); + } + + SUBCASE("string_or_nil_and_not_nil") + { + TypeId ty = reductionof("(string?) & Not"); + CHECK("string" == toString(ty)); + } + + SUBCASE("string_or_nil_and_not_false_or_nil") + { + TypeId ty = reductionof("(string?) & Not"); + CHECK("string" == toString(ty)); + } + + SUBCASE("string_or_nil_and_not_false_and_not_nil") + { + TypeId ty = reductionof("(string?) & Not & Not"); + CHECK("string" == toString(ty)); + } + + SUBCASE("not_false_and_bool") + { + TypeId ty = reductionof("Not & boolean"); + CHECK("true" == toString(ty)); + } + + SUBCASE("function_type_and_not_function") + { + TypeId ty = reductionof("() -> () & Not"); + CHECK("never" == toString(ty)); + } + + SUBCASE("function_type_and_not_string") + { + TypeId ty = reductionof("() -> () & Not"); + CHECK("() -> ()" == toString(ty)); + } + + SUBCASE("not_a_and_string_or_nil") + { + TypeId ty = reductionof(R"(Not<"a"> & (string | nil))"); + CHECK(R"((string & ~"a")?)" == toString(ty)); + } + + SUBCASE("not_a_and_a") + { + TypeId ty = reductionof(R"(Not<"a"> & "a")"); + CHECK("never" == toString(ty)); + } + + SUBCASE("not_a_and_b") + { + TypeId ty = reductionof(R"(Not<"a"> & "b")"); + CHECK(R"("b")" == toString(ty)); + } + + SUBCASE("not_string_and_a") + { + TypeId ty = reductionof(R"(Not & "a")"); + CHECK("never" == toString(ty)); + } + + SUBCASE("not_bool_and_true") + { + TypeId ty = reductionof("Not & true"); + CHECK("never" == toString(ty)); + } + + SUBCASE("not_string_and_true") + { + TypeId ty = reductionof("Not & true"); + CHECK("true" == toString(ty)); + } + + SUBCASE("parent_and_not_child") + { + TypeId ty = reductionof("Parent & Not"); + CHECK("Parent & ~Child" == toString(ty)); + } + + SUBCASE("not_child_and_parent") + { + TypeId ty = reductionof("Not & Parent"); + CHECK("Parent & ~Child" == toString(ty)); + } + + SUBCASE("child_and_not_parent") + { + TypeId ty = reductionof("Child & Not"); + CHECK("never" == toString(ty)); + } + + SUBCASE("not_parent_and_child") + { + TypeId ty = reductionof("Not & Child"); + CHECK("never" == toString(ty)); + } + + SUBCASE("not_parent_and_unrelated") + { + TypeId ty = reductionof("Not & Unrelated"); + CHECK("Unrelated" == toString(ty)); + } + + SUBCASE("unrelated_and_not_parent") + { + TypeId ty = reductionof("Unrelated & Not"); + CHECK("Unrelated" == toString(ty)); + } + + SUBCASE("not_unrelated_and_parent") + { + TypeId ty = reductionof("Not & Parent"); + CHECK("Parent" == toString(ty)); + } + + SUBCASE("parent_and_not_unrelated") + { + TypeId ty = reductionof("Parent & Not"); + CHECK("Parent" == toString(ty)); + } + + SUBCASE("reducible_function_and_not_function") + { + TypeId ty = reductionof("((string | string) -> (number | number)) & Not"); + CHECK("never" == toString(ty)); + } + + SUBCASE("string_and_not_error") + { + TypeId ty = reductionof("string & Not"); + CHECK("string & ~*error-type*" == toString(ty)); + } + + SUBCASE("table_p_string_and_table_p_not_number") + { + TypeId ty = reductionof("{ p: string } & { p: Not }"); + CHECK("{| p: string |}" == toString(ty)); + } + + SUBCASE("table_p_string_and_table_p_not_string") + { + TypeId ty = reductionof("{ p: string } & { p: Not }"); + CHECK("never" == toString(ty)); + } + + SUBCASE("table_x_table_p_string_and_table_x_table_p_not_number") + { + TypeId ty = reductionof("{ x: { p: string } } & { x: { p: Not } }"); + CHECK("{| x: {| p: string |} |}" == toString(ty)); + } +} // intersections_with_negations + +TEST_CASE_FIXTURE(ReductionFixture, "unions_without_negations") +{ + SUBCASE("never_or_string") + { + TypeId ty = reductionof("never | string"); + CHECK("string" == toString(ty)); + } + + SUBCASE("string_or_never") + { + TypeId ty = reductionof("string | never"); + CHECK("string" == toString(ty)); + } + + SUBCASE("unknown_or_string") + { + TypeId ty = reductionof("unknown | string"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("string_or_unknown") + { + TypeId ty = reductionof("string | unknown"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("any_or_string") + { + TypeId ty = reductionof("any | string"); + CHECK("any" == toString(ty)); + } + + SUBCASE("string_or_any") + { + TypeId ty = reductionof("string | any"); + CHECK("any" == toString(ty)); + } + + SUBCASE("string_or_string_and_number") + { + TypeId ty = reductionof("string | (string & number)"); + CHECK("string" == toString(ty)); + } + + SUBCASE("string_or_string") + { + TypeId ty = reductionof("string | string"); + CHECK("string" == toString(ty)); + } + + SUBCASE("string_or_number") + { + TypeId ty = reductionof("string | number"); + CHECK("number | string" == toString(ty)); + } + + SUBCASE("number_or_string") + { + TypeId ty = reductionof("number | string"); + CHECK("number | string" == toString(ty)); + } + + SUBCASE("string_or_number_or_string") + { + TypeId ty = reductionof("(string | number) | string"); + CHECK("number | string" == toString(ty)); + } + + SUBCASE("string_or_number_or_string_2") + { + TypeId ty = reductionof("string | (number | string)"); + CHECK("number | string" == toString(ty)); + } + + SUBCASE("string_or_string_or_number") + { + TypeId ty = reductionof("string | (string | number)"); + CHECK("number | string" == toString(ty)); + } + + SUBCASE("string_or_string_or_number_or_boolean") + { + TypeId ty = reductionof("string | (string | number | boolean)"); + CHECK("boolean | number | string" == toString(ty)); + } + + SUBCASE("string_or_string_or_boolean_or_number") + { + TypeId ty = reductionof("string | (string | boolean | number)"); + CHECK("boolean | number | string" == toString(ty)); + } + + SUBCASE("string_or_boolean_or_string_or_number") + { + TypeId ty = reductionof("string | (boolean | string | number)"); + CHECK("boolean | number | string" == toString(ty)); + } + + SUBCASE("boolean_or_string_or_number_or_string") + { + TypeId ty = reductionof("(boolean | string | number) | string"); + CHECK("boolean | number | string" == toString(ty)); + } + + SUBCASE("boolean_or_true") + { + TypeId ty = reductionof("boolean | true"); + CHECK("boolean" == toString(ty)); + } + + SUBCASE("boolean_or_false") + { + TypeId ty = reductionof("boolean | false"); + CHECK("boolean" == toString(ty)); + } + + SUBCASE("boolean_or_true_or_false") + { + TypeId ty = reductionof("boolean | true | false"); + CHECK("boolean" == toString(ty)); + } + + SUBCASE("string_or_a") + { + TypeId ty = reductionof(R"(string | "a")"); + CHECK("string" == toString(ty)); + } + + SUBCASE("a_or_a") + { + TypeId ty = reductionof(R"("a" | "a")"); + CHECK(R"("a")" == toString(ty)); + } + + SUBCASE("a_or_b") + { + TypeId ty = reductionof(R"("a" | "b")"); + CHECK(R"("a" | "b")" == toString(ty)); + } + + SUBCASE("a_or_b_or_string") + { + TypeId ty = reductionof(R"("a" | "b" | string)"); + CHECK("string" == toString(ty)); + } + + SUBCASE("unknown_or_any") + { + TypeId ty = reductionof("unknown | any"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("any_or_unknown") + { + TypeId ty = reductionof("any | unknown"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("function_type_or_function") + { + TypeId ty = reductionof("() -> () | fun"); + CHECK("function" == toString(ty)); + } + + SUBCASE("function_or_string") + { + TypeId ty = reductionof("fun | string"); + CHECK("function | string" == toString(ty)); + } + + SUBCASE("parent_or_child") + { + TypeId ty = reductionof("Parent | Child"); + CHECK("Parent" == toString(ty)); + } + + SUBCASE("child_or_parent") + { + TypeId ty = reductionof("Child | Parent"); + CHECK("Parent" == toString(ty)); + } + + SUBCASE("parent_or_unrelated") + { + TypeId ty = reductionof("Parent | Unrelated"); + CHECK("Parent | Unrelated" == toString(ty)); + } + + SUBCASE("parent_or_child_or_unrelated") + { + TypeId ty = reductionof("Parent | Child | Unrelated"); + CHECK("Parent | Unrelated" == toString(ty)); + } + + SUBCASE("parent_or_unrelated_or_child") + { + TypeId ty = reductionof("Parent | Unrelated | Child"); + CHECK("Parent | Unrelated" == toString(ty)); + } + + SUBCASE("parent_or_child_or_unrelated_or_child") + { + TypeId ty = reductionof("Parent | Child | Unrelated | Child"); + CHECK("Parent | Unrelated" == toString(ty)); + } + + SUBCASE("string_or_true") + { + TypeId ty = reductionof("string | true"); + CHECK("string | true" == toString(ty)); + } + + SUBCASE("string_or_function") + { + TypeId ty = reductionof("string | () -> ()"); + CHECK("(() -> ()) | string" == toString(ty)); + } + + SUBCASE("string_or_err") + { + TypeId ty = reductionof("string | err"); + CHECK("*error-type* | string" == toString(ty)); + } +} // unions_without_negations + +TEST_CASE_FIXTURE(ReductionFixture, "unions_with_negations") +{ + SUBCASE("string_or_not_string") + { + TypeId ty = reductionof("string | Not"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("not_string_or_string") + { + TypeId ty = reductionof("Not | string"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("not_number_or_string") + { + TypeId ty = reductionof("Not | string"); + CHECK("~number" == toString(ty)); + } + + SUBCASE("string_or_not_number") + { + TypeId ty = reductionof("string | Not"); + CHECK("~number" == toString(ty)); + } + + SUBCASE("not_hi_or_string_and_not_hi") + { + TypeId ty = reductionof(R"(Not<"hi"> | (string & Not<"hi">))"); + CHECK(R"(~"hi")" == toString(ty)); + } + + SUBCASE("string_and_not_hi_or_not_hi") + { + TypeId ty = reductionof(R"((string & Not<"hi">) | Not<"hi">)"); + CHECK(R"(~"hi")" == toString(ty)); + } + + SUBCASE("string_or_not_never") + { + TypeId ty = reductionof("string | Not"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("not_a_or_not_a") + { + TypeId ty = reductionof(R"(Not<"a"> | Not<"a">)"); + CHECK(R"(~"a")" == toString(ty)); + } + + SUBCASE("not_a_or_a") + { + TypeId ty = reductionof(R"(Not<"a"> | "a")"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("a_or_not_a") + { + TypeId ty = reductionof(R"("a" | Not<"a">)"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("not_a_or_string") + { + TypeId ty = reductionof(R"(Not<"a"> | string)"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("string_or_not_a") + { + TypeId ty = reductionof(R"(string | Not<"a">)"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("not_string_or_a") + { + TypeId ty = reductionof(R"(Not | "a")"); + CHECK(R"("a" | ~string)" == toString(ty)); + } + + SUBCASE("a_or_not_string") + { + TypeId ty = reductionof(R"("a" | Not)"); + CHECK(R"("a" | ~string)" == toString(ty)); + } + + SUBCASE("not_number_or_a") + { + TypeId ty = reductionof(R"(Not | "a")"); + CHECK("~number" == toString(ty)); + } + + SUBCASE("a_or_not_number") + { + TypeId ty = reductionof(R"("a" | Not)"); + CHECK("~number" == toString(ty)); + } + + SUBCASE("not_a_or_not_b") + { + TypeId ty = reductionof(R"(Not<"a"> | Not<"b">)"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("boolean_or_not_false") + { + TypeId ty = reductionof("boolean | Not"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("boolean_or_not_true") + { + TypeId ty = reductionof("boolean | Not"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("false_or_not_false") + { + TypeId ty = reductionof("false | Not"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("true_or_not_false") + { + TypeId ty = reductionof("true | Not"); + CHECK("~false" == toString(ty)); + } + + SUBCASE("not_boolean_or_true") + { + TypeId ty = reductionof("Not | true"); + CHECK("~false" == toString(ty)); + } + + SUBCASE("not_false_or_not_boolean") + { + TypeId ty = reductionof("Not | Not"); + CHECK("~false" == toString(ty)); + } + + SUBCASE("function_type_or_not_function") + { + TypeId ty = reductionof("() -> () | Not"); + CHECK("(() -> ()) | ~function" == toString(ty)); + } + + SUBCASE("not_parent_or_child") + { + TypeId ty = reductionof("Not | Child"); + CHECK("Child | ~Parent" == toString(ty)); + } + + SUBCASE("child_or_not_parent") + { + TypeId ty = reductionof("Child | Not"); + CHECK("Child | ~Parent" == toString(ty)); + } + + SUBCASE("parent_or_not_child") + { + TypeId ty = reductionof("Parent | Not"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("not_child_or_parent") + { + TypeId ty = reductionof("Not | Parent"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("parent_or_not_unrelated") + { + TypeId ty = reductionof("Parent | Not"); + CHECK("~Unrelated" == toString(ty)); + } + + SUBCASE("not_string_or_string_and_not_a") + { + TypeId ty = reductionof(R"(Not | (string & Not<"a">))"); + CHECK(R"(~"a")" == toString(ty)); + } + + SUBCASE("not_string_or_not_string") + { + TypeId ty = reductionof("Not | Not"); + CHECK("~string" == toString(ty)); + } + + SUBCASE("not_string_or_not_number") + { + TypeId ty = reductionof("Not | Not"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("not_a_or_not_boolean") + { + TypeId ty = reductionof(R"(Not<"a"> | Not)"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("not_a_or_boolean") + { + TypeId ty = reductionof(R"(Not<"a"> | boolean)"); + CHECK(R"(~"a")" == toString(ty)); + } + + SUBCASE("string_or_err") + { + TypeId ty = reductionof("string | Not"); + CHECK("string | ~*error-type*" == toString(ty)); + } +} // unions_with_negations + +TEST_CASE_FIXTURE(ReductionFixture, "tables") +{ + SUBCASE("reduce_props") + { + ToStringOptions opts; + opts.exhaustive = true; + + TypeId ty = reductionof("{ x: string | string, y: number | number }"); + CHECK("{| x: string, y: number |}" == toString(ty, opts)); + } + + SUBCASE("reduce_indexers") + { + ToStringOptions opts; + opts.exhaustive = true; + + TypeId ty = reductionof("{ [string | string]: number | number }"); + CHECK("{| [string]: number |}" == toString(ty, opts)); + } + + SUBCASE("reduce_instantiated_type_parameters") + { + check(R"( + type Foo = { x: T } + local foo: Foo = { x = "hello" } + )"); + + TypeId ty = reductionof(requireType("foo")); + CHECK("Foo" == toString(ty)); + } + + SUBCASE("reduce_instantiated_type_pack_parameters") + { + check(R"( + type Foo = { x: () -> T... } + local foo: Foo = { x = function() return "hi", 5 end } + )"); + + TypeId ty = reductionof(requireType("foo")); + CHECK("Foo" == toString(ty)); + } + + SUBCASE("reduce_tables_within_tables") + { + ToStringOptions opts; + opts.exhaustive = true; + + TypeId ty = reductionof("{ x: { y: string & number } }"); + CHECK("{| x: {| y: never |} |}" == toString(ty, opts)); + } +} + +TEST_CASE_FIXTURE(ReductionFixture, "metatables") +{ + SUBCASE("reduce_table_part") + { + TableType table; + table.props["x"] = {arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->stringType}})}; + TypeId tableTy = arena.addType(std::move(table)); + + TypeId ty = reductionof(arena.addType(MetatableType{tableTy, arena.addType(TableType{})})); + CHECK("{ @metatable { }, { x: string } }" == toString(ty)); + } + + SUBCASE("reduce_metatable_part") + { + TableType table; + table.props["x"] = {arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->stringType}})}; + TypeId tableTy = arena.addType(std::move(table)); + + TypeId ty = reductionof(arena.addType(MetatableType{arena.addType(TableType{}), tableTy})); + CHECK("{ @metatable { x: string }, { } }" == toString(ty)); + } +} + +TEST_CASE_FIXTURE(ReductionFixture, "functions") +{ + SUBCASE("reduce_parameters") + { + TypeId ty = reductionof("(string | string) -> ()"); + CHECK("(string) -> ()" == toString(ty)); + } + + SUBCASE("reduce_returns") + { + TypeId ty = reductionof("() -> (string | string)"); + CHECK("() -> string" == toString(ty)); + } + + SUBCASE("reduce_parameters_and_returns") + { + TypeId ty = reductionof("(string | string) -> (number | number)"); + CHECK("(string) -> number" == toString(ty)); + } + + SUBCASE("reduce_tail") + { + TypeId ty = reductionof("() -> ...(string | string)"); + CHECK("() -> (...string)" == toString(ty)); + } + + SUBCASE("reduce_head_and_tail") + { + TypeId ty = reductionof("() -> (string | string, number | number, ...(boolean | boolean))"); + CHECK("() -> (string, number, ...boolean)" == toString(ty)); + } + + SUBCASE("reduce_overloaded_functions") + { + TypeId ty = reductionof("((number | number) -> ()) & ((string | string) -> ())"); + CHECK("((number) -> ()) & ((string) -> ())" == toString(ty)); + } +} // functions + +TEST_CASE_FIXTURE(ReductionFixture, "negations") +{ + SUBCASE("not_unknown") + { + TypeId ty = reductionof("Not"); + CHECK("never" == toString(ty)); + } + + SUBCASE("not_never") + { + TypeId ty = reductionof("Not"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("not_any") + { + TypeId ty = reductionof("Not"); + CHECK("any" == toString(ty)); + } + + SUBCASE("not_not_reduction") + { + TypeId ty = reductionof("Not>"); + CHECK("never" == toString(ty)); + } + + SUBCASE("not_string") + { + TypeId ty = reductionof("Not"); + CHECK("~string" == toString(ty)); + } + + SUBCASE("not_string_or_number") + { + TypeId ty = reductionof("Not"); + CHECK("~number & ~string" == toString(ty)); + } + + SUBCASE("not_string_and_number") + { + TypeId ty = reductionof("Not"); + CHECK("unknown" == toString(ty)); + } + + SUBCASE("not_error") + { + TypeId ty = reductionof("Not"); + CHECK("~*error-type*" == toString(ty)); + } +} // negations + +TEST_CASE_FIXTURE(ReductionFixture, "discriminable_unions") +{ + SUBCASE("cat_or_dog_and_dog") + { + TypeId ty = reductionof(R"(({ tag: "cat", catfood: string } | { tag: "dog", dogfood: string }) & { tag: "dog" })"); + CHECK(R"({| dogfood: string, tag: "dog" |})" == toString(ty)); + } + + SUBCASE("cat_or_dog_and_not_dog") + { + TypeId ty = reductionof(R"(({ tag: "cat", catfood: string } | { tag: "dog", dogfood: string }) & { tag: Not<"dog"> })"); + CHECK(R"({| catfood: string, tag: "cat" |})" == toString(ty)); + } + + SUBCASE("string_or_number_and_number") + { + TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { tag: string }"); + CHECK("{| a: number, tag: string |}" == toString(ty)); + } + + SUBCASE("string_or_number_and_number") + { + TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { tag: number }"); + CHECK("{| b: string, tag: number |}" == toString(ty)); + } + + SUBCASE("child_or_unrelated_and_parent") + { + TypeId ty = reductionof("({ tag: Child, x: number } | { tag: Unrelated, y: string }) & { tag: Parent }"); + CHECK("{| tag: Child, x: number |}" == toString(ty)); + } + + SUBCASE("child_or_unrelated_and_not_parent") + { + TypeId ty = reductionof("({ tag: Child, x: number } | { tag: Unrelated, y: string }) & { tag: Not }"); + CHECK("{| tag: Unrelated, y: string |}" == toString(ty)); + } +} + +TEST_CASE_FIXTURE(ReductionFixture, "cycles") +{ + SUBCASE("recursively_defined_function") + { + check("type F = (f: F) -> ()"); + + TypeId ty = reductionof(requireTypeAlias("F")); + CHECK("(t1) -> () where t1 = (t1) -> ()" == toString(ty)); + } + + SUBCASE("recursively_defined_function_and_function") + { + check("type F = (f: F & fun) -> ()"); + + TypeId ty = reductionof(requireTypeAlias("F")); + CHECK("(t1) -> () where t1 = (function & t1) -> ()" == toString(ty)); + } + + SUBCASE("recursively_defined_table") + { + ToStringOptions opts; + opts.exhaustive = true; + + check("type T = { x: T }"); + + TypeId ty = reductionof(requireTypeAlias("T")); + CHECK("{| x: t1 |} where t1 = {| x: t1 |}" == toString(ty, opts)); + } + + SUBCASE("recursively_defined_table_and_table") + { + ToStringOptions opts; + opts.exhaustive = true; + + check("type T = { x: T & {} }"); + + TypeId ty = reductionof(requireTypeAlias("T")); + CHECK("{| x: t1 & {| |} |} where t1 = {| x: t1 & {| |} |}" == toString(ty, opts)); + } + + SUBCASE("recursively_defined_table_and_table_2") + { + ToStringOptions opts; + opts.exhaustive = true; + + check("type T = { x: T } & { x: number }"); + + TypeId ty = reductionof(requireTypeAlias("T")); + CHECK("never" == toString(ty)); + } + + SUBCASE("recursively_defined_table_and_table_3") + { + ToStringOptions opts; + opts.exhaustive = true; + + check("type T = { x: T } & { x: T }"); + + TypeId ty = reductionof(requireTypeAlias("T")); + CHECK("{| x: {| x: t1 |} & {| x: t1 |} & {| x: t2 & t2 & {| x: t1 |} & {| x: t1 |} |} |} where t1 = t2 & {| x: t1 |} ; t2 = {| x: t1 |}" == + toString(ty)); + } +} + +TEST_CASE_FIXTURE(ReductionFixture, "stress_test_recursion_limits") +{ + TypeId ty = arena.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}); + for (size_t i = 0; i < 20'000; ++i) + { + TableType table; + table.state = TableState::Sealed; + table.props["x"] = {ty}; + ty = arena.addType(IntersectionType{{arena.addType(table), arena.addType(table)}}); + } + + CHECK(!reduction.reduce(ty)); +} + +TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index ec0a2473c..36e437e24 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -1,7 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Scope.h" -#include "Luau/TypeInfer.h" #include "Luau/Type.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeReduction.h" #include "Luau/VisitType.h" #include "Fixture.h" diff --git a/tools/faillist.txt b/tools/faillist.txt index 233c75c1a..f336bb222 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -321,6 +321,7 @@ TypeInfer.globals2 TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict TypeInfer.no_stack_overflow_from_isoptional +TypeInfer.no_stack_overflow_from_isoptional2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.tc_if_else_expressions_expected_type_3 TypeInfer.tc_interpolated_string_basic @@ -408,10 +409,7 @@ TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators TypeInferOperators.cli_38355_recursive_union -TypeInferOperators.compound_assign_metatable TypeInferOperators.compound_assign_mismatch_metatable -TypeInferOperators.compound_assign_mismatch_op -TypeInferOperators.compound_assign_mismatch_result TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown From 5db96755371e6b08465e496d96c08d965230de22 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 6 Jan 2023 18:30:05 +0200 Subject: [PATCH 26/66] Smaller recursion limit to not hit stack overflow in debug on Windows --- Analysis/src/TypeReduction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index e47ee39cc..0e66d0184 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -9,7 +9,7 @@ #include LUAU_FASTINTVARIABLE(LuauTypeReductionCartesianProductLimit, 100'000) -LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 900) +LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 700) LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) namespace Luau From a2365f2adf83f722a8e9ff7c5775a2f4feef449d Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 6 Jan 2023 18:30:32 +0200 Subject: [PATCH 27/66] Fix build warning --- Analysis/src/TypeReduction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 0e66d0184..db79debf3 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -27,7 +27,7 @@ struct RecursionGuard : RecursionLimiter , seen(seen) { // count has been incremented, which should imply that seen has already had an element pushed in. - LUAU_ASSERT(*count == seen->size()); + LUAU_ASSERT(size_t(*count) == seen->size()); } ~RecursionGuard() From 96c1cafff28b70054be91938c6fd004d45dfa007 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 13 Jan 2023 12:36:28 -0800 Subject: [PATCH 28/66] Sync to upstream/release/559 --- Analysis/include/Luau/Predicate.h | 25 +- Analysis/include/Luau/ToString.h | 1 - Analysis/include/Luau/TypeReduction.h | 4 +- Analysis/src/Autocomplete.cpp | 19 +- Analysis/src/ToString.cpp | 25 +- Analysis/src/TypeReduction.cpp | 96 +- Ast/src/Parser.cpp | 15 +- CodeGen/include/Luau/AssemblyBuilderX64.h | 5 +- CodeGen/src/AssemblyBuilderX64.cpp | 56 +- CodeGen/src/CodeGen.cpp | 10 + CodeGen/src/EmitCommonX64.cpp | 16 + CodeGen/src/EmitCommonX64.h | 17 +- CodeGen/src/EmitInstructionX64.cpp | 68 ++ CodeGen/src/EmitInstructionX64.h | 3 + CodeGen/src/Fallbacks.cpp | 13 - CodeGen/src/Fallbacks.h | 1 - CodeGen/src/IrData.h | 280 +++++ CodeGen/src/IrDump.cpp | 379 +++++++ CodeGen/src/IrDump.h | 32 + CodeGen/src/IrUtils.h | 161 +++ CodeGen/src/NativeState.cpp | 3 +- CodeGen/src/NativeState.h | 2 + Compiler/include/Luau/BytecodeBuilder.h | 4 + Compiler/src/BytecodeBuilder.cpp | 201 +++- Compiler/src/Compiler.cpp | 4 +- Sources.cmake | 4 + VM/include/lualib.h | 4 +- VM/src/laux.cpp | 4 +- VM/src/ldblib.cpp | 2 +- VM/src/loslib.cpp | 2 +- VM/src/lstrlib.cpp | 46 +- VM/src/ltablib.cpp | 2 +- VM/src/lutf8lib.cpp | 2 +- bench/bench.py | 2 +- tests/AssemblyBuilderX64.test.cpp | 6 + tests/Autocomplete.test.cpp | 48 +- tests/Compiler.test.cpp | 1255 +++++++++++---------- tests/Conformance.test.cpp | 2 + tests/Parser.test.cpp | 2 - tests/ToString.test.cpp | 3 - tests/TypeReduction.test.cpp | 132 ++- tests/conformance/strings.lua | 10 + tests/conformance/tables.lua | 14 + tools/faillist.txt | 12 +- tools/heapgraph.py | 2 +- tools/heapstat.py | 2 +- tools/lvmexecute_split.py | 4 +- tools/numprint.py | 2 +- tools/patchtests.py | 2 +- tools/perfgraph.py | 2 +- tools/perfstat.py | 2 +- tools/stack-usage-reporter.py | 2 +- tools/test_dcr.py | 1 + tools/tracegraph.py | 2 +- 54 files changed, 2188 insertions(+), 825 deletions(-) create mode 100644 CodeGen/src/IrData.h create mode 100644 CodeGen/src/IrDump.cpp create mode 100644 CodeGen/src/IrDump.h create mode 100644 CodeGen/src/IrUtils.h diff --git a/Analysis/include/Luau/Predicate.h b/Analysis/include/Luau/Predicate.h index 8d486ad51..50fd7edd8 100644 --- a/Analysis/include/Luau/Predicate.h +++ b/Analysis/include/Luau/Predicate.h @@ -57,11 +57,7 @@ struct AndPredicate PredicateVec lhs; PredicateVec rhs; - AndPredicate(PredicateVec&& lhs, PredicateVec&& rhs) - : lhs(std::move(lhs)) - , rhs(std::move(rhs)) - { - } + AndPredicate(PredicateVec&& lhs, PredicateVec&& rhs); }; struct OrPredicate @@ -69,11 +65,7 @@ struct OrPredicate PredicateVec lhs; PredicateVec rhs; - OrPredicate(PredicateVec&& lhs, PredicateVec&& rhs) - : lhs(std::move(lhs)) - , rhs(std::move(rhs)) - { - } + OrPredicate(PredicateVec&& lhs, PredicateVec&& rhs); }; struct NotPredicate @@ -81,6 +73,19 @@ struct NotPredicate PredicateVec predicates; }; +// Outside definition works around clang 15 issue where vector instantiation is triggered while Predicate is still incomplete +inline AndPredicate::AndPredicate(PredicateVec&& lhs, PredicateVec&& rhs) + : lhs(std::move(lhs)) + , rhs(std::move(rhs)) +{ +} + +inline OrPredicate::OrPredicate(PredicateVec&& lhs, PredicateVec&& rhs) + : lhs(std::move(lhs)) + , rhs(std::move(rhs)) +{ +} + template const T* get(const Predicate& predicate) { diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index dd8aef574..461a8fffb 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -45,7 +45,6 @@ struct ToStringOptions bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self - bool DEPRECATED_indent = false; // TODO Deprecated field, prune when clipping flag FFlagLuauLineBreaksDeterminIndents size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); ToStringNameMap nameMap; diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h index 7df7edfa7..a7cec9468 100644 --- a/Analysis/include/Luau/TypeReduction.h +++ b/Analysis/include/Luau/TypeReduction.h @@ -27,8 +27,8 @@ struct TypeReduction DenseHashMap cachedTypes{nullptr}; DenseHashMap cachedTypePacks{nullptr}; - std::optional reduceImpl(TypeId ty); - std::optional reduceImpl(TypePackId tp); + std::pair, bool> reduceImpl(TypeId ty); + std::pair, bool> reduceImpl(TypePackId tp); // Computes an *estimated length* of the cartesian product of the given type. size_t cartesianProductSize(TypeId ty) const; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 49c430e63..6fab97d52 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,6 +13,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); +LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInIf, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -1467,8 +1468,22 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; } else if (AstStatIf* statIf = extractStat(ancestry); - statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && + (!FFlag::LuauFixAutocompleteInIf || (statIf->condition && !statIf->condition->location.containsClosed(position)))) + { + if (FFlag::LuauFixAutocompleteInIf) + { + AutocompleteEntryMap ret; + ret["then"] = {AutocompleteEntryKind::Keyword}; + ret["and"] = {AutocompleteEntryKind::Keyword}; + ret["or"] = {AutocompleteEntryKind::Keyword}; + return {std::move(ret), ancestry, AutocompleteContext::Keyword}; + } + else + { + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + } + } else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index c7d9b3733..86ae3cde7 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -15,10 +15,8 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUnknownAndNeverType) -LUAU_FASTFLAGVARIABLE(LuauLineBreaksDetermineIndents, false) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) -LUAU_FASTFLAGVARIABLE(LuauSerializeNilUnionAsNil, false) /* * Prefix generic typenames with gen- @@ -277,20 +275,10 @@ struct StringifierState private: void emitIndentation() { - if (!FFlag::LuauLineBreaksDetermineIndents) - { - if (!opts.DEPRECATED_indent) - return; - - emit(std::string(indentation, ' ')); - } - else - { - if (!opts.useLineBreaks) - return; + if (!opts.useLineBreaks) + return; - emit(std::string(indentation, ' ')); - } + emit(std::string(indentation, ' ')); } }; @@ -780,11 +768,8 @@ struct TypeStringifier if (results.size() > 1) s = ")?"; - if (FFlag::LuauSerializeNilUnionAsNil) - { - if (!hasNonNilDisjunct) - s = "nil"; - } + if (!hasNonNilDisjunct) + s = "nil"; state.emit(s); } diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index e47ee39cc..8d837ddb2 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -9,7 +9,7 @@ #include LUAU_FASTINTVARIABLE(LuauTypeReductionCartesianProductLimit, 100'000) -LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 900) +LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 700) LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) namespace Luau @@ -27,7 +27,7 @@ struct RecursionGuard : RecursionLimiter , seen(seen) { // count has been incremented, which should imply that seen has already had an element pushed in. - LUAU_ASSERT(*count == seen->size()); + LUAU_ASSERT(size_t(*count) == seen->size()); } ~RecursionGuard() @@ -51,6 +51,16 @@ struct TypeReducer NotNull builtinTypes; NotNull handle; + std::unordered_map copies; + std::deque seen; + int depth = 0; + + // When we encounter _any type_ that which is usually mutated in-place, we need to not cache the result. + // e.g. `'a & {} T` may have an upper bound constraint `{}` placed upon `'a`, but this constraint was not + // known when we decided to reduce this intersection type. By not caching, we'll always be forced to perform + // the reduction calculus over again. + bool cacheOk = true; + TypeId reduce(TypeId ty); TypePackId reduce(TypePackId tp); @@ -60,13 +70,11 @@ struct TypeReducer TypeId functionType(TypeId ty); TypeId negationType(TypeId ty); - std::deque seen; - int depth = 0; - RecursionGuard guard(TypeId ty); RecursionGuard guard(TypePackId tp); - std::unordered_map copies; + void checkCacheable(TypeId ty); + void checkCacheable(TypePackId tp); template LUAU_NOINLINE std::pair copy(TypeId ty, const T* t) @@ -153,6 +161,7 @@ TypeId TypeReducer::reduce(TypeId ty) return ty; RecursionGuard rg = guard(ty); + checkCacheable(ty); if (auto i = get(ty)) return foldl(begin(i), end(i), &TypeReducer::intersectionType); @@ -176,6 +185,7 @@ TypePackId TypeReducer::reduce(TypePackId tp) return tp; RecursionGuard rg = guard(tp); + checkCacheable(tp); TypePackIterator it = begin(tp); @@ -213,6 +223,14 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) return right; // any & T ~ T else if (get(right)) return left; // T & any ~ T + else if (get(left)) + return std::nullopt; // 'a & T ~ 'a & T + else if (get(right)) + return std::nullopt; // T & 'a ~ T & 'a + else if (get(left)) + return std::nullopt; // G & T ~ G & T + else if (get(right)) + return std::nullopt; // T & G ~ T & G else if (get(left)) return std::nullopt; // error & T ~ error & T else if (get(right)) @@ -701,6 +719,32 @@ RecursionGuard TypeReducer::guard(TypePackId tp) return RecursionGuard{&depth, FInt::LuauTypeReductionRecursionLimit, &seen}; } +void TypeReducer::checkCacheable(TypeId ty) +{ + if (!cacheOk) + return; + + ty = follow(ty); + + // Only does shallow check, the TypeReducer itself already does deep traversal. + if (get(ty) || get(ty) || get(ty)) + cacheOk = false; + else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) + cacheOk = false; +} + +void TypeReducer::checkCacheable(TypePackId tp) +{ + if (!cacheOk) + return; + + tp = follow(tp); + + // Only does shallow check, the TypeReducer itself already does deep traversal. + if (get(tp) || get(tp)) + cacheOk = false; +} + } // namespace TypeReduction::TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle) @@ -715,13 +759,11 @@ std::optional TypeReduction::reduce(TypeId ty) if (auto found = cachedTypes.find(ty)) return *found; - if (auto reduced = reduceImpl(ty)) - { - cachedTypes[ty] = *reduced; - return *reduced; - } + auto [reducedTy, cacheOk] = reduceImpl(ty); + if (cacheOk) + cachedTypes[ty] = *reducedTy; - return std::nullopt; + return reducedTy; } std::optional TypeReduction::reduce(TypePackId tp) @@ -729,50 +771,48 @@ std::optional TypeReduction::reduce(TypePackId tp) if (auto found = cachedTypePacks.find(tp)) return *found; - if (auto reduced = reduceImpl(tp)) - { - cachedTypePacks[tp] = *reduced; - return *reduced; - } + auto [reducedTp, cacheOk] = reduceImpl(tp); + if (cacheOk) + cachedTypePacks[tp] = *reducedTp; - return std::nullopt; + return reducedTp; } -std::optional TypeReduction::reduceImpl(TypeId ty) +std::pair, bool> TypeReduction::reduceImpl(TypeId ty) { if (FFlag::DebugLuauDontReduceTypes) - return ty; + return {ty, false}; if (hasExceededCartesianProductLimit(ty)) - return std::nullopt; + return {std::nullopt, false}; try { TypeReducer reducer{arena, builtinTypes, handle}; - return reducer.reduce(ty); + return {reducer.reduce(ty), reducer.cacheOk}; } catch (const RecursionLimitException&) { - return std::nullopt; + return {std::nullopt, false}; } } -std::optional TypeReduction::reduceImpl(TypePackId tp) +std::pair, bool> TypeReduction::reduceImpl(TypePackId tp) { if (FFlag::DebugLuauDontReduceTypes) - return tp; + return {tp, false}; if (hasExceededCartesianProductLimit(tp)) - return std::nullopt; + return {std::nullopt, false}; try { TypeReducer reducer{arena, builtinTypes, handle}; - return reducer.reduce(tp); + return {reducer.reduce(tp), reducer.cacheOk}; } catch (const RecursionLimitException&) { - return std::nullopt; + return {std::nullopt, false}; } } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 5cd5f7437..dea54c168 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -14,11 +14,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false) -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) - -bool lua_telemetry_parsed_named_non_function_type = false; - LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) @@ -1423,7 +1418,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray paramTypes = copy(params); - if (FFlag::LuauFixNamedFunctionParse && !names.empty()) + if (!names.empty()) forceFunctionType = true; bool returnTypeIntroducer = lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':'; @@ -1431,9 +1426,6 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element if (params.size() == 1 && !varargAnnotation && !forceFunctionType && !returnTypeIntroducer) { - if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) - lua_telemetry_parsed_named_non_function_type = true; - if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; else @@ -1441,12 +1433,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) } if (!forceFunctionType && !returnTypeIntroducer && allowPack) - { - if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) - lua_telemetry_parsed_named_non_function_type = true; - return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; - } AstArray> paramNames = copy(names); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index dbb366b23..918c82669 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -78,6 +78,7 @@ class AssemblyBuilderX64 void test(OperandX64 lhs, OperandX64 rhs); void lea(OperandX64 lhs, OperandX64 rhs); + void setcc(ConditionX64 cond, OperandX64 op); void push(OperandX64 op); void pop(OperandX64 op); @@ -189,8 +190,8 @@ class AssemblyBuilderX64 const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t imm8, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); // Instruction components - void placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs); - void placeModRegMem(OperandX64 rhs, uint8_t regop); + void placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs, int32_t extraCodeBytes = 0); + void placeModRegMem(OperandX64 rhs, uint8_t regop, int32_t extraCodeBytes = 0); void placeRex(RegisterX64 op); void placeRex(OperandX64 op); void placeRexNoW(OperandX64 op); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index f23fe4634..77856e9e4 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -17,9 +17,13 @@ static const uint8_t codeForCondition[] = { 0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd, 0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb}; static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered"); -static const char* textForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna", +static const char* jccTextForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna", "jnae", "jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"}; -static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered"); +static_assert(sizeof(jccTextForCondition) / sizeof(jccTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered"); + +static const char* setccTextForCondition[] = {"seto", "setno", "setc", "setnc", "setb", "setbe", "seta", "setae", "sete", "setl", "setle", "setg", + "setge", "setnb", "setnbe", "setna", "setnae", "setne", "setnl", "setnle", "setng", "setnge", "setz", "setnz", "setp", "setnp"}; +static_assert(sizeof(setccTextForCondition) / sizeof(setccTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered"); #define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7)) #define OP_PLUS_CC(op, cc) ((op) + uint8_t(cc)) @@ -169,7 +173,7 @@ void AssemblyBuilderX64::mov(OperandX64 lhs, OperandX64 rhs) if (size == SizeX64::byte) { place(0xc6); - placeModRegMem(lhs, 0); + placeModRegMem(lhs, 0, /*extraCodeBytes=*/1); placeImm8(rhs.imm); } else @@ -177,7 +181,7 @@ void AssemblyBuilderX64::mov(OperandX64 lhs, OperandX64 rhs) LUAU_ASSERT(size == SizeX64::dword || size == SizeX64::qword); place(0xc7); - placeModRegMem(lhs, 0); + placeModRegMem(lhs, 0, /*extraCodeBytes=*/4); placeImm32(rhs.imm); } } @@ -304,13 +308,13 @@ void AssemblyBuilderX64::imul(OperandX64 dst, OperandX64 lhs, int32_t rhs) if (int8_t(rhs) == rhs) { place(0x6b); - placeRegAndModRegMem(dst, lhs); + placeRegAndModRegMem(dst, lhs, /*extraCodeBytes=*/1); placeImm8(rhs); } else { place(0x69); - placeRegAndModRegMem(dst, lhs); + placeRegAndModRegMem(dst, lhs, /*extraCodeBytes=*/4); placeImm32(rhs); } @@ -366,9 +370,24 @@ void AssemblyBuilderX64::ret() commit(); } +void AssemblyBuilderX64::setcc(ConditionX64 cond, OperandX64 op) +{ + SizeX64 size = op.cat == CategoryX64::reg ? op.base.size : op.memSize; + LUAU_ASSERT(size == SizeX64::byte); + + if (logText) + log(setccTextForCondition[size_t(cond)], op); + + placeRex(op); + place(0x0f); + place(0x90 | codeForCondition[size_t(cond)]); + placeModRegMem(op, 0); + commit(); +} + void AssemblyBuilderX64::jcc(ConditionX64 cond, Label& label) { - placeJcc(textForCondition[size_t(cond)], label, codeForCondition[size_t(cond)]); + placeJcc(jccTextForCondition[size_t(cond)], label, codeForCondition[size_t(cond)]); } void AssemblyBuilderX64::jmp(Label& label) @@ -866,7 +885,7 @@ void AssemblyBuilderX64::placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, if (size == SizeX64::byte) { place(code8); - placeModRegMem(lhs, opreg); + placeModRegMem(lhs, opreg, /*extraCodeBytes=*/1); placeImm8(rhs.imm); } else @@ -876,13 +895,13 @@ void AssemblyBuilderX64::placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, if (int8_t(rhs.imm) == rhs.imm && code != codeImm8) { place(codeImm8); - placeModRegMem(lhs, opreg); + placeModRegMem(lhs, opreg, /*extraCodeBytes=*/1); placeImm8(rhs.imm); } else { place(code); - placeModRegMem(lhs, opreg); + placeModRegMem(lhs, opreg, /*extraCodeBytes=*/4); placeImm32(rhs.imm); } } @@ -950,7 +969,7 @@ void AssemblyBuilderX64::placeShift(const char* name, OperandX64 lhs, OperandX64 LUAU_ASSERT(int8_t(rhs.imm) == rhs.imm); place(size == SizeX64::byte ? 0xc0 : 0xc1); - placeModRegMem(lhs, opreg); + placeModRegMem(lhs, opreg, /*extraCodeBytes=*/1); placeImm8(rhs.imm); } else @@ -1042,7 +1061,7 @@ void AssemblyBuilderX64::placeAvx( placeVex(dst, src1, src2, setW, mode, prefix); place(code); - placeRegAndModRegMem(dst, src2); + placeRegAndModRegMem(dst, src2, /*extraCodeBytes=*/1); placeImm8(imm8); commit(); @@ -1118,14 +1137,14 @@ static uint8_t getScaleEncoding(uint8_t scale) return scales[scale]; } -void AssemblyBuilderX64::placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs) +void AssemblyBuilderX64::placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs, int32_t extraCodeBytes) { LUAU_ASSERT(lhs.cat == CategoryX64::reg); - placeModRegMem(rhs, lhs.base.index); + placeModRegMem(rhs, lhs.base.index, extraCodeBytes); } -void AssemblyBuilderX64::placeModRegMem(OperandX64 rhs, uint8_t regop) +void AssemblyBuilderX64::placeModRegMem(OperandX64 rhs, uint8_t regop, int32_t extraCodeBytes) { if (rhs.cat == CategoryX64::reg) { @@ -1180,7 +1199,12 @@ void AssemblyBuilderX64::placeModRegMem(OperandX64 rhs, uint8_t regop) else if (base == rip) { place(MOD_RM(0b00, regop, 0b101)); - placeImm32(-int32_t(getCodeSize() + 4) + rhs.imm); + + // As a reminder: we do (getCodeSize() + 4) here to calculate the offset of the end of the current instruction we are placing. + // Since we have already placed all of the instruction bytes for this instruction, we add +4 to account for the imm32 displacement. + // Some instructions, however, are encoded such that an additional imm8 byte, or imm32 bytes, is placed after the ModRM byte, thus, + // we need to account for that case here as well. + placeImm32(-int32_t(getCodeSize() + 4 + extraCodeBytes) + rhs.imm); } else if (base != noreg) { diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 1c05b2986..72c1294f4 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -91,6 +91,9 @@ static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& case LOP_SETGLOBAL: emitInstSetGlobal(build, pc, i, next, fallback); break; + case LOP_NAMECALL: + emitInstNameCall(build, pc, i, proto->k, next, fallback); + break; case LOP_CALL: emitInstCall(build, helpers, pc, i); break; @@ -270,6 +273,9 @@ static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& case LOP_CONCAT: emitInstConcat(build, pc, i, next); break; + case LOP_COVERAGE: + emitInstCoverage(build, i); + break; default: emitFallback(build, data, op, i); break; @@ -298,6 +304,10 @@ static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauO case LOP_SETTABLEN: emitInstSetTableNFallback(build, pc, i); break; + case LOP_NAMECALL: + // TODO: fast-paths that we've handled can be removed from the fallback + emitFallback(build, data, op, i); + break; case LOP_JUMPIFEQ: emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false); break; diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index fe258ff83..2c410ae87 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -239,6 +239,22 @@ void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip) emitUpdateBase(build); } +void callGetFastTmOrFallback(AssemblyBuilderX64& build, RegisterX64 table, TMS tm, Label& fallback) +{ + build.mov(rArg1, qword[table + offsetof(Table, metatable)]); + build.test(rArg1, rArg1); + build.jcc(ConditionX64::Zero, fallback); // no metatable + + build.test(byte[rArg1 + offsetof(Table, tmcache)], 1 << tm); + build.jcc(ConditionX64::NotZero, fallback); // no tag method + + // rArg1 is already prepared + build.mov(rArg2, tm); + build.mov(rax, qword[rState + offsetof(lua_State, global)]); + build.mov(rArg3, qword[rax + offsetof(global_State, tmname[tm])]); + build.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); +} + void emitExit(AssemblyBuilderX64& build, bool continueInVm) { if (continueInVm) diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 238a0ed42..e475da606 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -67,8 +67,10 @@ constexpr OperandX64 sArg6 = noreg; constexpr unsigned kTValueSizeLog2 = 4; constexpr unsigned kLuaNodeSizeLog2 = 5; constexpr unsigned kLuaNodeTagMask = 0xf; +constexpr unsigned kNextBitOffset = 4; constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field +constexpr unsigned kOffsetOfLuaNodeNext = 12; // offsetof cannot be used on a bit field constexpr unsigned kOffsetOfInstructionC = 3; // Leaf functions that are placed in every module to perform common instruction sequences @@ -168,6 +170,12 @@ inline void jumpIfTagIsNot(AssemblyBuilderX64& build, int ri, lua_Type tag, Labe build.jcc(ConditionX64::NotEqual, label); } +inline void jumpIfTagIsNot(AssemblyBuilderX64& build, RegisterX64 reg, lua_Type tag, Label& label) +{ + build.cmp(dword[reg + offsetof(TValue, tt)], tag); + build.jcc(ConditionX64::NotEqual, label); +} + // Note: fallthrough label should be placed after this condition inline void jumpIfFalsy(AssemblyBuilderX64& build, int ri, Label& target, Label& fallthrough) { @@ -224,6 +232,13 @@ inline void jumpIfNodeValueTagIs(AssemblyBuilderX64& build, RegisterX64 node, lu build.jcc(ConditionX64::Equal, label); } +inline void jumpIfNodeHasNext(AssemblyBuilderX64& build, RegisterX64 node, Label& label) +{ + build.mov(ecx, dword[node + offsetof(LuaNode, key) + kOffsetOfLuaNodeNext]); + build.shr(ecx, kNextBitOffset); + build.jcc(ConditionX64::NotZero, label); +} + inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, OperandX64 expectedKey, Label& label) { jumpIfNodeKeyTagIsNot(build, tmp, node, LUA_TSTRING, label); @@ -250,6 +265,7 @@ void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 ta void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip); void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip); void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip); +void callGetFastTmOrFallback(AssemblyBuilderX64& build, RegisterX64 table, TMS tm, Label& fallback); void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitUpdateBase(AssemblyBuilderX64& build); @@ -258,7 +274,6 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos); void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos); void emitContinueCallInVm(AssemblyBuilderX64& build); -void emitExitFromLastReturn(AssemblyBuilderX64& build); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index abbdb65ca..7b6e1c643 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -69,6 +69,51 @@ void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc) build.vmovups(luauReg(ra), xmm0); } +void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + uint32_t aux = pc[1]; + + Label secondfpath; + + jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback); + + RegisterX64 table = r8; + build.mov(table, luauRegValue(rb)); + + // &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; + RegisterX64 node = rdx; + build.mov(node, qword[table + offsetof(Table, node)]); + build.mov(eax, 1); + build.mov(cl, byte[table + offsetof(Table, lsizenode)]); + build.shl(eax, cl); + build.dec(eax); + build.and_(eax, tsvalue(&k[aux])->hash); + build.shl(rax, kLuaNodeSizeLog2); + build.add(node, rax); + + jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), secondfpath); + + setLuauReg(build, xmm0, ra + 1, luauReg(rb)); + setLuauReg(build, xmm0, ra, luauNodeValue(node)); + build.jmp(next); + + build.setLabel(secondfpath); + + jumpIfNodeHasNext(build, node, fallback); + callGetFastTmOrFallback(build, table, TM_INDEX, fallback); + jumpIfTagIsNot(build, rax, LUA_TTABLE, fallback); + + build.mov(table, qword[rax + offsetof(TValue, value)]); + + getTableNodeAtCachedSlot(build, rax, node, table, pcpos); + jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback); + + setLuauReg(build, xmm0, ra + 1, luauReg(rb)); + setLuauReg(build, xmm0, ra, luauNodeValue(node)); +} + void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos) { int ra = LUAU_INSN_A(*pc); @@ -1627,5 +1672,28 @@ void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, callCheckGc(build, pcpos, /* savepc= */ false, next); } +void emitInstCoverage(AssemblyBuilderX64& build, int pcpos) +{ + build.mov(rcx, sCode); + build.add(rcx, pcpos * sizeof(Instruction)); + + // hits = LUAU_INSN_E(*pc) + build.mov(edx, dword[rcx]); + build.sar(edx, 8); + + // hits = (hits < (1 << 23) - 1) ? hits + 1 : hits; + build.xor_(eax, eax); + build.cmp(edx, (1 << 23) - 1); + build.setcc(ConditionX64::NotEqual, al); + build.add(edx, eax); + + + // VM_PATCH_E(pc, hits); + build.sal(edx, 8); + build.movzx(eax, byte[rcx]); + build.or_(eax, edx); + build.mov(dword[rcx], eax); +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index 1ecb06d4f..96501e63d 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -17,6 +17,7 @@ class AssemblyBuilderX64; enum class ConditionX64 : uint8_t; struct Label; struct ModuleHelpers; +struct NativeState; void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc); void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); @@ -24,6 +25,7 @@ void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc); void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc); void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc); +void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback); void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); @@ -84,6 +86,7 @@ void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pc void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback); void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next); +void emitInstCoverage(AssemblyBuilderX64& build, int pcpos); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/Fallbacks.cpp b/CodeGen/src/Fallbacks.cpp index 41f2bc8c8..e84ee2136 100644 --- a/CodeGen/src/Fallbacks.cpp +++ b/CodeGen/src/Fallbacks.cpp @@ -594,19 +594,6 @@ const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, return pc; } -const Instruction* execute_LOP_COVERAGE(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - int hits = LUAU_INSN_E(insn); - - // update hits with saturated add and patch the instruction in place - hits = (hits < (1 << 23) - 1) ? hits + 1 : hits; - VM_PATCH_E(pc - 1, hits); - - return pc; -} - const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k) { LUAU_ASSERT(!"Unsupported deprecated opcode"); diff --git a/CodeGen/src/Fallbacks.h b/CodeGen/src/Fallbacks.h index 72573b15a..bfc0e2b7c 100644 --- a/CodeGen/src/Fallbacks.h +++ b/CodeGen/src/Fallbacks.h @@ -20,5 +20,4 @@ const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, Stk const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_COVERAGE(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k); diff --git a/CodeGen/src/IrData.h b/CodeGen/src/IrData.h new file mode 100644 index 000000000..c4ed47ccb --- /dev/null +++ b/CodeGen/src/IrData.h @@ -0,0 +1,280 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Label.h" +#include "Luau/RegisterX64.h" +#include "Luau/RegisterA64.h" + +#include + +#include + +namespace Luau +{ +namespace CodeGen +{ + +enum class IrCmd : uint8_t +{ + NOP, + + LOAD_TAG, + LOAD_POINTER, + LOAD_DOUBLE, + LOAD_INT, + LOAD_TVALUE, + LOAD_NODE_VALUE_TV, // TODO: we should find a way to generalize LOAD_TVALUE + LOAD_ENV, + + GET_ARR_ADDR, + GET_SLOT_NODE_ADDR, + + STORE_TAG, + STORE_POINTER, + STORE_DOUBLE, + STORE_INT, + STORE_TVALUE, + STORE_NODE_VALUE_TV, // TODO: we should find a way to generalize STORE_TVALUE + + ADD_INT, + SUB_INT, + + ADD_NUM, + SUB_NUM, + MUL_NUM, + DIV_NUM, + MOD_NUM, + POW_NUM, + + UNM_NUM, + + NOT_ANY, // TODO: boolean specialization will be useful + + JUMP, + JUMP_IF_TRUTHY, + JUMP_IF_FALSY, + JUMP_EQ_TAG, + JUMP_EQ_BOOLEAN, + JUMP_EQ_POINTER, + + JUMP_CMP_NUM, + JUMP_CMP_STR, + JUMP_CMP_ANY, + + TABLE_LEN, + NEW_TABLE, + DUP_TABLE, + + NUM_TO_INDEX, + + // Fallback functions + DO_ARITH, + DO_LEN, + GET_TABLE, + SET_TABLE, + GET_IMPORT, + CONCAT, + GET_UPVALUE, + SET_UPVALUE, + + // Guards and checks + CHECK_TAG, + CHECK_READONLY, + CHECK_NO_METATABLE, + CHECK_SAFE_ENV, + CHECK_ARRAY_SIZE, + CHECK_SLOT_MATCH, + + // Special operations + INTERRUPT, + CHECK_GC, + BARRIER_OBJ, + BARRIER_TABLE_BACK, + BARRIER_TABLE_FORWARD, + SET_SAVEDPC, + CLOSE_UPVALS, + + // While capture is a no-op right now, it might be useful to track register/upvalue lifetimes + CAPTURE, + + // Operations that don't have an IR representation yet + LOP_SETLIST, + LOP_CALL, + LOP_RETURN, + LOP_FASTCALL, + LOP_FASTCALL1, + LOP_FASTCALL2, + LOP_FASTCALL2K, + LOP_FORNPREP, + LOP_FORNLOOP, + LOP_FORGLOOP, + LOP_FORGLOOP_FALLBACK, + LOP_FORGPREP_NEXT, + LOP_FORGPREP_INEXT, + LOP_FORGPREP_XNEXT_FALLBACK, + LOP_AND, + LOP_ANDK, + LOP_OR, + LOP_ORK, + + // Operations that have a translation, but use a full instruction fallback + FALLBACK_GETGLOBAL, + FALLBACK_SETGLOBAL, + FALLBACK_GETTABLEKS, + FALLBACK_SETTABLEKS, + + // Operations that don't have assembly lowering at all + FALLBACK_NAMECALL, + FALLBACK_PREPVARARGS, + FALLBACK_GETVARARGS, + FALLBACK_NEWCLOSURE, + FALLBACK_DUPCLOSURE, + FALLBACK_FORGPREP, + FALLBACK_COVERAGE, +}; + +enum class IrConstKind : uint8_t +{ + Bool, + Int, + Uint, + Double, + Tag, +}; + +struct IrConst +{ + IrConstKind kind; + + union + { + bool valueBool; + int valueInt; + unsigned valueUint; + double valueDouble; + uint8_t valueTag; + }; +}; + +enum class IrCondition : uint8_t +{ + Equal, + NotEqual, + Less, + NotLess, + LessEqual, + NotLessEqual, + Greater, + NotGreater, + GreaterEqual, + NotGreaterEqual, + + UnsignedLess, + UnsignedLessEqual, + UnsignedGreater, + UnsignedGreaterEqual, + + Count +}; + +enum class IrOpKind : uint32_t +{ + None, + + // To reference a constant value + Constant, + + // To specify a condition code + Condition, + + // To reference a result of a previous instruction + Inst, + + // To reference a basic block in control flow + Block, + + // To reference a VM register + VmReg, + + // To reference a VM constant + VmConst, + + // To reference a VM upvalue + VmUpvalue, +}; + +struct IrOp +{ + IrOpKind kind : 4; + uint32_t index : 28; + + IrOp() + : kind(IrOpKind::None) + , index(0) + { + } + + IrOp(IrOpKind kind, uint32_t index) + : kind(kind) + , index(index) + { + } +}; + +static_assert(sizeof(IrOp) == 4); + +struct IrInst +{ + IrCmd cmd; + + // Operands + IrOp a; + IrOp b; + IrOp c; + IrOp d; + IrOp e; + + uint32_t lastUse = 0; + uint16_t useCount = 0; + + // Location of the result (optional) + RegisterX64 regX64 = noreg; + RegisterA64 regA64{KindA64::none, 0}; + bool reusedReg = false; +}; + +enum class IrBlockKind : uint8_t +{ + Bytecode, + Fallback, + Internal, +}; + +struct IrBlock +{ + IrBlockKind kind; + + // Start points to an instruction index in a stream + // End is implicit + uint32_t start; + + Label label; +}; + +struct BytecodeMapping +{ + uint32_t irLocation; + uint32_t asmLocation; +}; + +struct IrFunction +{ + std::vector blocks; + std::vector instructions; + std::vector constants; + + std::vector bcMapping; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp new file mode 100644 index 000000000..5d54026a1 --- /dev/null +++ b/CodeGen/src/IrDump.cpp @@ -0,0 +1,379 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "IrDump.h" + +#include "IrUtils.h" + +#include "lua.h" + +#include + +namespace Luau +{ +namespace CodeGen +{ + +static const char* textForCondition[] = { + "eq", "not_eq", "lt", "not_lt", "le", "not_le", "gt", "not_gt", "ge", "not_ge", "u_lt", "u_le", "u_gt", "u_ge"}; +static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(IrCondition::Count), "all conditions have to be covered"); + +const int kDetailsAlignColumn = 60; + +LUAU_PRINTF_ATTR(2, 3) +static void append(std::string& result, const char* fmt, ...) +{ + char buf[256]; + va_list args; + va_start(args, fmt); + vsnprintf(buf, sizeof(buf), fmt, args); + va_end(args); + result.append(buf); +} + +static const char* getTagName(uint8_t tag) +{ + switch (tag) + { + case LUA_TNIL: + return "tnil"; + case LUA_TBOOLEAN: + return "tboolean"; + case LUA_TLIGHTUSERDATA: + return "tlightuserdata"; + case LUA_TNUMBER: + return "tnumber"; + case LUA_TVECTOR: + return "tvector"; + case LUA_TSTRING: + return "tstring"; + case LUA_TTABLE: + return "ttable"; + case LUA_TFUNCTION: + return "tfunction"; + case LUA_TUSERDATA: + return "tuserdata"; + case LUA_TTHREAD: + return "tthread"; + default: + LUAU_UNREACHABLE(); + } +} + +const char* getCmdName(IrCmd cmd) +{ + switch (cmd) + { + case IrCmd::NOP: + return "NOP"; + case IrCmd::LOAD_TAG: + return "LOAD_TAG"; + case IrCmd::LOAD_POINTER: + return "LOAD_POINTER"; + case IrCmd::LOAD_DOUBLE: + return "LOAD_DOUBLE"; + case IrCmd::LOAD_INT: + return "LOAD_INT"; + case IrCmd::LOAD_TVALUE: + return "LOAD_TVALUE"; + case IrCmd::LOAD_NODE_VALUE_TV: + return "LOAD_NODE_VALUE_TV"; + case IrCmd::LOAD_ENV: + return "LOAD_ENV"; + case IrCmd::GET_ARR_ADDR: + return "GET_ARR_ADDR"; + case IrCmd::GET_SLOT_NODE_ADDR: + return "GET_SLOT_NODE_ADDR"; + case IrCmd::STORE_TAG: + return "STORE_TAG"; + case IrCmd::STORE_POINTER: + return "STORE_POINTER"; + case IrCmd::STORE_DOUBLE: + return "STORE_DOUBLE"; + case IrCmd::STORE_INT: + return "STORE_INT"; + case IrCmd::STORE_TVALUE: + return "STORE_TVALUE"; + case IrCmd::STORE_NODE_VALUE_TV: + return "STORE_NODE_VALUE_TV"; + case IrCmd::ADD_INT: + return "ADD_INT"; + case IrCmd::SUB_INT: + return "SUB_INT"; + case IrCmd::ADD_NUM: + return "ADD_NUM"; + case IrCmd::SUB_NUM: + return "SUB_NUM"; + case IrCmd::MUL_NUM: + return "MUL_NUM"; + case IrCmd::DIV_NUM: + return "DIV_NUM"; + case IrCmd::MOD_NUM: + return "MOD_NUM"; + case IrCmd::POW_NUM: + return "POW_NUM"; + case IrCmd::UNM_NUM: + return "UNM_NUM"; + case IrCmd::NOT_ANY: + return "NOT_ANY"; + case IrCmd::JUMP: + return "JUMP"; + case IrCmd::JUMP_IF_TRUTHY: + return "JUMP_IF_TRUTHY"; + case IrCmd::JUMP_IF_FALSY: + return "JUMP_IF_FALSY"; + case IrCmd::JUMP_EQ_TAG: + return "JUMP_EQ_TAG"; + case IrCmd::JUMP_EQ_BOOLEAN: + return "JUMP_EQ_BOOLEAN"; + case IrCmd::JUMP_EQ_POINTER: + return "JUMP_EQ_POINTER"; + case IrCmd::JUMP_CMP_NUM: + return "JUMP_CMP_NUM"; + case IrCmd::JUMP_CMP_STR: + return "JUMP_CMP_STR"; + case IrCmd::JUMP_CMP_ANY: + return "JUMP_CMP_ANY"; + case IrCmd::TABLE_LEN: + return "TABLE_LEN"; + case IrCmd::NEW_TABLE: + return "NEW_TABLE"; + case IrCmd::DUP_TABLE: + return "DUP_TABLE"; + case IrCmd::NUM_TO_INDEX: + return "NUM_TO_INDEX"; + case IrCmd::DO_ARITH: + return "DO_ARITH"; + case IrCmd::DO_LEN: + return "DO_LEN"; + case IrCmd::GET_TABLE: + return "GET_TABLE"; + case IrCmd::SET_TABLE: + return "SET_TABLE"; + case IrCmd::GET_IMPORT: + return "GET_IMPORT"; + case IrCmd::CONCAT: + return "CONCAT"; + case IrCmd::GET_UPVALUE: + return "GET_UPVALUE"; + case IrCmd::SET_UPVALUE: + return "SET_UPVALUE"; + case IrCmd::CHECK_TAG: + return "CHECK_TAG"; + case IrCmd::CHECK_READONLY: + return "CHECK_READONLY"; + case IrCmd::CHECK_NO_METATABLE: + return "CHECK_NO_METATABLE"; + case IrCmd::CHECK_SAFE_ENV: + return "CHECK_SAFE_ENV"; + case IrCmd::CHECK_ARRAY_SIZE: + return "CHECK_ARRAY_SIZE"; + case IrCmd::CHECK_SLOT_MATCH: + return "CHECK_SLOT_MATCH"; + case IrCmd::INTERRUPT: + return "INTERRUPT"; + case IrCmd::CHECK_GC: + return "CHECK_GC"; + case IrCmd::BARRIER_OBJ: + return "BARRIER_OBJ"; + case IrCmd::BARRIER_TABLE_BACK: + return "BARRIER_TABLE_BACK"; + case IrCmd::BARRIER_TABLE_FORWARD: + return "BARRIER_TABLE_FORWARD"; + case IrCmd::SET_SAVEDPC: + return "SET_SAVEDPC"; + case IrCmd::CLOSE_UPVALS: + return "CLOSE_UPVALS"; + case IrCmd::CAPTURE: + return "CAPTURE"; + case IrCmd::LOP_SETLIST: + return "LOP_SETLIST"; + case IrCmd::LOP_CALL: + return "LOP_CALL"; + case IrCmd::LOP_RETURN: + return "LOP_RETURN"; + case IrCmd::LOP_FASTCALL: + return "LOP_FASTCALL"; + case IrCmd::LOP_FASTCALL1: + return "LOP_FASTCALL1"; + case IrCmd::LOP_FASTCALL2: + return "LOP_FASTCALL2"; + case IrCmd::LOP_FASTCALL2K: + return "LOP_FASTCALL2K"; + case IrCmd::LOP_FORNPREP: + return "LOP_FORNPREP"; + case IrCmd::LOP_FORNLOOP: + return "LOP_FORNLOOP"; + case IrCmd::LOP_FORGLOOP: + return "LOP_FORGLOOP"; + case IrCmd::LOP_FORGLOOP_FALLBACK: + return "LOP_FORGLOOP_FALLBACK"; + case IrCmd::LOP_FORGPREP_NEXT: + return "LOP_FORGPREP_NEXT"; + case IrCmd::LOP_FORGPREP_INEXT: + return "LOP_FORGPREP_INEXT"; + case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + return "LOP_FORGPREP_XNEXT_FALLBACK"; + case IrCmd::LOP_AND: + return "LOP_AND"; + case IrCmd::LOP_ANDK: + return "LOP_ANDK"; + case IrCmd::LOP_OR: + return "LOP_OR"; + case IrCmd::LOP_ORK: + return "LOP_ORK"; + case IrCmd::FALLBACK_GETGLOBAL: + return "FALLBACK_GETGLOBAL"; + case IrCmd::FALLBACK_SETGLOBAL: + return "FALLBACK_SETGLOBAL"; + case IrCmd::FALLBACK_GETTABLEKS: + return "FALLBACK_GETTABLEKS"; + case IrCmd::FALLBACK_SETTABLEKS: + return "FALLBACK_SETTABLEKS"; + case IrCmd::FALLBACK_NAMECALL: + return "FALLBACK_NAMECALL"; + case IrCmd::FALLBACK_PREPVARARGS: + return "FALLBACK_PREPVARARGS"; + case IrCmd::FALLBACK_GETVARARGS: + return "FALLBACK_GETVARARGS"; + case IrCmd::FALLBACK_NEWCLOSURE: + return "FALLBACK_NEWCLOSURE"; + case IrCmd::FALLBACK_DUPCLOSURE: + return "FALLBACK_DUPCLOSURE"; + case IrCmd::FALLBACK_FORGPREP: + return "FALLBACK_FORGPREP"; + case IrCmd::FALLBACK_COVERAGE: + return "FALLBACK_COVERAGE"; + } + + LUAU_UNREACHABLE(); +} + +const char* getBlockKindName(IrBlockKind kind) +{ + switch (kind) + { + case IrBlockKind::Bytecode: + return "bb_bytecode"; + case IrBlockKind::Fallback: + return "bb_fallback"; + case IrBlockKind::Internal: + return "bb"; + } + + LUAU_UNREACHABLE(); +} + +void toString(IrToStringContext& ctx, IrInst inst, uint32_t index) +{ + append(ctx.result, " "); + + // Instructions with a result display target virtual register + if (hasResult(inst.cmd)) + append(ctx.result, "%%%u = ", index); + + ctx.result.append(getCmdName(inst.cmd)); + + if (inst.a.kind != IrOpKind::None) + { + append(ctx.result, " "); + toString(ctx, inst.a); + } + + if (inst.b.kind != IrOpKind::None) + { + append(ctx.result, ", "); + toString(ctx, inst.b); + } + + if (inst.c.kind != IrOpKind::None) + { + append(ctx.result, ", "); + toString(ctx, inst.c); + } + + if (inst.d.kind != IrOpKind::None) + { + append(ctx.result, ", "); + toString(ctx, inst.d); + } + + if (inst.e.kind != IrOpKind::None) + { + append(ctx.result, ", "); + toString(ctx, inst.e); + } +} + +void toString(IrToStringContext& ctx, IrOp op) +{ + switch (op.kind) + { + case IrOpKind::None: + break; + case IrOpKind::Constant: + toString(ctx.result, ctx.constants[op.index]); + break; + case IrOpKind::Condition: + LUAU_ASSERT(op.index < uint32_t(IrCondition::Count)); + ctx.result.append(textForCondition[op.index]); + break; + case IrOpKind::Inst: + append(ctx.result, "%%%u", op.index); + break; + case IrOpKind::Block: + append(ctx.result, "%s_%u", getBlockKindName(ctx.blocks[op.index].kind), op.index); + break; + case IrOpKind::VmReg: + append(ctx.result, "R%u", op.index); + break; + case IrOpKind::VmConst: + append(ctx.result, "K%u", op.index); + break; + case IrOpKind::VmUpvalue: + append(ctx.result, "U%u", op.index); + break; + } +} + +void toString(std::string& result, IrConst constant) +{ + switch (constant.kind) + { + case IrConstKind::Bool: + append(result, constant.valueBool ? "true" : "false"); + break; + case IrConstKind::Int: + append(result, "%di", constant.valueInt); + break; + case IrConstKind::Uint: + append(result, "%uu", constant.valueUint); + break; + case IrConstKind::Double: + append(result, "%.17g", constant.valueDouble); + break; + case IrConstKind::Tag: + result.append(getTagName(constant.valueTag)); + break; + } +} + +void toStringDetailed(IrToStringContext& ctx, IrInst inst, uint32_t index) +{ + size_t start = ctx.result.size(); + + toString(ctx, inst, index); + + int pad = kDetailsAlignColumn - int(ctx.result.size() - start); + + if (pad > 0) + ctx.result.append(pad, ' '); + + LUAU_ASSERT(inst.useCount == 0 || inst.lastUse != 0); + + if (inst.useCount == 0 && hasSideEffects(inst.cmd)) + append(ctx.result, "; %%%u, has side-effects\n", index); + else + append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrDump.h b/CodeGen/src/IrDump.h new file mode 100644 index 000000000..8fb4d6e5f --- /dev/null +++ b/CodeGen/src/IrDump.h @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "IrData.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +const char* getCmdName(IrCmd cmd); +const char* getBlockKindName(IrBlockKind kind); + +struct IrToStringContext +{ + std::string& result; + std::vector& blocks; + std::vector& constants; +}; + +void toString(IrToStringContext& ctx, IrInst inst, uint32_t index); +void toString(IrToStringContext& ctx, IrOp op); + +void toString(std::string& result, IrConst constant); + +void toStringDetailed(IrToStringContext& ctx, IrInst inst, uint32_t index); + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrUtils.h b/CodeGen/src/IrUtils.h new file mode 100644 index 000000000..f0e4cee6c --- /dev/null +++ b/CodeGen/src/IrUtils.h @@ -0,0 +1,161 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Bytecode.h" +#include "Luau/Common.h" + +#include "IrData.h" + +namespace Luau +{ +namespace CodeGen +{ + +inline bool isJumpD(LuauOpcode op) +{ + switch (op) + { + case LOP_JUMP: + case LOP_JUMPIF: + case LOP_JUMPIFNOT: + case LOP_JUMPIFEQ: + case LOP_JUMPIFLE: + case LOP_JUMPIFLT: + case LOP_JUMPIFNOTEQ: + case LOP_JUMPIFNOTLE: + case LOP_JUMPIFNOTLT: + case LOP_FORNPREP: + case LOP_FORNLOOP: + case LOP_FORGPREP: + case LOP_FORGLOOP: + case LOP_FORGPREP_INEXT: + case LOP_FORGPREP_NEXT: + case LOP_JUMPBACK: + case LOP_JUMPXEQKNIL: + case LOP_JUMPXEQKB: + case LOP_JUMPXEQKN: + case LOP_JUMPXEQKS: + return true; + + default: + return false; + } +} + +inline bool isSkipC(LuauOpcode op) +{ + switch (op) + { + case LOP_LOADB: + return true; + + default: + return false; + } +} + +inline bool isFastCall(LuauOpcode op) +{ + switch (op) + { + case LOP_FASTCALL: + case LOP_FASTCALL1: + case LOP_FASTCALL2: + case LOP_FASTCALL2K: + return true; + + default: + return false; + } +} + +inline int getJumpTarget(uint32_t insn, uint32_t pc) +{ + LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn)); + + if (isJumpD(op)) + return int(pc + LUAU_INSN_D(insn) + 1); + else if (isFastCall(op)) + return int(pc + LUAU_INSN_C(insn) + 2); + else if (isSkipC(op) && LUAU_INSN_C(insn)) + return int(pc + LUAU_INSN_C(insn) + 1); + else if (op == LOP_JUMPX) + return int(pc + LUAU_INSN_E(insn) + 1); + else + return -1; +} + +inline bool isBlockTerminator(IrCmd cmd) +{ + switch (cmd) + { + case IrCmd::JUMP: + case IrCmd::JUMP_IF_TRUTHY: + case IrCmd::JUMP_IF_FALSY: + case IrCmd::JUMP_EQ_TAG: + case IrCmd::JUMP_EQ_BOOLEAN: + case IrCmd::JUMP_EQ_POINTER: + case IrCmd::JUMP_CMP_NUM: + case IrCmd::JUMP_CMP_STR: + case IrCmd::JUMP_CMP_ANY: + case IrCmd::LOP_RETURN: + case IrCmd::LOP_FORNPREP: + case IrCmd::LOP_FORNLOOP: + case IrCmd::LOP_FORGLOOP: + case IrCmd::LOP_FORGLOOP_FALLBACK: + case IrCmd::LOP_FORGPREP_NEXT: + case IrCmd::LOP_FORGPREP_INEXT: + case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + case IrCmd::FALLBACK_FORGPREP: + return true; + default: + break; + } + + return false; +} + +inline bool hasResult(IrCmd cmd) +{ + switch (cmd) + { + case IrCmd::LOAD_TAG: + case IrCmd::LOAD_POINTER: + case IrCmd::LOAD_DOUBLE: + case IrCmd::LOAD_INT: + case IrCmd::LOAD_TVALUE: + case IrCmd::LOAD_NODE_VALUE_TV: + case IrCmd::LOAD_ENV: + case IrCmd::GET_ARR_ADDR: + case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::ADD_INT: + case IrCmd::SUB_INT: + case IrCmd::ADD_NUM: + case IrCmd::SUB_NUM: + case IrCmd::MUL_NUM: + case IrCmd::DIV_NUM: + case IrCmd::MOD_NUM: + case IrCmd::POW_NUM: + case IrCmd::UNM_NUM: + case IrCmd::NOT_ANY: + case IrCmd::TABLE_LEN: + case IrCmd::NEW_TABLE: + case IrCmd::DUP_TABLE: + case IrCmd::NUM_TO_INDEX: + return true; + default: + break; + } + + return false; +} + +inline bool hasSideEffects(IrCmd cmd) +{ + // Instructions that don't produce a result most likely have other side-effects to make them useful + // Right now, a full switch would mirror the 'hasResult' function, so we use this simple condition + return !hasResult(cmd); +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 62974fe3e..22d38aa6e 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -42,7 +42,6 @@ void initFallbackTable(NativeState& data) CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0); CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0); CODEGEN_SET_FALLBACK(LOP_PREPVARARGS, 0); - CODEGEN_SET_FALLBACK(LOP_COVERAGE, 0); CODEGEN_SET_FALLBACK(LOP_BREAK, 0); // Fallbacks that are called from partial implementation of an instruction @@ -80,6 +79,8 @@ void initHelperFunctions(NativeState& data) data.context.luaF_close = luaF_close; + data.context.luaT_gettm = luaT_gettm; + data.context.libm_exp = exp; data.context.libm_pow = pow; data.context.libm_fmod = fmod; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 9138ba472..03d8ae3a5 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -78,6 +78,8 @@ struct NativeContext void (*luaF_close)(lua_State* L, StkId level) = nullptr; + const TValue* (*luaT_gettm)(Table* events, TMS event, TString* ename) = nullptr; + double (*libm_exp)(double) = nullptr; double (*libm_pow)(double, double) = nullptr; double (*libm_fmod)(double, double) = nullptr; diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index 0d4543f77..047c1b67f 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -123,6 +123,8 @@ class BytecodeBuilder static uint32_t getImportId(int32_t id0, int32_t id1); static uint32_t getImportId(int32_t id0, int32_t id1, int32_t id2); + static int decomposeImportId(uint32_t ids, int32_t& id0, int32_t& id1, int32_t& id2); + static uint32_t getStringHash(StringRef key); static std::string getError(const std::string& message); @@ -243,6 +245,7 @@ class BytecodeBuilder std::vector debugUpvals; DenseHashMap stringTable; + std::vector debugStrings; std::vector> debugRemarks; std::string debugRemarkBuffer; @@ -261,6 +264,7 @@ class BytecodeBuilder void validateVariadic() const; std::string dumpCurrentFunction(std::vector& dumpinstoffs) const; + void dumpConstant(std::string& result, int k) const; void dumpInstruction(const uint32_t* opcode, std::string& output, int targetLabel) const; void writeFunction(std::string& ss, uint32_t id) const; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 7d230738b..11bf24297 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -318,8 +318,13 @@ unsigned int BytecodeBuilder::addStringTableEntry(StringRef value) // note: bytecode serialization format uses 1-based table indices, 0 is reserved to mean nil if (index == 0) + { index = uint32_t(stringTable.size()); + if ((dumpFlags & Dump_Code) != 0) + debugStrings.push_back(value); + } + return index; } @@ -870,6 +875,15 @@ uint32_t BytecodeBuilder::getImportId(int32_t id0, int32_t id1, int32_t id2) return (3u << 30) | (id0 << 20) | (id1 << 10) | id2; } +int BytecodeBuilder::decomposeImportId(uint32_t ids, int32_t& id0, int32_t& id1, int32_t& id2) +{ + int count = ids >> 30; + id0 = count > 0 ? int(ids >> 20) & 1023 : -1; + id1 = count > 1 ? int(ids >> 10) & 1023 : -1; + id2 = count > 2 ? int(ids) & 1023 : -1; + return count; +} + uint32_t BytecodeBuilder::getStringHash(StringRef key) { // This hashing algorithm should match luaS_hash defined in VM/lstring.cpp for short inputs; we can't use that code directly to keep compiler and @@ -1598,6 +1612,95 @@ void BytecodeBuilder::validateVariadic() const } #endif +static bool printableStringConstant(const char* str, size_t len) +{ + for (size_t i = 0; i < len; ++i) + { + if (unsigned(str[i]) < ' ') + return false; + } + + return true; +} + +void BytecodeBuilder::dumpConstant(std::string& result, int k) const +{ + LUAU_ASSERT(unsigned(k) < constants.size()); + const Constant& data = constants[k]; + + switch (data.type) + { + case Constant::Type_Nil: + formatAppend(result, "nil"); + break; + case Constant::Type_Boolean: + formatAppend(result, "%s", data.valueBoolean ? "true" : "false"); + break; + case Constant::Type_Number: + formatAppend(result, "%.17g", data.valueNumber); + break; + case Constant::Type_String: + { + const StringRef& str = debugStrings[data.valueString - 1]; + + if (printableStringConstant(str.data, str.length)) + { + if (str.length < 32) + formatAppend(result, "'%.*s'", int(str.length), str.data); + else + formatAppend(result, "'%.*s'...", 32, str.data); + } + break; + } + case Constant::Type_Import: + { + int id0 = -1, id1 = -1, id2 = -1; + if (int count = decomposeImportId(data.valueImport, id0, id1, id2)) + { + { + const Constant& id = constants[id0]; + LUAU_ASSERT(id.type == Constant::Type_String && id.valueString <= debugStrings.size()); + + const StringRef& str = debugStrings[id.valueString - 1]; + formatAppend(result, "%.*s", int(str.length), str.data); + } + + if (count > 1) + { + const Constant& id = constants[id1]; + LUAU_ASSERT(id.type == Constant::Type_String && id.valueString <= debugStrings.size()); + + const StringRef& str = debugStrings[id.valueString - 1]; + formatAppend(result, ".%.*s", int(str.length), str.data); + } + + if (count > 2) + { + const Constant& id = constants[id2]; + LUAU_ASSERT(id.type == Constant::Type_String && id.valueString <= debugStrings.size()); + + const StringRef& str = debugStrings[id.valueString - 1]; + formatAppend(result, ".%.*s", int(str.length), str.data); + } + } + break; + } + case Constant::Type_Table: + formatAppend(result, "{...}"); + break; + case Constant::Type_Closure: + { + const Function& func = functions[data.valueClosure]; + + if (!func.dumpname.empty()) + formatAppend(result, "'%s'", func.dumpname.c_str()); + break; + } + default: + LUAU_UNREACHABLE(); + } +} + void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, int targetLabel) const { uint32_t insn = *code++; @@ -1620,7 +1723,9 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_LOADK: - formatAppend(result, "LOADK R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "LOADK R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + dumpConstant(result, LUAU_INSN_D(insn)); + result.append("]\n"); break; case LOP_MOVE: @@ -1628,11 +1733,17 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_GETGLOBAL: - formatAppend(result, "GETGLOBAL R%d K%d\n", LUAU_INSN_A(insn), *code++); + formatAppend(result, "GETGLOBAL R%d K%d [", LUAU_INSN_A(insn), *code); + dumpConstant(result, *code); + result.append("]\n"); + code++; break; case LOP_SETGLOBAL: - formatAppend(result, "SETGLOBAL R%d K%d\n", LUAU_INSN_A(insn), *code++); + formatAppend(result, "SETGLOBAL R%d K%d [", LUAU_INSN_A(insn), *code); + dumpConstant(result, *code); + result.append("]\n"); + code++; break; case LOP_GETUPVAL: @@ -1648,7 +1759,9 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_GETIMPORT: - formatAppend(result, "GETIMPORT R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "GETIMPORT R%d %d [", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + dumpConstant(result, LUAU_INSN_D(insn)); + result.append("]\n"); code++; // AUX break; @@ -1661,11 +1774,17 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_GETTABLEKS: - formatAppend(result, "GETTABLEKS R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code++); + formatAppend(result, "GETTABLEKS R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code); + dumpConstant(result, *code); + result.append("]\n"); + code++; break; case LOP_SETTABLEKS: - formatAppend(result, "SETTABLEKS R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code++); + formatAppend(result, "SETTABLEKS R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code); + dumpConstant(result, *code); + result.append("]\n"); + code++; break; case LOP_GETTABLEN: @@ -1681,7 +1800,10 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_NAMECALL: - formatAppend(result, "NAMECALL R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code++); + formatAppend(result, "NAMECALL R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code); + dumpConstant(result, *code); + result.append("]\n"); + code++; break; case LOP_CALL: @@ -1753,27 +1875,39 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_ADDK: - formatAppend(result, "ADDK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + formatAppend(result, "ADDK R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + dumpConstant(result, LUAU_INSN_C(insn)); + result.append("]\n"); break; case LOP_SUBK: - formatAppend(result, "SUBK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + formatAppend(result, "SUBK R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + dumpConstant(result, LUAU_INSN_C(insn)); + result.append("]\n"); break; case LOP_MULK: - formatAppend(result, "MULK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + formatAppend(result, "MULK R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + dumpConstant(result, LUAU_INSN_C(insn)); + result.append("]\n"); break; case LOP_DIVK: - formatAppend(result, "DIVK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + formatAppend(result, "DIVK R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + dumpConstant(result, LUAU_INSN_C(insn)); + result.append("]\n"); break; case LOP_MODK: - formatAppend(result, "MODK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + formatAppend(result, "MODK R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + dumpConstant(result, LUAU_INSN_C(insn)); + result.append("]\n"); break; case LOP_POWK: - formatAppend(result, "POWK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + formatAppend(result, "POWK R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + dumpConstant(result, LUAU_INSN_C(insn)); + result.append("]\n"); break; case LOP_AND: @@ -1785,11 +1919,15 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_ANDK: - formatAppend(result, "ANDK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + formatAppend(result, "ANDK R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + dumpConstant(result, LUAU_INSN_C(insn)); + result.append("]\n"); break; case LOP_ORK: - formatAppend(result, "ORK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + formatAppend(result, "ORK R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + dumpConstant(result, LUAU_INSN_C(insn)); + result.append("]\n"); break; case LOP_CONCAT: @@ -1850,7 +1988,9 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_DUPCLOSURE: - formatAppend(result, "DUPCLOSURE R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "DUPCLOSURE R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + dumpConstant(result, LUAU_INSN_D(insn)); + result.append("]\n"); break; case LOP_BREAK: @@ -1862,7 +2002,10 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_LOADKX: - formatAppend(result, "LOADKX R%d K%d\n", LUAU_INSN_A(insn), *code++); + formatAppend(result, "LOADKX R%d K%d [", LUAU_INSN_A(insn), *code); + dumpConstant(result, *code); + result.append("]\n"); + code++; break; case LOP_JUMPX: @@ -1876,18 +2019,18 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, case LOP_FASTCALL1: formatAppend(result, "FASTCALL1 %d R%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), targetLabel); break; + case LOP_FASTCALL2: - { - uint32_t aux = *code++; - formatAppend(result, "FASTCALL2 %d R%d R%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, targetLabel); + formatAppend(result, "FASTCALL2 %d R%d R%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code, targetLabel); + code++; break; - } + case LOP_FASTCALL2K: - { - uint32_t aux = *code++; - formatAppend(result, "FASTCALL2K %d R%d K%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, targetLabel); + formatAppend(result, "FASTCALL2K %d R%d K%d L%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code, targetLabel); + dumpConstant(result, *code); + result.append("]\n"); + code++; break; - } case LOP_COVERAGE: formatAppend(result, "COVERAGE\n"); @@ -1910,12 +2053,16 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_JUMPXEQKN: - formatAppend(result, "JUMPXEQKN R%d K%d L%d%s\n", LUAU_INSN_A(insn), *code & 0xffffff, targetLabel, *code >> 31 ? " NOT" : ""); + formatAppend(result, "JUMPXEQKN R%d K%d L%d%s [", LUAU_INSN_A(insn), *code & 0xffffff, targetLabel, *code >> 31 ? " NOT" : ""); + dumpConstant(result, *code & 0xffffff); + result.append("]\n"); code++; break; case LOP_JUMPXEQKS: - formatAppend(result, "JUMPXEQKS R%d K%d L%d%s\n", LUAU_INSN_A(insn), *code & 0xffffff, targetLabel, *code >> 31 ? " NOT" : ""); + formatAppend(result, "JUMPXEQKS R%d K%d L%d%s [", LUAU_INSN_A(insn), *code & 0xffffff, targetLabel, *code >> 31 ? " NOT" : ""); + dumpConstant(result, *code & 0xffffff); + result.append("]\n"); code++; break; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 8a1e80fc5..fbeef68ef 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -28,6 +28,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAG(LuauInterpolatedStringBaseSupport) LUAU_FASTFLAGVARIABLE(LuauMultiAssignmentConflictFix, false) LUAU_FASTFLAGVARIABLE(LuauSelfAssignmentSkip, false) +LUAU_FASTFLAGVARIABLE(LuauCompileInterpStringLimit, false) namespace Luau { @@ -1580,7 +1581,8 @@ struct Compiler RegScope rs(this); - uint8_t baseReg = allocReg(expr, uint8_t(2 + expr->expressions.size)); + uint8_t baseReg = FFlag::LuauCompileInterpStringLimit ? allocReg(expr, unsigned(2 + expr->expressions.size)) + : allocReg(expr, uint8_t(2 + expr->expressions.size)); emitLoadK(baseReg, formatStringIndex); diff --git a/Sources.cmake b/Sources.cmake index 87d76bf39..36e4f04d9 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -82,6 +82,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/EmitCommonX64.cpp CodeGen/src/EmitInstructionX64.cpp CodeGen/src/Fallbacks.cpp + CodeGen/src/IrDump.cpp CodeGen/src/NativeState.cpp CodeGen/src/UnwindBuilderDwarf2.cpp CodeGen/src/UnwindBuilderWin.cpp @@ -95,6 +96,9 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/EmitInstructionX64.h CodeGen/src/Fallbacks.h CodeGen/src/FallbacksProlog.h + CodeGen/src/IrDump.h + CodeGen/src/IrData.h + CodeGen/src/IrUtils.h CodeGen/src/NativeState.h ) diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 955604de7..190cf66a9 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -91,13 +91,13 @@ typedef struct luaL_Buffer luaL_Buffer; // all the buffer users we have in Luau match this pattern, but it's something to keep in mind for new uses of buffers #define luaL_addchar(B, c) ((void)((B)->p < (B)->end || luaL_extendbuffer(B, 1, -1)), (*(B)->p++ = (char)(c))) -#define luaL_addstring(B, s) luaL_addlstring(B, s, strlen(s)) +#define luaL_addstring(B, s) luaL_addlstring(B, s, strlen(s), -1) LUALIB_API void luaL_buffinit(lua_State* L, luaL_Buffer* B); LUALIB_API char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size); LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc); LUALIB_API void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc); -LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t l); +LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t l, int boxloc); LUALIB_API void luaL_addvalue(luaL_Buffer* B); LUALIB_API void luaL_pushresult(luaL_Buffer* B); LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size); diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index c2d07dddd..b4490fff3 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -414,10 +414,10 @@ void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc) luaL_extendbuffer(B, size - (B->end - B->p), boxloc); } -void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) +void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len, int boxloc) { if (size_t(B->end - B->p) < len) - luaL_extendbuffer(B, len - (B->end - B->p), -1); + luaL_extendbuffer(B, len - (B->end - B->p), boxloc); memcpy(B->p, s, len); B->p += len; diff --git a/VM/src/ldblib.cpp b/VM/src/ldblib.cpp index ece4f5511..97ddfa2ce 100644 --- a/VM/src/ldblib.cpp +++ b/VM/src/ldblib.cpp @@ -137,7 +137,7 @@ static int db_traceback(lua_State* L) *--lineptr = '0' + (r % 10); luaL_addchar(&buf, ':'); - luaL_addlstring(&buf, lineptr, lineend - lineptr); + luaL_addlstring(&buf, lineptr, lineend - lineptr, -1); } if (ar.name) diff --git a/VM/src/loslib.cpp b/VM/src/loslib.cpp index 62a5668b2..9b10229e3 100644 --- a/VM/src/loslib.cpp +++ b/VM/src/loslib.cpp @@ -151,7 +151,7 @@ static int os_date(lua_State* L) char buff[200]; // should be big enough for any conversion result cc[1] = *(++s); reslen = strftime(buff, sizeof(buff), cc, stm); - luaL_addlstring(&b, buff, reslen); + luaL_addlstring(&b, buff, reslen, -1); } } luaL_pushresult(&b); diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 192ea0b5c..59f16e5cf 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,6 +8,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauStringFormatAnyFix, false) + // macro to `unsign' a character #define uchar(c) ((unsigned char)(c)) @@ -771,7 +773,7 @@ static void add_s(MatchState* ms, luaL_Buffer* b, const char* s, const char* e) luaL_addchar(b, news[i]); } else if (news[i] == '0') - luaL_addlstring(b, s, e - s); + luaL_addlstring(b, s, e - s, -1); else { push_onecapture(ms, news[i] - '1', s, e); @@ -854,7 +856,7 @@ static int str_gsub(lua_State* L) if (anchor) break; } - luaL_addlstring(&b, src, ms.src_end - src); + luaL_addlstring(&b, src, ms.src_end - src, -1); luaL_pushresult(&b); lua_pushinteger(L, n); // number of substitutions return 2; @@ -891,12 +893,12 @@ static void addquoted(lua_State* L, luaL_Buffer* b, int arg) } case '\r': { - luaL_addlstring(b, "\\r", 2); + luaL_addlstring(b, "\\r", 2, -1); break; } case '\0': { - luaL_addlstring(b, "\\000", 4); + luaL_addlstring(b, "\\000", 4, -1); break; } default: @@ -1012,7 +1014,7 @@ static int str_format(lua_State* L) case 'q': { addquoted(L, &b, arg); - continue; // skip the 'addsize' at the end + continue; // skip the 'luaL_addlstring' at the end } case 's': { @@ -1024,7 +1026,7 @@ static int str_format(lua_State* L) keep original string */ lua_pushvalue(L, arg); luaL_addvalue(&b); - continue; // skip the `addsize' at the end + continue; // skip the `luaL_addlstring' at the end } else { @@ -1037,18 +1039,30 @@ static int str_format(lua_State* L) if (formatItemSize != 1) luaL_error(L, "'%%*' does not take a form"); - size_t length; - const char* string = luaL_tolstring(L, arg, &length); + if (FFlag::LuauStringFormatAnyFix) + { + size_t length; + const char* string = luaL_tolstring(L, arg, &length); + + luaL_addlstring(&b, string, length, -2); + lua_pop(L, 1); + } + else + { + size_t length; + const char* string = luaL_tolstring(L, arg, &length); + + luaL_addlstring(&b, string, length, -1); + } - luaL_addlstring(&b, string, length); - continue; // skip the `addsize' at the end + continue; // skip the `luaL_addlstring' at the end } default: { // also treat cases `pnLlh' luaL_error(L, "invalid option '%%%c' to 'format'", *(strfrmt - 1)); } } - luaL_addlstring(&b, buff, strlen(buff)); + luaL_addlstring(&b, buff, strlen(buff), -1); } } luaL_pushresult(&b); @@ -1360,7 +1374,7 @@ static void packint(luaL_Buffer* b, unsigned long long n, int islittle, int size for (i = SZINT; i < size; i++) // correct extra bytes buff[islittle ? i : size - 1 - i] = (char)MC; } - luaL_addlstring(b, buff, size); // add result to buffer + luaL_addlstring(b, buff, size, -1); // add result to buffer } /* @@ -1434,7 +1448,7 @@ static int str_pack(lua_State* L) u.n = n; // move 'u' to final result, correcting endianness if needed copywithendian(buff, u.buff, size, h.islittle); - luaL_addlstring(&b, buff, size); + luaL_addlstring(&b, buff, size, -1); break; } case Kchar: @@ -1442,7 +1456,7 @@ static int str_pack(lua_State* L) size_t len; const char* s = luaL_checklstring(L, arg, &len); luaL_argcheck(L, len <= (size_t)size, arg, "string longer than given size"); - luaL_addlstring(&b, s, len); // add string + luaL_addlstring(&b, s, len, -1); // add string while (len++ < (size_t)size) // pad extra space luaL_addchar(&b, LUAL_PACKPADBYTE); break; @@ -1453,7 +1467,7 @@ static int str_pack(lua_State* L) const char* s = luaL_checklstring(L, arg, &len); luaL_argcheck(L, size >= (int)sizeof(size_t) || len < ((size_t)1 << (size * NB)), arg, "string length does not fit in given size"); packint(&b, len, h.islittle, size, 0); // pack length - luaL_addlstring(&b, s, len); + luaL_addlstring(&b, s, len, -1); totalsize += len; break; } @@ -1462,7 +1476,7 @@ static int str_pack(lua_State* L) size_t len; const char* s = luaL_checklstring(L, arg, &len); luaL_argcheck(L, strlen(s) == len, arg, "string contains zeros"); - luaL_addlstring(&b, s, len); + luaL_addlstring(&b, s, len, -1); luaL_addchar(&b, '\0'); // add zero at the end totalsize += len + 1; break; diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 6dd941491..0efa9ee04 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -238,7 +238,7 @@ static int tconcat(lua_State* L) for (; i < last; i++) { addfield(L, &b, i); - luaL_addlstring(&b, sep, lsep); + luaL_addlstring(&b, sep, lsep, -1); } if (i == last) // add last value (if interval was not empty) addfield(L, &b, i); diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp index 837d0e125..0bbce01f0 100644 --- a/VM/src/lutf8lib.cpp +++ b/VM/src/lutf8lib.cpp @@ -175,7 +175,7 @@ static int utfchar(lua_State* L) for (int i = 1; i <= n; i++) { int l = buffutfchar(L, i, buff, &charstr); - luaL_addlstring(&b, charstr, l); + luaL_addlstring(&b, charstr, l, -1); } luaL_pushresult(&b); } diff --git a/bench/bench.py b/bench/bench.py index 0db33950a..547e0d38d 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details import argparse import os diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index b5dbf583b..d6d165094 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -243,6 +243,12 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") SINGLE_COMPARE(lea(rax, addr[r13 + r12 * 4 + 4]), 0x4b, 0x8d, 0x44, 0xa5, 0x04); } +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfSetcc") +{ + SINGLE_COMPARE(setcc(ConditionX64::NotEqual, bl), 0x0f, 0x95, 0xc3); + SINGLE_COMPARE(setcc(ConditionX64::BelowEqual, byte[rcx]), 0x0f, 0x96, 0x01); +} + TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps") { SINGLE_COMPARE(jmp(rax), 0xff, 0xe0); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index f241963a6..3b59bc338 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) +LUAU_FASTFLAG(LuauFixAutocompleteInIf) using namespace Luau; @@ -789,14 +790,30 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac2.entryMap.count("end"), 0); CHECK_EQ(ac2.context, AutocompleteContext::Keyword); - check(R"( - if x t@1 - )"); + if (FFlag::LuauFixAutocompleteInIf) + { + check(R"( + if x t@1 + )"); - auto ac3 = autocomplete('1'); - CHECK_EQ(1, ac3.entryMap.size()); - CHECK_EQ(ac3.entryMap.count("then"), 1); - CHECK_EQ(ac3.context, AutocompleteContext::Keyword); + auto ac3 = autocomplete('1'); + CHECK_EQ(3, ac3.entryMap.size()); + CHECK_EQ(ac3.entryMap.count("then"), 1); + CHECK_EQ(ac3.entryMap.count("and"), 1); + CHECK_EQ(ac3.entryMap.count("or"), 1); + CHECK_EQ(ac3.context, AutocompleteContext::Keyword); + } + else + { + check(R"( + if x t@1 + )"); + + auto ac3 = autocomplete('1'); + CHECK_EQ(1, ac3.entryMap.size()); + CHECK_EQ(ac3.entryMap.count("then"), 1); + CHECK_EQ(ac3.context, AutocompleteContext::Keyword); + } check(R"( if x then @@ -839,6 +856,23 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac5.entryMap.count("elseif"), 0); CHECK_EQ(ac5.entryMap.count("end"), 0); CHECK_EQ(ac5.context, AutocompleteContext::Statement); + + if (FFlag::LuauFixAutocompleteInIf) + { + check(R"( + if t@1 + )"); + + auto ac6 = autocomplete('1'); + CHECK_EQ(ac6.entryMap.count("true"), 1); + CHECK_EQ(ac6.entryMap.count("false"), 1); + CHECK_EQ(ac6.entryMap.count("then"), 0); + CHECK_EQ(ac6.entryMap.count("function"), 1); + CHECK_EQ(ac6.entryMap.count("else"), 0); + CHECK_EQ(ac6.entryMap.count("elseif"), 0); + CHECK_EQ(ac6.entryMap.count("end"), 0); + CHECK_EQ(ac6.context, AutocompleteContext::Expression); + } } TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_in_repeat") diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 1a6061267..42d88c7ea 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -59,14 +59,14 @@ TEST_CASE("CompileToBytecode") CHECK_EQ("\n" + bcb.dumpFunction(0), R"( LOADN R0 5 -LOADK R1 K0 +LOADK R1 K0 [6.5] RETURN R0 2 )"); CHECK_EQ("\n" + bcb.dumpEverything(), R"( Function 0 (??): LOADN R0 5 -LOADK R1 K0 +LOADK R1 K0 [6.5] RETURN R0 2 )"); @@ -102,7 +102,7 @@ TEST_CASE("BasicFunction") Luau::compileOrThrow(bcb, "local function foo(a, b) return b end"); CHECK_EQ("\n" + bcb.dumpFunction(1), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] RETURN R0 0 )"); @@ -129,15 +129,15 @@ TEST_CASE("FunctionCallOptimization") { // direct call into local CHECK_EQ("\n" + compileFunction0("local foo = math.foo()"), R"( -GETIMPORT R0 2 +GETIMPORT R0 2 [math.foo] CALL R0 0 1 RETURN R0 0 )"); // direct call into temp CHECK_EQ("\n" + compileFunction0("local foo = math.foo(math.bar())"), R"( -GETIMPORT R0 2 -GETIMPORT R1 4 +GETIMPORT R0 2 [math.foo] +GETIMPORT R1 4 [math.bar] CALL R1 0 -1 CALL R0 -1 1 RETURN R0 0 @@ -146,7 +146,7 @@ RETURN R0 0 // can't directly call into local since foo might be used as arguments of caller CHECK_EQ("\n" + compileFunction0("local foo foo = math.foo(foo)"), R"( LOADNIL R0 -GETIMPORT R1 2 +GETIMPORT R1 2 [math.foo] MOVE R2 R0 CALL R1 1 1 MOVE R0 R1 @@ -162,19 +162,19 @@ part.Size = Vector3.new(1, 2, 3) return part.Size.Z * part:GetMass() )"), R"( -GETIMPORT R0 2 -LOADK R1 K3 -GETIMPORT R2 5 +GETIMPORT R0 2 [Instance.new] +LOADK R1 K3 ['Part'] +GETIMPORT R2 5 [workspace] CALL R0 2 1 -GETIMPORT R1 7 +GETIMPORT R1 7 [Vector3.new] LOADN R2 1 LOADN R3 2 LOADN R4 3 CALL R1 3 1 -SETTABLEKS R1 R0 K8 -GETTABLEKS R3 R0 K8 -GETTABLEKS R2 R3 K9 -NAMECALL R3 R0 K10 +SETTABLEKS R1 R0 K8 ['Size'] +GETTABLEKS R3 R0 K8 ['Size'] +GETTABLEKS R2 R3 K9 ['Z'] +NAMECALL R3 R0 K10 ['GetMass'] CALL R3 1 1 MUL R1 R2 R3 RETURN R1 1 @@ -185,9 +185,9 @@ TEST_CASE("ImportCall") { CHECK_EQ("\n" + compileFunction0("return math.max(1, 2)"), R"( LOADN R1 1 -FASTCALL2K 18 R1 K0 L0 -LOADK R2 K0 -GETIMPORT R0 3 +FASTCALL2K 18 R1 K0 L0 [2] +LOADK R2 K0 [2] +GETIMPORT R0 3 [math.max] CALL R0 2 -1 L0: RETURN R0 -1 )"); @@ -198,8 +198,8 @@ TEST_CASE("FakeImportCall") const char* source = "math = {} function math.max() return 0 end function test() return math.max(1, 2) end"; CHECK_EQ("\n" + compileFunction(source, 1), R"( -GETGLOBAL R1 K0 -GETTABLEKS R0 R1 K1 +GETGLOBAL R1 K0 ['math'] +GETTABLEKS R0 R1 K1 ['max'] LOADN R1 1 LOADN R2 2 CALL R0 2 -1 @@ -220,7 +220,7 @@ TEST_CASE("AssignmentGlobal") { CHECK_EQ("\n" + compileFunction0("a = 2"), R"( LOADN R0 2 -SETGLOBAL R0 K0 +SETGLOBAL R0 K0 ['a'] RETURN R0 0 )"); } @@ -233,8 +233,8 @@ TEST_CASE("AssignmentTable") GETVARARGS R0 1 NEWTABLE R1 1 0 LOADN R2 2 -SETTABLEKS R2 R1 K0 -SETTABLEKS R0 R1 K0 +SETTABLEKS R2 R1 K0 ['b'] +SETTABLEKS R0 R1 K0 ['b'] RETURN R0 0 )"); } @@ -242,25 +242,25 @@ RETURN R0 0 TEST_CASE("ConcatChainOptimization") { CHECK_EQ("\n" + compileFunction0("return '1' .. '2'"), R"( -LOADK R1 K0 -LOADK R2 K1 +LOADK R1 K0 ['1'] +LOADK R2 K1 ['2'] CONCAT R0 R1 R2 RETURN R0 1 )"); CHECK_EQ("\n" + compileFunction0("return '1' .. '2' .. '3'"), R"( -LOADK R1 K0 -LOADK R2 K1 -LOADK R3 K2 +LOADK R1 K0 ['1'] +LOADK R2 K1 ['2'] +LOADK R3 K2 ['3'] CONCAT R0 R1 R3 RETURN R0 1 )"); CHECK_EQ("\n" + compileFunction0("return ('1' .. '2') .. '3'"), R"( -LOADK R3 K0 -LOADK R4 K1 +LOADK R3 K0 ['1'] +LOADK R4 K1 ['2'] CONCAT R1 R3 R4 -LOADK R2 K2 +LOADK R2 K2 ['3'] CONCAT R0 R1 R2 RETURN R0 1 )"); @@ -271,10 +271,10 @@ TEST_CASE("RepeatLocals") CHECK_EQ("\n" + compileFunction0("repeat local a a = 5 until a - 4 < 0 or a - 4 >= 0"), R"( L0: LOADNIL R0 LOADN R0 5 -SUBK R1 R0 K0 +SUBK R1 R0 K0 [4] LOADN R2 0 JUMPIFLT R1 R2 L1 -SUBK R1 R0 K0 +SUBK R1 R0 K0 [4] LOADN R2 0 JUMPIFLE R2 R1 L1 JUMPBACK L0 @@ -290,7 +290,7 @@ LOADN R2 1 LOADN R0 5 LOADN R1 1 FORNPREP R0 L1 -L0: GETIMPORT R3 1 +L0: GETIMPORT R3 1 [print] MOVE R4 R2 CALL R3 1 0 FORNLOOP R0 L0 @@ -305,7 +305,7 @@ LOADN R1 1 FORNPREP R0 L1 L0: MOVE R3 R2 LOADN R3 7 -GETIMPORT R4 1 +GETIMPORT R4 1 [print] MOVE R5 R3 CALL R4 1 0 FORNLOOP R0 L0 @@ -314,12 +314,12 @@ L1: RETURN R0 0 // basic for-in loop, generic version CHECK_EQ("\n" + compileFunction0("for word in string.gmatch(\"Hello Lua user\", \"%a+\") do print(word) end"), R"( -GETIMPORT R0 2 -LOADK R1 K3 -LOADK R2 K4 +GETIMPORT R0 2 [string.gmatch] +LOADK R1 K3 ['Hello Lua user'] +LOADK R2 K4 ['%a+'] CALL R0 2 3 FORGPREP R0 L1 -L0: GETIMPORT R5 6 +L0: GETIMPORT R5 6 [print] MOVE R6 R3 CALL R5 1 0 L1: FORGLOOP R0 L0 1 @@ -328,11 +328,11 @@ RETURN R0 0 // basic for-in loop, using inext specialization CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do print(k,v) end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [ipairs] NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP_INEXT R0 L1 -L0: GETIMPORT R5 3 +L0: GETIMPORT R5 3 [print] MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 @@ -342,11 +342,11 @@ RETURN R0 0 // basic for-in loop, using next specialization CHECK_EQ("\n" + compileFunction0("for k,v in pairs({}) do print(k,v) end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [pairs] NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP_NEXT R0 L1 -L0: GETIMPORT R5 3 +L0: GETIMPORT R5 3 [print] MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 @@ -355,11 +355,11 @@ RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("for k,v in next,{} do print(k,v) end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [next] NEWTABLE R1 0 0 LOADNIL R2 FORGPREP_NEXT R0 L1 -L0: GETIMPORT R5 3 +L0: GETIMPORT R5 3 [print] MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 @@ -372,7 +372,7 @@ TEST_CASE("ForBytecodeBuiltin") { // we generally recognize builtins like pairs/ipairs and emit special opcodes CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [ipairs] NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP_INEXT R0 L0 @@ -382,7 +382,7 @@ RETURN R0 0 // ... even if they are using a local variable CHECK_EQ("\n" + compileFunction0("local ip = ipairs for k,v in ip({}) do end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [ipairs] MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 @@ -403,8 +403,8 @@ RETURN R0 0 // but if it's reassigned then all bets are off CHECK_EQ("\n" + compileFunction0("local ip = ipairs ip = pairs for k,v in ip({}) do end"), R"( -GETIMPORT R0 1 -GETIMPORT R0 3 +GETIMPORT R0 1 [ipairs] +GETIMPORT R0 3 [pairs] MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 @@ -415,9 +415,9 @@ RETURN R0 0 // or if the global is hijacked CHECK_EQ("\n" + compileFunction0("ipairs = pairs for k,v in ipairs({}) do end"), R"( -GETIMPORT R0 1 -SETGLOBAL R0 K2 -GETGLOBAL R0 K2 +GETIMPORT R0 1 [pairs] +SETGLOBAL R0 K2 ['ipairs'] +GETGLOBAL R0 K2 ['ipairs'] NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP R0 L0 @@ -427,7 +427,7 @@ RETURN R0 0 // or if we don't even know the global to begin with CHECK_EQ("\n" + compileFunction0("for k,v in unknown({}) do end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [unknown] NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP R0 L0 @@ -512,11 +512,11 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("return {a=1,b=2,c=3}"), R"( DUPTABLE R0 3 LOADN R1 1 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['a'] LOADN R1 2 -SETTABLEKS R1 R0 K1 +SETTABLEKS R1 R0 K1 ['b'] LOADN R1 3 -SETTABLEKS R1 R0 K2 +SETTABLEKS R1 R0 K2 ['c'] RETURN R0 1 )"); @@ -524,9 +524,9 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("return {a=1,b=2,3,4}"), R"( NEWTABLE R0 2 2 LOADN R3 1 -SETTABLEKS R3 R0 K0 +SETTABLEKS R3 R0 K0 ['a'] LOADN R3 2 -SETTABLEKS R3 R0 K1 +SETTABLEKS R3 R0 K1 ['b'] LOADN R1 3 LOADN R2 4 SETLIST R0 R1 2 [1] @@ -536,9 +536,9 @@ RETURN R0 1 // expression assignment CHECK_EQ("\n" + compileFunction0("a = 7 return {[a]=42}"), R"( LOADN R0 7 -SETGLOBAL R0 K0 +SETGLOBAL R0 K0 ['a'] NEWTABLE R0 1 0 -GETGLOBAL R1 K0 +GETGLOBAL R1 K0 ['a'] LOADN R2 42 SETTABLE R2 R0 R1 RETURN R0 1 @@ -548,19 +548,19 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("return {a=1,b=2},{b=3,a=4},{a=5,b=6}"), R"( DUPTABLE R0 2 LOADN R1 1 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['a'] LOADN R1 2 -SETTABLEKS R1 R0 K1 +SETTABLEKS R1 R0 K1 ['b'] DUPTABLE R1 3 LOADN R2 3 -SETTABLEKS R2 R1 K1 +SETTABLEKS R2 R1 K1 ['b'] LOADN R2 4 -SETTABLEKS R2 R1 K0 +SETTABLEKS R2 R1 K0 ['a'] DUPTABLE R2 2 LOADN R3 5 -SETTABLEKS R3 R2 K0 +SETTABLEKS R3 R2 K0 ['a'] LOADN R3 6 -SETTABLEKS R3 R2 K1 +SETTABLEKS R3 R2 K1 ['b'] RETURN R0 3 )"); } @@ -624,9 +624,9 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("return {key = 1, value = 2, [1] = 42}"), R"( NEWTABLE R0 2 1 LOADN R1 1 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['key'] LOADN R1 2 -SETTABLEKS R1 R0 K1 +SETTABLEKS R1 R0 K1 ['value'] LOADN R1 42 SETTABLEN R1 R0 1 RETURN R0 1 @@ -643,9 +643,9 @@ TEST_CASE("TableLiteralsIndexConstant") R"( NEWTABLE R0 2 0 LOADN R1 42 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['key'] LOADN R1 0 -SETTABLEKS R1 R0 K1 +SETTABLEKS R1 R0 K1 ['value'] RETURN R0 1 )"); @@ -681,23 +681,23 @@ t.i = 1 R"( NEWTABLE R0 16 0 LOADN R1 1 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['a'] LOADN R1 1 -SETTABLEKS R1 R0 K1 +SETTABLEKS R1 R0 K1 ['b'] LOADN R1 1 -SETTABLEKS R1 R0 K2 +SETTABLEKS R1 R0 K2 ['c'] LOADN R1 1 -SETTABLEKS R1 R0 K3 +SETTABLEKS R1 R0 K3 ['d'] LOADN R1 1 -SETTABLEKS R1 R0 K4 +SETTABLEKS R1 R0 K4 ['e'] LOADN R1 1 -SETTABLEKS R1 R0 K5 +SETTABLEKS R1 R0 K5 ['f'] LOADN R1 1 -SETTABLEKS R1 R0 K6 +SETTABLEKS R1 R0 K6 ['g'] LOADN R1 1 -SETTABLEKS R1 R0 K7 +SETTABLEKS R1 R0 K7 ['h'] LOADN R1 1 -SETTABLEKS R1 R0 K8 +SETTABLEKS R1 R0 K8 ['i'] RETURN R0 0 )"); @@ -716,23 +716,23 @@ t.x = 9 R"( NEWTABLE R0 1 0 LOADN R1 1 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['x'] LOADN R1 2 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['x'] LOADN R1 3 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['x'] LOADN R1 4 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['x'] LOADN R1 5 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['x'] LOADN R1 6 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['x'] LOADN R1 7 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['x'] LOADN R1 8 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['x'] LOADN R1 9 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['x'] RETURN R0 0 )"); @@ -789,9 +789,9 @@ return t R"( NEWTABLE R0 2 0 LOADN R1 1 -SETTABLEKS R1 R0 K0 -DUPCLOSURE R1 K1 -SETTABLEKS R1 R0 K2 +SETTABLEKS R1 R0 K0 ['field'] +DUPCLOSURE R1 K1 ['getfield'] +SETTABLEKS R1 R0 K2 ['getfield'] RETURN R0 1 )"); } @@ -806,14 +806,14 @@ return t )"), R"( NEWTABLE R1 2 0 -FASTCALL2K 61 R1 K0 L0 -LOADK R2 K0 -GETIMPORT R0 2 +FASTCALL2K 61 R1 K0 L0 [nil] +LOADK R2 K0 [nil] +GETIMPORT R0 2 [setmetatable] CALL R0 2 1 L0: LOADN R1 1 -SETTABLEKS R1 R0 K3 +SETTABLEKS R1 R0 K3 ['field1'] LOADN R1 2 -SETTABLEKS R1 R0 K4 +SETTABLEKS R1 R0 K4 ['field2'] RETURN R0 1 )"); } @@ -843,7 +843,7 @@ L1: RETURN R0 1 TEST_CASE("ReflectionEnums") { CHECK_EQ("\n" + compileFunction0("return Enum.EasingStyle.Linear"), R"( -GETIMPORT R0 3 +GETIMPORT R0 3 [Enum.EasingStyle.Linear] RETURN R0 1 )"); } @@ -877,7 +877,7 @@ RETURN R0 0 CHECK_EQ("\n" + bcb.dumpFunction(0), R"( GETUPVAL R0 0 LOADN R1 5 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['_tweakingTooltipFrame'] RETURN R0 0 )"); } @@ -1030,7 +1030,7 @@ TEST_CASE("AndOr") // codegen for constant, local, global for and CHECK_EQ("\n" + compileFunction0("local a = 1 a = a and 2 return a"), R"( LOADN R0 1 -ANDK R0 R0 K0 +ANDK R0 R0 K0 [2] RETURN R0 1 )"); @@ -1044,10 +1044,10 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("local a = 1 b = 2 a = a and b return a"), R"( LOADN R0 1 LOADN R1 2 -SETGLOBAL R1 K0 +SETGLOBAL R1 K0 ['b'] MOVE R1 R0 JUMPIFNOT R1 L0 -GETGLOBAL R1 K0 +GETGLOBAL R1 K0 ['b'] L0: MOVE R0 R1 RETURN R0 1 )"); @@ -1055,7 +1055,7 @@ RETURN R0 1 // codegen for constant, local, global for or CHECK_EQ("\n" + compileFunction0("local a = 1 a = a or 2 return a"), R"( LOADN R0 1 -ORK R0 R0 K0 +ORK R0 R0 K0 [2] RETURN R0 1 )"); @@ -1069,10 +1069,10 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("local a = 1 b = 2 a = a or b return a"), R"( LOADN R0 1 LOADN R1 2 -SETGLOBAL R1 K0 +SETGLOBAL R1 K0 ['b'] MOVE R1 R0 JUMPIF R1 L0 -GETGLOBAL R1 K0 +GETGLOBAL R1 K0 ['b'] L0: MOVE R0 R1 RETURN R0 1 )"); @@ -1082,20 +1082,20 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("local a = 1 a = a b = 2 local c = a and b return c"), R"( LOADN R0 1 LOADN R1 2 -SETGLOBAL R1 K0 +SETGLOBAL R1 K0 ['b'] MOVE R1 R0 JUMPIFNOT R1 L0 -GETGLOBAL R1 K0 +GETGLOBAL R1 K0 ['b'] L0: RETURN R1 1 )"); CHECK_EQ("\n" + compileFunction0("local a = 1 a = a b = 2 local c = a or b return c"), R"( LOADN R0 1 LOADN R1 2 -SETGLOBAL R1 K0 +SETGLOBAL R1 K0 ['b'] MOVE R1 R0 JUMPIF R1 L0 -GETGLOBAL R1 K0 +GETGLOBAL R1 K0 ['b'] L0: RETURN R1 1 )"); } @@ -1108,7 +1108,7 @@ RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = true if a or b then b() end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [b] CALL R0 0 0 RETURN R0 0 )"); @@ -1116,18 +1116,18 @@ RETURN R0 0 // however, if right hand side is constant we can't constant fold the entire expression // (note that we don't need to evaluate the right hand side, but we do need a branch) CHECK_EQ("\n" + compileFunction0("local a = false if b and a then b() end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [b] JUMPIFNOT R0 L0 RETURN R0 0 -GETIMPORT R0 1 +GETIMPORT R0 1 [b] CALL R0 0 0 L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = true if b or a then b() end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [b] JUMPIF R0 L0 -L0: GETIMPORT R0 1 +L0: GETIMPORT R0 1 [b] CALL R0 0 0 RETURN R0 0 )"); @@ -1144,22 +1144,22 @@ TEST_CASE("AndOrChainCodegen") CHECK_EQ("\n" + compileFunction0(source), R"( LOADN R2 1 -GETIMPORT R3 1 +GETIMPORT R3 1 [verticalGradientTurbulence] SUB R1 R2 R3 -GETIMPORT R3 4 -ADDK R2 R3 K2 +GETIMPORT R3 4 [waterLevel] +ADDK R2 R3 K2 [0.014999999999999999] JUMPIFNOTLT R1 R2 L0 -GETIMPORT R0 8 +GETIMPORT R0 8 [Enum.Material.Sand] JUMPIF R0 L2 -L0: GETIMPORT R1 10 +L0: GETIMPORT R1 10 [sandbank] LOADN R2 0 JUMPIFNOTLT R2 R1 L1 -GETIMPORT R1 10 +GETIMPORT R1 10 [sandbank] LOADN R2 1 JUMPIFNOTLT R1 R2 L1 -GETIMPORT R0 8 +GETIMPORT R0 8 [Enum.Material.Sand] JUMPIF R0 L2 -L1: GETIMPORT R0 12 +L1: GETIMPORT R0 12 [Enum.Material.Sandstone] L2: RETURN R0 1 )"); } @@ -1205,7 +1205,7 @@ RETURN R0 1 // codegen for a non-constant condition CHECK_EQ("\n" + compileFunction0("return if condition then 10 else 20"), R"( -GETIMPORT R1 1 +GETIMPORT R1 1 [condition] JUMPIFNOT R1 L0 LOADN R0 10 RETURN R0 1 @@ -1215,18 +1215,18 @@ RETURN R0 1 // codegen for a non-constant condition using an assignment CHECK_EQ("\n" + compileFunction0("result = if condition then 10 else 20"), R"( -GETIMPORT R1 1 +GETIMPORT R1 1 [condition] JUMPIFNOT R1 L0 LOADN R0 10 JUMP L1 L0: LOADN R0 20 -L1: SETGLOBAL R0 K2 +L1: SETGLOBAL R0 K2 ['result'] RETURN R0 0 )"); // codegen for a non-constant condition using an assignment to a local variable CHECK_EQ("\n" + compileFunction0("local result = if condition then 10 else 20"), R"( -GETIMPORT R1 1 +GETIMPORT R1 1 [condition] JUMPIFNOT R1 L0 LOADN R0 10 RETURN R0 0 @@ -1236,20 +1236,20 @@ RETURN R0 0 // codegen for an if-else expression with multiple elseif's CHECK_EQ("\n" + compileFunction0("result = if condition1 then 10 elseif condition2 then 20 elseif condition3 then 30 else 40"), R"( -GETIMPORT R1 1 +GETIMPORT R1 1 [condition1] JUMPIFNOT R1 L0 LOADN R0 10 JUMP L3 -L0: GETIMPORT R1 3 +L0: GETIMPORT R1 3 [condition2] JUMPIFNOT R1 L1 LOADN R0 20 JUMP L3 -L1: GETIMPORT R1 5 +L1: GETIMPORT R1 5 [condition3] JUMPIFNOT R1 L2 LOADN R0 30 JUMP L3 L2: LOADN R0 40 -L3: SETGLOBAL R0 K6 +L3: SETGLOBAL R0 K6 ['result'] RETURN R0 0 )"); } @@ -1288,9 +1288,9 @@ TEST_CASE("InterpStringZeroCost") CHECK_EQ("\n" + compileFunction0(R"(local _ = `hello, {"world"}!`)"), R"( -LOADK R1 K0 -LOADK R3 K1 -NAMECALL R1 R1 K2 +LOADK R1 K0 ['hello, %*!'] +LOADK R3 K1 ['world'] +NAMECALL R1 R1 K2 ['format'] CALL R1 2 1 MOVE R0 R1 RETURN R0 0 @@ -1309,20 +1309,29 @@ TEST_CASE("InterpStringRegisterCleanup") R"( LOADNIL R0 -LOADK R1 K0 -LOADK R2 K1 -LOADK R3 K2 -LOADK R5 K3 -NAMECALL R3 R3 K4 +LOADK R1 K0 ['um'] +LOADK R2 K1 ['uh oh'] +LOADK R3 K2 ['foo%*'] +LOADK R5 K3 ['bar'] +NAMECALL R3 R3 K4 ['format'] CALL R3 2 1 MOVE R0 R3 -GETIMPORT R3 6 +GETIMPORT R3 6 [print] MOVE R4 R0 CALL R3 1 0 RETURN R0 0 )"); } +TEST_CASE("InterpStringRegisterLimit") +{ + ScopedFastFlag luauInterpolatedStringBaseSupport{"LuauInterpolatedStringBaseSupport", true}; + ScopedFastFlag luauCompileInterpStringLimit{"LuauCompileInterpStringLimit", true}; + + CHECK_THROWS_AS(compileFunction0(("local a = `" + rep("{1}", 254) + "`").c_str()), std::exception); + CHECK_THROWS_AS(compileFunction0(("local a = `" + rep("{1}", 253) + "`").c_str()), std::exception); +} + TEST_CASE("ConstantFoldArith") { CHECK_EQ("\n" + compileFunction0("return 10 + 2"), R"( @@ -1485,7 +1494,7 @@ RETURN R0 2 // local values for multiple assignments w/multret CHECK_EQ("\n" + compileFunction0("local a, b = ... return a + 1, b"), R"( GETVARARGS R0 2 -ADDK R2 R0 K0 +ADDK R2 R0 K0 [1] MOVE R3 R1 RETURN R2 2 )"); @@ -1534,7 +1543,7 @@ RETURN R0 1 // and/or constant folding when left hand side is constant CHECK_EQ("\n" + compileFunction0("return true and a"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [a] RETURN R0 1 )"); @@ -1549,22 +1558,22 @@ RETURN R0 1 )"); CHECK_EQ("\n" + compileFunction0("return false or a"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [a] RETURN R0 1 )"); // constant fold parts in chains of and/or statements CHECK_EQ("\n" + compileFunction0("return a and true and b"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [a] JUMPIFNOT R0 L0 -GETIMPORT R0 3 +GETIMPORT R0 3 [b] L0: RETURN R0 1 )"); CHECK_EQ("\n" + compileFunction0("return a or false or b"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [a] JUMPIF R0 L0 -GETIMPORT R0 3 +GETIMPORT R0 3 [b] L0: RETURN R0 1 )"); } @@ -1574,7 +1583,7 @@ TEST_CASE("ConstantFoldConditionalAndOr") CHECK_EQ("\n" + compileFunction0("local a = ... if false or a then print(1) end"), R"( GETVARARGS R0 1 JUMPIFNOT R0 L0 -GETIMPORT R1 1 +GETIMPORT R1 1 [print] LOADN R2 1 CALL R1 1 0 L0: RETURN R0 0 @@ -1583,7 +1592,7 @@ L0: RETURN R0 0 CHECK_EQ("\n" + compileFunction0("local a = ... if not (false or a) then print(1) end"), R"( GETVARARGS R0 1 JUMPIF R0 L0 -GETIMPORT R1 1 +GETIMPORT R1 1 [print] LOADN R2 1 CALL R1 1 0 L0: RETURN R0 0 @@ -1592,7 +1601,7 @@ L0: RETURN R0 0 CHECK_EQ("\n" + compileFunction0("local a = ... if true and a then print(1) end"), R"( GETVARARGS R0 1 JUMPIFNOT R0 L0 -GETIMPORT R1 1 +GETIMPORT R1 1 [print] LOADN R2 1 CALL R1 1 0 L0: RETURN R0 0 @@ -1601,7 +1610,7 @@ L0: RETURN R0 0 CHECK_EQ("\n" + compileFunction0("local a = ... if not (true and a) then print(1) end"), R"( GETVARARGS R0 1 JUMPIF R0 L0 -GETIMPORT R1 1 +GETIMPORT R1 1 [print] LOADN R2 1 CALL R1 1 0 L0: RETURN R0 0 @@ -1612,7 +1621,7 @@ TEST_CASE("ConstantFoldFlowControl") { // if CHECK_EQ("\n" + compileFunction0("if true then print(1) end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [print] LOADN R1 1 CALL R0 1 0 RETURN R0 0 @@ -1623,14 +1632,14 @@ RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("if true then print(1) else print(2) end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [print] LOADN R1 1 CALL R0 1 0 RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("if false then print(1) else print(2) end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [print] LOADN R1 2 CALL R0 1 0 RETURN R0 0 @@ -1638,7 +1647,7 @@ RETURN R0 0 // while CHECK_EQ("\n" + compileFunction0("while true do print(1) end"), R"( -L0: GETIMPORT R0 1 +L0: GETIMPORT R0 1 [print] LOADN R1 1 CALL R0 1 0 JUMPBACK L0 @@ -1651,14 +1660,14 @@ RETURN R0 0 // repeat CHECK_EQ("\n" + compileFunction0("repeat print(1) until true"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [print] LOADN R1 1 CALL R0 1 0 RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("repeat print(1) until false"), R"( -L0: GETIMPORT R0 1 +L0: GETIMPORT R0 1 [print] LOADN R1 1 CALL R0 1 0 JUMPBACK L0 @@ -1667,10 +1676,10 @@ RETURN R0 0 // there's an odd case in repeat..until compilation where we evaluate the expression that is always false for side-effects of the left hand side CHECK_EQ("\n" + compileFunction0("repeat print(1) until five and false"), R"( -L0: GETIMPORT R0 1 +L0: GETIMPORT R0 1 [print] LOADN R1 1 CALL R0 1 0 -GETIMPORT R0 3 +GETIMPORT R0 3 [five] JUMPIFNOT R0 L1 L1: JUMPBACK L0 RETURN R0 0 @@ -1681,9 +1690,9 @@ TEST_CASE("LoopBreak") { // default codegen: compile breaks as unconditional jumps CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break else end end"), R"( -L0: GETIMPORT R0 2 +L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFNOTLT R0 R1 L1 RETURN R0 0 JUMP L1 @@ -1693,9 +1702,9 @@ RETURN R0 0 // optimization: if then body is a break statement, flip the branches CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break end end"), R"( -L0: GETIMPORT R0 2 +L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L1 JUMPBACK L0 L1: RETURN R0 0 @@ -1706,28 +1715,28 @@ TEST_CASE("LoopContinue") { // default codegen: compile continue as unconditional jumps CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue else end break until false error()"), R"( -L0: GETIMPORT R0 2 +L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFNOTLT R0 R1 L2 JUMP L1 JUMP L2 JUMP L2 L1: JUMPBACK L0 -L2: GETIMPORT R0 5 +L2: GETIMPORT R0 5 [error] CALL R0 0 0 RETURN R0 0 )"); // optimization: if then body is a continue statement, flip the branches CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue end break until false error()"), R"( -L0: GETIMPORT R0 2 +L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L1 JUMP L2 L1: JUMPBACK L0 -L2: GETIMPORT R0 5 +L2: GETIMPORT R0 5 [error] CALL R0 0 0 RETURN R0 0 )"); @@ -1737,12 +1746,12 @@ TEST_CASE("LoopContinueUntil") { // it's valid to use locals defined inside the loop in until expression if they're defined before continue CHECK_EQ("\n" + compileFunction0("repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until r < 0.5"), R"( -L0: GETIMPORT R0 2 +L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R1 R0 L1 -ADDK R0 R0 K4 -L1: LOADK R1 K3 +ADDK R0 R0 K4 [0.29999999999999999] +L1: LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L2 JUMPBACK L0 L2: RETURN R0 0 @@ -1776,13 +1785,13 @@ until rr < 0.5 CHECK_EQ("\n" + compileFunction0( "repeat local r = math.random() repeat if r > 0.5 then continue end r = r - 0.1 until true r = r + 0.3 until r < 0.5"), R"( -L0: GETIMPORT R0 2 +L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R1 R0 L1 -SUBK R0 R0 K4 -L1: ADDK R0 R0 K5 -LOADK R1 K3 +SUBK R0 R0 K4 [0.10000000000000001] +L1: ADDK R0 R0 K5 [0.29999999999999999] +LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L2 JUMPBACK L0 L2: RETURN R0 0 @@ -1793,13 +1802,13 @@ L2: RETURN R0 0 "\n" + compileFunction( "repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until (function() local a = r return a < 0.5 end)()", 1), R"( -L0: GETIMPORT R0 2 +L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFNOTLT R1 R0 L1 CLOSEUPVALS R0 JUMP L2 -L1: ADDK R0 R0 K4 +L1: ADDK R0 R0 K4 [0.29999999999999999] L2: NEWCLOSURE R1 P0 CAPTURE REF R0 CALL R1 0 1 @@ -1837,14 +1846,14 @@ until (function() return rr end)() < 0.5 CHECK_EQ("\n" + compileFunction0("local stop = false stop = true function test() repeat local r = math.random() if r > 0.5 then " "continue end r = r + 0.3 until stop or r < 0.5 end"), R"( -L0: GETIMPORT R0 2 +L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R1 R0 L1 -ADDK R0 R0 K4 +ADDK R0 R0 K4 [0.29999999999999999] L1: GETUPVAL R1 0 JUMPIF R1 L2 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L2 JUMPBACK L0 L2: RETURN R0 0 @@ -1855,13 +1864,13 @@ L2: RETURN R0 0 "end r = r + 0.3 until (function() return stop or r < 0.5 end)() end", 1), R"( -L0: GETIMPORT R0 2 +L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFNOTLT R1 R0 L1 CLOSEUPVALS R0 JUMP L2 -L1: ADDK R0 R0 K4 +L1: ADDK R0 R0 K4 [0.29999999999999999] L2: NEWCLOSURE R1 P0 CAPTURE UPVAL U0 CAPTURE REF R0 @@ -1906,7 +1915,7 @@ end )", 0), R"( -ORK R2 R1 K0 +ORK R2 R1 K0 [0.5] SUB R0 R0 R2 LOADN R4 1 LOADN R8 0 @@ -1914,7 +1923,7 @@ JUMPIFNOTLT R0 R8 L0 MINUS R7 R0 JUMPIF R7 L1 L0: MOVE R7 R0 -L1: MULK R6 R7 K1 +L1: MULK R6 R7 K1 [1] LOADN R8 1 SUB R7 R8 R2 DIV R5 R6 R7 @@ -1931,12 +1940,12 @@ end 0), R"( LOADB R2 0 -LOADK R4 K0 -MULK R5 R1 K1 +LOADK R4 K0 [0.5] +MULK R5 R1 K1 [0.40000000000000002] SUB R3 R4 R5 JUMPIFNOTLT R3 R0 L1 -LOADK R4 K0 -MULK R5 R1 K1 +LOADK R4 K0 [0.5] +MULK R5 R1 K1 [0.40000000000000002] ADD R3 R4 R5 JUMPIFLT R0 R3 L0 LOADB R2 0 +1 @@ -1953,12 +1962,12 @@ end 0), R"( LOADB R2 1 -LOADK R4 K0 -MULK R5 R1 K1 +LOADK R4 K0 [0.5] +MULK R5 R1 K1 [0.40000000000000002] SUB R3 R4 R5 JUMPIFLT R0 R3 L1 -LOADK R4 K0 -MULK R5 R1 K1 +LOADK R4 K0 [0.5] +MULK R5 R1 K1 [0.40000000000000002] ADD R3 R4 R5 JUMPIFLT R3 R0 L0 LOADB R2 0 +1 @@ -2006,7 +2015,7 @@ TEST_CASE("JumpFold") { // jump-to-return folding to return CHECK_EQ("\n" + compileFunction0("return a and 1 or 0"), R"( -GETIMPORT R1 1 +GETIMPORT R1 1 [a] JUMPIFNOT R1 L0 LOADN R0 1 RETURN R0 1 @@ -2016,26 +2025,26 @@ RETURN R0 1 // conditional jump in the inner if() folding to jump out of the expression (JUMPIFNOT+5 skips over all jumps, JUMP+1 skips over JUMP+0) CHECK_EQ("\n" + compileFunction0("if a then if b then b() else end else end d()"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [a] JUMPIFNOT R0 L0 -GETIMPORT R0 3 +GETIMPORT R0 3 [b] JUMPIFNOT R0 L0 -GETIMPORT R0 3 +GETIMPORT R0 3 [b] CALL R0 0 0 JUMP L0 JUMP L0 -L0: GETIMPORT R0 5 +L0: GETIMPORT R0 5 [d] CALL R0 0 0 RETURN R0 0 )"); // same as example before but the unconditional jumps are folded with RETURN CHECK_EQ("\n" + compileFunction0("if a then if b then b() else end else end"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [a] JUMPIFNOT R0 L0 -GETIMPORT R0 3 +GETIMPORT R0 3 [b] JUMPIFNOT R0 L0 -GETIMPORT R0 3 +GETIMPORT R0 3 [b] CALL R0 0 0 RETURN R0 0 RETURN R0 0 @@ -2057,33 +2066,33 @@ end )", 0), R"( -ORK R6 R3 K0 -ORK R7 R4 K1 +ORK R6 R3 K0 [0] +ORK R7 R4 K1 [1] JUMPIF R5 L0 -GETIMPORT R10 5 +GETIMPORT R10 5 [math.noise] DIV R13 R0 R7 -MULK R14 R6 K6 +MULK R14 R6 K6 [17] ADD R12 R13 R14 -GETIMPORT R13 8 +GETIMPORT R13 8 [masterSeed] ADD R11 R12 R13 DIV R13 R1 R7 -GETIMPORT R14 8 +GETIMPORT R14 8 [masterSeed] SUB R12 R13 R14 DIV R14 R2 R7 MUL R15 R6 R6 SUB R13 R14 R15 CALL R10 3 1 -MULK R9 R10 K2 -ADDK R8 R9 K2 +MULK R9 R10 K2 [0.5] +ADDK R8 R9 K2 [0.5] RETURN R8 1 -L0: GETIMPORT R8 5 +L0: GETIMPORT R8 5 [math.noise] DIV R11 R0 R7 -MULK R12 R6 K6 +MULK R12 R6 K6 [17] ADD R10 R11 R12 -GETIMPORT R11 8 +GETIMPORT R11 8 [masterSeed] ADD R9 R10 R11 DIV R11 R1 R7 -GETIMPORT R12 8 +GETIMPORT R12 8 [masterSeed] SUB R10 R11 R12 DIV R12 R2 R7 MUL R13 R6 R6 @@ -2248,11 +2257,11 @@ TEST_CASE("NestedFunctionCalls") FASTCALL2 18 R0 R1 L0 MOVE R5 R0 MOVE R6 R1 -GETIMPORT R4 2 +GETIMPORT R4 2 [math.max] CALL R4 2 1 L0: FASTCALL2 19 R4 R2 L1 MOVE R5 R2 -GETIMPORT R3 4 +GETIMPORT R3 4 [math.min] CALL R3 2 -1 L1: RETURN R3 -1 )"); @@ -2281,11 +2290,11 @@ LOADN R0 10 LOADN R1 1 FORNPREP R0 L2 L0: MOVE R3 R2 -GETIMPORT R4 1 +GETIMPORT R4 1 [foo] NEWCLOSURE R5 P0 CAPTURE REF R3 CALL R4 1 0 -GETIMPORT R4 3 +GETIMPORT R4 3 [bar] JUMPIFNOT R4 L1 CLOSEUPVALS R3 JUMP L2 @@ -2309,15 +2318,15 @@ end )", 1), R"( -GETIMPORT R0 1 -GETIMPORT R1 3 +GETIMPORT R0 1 [ipairs] +GETIMPORT R1 3 [data] CALL R0 1 3 FORGPREP_INEXT R0 L2 -L0: GETIMPORT R5 5 +L0: GETIMPORT R5 5 [foo] NEWCLOSURE R6 P0 CAPTURE REF R3 CALL R5 1 0 -GETIMPORT R5 7 +GETIMPORT R5 7 [bar] JUMPIFNOT R5 L1 CLOSEUPVALS R3 JUMP L3 @@ -2349,12 +2358,12 @@ L0: LOADN R1 5 JUMPIFNOTLT R0 R1 L2 LOADNIL R1 MOVE R1 R0 -GETIMPORT R2 1 +GETIMPORT R2 1 [foo] NEWCLOSURE R3 P0 CAPTURE REF R1 CALL R2 1 0 -ADDK R0 R0 K2 -GETIMPORT R2 4 +ADDK R0 R0 K2 [1] +GETIMPORT R2 4 [bar] JUMPIFNOT R2 L1 CLOSEUPVALS R1 JUMP L2 @@ -2384,12 +2393,12 @@ end LOADN R0 0 L0: LOADNIL R1 MOVE R1 R0 -GETIMPORT R2 1 +GETIMPORT R2 1 [foo] NEWCLOSURE R3 P0 CAPTURE REF R1 CALL R2 1 0 -ADDK R0 R0 K2 -GETIMPORT R2 4 +ADDK R0 R0 K2 [1] +GETIMPORT R2 4 [bar] JUMPIFNOT R2 L1 CLOSEUPVALS R1 JUMP L3 @@ -2437,25 +2446,25 @@ return result CHECK_EQ("\n" + bcb.dumpFunction(0), R"( 2: NEWTABLE R0 16 0 3: LOADB R1 1 -3: SETTABLEKS R1 R0 K0 +3: SETTABLEKS R1 R0 K0 ['Mountains'] 4: LOADB R1 1 -4: SETTABLEKS R1 R0 K1 +4: SETTABLEKS R1 R0 K1 ['Canyons'] 5: LOADB R1 1 -5: SETTABLEKS R1 R0 K2 +5: SETTABLEKS R1 R0 K2 ['Dunes'] 6: LOADB R1 1 -6: SETTABLEKS R1 R0 K3 +6: SETTABLEKS R1 R0 K3 ['Arctic'] 7: LOADB R1 1 -7: SETTABLEKS R1 R0 K4 +7: SETTABLEKS R1 R0 K4 ['Lavaflow'] 8: LOADB R1 1 -8: SETTABLEKS R1 R0 K5 +8: SETTABLEKS R1 R0 K5 ['Hills'] 9: LOADB R1 1 -9: SETTABLEKS R1 R0 K6 +9: SETTABLEKS R1 R0 K6 ['Plains'] 10: LOADB R1 1 -10: SETTABLEKS R1 R0 K7 +10: SETTABLEKS R1 R0 K7 ['Marsh'] 11: LOADB R1 1 -11: SETTABLEKS R1 R0 K8 -13: LOADK R1 K9 -14: GETIMPORT R2 11 +11: SETTABLEKS R1 R0 K8 ['Water'] +13: LOADK R1 K9 [''] +14: GETIMPORT R2 11 [pairs] 14: MOVE R3 R0 14: CALL R2 1 3 14: FORGPREP_NEXT R2 L1 @@ -2490,7 +2499,7 @@ end 7: LOADN R1 2 9: LOADN R2 3 9: FORGPREP R0 L1 -11: L0: GETIMPORT R5 1 +11: L0: GETIMPORT R5 1 [print] 11: MOVE R6 R3 11: CALL R5 1 0 2: L1: FORGLOOP R0 L0 1 @@ -2515,11 +2524,11 @@ end CHECK_EQ("\n" + bcb.dumpFunction(0), R"( 2: LOADN R0 0 -4: L0: ADDK R0 R0 K0 +4: L0: ADDK R0 R0 K0 [1] 5: LOADN R1 1 5: JUMPIFNOTLT R1 R0 L1 -6: GETIMPORT R1 2 -6: LOADK R2 K3 +6: GETIMPORT R1 2 [print] +6: LOADK R2 K3 ['done!'] 6: CALL R1 1 0 10: RETURN R0 0 3: L1: JUMPBACK L0 @@ -2543,14 +2552,14 @@ until f == 0 0), R"( 2: LOADN R0 0 -4: L0: ADDK R0 R0 K0 -5: JUMPXEQKN R0 K0 L1 NOT -6: GETIMPORT R1 2 +4: L0: ADDK R0 R0 K0 [1] +5: JUMPXEQKN R0 K0 L1 NOT [1] +6: GETIMPORT R1 2 [print] 6: MOVE R2 R0 6: CALL R1 1 0 6: JUMP L2 8: L1: LOADN R0 0 -10: L2: JUMPXEQKN R0 K3 L3 +10: L2: JUMPXEQKN R0 K3 L3 [0] 10: JUMPBACK L0 11: L3: RETURN R0 0 )"); @@ -2575,14 +2584,14 @@ Table.SubTable["Key"] = { CHECK_EQ("\n" + bcb.dumpFunction(0), R"( 2: GETVARARGS R0 3 3: NEWTABLE R3 0 0 -5: GETTABLEKS R4 R3 K0 +5: GETTABLEKS R4 R3 K0 ['SubTable'] 5: DUPTABLE R5 5 -6: SETTABLEKS R0 R5 K1 -7: SETTABLEKS R1 R5 K2 -8: SETTABLEKS R2 R5 K3 +6: SETTABLEKS R0 R5 K1 ['Key1'] +7: SETTABLEKS R1 R5 K2 ['Key2'] +8: SETTABLEKS R2 R5 K3 ['Key3'] 9: LOADB R6 1 -9: SETTABLEKS R6 R5 K4 -5: SETTABLEKS R5 R4 K6 +9: SETTABLEKS R6 R5 K4 ['Key4'] +5: SETTABLEKS R5 R4 K6 ['Key'] 11: RETURN R0 0 )"); } @@ -2605,7 +2614,7 @@ Foo:Bar( 5: LOADN R3 1 6: LOADN R4 2 7: LOADN R5 3 -4: NAMECALL R1 R0 K0 +4: NAMECALL R1 R0 K0 ['Bar'] 4: CALL R1 4 0 8: RETURN R0 0 )"); @@ -2627,12 +2636,12 @@ Foo CHECK_EQ("\n" + bcb.dumpFunction(0), R"( 2: GETVARARGS R0 1 5: LOADN R4 1 -5: NAMECALL R2 R0 K0 +5: NAMECALL R2 R0 K0 ['Bar'] 5: CALL R2 2 1 6: LOADN R4 2 -6: NAMECALL R2 R2 K1 +6: NAMECALL R2 R2 K1 ['Baz'] 6: CALL R2 2 1 -7: GETTABLEKS R1 R2 K2 +7: GETTABLEKS R1 R2 K2 ['Qux'] 7: LOADN R2 3 7: CALL R1 1 0 8: RETURN R0 0 @@ -2657,7 +2666,7 @@ return 5: FASTCALL2 18 R0 R1 L0 5: MOVE R3 R0 5: MOVE R4 R1 -5: GETIMPORT R2 2 +5: GETIMPORT R2 2 [math.max] 5: CALL R2 2 -1 5: L0: RETURN R2 -1 )"); @@ -2681,13 +2690,13 @@ a 2: DUPTABLE R1 3 2: DUPTABLE R2 5 2: LOADN R3 3 -2: SETTABLEKS R3 R2 K4 -2: SETTABLEKS R2 R1 K2 -2: SETTABLEKS R1 R0 K0 -5: GETTABLEKS R2 R0 K0 -6: GETTABLEKS R1 R2 K2 +2: SETTABLEKS R3 R2 K4 ['d'] +2: SETTABLEKS R2 R1 K2 ['c'] +2: SETTABLEKS R1 R0 K0 ['b'] +5: GETTABLEKS R2 R0 K0 ['b'] +6: GETTABLEKS R1 R2 K2 ['c'] 7: LOADN R2 4 -7: SETTABLEKS R2 R1 K4 +7: SETTABLEKS R2 R1 K4 ['d'] 8: RETURN R0 0 )"); } @@ -2724,35 +2733,35 @@ return result NEWTABLE R0 16 0 3: ['Mountains'] = true, LOADB R1 1 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['Mountains'] 4: ['Canyons'] = true, LOADB R1 1 -SETTABLEKS R1 R0 K1 +SETTABLEKS R1 R0 K1 ['Canyons'] 5: ['Dunes'] = true, LOADB R1 1 -SETTABLEKS R1 R0 K2 +SETTABLEKS R1 R0 K2 ['Dunes'] 6: ['Arctic'] = true, LOADB R1 1 -SETTABLEKS R1 R0 K3 +SETTABLEKS R1 R0 K3 ['Arctic'] 7: ['Lavaflow'] = true, LOADB R1 1 -SETTABLEKS R1 R0 K4 +SETTABLEKS R1 R0 K4 ['Lavaflow'] 8: ['Hills'] = true, LOADB R1 1 -SETTABLEKS R1 R0 K5 +SETTABLEKS R1 R0 K5 ['Hills'] 9: ['Plains'] = true, LOADB R1 1 -SETTABLEKS R1 R0 K6 +SETTABLEKS R1 R0 K6 ['Plains'] 10: ['Marsh'] = true, LOADB R1 1 -SETTABLEKS R1 R0 K7 +SETTABLEKS R1 R0 K7 ['Marsh'] 11: ['Water'] = true, LOADB R1 1 -SETTABLEKS R1 R0 K8 +SETTABLEKS R1 R0 K8 ['Water'] 13: local result = "" -LOADK R1 K9 +LOADK R1 K9 [''] 14: for k in pairs(kSelectedBiomes) do -GETIMPORT R2 11 +GETIMPORT R2 11 [pairs] MOVE R3 R0 CALL R2 1 3 FORGPREP_NEXT R2 L1 @@ -2817,25 +2826,25 @@ local 8: reg 3, start pc 35 line 21, end pc 35 line 21 4: LOADN R3 3 4: LOADN R4 1 4: FORNPREP R3 L1 -5: L0: GETIMPORT R6 1 +5: L0: GETIMPORT R6 1 [print] 5: MOVE R7 R5 5: CALL R6 1 0 4: FORNLOOP R3 L0 -7: L1: GETIMPORT R3 3 +7: L1: GETIMPORT R3 3 [pairs] 7: CALL R3 0 3 7: FORGPREP_NEXT R3 L3 -8: L2: GETIMPORT R8 1 +8: L2: GETIMPORT R8 1 [print] 8: MOVE R9 R6 8: MOVE R10 R7 8: CALL R8 2 0 7: L3: FORGLOOP R3 L2 2 11: LOADN R3 2 -12: GETIMPORT R4 1 +12: GETIMPORT R4 1 [print] 12: LOADN R5 2 12: CALL R4 1 0 15: LOADN R3 2 -16: GETIMPORT R4 1 -16: GETIMPORT R5 5 +16: GETIMPORT R4 1 [print] +16: GETIMPORT R5 5 [b] 16: CALL R4 1 0 18: NEWCLOSURE R3 P0 18: CAPTURE VAL R3 @@ -2944,7 +2953,7 @@ RETURN R0 0 LOADNIL R0 LOADN R1 1 LOADN R2 2 -SETTABLEKS R2 R0 K0 +SETTABLEKS R2 R0 K0 ['foo'] MOVE R0 R1 RETURN R0 0 )"); @@ -2952,7 +2961,7 @@ RETURN R0 0 // ... or a table index ... CHECK_EQ("\n" + compileFunction0("local a a, foo[a] = 1, 2"), R"( LOADNIL R0 -GETIMPORT R1 1 +GETIMPORT R1 1 [foo] LOADN R2 1 LOADN R3 2 SETTABLE R3 R1 R0 @@ -2987,8 +2996,8 @@ RETURN R0 0 // into a temp register CHECK_EQ("\n" + compileFunction0("local a a, foo[a + 1] = 1, 2"), R"( LOADNIL R0 -GETIMPORT R1 1 -ADDK R2 R0 K2 +GETIMPORT R1 1 [foo] +ADDK R2 R0 K2 [1] LOADN R0 1 LOADN R3 2 SETTABLE R3 R1 R2 @@ -3002,14 +3011,14 @@ TEST_CASE("FastcallBytecode") CHECK_EQ("\n" + compileFunction0("return math.abs(-5)"), R"( LOADN R1 -5 FASTCALL1 2 R1 L0 -GETIMPORT R0 2 +GETIMPORT R0 2 [math.abs] CALL R0 1 -1 L0: RETURN R0 -1 )"); // call through a local variable CHECK_EQ("\n" + compileFunction0("local abs = math.abs return abs(-5)"), R"( -GETIMPORT R0 2 +GETIMPORT R0 2 [math.abs] LOADN R2 -5 FASTCALL1 2 R2 L0 MOVE R1 R0 @@ -3029,9 +3038,9 @@ L0: RETURN R0 -1 // mutating the global in the script breaks the optimization CHECK_EQ("\n" + compileFunction0("math = {} return math.abs(-5)"), R"( NEWTABLE R0 0 0 -SETGLOBAL R0 K0 -GETGLOBAL R1 K0 -GETTABLEKS R0 R1 K1 +SETGLOBAL R0 K0 ['math'] +GETGLOBAL R1 K0 ['math'] +GETTABLEKS R0 R1 K1 ['abs'] LOADN R1 -5 CALL R0 1 -1 RETURN R0 -1 @@ -3039,7 +3048,7 @@ RETURN R0 -1 // mutating the local in the script breaks the optimization CHECK_EQ("\n" + compileFunction0("local abs = math.abs abs = nil return abs(-5)"), R"( -GETIMPORT R0 2 +GETIMPORT R0 2 [math.abs] LOADNIL R0 MOVE R1 R0 LOADN R2 -5 @@ -3049,10 +3058,10 @@ RETURN R1 -1 // mutating the global in the script breaks the optimization, even if you do this after computing the local (for simplicity) CHECK_EQ("\n" + compileFunction0("local abs = math.abs math = {} return abs(-5)"), R"( -GETGLOBAL R1 K0 -GETTABLEKS R0 R1 K1 +GETGLOBAL R1 K0 ['math'] +GETTABLEKS R0 R1 K1 ['abs'] NEWTABLE R1 0 0 -SETGLOBAL R1 K0 +SETGLOBAL R1 K0 ['math'] MOVE R1 R0 LOADN R2 -5 CALL R1 1 -1 @@ -3064,9 +3073,9 @@ TEST_CASE("FastcallSelect") { // select(_, ...) compiles to a builtin call CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"( -LOADK R1 K0 +LOADK R1 K0 ['#'] FASTCALL1 57 R1 L0 -GETIMPORT R0 2 +GETIMPORT R0 2 [select] GETVARARGS R2 -1 CALL R0 -1 1 L0: RETURN R0 1 @@ -3083,16 +3092,16 @@ return sum R"( LOADN R0 0 LOADN R3 1 -LOADK R5 K0 +LOADK R5 K0 ['#'] FASTCALL1 57 R5 L0 -GETIMPORT R4 2 +GETIMPORT R4 2 [select] GETVARARGS R6 -1 CALL R4 -1 1 L0: MOVE R1 R4 LOADN R2 1 FORNPREP R1 L3 L1: FASTCALL1 57 R3 L2 -GETIMPORT R4 2 +GETIMPORT R4 2 [select] MOVE R5 R3 GETVARARGS R6 -1 CALL R4 -1 1 @@ -3103,8 +3112,8 @@ L3: RETURN R0 1 // currently we assume a single value return to avoid dealing with stack resizing CHECK_EQ("\n" + compileFunction0("return select('#', ...)"), R"( -GETIMPORT R0 1 -LOADK R1 K2 +GETIMPORT R0 1 [select] +LOADK R1 K2 ['#'] GETVARARGS R2 -1 CALL R0 -1 -1 RETURN R0 -1 @@ -3112,17 +3121,17 @@ RETURN R0 -1 // note that select with a non-variadic second argument doesn't get optimized CHECK_EQ("\n" + compileFunction0("return select('#')"), R"( -GETIMPORT R0 1 -LOADK R1 K2 +GETIMPORT R0 1 [select] +LOADK R1 K2 ['#'] CALL R0 1 -1 RETURN R0 -1 )"); // note that select with a non-variadic second argument doesn't get optimized CHECK_EQ("\n" + compileFunction0("return select('#', foo())"), R"( -GETIMPORT R0 1 -LOADK R1 K2 -GETIMPORT R2 4 +GETIMPORT R0 1 [select] +LOADK R1 K2 ['#'] +GETIMPORT R2 4 [foo] CALL R2 0 -1 CALL R0 -1 -1 RETURN R0 -1 @@ -3194,7 +3203,7 @@ RETURN R0 1 )"); CHECK_EQ("\n" + compileFunction0("return -0"), R"( -LOADK R0 K0 +LOADK R0 K0 [-0] RETURN R0 1 )"); } @@ -3253,14 +3262,14 @@ RETURN R0 1 // recursive capture CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] CAPTURE VAL R0 RETURN R0 0 )"); // multi-level recursive capture CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 [] CAPTURE UPVAL U0 RETURN R0 1 )"); @@ -3411,12 +3420,12 @@ TEST_CASE("FastCallImportFallback") // note: it's important that GETGLOBAL below doesn't overwrite R2 CHECK_EQ("\n" + fragment, R"( LOADN R1 1024 -LOADK R2 K1023 +LOADK R2 K1023 ['1024'] SETTABLE R2 R0 R1 LOADN R2 -1 FASTCALL1 2 R2 L0 -GETGLOBAL R3 K1024 -GETTABLEKS R1 R3 K1025 +GETGLOBAL R3 K1024 ['math'] +GETTABLEKS R1 R3 K1025 ['abs'] CALL R1 1 -1 )"); } @@ -3425,25 +3434,25 @@ TEST_CASE("CompoundAssignment") { // globals vs constants CHECK_EQ("\n" + compileFunction0("a += 1"), R"( -GETGLOBAL R0 K0 -ADDK R0 R0 K1 -SETGLOBAL R0 K0 +GETGLOBAL R0 K0 ['a'] +ADDK R0 R0 K1 [1] +SETGLOBAL R0 K0 ['a'] RETURN R0 0 )"); // globals vs expressions CHECK_EQ("\n" + compileFunction0("a -= a"), R"( -GETGLOBAL R0 K0 -GETGLOBAL R1 K0 +GETGLOBAL R0 K0 ['a'] +GETGLOBAL R1 K0 ['a'] SUB R0 R0 R1 -SETGLOBAL R0 K0 +SETGLOBAL R0 K0 ['a'] RETURN R0 0 )"); // locals vs constants CHECK_EQ("\n" + compileFunction0("local a = 1 a *= 2"), R"( LOADN R0 1 -MULK R0 R0 K0 +MULK R0 R0 K0 [2] RETURN R0 0 )"); @@ -3457,7 +3466,7 @@ RETURN R0 0 // locals vs expressions CHECK_EQ("\n" + compileFunction0("local a = 1 a /= a + 1"), R"( LOADN R0 1 -ADDK R1 R0 K0 +ADDK R1 R0 K0 [1] DIV R0 R0 R1 RETURN R0 0 )"); @@ -3465,7 +3474,7 @@ RETURN R0 0 // upvalues CHECK_EQ("\n" + compileFunction0("local a = 1 function foo() a += 4 end"), R"( GETUPVAL R0 0 -ADDK R0 R0 K0 +ADDK R0 R0 K0 [4] SETUPVAL R0 0 RETURN R0 0 )"); @@ -3473,16 +3482,16 @@ RETURN R0 0 // table variants (indexed by string, number, variable) CHECK_EQ("\n" + compileFunction0("local a = {} a.foo += 5"), R"( NEWTABLE R0 0 0 -GETTABLEKS R1 R0 K0 -ADDK R1 R1 K1 -SETTABLEKS R1 R0 K0 +GETTABLEKS R1 R0 K0 ['foo'] +ADDK R1 R1 K1 [5] +SETTABLEKS R1 R0 K0 ['foo'] RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = {} a[1] += 5"), R"( NEWTABLE R0 0 0 GETTABLEN R1 R0 1 -ADDK R1 R1 K0 +ADDK R1 R1 K0 [5] SETTABLEN R1 R0 1 RETURN R0 0 )"); @@ -3490,19 +3499,19 @@ RETURN R0 0 CHECK_EQ("\n" + compileFunction0("local a = {} a[a] += 5"), R"( NEWTABLE R0 0 0 GETTABLE R1 R0 R0 -ADDK R1 R1 K0 +ADDK R1 R1 K0 [5] SETTABLE R1 R0 R0 RETURN R0 0 )"); // left hand side is evaluated once CHECK_EQ("\n" + compileFunction0("foo()[bar()] += 5"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [foo] CALL R0 0 1 -GETIMPORT R1 3 +GETIMPORT R1 3 [bar] CALL R1 0 1 GETTABLE R2 R0 R1 -ADDK R2 R2 K4 +ADDK R2 R2 K4 [5] SETTABLE R2 R0 R1 RETURN R0 0 )"); @@ -3512,40 +3521,40 @@ TEST_CASE("CompoundAssignmentConcat") { // basic concat CHECK_EQ("\n" + compileFunction0("local a = '' a ..= 'a'"), R"( -LOADK R0 K0 +LOADK R0 K0 [''] MOVE R1 R0 -LOADK R2 K1 +LOADK R2 K1 ['a'] CONCAT R0 R1 R2 RETURN R0 0 )"); // concat chains CHECK_EQ("\n" + compileFunction0("local a = '' a ..= 'a' .. 'b'"), R"( -LOADK R0 K0 +LOADK R0 K0 [''] MOVE R1 R0 -LOADK R2 K1 -LOADK R3 K2 +LOADK R2 K1 ['a'] +LOADK R3 K2 ['b'] CONCAT R0 R1 R3 RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = '' a ..= 'a' .. 'b' .. 'c'"), R"( -LOADK R0 K0 +LOADK R0 K0 [''] MOVE R1 R0 -LOADK R2 K1 -LOADK R3 K2 -LOADK R4 K3 +LOADK R2 K1 ['a'] +LOADK R3 K2 ['b'] +LOADK R4 K3 ['c'] CONCAT R0 R1 R4 RETURN R0 0 )"); // concat on non-local CHECK_EQ("\n" + compileFunction0("_VERSION ..= 'a' .. 'b'"), R"( -GETGLOBAL R1 K0 -LOADK R2 K1 -LOADK R3 K2 +GETGLOBAL R1 K0 ['_VERSION'] +LOADK R2 K1 ['a'] +LOADK R3 K2 ['b'] CONCAT R0 R1 R3 -SETGLOBAL R0 K0 +SETGLOBAL R0 K0 ['_VERSION'] RETURN R0 0 )"); } @@ -3588,12 +3597,12 @@ JUMP L1 L0: JUMPX L14543 L1: FORNPREP R1 L0 L2: ADD R0 R0 R3 -LOADK R4 K0 +LOADK R4 K0 [150000] JUMP L4 L3: JUMPX L14543 L4: JUMPIFLT R4 R0 L3 ADD R0 R0 R3 -LOADK R4 K0 +LOADK R4 K0 [150000] JUMP L6 L5: JUMPX L14543 )"); @@ -3606,10 +3615,10 @@ L5: JUMPX L14543 CHECK_EQ("\n" + tail, R"( ADD R0 R0 R3 -LOADK R4 K0 +LOADK R4 K0 [150000] JUMPIFLT R4 R0 L14543 ADD R0 R0 R3 -LOADK R4 K0 +LOADK R4 K0 [150000] JUMPIFLT R4 R0 L14543 JUMP L14542 L14541: JUMPX L2 @@ -3634,13 +3643,13 @@ return obj:Method(1):Method(2):Method(3) R"( GETVARARGS R0 1 LOADN R3 1 -NAMECALL R1 R0 K0 +NAMECALL R1 R0 K0 ['Method'] CALL R1 2 1 LOADN R3 2 -NAMECALL R1 R1 K0 +NAMECALL R1 R1 K0 ['Method'] CALL R1 2 1 LOADN R3 3 -NAMECALL R1 R1 K0 +NAMECALL R1 R1 K0 ['Method'] CALL R1 2 -1 RETURN R1 -1 )"); @@ -3664,7 +3673,7 @@ local a = g() return a )"), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [g] CALL R0 0 1 RETURN R0 1 )"); @@ -3676,7 +3685,7 @@ return a )"), R"( LOADN R0 1 -GETIMPORT R1 1 +GETIMPORT R1 1 [g] CALL R1 0 1 RETURN R0 1 )"); @@ -3690,7 +3699,7 @@ local b = obj == 1 )"), R"( GETVARARGS R0 1 -JUMPXEQKN R0 K0 L0 +JUMPXEQKN R0 K0 L0 [1] LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 @@ -3702,7 +3711,7 @@ local b = 1 == obj )"), R"( GETVARARGS R0 1 -JUMPXEQKN R0 K0 L0 +JUMPXEQKN R0 K0 L0 [1] LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 @@ -3714,7 +3723,7 @@ local b = "Hello, Sailor!" == obj )"), R"( GETVARARGS R0 1 -JUMPXEQKS R0 K0 L0 +JUMPXEQKS R0 K0 L0 ['Hello, Sailor!'] LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 @@ -3780,8 +3789,8 @@ return t['a'] R"( DUPTABLE R0 1 LOADN R1 2 -SETTABLEKS R1 R0 K0 -GETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['a'] +GETTABLEKS R1 R0 K0 ['a'] RETURN R1 1 )"); @@ -3792,7 +3801,7 @@ t['a'] = 2 R"( NEWTABLE R0 0 0 LOADN R1 2 -SETTABLEKS R1 R0 K0 +SETTABLEKS R1 R0 K0 ['a'] RETURN R0 0 )"); } @@ -3807,11 +3816,11 @@ print(2) 1), R"( 2: COVERAGE -2: GETIMPORT R0 1 +2: GETIMPORT R0 1 [print] 2: LOADN R1 1 2: CALL R0 1 0 3: COVERAGE -3: GETIMPORT R0 1 +3: GETIMPORT R0 1 [print] 3: LOADN R1 2 3: CALL R0 1 0 4: RETURN R0 0 @@ -3828,15 +3837,15 @@ end 1), R"( 2: COVERAGE -2: GETIMPORT R0 1 +2: GETIMPORT R0 1 [x] 2: JUMPIFNOT R0 L0 3: COVERAGE -3: GETIMPORT R0 3 +3: GETIMPORT R0 3 [print] 3: LOADN R1 1 3: CALL R0 1 0 7: RETURN R0 0 5: L0: COVERAGE -5: GETIMPORT R0 3 +5: GETIMPORT R0 3 [print] 5: LOADN R1 2 5: CALL R0 1 0 7: RETURN R0 0 @@ -3856,15 +3865,15 @@ end 1), R"( 2: COVERAGE -2: GETIMPORT R0 1 +2: GETIMPORT R0 1 [x] 2: JUMPIFNOT R0 L0 4: COVERAGE -4: GETIMPORT R0 3 +4: GETIMPORT R0 3 [print] 4: LOADN R1 1 4: CALL R0 1 0 9: RETURN R0 0 7: L0: COVERAGE -7: GETIMPORT R0 3 +7: GETIMPORT R0 3 [print] 7: LOADN R1 2 7: CALL R0 1 0 9: RETURN R0 0 @@ -3891,13 +3900,13 @@ local t = { 4: COVERAGE 4: COVERAGE 4: LOADN R2 1 -4: SETTABLEKS R2 R1 K0 +4: SETTABLEKS R2 R1 K0 ['a'] 5: COVERAGE 5: COVERAGE 5: LOADN R2 2 -5: SETTABLEKS R2 R1 K1 +5: SETTABLEKS R2 R1 K1 ['b'] 6: COVERAGE -6: SETTABLEKS R0 R1 K2 +6: SETTABLEKS R0 R1 K2 ['c'] 8: RETURN R0 0 )"); } @@ -3910,7 +3919,7 @@ return function() end )", 1), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 [] RETURN R0 1 )"); @@ -3920,7 +3929,7 @@ return function() print("hi") end )", 1), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 [] RETURN R0 1 )"); @@ -3933,7 +3942,7 @@ end )", 1), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [print] NEWCLOSURE R1 P0 CAPTURE VAL R0 RETURN R1 1 @@ -3946,7 +3955,7 @@ return function() print("hi") end )", 1), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [setfenv] LOADN R1 1 NEWTABLE R2 0 0 CALL R0 2 0 @@ -3978,7 +3987,7 @@ end )", 1), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 [] CAPTURE UPVAL U0 RETURN R0 1 )"); @@ -4027,9 +4036,9 @@ end )", 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['bar'] CAPTURE UPVAL U0 -DUPCLOSURE R1 K1 +DUPCLOSURE R1 K1 [] CAPTURE VAL R0 RETURN R1 1 )"); @@ -4055,7 +4064,7 @@ RETURN R2 1 // we also allow recursive function captures to share the object, even when it's not top-level CHECK_EQ("\n" + compileFunction("function test() local function foo() return foo() end end", 1), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] CAPTURE VAL R0 RETURN R0 0 )"); @@ -4097,16 +4106,16 @@ LOADN R2 1 LOADN R0 10 LOADN R1 1 FORNPREP R0 L1 -L0: GETIMPORT R3 1 +L0: GETIMPORT R3 1 [print] NEWCLOSURE R4 P0 CAPTURE VAL R2 CALL R3 1 0 FORNLOOP R0 L0 -L1: GETIMPORT R0 3 +L1: GETIMPORT R0 3 [pairs] GETVARARGS R1 -1 CALL R0 -1 3 FORGPREP_NEXT R0 L3 -L2: GETIMPORT R5 1 +L2: GETIMPORT R5 1 [print] NEWCLOSURE R6 P1 CAPTURE VAL R3 CALL R5 1 0 @@ -4115,7 +4124,7 @@ LOADN R2 1 LOADN R0 10 LOADN R1 1 FORNPREP R0 L5 -L4: GETIMPORT R3 1 +L4: GETIMPORT R3 1 [print] NEWCLOSURE R4 P2 CAPTURE VAL R2 CALL R3 1 0 @@ -4140,24 +4149,24 @@ workspace.print() // Check Roblox globals are no longer here CHECK_EQ("\n" + compileFunction0(source), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [print] CALL R0 0 0 -GETIMPORT R0 3 +GETIMPORT R0 3 [Game.print] CALL R0 0 0 -GETIMPORT R0 5 +GETIMPORT R0 5 [Workspace.print] CALL R0 0 0 -GETIMPORT R1 7 -GETTABLEKS R0 R1 K0 +GETIMPORT R1 7 [_G] +GETTABLEKS R0 R1 K0 ['print'] CALL R0 0 0 -GETIMPORT R0 9 +GETIMPORT R0 9 [game.print] CALL R0 0 0 -GETIMPORT R0 11 +GETIMPORT R0 11 [plugin.print] CALL R0 0 0 -GETIMPORT R0 13 +GETIMPORT R0 13 [script.print] CALL R0 0 0 -GETIMPORT R0 15 +GETIMPORT R0 15 [shared.print] CALL R0 0 0 -GETIMPORT R0 17 +GETIMPORT R0 17 [workspace.print] CALL R0 0 0 RETURN R0 0 )"); @@ -4171,31 +4180,31 @@ RETURN R0 0 Luau::compileOrThrow(bcb, source, options); CHECK_EQ("\n" + bcb.dumpFunction(0), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [print] CALL R0 0 0 -GETIMPORT R1 3 -GETTABLEKS R0 R1 K0 +GETIMPORT R1 3 [Game] +GETTABLEKS R0 R1 K0 ['print'] CALL R0 0 0 -GETIMPORT R1 5 -GETTABLEKS R0 R1 K0 +GETIMPORT R1 5 [Workspace] +GETTABLEKS R0 R1 K0 ['print'] CALL R0 0 0 -GETIMPORT R1 7 -GETTABLEKS R0 R1 K0 +GETIMPORT R1 7 [_G] +GETTABLEKS R0 R1 K0 ['print'] CALL R0 0 0 -GETIMPORT R1 9 -GETTABLEKS R0 R1 K0 +GETIMPORT R1 9 [game] +GETTABLEKS R0 R1 K0 ['print'] CALL R0 0 0 -GETIMPORT R1 11 -GETTABLEKS R0 R1 K0 +GETIMPORT R1 11 [plugin] +GETTABLEKS R0 R1 K0 ['print'] CALL R0 0 0 -GETIMPORT R1 13 -GETTABLEKS R0 R1 K0 +GETIMPORT R1 13 [script] +GETTABLEKS R0 R1 K0 ['print'] CALL R0 0 0 -GETIMPORT R1 15 -GETTABLEKS R0 R1 K0 +GETIMPORT R1 15 [shared] +GETTABLEKS R0 R1 K0 ['print'] CALL R0 0 0 -GETIMPORT R1 17 -GETTABLEKS R0 R1 K0 +GETIMPORT R1 17 [workspace] +GETTABLEKS R0 R1 K0 ['print'] CALL R0 0 0 RETURN R0 0 )"); @@ -4214,8 +4223,8 @@ TEST_CASE("ConstantsNoFolding") CHECK_EQ("\n" + bcb.dumpFunction(0), R"( LOADNIL R0 LOADB R1 1 -LOADK R2 K0 -LOADK R3 K1 +LOADK R2 K0 [42] +LOADK R3 K1 ['hello'] RETURN R0 4 )"); } @@ -4236,7 +4245,7 @@ LOADN R1 1 LOADN R2 2 LOADN R3 3 FASTCALL 54 L0 -GETIMPORT R0 2 +GETIMPORT R0 2 [Vector3.new] CALL R0 3 -1 L0: RETURN R0 -1 )"); @@ -4249,8 +4258,8 @@ TEST_CASE("TypeAssertion") print(foo() :: typeof(error("compile time"))) )"), R"( -GETIMPORT R0 1 -GETIMPORT R1 3 +GETIMPORT R0 1 [print] +GETIMPORT R1 3 [foo] CALL R1 0 1 CALL R0 1 0 RETURN R0 0 @@ -4261,8 +4270,8 @@ RETURN R0 0 print(foo()) )"), R"( -GETIMPORT R0 1 -GETIMPORT R1 3 +GETIMPORT R0 1 [print] +GETIMPORT R1 3 [foo] CALL R1 0 -1 CALL R0 -1 0 RETURN R0 0 @@ -4295,12 +4304,12 @@ return a + 1, a - 1, a / 1, a * 1, a % 1, a ^ 1 )"), R"( GETVARARGS R0 1 -ADDK R1 R0 K0 -SUBK R2 R0 K0 -DIVK R3 R0 K0 -MULK R4 R0 K0 -MODK R5 R0 K0 -POWK R6 R0 K0 +ADDK R1 R0 K0 [1] +SUBK R2 R0 K0 [1] +DIVK R3 R0 K0 [1] +MULK R4 R0 K0 [1] +MODK R5 R0 K0 [1] +POWK R6 R0 K0 [1] RETURN R1 6 )"); } @@ -4413,20 +4422,20 @@ LOADN R3 0 LOADN R1 3 LOADN R2 1 FORNPREP R1 L1 -L0: MULK R5 R3 K1 -ADDK R4 R5 K0 +L0: MULK R5 R3 K1 [4] +ADDK R4 R5 K0 [1] LOADN R5 0 SETTABLE R5 R0 R4 -MULK R5 R3 K1 -ADDK R4 R5 K2 +MULK R5 R3 K1 [4] +ADDK R4 R5 K2 [2] LOADN R5 0 SETTABLE R5 R0 R4 -MULK R5 R3 K1 -ADDK R4 R5 K3 +MULK R5 R3 K1 [4] +ADDK R4 R5 K3 [3] LOADN R5 0 SETTABLE R5 R0 R4 -MULK R5 R3 K1 -ADDK R4 R5 K1 +MULK R5 R3 K1 [4] +ADDK R4 R5 K1 [4] LOADN R5 0 SETTABLE R5 R0 R4 FORNLOOP R1 L0 @@ -4464,9 +4473,9 @@ end )", 0, 2), R"( -GETIMPORT R2 1 -GETIMPORT R0 3 -GETIMPORT R1 5 +GETIMPORT R2 1 [x] +GETIMPORT R0 3 [y] +GETIMPORT R1 5 [z] FORNPREP R0 L1 L0: FORNLOOP R0 L0 L1: RETURN R0 0 @@ -4496,7 +4505,7 @@ end R"( LOADN R2 1 LOADN R0 2 -LOADK R1 K0 +LOADK R1 K0 [0.10000000000000001] FORNPREP R0 L1 L0: FORNLOOP R0 L0 L1: RETURN R0 0 @@ -4509,8 +4518,8 @@ end )", 0, 2), R"( -LOADK R2 K0 -LOADK R0 K1 +LOADK R2 K0 [4294967295] +LOADK R0 K1 [4294967296] LOADN R1 1 FORNPREP R0 L1 L0: FORNLOOP R0 L0 @@ -4535,17 +4544,17 @@ end )", 0, 2), R"( -GETIMPORT R0 2 +GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L0 -GETIMPORT R0 2 +GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L0 -GETIMPORT R0 2 +GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L0 L0: RETURN R0 0 )"); @@ -4561,25 +4570,25 @@ end )", 0, 2), R"( -GETIMPORT R0 2 +GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L0 -GETIMPORT R0 5 +GETIMPORT R0 5 [print] LOADN R1 1 CALL R0 1 0 -L0: GETIMPORT R0 2 +L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L1 -GETIMPORT R0 5 +GETIMPORT R0 5 [print] LOADN R1 2 CALL R0 1 0 -L1: GETIMPORT R0 2 +L1: GETIMPORT R0 2 [math.random] CALL R0 0 1 -LOADK R1 K3 +LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L2 -GETIMPORT R0 5 +GETIMPORT R0 5 [print] LOADN R1 3 CALL R0 1 0 L2: RETURN R0 0 @@ -4598,20 +4607,20 @@ end )", 1, 2), R"( -GETIMPORT R0 1 +GETIMPORT R0 1 [global] LOADN R1 1 CALL R0 1 1 -GETIMPORT R1 3 +GETIMPORT R1 3 [print] NEWCLOSURE R2 P0 CAPTURE REF R0 CALL R1 1 0 -GETIMPORT R1 6 +GETIMPORT R1 6 [math.random] CALL R1 0 1 -LOADK R2 K7 +LOADK R2 K7 [0.5] JUMPIFNOTLT R1 R2 L0 CLOSEUPVALS R0 RETURN R0 0 -L0: ADDK R0 R0 K8 +L0: ADDK R0 R0 K8 [1] CLOSEUPVALS R0 RETURN R0 0 )"); @@ -4797,10 +4806,10 @@ LOADN R2 1 FORNPREP R1 L3 L0: FASTCALL1 24 R3 L1 MOVE R6 R3 -GETIMPORT R5 2 +GETIMPORT R5 2 [math.sin] CALL R5 1 -1 L1: FASTCALL 2 L2 -GETIMPORT R4 4 +GETIMPORT R4 4 [math.abs] CALL R4 -1 1 L2: SETTABLE R4 R0 R3 FORNLOOP R1 L0 @@ -4825,7 +4834,7 @@ LOADN R1 1 FORNPREP R0 L1 L0: MOVE R3 R2 LOADN R3 3 -GETIMPORT R4 1 +GETIMPORT R4 1 [print] MOVE R5 R3 CALL R4 1 0 FORNLOOP R0 L0 @@ -4850,44 +4859,44 @@ end )", 0, 2), R"( -FASTCALL2K 39 R1 K0 L0 +FASTCALL2K 39 R1 K0 L0 [0] MOVE R4 R1 -LOADK R5 K0 -GETIMPORT R3 3 +LOADK R5 K0 [0] +GETIMPORT R3 3 [bit32.rshift] CALL R3 2 1 -L0: FASTCALL2K 29 R3 K4 L1 -LOADK R4 K4 -GETIMPORT R2 6 +L0: FASTCALL2K 29 R3 K4 L1 [255] +LOADK R4 K4 [255] +GETIMPORT R2 6 [bit32.band] CALL R2 2 1 L1: SETTABLEN R2 R0 1 -FASTCALL2K 39 R1 K7 L2 +FASTCALL2K 39 R1 K7 L2 [8] MOVE R4 R1 -LOADK R5 K7 -GETIMPORT R3 3 +LOADK R5 K7 [8] +GETIMPORT R3 3 [bit32.rshift] CALL R3 2 1 -L2: FASTCALL2K 29 R3 K4 L3 -LOADK R4 K4 -GETIMPORT R2 6 +L2: FASTCALL2K 29 R3 K4 L3 [255] +LOADK R4 K4 [255] +GETIMPORT R2 6 [bit32.band] CALL R2 2 1 L3: SETTABLEN R2 R0 2 -FASTCALL2K 39 R1 K8 L4 +FASTCALL2K 39 R1 K8 L4 [16] MOVE R4 R1 -LOADK R5 K8 -GETIMPORT R3 3 +LOADK R5 K8 [16] +GETIMPORT R3 3 [bit32.rshift] CALL R3 2 1 -L4: FASTCALL2K 29 R3 K4 L5 -LOADK R4 K4 -GETIMPORT R2 6 +L4: FASTCALL2K 29 R3 K4 L5 [255] +LOADK R4 K4 [255] +GETIMPORT R2 6 [bit32.band] CALL R2 2 1 L5: SETTABLEN R2 R0 3 -FASTCALL2K 39 R1 K9 L6 +FASTCALL2K 39 R1 K9 L6 [24] MOVE R4 R1 -LOADK R5 K9 -GETIMPORT R3 3 +LOADK R5 K9 [24] +GETIMPORT R3 3 [bit32.rshift] CALL R3 2 1 -L6: FASTCALL2K 29 R3 K4 L7 -LOADK R4 K4 -GETIMPORT R2 6 +L6: FASTCALL2K 29 R3 K4 L7 [255] +LOADK R4 K4 [255] +GETIMPORT R2 6 [bit32.band] CALL R2 2 1 L7: SETTABLEN R2 R0 4 RETURN R0 0 @@ -4909,13 +4918,13 @@ LOADN R4 0 LOADN R2 3 LOADN R3 1 FORNPREP R2 L1 -L0: ADDK R5 R4 K0 -GETGLOBAL R7 K1 -GETTABLEKS R6 R7 K2 -GETGLOBAL R8 K1 -GETTABLEKS R7 R8 K3 +L0: ADDK R5 R4 K0 [1] +GETGLOBAL R7 K1 ['bit32'] +GETTABLEKS R6 R7 K2 ['band'] +GETGLOBAL R8 K1 ['bit32'] +GETTABLEKS R7 R8 K3 ['rshift'] MOVE R8 R1 -MULK R9 R4 K4 +MULK R9 R4 K4 [8] CALL R7 2 1 LOADN R8 255 CALL R6 2 1 @@ -4938,11 +4947,11 @@ LOADN R4 0 LOADN R2 3 LOADN R3 1 FORNPREP R2 L3 -L0: ADDK R5 R4 K0 -MULK R9 R4 K1 +L0: ADDK R5 R4 K0 [1] +MULK R9 R4 K1 [8] FASTCALL2 39 R1 R9 L1 MOVE R8 R1 -GETIMPORT R7 4 +GETIMPORT R7 4 [bit32.rshift] CALL R7 2 1 L1: LOADN R8 255 LOADN R9 255 @@ -4950,7 +4959,7 @@ LOADN R10 255 LOADN R11 255 LOADN R12 255 FASTCALL 29 L2 -GETIMPORT R6 6 +GETIMPORT R6 6 [bit32.band] CALL R6 6 1 L2: SETTABLE R6 R0 R5 FORNLOOP R2 L0 @@ -4971,7 +4980,7 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 )"); @@ -4987,7 +4996,7 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 )"); @@ -5007,8 +5016,8 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 -GETIMPORT R2 3 +DUPCLOSURE R0 K0 ['foo'] +GETIMPORT R2 3 [math.random] CALL R2 0 1 MOVE R1 R2 RETURN R1 1 @@ -5030,8 +5039,8 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 -GETIMPORT R2 3 +DUPCLOSURE R0 K0 ['foo'] +GETIMPORT R2 3 [math.random] CALL R2 0 1 LOADN R1 5 RETURN R1 1 @@ -5052,7 +5061,7 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 CALL R1 0 1 RETURN R1 1 @@ -5070,10 +5079,10 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 CALL R1 0 1 -GETIMPORT R2 2 +GETIMPORT R2 2 [getfenv] CALL R2 0 0 RETURN R1 1 )"); @@ -5095,7 +5104,7 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] NEWTABLE R2 0 0 LOADN R3 1 SETTABLEN R3 R2 1 @@ -5121,7 +5130,7 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] NEWTABLE R2 0 0 LOADN R3 1 SETTABLEN R3 R2 1 @@ -5147,7 +5156,7 @@ return x )", 2, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADN R2 1 NEWCLOSURE R1 P1 CAPTURE VAL R2 @@ -5173,9 +5182,9 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADN R2 42 -ORK R2 R2 K1 +ORK R2 R2 K1 [5] MOVE R1 R2 RETURN R1 1 )"); @@ -5192,7 +5201,7 @@ return y )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 MOVE R2 R1 RETURN R2 1 @@ -5211,7 +5220,7 @@ return y )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 LOADNIL R1 MOVE R3 R1 @@ -5232,7 +5241,7 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 LOADN R2 42 CALL R1 1 1 @@ -5276,7 +5285,7 @@ return x 1, 2), R"( GETVARARGS R0 1 -DUPCLOSURE R1 K0 +DUPCLOSURE R1 K0 ['foo'] CAPTURE VAL R0 LOADN R3 42 ADD R2 R3 R0 @@ -5298,7 +5307,7 @@ end )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] CAPTURE UPVAL U0 LOADN R2 42 GETUPVAL R3 0 @@ -5321,7 +5330,7 @@ return y )", 2, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 NEWCLOSURE R2 P1 CAPTURE VAL R1 @@ -5339,7 +5348,7 @@ return y )", 2, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADN R2 42 NEWCLOSURE R1 P1 CAPTURE VAL R2 @@ -5358,7 +5367,7 @@ return y )", 2, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADNIL R1 LOADN R1 42 MOVE R3 R1 @@ -5379,9 +5388,9 @@ return y )", 2, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADNIL R2 -ORK R2 R2 K1 +ORK R2 R2 K1 [42] NEWCLOSURE R1 P1 CAPTURE REF R2 CLOSEUPVALS R2 @@ -5401,11 +5410,11 @@ return y )", 2, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 MOVE R3 R1 -ORK R3 R3 K1 -GETIMPORT R4 3 +ORK R3 R3 K1 [42] +GETIMPORT R4 3 [print] NEWCLOSURE R5 P1 CAPTURE REF R3 CALL R4 1 0 @@ -5431,7 +5440,7 @@ return y, x )", 2, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 JUMPIF R1 L0 LOADNIL R3 @@ -5460,7 +5469,7 @@ return a, b )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADNIL R1 LOADNIL R2 RETURN R1 2 @@ -5477,7 +5486,7 @@ return a, b )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADNIL R1 LOADNIL R2 RETURN R1 2 @@ -5493,7 +5502,7 @@ return foo() )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 CALL R1 0 -1 RETURN R1 -1 @@ -5515,7 +5524,7 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADNIL R1 RETURN R1 1 )"); @@ -5531,10 +5540,10 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 -LOADK R3 K1 +DUPCLOSURE R0 K0 ['foo'] +LOADK R3 K1 [1.5] FASTCALL1 20 R3 L0 -GETIMPORT R2 4 +GETIMPORT R2 4 [math.modf] CALL R2 1 2 L0: ADD R1 R2 R3 RETURN R1 1 @@ -5551,7 +5560,7 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] GETVARARGS R2 2 ADD R1 R2 R3 RETURN R1 1 @@ -5568,8 +5577,8 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 -GETIMPORT R2 2 +DUPCLOSURE R0 K0 ['foo'] +GETIMPORT R2 2 [print] CALL R2 0 1 LOADN R1 42 RETURN R1 1 @@ -5587,7 +5596,7 @@ return x )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADNIL R2 LOADN R2 42 MOVE R1 R2 @@ -5612,9 +5621,9 @@ return a, b, c, d )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 2 -ADDK R3 R1 K1 +ADDK R3 R1 K1 [1] LOADN R5 1 ADD R4 R5 R1 LOADN R5 3 @@ -5643,9 +5652,9 @@ return (baz()) )", 3, 2), R"( -DUPCLOSURE R0 K0 -DUPCLOSURE R1 K1 -DUPCLOSURE R2 K2 +DUPCLOSURE R0 K0 ['foo'] +DUPCLOSURE R1 K1 ['bar'] +DUPCLOSURE R2 K2 ['baz'] LOADN R4 43 LOADN R5 41 MUL R3 R4 R5 @@ -5671,7 +5680,7 @@ return (foo()) )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 CALL R1 0 1 RETURN R1 1 @@ -5687,7 +5696,7 @@ return (foo()) )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 CALL R1 0 1 RETURN R1 1 @@ -5711,9 +5720,9 @@ return (baz()) )", 3, 2), R"( -DUPCLOSURE R0 K0 -DUPCLOSURE R1 K1 -DUPCLOSURE R2 K2 +DUPCLOSURE R0 K0 ['foo'] +DUPCLOSURE R1 K1 ['bar'] +DUPCLOSURE R2 K2 ['baz'] MOVE R4 R0 LOADN R5 42 LOADN R6 1 @@ -5772,7 +5781,7 @@ foo(foo(foo,foo(foo,foo))[foo]) )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R2 R0 MOVE R3 R0 MOVE R4 R0 @@ -5797,17 +5806,17 @@ set({}) )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['set'] NEWTABLE R2 0 0 -FASTCALL2K 49 R2 K1 L0 -LOADK R3 K1 -GETIMPORT R1 3 +FASTCALL2K 49 R2 K1 L0 [false] +LOADK R3 K1 [false] +GETIMPORT R1 3 [rawset] CALL R1 2 0 L0: NEWTABLE R1 0 0 NEWTABLE R3 0 0 FASTCALL2 49 R3 R1 L1 MOVE R4 R1 -GETIMPORT R2 3 +GETIMPORT R2 3 [rawset] CALL R2 2 0 L1: RETURN R0 0 )"); @@ -5838,7 +5847,7 @@ end )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 [] L0: LOADNIL R4 LOADNIL R5 CALL R4 1 1 @@ -5847,7 +5856,7 @@ GETTABLE R3 R4 R5 JUMPIFNOT R3 L1 JUMPBACK L0 L1: LOADNIL R2 -GETTABLEKS R1 R2 K1 +GETTABLEKS R1 R2 K1 [''] JUMPIFNOT R1 L2 RETURN R0 0 L2: JUMPIFNOT R1 L3 @@ -5889,7 +5898,7 @@ return y )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 MOVE R3 R1 LOADN R3 42 @@ -5912,13 +5921,13 @@ return y )", 2, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 NEWCLOSURE R2 P1 CAPTURE REF R1 -SETGLOBAL R2 K1 +SETGLOBAL R2 K1 ['mutator'] MOVE R3 R1 -GETGLOBAL R4 K1 +GETGLOBAL R4 K1 ['mutator'] CALL R4 0 0 MOVE R2 R3 CLOSEUPVALS R1 @@ -5938,7 +5947,7 @@ return foo(42) )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 LOADN R2 42 CALL R1 1 -1 @@ -5955,7 +5964,7 @@ return foo(42) )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 )"); @@ -5974,8 +5983,8 @@ return bar(42) )", 2, 2), R"( -DUPCLOSURE R0 K0 -DUPCLOSURE R1 K1 +DUPCLOSURE R0 K0 ['foo'] +DUPCLOSURE R1 K1 ['bar'] LOADN R2 42 RETURN R2 1 )"); @@ -5990,7 +5999,7 @@ return foo(42) )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] CAPTURE VAL R0 MOVE R1 R0 LOADN R2 42 @@ -6008,7 +6017,7 @@ return foo(42) )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 LOADN R2 42 CALL R1 1 -1 @@ -6046,7 +6055,7 @@ return x, y + 1 R"( GETVARARGS R0 2 MOVE R2 R0 -ADDK R3 R1 K0 +ADDK R3 R1 K0 [1] RETURN R2 2 )"); @@ -6093,7 +6102,7 @@ return foo(42) )", 1, 1), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 LOADN R2 42 CALL R1 1 -1 @@ -6111,7 +6120,7 @@ return foo(42) )", 1, 1), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 )"); @@ -6126,7 +6135,7 @@ return foo(42) )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 )"); @@ -6142,7 +6151,7 @@ return foo(42) )", 1, 2), R"( -DUPCLOSURE R0 K0 +DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 LOADN R2 42 CALL R1 1 -1 @@ -6235,7 +6244,7 @@ LOADN R22 0 LOADN R23 3 LOADN R24 0 LOADN R25 0 -LOADK R26 K0 +LOADK R26 K0 [4294967291] LOADN R27 5 LOADN R28 1 LOADN R29 1 @@ -6248,19 +6257,19 @@ LOADN R35 200 LOADN R36 106 LOADN R37 200 LOADN R38 50 -LOADK R39 K1 +LOADK R39 K1 ['number'] LOADN R40 97 LOADN R41 98 LOADN R42 3 -LOADK R43 K2 +LOADK R43 K2 ['boolean'] LOADN R44 0 LOADN R45 1 LOADN R46 8 LOADN R47 1 LOADN R48 101 LOADN R49 2 -LOADK R50 K3 -LOADK R51 K4 +LOADK R50 K3 ['nil'] +LOADK R51 K4 ['string'] RETURN R0 52 )"); } @@ -6283,53 +6292,53 @@ return 0, 2), R"( FASTCALL 2 L0 -GETIMPORT R0 2 +GETIMPORT R0 2 [math.abs] CALL R0 0 1 L0: LOADN R2 1 -FASTCALL2K 18 R2 K3 L1 -LOADK R3 K3 -GETIMPORT R1 5 +FASTCALL2K 18 R2 K3 L1 [true] +LOADK R3 K3 [true] +GETIMPORT R1 5 [math.max] CALL R1 2 1 -L1: LOADK R3 K6 -FASTCALL2K 41 R3 K7 L2 -LOADK R4 K7 -GETIMPORT R2 10 +L1: LOADK R3 K6 ['abc'] +FASTCALL2K 41 R3 K7 L2 [42] +LOADK R4 K7 [42] +GETIMPORT R2 10 [string.byte] CALL R2 2 1 L2: LOADN R4 10 -FASTCALL2K 39 R4 K7 L3 -LOADK R5 K7 -GETIMPORT R3 13 +FASTCALL2K 39 R4 K7 L3 [42] +LOADK R5 K7 [42] +GETIMPORT R3 13 [bit32.rshift] CALL R3 2 1 L3: LOADN R5 1 LOADN R6 2 -LOADK R7 K14 +LOADK R7 K14 ['3'] FASTCALL 34 L4 -GETIMPORT R4 16 +GETIMPORT R4 16 [bit32.extract] CALL R4 3 1 L4: LOADN R6 1 -FASTCALL2K 31 R6 K3 L5 -LOADK R7 K3 -GETIMPORT R5 18 +FASTCALL2K 31 R6 K3 L5 [true] +LOADK R7 K3 [true] +GETIMPORT R5 18 [bit32.bor] CALL R5 2 1 L5: LOADN R7 1 -FASTCALL2K 29 R7 K3 L6 -LOADK R8 K3 -GETIMPORT R6 20 +FASTCALL2K 29 R7 K3 L6 [true] +LOADK R8 K3 [true] +GETIMPORT R6 20 [bit32.band] CALL R6 2 1 L6: LOADN R8 1 -FASTCALL2K 32 R8 K3 L7 -LOADK R9 K3 -GETIMPORT R7 22 +FASTCALL2K 32 R8 K3 L7 [true] +LOADK R9 K3 [true] +GETIMPORT R7 22 [bit32.bxor] CALL R7 2 1 L7: LOADN R9 1 -FASTCALL2K 33 R9 K3 L8 -LOADK R10 K3 -GETIMPORT R8 24 +FASTCALL2K 33 R9 K3 L8 [true] +LOADK R10 K3 [true] +GETIMPORT R8 24 [bit32.btest] CALL R8 2 1 L8: LOADN R10 1 -FASTCALL2K 19 R10 K3 L9 -LOADK R11 K3 -GETIMPORT R9 26 +FASTCALL2K 19 R10 K3 L9 [true] +LOADK R11 K3 [true] +GETIMPORT R9 26 [math.min] CALL R9 2 -1 L9: RETURN R0 -1 )"); @@ -6414,20 +6423,20 @@ end )", 0, 2), R"( -GETTABLEKS R2 R0 K0 -FASTCALL2K 29 R2 K1 L0 -LOADK R3 K1 -GETIMPORT R1 4 +GETTABLEKS R2 R0 K0 ['pendingLanes'] +FASTCALL2K 29 R2 K1 L0 [3221225471] +LOADK R3 K1 [3221225471] +GETIMPORT R1 4 [bit32.band] CALL R1 2 1 -L0: JUMPXEQKN R1 K5 L1 +L0: JUMPXEQKN R1 K5 L1 [0] RETURN R1 1 -L1: FASTCALL2K 29 R1 K6 L2 +L1: FASTCALL2K 29 R1 K6 L2 [1073741824] MOVE R3 R1 -LOADK R4 K6 -GETIMPORT R2 4 +LOADK R4 K6 [1073741824] +GETIMPORT R2 4 [bit32.band] CALL R2 2 1 -L2: JUMPXEQKN R2 K5 L3 -LOADK R2 K6 +L2: JUMPXEQKN R2 K5 L3 [0] +LOADK R2 K6 [1073741824] RETURN R2 1 L3: LOADN R2 0 RETURN R2 1 @@ -6482,9 +6491,9 @@ end )"), R"( MOVE R2 R0 -ADDK R2 R2 K0 +ADDK R2 R2 K0 [0] MOVE R3 R1 -ADDK R1 R1 K0 +ADDK R1 R1 K0 [0] ADD R4 R2 R3 RETURN R4 1 )"); @@ -6543,11 +6552,11 @@ TEST_CASE("MultipleAssignments") R"( LOADNIL R0 LOADNIL R1 -GETIMPORT R2 1 +GETIMPORT R2 1 [f] LOADN R3 1 CALL R2 1 1 MOVE R0 R2 -GETIMPORT R2 1 +GETIMPORT R2 1 [f] LOADN R3 2 CALL R2 1 1 MOVE R1 R2 @@ -6597,8 +6606,8 @@ RETURN R0 0 GETVARARGS R0 4 MOVE R0 R1 MOVE R1 R2 -ADDK R2 R2 K0 -SUBK R3 R3 K0 +ADDK R2 R2 K0 [1] +SUBK R3 R3 K0 [1] RETURN R0 0 )"); @@ -6670,7 +6679,7 @@ RETURN R0 0 R"( GETVARARGS R0 4 LOADN R0 1 -GETIMPORT R4 1 +GETIMPORT R4 1 [foo] CALL R4 0 3 MOVE R1 R4 MOVE R2 R5 @@ -6686,7 +6695,7 @@ RETURN R0 0 R"( GETVARARGS R0 4 LOADN R4 1 -GETIMPORT R6 1 +GETIMPORT R6 1 [foo] CALL R6 0 3 SETTABLE R6 R1 R0 SETTABLE R7 R2 R3 @@ -6704,7 +6713,7 @@ RETURN R0 0 R"( GETVARARGS R0 4 LOADN R0 1 -GETIMPORT R4 1 +GETIMPORT R4 1 [foo] LOADNIL R5 LOADNIL R6 MOVE R1 R4 @@ -6720,7 +6729,7 @@ RETURN R0 0 )"), R"( GETVARARGS R0 2 -ADDK R2 R1 K0 +ADDK R2 R1 K0 [1] SETTABLEN R1 R0 1 SETTABLEN R2 R0 2 RETURN R0 0 @@ -6763,7 +6772,7 @@ RETURN R0 0 )"), R"( GETVARARGS R0 2 -ADDK R2 R1 K0 +ADDK R2 R1 K0 [1] SETTABLEN R1 R0 1 MOVE R1 R2 RETURN R0 0 @@ -6781,11 +6790,11 @@ return bit32.extract(v, 1, 3) )"), R"( GETVARARGS R0 1 -FASTCALL2K 59 R0 K0 L0 +FASTCALL2K 59 R0 K0 L0 [65] MOVE R2 R0 -LOADK R3 K1 -LOADK R4 K2 -GETIMPORT R1 5 +LOADK R3 K1 [1] +LOADK R4 K2 [3] +GETIMPORT R1 5 [bit32.extract] CALL R1 3 -1 L0: RETURN R1 -1 )"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index c32f2870f..d7340ce51 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -297,6 +297,8 @@ TEST_CASE("Clear") TEST_CASE("Strings") { + ScopedFastFlag luauStringFormatAnyFix{"LuauStringFormatAnyFix", true}; + runConformance("strings.lua"); } diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 18e91e1ba..fae8bc4d8 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2177,8 +2177,6 @@ type C = Packed<(number, X...)> TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") { - ScopedFastFlag luauFixNamedFunctionParse{"LuauFixNamedFunctionParse", true}; - matchParseError("type A = (b: number)", "Expected '->' when parsing function type, got "); matchParseError("type P = () -> T... type B = P<(x: number, y: string)>", "Expected '->' when parsing function type, got '>'"); matchParseError("type F = (T...) -> ()", "Expected '->' when parsing function type, got '>'"); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 7ccda8def..7d27437d7 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -83,7 +83,6 @@ TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") ToStringOptions opts; opts.useLineBreaks = true; - opts.DEPRECATED_indent = true; //clang-format off CHECK_EQ("{|\n" @@ -97,7 +96,6 @@ TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") TEST_CASE_FIXTURE(Fixture, "nil_or_nil_is_nil_not_question_mark") { - ScopedFastFlag sff("LuauSerializeNilUnionAsNil", true); CheckResult result = check(R"( type nil_ty = nil | nil local a : nil_ty = nil @@ -109,7 +107,6 @@ TEST_CASE_FIXTURE(Fixture, "nil_or_nil_is_nil_not_question_mark") TEST_CASE_FIXTURE(Fixture, "long_disjunct_of_nil_is_nil_not_question_mark") { - ScopedFastFlag sff("LuauSerializeNilUnionAsNil", true); CheckResult result = check(R"( type nil_ty = nil | nil | nil | nil | nil local a : nil_ty = nil diff --git a/tests/TypeReduction.test.cpp b/tests/TypeReduction.test.cpp index c629b3e34..f2d7b027b 100644 --- a/tests/TypeReduction.test.cpp +++ b/tests/TypeReduction.test.cpp @@ -88,6 +88,96 @@ TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_is_zero") CHECK(ty); } +TEST_CASE_FIXTURE(ReductionFixture, "stress_test_recursion_limits") +{ + TypeId ty = arena.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}); + for (size_t i = 0; i < 20'000; ++i) + { + TableType table; + table.state = TableState::Sealed; + table.props["x"] = {ty}; + ty = arena.addType(IntersectionType{{arena.addType(table), arena.addType(table)}}); + } + + CHECK(!reduction.reduce(ty)); +} + +TEST_CASE_FIXTURE(ReductionFixture, "caching") +{ + SUBCASE("free_tables") + { + TypeId ty1 = arena.addType(TableType{}); + getMutable(ty1)->state = TableState::Free; + getMutable(ty1)->props["x"] = {builtinTypes->stringType}; + + TypeId ty2 = arena.addType(TableType{}); + getMutable(ty2)->state = TableState::Sealed; + + TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); + + ToStringOptions opts; + opts.exhaustive = true; + + CHECK("{- x: string -} & {| |}" == toString(reductionof(intersectionTy))); + + getMutable(ty1)->state = TableState::Sealed; + CHECK("{| x: string |}" == toString(reductionof(intersectionTy))); + } + + SUBCASE("unsealed_tables") + { + TypeId ty1 = arena.addType(TableType{}); + getMutable(ty1)->state = TableState::Unsealed; + getMutable(ty1)->props["x"] = {builtinTypes->stringType}; + + TypeId ty2 = arena.addType(TableType{}); + getMutable(ty2)->state = TableState::Sealed; + + TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); + + ToStringOptions opts; + opts.exhaustive = true; + + CHECK("{| x: string |}" == toString(reductionof(intersectionTy))); + + getMutable(ty1)->state = TableState::Sealed; + CHECK("{| x: string |}" == toString(reductionof(intersectionTy))); + } + + SUBCASE("free_types") + { + TypeId ty1 = arena.freshType(nullptr); + TypeId ty2 = arena.addType(TableType{}); + getMutable(ty2)->state = TableState::Sealed; + + TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); + + ToStringOptions opts; + opts.exhaustive = true; + + CHECK("a & {| |}" == toString(reductionof(intersectionTy))); + + *asMutable(ty1) = BoundType{ty2}; + CHECK("{| |}" == toString(reductionof(intersectionTy))); + } + + SUBCASE("we_can_see_that_the_cache_works_if_we_mutate_a_normally_not_mutated_type") + { + TypeId ty1 = arena.addType(BoundType{builtinTypes->stringType}); + TypeId ty2 = builtinTypes->numberType; + + TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); + + ToStringOptions opts; + opts.exhaustive = true; + + CHECK("never" == toString(reductionof(intersectionTy))); // Bound & number ~ never + + *asMutable(ty1) = BoundType{ty2}; + CHECK("never" == toString(reductionof(intersectionTy))); // Bound & number ~ number, but the cache is `never`. + } +} // caching + TEST_CASE_FIXTURE(ReductionFixture, "intersections_without_negations") { SUBCASE("string_and_string") @@ -359,6 +449,34 @@ TEST_CASE_FIXTURE(ReductionFixture, "intersections_without_negations") TypeId ty = reductionof("{ p: string } & { p: string, [string]: number }"); CHECK("{| [string]: number, p: string |}" == toString(ty)); } + + SUBCASE("fresh_type_and_string") + { + TypeId freshTy = arena.freshType(nullptr); + TypeId ty = reductionof(arena.addType(IntersectionType{{freshTy, builtinTypes->stringType}})); + CHECK("a & string" == toString(ty)); + } + + SUBCASE("string_and_fresh_type") + { + TypeId freshTy = arena.freshType(nullptr); + TypeId ty = reductionof(arena.addType(IntersectionType{{builtinTypes->stringType, freshTy}})); + CHECK("a & string" == toString(ty)); + } + + SUBCASE("generic_and_string") + { + TypeId genericTy = arena.addType(GenericType{"G"}); + TypeId ty = reductionof(arena.addType(IntersectionType{{genericTy, builtinTypes->stringType}})); + CHECK("G & string" == toString(ty)); + } + + SUBCASE("string_and_generic") + { + TypeId genericTy = arena.addType(GenericType{"G"}); + TypeId ty = reductionof(arena.addType(IntersectionType{{builtinTypes->stringType, genericTy}})); + CHECK("G & string" == toString(ty)); + } } // intersections_without_negations TEST_CASE_FIXTURE(ReductionFixture, "intersections_with_negations") @@ -1232,18 +1350,4 @@ TEST_CASE_FIXTURE(ReductionFixture, "cycles") } } -TEST_CASE_FIXTURE(ReductionFixture, "stress_test_recursion_limits") -{ - TypeId ty = arena.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}); - for (size_t i = 0; i < 20'000; ++i) - { - TableType table; - table.state = TableState::Sealed; - table.props["x"] = {ty}; - ty = arena.addType(IntersectionType{{arena.addType(table), arena.addType(table)}}); - } - - CHECK(!reduction.reduce(ty)); -} - TEST_SUITE_END(); diff --git a/tests/conformance/strings.lua b/tests/conformance/strings.lua index 61bac7266..59d342189 100644 --- a/tests/conformance/strings.lua +++ b/tests/conformance/strings.lua @@ -164,6 +164,16 @@ local ud = newproxy(true) getmetatable(ud).__tostring = function() return "good" end assert(string.format("%*", ud) == "good") +assert(string.format(string.rep("%*", 100), table.unpack(table.create(100, 1))) == string.rep("1", 100)) + +do + local a = "1234567890" + a = string.format("%*%*%*%*%*", a, a, a, a, a) + a = string.format("%*%*%*%*%*", a, a, a, a, a) + a = string.format("%*%*%*%*%*", a, a, a, a, a) + assert(a == string.rep("1234567890", 125)) +end + assert(pcall(function() string.format("%#*", "bad form") end) == false) diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 7ae80cc4c..4b47ed26a 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -686,4 +686,18 @@ do assert(pcall(table.clear, table.freeze({})) == false) end +-- check that namecall lookup doesn't give up on entries missing from cached slot position +do + for i = 1,10 do + local t = setmetatable({}, { __index = { foo = 1 }}) + + assert(t.foo == 1) + + t[-i] = 2 + t.foo = function(t, i) return -i end + + assert(t:foo(i) == -i) + end +end + return"OK" diff --git a/tools/faillist.txt b/tools/faillist.txt index f336bb222..3fcd4200a 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -2,8 +2,6 @@ AnnotationTests.corecursive_types_error_on_tight_loop AnnotationTests.duplicate_type_param_name AnnotationTests.for_loop_counter_annotation_is_checked AnnotationTests.generic_aliases_are_cloned_properly -AnnotationTests.instantiation_clone_has_to_follow -AnnotationTests.luau_print_is_not_special_without_the_flag AnnotationTests.occurs_check_on_cyclic_intersection_type AnnotationTests.occurs_check_on_cyclic_union_type AnnotationTests.too_many_type_params @@ -87,6 +85,8 @@ DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props DefinitionTests.declaring_generic_functions DefinitionTests.definition_file_classes +DefinitionTests.definitions_symbols_are_generated_for_recursively_referenced_types +DefinitionTests.single_class_type_identity_in_global_types FrontendTest.environments FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded FrontendTest.nocheck_cycle_used_by_checked @@ -140,6 +140,7 @@ NonstrictModeTests.parameters_having_type_any_are_optional NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_props_are_any Normalize.cyclic_table_normalizes_sensibly +Normalize.negations_of_classes ParseErrorRecovery.generic_type_list_recovery ParseErrorRecovery.recovery_of_parenthesized_expressions ParserTests.parse_nesting_based_end_detection_failsafe_earlier @@ -165,8 +166,10 @@ RefinementTest.call_an_incompatible_function_after_using_typeguard RefinementTest.correctly_lookup_property_whose_base_was_previously_refined2 RefinementTest.discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false RefinementTest.discriminate_tag +RefinementTest.eliminate_subclasses_of_instance RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil +RefinementTest.narrow_from_subclasses_of_instance_or_string_or_vector3 RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table @@ -176,6 +179,7 @@ RefinementTest.type_guard_narrowed_into_nothingness RefinementTest.type_narrow_for_all_the_userdata RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector +RefinementTest.typeguard_cast_instance_or_vector3_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.typeguard_narrows_for_table RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table @@ -457,6 +461,10 @@ TypePackTests.type_pack_type_parameters TypePackTests.unify_variadic_tails_in_arguments TypePackTests.unify_variadic_tails_in_arguments_free TypePackTests.variadic_packs +TypeReductionTests.discriminable_unions +TypeReductionTests.intersections_with_negations +TypeReductionTests.negations +TypeReductionTests.unions_with_negations TypeSingletons.error_detailed_tagged_union_mismatch_bool TypeSingletons.error_detailed_tagged_union_mismatch_string TypeSingletons.function_call_with_singletons diff --git a/tools/heapgraph.py b/tools/heapgraph.py index b8dc207f0..17ce7a40a 100644 --- a/tools/heapgraph.py +++ b/tools/heapgraph.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details # Given two heap snapshots (A & B), this tool performs reachability analysis on new objects allocated in B diff --git a/tools/heapstat.py b/tools/heapstat.py index 7337aa446..d9fd839a7 100644 --- a/tools/heapstat.py +++ b/tools/heapstat.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details # Given a heap snapshot, this tool gathers basic statistics about the allocated objects diff --git a/tools/lvmexecute_split.py b/tools/lvmexecute_split.py index f4a78960b..16de45dcc 100644 --- a/tools/lvmexecute_split.py +++ b/tools/lvmexecute_split.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details # This code can be used to split lvmexecute.cpp VM switch into separate functions for use as native code generation fallbacks @@ -34,7 +34,7 @@ function = "" signature = "" -includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_COVERAGE", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS"] +includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS"] state = 0 diff --git a/tools/numprint.py b/tools/numprint.py index 47ad36d9c..4fb64e628 100644 --- a/tools/numprint.py +++ b/tools/numprint.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details # This code can be used to generate power tables for Schubfach algorithm (see lnumprint.cpp) diff --git a/tools/patchtests.py b/tools/patchtests.py index 56970c9f6..82d8364cd 100644 --- a/tools/patchtests.py +++ b/tools/patchtests.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details # This code can be used to patch Compiler.test.cpp following bytecode changes, based on error output diff --git a/tools/perfgraph.py b/tools/perfgraph.py index eb6b68ce1..94c57cc77 100644 --- a/tools/perfgraph.py +++ b/tools/perfgraph.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details # Given a profile dump, this tool generates a flame graph based on the stacks listed in the profile diff --git a/tools/perfstat.py b/tools/perfstat.py index e5cfd1173..1af2473b5 100644 --- a/tools/perfstat.py +++ b/tools/perfstat.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details # Given a profile dump, this tool displays top functions based on the stacks listed in the profile diff --git a/tools/stack-usage-reporter.py b/tools/stack-usage-reporter.py index 91e74887d..9f11b6504 100644 --- a/tools/stack-usage-reporter.py +++ b/tools/stack-usage-reporter.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details # The purpose of this script is to analyze disassembly generated by objdump or diff --git a/tools/test_dcr.py b/tools/test_dcr.py index 6d553b648..d30490b30 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -1,3 +1,4 @@ +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details import argparse diff --git a/tools/tracegraph.py b/tools/tracegraph.py index a46423e7e..1cdd32d6a 100644 --- a/tools/tracegraph.py +++ b/tools/tracegraph.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details # Given a trace event file, this tool generates a flame graph based on the event scopes present in the file From eec289ad1bf323cec82d7aa9bb9c86823d25529c Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 20 Jan 2023 14:02:39 +0200 Subject: [PATCH 29/66] Sync to upstream/release/560 --- Analysis/include/Luau/AstQuery.h | 2 +- Analysis/include/Luau/Autocomplete.h | 3 +- .../include/Luau/ConstraintGraphBuilder.h | 16 +- Analysis/include/Luau/ConstraintSolver.h | 2 +- Analysis/include/Luau/Scope.h | 2 + Analysis/include/Luau/ToString.h | 2 + Analysis/include/Luau/Type.h | 8 +- Analysis/include/Luau/TypeReduction.h | 31 +- Analysis/include/Luau/VisitType.h | 5 +- Analysis/src/Anyification.cpp | 1 + Analysis/src/AstQuery.cpp | 26 +- Analysis/src/Autocomplete.cpp | 87 +- Analysis/src/BuiltinDefinitions.cpp | 5 +- Analysis/src/Clone.cpp | 2 + Analysis/src/ConstraintGraphBuilder.cpp | 151 ++-- Analysis/src/ConstraintSolver.cpp | 33 +- Analysis/src/Instantiation.cpp | 1 + Analysis/src/Linter.cpp | 45 +- Analysis/src/Normalize.cpp | 5 +- Analysis/src/Quantify.cpp | 21 +- Analysis/src/ToString.cpp | 23 +- Analysis/src/Type.cpp | 12 +- Analysis/src/TypeChecker2.cpp | 128 +-- Analysis/src/TypeInfer.cpp | 227 ++--- Analysis/src/TypePack.cpp | 6 +- Analysis/src/TypeReduction.cpp | 546 +++++++----- Analysis/src/Unifiable.cpp | 20 +- Analysis/src/Unifier.cpp | 109 ++- Ast/include/Luau/Ast.h | 3 +- Ast/src/Ast.cpp | 5 +- Ast/src/Parser.cpp | 2 +- CodeGen/src/CodeGen.cpp | 2 +- CodeGen/src/EmitCommonX64.cpp | 2 +- CodeGen/src/EmitInstructionX64.cpp | 20 +- CodeGen/src/EmitInstructionX64.h | 2 +- CodeGen/src/IrBuilder.cpp | 563 +++++++++++++ CodeGen/src/IrBuilder.h | 63 ++ CodeGen/src/IrData.h | 9 +- CodeGen/src/IrDump.cpp | 49 +- CodeGen/src/IrDump.h | 2 + CodeGen/src/IrTranslation.cpp | 780 ++++++++++++++++++ CodeGen/src/IrTranslation.h | 58 ++ CodeGen/src/IrUtils.h | 1 + Common/include/Luau/Common.h | 5 +- {Ast => Common}/include/Luau/DenseHash.h | 0 Sources.cmake | 6 +- VM/src/lapi.cpp | 4 +- fuzz/luau.proto | 5 + fuzz/protoprint.cpp | 24 + tests/Autocomplete.test.cpp | 121 ++- tests/Conformance.test.cpp | 27 + tests/Fixture.cpp | 8 +- tests/Linter.test.cpp | 5 - tests/Normalize.test.cpp | 5 - tests/ToDot.test.cpp | 81 +- tests/ToString.test.cpp | 2 - tests/TypeInfer.builtins.test.cpp | 171 +++- tests/TypeInfer.classes.test.cpp | 12 - tests/TypeInfer.definitions.test.cpp | 2 - tests/TypeInfer.functions.test.cpp | 8 - tests/TypeInfer.intersectionTypes.test.cpp | 82 -- tests/TypeInfer.modules.test.cpp | 2 - tests/TypeInfer.negations.test.cpp | 3 - tests/TypeInfer.primitives.test.cpp | 11 - tests/TypeInfer.refinements.test.cpp | 32 +- tests/TypeInfer.tables.test.cpp | 20 +- tests/TypeInfer.test.cpp | 10 - tests/TypeInfer.tryUnify.test.cpp | 30 +- tests/TypeInfer.typePacks.cpp | 4 - tests/TypeInfer.unionTypes.test.cpp | 45 - tests/TypeReduction.test.cpp | 522 ++++++------ tests/conformance/basic.lua | 15 + tests/conformance/ndebug_upvalues.lua | 13 + tools/faillist.txt | 36 +- 74 files changed, 3174 insertions(+), 1217 deletions(-) create mode 100644 CodeGen/src/IrBuilder.cpp create mode 100644 CodeGen/src/IrBuilder.h create mode 100644 CodeGen/src/IrTranslation.cpp create mode 100644 CodeGen/src/IrTranslation.h rename {Ast => Common}/include/Luau/DenseHash.h (100%) create mode 100644 tests/conformance/ndebug_upvalues.lua diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index bf7384623..aa7ef8d3e 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -64,7 +64,7 @@ struct ExprOrLocal }; std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos); -std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos); +std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos, bool includeTypes = false); AstNode* findNodeAtPosition(const SourceModule& source, Position pos); AstExpr* findExprAtPosition(const SourceModule& source, Position pos); ScopePtr findScopeAtPosition(const Module& module, Position pos); diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index a4101e162..618325777 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -89,7 +89,8 @@ struct AutocompleteResult }; using ModuleName = std::string; -using StringCompletionCallback = std::function(std::string tag, std::optional ctx)>; +using StringCompletionCallback = + std::function(std::string tag, std::optional ctx, std::optional contents)>; AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 3a67610a8..a1caf85af 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -240,20 +240,28 @@ struct ConstraintGraphBuilder * Resolves a type from its AST annotation. * @param scope the scope that the type annotation appears within. * @param ty the AST annotation to resolve. - * @param topLevel whether the annotation is a "top-level" annotation. + * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type of the AST annotation. **/ - TypeId resolveType(const ScopePtr& scope, AstType* ty, bool topLevel = false); + TypeId resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments); /** * Resolves a type pack from its AST annotation. * @param scope the scope that the type annotation appears within. * @param tp the AST annotation to resolve. + * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type pack of the AST annotation. **/ - TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp); + TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArguments); - TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list); + /** + * Resolves a type pack from its AST annotation. + * @param scope the scope that the type annotation appears within. + * @param list the AST annotation to resolve. + * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. + * @return the type pack of the AST annotation. + **/ + TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments); std::vector> createGenerics(const ScopePtr& scope, AstArray generics); std::vector> createGenericPacks(const ScopePtr& scope, AstArray packs); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 5c235a354..66d3e8f3f 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -111,7 +111,7 @@ struct ConstraintSolver bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); - bool tryDispatch(const SetPropConstraint& c, NotNull constraint); + bool tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force); bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); // for a, ... in some_table do diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 797c9cb04..a8f83e2f7 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -43,6 +43,8 @@ struct Scope std::unordered_map exportedTypeBindings; std::unordered_map privateTypeBindings; std::unordered_map typeAliasLocations; + std::unordered_map typeAliasNameLocations; + std::unordered_map importedModules; // Mapping from the name in the require statement to the internal moduleName. std::unordered_map> importedTypeBindings; DenseHashSet builtinTypeNames{""}; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 461a8fffb..7758e8f99 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -132,7 +132,9 @@ std::optional getFunctionNameAsString(const AstExpr& expr); // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression std::string dump(TypeId ty); +std::string dump(const std::optional& ty); std::string dump(TypePackId ty); +std::string dump(const std::optional& ty); std::string dump(const Constraint& c); std::string dump(const std::shared_ptr& scope, const char* name); diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 734d40eac..4962274c9 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -375,6 +375,7 @@ struct TableType std::vector instantiatedTypeParams; std::vector instantiatedTypePackParams; ModuleName definitionModuleName; + Location definitionLocation; std::optional boundTo; Tags tags; @@ -656,8 +657,11 @@ struct BuiltinTypes const TypeId unknownType; const TypeId neverType; const TypeId errorType; - const TypeId falsyType; // No type binding! - const TypeId truthyType; // No type binding! + const TypeId falsyType; + const TypeId truthyType; + + const TypeId optionalNumberType; + const TypeId optionalStringType; const TypePackId anyTypePack; const TypePackId neverTypePack; diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h index a7cec9468..7cc169781 100644 --- a/Analysis/include/Luau/TypeReduction.h +++ b/Analysis/include/Luau/TypeReduction.h @@ -9,11 +9,28 @@ namespace Luau { -/// If it's desirable to allocate into a different arena than the TypeReduction instance you have, you will need -/// to create a temporary TypeReduction in that case. This is because TypeReduction caches the reduced type. +namespace detail +{ +template +struct ReductionContext +{ + T type = nullptr; + bool irreducible = false; +}; +} // namespace detail + +struct TypeReductionOptions +{ + /// If it's desirable for type reduction to allocate into a different arena than the TypeReduction instance you have, you will need + /// to create a temporary TypeReduction in that case, and set [`TypeReductionOptions::allowTypeReductionsFromOtherArenas`] to true. + /// This is because TypeReduction caches the reduced type. + bool allowTypeReductionsFromOtherArenas = false; +}; + struct TypeReduction { - explicit TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle); + explicit TypeReduction( + NotNull arena, NotNull builtinTypes, NotNull handle, const TypeReductionOptions& opts = {}); std::optional reduce(TypeId ty); std::optional reduce(TypePackId tp); @@ -23,12 +40,10 @@ struct TypeReduction NotNull arena; NotNull builtinTypes; NotNull handle; + TypeReductionOptions options; - DenseHashMap cachedTypes{nullptr}; - DenseHashMap cachedTypePacks{nullptr}; - - std::pair, bool> reduceImpl(TypeId ty); - std::pair, bool> reduceImpl(TypePackId tp); + DenseHashMap> memoizedTypes{nullptr}; + DenseHashMap> memoizedTypePacks{nullptr}; // Computes an *estimated length* of the cartesian product of the given type. size_t cartesianProductSize(TypeId ty) const; diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index fdac65856..e0ab12e7e 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -318,7 +318,10 @@ struct GenericTypeVisitor } } else if (auto ntv = get(ty)) - visit(ty, *ntv); + { + if (visit(ty, *ntv)) + traverse(ntv->ty); + } else if (!FFlag::LuauCompleteVisitor) return visit_detail::unsee(seen, ty); else diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp index e0ddeacf2..15dd25cc5 100644 --- a/Analysis/src/Anyification.cpp +++ b/Analysis/src/Anyification.cpp @@ -59,6 +59,7 @@ TypeId Anyification::clean(TypeId ty) { TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; clone.definitionModuleName = ttv->definitionModuleName; + clone.definitionLocation = ttv->definitionLocation; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; clone.tags = ttv->tags; diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index ffab734ab..e95b0017f 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -12,6 +12,7 @@ #include LUAU_FASTFLAG(LuauCompleteTableKeysBetter); +LUAU_FASTFLAGVARIABLE(SupportTypeAliasGoToDeclaration, false); namespace Luau { @@ -183,14 +184,31 @@ struct FindFullAncestry final : public AstVisitor std::vector nodes; Position pos; Position documentEnd; + bool includeTypes = false; - explicit FindFullAncestry(Position pos, Position documentEnd) + explicit FindFullAncestry(Position pos, Position documentEnd, bool includeTypes = false) : pos(pos) , documentEnd(documentEnd) + , includeTypes(includeTypes) { } - bool visit(AstNode* node) + bool visit(AstType* type) override + { + if (FFlag::SupportTypeAliasGoToDeclaration) + { + if (includeTypes) + return visit(static_cast(type)); + else + return false; + } + else + { + return AstVisitor::visit(type); + } + } + + bool visit(AstNode* node) override { if (node->location.contains(pos)) { @@ -220,13 +238,13 @@ std::vector findAncestryAtPositionForAutocomplete(const SourceModule& return finder.ancestry; } -std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) +std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos, bool includeTypes) { const Position end = source.root->location.end; if (pos > end) pos = end; - FindFullAncestry finder(pos, end); + FindFullAncestry finder(pos, end, includeTypes); source.root->visit(&finder); return finder.nodes; } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 6fab97d52..4e5403f87 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,9 @@ LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInIf, false); +LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInWhile, false); +LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInFor, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringContent, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -1263,6 +1266,26 @@ static bool isSimpleInterpolatedString(const AstNode* node) return interpString != nullptr && interpString->expressions.size == 0; } +static std::optional getStringContents(const AstNode* node) +{ + if (!FFlag::LuauAutocompleteStringContent) + return std::nullopt; + + if (const AstExprConstantString* string = node->as()) + { + return std::string(string->value.data, string->value.size); + } + else if (const AstExprInterpString* interpString = node->as(); interpString && interpString->expressions.size == 0) + { + LUAU_ASSERT(interpString->strings.size == 1); + return std::string(interpString->strings.data->data, interpString->strings.data->size); + } + else + { + return std::nullopt; + } +} + static std::optional autocompleteStringParams(const SourceModule& sourceModule, const ModulePtr& module, const std::vector& nodes, Position position, StringCompletionCallback callback) { @@ -1295,10 +1318,12 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } + std::optional candidateString = getStringContents(nodes.back()); + auto performCallback = [&](const FunctionType* funcType) -> std::optional { for (const std::string& tag : funcType->tags) { - if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func))) + if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) { return ret; } @@ -1329,6 +1354,15 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } +static AutocompleteResult autocompleteWhileLoopKeywords(std::vector ancestry) +{ + AutocompleteEntryMap ret; + ret["do"] = {AutocompleteEntryKind::Keyword}; + ret["and"] = {AutocompleteEntryKind::Keyword}; + ret["or"] = {AutocompleteEntryKind::Keyword}; + return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword}; +} + static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull builtinTypes, TypeArena* typeArena, Scope* globalScope, Position position, StringCompletionCallback callback) { @@ -1387,13 +1421,24 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (!statFor->hasDo || position < statFor->doLocation.begin) { - if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + if (FFlag::LuauFixAutocompleteInFor) + { + if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || + (statFor->step && statFor->step->location.containsClosed(position))) + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || - (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + } + else + { + if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || + (statFor->step && statFor->step->location.containsClosed(position))) + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + } return {}; } @@ -1443,7 +1488,16 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) { if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + { + if (FFlag::LuauFixAutocompleteInWhile) + { + return autocompleteWhileLoopKeywords(ancestry); + } + else + { + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + } + } if (!statWhile->hasDo || position < statWhile->doLocation.begin) return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); @@ -1452,9 +1506,20 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; } - else if (AstStatWhile* statWhile = extractStat(ancestry); statWhile && !statWhile->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - + else if (AstStatWhile* statWhile = extractStat(ancestry); + FFlag::LuauFixAutocompleteInWhile ? (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && + statWhile->condition && !statWhile->condition->location.containsClosed(position)) + : (statWhile && !statWhile->hasDo)) + { + if (FFlag::LuauFixAutocompleteInWhile) + { + return autocompleteWhileLoopKeywords(ancestry); + } + else + { + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + } + } else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, @@ -1468,7 +1533,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; } else if (AstStatIf* statIf = extractStat(ancestry); - statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && + statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && (!FFlag::LuauFixAutocompleteInIf || (statIf->condition && !statIf->condition->location.containsClosed(position)))) { if (FFlag::LuauFixAutocompleteInIf) diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 26aaf54fe..1fb915e95 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -15,7 +15,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false) LUAU_FASTFLAG(LuauReportShadowedTypeAlias) @@ -583,7 +582,7 @@ static std::optional> magicFunctionSetMetaTable( TypeId mtTy = arena.addType(mtv); - if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) + if (expr.args.size < 1) { if (FFlag::LuauUnknownAndNeverType) return std::nullopt; @@ -591,7 +590,7 @@ static std::optional> magicFunctionSetMetaTable( return WithPredicate{}; } - if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) + if (!expr.self) { AstExpr* targetExpr = expr.args.data[0]; if (AstExprLocal* targetLocal = targetExpr->as()) diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 870d29490..3dd8df870 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -263,6 +263,7 @@ void TypeCloner::operator()(const TableType& t) arg = clone(arg, dest, cloneState); ttv->definitionModuleName = t.definitionModuleName; + ttv->definitionLocation = t.definitionLocation; ttv->tags = t.tags; } @@ -446,6 +447,7 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl LUAU_ASSERT(!ttv->boundTo); TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; clone.definitionModuleName = ttv->definitionModuleName; + clone.definitionLocation = ttv->definitionLocation; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; clone.instantiatedTypeParams = ttv->instantiatedTypeParams; diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 6a80fed2e..7181e4f03 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauNegatedClassTypes); LUAU_FASTFLAG(LuauScopelessModule); +LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration); namespace Luau { @@ -418,7 +419,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) TypeId ty = nullptr; if (local->annotation) - ty = resolveType(scope, local->annotation, /* topLevel */ true); + ty = resolveType(scope, local->annotation, /* inTypeArguments */ false); varTypes.push_back(ty); } @@ -521,8 +522,12 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) const Name name{local->vars.data[i]->name.value}; if (ModulePtr module = moduleResolver->getModule(moduleInfo->name)) + { scope->importedTypeBindings[name] = FFlag::LuauScopelessModule ? module->exportedTypeBindings : module->getModuleScope()->exportedTypeBindings; + if (FFlag::SupportTypeAliasGoToDeclaration) + scope->importedModules[name] = moduleName; + } } } } @@ -775,7 +780,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia } ScopePtr resolvingScope = *defnIt; - TypeId ty = resolveType(resolvingScope, alias->type, /* topLevel */ true); + TypeId ty = resolveType(resolvingScope, alias->type, /* inTypeArguments */ false); if (alias->exported) { @@ -798,7 +803,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* { LUAU_ASSERT(global->type); - TypeId globalTy = resolveType(scope, global->type); + TypeId globalTy = resolveType(scope, global->type, /* inTypeArguments */ false); Name globalName(global->name.value); module->declaredGlobals[globalName] = globalTy; @@ -854,7 +859,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d for (const AstDeclaredClassProp& prop : declaredClass->props) { Name propName(prop.name.value); - TypeId propTy = resolveType(scope, prop.ty); + TypeId propTy = resolveType(scope, prop.ty, /* inTypeArguments */ false); bool assignToMetatable = isMetamethod(propName); @@ -937,8 +942,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction if (!generics.empty() || !genericPacks.empty()) funScope = childScope(global, scope); - TypePackId paramPack = resolveTypePack(funScope, global->params); - TypePackId retPack = resolveTypePack(funScope, global->retTypes); + TypePackId paramPack = resolveTypePack(funScope, global->params, /* inTypeArguments */ false); + TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); FunctionType* ftv = getMutable(fnType); @@ -1501,7 +1506,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* if Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) { check(scope, typeAssert->expr, std::nullopt); - return Inference{resolveType(scope, typeAssert->annotation)}; + return Inference{resolveType(scope, typeAssert->annotation, /* inTypeArguments */ false)}; } std::tuple ConstraintGraphBuilder::checkBinary( @@ -1563,7 +1568,7 @@ std::tuple ConstraintGraphBuilder::checkBinary( TypeId ty = follow(typeFun->type); // We're only interested in the root class of any classes. - if (auto ctv = get(ty); !ctv || !ctv->parent) + if (auto ctv = get(ty); !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent == builtinTypes->classType) : !ctv->parent)) discriminantTy = ty; } @@ -1618,39 +1623,6 @@ TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray< return arena->addTypePack(std::move(types)); } -/** - * If the expr is a dotted set of names, and if the root symbol refers to an - * unsealed table, return that table type, plus the indeces that follow as a - * vector. - */ -static std::optional>> extractDottedName(AstExpr* expr) -{ - std::vector names; - - while (expr) - { - if (auto global = expr->as()) - { - std::reverse(begin(names), end(names)); - return std::pair{global->name, std::move(names)}; - } - else if (auto local = expr->as()) - { - std::reverse(begin(names), end(names)); - return std::pair{local->local, std::move(names)}; - } - else if (auto indexName = expr->as()) - { - names.push_back(indexName->index.value); - expr = indexName->expr; - } - else - return std::nullopt; - } - - return std::nullopt; -} - /** * This function is mostly about identifying properties that are being inserted into unsealed tables. * @@ -1671,13 +1643,38 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) else if (!expr->is()) return check(scope, expr).ty; - auto dottedPath = extractDottedName(expr); - if (!dottedPath) - return check(scope, expr).ty; - const auto [sym, segments] = std::move(*dottedPath); + Symbol sym; + std::vector segments; + std::vector exprs; + + AstExpr* e = expr; + while (e) + { + if (auto global = e->as()) + { + sym = global->name; + break; + } + else if (auto local = e->as()) + { + sym = local->local; + break; + } + else if (auto indexName = e->as()) + { + segments.push_back(indexName->index.value); + exprs.push_back(e); + e = indexName->expr; + } + else + return check(scope, expr).ty; + } LUAU_ASSERT(!segments.empty()); + std::reverse(begin(segments), end(segments)); + std::reverse(begin(exprs), end(exprs)); + auto lookupResult = scope->lookupEx(sym); if (!lookupResult) return check(scope, expr).ty; @@ -1695,7 +1692,18 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) symbolScope->bindings[sym].typeId = updatedType; symbolScope->dcrRefinements[*def] = updatedType; - astTypes[expr] = propTy; + TypeId prevSegmentTy = updatedType; + for (size_t i = 0; i < segments.size(); ++i) + { + TypeId segmentTy = arena->addType(BlockedType{}); + astTypes[exprs[i]] = segmentTy; + addConstraint(scope, expr->location, HasPropConstraint{segmentTy, prevSegmentTy, segments[i]}); + prevSegmentTy = segmentTy; + } + + astTypes[expr] = prevSegmentTy; + astTypes[e] = updatedType; + // astTypes[expr] = propTy; return propTy; } @@ -1845,7 +1853,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS if (local->annotation) { - annotationTy = resolveType(signatureScope, local->annotation, /* topLevel */ true); + annotationTy = resolveType(signatureScope, local->annotation, /* inTypeArguments */ false); addConstraint(signatureScope, local->annotation->location, SubtypeConstraint{t, annotationTy}); } else if (i < expectedArgPack.head.size()) @@ -1866,7 +1874,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS { if (fn->varargAnnotation) { - TypePackId annotationType = resolveTypePack(signatureScope, fn->varargAnnotation); + TypePackId annotationType = resolveTypePack(signatureScope, fn->varargAnnotation, /* inTypeArguments */ false); varargPack = annotationType; } else if (expectedArgPack.tail && get(*expectedArgPack.tail)) @@ -1893,7 +1901,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // Type checking will sort out any discrepancies later. if (fn->returnAnnotation) { - TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation); + TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation, /* inTypeArguments */ false); // We bind the annotated type directly here so that, when we need to // generate constraints for return types, we have a guarantee that we @@ -1942,7 +1950,7 @@ void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFun } } -TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool topLevel) +TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments) { TypeId result = nullptr; @@ -1960,7 +1968,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b return builtinTypes->errorRecoveryType(); } else - return resolveType(scope, ref->parameters.data[0].type, topLevel); + return resolveType(scope, ref->parameters.data[0].type, inTypeArguments); } } @@ -1994,11 +2002,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b // that is done in the parser. if (p.type) { - parameters.push_back(resolveType(scope, p.type)); + parameters.push_back(resolveType(scope, p.type, /* inTypeArguments */ true)); } else if (p.typePack) { - packParameters.push_back(resolveTypePack(scope, p.typePack)); + packParameters.push_back(resolveTypePack(scope, p.typePack, /* inTypeArguments */ true)); } else { @@ -2010,10 +2018,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b result = arena->addType(PendingExpansionType{ref->prefix, ref->name, parameters, packParameters}); - if (topLevel) - { + // If we're not in a type argument context, we need to create a constraint that expands this. + // The dispatching of the above constraint will queue up additional constraints for nested + // type function applications. + if (!inTypeArguments) addConstraint(scope, ty->location, TypeAliasExpansionConstraint{/* target */ result}); - } } } else @@ -2035,7 +2044,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b { std::string name = prop.name.value; // TODO: Recursion limit. - TypeId propTy = resolveType(scope, prop.type); + TypeId propTy = resolveType(scope, prop.type, inTypeArguments); // TODO: Fill in location. props[name] = {propTy}; } @@ -2044,8 +2053,8 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b { // TODO: Recursion limit. indexer = TableIndexer{ - resolveType(scope, tab->indexer->indexType), - resolveType(scope, tab->indexer->resultType), + resolveType(scope, tab->indexer->indexType, inTypeArguments), + resolveType(scope, tab->indexer->resultType, inTypeArguments), }; } @@ -2089,8 +2098,8 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b signatureScope = scope; } - TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes); - TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes); + TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments); + TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments); // TODO: FunctionType needs a pointer to the scope so that we know // how to quantify/instantiate it. @@ -2130,7 +2139,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (AstType* part : unionAnnotation->types) { // TODO: Recursion limit. - parts.push_back(resolveType(scope, part, topLevel)); + parts.push_back(resolveType(scope, part, inTypeArguments)); } result = arena->addType(UnionType{parts}); @@ -2141,7 +2150,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (AstType* part : intersectionAnnotation->types) { // TODO: Recursion limit. - parts.push_back(resolveType(scope, part, topLevel)); + parts.push_back(resolveType(scope, part, inTypeArguments)); } result = arena->addType(IntersectionType{parts}); @@ -2168,16 +2177,16 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp) +TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument) { TypePackId result; if (auto expl = tp->as()) { - result = resolveTypePack(scope, expl->typeList); + result = resolveTypePack(scope, expl->typeList, inTypeArgument); } else if (auto var = tp->as()) { - TypeId ty = resolveType(scope, var->variadicType); + TypeId ty = resolveType(scope, var->variadicType, inTypeArgument); result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); } else if (auto gen = tp->as()) @@ -2202,19 +2211,19 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTyp return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list) +TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments) { std::vector head; for (AstType* headTy : list.types) { - head.push_back(resolveType(scope, headTy)); + head.push_back(resolveType(scope, headTy, inTypeArguments)); } std::optional tail = std::nullopt; if (list.tailType) { - tail = resolveTypePack(scope, list.tailType); + tail = resolveTypePack(scope, list.tailType, inTypeArguments); } return arena->addTypePack(TypePack{head, tail}); @@ -2229,7 +2238,7 @@ std::vector> ConstraintGraphBuilder::crea std::optional defaultTy = std::nullopt; if (generic.defaultValue) - defaultTy = resolveType(scope, generic.defaultValue); + defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false); result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); } @@ -2247,7 +2256,7 @@ std::vector> ConstraintGraphBuilder:: std::optional defaultTy = std::nullopt; if (generic.defaultValue) - defaultTy = resolveTypePack(scope, generic.defaultValue); + defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false); result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 8092144cc..3fbd7d9e2 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -417,7 +417,7 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); else if (auto spc = get(*constraint)) - success = tryDispatch(*spc, constraint); + success = tryDispatch(*spc, constraint, force); else if (auto sottc = get(*constraint)) success = tryDispatch(*sottc, constraint); else @@ -933,13 +933,11 @@ struct InfiniteTypeFinder : TypeOnceVisitor struct InstantiationQueuer : TypeOnceVisitor { ConstraintSolver* solver; - const InstantiationSignature& signature; NotNull scope; Location location; - explicit InstantiationQueuer(NotNull scope, const Location& location, ConstraintSolver* solver, const InstantiationSignature& signature) + explicit InstantiationQueuer(NotNull scope, const Location& location, ConstraintSolver* solver) : solver(solver) - , signature(signature) , scope(scope) , location(location) { @@ -1061,8 +1059,17 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul TypeId instantiated = *maybeInstantiated; TypeId target = follow(instantiated); + // The application is not recursive, so we need to queue up application of + // any child type function instantiations within the result in order for it + // to be complete. + InstantiationQueuer queuer{constraint->scope, constraint->location, this}; + queuer.traverse(target); + if (target->persistent) + { + bindResult(target); return true; + } // Type function application will happily give us the exact same type if // there are e.g. generic saturatedTypeArguments that go unused. @@ -1102,12 +1109,6 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul bindResult(target); - // The application is not recursive, so we need to queue up application of - // any child type function instantiations within the result in order for it - // to be complete. - InstantiationQueuer queuer{constraint->scope, constraint->location, this, signature}; - queuer.traverse(target); - instantiatedAliases[signature] = target; return true; @@ -1326,13 +1327,16 @@ static std::optional updateTheTableType(NotNull arena, TypeId return res; } -bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint) +bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force) { TypeId subjectType = follow(c.subjectType); if (isBlocked(subjectType)) return block(subjectType, constraint); + if (!force && get(subjectType)) + return block(subjectType, constraint); + std::optional existingPropType = subjectType; for (const std::string& segment : c.path) { @@ -1399,6 +1403,13 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) + { + // Classes never change shape as a result of property assignments. + // The result is always the subject. + bind(c.resultType, subjectType); + return true; + } else if (get(subjectType) || get(subjectType)) { bind(c.resultType, subjectType); diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 209ba7e90..912c4155b 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -116,6 +116,7 @@ TypeId ReplaceGenerics::clean(TypeId ty) { TableType clone = TableType{ttv->props, ttv->indexer, level, scope, TableState::Free}; clone.definitionModuleName = ttv->definitionModuleName; + clone.definitionLocation = ttv->definitionLocation; return addType(std::move(clone)); } else diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 4250b3117..752259bdf 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -13,7 +13,6 @@ #include LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) -LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) namespace Luau { @@ -331,8 +330,7 @@ class LintGlobalLocal : AstVisitor "Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local", g.firstRef->name.value, top->location.begin.line + 1); } - else if (FFlag::LuauLintGlobalNeverReadBeforeWritten && g.assigned && !g.readBeforeWritten && !g.definedInModuleScope && - g.firstRef->name != context->placeholder) + else if (g.assigned && !g.readBeforeWritten && !g.definedInModuleScope && g.firstRef->name != context->placeholder) { emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, "Global '%s' is never read before being written. Consider changing it to local", g.firstRef->name.value); @@ -353,7 +351,7 @@ class LintGlobalLocal : AstVisitor bool visit(AstExprGlobal* node) override { - if (FFlag::LuauLintGlobalNeverReadBeforeWritten && !functionStack.empty() && !functionStack.back().dominatedGlobals.contains(node->name)) + if (!functionStack.empty() && !functionStack.back().dominatedGlobals.contains(node->name)) { Global& g = globals[node->name]; g.readBeforeWritten = true; @@ -386,18 +384,15 @@ class LintGlobalLocal : AstVisitor { Global& g = globals[gv->name]; - if (FFlag::LuauLintGlobalNeverReadBeforeWritten) + if (functionStack.empty()) { - if (functionStack.empty()) - { - g.definedInModuleScope = true; - } - else + g.definedInModuleScope = true; + } + else + { + if (!functionStack.back().conditionalExecution) { - if (!functionStack.back().conditionalExecution) - { - functionStack.back().dominatedGlobals.insert(gv->name); - } + functionStack.back().dominatedGlobals.insert(gv->name); } } @@ -437,11 +432,8 @@ class LintGlobalLocal : AstVisitor else { g.assigned = true; - if (FFlag::LuauLintGlobalNeverReadBeforeWritten) - { - g.definedAsFunction = true; - g.definedInModuleScope = functionStack.empty(); - } + g.definedAsFunction = true; + g.definedInModuleScope = functionStack.empty(); } trackGlobalRef(gv); @@ -475,9 +467,6 @@ class LintGlobalLocal : AstVisitor bool visit(AstStatIf* node) override { - if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) - return true; - HoldConditionalExecution ce(*this); node->condition->visit(this); node->thenbody->visit(this); @@ -489,9 +478,6 @@ class LintGlobalLocal : AstVisitor bool visit(AstStatWhile* node) override { - if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) - return true; - HoldConditionalExecution ce(*this); node->condition->visit(this); node->body->visit(this); @@ -501,9 +487,6 @@ class LintGlobalLocal : AstVisitor bool visit(AstStatRepeat* node) override { - if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) - return true; - HoldConditionalExecution ce(*this); node->condition->visit(this); node->body->visit(this); @@ -513,9 +496,6 @@ class LintGlobalLocal : AstVisitor bool visit(AstStatFor* node) override { - if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) - return true; - HoldConditionalExecution ce(*this); node->from->visit(this); node->to->visit(this); @@ -530,9 +510,6 @@ class LintGlobalLocal : AstVisitor bool visit(AstStatForIn* node) override { - if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) - return true; - HoldConditionalExecution ce(*this); for (AstExpr* expr : node->values) expr->visit(this); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index e19d48f80..901144e4a 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -17,13 +17,10 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); -LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); LUAU_FASTFLAG(LuauUninhabitedSubAnything2) namespace Luau @@ -2165,7 +2162,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th argTypes = *argTypesOpt; retTypes = hftv->retTypes; } - else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) + else if (hftv->argTypes == tftv->argTypes) { std::optional retTypesOpt = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); if (!retTypesOpt) diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 22c5875be..aac7864a8 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -157,6 +157,8 @@ struct PureQuantifier : Substitution Scope* scope; std::vector insertedGenerics; std::vector insertedGenericPacks; + bool seenMutableType = false; + bool seenGenericType = false; PureQuantifier(TypeArena* arena, Scope* scope) : Substitution(TxnLog::empty(), arena) @@ -170,11 +172,18 @@ struct PureQuantifier : Substitution if (auto ftv = get(ty)) { - return subsumes(scope, ftv->scope); + bool result = subsumes(scope, ftv->scope); + seenMutableType |= result; + return result; } else if (auto ttv = get(ty)) { - return ttv->state == TableState::Free && subsumes(scope, ttv->scope); + if (ttv->state == TableState::Free) + seenMutableType = true; + else if (ttv->state == TableState::Generic) + seenGenericType = true; + + return ttv->state == TableState::Unsealed || (ttv->state == TableState::Free && subsumes(scope, ttv->scope)); } return false; @@ -207,7 +216,11 @@ struct PureQuantifier : Substitution *resultTable = *ttv; resultTable->level = TypeLevel{}; resultTable->scope = scope; - resultTable->state = TableState::Generic; + + if (ttv->state == TableState::Free) + resultTable->state = TableState::Generic; + else if (ttv->state == TableState::Unsealed) + resultTable->state = TableState::Sealed; return result; } @@ -251,7 +264,7 @@ TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope) ftv->scope = scope; ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); - ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty(); + ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType; return *result; } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 86ae3cde7..89d3c5557 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -16,7 +16,6 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) -LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) /* * Prefix generic typenames with gen- @@ -606,9 +605,7 @@ struct TypeStringifier stringify(ttv.indexer->indexResultType); state.emit("}"); - if (FFlag::LuauUnseeArrayTtv) - state.unsee(&ttv); - + state.unsee(&ttv); return; } @@ -1403,6 +1400,15 @@ std::string dump(TypeId ty) return s; } +std::string dump(const std::optional& ty) +{ + if (ty) + return dump(*ty); + + printf("nullopt\n"); + return "nullopt"; +} + std::string dump(TypePackId ty) { std::string s = toString(ty, dumpOptions()); @@ -1410,6 +1416,15 @@ std::string dump(TypePackId ty) return s; } +std::string dump(const std::optional& ty) +{ + if (ty) + return dump(*ty); + + printf("nullopt\n"); + return "nullopt"; +} + std::string dump(const ScopePtr& scope, const char* name) { auto binding = scope->linearSearchForBinding(name); diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index f03061a85..4b2165187 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -27,6 +27,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAGVARIABLE(LuauMatchReturnsOptionalString, false); namespace Luau { @@ -768,15 +769,16 @@ BuiltinTypes::BuiltinTypes() , errorType(arena->addType(Type{ErrorType{}, /*persistent*/ true})) , falsyType(arena->addType(Type{UnionType{{falseType, nilType}}, /*persistent*/ true})) , truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true})) + , optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true})) + , optionalStringType(arena->addType(Type{UnionType{{stringType, nilType}}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) - , uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack)) + , uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true})) , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) { TypeId stringMetatable = makeStringMetatable(); asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable}; persist(stringMetatable); - persist(uninhabitableTypePack); freeze(*arena); } @@ -1231,12 +1233,12 @@ static std::vector parsePatternString(NotNull builtinTypes if (i + 1 < size && data[i + 1] == ')') { i++; - result.push_back(builtinTypes->numberType); + result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalNumberType : builtinTypes->numberType); continue; } ++depth; - result.push_back(builtinTypes->stringType); + result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalStringType : builtinTypes->stringType); } else if (data[i] == ')') { @@ -1254,7 +1256,7 @@ static std::vector parsePatternString(NotNull builtinTypes return std::vector(); if (result.empty()) - result.push_back(builtinTypes->stringType); + result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalStringType : builtinTypes->stringType); return result; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 1d212851a..1972f26f3 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -4,6 +4,7 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Clone.h" +#include "Luau/Error.h" #include "Luau/Instantiation.h" #include "Luau/Metamethods.h" #include "Luau/Normalize.h" @@ -212,6 +213,12 @@ struct TypeChecker2 return bestScope; } + enum ValueContext + { + LValue, + RValue + }; + void visit(AstStat* stat) { auto pusher = pushStack(stat); @@ -273,7 +280,7 @@ struct TypeChecker2 void visit(AstStatIf* ifStatement) { - visit(ifStatement->condition); + visit(ifStatement->condition, RValue); visit(ifStatement->thenbody); if (ifStatement->elsebody) visit(ifStatement->elsebody); @@ -281,14 +288,14 @@ struct TypeChecker2 void visit(AstStatWhile* whileStatement) { - visit(whileStatement->condition); + visit(whileStatement->condition, RValue); visit(whileStatement->body); } void visit(AstStatRepeat* repeatStatement) { visit(repeatStatement->body); - visit(repeatStatement->condition); + visit(repeatStatement->condition, RValue); } void visit(AstStatBreak*) {} @@ -315,12 +322,12 @@ struct TypeChecker2 } for (AstExpr* expr : ret->list) - visit(expr); + visit(expr, RValue); } void visit(AstStatExpr* expr) { - visit(expr->expr); + visit(expr->expr, RValue); } void visit(AstStatLocal* local) @@ -331,7 +338,7 @@ struct TypeChecker2 AstExpr* value = i < local->values.size ? local->values.data[i] : nullptr; if (value) - visit(value); + visit(value, RValue); TypeId* maybeValueType = value ? module->astTypes.find(value) : nullptr; if (i != local->values.size - 1 || maybeValueType) @@ -387,10 +394,10 @@ struct TypeChecker2 if (forStatement->var->annotation) visit(forStatement->var->annotation); - visit(forStatement->from); - visit(forStatement->to); + visit(forStatement->from, RValue); + visit(forStatement->to, RValue); if (forStatement->step) - visit(forStatement->step); + visit(forStatement->step, RValue); visit(forStatement->body); } @@ -403,7 +410,7 @@ struct TypeChecker2 } for (AstExpr* expr : forInStatement->values) - visit(expr); + visit(expr, RValue); visit(forInStatement->body); @@ -610,11 +617,11 @@ struct TypeChecker2 for (size_t i = 0; i < count; ++i) { AstExpr* lhs = assign->vars.data[i]; - visit(lhs); + visit(lhs, LValue); TypeId lhsType = lookupType(lhs); AstExpr* rhs = assign->values.data[i]; - visit(rhs); + visit(rhs, RValue); TypeId rhsType = lookupType(rhs); if (!isSubtype(rhsType, lhsType, stack.back())) @@ -635,7 +642,7 @@ struct TypeChecker2 void visit(AstStatFunction* stat) { - visit(stat->name); + visit(stat->name, LValue); visit(stat->func); } @@ -698,13 +705,13 @@ struct TypeChecker2 void visit(AstStatError* stat) { for (AstExpr* expr : stat->expressions) - visit(expr); + visit(expr, RValue); for (AstStat* s : stat->statements) visit(s); } - void visit(AstExpr* expr) + void visit(AstExpr* expr, ValueContext context) { auto StackPusher = pushStack(expr); @@ -712,7 +719,7 @@ struct TypeChecker2 { } else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) @@ -730,9 +737,9 @@ struct TypeChecker2 else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) @@ -754,9 +761,9 @@ struct TypeChecker2 LUAU_ASSERT(!"TypeChecker2 encountered an unknown expression type"); } - void visit(AstExprGroup* expr) + void visit(AstExprGroup* expr, ValueContext context) { - visit(expr->expr); + visit(expr->expr, context); } void visit(AstExprConstantNil* expr) @@ -808,10 +815,10 @@ struct TypeChecker2 void visit(AstExprCall* call) { - visit(call->func); + visit(call->func, RValue); for (AstExpr* arg : call->args) - visit(arg); + visit(arg, RValue); TypeArena* arena = &testArena; Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; @@ -820,6 +827,8 @@ struct TypeChecker2 TypeId functionType = lookupType(call->func); TypeId testFunctionType = functionType; TypePack args; + std::vector argLocs; + argLocs.reserve(call->args.size + 1); if (get(functionType) || get(functionType)) return; @@ -830,6 +839,7 @@ struct TypeChecker2 if (std::optional instantiatedCallMm = instantiation.substitute(*callMm)) { args.head.push_back(functionType); + argLocs.push_back(call->func->location); testFunctionType = follow(*instantiatedCallMm); } else @@ -899,11 +909,13 @@ struct TypeChecker2 ice.ice("method call expression has no 'self'"); args.head.push_back(lookupType(indexExpr->expr)); + argLocs.push_back(indexExpr->expr->location); } for (size_t i = 0; i < call->args.size; ++i) { AstExpr* arg = call->args.data[i]; + argLocs.push_back(arg->location); TypeId* argTy = module->astTypes.find(arg); if (argTy) args.head.push_back(*argTy); @@ -919,19 +931,34 @@ struct TypeChecker2 args.head.push_back(builtinTypes->anyType); } - TypePackId argsTp = arena->addTypePack(args); - FunctionType ftv{argsTp, expectedRetType}; - TypeId expectedType = arena->addType(ftv); + TypePackId expectedArgTypes = arena->addTypePack(args); + + const FunctionType* inferredFunctionType = get(testFunctionType); + LUAU_ASSERT(inferredFunctionType); // testFunctionType should always be a FunctionType here - if (!isSubtype(testFunctionType, expectedType, stack.back())) + size_t argIndex = 0; + auto inferredArgIt = begin(inferredFunctionType->argTypes); + auto expectedArgIt = begin(expectedArgTypes); + while (inferredArgIt != end(inferredFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) { - CloneState cloneState; - expectedType = clone(expectedType, testArena, cloneState); - reportError(TypeMismatch{expectedType, functionType}, call->location); + Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; + reportErrors(tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt)); + + ++argIndex; + ++inferredArgIt; + ++expectedArgIt; } + + // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad + ErrorVec errors = tryUnify(stack.back(), call->location, expectedArgTypes, inferredFunctionType->argTypes); + for (TypeError e : errors) + if (get(e) != nullptr) + reportError(std::move(e)); + + reportErrors(tryUnify(stack.back(), call->location, inferredFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult)); } - void visit(AstExprIndexName* indexName) + void visit(AstExprIndexName* indexName, ValueContext context) { TypeId leftType = lookupType(indexName->expr); @@ -939,14 +966,14 @@ struct TypeChecker2 if (!norm) reportError(NormalizationTooComplex{}, indexName->indexLocation); - checkIndexTypeFromType(leftType, *norm, indexName->index.value, indexName->location); + checkIndexTypeFromType(leftType, *norm, indexName->index.value, indexName->location, context); } - void visit(AstExprIndexExpr* indexExpr) + void visit(AstExprIndexExpr* indexExpr, ValueContext context) { // TODO! - visit(indexExpr->expr); - visit(indexExpr->index); + visit(indexExpr->expr, LValue); + visit(indexExpr->index, RValue); } void visit(AstExprFunction* fn) @@ -986,14 +1013,14 @@ struct TypeChecker2 for (const AstExprTable::Item& item : expr->items) { if (item.key) - visit(item.key); - visit(item.value); + visit(item.key, LValue); + visit(item.value, RValue); } } void visit(AstExprUnary* expr) { - visit(expr->expr); + visit(expr->expr, RValue); NotNull scope = stack.back(); TypeId operandType = lookupType(expr->expr); @@ -1053,8 +1080,8 @@ struct TypeChecker2 TypeId visit(AstExprBinary* expr, void* overrideKey = nullptr) { - visit(expr->left); - visit(expr->right); + visit(expr->left, LValue); + visit(expr->right, LValue); NotNull scope = stack.back(); @@ -1307,7 +1334,7 @@ struct TypeChecker2 void visit(AstExprTypeAssertion* expr) { - visit(expr->expr); + visit(expr->expr, RValue); visit(expr->annotation); TypeId annotationType = lookupAnnotation(expr->annotation); @@ -1326,16 +1353,16 @@ struct TypeChecker2 void visit(AstExprIfElse* expr) { // TODO! - visit(expr->condition); - visit(expr->trueExpr); - visit(expr->falseExpr); + visit(expr->condition, RValue); + visit(expr->trueExpr, RValue); + visit(expr->falseExpr, RValue); } void visit(AstExprError* expr) { // TODO! for (AstExpr* e : expr->expressions) - visit(e); + visit(e, RValue); } /** Extract a TypeId for the first type of the provided pack. @@ -1550,7 +1577,7 @@ struct TypeChecker2 void visit(AstTypeTypeof* ty) { - visit(ty->expr); + visit(ty->expr, RValue); } void visit(AstTypeUnion* ty) @@ -1630,9 +1657,10 @@ struct TypeChecker2 } template - ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy) + ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy, CountMismatch::Context context = CountMismatch::Arg) { Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; + u.ctx = context; u.useScopes = true; u.tryUnify(subTy, superTy); @@ -1658,7 +1686,7 @@ struct TypeChecker2 reportError(std::move(e)); } - void checkIndexTypeFromType(TypeId denormalizedTy, const NormalizedType& norm, const std::string& prop, const Location& location) + void checkIndexTypeFromType(TypeId tableTy, const NormalizedType& norm, const std::string& prop, const Location& location, ValueContext context) { bool foundOneProp = false; std::vector typesMissingTheProp; @@ -1723,9 +1751,11 @@ struct TypeChecker2 if (!typesMissingTheProp.empty()) { if (foundOneProp) - reportError(TypeError{location, MissingUnionProperty{denormalizedTy, typesMissingTheProp, prop}}); + reportError(MissingUnionProperty{tableTy, typesMissingTheProp, prop}, location); + else if (context == LValue) + reportError(CannotExtendTable{tableTy, CannotExtendTable::Property, prop}, location); else - reportError(TypeError{location, UnknownProperty{denormalizedTy, prop}}); + reportError(UnknownProperty{tableTy, prop}, location); } } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 5c1ee3888..de52a5261 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,16 +32,13 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) -LUAU_FASTFLAGVARIABLE(LuauTypeInferMissingFollows, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauScopelessModule, false) -LUAU_FASTFLAGVARIABLE(LuauFollowInLvalueIndexCheck, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) @@ -52,9 +49,8 @@ LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) -LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) -LUAU_FASTFLAGVARIABLE(LuauCallableClasses, false) +LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration) namespace Luau { @@ -333,12 +329,9 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo prepareErrorsForDisplay(currentModule->errors); - if (FFlag::LuauTypeNormalization2) - { - // Clear the normalizer caches, since they contain types from the internal type surface - normalizer.clearCaches(); - normalizer.arena = nullptr; - } + // Clear the normalizer caches, since they contain types from the internal type surface + normalizer.clearCaches(); + normalizer.arena = nullptr; currentModule->clonePublicInterface(builtinTypes, *iceHandler); @@ -512,7 +505,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A prototype(scope, *typealias, subLevel); ++subLevel; } - else if (const auto& declaredClass = stat->as(); FFlag::LuauDeclareClassPrototype && declaredClass) + else if (const auto& declaredClass = stat->as()) { prototype(scope, *declaredClass); } @@ -1137,8 +1130,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) const Name name{local.vars.data[i]->name.value}; if (ModulePtr module = resolver->getModule(moduleInfo->name)) + { scope->importedTypeBindings[name] = FFlag::LuauScopelessModule ? module->exportedTypeBindings : module->getModuleScope()->exportedTypeBindings; + if (FFlag::SupportTypeAliasGoToDeclaration) + scope->importedModules[name] = moduleInfo->name; + } // In non-strict mode we force the module type on the variable, in strict mode it is already unified if (isNonstrictMode()) @@ -1535,6 +1532,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // This is a shallow clone, original recursive links to self are not updated TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->state}; clone.definitionModuleName = ttv->definitionModuleName; + clone.definitionLocation = ttv->definitionLocation; clone.name = name; for (auto param : binding->typeParams) @@ -1621,6 +1619,8 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; scope->typeAliasLocations[name] = typealias.location; + if (FFlag::SupportTypeAliasGoToDeclaration) + scope->typeAliasNameLocations[name] = typealias.nameLocation; } } else @@ -1639,12 +1639,13 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; scope->typeAliasLocations[name] = typealias.location; + if (FFlag::SupportTypeAliasGoToDeclaration) + scope->typeAliasNameLocations[name] = typealias.nameLocation; } } void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { - LUAU_ASSERT(FFlag::LuauDeclareClassPrototype); std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; if (declaredClass.superName) { @@ -1683,166 +1684,74 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { - if (FFlag::LuauDeclareClassPrototype) - { - Name className(declaredClass.name.value); - - // Don't bother checking if the class definition was incorrect - if (incorrectClassDefinitions.find(&declaredClass)) - return; + Name className(declaredClass.name.value); - std::optional binding; - if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end()) - binding = it->second; + // Don't bother checking if the class definition was incorrect + if (incorrectClassDefinitions.find(&declaredClass)) + return; - // This class definition must have been `prototype()`d first. - if (!binding) - ice("Class not predeclared"); + std::optional binding; + if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end()) + binding = it->second; - TypeId classTy = binding->type; - ClassType* ctv = getMutable(classTy); + // This class definition must have been `prototype()`d first. + if (!binding) + ice("Class not predeclared"); - if (!ctv->metatable) - ice("No metatable for declared class"); + TypeId classTy = binding->type; + ClassType* ctv = getMutable(classTy); - TableType* metatable = getMutable(*ctv->metatable); - for (const AstDeclaredClassProp& prop : declaredClass.props) - { - Name propName(prop.name.value); - TypeId propTy = resolveType(scope, *prop.ty); + if (!ctv->metatable) + ice("No metatable for declared class"); - bool assignToMetatable = isMetamethod(propName); - Luau::ClassType::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + TableType* metatable = getMutable(*ctv->metatable); + for (const AstDeclaredClassProp& prop : declaredClass.props) + { + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, *prop.ty); - // Function types always take 'self', but this isn't reflected in the - // parsed annotation. Add it here. - if (prop.isMethod) - { - if (FunctionType* ftv = getMutable(propTy)) - { - ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); - ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - ftv->hasSelf = true; - } - } + bool assignToMetatable = isMetamethod(propName); + Luau::ClassType::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; - if (assignTo.count(propName) == 0) + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) + { + if (FunctionType* ftv = getMutable(propTy)) { - assignTo[propName] = {propTy}; + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + ftv->hasSelf = true; } - else - { - TypeId currentTy = assignTo[propName].type; - - // We special-case this logic to keep the intersection flat; otherwise we - // would create a ton of nested intersection types. - if (const IntersectionType* itv = get(currentTy)) - { - std::vector options = itv->parts; - options.push_back(propTy); - TypeId newItv = addType(IntersectionType{std::move(options)}); - - assignTo[propName] = {newItv}; - } - else if (get(currentTy)) - { - TypeId intersection = addType(IntersectionType{{currentTy, propTy}}); + } - assignTo[propName] = {intersection}; - } - else - { - reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); - } - } + if (assignTo.count(propName) == 0) + { + assignTo[propName] = {propTy}; } - } - else - { - std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; - if (declaredClass.superName) + else { - Name superName = Name(declaredClass.superName->value); - std::optional lookupType = scope->lookupType(superName); + TypeId currentTy = assignTo[propName].type; - if (!lookupType) + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionType* itv = get(currentTy)) { - reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type}); - return; - } - - // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); - superTy = lookupType->type; + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionType{std::move(options)}); - if (!get(follow(*superTy))) - { - reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", - superName.c_str(), declaredClass.name.value)}); - return; + assignTo[propName] = {newItv}; } - } - - Name className(declaredClass.name.value); - - TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); - - ClassType* ctv = getMutable(classTy); - TypeId metaTy = addType(TableType{TableState::Sealed, scope->level}); - TableType* metatable = getMutable(metaTy); - - ctv->metatable = metaTy; - - scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; - - for (const AstDeclaredClassProp& prop : declaredClass.props) - { - Name propName(prop.name.value); - TypeId propTy = resolveType(scope, *prop.ty); - - bool assignToMetatable = isMetamethod(propName); - Luau::ClassType::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; - - // Function types always take 'self', but this isn't reflected in the - // parsed annotation. Add it here. - if (prop.isMethod) + else if (get(currentTy)) { - if (FunctionType* ftv = getMutable(propTy)) - { - ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); - ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - ftv->hasSelf = true; - } - } + TypeId intersection = addType(IntersectionType{{currentTy, propTy}}); - if (assignTo.count(propName) == 0) - { - assignTo[propName] = {propTy}; + assignTo[propName] = {intersection}; } else { - TypeId currentTy = assignTo[propName].type; - - // We special-case this logic to keep the intersection flat; otherwise we - // would create a ton of nested intersection types. - if (const IntersectionType* itv = get(currentTy)) - { - std::vector options = itv->parts; - options.push_back(propTy); - TypeId newItv = addType(IntersectionType{std::move(options)}); - - assignTo[propName] = {newItv}; - } - else if (get(currentTy)) - { - TypeId intersection = addType(IntersectionType{{currentTy, propTy}}); - - assignTo[propName] = {intersection}; - } - else - { - reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); - } + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); } } } @@ -2370,6 +2279,7 @@ TypeId TypeChecker::checkExprTable( TableState state = TableState::Unsealed; TableType table = TableType{std::move(props), indexer, scope->level, state}; table.definitionModuleName = currentModuleName; + table.definitionLocation = expr.location; return addType(table); } @@ -3362,8 +3272,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex TypeId indexType = checkExpr(scope, *expr.index).type; - if (FFlag::LuauFollowInLvalueIndexCheck) - exprType = follow(exprType); + exprType = follow(exprType); if (get(exprType) || get(exprType)) return exprType; @@ -4280,7 +4189,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc { callTy = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false); } - else if (const ClassType* ctv = get(fn); FFlag::LuauCallableClasses && ctv && ctv->metatable) + else if (const ClassType* ctv = get(fn); ctv && ctv->metatable) { callTy = getIndexTypeFromType(scope, *ctv->metatable, "__call", expr.func->location, /* addErrors= */ false); } @@ -4477,7 +4386,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast std::string s; for (size_t i = 0; i < overloadTypes.size(); ++i) { - TypeId overload = FFlag::LuauTypeInferMissingFollows ? follow(overloadTypes[i]) : overloadTypes[i]; + TypeId overload = follow(overloadTypes[i]); Unifier state = mkUnifier(scope, expr.location); // Unify return types @@ -4859,7 +4768,7 @@ TypePackId TypeChecker::anyifyModuleReturnTypePackGenerics(TypePackId tp) if (const VariadicTypePack* vtp = get(tp)) { - TypeId ty = FFlag::LuauTypeInferMissingFollows ? follow(vtp->ty) : vtp->ty; + TypeId ty = follow(vtp->ty); return get(ty) ? anyTypePack : tp; } @@ -5371,6 +5280,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno TableType ttv{props, tableIndexer, scope->level, TableState::Sealed}; ttv.definitionModuleName = currentModuleName; + ttv.definitionLocation = annotation.location; return addType(std::move(ttv)); } else if (const auto& func = annotation.as()) @@ -5572,6 +5482,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, ttv->instantiatedTypeParams = typeParams; ttv->instantiatedTypePackParams = typePackParams; ttv->definitionModuleName = currentModuleName; + ttv->definitionLocation = location; } return instantiated; @@ -6101,11 +6012,11 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc if (optionIsSubtype && !targetIsSubtype) return option; else if (!optionIsSubtype && targetIsSubtype) - return FFlag::LuauTypeInferMissingFollows ? follow(eqP.type) : eqP.type; + return follow(eqP.type); else if (!optionIsSubtype && !targetIsSubtype) return nope; else if (optionIsSubtype && targetIsSubtype) - return FFlag::LuauTypeInferMissingFollows ? follow(eqP.type) : eqP.type; + return follow(eqP.type); } else { diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index e41bf2fe9..ccea604ff 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTxnLogTypePackIterator, false) - namespace Luau { @@ -62,8 +60,8 @@ TypePackIterator::TypePackIterator(TypePackId typePack) } TypePackIterator::TypePackIterator(TypePackId typePack, const TxnLog* log) - : currentTypePack(FFlag::LuauTxnLogTypePackIterator ? log->follow(typePack) : follow(typePack)) - , tp(FFlag::LuauTxnLogTypePackIterator ? log->get(currentTypePack) : get(currentTypePack)) + : currentTypePack(log->follow(typePack)) + , tp(log->get(currentTypePack)) , currentIndex(0) , log(log) { diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 8d837ddb2..c748f8632 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" #include "Luau/Error.h" #include "Luau/RecursionCounter.h" +#include "Luau/VisitType.h" #include #include @@ -18,24 +19,7 @@ namespace Luau namespace { -struct RecursionGuard : RecursionLimiter -{ - std::deque* seen; - - RecursionGuard(int* count, int limit, std::deque* seen) - : RecursionLimiter(count, limit) - , seen(seen) - { - // count has been incremented, which should imply that seen has already had an element pushed in. - LUAU_ASSERT(size_t(*count) == seen->size()); - } - - ~RecursionGuard() - { - LUAU_ASSERT(!seen->empty()); // It is UB to pop_back() on an empty deque. - seen->pop_back(); - } -}; +using detail::ReductionContext; template std::pair get2(const Thing& one, const Thing& two) @@ -51,15 +35,11 @@ struct TypeReducer NotNull builtinTypes; NotNull handle; - std::unordered_map copies; - std::deque seen; - int depth = 0; + DenseHashMap>* memoizedTypes; + DenseHashMap>* memoizedTypePacks; + DenseHashSet* cyclicTypes; - // When we encounter _any type_ that which is usually mutated in-place, we need to not cache the result. - // e.g. `'a & {} T` may have an upper bound constraint `{}` placed upon `'a`, but this constraint was not - // known when we decided to reduce this intersection type. By not caching, we'll always be forced to perform - // the reduction calculus over again. - bool cacheOk = true; + int depth = 0; TypeId reduce(TypeId ty); TypePackId reduce(TypePackId tp); @@ -70,62 +50,73 @@ struct TypeReducer TypeId functionType(TypeId ty); TypeId negationType(TypeId ty); - RecursionGuard guard(TypeId ty); - RecursionGuard guard(TypePackId tp); + bool isIrreducible(TypeId ty); + bool isIrreducible(TypePackId tp); + + TypeId memoize(TypeId ty, TypeId reducedTy); + TypePackId memoize(TypePackId tp, TypePackId reducedTp); - void checkCacheable(TypeId ty); - void checkCacheable(TypePackId tp); + // It's either cyclic with no memoized result, so we should terminate, or + // there is a memoized result but one that's being reduced top-down, so + // we need to return the root of that memoized result to tighten up things. + TypeId memoizedOr(TypeId ty) const; + TypePackId memoizedOr(TypePackId tp) const; + + using BinaryFold = std::optional (TypeReducer::*)(TypeId, TypeId); + using UnaryFold = TypeId (TypeReducer::*)(TypeId); template LUAU_NOINLINE std::pair copy(TypeId ty, const T* t) { - if (auto it = copies.find(ty); it != copies.end()) - return {it->second, getMutable(it->second)}; + ty = follow(ty); + + if (auto ctx = memoizedTypes->find(ty)) + return {ctx->type, getMutable(ctx->type)}; TypeId copiedTy = arena->addType(*t); - copies[ty] = copiedTy; + (*memoizedTypes)[ty] = {copiedTy, false}; + (*memoizedTypes)[copiedTy] = {copiedTy, false}; return {copiedTy, getMutable(copiedTy)}; } - using Folder = std::optional (TypeReducer::*)(TypeId, TypeId); - template - void foldl_impl(Iter it, Iter endIt, Folder f, NotNull> result) + void foldl_impl(Iter it, Iter endIt, BinaryFold f, std::vector* result, bool* didReduce) { + RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; + while (it != endIt) { - bool replaced = false; - TypeId currentTy = reduce(*it); - RecursionGuard rg = guard(*it); + TypeId right = reduce(*it); + *didReduce |= right != follow(*it); // We're hitting a case where the `currentTy` returned a type that's the same as `T`. // e.g. `(string?) & ~(false | nil)` became `(string?) & (~false & ~nil)` but the current iterator we're consuming doesn't know this. // We will need to recurse and traverse that first. - if (auto t = get(currentTy)) + if (auto t = get(right)) { - foldl_impl(begin(t), end(t), f, result); + foldl_impl(begin(t), end(t), f, result, didReduce); ++it; continue; } + bool replaced = false; auto resultIt = result->begin(); while (resultIt != result->end()) { - TypeId& ty = *resultIt; - - std::optional reduced = (this->*f)(ty, currentTy); - if (reduced && replaced) + TypeId left = *resultIt; + if (left == right) { - // We want to erase any other elements that occurs after the first replacement too. - // e.g. `"a" | "b" | string` where `"a"` and `"b"` is in the `result` vector, then `string` replaces both `"a"` and `"b"`. - // If we don't erase redundant elements, `"b"` may be kept or be replaced by `string`, leaving us with `string | string`. - resultIt = result->erase(resultIt); + replaced = true; + ++resultIt; + continue; } - else if (reduced && !replaced) + + std::optional reduced = (this->*f)(left, right); + if (reduced) { + *resultIt = *reduced; ++resultIt; replaced = true; - ty = *reduced; } else { @@ -135,21 +126,65 @@ struct TypeReducer } if (!replaced) - result->push_back(currentTy); + result->push_back(right); + *didReduce |= replaced; ++it; } } + template + TypeId flatten(std::vector&& types) + { + if (types.size() == 1) + return types[0]; + else + return arena->addType(T{std::move(types)}); + } + template - TypeId foldl(Iter it, Iter endIt, Folder f) + TypeId foldl(Iter it, Iter endIt, std::optional ty, BinaryFold f) { std::vector result; - foldl_impl(it, endIt, f, NotNull{&result}); - if (result.size() == 1) - return result[0]; + bool didReduce = false; + foldl_impl(it, endIt, f, &result, &didReduce); + if (!didReduce && ty) + return *ty; else - return arena->addType(T{std::move(result)}); + { + // If we've done any reduction, then we'll need to reduce it again, e.g. + // `"a" | "b" | string` is reduced into `string | string`, which is then reduced into `string`. + return reduce(flatten(std::move(result))); + } + } + + template + TypeId apply(BinaryFold f, TypeId left, TypeId right) + { + left = follow(left); + right = follow(right); + + if (get(left) || get(right)) + { + std::vector types{left, right}; + return foldl(begin(types), end(types), std::nullopt, f); + } + else if (auto reduced = (this->*f)(left, right)) + return *reduced; + else + return arena->addType(T{{left, right}}); + } + + template + TypeId distribute(TypeIterator it, TypeIterator endIt, BinaryFold f, TypeId ty) + { + std::vector result; + while (it != endIt) + { + result.push_back(apply(f, *it, ty)); + ++it; + } + return flatten(std::move(result)); } }; @@ -157,42 +192,48 @@ TypeId TypeReducer::reduce(TypeId ty) { ty = follow(ty); - if (std::find(seen.begin(), seen.end(), ty) != seen.end()) - return ty; + if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible) + return ctx->type; + else if (auto cyclicTy = cyclicTypes->find(ty)) + return *cyclicTy; - RecursionGuard rg = guard(ty); - checkCacheable(ty); + RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; + TypeId result = nullptr; if (auto i = get(ty)) - return foldl(begin(i), end(i), &TypeReducer::intersectionType); + result = foldl(begin(i), end(i), ty, &TypeReducer::intersectionType); else if (auto u = get(ty)) - return foldl(begin(u), end(u), &TypeReducer::unionType); + result = foldl(begin(u), end(u), ty, &TypeReducer::unionType); else if (get(ty) || get(ty)) - return tableType(ty); + result = tableType(ty); else if (get(ty)) - return functionType(ty); - else if (auto n = get(ty)) - return negationType(follow(n->ty)); + result = functionType(ty); + else if (get(ty)) + result = negationType(ty); else - return ty; + result = ty; + + return memoize(ty, result); } TypePackId TypeReducer::reduce(TypePackId tp) { tp = follow(tp); - if (std::find(seen.begin(), seen.end(), tp) != seen.end()) - return tp; + if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible) + return ctx->type; - RecursionGuard rg = guard(tp); - checkCacheable(tp); + RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; + bool didReduce = false; TypePackIterator it = begin(tp); std::vector head; while (it != end(tp)) { - head.push_back(reduce(*it)); + TypeId reducedTy = reduce(*it); + head.push_back(reducedTy); + didReduce |= follow(*it) != follow(reducedTy); ++it; } @@ -200,10 +241,22 @@ TypePackId TypeReducer::reduce(TypePackId tp) if (tail) { if (auto vtp = get(follow(*it.tail()))) - tail = arena->addTypePack(VariadicTypePack{reduce(vtp->ty), vtp->hidden}); + { + TypeId reducedTy = reduce(vtp->ty); + if (follow(vtp->ty) != follow(reducedTy)) + { + tail = arena->addTypePack(VariadicTypePack{reducedTy, vtp->hidden}); + didReduce = true; + } + } } - return arena->addTypePack(TypePack{std::move(head), tail}); + if (!didReduce) + return memoize(tp, tp); + else if (head.empty() && tail) + return memoize(tp, *tail); + else + return memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); } std::optional TypeReducer::intersectionType(TypeId left, TypeId right) @@ -236,18 +289,7 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) else if (get(right)) return std::nullopt; // T & error ~ T & error else if (auto ut = get(left)) - { - std::vector options; - for (TypeId option : ut) - { - if (auto result = intersectionType(option, right)) - options.push_back(*result); - else - options.push_back(arena->addType(IntersectionType{{option, right}})); - } - - return foldl(begin(options), end(options), &TypeReducer::unionType); // (A | B) & T ~ (A & T) | (B & T) - } + return reduce(distribute(begin(ut), end(ut), &TypeReducer::intersectionType, right)); // (A | B) & T ~ (A & T) | (B & T) else if (get(right)) return intersectionType(right, left); // T & (A | B) ~ (A | B) & T else if (auto [p1, p2] = get2(left, right); p1 && p2) @@ -294,14 +336,7 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) return builtinTypes->neverType; // Base & Unrelated ~ never } else if (auto [f1, f2] = get2(left, right); f1 && f2) - { - if (std::find(seen.begin(), seen.end(), left) != seen.end()) - return std::nullopt; - else if (std::find(seen.begin(), seen.end(), right) != seen.end()) - return std::nullopt; - return std::nullopt; // TODO - } else if (auto [t1, t2] = get2(left, right); t1 && t2) { if (t1->state == TableState::Free || t2->state == TableState::Free) @@ -309,10 +344,10 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) else if (t1->state == TableState::Generic || t2->state == TableState::Generic) return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } - if (std::find(seen.begin(), seen.end(), left) != seen.end()) - return std::nullopt; - else if (std::find(seen.begin(), seen.end(), right) != seen.end()) - return std::nullopt; + if (cyclicTypes->find(left)) + return std::nullopt; // (t1 where t1 = { p: t1 }) & {} ~ t1 & {} + else if (cyclicTypes->find(right)) + return std::nullopt; // {} & (t1 where t1 = { p: t1 }) ~ {} & t1 TypeId resultTy = arena->addType(TableType{}); TableType* table = getMutable(resultTy); @@ -324,8 +359,7 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) // even if we have the corresponding property in the other one. if (auto other = t2->props.find(name); other != t2->props.end()) { - std::vector parts{prop.type, other->second.type}; - TypeId propTy = foldl(begin(parts), end(parts), &TypeReducer::intersectionType); + TypeId propTy = apply(&TypeReducer::intersectionType, prop.type, other->second.type); if (get(propTy)) return builtinTypes->neverType; // { p : string } & { p : number } ~ { p : string & number } ~ { p : never } ~ never else @@ -340,27 +374,33 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) // TODO: And vice versa, t2 properties against t1 indexer if it exists, // even if we have the corresponding property in the other one. if (!t1->props.count(name)) - table->props[name] = prop; // {} & { p : string } ~ { p : string } + table->props[name] = {reduce(prop.type)}; // {} & { p : string & string } ~ { p : string } } if (t1->indexer && t2->indexer) { - std::vector keyParts{t1->indexer->indexType, t2->indexer->indexType}; - TypeId keyTy = foldl(begin(keyParts), end(keyParts), &TypeReducer::intersectionType); + TypeId keyTy = apply(&TypeReducer::intersectionType, t1->indexer->indexType, t2->indexer->indexType); if (get(keyTy)) - return builtinTypes->neverType; // { [string]: _ } & { [number]: _ } ~ { [string & number]: _ } ~ { [never]: _ } ~ never + return std::nullopt; // { [string]: _ } & { [number]: _ } ~ { [string]: _ } & { [number]: _ } - std::vector valueParts{t1->indexer->indexResultType, t2->indexer->indexResultType}; - TypeId valueTy = foldl(begin(valueParts), end(valueParts), &TypeReducer::intersectionType); + TypeId valueTy = apply(&TypeReducer::intersectionType, t1->indexer->indexResultType, t2->indexer->indexResultType); if (get(valueTy)) return builtinTypes->neverType; // { [_]: string } & { [_]: number } ~ { [_]: string & number } ~ { [_]: never } ~ never table->indexer = TableIndexer{keyTy, valueTy}; } else if (t1->indexer) - table->indexer = t1->indexer; // { [number]: boolean } & { p : string } ~ { p : string, [number]: boolean } + { + TypeId keyTy = reduce(t1->indexer->indexType); + TypeId valueTy = reduce(t1->indexer->indexResultType); + table->indexer = TableIndexer{keyTy, valueTy}; // { [number]: boolean } & { p : string } ~ { p : string, [number]: boolean } + } else if (t2->indexer) - table->indexer = t2->indexer; // { p : string } & { [number]: boolean } ~ { p : string, [number]: boolean } + { + TypeId keyTy = reduce(t2->indexer->indexType); + TypeId valueTy = reduce(t2->indexer->indexResultType); + table->indexer = TableIndexer{keyTy, valueTy}; // { p : string } & { [number]: boolean } ~ { p : string, [number]: boolean } + } return resultTy; } @@ -506,22 +546,7 @@ std::optional TypeReducer::unionType(TypeId left, TypeId right) return std::nullopt; // Base | Unrelated ~ Base | Unrelated } else if (auto [nt, it] = get2(left, right); nt && it) - { - std::vector parts; - for (TypeId option : it) - { - if (auto result = unionType(left, option)) - parts.push_back(*result); - else - { - // TODO: does there exist a reduced form such that `~T | A` hasn't already reduced it, if `A & B` is irreducible? - // I want to say yes, but I can't generate a case that hits this code path. - parts.push_back(arena->addType(UnionType{{left, option}})); - } - } - - return foldl(begin(parts), end(parts), &TypeReducer::intersectionType); // ~T | (A & B) ~ (~T | A) & (~T | B) - } + return reduce(distribute(begin(it), end(it), &TypeReducer::unionType, left)); // ~T | (A & B) ~ (~T | A) & (~T | B) else if (auto [it, nt] = get2(left, right); it && nt) return unionType(right, left); // (A & B) | ~T ~ ~T | (A & B) else if (auto [nl, nr] = get2(left, right); nl && nr) @@ -628,8 +653,6 @@ std::optional TypeReducer::unionType(TypeId left, TypeId right) TypeId TypeReducer::tableType(TypeId ty) { - RecursionGuard rg = guard(ty); - if (auto mt = get(ty)) { auto [copiedTy, copied] = copy(ty, mt); @@ -639,15 +662,30 @@ TypeId TypeReducer::tableType(TypeId ty) } else if (auto tt = get(ty)) { + // Because of `typeof()`, we need to preserve pointer identity of free/unsealed tables so that + // all mutations that occurs on this will be applied without leaking the implementation details. + // As a result, we'll just use the type instead of cloning it if it's free/unsealed. + // + // We could choose to do in-place reductions here, but to be on the safer side, I propose that we do not. + if (tt->state == TableState::Free || tt->state == TableState::Unsealed) + return ty; + auto [copiedTy, copied] = copy(ty, tt); for (auto& [name, prop] : copied->props) - prop.type = reduce(prop.type); + { + TypeId propTy = reduce(prop.type); + if (get(propTy)) + return builtinTypes->neverType; + else + prop.type = propTy; + } - if (auto& indexer = copied->indexer) + if (copied->indexer) { - indexer->indexType = reduce(indexer->indexType); - indexer->indexResultType = reduce(indexer->indexResultType); + TypeId keyTy = reduce(copied->indexer->indexType); + TypeId valueTy = reduce(copied->indexer->indexResultType); + copied->indexer = TableIndexer{keyTy, valueTy}; } for (TypeId& ty : copied->instantiatedTypeParams) @@ -659,16 +697,14 @@ TypeId TypeReducer::tableType(TypeId ty) return copiedTy; } else - handle->ice("Unexpected type in TypeReducer::tableType"); + handle->ice("TypeReducer::tableType expects a TableType or MetatableType"); } TypeId TypeReducer::functionType(TypeId ty) { - RecursionGuard rg = guard(ty); - const FunctionType* f = get(ty); if (!f) - handle->ice("TypeReducer::reduce expects a FunctionType"); + handle->ice("TypeReducer::functionType expects a FunctionType"); // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. auto [copiedTy, copied] = copy(ty, f); @@ -679,140 +715,238 @@ TypeId TypeReducer::functionType(TypeId ty) TypeId TypeReducer::negationType(TypeId ty) { - RecursionGuard rg = guard(ty); + const NegationType* n = get(ty); + if (!n) + return arena->addType(NegationType{ty}); - if (auto nn = get(ty)) + if (auto nn = get(n->ty)) return nn->ty; // ~~T ~ T - else if (get(ty)) + else if (get(n->ty)) return builtinTypes->unknownType; // ~never ~ unknown - else if (get(ty)) + else if (get(n->ty)) return builtinTypes->neverType; // ~unknown ~ never - else if (get(ty)) + else if (get(n->ty)) return builtinTypes->anyType; // ~any ~ any - else if (auto ni = get(ty)) + else if (auto ni = get(n->ty)) { std::vector options; for (TypeId part : ni) - options.push_back(negationType(part)); - return foldl(begin(options), end(options), &TypeReducer::unionType); // ~(T & U) ~ (~T | ~U) + options.push_back(negationType(arena->addType(NegationType{part}))); + return reduce(flatten(std::move(options))); // ~(T & U) ~ (~T | ~U) } - else if (auto nu = get(ty)) + else if (auto nu = get(n->ty)) { std::vector parts; for (TypeId option : nu) - parts.push_back(negationType(option)); - return foldl(begin(parts), end(parts), &TypeReducer::intersectionType); // ~(T | U) ~ (~T & ~U) + parts.push_back(negationType(arena->addType(NegationType{option}))); + return reduce(flatten(std::move(parts))); // ~(T | U) ~ (~T & ~U) } else - return arena->addType(NegationType{ty}); // for all T except the ones handled above, ~T ~ ~T + return ty; // for all T except the ones handled above, ~T ~ ~T } -RecursionGuard TypeReducer::guard(TypeId ty) +bool TypeReducer::isIrreducible(TypeId ty) { - seen.push_back(ty); - return RecursionGuard{&depth, FInt::LuauTypeReductionRecursionLimit, &seen}; + ty = follow(ty); + + // Only does shallow check, the TypeReducer itself already does deep traversal. + if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible) + return true; + else if (get(ty) || get(ty) || get(ty)) + return false; + else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) + return false; + else + return true; } -RecursionGuard TypeReducer::guard(TypePackId tp) +bool TypeReducer::isIrreducible(TypePackId tp) { - seen.push_back(tp); - return RecursionGuard{&depth, FInt::LuauTypeReductionRecursionLimit, &seen}; + tp = follow(tp); + + // Only does shallow check, the TypeReducer itself already does deep traversal. + if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible) + return true; + else if (get(tp) || get(tp)) + return false; + else if (auto vtp = get(tp)) + return isIrreducible(vtp->ty); + else + return true; } -void TypeReducer::checkCacheable(TypeId ty) +TypeId TypeReducer::memoize(TypeId ty, TypeId reducedTy) { - if (!cacheOk) - return; - ty = follow(ty); + reducedTy = follow(reducedTy); - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (get(ty) || get(ty) || get(ty)) - cacheOk = false; - else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) - cacheOk = false; + // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. + // We don't need to recurse much further than that, because we already record the irreducibility from + // the bottom up. + bool irreducible = isIrreducible(reducedTy); + if (auto it = get(reducedTy)) + { + for (TypeId part : it) + irreducible &= isIrreducible(part); + } + else if (auto ut = get(reducedTy)) + { + for (TypeId option : ut) + irreducible &= isIrreducible(option); + } + else if (auto tt = get(reducedTy)) + { + for (auto& [k, p] : tt->props) + irreducible &= isIrreducible(p.type); + + if (tt->indexer) + { + irreducible &= isIrreducible(tt->indexer->indexType); + irreducible &= isIrreducible(tt->indexer->indexResultType); + } + + for (auto ta : tt->instantiatedTypeParams) + irreducible &= isIrreducible(ta); + + for (auto tpa : tt->instantiatedTypePackParams) + irreducible &= isIrreducible(tpa); + } + else if (auto mt = get(reducedTy)) + { + irreducible &= isIrreducible(mt->table); + irreducible &= isIrreducible(mt->metatable); + } + else if (auto ft = get(reducedTy)) + { + irreducible &= isIrreducible(ft->argTypes); + irreducible &= isIrreducible(ft->retTypes); + } + else if (auto nt = get(reducedTy)) + irreducible &= isIrreducible(nt->ty); + + (*memoizedTypes)[ty] = {reducedTy, irreducible}; + (*memoizedTypes)[reducedTy] = {reducedTy, irreducible}; + return reducedTy; } -void TypeReducer::checkCacheable(TypePackId tp) +TypePackId TypeReducer::memoize(TypePackId tp, TypePackId reducedTp) { - if (!cacheOk) - return; - tp = follow(tp); + reducedTp = follow(reducedTp); - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (get(tp) || get(tp)) - cacheOk = false; -} + bool irreducible = isIrreducible(reducedTp); + TypePackIterator it = begin(tp); + while (it != end(tp)) + { + irreducible &= isIrreducible(*it); + ++it; + } -} // namespace + if (it.tail()) + irreducible &= isIrreducible(*it.tail()); -TypeReduction::TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle) - : arena(arena) - , builtinTypes(builtinTypes) - , handle(handle) -{ + (*memoizedTypePacks)[tp] = {reducedTp, irreducible}; + (*memoizedTypePacks)[reducedTp] = {reducedTp, irreducible}; + return reducedTp; } -std::optional TypeReduction::reduce(TypeId ty) +TypeId TypeReducer::memoizedOr(TypeId ty) const { - if (auto found = cachedTypes.find(ty)) - return *found; + ty = follow(ty); - auto [reducedTy, cacheOk] = reduceImpl(ty); - if (cacheOk) - cachedTypes[ty] = *reducedTy; + if (auto ctx = memoizedTypes->find(ty)) + return ctx->type; + else + return ty; +}; - return reducedTy; -} +TypePackId TypeReducer::memoizedOr(TypePackId tp) const +{ + tp = follow(tp); -std::optional TypeReduction::reduce(TypePackId tp) + if (auto ctx = memoizedTypePacks->find(tp)) + return ctx->type; + else + return tp; +}; + +struct MarkCycles : TypeVisitor { - if (auto found = cachedTypePacks.find(tp)) - return *found; + DenseHashSet cyclicTypes{nullptr}; - auto [reducedTp, cacheOk] = reduceImpl(tp); - if (cacheOk) - cachedTypePacks[tp] = *reducedTp; + void cycle(TypeId ty) override + { + cyclicTypes.insert(ty); + } - return reducedTp; + bool visit(TypeId ty) override + { + return !cyclicTypes.find(ty); + } +}; + +} // namespace + +TypeReduction::TypeReduction( + NotNull arena, NotNull builtinTypes, NotNull handle, const TypeReductionOptions& opts) + : arena(arena) + , builtinTypes(builtinTypes) + , handle(handle) + , options(opts) +{ } -std::pair, bool> TypeReduction::reduceImpl(TypeId ty) +std::optional TypeReduction::reduce(TypeId ty) { - if (FFlag::DebugLuauDontReduceTypes) - return {ty, false}; + ty = follow(ty); - if (hasExceededCartesianProductLimit(ty)) - return {std::nullopt, false}; + if (FFlag::DebugLuauDontReduceTypes) + return ty; + else if (!options.allowTypeReductionsFromOtherArenas && ty->owningArena != arena) + return ty; + else if (auto ctx = memoizedTypes.find(ty); ctx && ctx->irreducible) + return ctx->type; + else if (hasExceededCartesianProductLimit(ty)) + return std::nullopt; try { - TypeReducer reducer{arena, builtinTypes, handle}; - return {reducer.reduce(ty), reducer.cacheOk}; + MarkCycles finder; + finder.traverse(ty); + + TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclicTypes}; + return reducer.reduce(ty); } catch (const RecursionLimitException&) { - return {std::nullopt, false}; + return std::nullopt; } } -std::pair, bool> TypeReduction::reduceImpl(TypePackId tp) +std::optional TypeReduction::reduce(TypePackId tp) { - if (FFlag::DebugLuauDontReduceTypes) - return {tp, false}; + tp = follow(tp); - if (hasExceededCartesianProductLimit(tp)) - return {std::nullopt, false}; + if (FFlag::DebugLuauDontReduceTypes) + return tp; + else if (!options.allowTypeReductionsFromOtherArenas && tp->owningArena != arena) + return tp; + else if (auto ctx = memoizedTypePacks.find(tp); ctx && ctx->irreducible) + return ctx->type; + else if (hasExceededCartesianProductLimit(tp)) + return std::nullopt; try { - TypeReducer reducer{arena, builtinTypes, handle}; - return {reducer.reduce(tp), reducer.cacheOk}; + MarkCycles finder; + finder.traverse(tp); + + TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclicTypes}; + return reducer.reduce(tp); } catch (const RecursionLimitException&) { - return {std::nullopt, false}; + return std::nullopt; } } diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index e0cc14149..9db8f7f00 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Unifiable.h" -LUAU_FASTFLAG(LuauTypeNormalization2); - namespace Luau { namespace Unifiable @@ -11,19 +9,19 @@ namespace Unifiable static int nextIndex = 0; Free::Free(TypeLevel level) - : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + : index(++nextIndex) , level(level) { } Free::Free(Scope* scope) - : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + : index(++nextIndex) , scope(scope) { } Free::Free(Scope* scope, TypeLevel level) - : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + : index(++nextIndex) , level(level) , scope(scope) { @@ -32,33 +30,33 @@ Free::Free(Scope* scope, TypeLevel level) int Free::DEPRECATED_nextIndex = 0; Generic::Generic() - : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + : index(++nextIndex) , name("g" + std::to_string(index)) { } Generic::Generic(TypeLevel level) - : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + : index(++nextIndex) , level(level) , name("g" + std::to_string(index)) { } Generic::Generic(const Name& name) - : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + : index(++nextIndex) , name(name) , explicitName(true) { } Generic::Generic(Scope* scope) - : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + : index(++nextIndex) , scope(scope) { } Generic::Generic(TypeLevel level, const Name& name) - : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + : index(++nextIndex) , level(level) , name(name) , explicitName(true) @@ -66,7 +64,7 @@ Generic::Generic(TypeLevel level, const Name& name) } Generic::Generic(Scope* scope, const Name& name) - : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + : index(++nextIndex) , scope(scope) , name(name) , explicitName(true) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index d35e37710..80f63f10a 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,15 +18,13 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) -LUAU_FASTFLAGVARIABLE(LuauReportTypeMismatchForTypePackUnificationFailure, false) -LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) +LUAU_FASTFLAGVARIABLE(LuauUnifyAnyTxnLog, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner2, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) +LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) -LUAU_FASTFLAG(LuauTxnLogTypePackIterator) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNegatedFunctionTypes) LUAU_FASTFLAG(LuauNegatedClassTypes) @@ -378,7 +376,7 @@ Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope , variance(variance) , sharedState(*normalizer->sharedState) { - normalize = FFlag::LuauSubtypeNormalizer; + normalize = true; LUAU_ASSERT(sharedState.iceHandler); } @@ -480,17 +478,40 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return; } - if (get(superTy) || get(superTy) || get(superTy)) - return tryUnifyWithAny(subTy, superTy); + if (FFlag::LuauUnifyAnyTxnLog) + { + if (log.get(superTy)) + return tryUnifyWithAny(subTy, builtinTypes->anyType); + + if (log.get(superTy)) + return tryUnifyWithAny(subTy, builtinTypes->errorType); - if (get(subTy)) - return tryUnifyWithAny(superTy, subTy); + if (log.get(superTy)) + return tryUnifyWithAny(subTy, builtinTypes->unknownType); - if (log.get(subTy)) - return tryUnifyWithAny(superTy, subTy); + if (log.get(subTy)) + return tryUnifyWithAny(superTy, builtinTypes->anyType); - if (log.get(subTy)) - return tryUnifyWithAny(superTy, subTy); + if (log.get(subTy)) + return tryUnifyWithAny(superTy, builtinTypes->errorType); + + if (log.get(subTy)) + return tryUnifyWithAny(superTy, builtinTypes->neverType); + } + else + { + if (get(superTy) || get(superTy) || get(superTy)) + return tryUnifyWithAny(subTy, superTy); + + if (get(subTy)) + return tryUnifyWithAny(superTy, subTy); + + if (log.get(subTy)) + return tryUnifyWithAny(superTy, subTy); + + if (log.get(subTy)) + return tryUnifyWithAny(superTy, subTy); + } auto& cache = sharedState.cachedUnify; @@ -524,10 +545,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyUnionWithType(subTy, subUnion, superTy); } - else if (const UnionType* uv = (FFlag::LuauSubtypeNormalizer ? nullptr : log.getMutable(superTy))) - { - tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); - } else if (const IntersectionType* uv = log.getMutable(superTy)) { tryUnifyTypeWithIntersection(subTy, superTy, uv); @@ -915,8 +932,6 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* void Unifier::tryUnifyNormalizedTypes( TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error) { - LUAU_ASSERT(FFlag::LuauSubtypeNormalizer); - if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) return; else if (get(subNorm.tops)) @@ -1096,12 +1111,9 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized log.concat(std::move(innerState.log)); if (result) { - if (FFlag::LuauOverloadedFunctionSubtypingPerf) - { - innerState.log.clear(); - innerState.tryUnify_(*result, ftv->retTypes); - } - if (FFlag::LuauOverloadedFunctionSubtypingPerf && innerState.errors.empty()) + innerState.log.clear(); + innerState.tryUnify_(*result, ftv->retTypes); + if (innerState.errors.empty()) log.concat(std::move(innerState.log)); // Annoyingly, since we don't support intersection of generic type packs, // the intersection may fail. We rather arbitrarily use the first matching overload @@ -1250,8 +1262,11 @@ struct WeirdIter LUAU_ASSERT(canGrow()); LUAU_ASSERT(log.getMutable(newTail)); - level = log.getMutable(packId)->level; - scope = log.getMutable(packId)->scope; + auto freePack = log.getMutable(packId); + + level = freePack->level; + if (FFlag::LuauMaintainScopesInUnifier && freePack->scope != nullptr) + scope = freePack->scope; log.replace(packId, BoundTypePack(newTail)); packId = newTail; pack = log.getMutable(newTail); @@ -1380,6 +1395,12 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal auto superIter = WeirdIter(superTp, log); auto subIter = WeirdIter(subTp, log); + if (FFlag::LuauMaintainScopesInUnifier) + { + superIter.scope = scope.get(); + subIter.scope = scope.get(); + } + auto mkFreshType = [this](Scope* scope, TypeLevel level) { return types->freshType(scope, level); }; @@ -1420,15 +1441,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - if (!FFlag::LuauTxnLogTypePackIterator && subTpv->tail && superTpv->tail) - { - tryUnify_(*subTpv->tail, *superTpv->tail); - break; - } - const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (FFlag::LuauTxnLogTypePackIterator && lFreeTail && rFreeTail) + if (lFreeTail && rFreeTail) { tryUnify_(*subTpv->tail, *superTpv->tail); } @@ -1440,7 +1455,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal { tryUnify_(emptyTp, *subTpv->tail); } - else if (FFlag::LuauTxnLogTypePackIterator && subTpv->tail && superTpv->tail) + else if (subTpv->tail && superTpv->tail) { if (log.getMutable(superIter.packId)) tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); @@ -1523,10 +1538,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else { - if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure) - reportError(location, TypePackMismatch{subTp, superTp}); - else - reportError(location, GenericError{"Failed to unify type packs"}); + reportError(location, TypePackMismatch{subTp, superTp}); } } @@ -2356,11 +2368,11 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever if (!superVariadic) ice("passed non-variadic pack to tryUnifyVariadics"); - if (const VariadicTypePack* subVariadic = FFlag::LuauTxnLogTypePackIterator ? log.get(subTp) : get(subTp)) + if (const VariadicTypePack* subVariadic = log.get(subTp)) { tryUnify_(reversed ? superVariadic->ty : subVariadic->ty, reversed ? subVariadic->ty : superVariadic->ty); } - else if (FFlag::LuauTxnLogTypePackIterator ? log.get(subTp) : get(subTp)) + else if (log.get(subTp)) { TypePackIterator subIter = begin(subTp, &log); TypePackIterator subEnd = end(subTp); @@ -2465,9 +2477,18 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) { LUAU_ASSERT(get(anyTy) || get(anyTy) || get(anyTy) || get(anyTy)); - // These types are not visited in general loop below - if (get(subTy) || get(subTy) || get(subTy)) - return; + if (FFlag::LuauUnifyAnyTxnLog) + { + // These types are not visited in general loop below + if (log.get(subTy) || log.get(subTy) || log.get(subTy)) + return; + } + else + { + // These types are not visited in general loop below + if (get(subTy) || get(subTy) || get(subTy)) + return; + } TypePackId anyTp; if (FFlag::LuauUnknownAndNeverType) diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index aa87d9e86..7731312db 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -983,12 +983,13 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + AstStatTypeAlias(const Location& location, const AstName& name, const Location& nameLocation, const AstArray& generics, const AstArray& genericPacks, AstType* type, bool exported); void visit(AstVisitor* visitor) override; AstName name; + Location nameLocation; AstArray generics; AstArray genericPacks; AstType* type; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index cbed8bae1..e01ced049 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -647,10 +647,11 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const Location& nameLocation, + const AstArray& generics, const AstArray& genericPacks, AstType* type, bool exported) : AstStat(ClassIndex(), location) , name(name) + , nameLocation(nameLocation) , generics(generics) , genericPacks(genericPacks) , type(type) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index dea54c168..c71bd7c58 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -768,7 +768,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) AstType* type = parseTypeAnnotation(); - return allocator.alloc(Location(start, type->location), name->name, generics, genericPacks, type, exported); + return allocator.alloc(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 72c1294f4..405b92ddd 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -244,7 +244,7 @@ static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& emitInstForNPrep(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); break; case LOP_FORNLOOP: - emitInstForNLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); + emitInstForNLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next); break; case LOP_FORGLOOP: emitinstForGLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next, fallback); diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 2c410ae87..f1f6ba66e 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -251,7 +251,7 @@ void callGetFastTmOrFallback(AssemblyBuilderX64& build, RegisterX64 table, TMS t // rArg1 is already prepared build.mov(rArg2, tm); build.mov(rax, qword[rState + offsetof(lua_State, global)]); - build.mov(rArg3, qword[rax + offsetof(global_State, tmname[tm])]); + build.mov(rArg3, qword[rax + offsetof(global_State, tmname) + tm * sizeof(TString*)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); } diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 7b6e1c643..e1212eab1 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -1010,17 +1010,15 @@ static int emitInstFastCallN( if (nparams == LUA_MULTRET) { - // TODO: for SystemV ABI we can compute the result directly into rArg6 // L->top - (ra + 1) - build.mov(rcx, qword[rState + offsetof(lua_State, top)]); + RegisterX64 reg = (build.abi == ABIX64::Windows) ? rcx : rArg6; + build.mov(reg, qword[rState + offsetof(lua_State, top)]); build.lea(rdx, addr[rBase + (ra + 1) * sizeof(TValue)]); - build.sub(rcx, rdx); - build.shr(rcx, kTValueSizeLog2); + build.sub(reg, rdx); + build.shr(reg, kTValueSizeLog2); if (build.abi == ABIX64::Windows) - build.mov(sArg6, rcx); - else - build.mov(rArg6, rcx); + build.mov(sArg6, reg); } else { @@ -1126,7 +1124,7 @@ void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpo build.setLabel(exit); } -void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat) +void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit) { emitInterrupt(build, pcpos); @@ -1144,20 +1142,18 @@ void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo build.vaddsd(idx, idx, step); build.vmovsd(luauRegValue(ra + 2), idx); - Label reverse, exit; + Label reverse; // step <= 0 jumpOnNumberCmp(build, noreg, step, zero, ConditionX64::LessEqual, reverse); // false: idx <= limit jumpOnNumberCmp(build, noreg, idx, limit, ConditionX64::LessEqual, loopRepeat); - build.jmp(exit); + build.jmp(loopExit); // true: limit <= idx build.setLabel(reverse); jumpOnNumberCmp(build, noreg, limit, idx, ConditionX64::LessEqual, loopRepeat); - - build.setLabel(exit); } void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback) diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index 96501e63d..83dfa8c84 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -61,7 +61,7 @@ int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpo int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopExit); -void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat); +void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit); void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback); void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat); void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp new file mode 100644 index 000000000..56968c1aa --- /dev/null +++ b/CodeGen/src/IrBuilder.cpp @@ -0,0 +1,563 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "IrBuilder.h" + +#include "Luau/Common.h" + +#include "CustomExecUtils.h" +#include "IrTranslation.h" +#include "IrUtils.h" + +#include "lapi.h" + +namespace Luau +{ +namespace CodeGen +{ + +constexpr unsigned kNoAssociatedBlockIndex = ~0u; + +void IrBuilder::buildFunctionIr(Proto* proto) +{ + function.proto = proto; + + // Rebuild original control flow blocks + rebuildBytecodeBasicBlocks(proto); + + function.bcMapping.resize(proto->sizecode, {~0u, 0}); + + // Translate all instructions to IR inside blocks + for (int i = 0; i < proto->sizecode;) + { + const Instruction* pc = &proto->code[i]; + LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc)); + + int nexti = i + getOpLength(op); + LUAU_ASSERT(nexti <= proto->sizecode); + + function.bcMapping[i] = {uint32_t(function.instructions.size()), 0}; + + // Begin new block at this instruction if it was in the bytecode or requested during translation + if (instIndexToBlock[i] != kNoAssociatedBlockIndex) + beginBlock(blockAtInst(i)); + + translateInst(op, pc, i); + + i = nexti; + LUAU_ASSERT(i <= proto->sizecode); + + // If we are going into a new block at the next instruction and it's a fallthrough, jump has to be placed to mark block termination + if (i < int(instIndexToBlock.size()) && instIndexToBlock[i] != kNoAssociatedBlockIndex) + { + if (!isBlockTerminator(function.instructions.back().cmd)) + inst(IrCmd::JUMP, blockAtInst(i)); + } + } +} + +void IrBuilder::rebuildBytecodeBasicBlocks(Proto* proto) +{ + instIndexToBlock.resize(proto->sizecode, kNoAssociatedBlockIndex); + + // Mark jump targets + std::vector jumpTargets(proto->sizecode, 0); + + for (int i = 0; i < proto->sizecode;) + { + const Instruction* pc = &proto->code[i]; + LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc)); + + int target = getJumpTarget(*pc, uint32_t(i)); + + if (target >= 0 && !isFastCall(op)) + jumpTargets[target] = true; + + i += getOpLength(op); + LUAU_ASSERT(i <= proto->sizecode); + } + + + // Bytecode blocks are created at bytecode jump targets and the start of a function + jumpTargets[0] = true; + + for (int i = 0; i < proto->sizecode; i++) + { + if (jumpTargets[i]) + { + IrOp b = block(IrBlockKind::Bytecode); + instIndexToBlock[i] = b.index; + } + } +} + +void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) +{ + switch (op) + { + case LOP_NOP: + break; + case LOP_LOADNIL: + translateInstLoadNil(*this, pc); + break; + case LOP_LOADB: + translateInstLoadB(*this, pc, i); + break; + case LOP_LOADN: + translateInstLoadN(*this, pc); + break; + case LOP_LOADK: + translateInstLoadK(*this, pc); + break; + case LOP_LOADKX: + translateInstLoadKX(*this, pc); + break; + case LOP_MOVE: + translateInstMove(*this, pc); + break; + case LOP_GETGLOBAL: + translateInstGetGlobal(*this, pc, i); + break; + case LOP_SETGLOBAL: + translateInstSetGlobal(*this, pc, i); + break; + case LOP_CALL: + inst(IrCmd::LOP_CALL, constUint(i)); + + if (activeFastcallFallback) + { + inst(IrCmd::JUMP, fastcallFallbackReturn); + + beginBlock(fastcallFallbackReturn); + + activeFastcallFallback = false; + } + break; + case LOP_RETURN: + inst(IrCmd::LOP_RETURN, constUint(i)); + break; + case LOP_GETTABLE: + translateInstGetTable(*this, pc, i); + break; + case LOP_SETTABLE: + translateInstSetTable(*this, pc, i); + break; + case LOP_GETTABLEKS: + translateInstGetTableKS(*this, pc, i); + break; + case LOP_SETTABLEKS: + translateInstSetTableKS(*this, pc, i); + break; + case LOP_GETTABLEN: + translateInstGetTableN(*this, pc, i); + break; + case LOP_SETTABLEN: + translateInstSetTableN(*this, pc, i); + break; + case LOP_JUMP: + translateInstJump(*this, pc, i); + break; + case LOP_JUMPBACK: + translateInstJumpBack(*this, pc, i); + break; + case LOP_JUMPIF: + translateInstJumpIf(*this, pc, i, /* not_ */ false); + break; + case LOP_JUMPIFNOT: + translateInstJumpIf(*this, pc, i, /* not_ */ true); + break; + case LOP_JUMPIFEQ: + translateInstJumpIfEq(*this, pc, i, /* not_ */ false); + break; + case LOP_JUMPIFLE: + translateInstJumpIfCond(*this, pc, i, IrCondition::LessEqual); + break; + case LOP_JUMPIFLT: + translateInstJumpIfCond(*this, pc, i, IrCondition::Less); + break; + case LOP_JUMPIFNOTEQ: + translateInstJumpIfEq(*this, pc, i, /* not_ */ true); + break; + case LOP_JUMPIFNOTLE: + translateInstJumpIfCond(*this, pc, i, IrCondition::NotLessEqual); + break; + case LOP_JUMPIFNOTLT: + translateInstJumpIfCond(*this, pc, i, IrCondition::NotLess); + break; + case LOP_JUMPX: + translateInstJumpX(*this, pc, i); + break; + case LOP_JUMPXEQKNIL: + translateInstJumpxEqNil(*this, pc, i); + break; + case LOP_JUMPXEQKB: + translateInstJumpxEqB(*this, pc, i); + break; + case LOP_JUMPXEQKN: + translateInstJumpxEqN(*this, pc, i); + break; + case LOP_JUMPXEQKS: + translateInstJumpxEqS(*this, pc, i); + break; + case LOP_ADD: + translateInstBinary(*this, pc, i, TM_ADD); + break; + case LOP_SUB: + translateInstBinary(*this, pc, i, TM_SUB); + break; + case LOP_MUL: + translateInstBinary(*this, pc, i, TM_MUL); + break; + case LOP_DIV: + translateInstBinary(*this, pc, i, TM_DIV); + break; + case LOP_MOD: + translateInstBinary(*this, pc, i, TM_MOD); + break; + case LOP_POW: + translateInstBinary(*this, pc, i, TM_POW); + break; + case LOP_ADDK: + translateInstBinaryK(*this, pc, i, TM_ADD); + break; + case LOP_SUBK: + translateInstBinaryK(*this, pc, i, TM_SUB); + break; + case LOP_MULK: + translateInstBinaryK(*this, pc, i, TM_MUL); + break; + case LOP_DIVK: + translateInstBinaryK(*this, pc, i, TM_DIV); + break; + case LOP_MODK: + translateInstBinaryK(*this, pc, i, TM_MOD); + break; + case LOP_POWK: + translateInstBinaryK(*this, pc, i, TM_POW); + break; + case LOP_NOT: + translateInstNot(*this, pc); + break; + case LOP_MINUS: + translateInstMinus(*this, pc, i); + break; + case LOP_LENGTH: + translateInstLength(*this, pc, i); + break; + case LOP_NEWTABLE: + translateInstNewTable(*this, pc, i); + break; + case LOP_DUPTABLE: + translateInstDupTable(*this, pc, i); + break; + case LOP_SETLIST: + inst(IrCmd::LOP_SETLIST, constUint(i)); + break; + case LOP_GETUPVAL: + translateInstGetUpval(*this, pc, i); + break; + case LOP_SETUPVAL: + translateInstSetUpval(*this, pc, i); + break; + case LOP_CLOSEUPVALS: + translateInstCloseUpvals(*this, pc); + break; + case LOP_FASTCALL: + { + IrOp fallback = block(IrBlockKind::Fallback); + IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + + inst(IrCmd::LOP_FASTCALL, constUint(i), fallback); + inst(IrCmd::JUMP, next); + + beginBlock(fallback); + + activeFastcallFallback = true; + fastcallFallbackReturn = next; + break; + } + case LOP_FASTCALL1: + { + IrOp fallback = block(IrBlockKind::Fallback); + IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + + inst(IrCmd::LOP_FASTCALL1, constUint(i), fallback); + inst(IrCmd::JUMP, next); + + beginBlock(fallback); + + activeFastcallFallback = true; + fastcallFallbackReturn = next; + break; + } + case LOP_FASTCALL2: + { + IrOp fallback = block(IrBlockKind::Fallback); + IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + + inst(IrCmd::LOP_FASTCALL2, constUint(i), fallback); + inst(IrCmd::JUMP, next); + + beginBlock(fallback); + + activeFastcallFallback = true; + fastcallFallbackReturn = next; + break; + } + case LOP_FASTCALL2K: + { + IrOp fallback = block(IrBlockKind::Fallback); + IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + + inst(IrCmd::LOP_FASTCALL2K, constUint(i), fallback); + inst(IrCmd::JUMP, next); + + beginBlock(fallback); + + activeFastcallFallback = true; + fastcallFallbackReturn = next; + break; + } + case LOP_FORNPREP: + { + IrOp loopExit = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); + + inst(IrCmd::LOP_FORNPREP, constUint(i), loopExit); + break; + } + case LOP_FORNLOOP: + { + IrOp loopRepeat = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); + IrOp loopExit = blockAtInst(i + getOpLength(LOP_FORNLOOP)); + + inst(IrCmd::LOP_FORNLOOP, constUint(i), loopRepeat, loopExit); + + beginBlock(loopExit); + break; + } + case LOP_FORGLOOP: + { + IrOp loopRepeat = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); + IrOp loopExit = blockAtInst(i + getOpLength(LOP_FORGLOOP)); + IrOp fallback = block(IrBlockKind::Fallback); + + inst(IrCmd::LOP_FORGLOOP, constUint(i), loopRepeat, loopExit, fallback); + + beginBlock(fallback); + inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), loopRepeat, loopExit); + + beginBlock(loopExit); + break; + } + case LOP_FORGPREP_NEXT: + { + IrOp target = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); + IrOp fallback = block(IrBlockKind::Fallback); + + inst(IrCmd::LOP_FORGPREP_NEXT, constUint(i), target, fallback); + + beginBlock(fallback); + inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, constUint(i), target); + break; + } + case LOP_FORGPREP_INEXT: + { + IrOp target = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); + IrOp fallback = block(IrBlockKind::Fallback); + + inst(IrCmd::LOP_FORGPREP_INEXT, constUint(i), target, fallback); + + beginBlock(fallback); + inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, constUint(i), target); + break; + } + case LOP_AND: + inst(IrCmd::LOP_AND, constUint(i)); + break; + case LOP_ANDK: + inst(IrCmd::LOP_ANDK, constUint(i)); + break; + case LOP_OR: + inst(IrCmd::LOP_OR, constUint(i)); + break; + case LOP_ORK: + inst(IrCmd::LOP_ORK, constUint(i)); + break; + case LOP_COVERAGE: + inst(IrCmd::LOP_COVERAGE, constUint(i)); + break; + case LOP_GETIMPORT: + translateInstGetImport(*this, pc, i); + break; + case LOP_CONCAT: + translateInstConcat(*this, pc, i); + break; + case LOP_CAPTURE: + translateInstCapture(*this, pc, i); + break; + case LOP_NAMECALL: + { + IrOp next = blockAtInst(i + getOpLength(LOP_NAMECALL)); + IrOp fallback = block(IrBlockKind::Fallback); + + inst(IrCmd::LOP_NAMECALL, constUint(i), next, fallback); + + beginBlock(fallback); + inst(IrCmd::FALLBACK_NAMECALL, constUint(i)); + inst(IrCmd::JUMP, next); + + beginBlock(next); + break; + } + case LOP_PREPVARARGS: + inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i)); + break; + case LOP_GETVARARGS: + inst(IrCmd::FALLBACK_GETVARARGS, constUint(i)); + break; + case LOP_NEWCLOSURE: + inst(IrCmd::FALLBACK_NEWCLOSURE, constUint(i)); + break; + case LOP_DUPCLOSURE: + inst(IrCmd::FALLBACK_DUPCLOSURE, constUint(i)); + break; + case LOP_FORGPREP: + inst(IrCmd::FALLBACK_FORGPREP, constUint(i)); + break; + default: + LUAU_ASSERT(!"unknown instruction"); + break; + } +} + +bool IrBuilder::isInternalBlock(IrOp block) +{ + IrBlock& target = function.blocks[block.index]; + + return target.kind == IrBlockKind::Internal; +} + +void IrBuilder::beginBlock(IrOp block) +{ + function.blocks[block.index].start = uint32_t(function.instructions.size()); +} + +IrOp IrBuilder::constBool(bool value) +{ + IrConst constant; + constant.kind = IrConstKind::Bool; + constant.valueBool = value; + return constAny(constant); +} + +IrOp IrBuilder::constInt(int value) +{ + IrConst constant; + constant.kind = IrConstKind::Int; + constant.valueInt = value; + return constAny(constant); +} + +IrOp IrBuilder::constUint(unsigned value) +{ + IrConst constant; + constant.kind = IrConstKind::Uint; + constant.valueUint = value; + return constAny(constant); +} + +IrOp IrBuilder::constDouble(double value) +{ + IrConst constant; + constant.kind = IrConstKind::Double; + constant.valueDouble = value; + return constAny(constant); +} + +IrOp IrBuilder::constTag(uint8_t value) +{ + IrConst constant; + constant.kind = IrConstKind::Tag; + constant.valueTag = value; + return constAny(constant); +} + +IrOp IrBuilder::constAny(IrConst constant) +{ + uint32_t index = uint32_t(function.constants.size()); + function.constants.push_back(constant); + return {IrOpKind::Constant, index}; +} + +IrOp IrBuilder::cond(IrCondition cond) +{ + return {IrOpKind::Condition, uint32_t(cond)}; +} + +IrOp IrBuilder::inst(IrCmd cmd) +{ + return inst(cmd, {}, {}, {}, {}, {}); +} + +IrOp IrBuilder::inst(IrCmd cmd, IrOp a) +{ + return inst(cmd, a, {}, {}, {}, {}); +} + +IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b) +{ + return inst(cmd, a, b, {}, {}, {}); +} + +IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c) +{ + return inst(cmd, a, b, c, {}, {}); +} + +IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d) +{ + return inst(cmd, a, b, c, d, {}); +} + +IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e) +{ + uint32_t index = uint32_t(function.instructions.size()); + function.instructions.push_back({cmd, a, b, c, d, e}); + return {IrOpKind::Inst, index}; +} + +IrOp IrBuilder::block(IrBlockKind kind) +{ + if (kind == IrBlockKind::Internal && activeFastcallFallback) + kind = IrBlockKind::Fallback; + + uint32_t index = uint32_t(function.blocks.size()); + function.blocks.push_back(IrBlock{kind, ~0u}); + return IrOp{IrOpKind::Block, index}; +} + +IrOp IrBuilder::blockAtInst(uint32_t index) +{ + uint32_t blockIndex = instIndexToBlock[index]; + + if (blockIndex != kNoAssociatedBlockIndex) + return IrOp{IrOpKind::Block, blockIndex}; + + return block(IrBlockKind::Internal); +} + +IrOp IrBuilder::vmReg(uint8_t index) +{ + return {IrOpKind::VmReg, index}; +} + +IrOp IrBuilder::vmConst(uint32_t index) +{ + return {IrOpKind::VmConst, index}; +} + +IrOp IrBuilder::vmUpvalue(uint8_t index) +{ + return {IrOpKind::VmUpvalue, index}; +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrBuilder.h b/CodeGen/src/IrBuilder.h new file mode 100644 index 000000000..c8f9b4ec1 --- /dev/null +++ b/CodeGen/src/IrBuilder.h @@ -0,0 +1,63 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Bytecode.h" + +#include "IrData.h" + +#include + +struct Proto; +typedef uint32_t Instruction; + +namespace Luau +{ +namespace CodeGen +{ + +struct AssemblyOptions; + +struct IrBuilder +{ + void buildFunctionIr(Proto* proto); + + void rebuildBytecodeBasicBlocks(Proto* proto); + void translateInst(LuauOpcode op, const Instruction* pc, int i); + + bool isInternalBlock(IrOp block); + void beginBlock(IrOp block); + + IrOp constBool(bool value); + IrOp constInt(int value); + IrOp constUint(unsigned value); + IrOp constDouble(double value); + IrOp constTag(uint8_t value); + IrOp constAny(IrConst constant); + + IrOp cond(IrCondition cond); + + IrOp inst(IrCmd cmd); + IrOp inst(IrCmd cmd, IrOp a); + IrOp inst(IrCmd cmd, IrOp a, IrOp b); + IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c); + IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d); + IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e); + + IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence + IrOp blockAtInst(uint32_t index); + + IrOp vmReg(uint8_t index); + IrOp vmConst(uint32_t index); + IrOp vmUpvalue(uint8_t index); + + bool activeFastcallFallback = false; + IrOp fastcallFallbackReturn; + + IrFunction function; + + std::vector instIndexToBlock; // Block index at the bytecode instruction +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrData.h b/CodeGen/src/IrData.h index c4ed47ccb..1c70c8017 100644 --- a/CodeGen/src/IrData.h +++ b/CodeGen/src/IrData.h @@ -9,6 +9,8 @@ #include +struct Proto; + namespace Luau { namespace CodeGen @@ -99,6 +101,7 @@ enum class IrCmd : uint8_t // Operations that don't have an IR representation yet LOP_SETLIST, + LOP_NAMECALL, LOP_CALL, LOP_RETURN, LOP_FASTCALL, @@ -116,21 +119,21 @@ enum class IrCmd : uint8_t LOP_ANDK, LOP_OR, LOP_ORK, + LOP_COVERAGE, // Operations that have a translation, but use a full instruction fallback FALLBACK_GETGLOBAL, FALLBACK_SETGLOBAL, FALLBACK_GETTABLEKS, FALLBACK_SETTABLEKS, + FALLBACK_NAMECALL, // Operations that don't have assembly lowering at all - FALLBACK_NAMECALL, FALLBACK_PREPVARARGS, FALLBACK_GETVARARGS, FALLBACK_NEWCLOSURE, FALLBACK_DUPCLOSURE, FALLBACK_FORGPREP, - FALLBACK_COVERAGE, }; enum class IrConstKind : uint8_t @@ -274,6 +277,8 @@ struct IrFunction std::vector constants; std::vector bcMapping; + + Proto* proto = nullptr; }; } // namespace CodeGen diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 5d54026a1..4dc5c6c54 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -186,6 +186,8 @@ const char* getCmdName(IrCmd cmd) return "CAPTURE"; case IrCmd::LOP_SETLIST: return "LOP_SETLIST"; + case IrCmd::LOP_NAMECALL: + return "LOP_NAMECALL"; case IrCmd::LOP_CALL: return "LOP_CALL"; case IrCmd::LOP_RETURN: @@ -220,6 +222,8 @@ const char* getCmdName(IrCmd cmd) return "LOP_OR"; case IrCmd::LOP_ORK: return "LOP_ORK"; + case IrCmd::LOP_COVERAGE: + return "LOP_COVERAGE"; case IrCmd::FALLBACK_GETGLOBAL: return "FALLBACK_GETGLOBAL"; case IrCmd::FALLBACK_SETGLOBAL: @@ -240,8 +244,6 @@ const char* getCmdName(IrCmd cmd) return "FALLBACK_DUPCLOSURE"; case IrCmd::FALLBACK_FORGPREP: return "FALLBACK_FORGPREP"; - case IrCmd::FALLBACK_COVERAGE: - return "FALLBACK_COVERAGE"; } LUAU_UNREACHABLE(); @@ -375,5 +377,48 @@ void toStringDetailed(IrToStringContext& ctx, IrInst inst, uint32_t index) append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); } +std::string dump(IrFunction& function) +{ + std::string result; + IrToStringContext ctx{result, function.blocks, function.constants}; + + for (size_t i = 0; i < function.blocks.size(); i++) + { + IrBlock& block = function.blocks[i]; + + append(ctx.result, "%s_%u:\n", getBlockKindName(block.kind), unsigned(i)); + + if (block.start == ~0u) + { + append(ctx.result, " *empty*\n\n"); + continue; + } + + for (uint32_t index = block.start; true; index++) + { + LUAU_ASSERT(index < function.instructions.size()); + + IrInst& inst = function.instructions[index]; + + // Nop is used to replace dead instructions in-place, so it's not that useful to see them + if (inst.cmd == IrCmd::NOP) + continue; + + append(ctx.result, " "); + toStringDetailed(ctx, inst, index); + + if (isBlockTerminator(inst.cmd)) + { + append(ctx.result, "\n"); + break; + } + } + } + + printf("%s\n", result.c_str()); + + return result; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrDump.h b/CodeGen/src/IrDump.h index 8fb4d6e5f..c803e8db1 100644 --- a/CodeGen/src/IrDump.h +++ b/CodeGen/src/IrDump.h @@ -28,5 +28,7 @@ void toString(std::string& result, IrConst constant); void toStringDetailed(IrToStringContext& ctx, IrInst inst, uint32_t index); +std::string dump(IrFunction& function); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp new file mode 100644 index 000000000..822578150 --- /dev/null +++ b/CodeGen/src/IrTranslation.cpp @@ -0,0 +1,780 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "IrTranslation.h" + +#include "Luau/Bytecode.h" + +#include "IrBuilder.h" + +#include "lobject.h" +#include "ltm.h" + +namespace Luau +{ +namespace CodeGen +{ + +// Helper to consistently define a switch to instruction fallback code +struct FallbackStreamScope +{ + FallbackStreamScope(IrBuilder& build, IrOp fallback, IrOp next) + : build(build) + , next(next) + { + LUAU_ASSERT(fallback.kind == IrOpKind::Block); + LUAU_ASSERT(next.kind == IrOpKind::Block); + + build.inst(IrCmd::JUMP, next); + build.beginBlock(fallback); + } + + ~FallbackStreamScope() + { + build.beginBlock(next); + } + + IrBuilder& build; + IrOp next; +}; + +void translateInstLoadNil(IrBuilder& build, const Instruction* pc) +{ + int ra = LUAU_INSN_A(*pc); + + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNIL)); +} + +void translateInstLoadB(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + + build.inst(IrCmd::STORE_INT, build.vmReg(ra), build.constInt(LUAU_INSN_B(*pc))); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TBOOLEAN)); + + if (int target = LUAU_INSN_C(*pc)) + build.inst(IrCmd::JUMP, build.blockAtInst(pcpos + 1 + target)); +} + +void translateInstLoadN(IrBuilder& build, const Instruction* pc) +{ + int ra = LUAU_INSN_A(*pc); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), build.constDouble(double(LUAU_INSN_D(*pc)))); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); +} + +void translateInstLoadK(IrBuilder& build, const Instruction* pc) +{ + int ra = LUAU_INSN_A(*pc); + + // TODO: per-component loads and stores might be preferable + IrOp load = build.inst(IrCmd::LOAD_TVALUE, build.vmConst(LUAU_INSN_D(*pc))); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load); +} + +void translateInstLoadKX(IrBuilder& build, const Instruction* pc) +{ + int ra = LUAU_INSN_A(*pc); + uint32_t aux = pc[1]; + + // TODO: per-component loads and stores might be preferable + IrOp load = build.inst(IrCmd::LOAD_TVALUE, build.vmConst(aux)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load); +} + +void translateInstMove(IrBuilder& build, const Instruction* pc) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + + // TODO: per-component loads and stores might be preferable + IrOp load = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load); +} + +void translateInstJump(IrBuilder& build, const Instruction* pc, int pcpos) +{ + build.inst(IrCmd::JUMP, build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc))); +} + +void translateInstJumpBack(IrBuilder& build, const Instruction* pc, int pcpos) +{ + build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + build.inst(IrCmd::JUMP, build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc))); +} + +void translateInstJumpIf(IrBuilder& build, const Instruction* pc, int pcpos, bool not_) +{ + int ra = LUAU_INSN_A(*pc); + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp next = build.blockAtInst(pcpos + 1); + + // TODO: falsy/truthy conditions should be deconstructed into more primitive operations + if (not_) + build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(ra), target, next); + else + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(ra), target, next); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(next)) + build.beginBlock(next); +} + +void translateInstJumpIfEq(IrBuilder& build, const Instruction* pc, int pcpos, bool not_) +{ + int ra = LUAU_INSN_A(*pc); + int rb = pc[1]; + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp next = build.blockAtInst(pcpos + 2); + IrOp numberCheck = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp ta = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra)); + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::JUMP_EQ_TAG, ta, tb, numberCheck, not_ ? target : next); + + build.beginBlock(numberCheck); + + // fast-path: number + build.inst(IrCmd::CHECK_TAG, ta, build.constTag(LUA_TNUMBER), fallback); + + IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra)); + IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb)); + + build.inst(IrCmd::JUMP_CMP_NUM, va, vb, build.cond(IrCondition::NotEqual), not_ ? target : next, not_ ? next : target); + + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::JUMP_CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(not_ ? IrCondition::NotEqual : IrCondition::Equal), target, next); +} + +void translateInstJumpIfCond(IrBuilder& build, const Instruction* pc, int pcpos, IrCondition cond) +{ + int ra = LUAU_INSN_A(*pc); + int rb = pc[1]; + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp next = build.blockAtInst(pcpos + 2); + IrOp fallback = build.block(IrBlockKind::Fallback); + + // fast-path: number + IrOp ta = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra)); + build.inst(IrCmd::CHECK_TAG, ta, build.constTag(LUA_TNUMBER), fallback); + + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER), fallback); + + IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra)); + IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb)); + + build.inst(IrCmd::JUMP_CMP_NUM, va, vb, build.cond(cond), target, next); + + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::JUMP_CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(cond), target, next); +} + +void translateInstJumpX(IrBuilder& build, const Instruction* pc, int pcpos) +{ + build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + build.inst(IrCmd::JUMP, build.blockAtInst(pcpos + 1 + LUAU_INSN_E(*pc))); +} + +void translateInstJumpxEqNil(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + bool not_ = (pc[1] & 0x80000000) != 0; + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp next = build.blockAtInst(pcpos + 2); + + IrOp ta = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra)); + build.inst(IrCmd::JUMP_EQ_TAG, ta, build.constTag(LUA_TNIL), not_ ? next : target, not_ ? target : next); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(next)) + build.beginBlock(next); +} + +void translateInstJumpxEqB(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + uint32_t aux = pc[1]; + bool not_ = (aux & 0x80000000) != 0; + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp next = build.blockAtInst(pcpos + 2); + IrOp checkValue = build.block(IrBlockKind::Internal); + + IrOp ta = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra)); + + build.inst(IrCmd::JUMP_EQ_TAG, ta, build.constTag(LUA_TBOOLEAN), checkValue, not_ ? target : next); + + build.beginBlock(checkValue); + IrOp va = build.inst(IrCmd::LOAD_INT, build.vmReg(ra)); + + build.inst(IrCmd::JUMP_EQ_BOOLEAN, va, build.constBool(aux & 0x1), not_ ? next : target, not_ ? target : next); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(next)) + build.beginBlock(next); +} + +void translateInstJumpxEqN(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + uint32_t aux = pc[1]; + bool not_ = (aux & 0x80000000) != 0; + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp next = build.blockAtInst(pcpos + 2); + IrOp checkValue = build.block(IrBlockKind::Internal); + + IrOp ta = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra)); + + build.inst(IrCmd::JUMP_EQ_TAG, ta, build.constTag(LUA_TNUMBER), checkValue, not_ ? target : next); + + build.beginBlock(checkValue); + IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra)); + IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmConst(aux & 0xffffff)); + + build.inst(IrCmd::JUMP_CMP_NUM, va, vb, build.cond(IrCondition::NotEqual), not_ ? target : next, not_ ? next : target); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(next)) + build.beginBlock(next); +} + +void translateInstJumpxEqS(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + uint32_t aux = pc[1]; + bool not_ = (aux & 0x80000000) != 0; + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp next = build.blockAtInst(pcpos + 2); + IrOp checkValue = build.block(IrBlockKind::Internal); + + IrOp ta = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra)); + build.inst(IrCmd::JUMP_EQ_TAG, ta, build.constTag(LUA_TSTRING), checkValue, not_ ? target : next); + + build.beginBlock(checkValue); + IrOp va = build.inst(IrCmd::LOAD_POINTER, build.vmReg(ra)); + IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmConst(aux & 0xffffff)); + + build.inst(IrCmd::JUMP_EQ_POINTER, va, vb, not_ ? next : target, not_ ? target : next); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(next)) + build.beginBlock(next); +} + +static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, IrOp opc, int pcpos, TMS tm) +{ + IrOp fallback = build.block(IrBlockKind::Fallback); + + // fast-path: number + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER), fallback); + + if (rc != -1 && rc != rb) // TODO: optimization should handle second check, but we'll test it later + { + IrOp tc = build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)); + build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER), fallback); + } + + IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb)); + IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, opc); + + IrOp va; + + switch (tm) + { + case TM_ADD: + va = build.inst(IrCmd::ADD_NUM, vb, vc); + break; + case TM_SUB: + va = build.inst(IrCmd::SUB_NUM, vb, vc); + break; + case TM_MUL: + va = build.inst(IrCmd::MUL_NUM, vb, vc); + break; + case TM_DIV: + va = build.inst(IrCmd::DIV_NUM, vb, vc); + break; + case TM_MOD: + va = build.inst(IrCmd::MOD_NUM, vb, vc); + break; + case TM_POW: + va = build.inst(IrCmd::POW_NUM, vb, vc); + break; + default: + LUAU_ASSERT(!"unsupported binary op"); + } + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), va); + + if (ra != rb && ra != rc) // TODO: optimization should handle second check, but we'll test this later + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + IrOp next = build.blockAtInst(pcpos + 1); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), opc, build.constInt(tm)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstBinary(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm) +{ + translateInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), build.vmReg(LUAU_INSN_C(*pc)), pcpos, tm); +} + +void translateInstBinaryK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm) +{ + translateInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, build.vmConst(LUAU_INSN_C(*pc)), pcpos, tm); +} + +void translateInstNot(IrBuilder& build, const Instruction* pc) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + IrOp vb = build.inst(IrCmd::LOAD_INT, build.vmReg(rb)); + + IrOp va = build.inst(IrCmd::NOT_ANY, tb, vb); + + build.inst(IrCmd::STORE_INT, build.vmReg(ra), va); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TBOOLEAN)); +} + +void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER), fallback); + + // fast-path: number + IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb)); + IrOp va = build.inst(IrCmd::UNM_NUM, vb); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), va); + + if (ra != rb) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + IrOp next = build.blockAtInst(pcpos + 1); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback); + + // fast-path: table without __len + IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + build.inst(IrCmd::CHECK_NO_METATABLE, vb, fallback); + + IrOp va = build.inst(IrCmd::TABLE_LEN, vb); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), va); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + IrOp next = build.blockAtInst(pcpos + 1); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc))); + build.inst(IrCmd::JUMP, next); +} + +void translateInstNewTable(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int b = LUAU_INSN_B(*pc); + uint32_t aux = pc[1]; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + + IrOp va = build.inst(IrCmd::NEW_TABLE, build.constUint(aux), build.constUint(b == 0 ? 0 : 1 << (b - 1))); + build.inst(IrCmd::STORE_POINTER, build.vmReg(ra), va); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TTABLE)); + + build.inst(IrCmd::CHECK_GC); +} + +void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int k = LUAU_INSN_D(*pc); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmConst(k)); + IrOp va = build.inst(IrCmd::DUP_TABLE, table); + build.inst(IrCmd::STORE_POINTER, build.vmReg(ra), va); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TTABLE)); + + build.inst(IrCmd::CHECK_GC); +} + +void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int up = LUAU_INSN_B(*pc); + + build.inst(IrCmd::GET_UPVALUE, build.vmReg(ra), build.vmUpvalue(up)); +} + +void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int up = LUAU_INSN_B(*pc); + + build.inst(IrCmd::SET_UPVALUE, build.vmUpvalue(up), build.vmReg(ra)); +} + +void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc) +{ + int ra = LUAU_INSN_A(*pc); + + build.inst(IrCmd::CLOSE_UPVALS, build.vmReg(ra)); +} + +void translateInstGetTableN(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + int c = LUAU_INSN_C(*pc); + + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback); + + IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + + build.inst(IrCmd::CHECK_ARRAY_SIZE, vb, build.constUint(c), fallback); + build.inst(IrCmd::CHECK_NO_METATABLE, vb, fallback); + + IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constUint(c)); + + // TODO: per-component loads and stores might be preferable + IrOp arrElTval = build.inst(IrCmd::LOAD_TVALUE, arrEl); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), arrElTval); + + IrOp next = build.blockAtInst(pcpos + 1); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::GET_TABLE, build.vmReg(ra), build.vmReg(rb), build.constUint(c + 1)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstSetTableN(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + int c = LUAU_INSN_C(*pc); + + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback); + + IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + + build.inst(IrCmd::CHECK_ARRAY_SIZE, vb, build.constUint(c), fallback); + build.inst(IrCmd::CHECK_NO_METATABLE, vb, fallback); + build.inst(IrCmd::CHECK_READONLY, vb, fallback); + + IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constUint(c)); + + // TODO: per-component loads and stores might be preferable + IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); + build.inst(IrCmd::STORE_TVALUE, arrEl, tva); + + build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra)); + + IrOp next = build.blockAtInst(pcpos + 1); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::SET_TABLE, build.vmReg(ra), build.vmReg(rb), build.constUint(c + 1)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstGetTable(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + int rc = LUAU_INSN_C(*pc); + + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback); + IrOp tc = build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)); + build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER), fallback); + + // fast-path: table with a number index + IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rc)); + + IrOp index = build.inst(IrCmd::NUM_TO_INDEX, vc, fallback); + + index = build.inst(IrCmd::SUB_INT, index, build.constInt(1)); + + build.inst(IrCmd::CHECK_ARRAY_SIZE, vb, index, fallback); + build.inst(IrCmd::CHECK_NO_METATABLE, vb, fallback); + + IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, index); + + // TODO: per-component loads and stores might be preferable + IrOp arrElTval = build.inst(IrCmd::LOAD_TVALUE, arrEl); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), arrElTval); + + IrOp next = build.blockAtInst(pcpos + 1); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::GET_TABLE, build.vmReg(ra), build.vmReg(rb), build.vmReg(rc)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstSetTable(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + int rc = LUAU_INSN_C(*pc); + + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback); + IrOp tc = build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)); + build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER), fallback); + + // fast-path: table with a number index + IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rc)); + + IrOp index = build.inst(IrCmd::NUM_TO_INDEX, vc, fallback); + + index = build.inst(IrCmd::SUB_INT, index, build.constInt(1)); + + build.inst(IrCmd::CHECK_ARRAY_SIZE, vb, index, fallback); + build.inst(IrCmd::CHECK_NO_METATABLE, vb, fallback); + build.inst(IrCmd::CHECK_READONLY, vb, fallback); + + IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, index); + + // TODO: per-component loads and stores might be preferable + IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); + build.inst(IrCmd::STORE_TVALUE, arrEl, tva); + + build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra)); + + IrOp next = build.blockAtInst(pcpos + 1); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::SET_TABLE, build.vmReg(ra), build.vmReg(rb), build.vmReg(rc)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstGetImport(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int k = LUAU_INSN_D(*pc); + uint32_t aux = pc[1]; + + IrOp fastPath = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + + // note: if import failed, k[] is nil; we could check this during codegen, but we instead use runtime fallback + // this allows us to handle ahead-of-time codegen smoothly when an import fails to resolve at runtime + IrOp tk = build.inst(IrCmd::LOAD_TAG, build.vmConst(k)); + build.inst(IrCmd::JUMP_EQ_TAG, tk, build.constTag(LUA_TNIL), fallback, fastPath); + + build.beginBlock(fastPath); + + // TODO: per-component loads and stores might be preferable + IrOp tvk = build.inst(IrCmd::LOAD_TVALUE, build.vmConst(k)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvk); + + IrOp next = build.blockAtInst(pcpos + 2); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::GET_IMPORT, build.vmReg(ra), build.constUint(aux)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + uint32_t aux = pc[1]; + + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback); + + IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + + IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, vb, build.constUint(pcpos)); + + build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback); + + // TODO: per-component loads and stores might be preferable + IrOp tvn = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrSlotEl); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvn); + + IrOp next = build.blockAtInst(pcpos + 2); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + uint32_t aux = pc[1]; + + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback); + + IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + + IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, vb, build.constUint(pcpos)); + + build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback); + build.inst(IrCmd::CHECK_READONLY, vb, fallback); + + // TODO: per-component loads and stores might be preferable + IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); + build.inst(IrCmd::STORE_NODE_VALUE_TV, addrSlotEl, tva); + + build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra)); + + IrOp next = build.blockAtInst(pcpos + 2); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::FALLBACK_SETTABLEKS, build.constUint(pcpos)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + uint32_t aux = pc[1]; + + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp env = build.inst(IrCmd::LOAD_ENV); + IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, env, build.constUint(pcpos)); + + build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback); + + // TODO: per-component loads and stores might be preferable + IrOp tvn = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrSlotEl); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvn); + + IrOp next = build.blockAtInst(pcpos + 2); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::FALLBACK_GETGLOBAL, build.constUint(pcpos)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + uint32_t aux = pc[1]; + + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp env = build.inst(IrCmd::LOAD_ENV); + IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, env, build.constUint(pcpos)); + + build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback); + build.inst(IrCmd::CHECK_READONLY, env, fallback); + + // TODO: per-component loads and stores might be preferable + IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); + build.inst(IrCmd::STORE_NODE_VALUE_TV, addrSlotEl, tva); + + build.inst(IrCmd::BARRIER_TABLE_FORWARD, env, build.vmReg(ra)); + + IrOp next = build.blockAtInst(pcpos + 2); + FallbackStreamScope scope(build, fallback, next); + + build.inst(IrCmd::FALLBACK_SETGLOBAL, build.constUint(pcpos)); + build.inst(IrCmd::JUMP, next); +} + +void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + int rc = LUAU_INSN_C(*pc); + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::CONCAT, build.constUint(rc - rb + 1), build.constUint(rc)); + + // TODO: per-component loads and stores might be preferable + IrOp tvb = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvb); + + build.inst(IrCmd::CHECK_GC); +} + +void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int type = LUAU_INSN_A(*pc); + int index = LUAU_INSN_B(*pc); + + switch (type) + { + case LCT_VAL: + build.inst(IrCmd::CAPTURE, build.vmReg(index), build.constBool(false)); + break; + case LCT_REF: + build.inst(IrCmd::CAPTURE, build.vmReg(index), build.constBool(true)); + break; + case LCT_UPVAL: + build.inst(IrCmd::CAPTURE, build.vmUpvalue(index), build.constBool(false)); + break; + default: + LUAU_ASSERT(!"Unknown upvalue capture type"); + } +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h new file mode 100644 index 000000000..53030a203 --- /dev/null +++ b/CodeGen/src/IrTranslation.h @@ -0,0 +1,58 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +#include "ltm.h" + +typedef uint32_t Instruction; + +namespace Luau +{ +namespace CodeGen +{ + +enum class IrCondition : uint8_t; +struct IrOp; +struct IrBuilder; + +void translateInstLoadNil(IrBuilder& build, const Instruction* pc); +void translateInstLoadB(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstLoadN(IrBuilder& build, const Instruction* pc); +void translateInstLoadK(IrBuilder& build, const Instruction* pc); +void translateInstLoadKX(IrBuilder& build, const Instruction* pc); +void translateInstMove(IrBuilder& build, const Instruction* pc); +void translateInstJump(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstJumpBack(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstJumpIf(IrBuilder& build, const Instruction* pc, int pcpos, bool not_); +void translateInstJumpIfEq(IrBuilder& build, const Instruction* pc, int pcpos, bool not_); +void translateInstJumpIfCond(IrBuilder& build, const Instruction* pc, int pcpos, IrCondition cond); +void translateInstJumpX(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstJumpxEqNil(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstJumpxEqB(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstJumpxEqN(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstJumpxEqS(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstBinary(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm); +void translateInstBinaryK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm); +void translateInstNot(IrBuilder& build, const Instruction* pc); +void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstNewTable(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc); +void translateInstGetTableN(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstSetTableN(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstGetTable(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstSetTable(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstGetImport(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos); + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrUtils.h b/CodeGen/src/IrUtils.h index f0e4cee6c..558817896 100644 --- a/CodeGen/src/IrUtils.h +++ b/CodeGen/src/IrUtils.h @@ -98,6 +98,7 @@ inline bool isBlockTerminator(IrCmd cmd) case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_STR: case IrCmd::JUMP_CMP_ANY: + case IrCmd::LOP_NAMECALL: case IrCmd::LOP_RETURN: case IrCmd::LOP_FORNPREP: case IrCmd::LOP_FORNLOOP: diff --git a/Common/include/Luau/Common.h b/Common/include/Luau/Common.h index e590987c3..31b416fb0 100644 --- a/Common/include/Luau/Common.h +++ b/Common/include/Luau/Common.h @@ -35,7 +35,10 @@ inline AssertHandler& assertHandler() return handler; } -inline int assertCallHandler(const char* expression, const char* file, int line, const char* function) +// We want 'inline' to correctly link this function declared in the header +// But we also want to prevent compiler from inlining this function when optimization and assertions are enabled together +// Reason for that is that compilation times can increase significantly in such a configuration +LUAU_NOINLINE inline int assertCallHandler(const char* expression, const char* file, int line, const char* function) { if (AssertHandler handler = assertHandler()) return handler(expression, file, line, function); diff --git a/Ast/include/Luau/DenseHash.h b/Common/include/Luau/DenseHash.h similarity index 100% rename from Ast/include/Luau/DenseHash.h rename to Common/include/Luau/DenseHash.h diff --git a/Sources.cmake b/Sources.cmake index 36e4f04d9..636b42f70 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -4,6 +4,7 @@ if(NOT ${CMAKE_VERSION} VERSION_LESS "3.19") target_sources(Luau.Common PRIVATE Common/include/Luau/Common.h Common/include/Luau/Bytecode.h + Common/include/Luau/DenseHash.h Common/include/Luau/ExperimentalFlags.h ) endif() @@ -12,7 +13,6 @@ endif() target_sources(Luau.Ast PRIVATE Ast/include/Luau/Ast.h Ast/include/Luau/Confusables.h - Ast/include/Luau/DenseHash.h Ast/include/Luau/Lexer.h Ast/include/Luau/Location.h Ast/include/Luau/ParseOptions.h @@ -82,7 +82,9 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/EmitCommonX64.cpp CodeGen/src/EmitInstructionX64.cpp CodeGen/src/Fallbacks.cpp + CodeGen/src/IrBuilder.cpp CodeGen/src/IrDump.cpp + CodeGen/src/IrTranslation.cpp CodeGen/src/NativeState.cpp CodeGen/src/UnwindBuilderDwarf2.cpp CodeGen/src/UnwindBuilderWin.cpp @@ -96,8 +98,10 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/EmitInstructionX64.h CodeGen/src/Fallbacks.h CodeGen/src/FallbacksProlog.h + CodeGen/src/IrBuilder.h CodeGen/src/IrDump.h CodeGen/src/IrData.h + CodeGen/src/IrTranslation.h CodeGen/src/IrUtils.h CodeGen/src/NativeState.h ) diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index d2091c6b5..1528aa39e 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1285,10 +1285,12 @@ static const char* aux_upvalue(StkId fi, int n, TValue** val) else { Proto* p = f->l.p; - if (!(1 <= n && n <= p->sizeupvalues)) + if (!(1 <= n && n <= p->nups)) // not a valid upvalue return NULL; TValue* r = &f->l.uprefs[n - 1]; *val = ttisupval(r) ? upvalue(r)->v : r; + if (!(1 <= n && n <= p->sizeupvalues)) // don't have a name for this upvalue + return ""; return getstr(p->upvalues[n - 1]); } } diff --git a/fuzz/luau.proto b/fuzz/luau.proto index 190b8c5be..e51d687bd 100644 --- a/fuzz/luau.proto +++ b/fuzz/luau.proto @@ -20,6 +20,7 @@ message Expr { ExprUnary unary = 14; ExprBinary binary = 15; ExprIfElse ifelse = 16; + ExprInterpString interpstring = 17; } } @@ -161,6 +162,10 @@ message ExprIfElse { } } +message ExprInterpString { + repeated Expr parts = 1; +} + message LValue { oneof lvalue_oneof { ExprLocal local = 1; diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index d4d522765..5c7c5bf60 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -282,6 +282,8 @@ struct ProtoToLuau print(expr.binary()); else if (expr.has_ifelse()) print(expr.ifelse()); + else if (expr.has_interpstring()) + print(expr.interpstring()); else source += "_"; } @@ -538,6 +540,28 @@ struct ProtoToLuau } } + void print(const luau::ExprInterpString& expr) + { + source += "`"; + + for (int i = 0; i < expr.parts_size(); ++i) + { + if (expr.parts(i).has_string()) + { + // String literal is added surrounded with "", but that's ok + print(expr.parts(i)); + } + else + { + source += "{"; + print(expr.parts(i)); + source += "}"; + } + } + + source += "`"; + } + void print(const luau::LValue& expr) { if (expr.has_local()) diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 3b59bc338..1b8bb3da3 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -16,10 +16,12 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) LUAU_FASTFLAG(LuauFixAutocompleteInIf) +LUAU_FASTFLAG(LuauFixAutocompleteInWhile) +LUAU_FASTFLAG(LuauFixAutocompleteInFor) using namespace Luau; -static std::optional nullCallback(std::string tag, std::optional ptr) +static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) { return std::nullopt; } @@ -37,9 +39,9 @@ struct ACFixtureImpl : BaseType return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); } - AutocompleteResult autocomplete(char marker) + AutocompleteResult autocomplete(char marker, StringCompletionCallback callback = nullCallback) { - return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), callback); } CheckResult check(const std::string& source) @@ -380,7 +382,7 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") { check(R"( type t1 = { a1 : string, b2 : number } - type t2 = { b2 : string, c3 : string } + type t2 = { b2 : number, c3 : string } function func(abc : t1 & t2) abc. @1 end @@ -629,9 +631,19 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") )"); auto ac5 = autocomplete('1'); - CHECK_EQ(ac5.entryMap.count("do"), 1); - CHECK_EQ(ac5.entryMap.count("end"), 0); - CHECK_EQ(ac5.context, AutocompleteContext::Keyword); + if (FFlag::LuauFixAutocompleteInFor) + { + CHECK_EQ(ac5.entryMap.count("math"), 1); + CHECK_EQ(ac5.entryMap.count("do"), 0); + CHECK_EQ(ac5.entryMap.count("end"), 0); + CHECK_EQ(ac5.context, AutocompleteContext::Expression); + } + else + { + CHECK_EQ(ac5.entryMap.count("do"), 1); + CHECK_EQ(ac5.entryMap.count("end"), 0); + CHECK_EQ(ac5.context, AutocompleteContext::Keyword); + } check(R"( for x = 1, 2, 5 f@1 @@ -649,6 +661,31 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("end"), 1); CHECK_EQ(ac7.context, AutocompleteContext::Statement); + + if (FFlag::LuauFixAutocompleteInFor) + { + check(R"(local Foo = 1 + for x = @11, @22, @35 + )"); + + for (int i = 0; i < 3; ++i) + { + auto ac8 = autocomplete('1' + i); + CHECK_EQ(ac8.entryMap.count("Foo"), 1); + CHECK_EQ(ac8.entryMap.count("do"), 0); + } + + check(R"(local Foo = 1 + for x = @11, @22 + )"); + + for (int i = 0; i < 2; ++i) + { + auto ac9 = autocomplete('1' + i); + CHECK_EQ(ac9.entryMap.count("Foo"), 1); + CHECK_EQ(ac9.entryMap.count("do"), 0); + } + } } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") @@ -740,8 +777,18 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") )"); auto ac2 = autocomplete('1'); - CHECK_EQ(1, ac2.entryMap.size()); - CHECK_EQ(ac2.entryMap.count("do"), 1); + if (FFlag::LuauFixAutocompleteInWhile) + { + CHECK_EQ(3, ac2.entryMap.size()); + CHECK_EQ(ac2.entryMap.count("do"), 1); + CHECK_EQ(ac2.entryMap.count("and"), 1); + CHECK_EQ(ac2.entryMap.count("or"), 1); + } + else + { + CHECK_EQ(1, ac2.entryMap.size()); + CHECK_EQ(ac2.entryMap.count("do"), 1); + } CHECK_EQ(ac2.context, AutocompleteContext::Keyword); check(R"( @@ -757,9 +804,31 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") )"); auto ac4 = autocomplete('1'); - CHECK_EQ(1, ac4.entryMap.size()); - CHECK_EQ(ac4.entryMap.count("do"), 1); + if (FFlag::LuauFixAutocompleteInWhile) + { + CHECK_EQ(3, ac4.entryMap.size()); + CHECK_EQ(ac4.entryMap.count("do"), 1); + CHECK_EQ(ac4.entryMap.count("and"), 1); + CHECK_EQ(ac4.entryMap.count("or"), 1); + } + else + { + CHECK_EQ(1, ac4.entryMap.size()); + CHECK_EQ(ac4.entryMap.count("do"), 1); + } CHECK_EQ(ac4.context, AutocompleteContext::Keyword); + + if (FFlag::LuauFixAutocompleteInWhile) + { + check(R"( + while t@1 + )"); + + auto ac5 = autocomplete('1'); + CHECK_EQ(ac5.entryMap.count("do"), 0); + CHECK_EQ(ac5.entryMap.count("true"), 1); + CHECK_EQ(ac5.entryMap.count("false"), 1); + } } TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") @@ -856,7 +925,7 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac5.entryMap.count("elseif"), 0); CHECK_EQ(ac5.entryMap.count("end"), 0); CHECK_EQ(ac5.context, AutocompleteContext::Statement); - + if (FFlag::LuauFixAutocompleteInIf) { check(R"( @@ -3397,4 +3466,32 @@ TEST_CASE_FIXTURE(ACFixture, "type_reduction_is_hooked_up_to_autocomplete") // CHECK("{| x: nil |}" == toString(*ty2, opts)); } +TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") +{ + ScopedFastFlag luauAutocompleteStringContent{"LuauAutocompleteStringContent", true}; + + loadDefinition(R"( + declare function require(path: string): any + )"); + + std::optional require = frontend.typeCheckerForAutocomplete.globalScope->linearSearchForBinding("require"); + REQUIRE(require); + Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + attachTag(require->typeId, "RequireCall"); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + + check(R"( + local x = require("testing/@1") + )"); + + bool isCorrect = false; + auto ac1 = autocomplete( + '1', [&isCorrect](std::string, std::optional, std::optional contents) -> std::optional { + isCorrect = contents && *contents == "testing/"; + return std::nullopt; + }); + + CHECK(isCorrect); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index d7340ce51..4d3146b30 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -693,6 +693,33 @@ TEST_CASE("Debugger") CHECK(stephits > 100); // note; this will depend on number of instructions which can vary, so we just make sure the callback gets hit often } +TEST_CASE("NDebugGetUpValue") +{ + lua_CompileOptions copts = defaultOptions(); + copts.debugLevel = 0; + // Don't optimize away any upvalues + copts.optimizationLevel = 0; + + runConformance( + "ndebug_upvalues.lua", nullptr, + [](lua_State* L) { + lua_checkstack(L, LUA_MINSTACK); + + // push the second frame's closure to the stack + lua_Debug ar = {}; + REQUIRE(lua_getinfo(L, 1, "f", &ar)); + + // get the first upvalue + const char* u = lua_getupvalue(L, -1, 1); + REQUIRE(u); + // upvalue name is unknown without debug info + CHECK(strcmp(u, "") == 0); + CHECK(lua_tointeger(L, -1) == 5); + lua_pop(L, 2); + }, + nullptr, &copts, /* skipCodegen */ false); +} + TEST_CASE("SameHash") { extern unsigned int luaS_hash(const char* str, size_t len); // internal function, declared in lstring.h - not exposed via lua.h diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 5ff006277..cb6eefc0d 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -606,12 +606,14 @@ void createSomeClasses(Frontend* frontend) TypeId childType = arena.addType(ClassType{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); - ClassType* childClass = getMutable(childType); - childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; - addGlobalBinding(*frontend, "Child", {childType}); moduleScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; + TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, parentType, std::nullopt, {}, nullptr, "Test"}); + + addGlobalBinding(*frontend, "AnotherChild", {anotherChildType}); + moduleScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildType}; + TypeId unrelatedType = arena.addType(ClassType{"Unrelated", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); addGlobalBinding(*frontend, "Unrelated", {unrelatedType}); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 426b520c6..84e286018 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -171,7 +171,6 @@ return bar() TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMultiFx") { - ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; LintResult result = lint(R"( function bar() foo = 6 @@ -192,7 +191,6 @@ return bar() + baz() TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMultiFxWithRead") { - ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; LintResult result = lint(R"( function bar() foo = 6 @@ -216,7 +214,6 @@ return bar() + baz() + read() TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalWithConditional") { - ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; LintResult result = lint(R"( function bar() if true then foo = 6 end @@ -236,7 +233,6 @@ return bar() + baz() TEST_CASE_FIXTURE(Fixture, "GlobalAsLocal3WithConditionalRead") { - ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; LintResult result = lint(R"( function bar() foo = 6 @@ -260,7 +256,6 @@ return bar() + baz() + read() TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalInnerRead") { - ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; LintResult result = lint(R"( function foo() local f = function() return bar end diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 615fc997c..13a956cff 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -174,11 +174,6 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop") TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - check(R"( local a: number & string local b: number diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index dc08ae1c5..80e82fdbb 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + struct ToDotClassFixture : Fixture { ToDotClassFixture() @@ -109,7 +111,27 @@ local function f(a, ...: string) return a end ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"(digraph graphname { +n1 [label="FunctionType 1"]; +n1 -> n2 [label="arg"]; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="GenericType 3"]; +n2 -> n4 [label="tail"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n1 -> n6 [label="ret"]; +n6 [label="TypePack 6"]; +n6 -> n3; +})", + toDot(requireType("f"), opts)); + } + else + { + CHECK_EQ(R"(digraph graphname { n1 [label="FunctionType 1"]; n1 -> n2 [label="arg"]; n2 [label="TypePack 2"]; @@ -125,7 +147,8 @@ n6 -> n7; n7 [label="TypePack 7"]; n7 -> n3; })", - toDot(requireType("f"), opts)); + toDot(requireType("f"), opts)); + } } TEST_CASE_FIXTURE(Fixture, "union") @@ -176,7 +199,35 @@ local a: A ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"(digraph graphname { +n1 [label="TableType A"]; +n1 -> n2 [label="x"]; +n2 [label="number"]; +n1 -> n3 [label="y"]; +n3 [label="FunctionType 3"]; +n3 -> n4 [label="arg"]; +n4 [label="TypePack 4"]; +n4 -> n5 [label="tail"]; +n5 [label="VariadicTypePack 5"]; +n5 -> n6; +n6 [label="string"]; +n3 -> n7 [label="ret"]; +n7 [label="TypePack 7"]; +n1 -> n8 [label="[index]"]; +n8 [label="string"]; +n1 -> n9 [label="[value]"]; +n9 [label="any"]; +n1 -> n10 [label="typeParam"]; +n10 [label="number"]; +n1 -> n5 [label="typePackParam"]; +})", + toDot(requireType("a"), opts)); + } + else + { + CHECK_EQ(R"(digraph graphname { n1 [label="TableType A"]; n1 -> n2 [label="x"]; n2 [label="number"]; @@ -196,7 +247,8 @@ n1 -> n9 [label="typeParam"]; n9 [label="number"]; n1 -> n4 [label="typePackParam"]; })", - toDot(requireType("a"), opts)); + toDot(requireType("a"), opts)); + } // Extra coverage with pointers (unstable values) (void)toDot(requireType("a")); @@ -357,14 +409,31 @@ b = a ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"(digraph graphname { +n1 [label="BoundType 1"]; +n1 -> n2; +n2 [label="TableType 2"]; +n2 -> n3 [label="boundTo"]; +n3 [label="TableType 3"]; +n3 -> n4 [label="x"]; +n4 [label="number"]; +})", + toDot(*ty, opts)); + } + else + { + CHECK_EQ(R"(digraph graphname { n1 [label="TableType 1"]; n1 -> n2 [label="boundTo"]; n2 [label="TableType a"]; n2 -> n3 [label="x"]; n3 [label="number"]; })", - toDot(*ty, opts)); + toDot(*ty, opts)); + } } TEST_CASE_FIXTURE(Fixture, "builtintypes") diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 7d27437d7..0e51f976d 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -814,8 +814,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") TEST_CASE_FIXTURE(Fixture, "tostring_unsee_ttv_if_array") { - ScopedFastFlag sff("LuauUnseeArrayTtv", true); - CheckResult result = check(R"( local x: {string} -- This code is constructed very specifically to use the same (by pointer diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 6c2d31088..860dcfd03 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -8,7 +8,8 @@ using namespace Luau; -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauMatchReturnsOptionalString); TEST_SUITE_BEGIN("BuiltinTests"); @@ -174,7 +175,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_checks_for_numbers") local n = math.max(1,2,"3") )"); - CHECK(!result.errors.empty()); + LUAU_REQUIRE_ERRORS(result); CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } @@ -1004,7 +1005,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { - ScopedFastFlag sff{"LuauSetMetaTableArgsCheck", true}; CheckResult result = check(R"( local a = {b=setmetatable} a.b() @@ -1055,6 +1055,20 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "string_match") +{ + CheckResult result = check(R"( + local s:string + local p = s:match("foo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauMatchReturnsOptionalString) + CHECK_EQ(toString(requireType("p")), "string?"); + else + CHECK_EQ(toString(requireType("p")), "string"); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") { CheckResult result = check(R"END( @@ -1063,12 +1077,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); + if (FFlag::LuauMatchReturnsOptionalString) + { + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); + } + else + { + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + } } -TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types2") +TEST_CASE_FIXTURE(Fixture, "gmatch_capture_types2") { CheckResult result = check(R"END( local a, b, c = ("This is a string"):gmatch("(.()(%a+))")() @@ -1076,9 +1099,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); + if (FFlag::LuauMatchReturnsOptionalString) + { + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); + } + else + { + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") @@ -1095,7 +1127,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") CHECK_EQ(acm->expected, 1); CHECK_EQ(acm->actual, 4); - CHECK_EQ(toString(requireType("a")), "string"); + if (FFlag::LuauMatchReturnsOptionalString) + CHECK_EQ(toString(requireType("a")), "string?"); + else + CHECK_EQ(toString(requireType("a")), "string"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens") @@ -1112,9 +1147,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens CHECK_EQ(acm->expected, 3); CHECK_EQ(acm->actual, 4); - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "string"); - CHECK_EQ(toString(requireType("c")), "number"); + if (FFlag::LuauMatchReturnsOptionalString) + { + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "string?"); + CHECK_EQ(toString(requireType("c")), "number?"); + } + else + { + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "string"); + CHECK_EQ(toString(requireType("c")), "number"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_ignored") @@ -1131,8 +1175,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_igno CHECK_EQ(acm->expected, 2); CHECK_EQ(acm->actual, 3); - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); + if (FFlag::LuauMatchReturnsOptionalString) + { + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + } + else + { + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket") @@ -1143,8 +1195,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "number"); - CHECK_EQ(toString(requireType("b")), "string"); + if (FFlag::LuauMatchReturnsOptionalString) + { + CHECK_EQ(toString(requireType("a")), "number?"); + CHECK_EQ(toString(requireType("b")), "string?"); + } + else + { + CHECK_EQ(toString(requireType("a")), "number"); + CHECK_EQ(toString(requireType("b")), "string"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_leading_end_bracket_is_part_of_set") @@ -1192,9 +1252,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); + if (FFlag::LuauMatchReturnsOptionalString) + { + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); + } + else + { + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") @@ -1210,9 +1279,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") CHECK_EQ(toString(tm->wantedType), "number?"); CHECK_EQ(toString(tm->givenType), "string"); - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); + if (FFlag::LuauMatchReturnsOptionalString) + { + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); + } + else + { + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") @@ -1223,9 +1301,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); + if (FFlag::LuauMatchReturnsOptionalString) + { + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); + } + else + { + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + } CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } @@ -1243,9 +1330,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types2") CHECK_EQ(toString(tm->wantedType), "number?"); CHECK_EQ(toString(tm->givenType), "string"); - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); + if (FFlag::LuauMatchReturnsOptionalString) + { + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); + } + else + { + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + } CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } @@ -1263,9 +1359,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") CHECK_EQ(toString(tm->wantedType), "boolean?"); CHECK_EQ(toString(tm->givenType), "string"); - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); + if (FFlag::LuauMatchReturnsOptionalString) + { + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); + } + else + { + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + } CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 28315b676..becc88aa6 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -398,11 +398,6 @@ local a: ChildClass = i TEST_CASE_FIXTURE(ClassFixture, "intersections_of_unions_of_classes") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : (BaseClass | Vector2) & (ChildClass | AnotherChild) local y : (ChildClass | AnotherChild) @@ -415,11 +410,6 @@ TEST_CASE_FIXTURE(ClassFixture, "intersections_of_unions_of_classes") TEST_CASE_FIXTURE(ClassFixture, "unions_of_intersections_of_classes") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : (BaseClass & ChildClass) | (BaseClass & AnotherChild) | (BaseClass & Vector2) local y : (ChildClass | AnotherChild) @@ -482,8 +472,6 @@ caused by: TEST_CASE_FIXTURE(ClassFixture, "callable_classes") { - ScopedFastFlag luauCallableClasses{"LuauCallableClasses", true}; - CheckResult result = check(R"( local x : CallableClass local y = x("testing") diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 93b405c25..2a681d1a6 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -396,8 +396,6 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { - ScopedFastFlag LuauDeclareClassPrototype("LuauDeclareClassPrototype", true); - unfreeze(typeChecker.globalTypes); LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( declare class Channel diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 70de13d15..de338fe1f 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1726,12 +1726,6 @@ foo(string.find("hello", "e")) TEST_CASE_FIXTURE(Fixture, "luau_subtyping_is_np_hard") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - {"LuauOverloadedFunctionSubtypingPerf", true}, - }; - CheckResult result = check(R"( --!strict @@ -1834,8 +1828,6 @@ TEST_CASE_FIXTURE(Fixture, "other_things_are_not_related_to_function") TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_must_follow_in_overload_resolution") { - ScopedFastFlag luauTypeInferMissingFollows{"LuauTypeInferMissingFollows", true}; - CheckResult result = check(R"( for _ in function():(t0)&((()->())&(()->())) end do diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index b57d88202..e18a73788 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -463,11 +463,6 @@ TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : ((number?) -> number?) & ((string?) -> string?) local y : (nil) -> nil = x -- OK @@ -481,11 +476,6 @@ TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : ((number) -> number) & ((string) -> string) local y : ((number | string) -> (number | string)) = x -- OK @@ -499,11 +489,6 @@ TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : { p : number?, q : string? } & { p : number?, q : number?, r : number? } local y : { p : number?, q : nil, r : number? } = x -- OK @@ -531,8 +516,6 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") { ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, {"LuauUninhabitedSubAnything2", true}, }; @@ -547,11 +530,6 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : ((number?) -> ({ p : number } & { q : number })) & ((string?) -> ({ p : number } & { r : number })) local y : (nil) -> { p : number, q : number, r : number} = x -- OK @@ -566,11 +544,6 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f() local x : ((number?) -> (a | number)) & ((string?) -> (a | string)) @@ -586,11 +559,6 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generics") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f() local x : ((a?) -> (a | b)) & ((c?) -> (b | c)) @@ -606,11 +574,6 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generics") TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic_packs") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f() local x : ((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...)) @@ -626,11 +589,6 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic_packs") TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f() local x : ((number) -> number) & ((nil) -> unknown) @@ -646,11 +604,6 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f() local x : ((number) -> number?) & ((unknown) -> string?) @@ -666,11 +619,6 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_result") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f() local x : ((number) -> number) & ((nil) -> never) @@ -686,11 +634,6 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_result") TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_arguments") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f() local x : ((number) -> number?) & ((never) -> string?) @@ -779,11 +722,6 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_4") TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local a : string? = nil local b : number? = nil @@ -807,11 +745,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_subtypes") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x = setmetatable({ a = 5 }, { p = 5 }); local y = setmetatable({ b = "hi" }, { p = 5, q = "hi" }); @@ -833,11 +766,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_subtypes") TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables_with_properties") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x = setmetatable({ a = 5 }, { p = 5 }); local y = setmetatable({ b = "hi" }, { q = "hi" }); @@ -856,11 +784,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables_with_properties") TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_with_table") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x = setmetatable({ a = 5 }, { p = 5 }); local z = setmetatable({ a = 5, b = "hi" }, { p = 5 }); @@ -881,11 +804,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_with_table") TEST_CASE_FIXTURE(Fixture, "CLI-44817") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( type X = {x: number} type Y = {y: number} diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index fe52d1682..ed3af11b4 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -475,8 +475,6 @@ return l0 TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_anyify_variadic_return_must_follow") { - ScopedFastFlag luauTypeInferMissingFollows{"LuauTypeInferMissingFollows", true}; - CheckResult result = check(R"( return unpack(l0[_]) )"); diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp index 261314a64..adf036532 100644 --- a/tests/TypeInfer.negations.test.cpp +++ b/tests/TypeInfer.negations.test.cpp @@ -14,9 +14,6 @@ namespace struct NegationFixture : Fixture { TypeArena arena; - ScopedFastFlag sff[1]{ - {"LuauSubtypeNormalizer", true}, - }; NegationFixture() { diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 7e99f0b02..02fdfa36e 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -72,17 +72,6 @@ TEST_CASE_FIXTURE(Fixture, "string_function_indirect") CHECK_EQ(*requireType("p"), *typeChecker.stringType); } -TEST_CASE_FIXTURE(Fixture, "string_function_other") -{ - CheckResult result = check(R"( - local s:string - local p = s:match("foo") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("p")), "string"); -} - TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f77cacfa9..dced3f587 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -800,7 +800,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_can_filter_for_intersection_of_ta LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}), opts)); CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } @@ -1436,6 +1438,32 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); } +TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_but_the_discriminant_type_isnt_a_class") +{ + CheckResult result = check(R"( + local function f(x: string | number | Instance | Vector3) + if type(x) == "any" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Instance | Vector3 | number | string) & never", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Instance | Vector3 | number | string) & ~never", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("*error-type*", toString(requireTypeAtPosition({5, 28}))); + } +} + TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") { CheckResult result = check(R"( @@ -1721,8 +1749,6 @@ TEST_CASE_FIXTURE(Fixture, "else_with_no_explicit_expression_should_also_refine_ TEST_CASE_FIXTURE(Fixture, "fuzz_filtered_refined_types_are_followed") { - ScopedFastFlag luauTypeInferMissingFollows{"LuauTypeInferMissingFollows", true}; - CheckResult result = check(R"( local _ do diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index dc3b7ceb7..01e1ead78 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -55,7 +55,10 @@ TEST_CASE_FIXTURE(Fixture, "augment_table") TEST_CASE_FIXTURE(Fixture, "augment_nested_table") { - CheckResult result = check("local t = { p = {} } t.p.foo = 'bar'"); + CheckResult result = check(R"( + local t = { p = {} } + t.p.foo = 'bar' + )"); LUAU_REQUIRE_NO_ERRORS(result); TableType* tType = getMutable(requireType("t")); @@ -70,19 +73,28 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table") TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { - CheckResult result = check("function mkt() return {prop=999} end local t = mkt() t.foo = 'bar'"); + CheckResult result = check(R"( + function mkt() + return {prop=999} + end + + local t = mkt() + t.foo = 'bar' + )"); LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; + + CHECK(err.location == Location{Position{6, 8}, Position{6, 13}}); + CannotExtendTable* error = get(err); - REQUIRE(error != nullptr); + REQUIRE_MESSAGE(error != nullptr, "Expected CannotExtendTable but got: " << toString(err)); // TODO: better, more robust comparison of type vars auto s = toString(error->tableType, ToStringOptions{/*exhaustive*/ true}); CHECK_EQ(s, "{| prop: number |}"); CHECK_EQ(error->prop, "foo"); CHECK_EQ(error->context, CannotExtendTable::Property); - CHECK_EQ(err.location, (Location{Position{0, 59}, Position{0, 64}})); } TEST_CASE_FIXTURE(Fixture, "dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index f4b84262c..fcbe2b147 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1029,10 +1029,6 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") { ScopedFastInt sfi("LuauTypeInferRecursionLimit", 10); - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; CheckResult result = check(R"( function f() @@ -1048,10 +1044,6 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") TEST_CASE_FIXTURE(Fixture, "type_infer_cache_limit_normalizer") { ScopedFastInt sfi("LuauNormalizeCacheLimit", 10); - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; CheckResult result = check(R"( local x : ((number) -> number) & ((string) -> string) & ((nil) -> nil) & (({}) -> {}) @@ -1161,8 +1153,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_retu TEST_CASE_FIXTURE(Fixture, "fuzz_free_table_type_change_during_index_check") { - ScopedFastFlag luauFollowInLvalueIndexCheck{"LuauFollowInLvalueIndexCheck", true}; - CheckResult result = check(R"( local _ = nil while _["" >= _] do diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 80c7ab579..8a55c5cf1 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -112,11 +112,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_never") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f(arg : string & number) : never return arg @@ -127,11 +122,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_never") TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_anything") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f(arg : string & number) : boolean return arg @@ -143,8 +133,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_anything") TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_never") { ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, {"LuauUninhabitedSubAnything2", true}, }; @@ -159,8 +147,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_never") TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_anything") { ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, {"LuauUninhabitedSubAnything2", true}, }; @@ -363,8 +349,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") { - ScopedFastFlag luauTxnLogTypePackIterator{"LuauTxnLogTypePackIterator", true}; - TypePackVar variadicAny{VariadicTypePack{typeChecker.anyType}}; TypePackVar packTmp{TypePack{{typeChecker.anyType}, &variadicAny}}; TypePackVar packSub{TypePack{{typeChecker.anyType, typeChecker.anyType}, &packTmp}}; @@ -376,4 +360,18 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") state.tryUnify(&packSub, &packSuper); } +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_unify_any_should_check_log") +{ + ScopedFastFlag luauUnifyAnyTxnLog{"LuauUnifyAnyTxnLog", true}; + + CheckResult result = check(R"( +repeat +_._,_ = nil +until _ +local l0:(any)&(typeof(_)),l0:(any)|(any) = _,_ + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 94448cfa5..5486b9699 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1039,8 +1039,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "generalize_expectedTypes_with_proper_scope") TEST_CASE_FIXTURE(Fixture, "fuzz_typepack_iter_follow") { - ScopedFastFlag luauTxnLogTypePackIterator{"LuauTxnLogTypePackIterator", true}; - CheckResult result = check(R"( local _ local _ = _,_(),_(_) @@ -1051,8 +1049,6 @@ local _ = _,_(),_(_) TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_typepack_iter_follow_2") { - ScopedFastFlag luauTxnLogTypePackIterator{"LuauTxnLogTypePackIterator", true}; - CheckResult result = check(R"( function test(name, searchTerm) local found = string.find(name:lower(), searchTerm:lower()) diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 8831bb2ea..6f69d6827 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -544,11 +544,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") TEST_CASE_FIXTURE(Fixture, "union_true_and_false") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : boolean local y1 : (true | false) = x -- OK @@ -562,11 +557,6 @@ TEST_CASE_FIXTURE(Fixture, "union_true_and_false") TEST_CASE_FIXTURE(Fixture, "union_of_functions") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : (number) -> number? local y : ((number?) -> number?) | ((number) -> number) = x -- OK @@ -599,11 +589,6 @@ TEST_CASE_FIXTURE(Fixture, "union_of_generic_typepack_functions") TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generics") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f() local x : (a) -> a? @@ -619,11 +604,6 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generics") TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generic_typepacks") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( function f() local x : (number, a...) -> (number?, a...) @@ -639,11 +619,6 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generic_typepacks") TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : (number) -> number? local y : ((number?) -> number) | ((number | string) -> nil) = x -- OK @@ -657,11 +632,6 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : () -> (number | string) local y : (() -> number) | (() -> string) = x -- OK @@ -675,11 +645,6 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : (...nil) -> (...number?) local y : ((...string?) -> (...number)) | ((...number?) -> nil) = x -- OK @@ -693,11 +658,6 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_variadics") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : (number) -> () local y : ((number?) -> ()) | ((...number) -> ()) = x -- OK @@ -711,11 +671,6 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_variadics") TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics") { - ScopedFastFlag sffs[]{ - {"LuauSubtypeNormalizer", true}, - {"LuauTypeNormalization2", true}, - }; - CheckResult result = check(R"( local x : () -> (number?, ...number) local y : (() -> (...number)) | (() -> nil) = x -- OK diff --git a/tests/TypeReduction.test.cpp b/tests/TypeReduction.test.cpp index f2d7b027b..f1d5eae91 100644 --- a/tests/TypeReduction.test.cpp +++ b/tests/TypeReduction.test.cpp @@ -10,10 +10,13 @@ namespace { struct ReductionFixture : Fixture { + TypeReductionOptions typeReductionOpts{/* allowTypeReductionsFromOtherArenas */ true}; + ToStringOptions toStringOpts{true}; + TypeArena arena; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; - TypeReduction reduction{NotNull{&arena}, builtinTypes, NotNull{&iceHandler}}; + TypeReduction reduction{NotNull{&arena}, builtinTypes, NotNull{&iceHandler}, typeReductionOpts}; ReductionFixture() { @@ -28,18 +31,15 @@ struct ReductionFixture : Fixture return *reducedTy; } - std::optional tryReduce(const std::string& annotation) + TypeId reductionof(const std::string& annotation) { - CheckResult result = check("type _Res = " + annotation); - LUAU_REQUIRE_NO_ERRORS(result); - return reduction.reduce(requireTypeAlias("_Res")); + check("type _Res = " + annotation); + return reductionof(requireTypeAlias("_Res")); } - TypeId reductionof(const std::string& annotation) + std::string toStringFull(TypeId ty) { - std::optional reducedTy = tryReduce(annotation); - REQUIRE_MESSAGE(reducedTy, "Exceeded the cartesian product of the type"); - return *reducedTy; + return toString(ty, toStringOpts); } }; } // namespace @@ -50,42 +50,54 @@ TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_exceeded") { ScopedFastInt sfi{"LuauTypeReductionCartesianProductLimit", 5}; - std::optional ty = tryReduce(R"( - string & (number | string | boolean) & (number | string | boolean) + CheckResult result = check(R"( + type T + = string + & (number | string | boolean) + & (number | string | boolean) )"); - CHECK(!ty); + CHECK(!reduction.reduce(requireTypeAlias("T"))); + // LUAU_REQUIRE_ERROR_COUNT(1, result); + // CHECK("Code is too complex to typecheck! Consider simplifying the code around this area" == toString(result.errors[0])); } TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_exceeded_with_normal_limit") { - std::optional ty = tryReduce(R"( - string -- 1 = 1 - & (number | string | boolean) -- 1 * 3 = 3 - & (number | string | boolean) -- 3 * 3 = 9 - & (number | string | boolean) -- 9 * 3 = 27 - & (number | string | boolean) -- 27 * 3 = 81 - & (number | string | boolean) -- 81 * 3 = 243 - & (number | string | boolean) -- 243 * 3 = 729 - & (number | string | boolean) -- 729 * 3 = 2187 - & (number | string | boolean) -- 2187 * 3 = 6561 - & (number | string | boolean) -- 6561 * 3 = 19683 - & (number | string | boolean) -- 19683 * 3 = 59049 - & (number | string) -- 59049 * 2 = 118098 + CheckResult result = check(R"( + type T + = string -- 1 = 1 + & (number | string | boolean) -- 1 * 3 = 3 + & (number | string | boolean) -- 3 * 3 = 9 + & (number | string | boolean) -- 9 * 3 = 27 + & (number | string | boolean) -- 27 * 3 = 81 + & (number | string | boolean) -- 81 * 3 = 243 + & (number | string | boolean) -- 243 * 3 = 729 + & (number | string | boolean) -- 729 * 3 = 2187 + & (number | string | boolean) -- 2187 * 3 = 6561 + & (number | string | boolean) -- 6561 * 3 = 19683 + & (number | string | boolean) -- 19683 * 3 = 59049 + & (number | string) -- 59049 * 2 = 118098 )"); - CHECK(!ty); + CHECK(!reduction.reduce(requireTypeAlias("T"))); + // LUAU_REQUIRE_ERROR_COUNT(1, result); + // CHECK("Code is too complex to typecheck! Consider simplifying the code around this area" == toString(result.errors[0])); } TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_is_zero") { ScopedFastInt sfi{"LuauTypeReductionCartesianProductLimit", 5}; - std::optional ty = tryReduce(R"( - string & (number | string | boolean) & (number | string | boolean) & never + CheckResult result = check(R"( + type T + = string + & (number | string | boolean) + & (number | string | boolean) + & never )"); - CHECK(ty); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(ReductionFixture, "stress_test_recursion_limits") @@ -115,13 +127,10 @@ TEST_CASE_FIXTURE(ReductionFixture, "caching") TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - ToStringOptions opts; - opts.exhaustive = true; - - CHECK("{- x: string -} & {| |}" == toString(reductionof(intersectionTy))); + CHECK("{- x: string -} & {| |}" == toStringFull(reductionof(intersectionTy))); getMutable(ty1)->state = TableState::Sealed; - CHECK("{| x: string |}" == toString(reductionof(intersectionTy))); + CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); } SUBCASE("unsealed_tables") @@ -135,13 +144,10 @@ TEST_CASE_FIXTURE(ReductionFixture, "caching") TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - ToStringOptions opts; - opts.exhaustive = true; - - CHECK("{| x: string |}" == toString(reductionof(intersectionTy))); + CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); getMutable(ty1)->state = TableState::Sealed; - CHECK("{| x: string |}" == toString(reductionof(intersectionTy))); + CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); } SUBCASE("free_types") @@ -152,13 +158,10 @@ TEST_CASE_FIXTURE(ReductionFixture, "caching") TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - ToStringOptions opts; - opts.exhaustive = true; - - CHECK("a & {| |}" == toString(reductionof(intersectionTy))); + CHECK("a & {| |}" == toStringFull(reductionof(intersectionTy))); *asMutable(ty1) = BoundType{ty2}; - CHECK("{| |}" == toString(reductionof(intersectionTy))); + CHECK("{| |}" == toStringFull(reductionof(intersectionTy))); } SUBCASE("we_can_see_that_the_cache_works_if_we_mutate_a_normally_not_mutated_type") @@ -168,13 +171,42 @@ TEST_CASE_FIXTURE(ReductionFixture, "caching") TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - ToStringOptions opts; - opts.exhaustive = true; - - CHECK("never" == toString(reductionof(intersectionTy))); // Bound & number ~ never + CHECK("never" == toStringFull(reductionof(intersectionTy))); // Bound & number ~ never *asMutable(ty1) = BoundType{ty2}; - CHECK("never" == toString(reductionof(intersectionTy))); // Bound & number ~ number, but the cache is `never`. + CHECK("never" == toStringFull(reductionof(intersectionTy))); // Bound & number ~ number, but the cache is `never`. + } + + SUBCASE("ptr_eq_irreducible_unions") + { + TypeId unionTy = arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->numberType}}); + TypeId reducedTy = reductionof(unionTy); + REQUIRE(unionTy == reducedTy); + } + + SUBCASE("ptr_eq_irreducible_intersections") + { + TypeId intersectionTy = arena.addType(IntersectionType{{builtinTypes->stringType, arena.addType(GenericType{"G"})}}); + TypeId reducedTy = reductionof(intersectionTy); + REQUIRE(intersectionTy == reducedTy); + } + + SUBCASE("ptr_eq_free_table") + { + TypeId tableTy = arena.addType(TableType{}); + getMutable(tableTy)->state = TableState::Free; + + TypeId reducedTy = reductionof(tableTy); + REQUIRE(tableTy == reducedTy); + } + + SUBCASE("ptr_eq_unsealed_table") + { + TypeId tableTy = arena.addType(TableType{}); + getMutable(tableTy)->state = TableState::Unsealed; + + TypeId reducedTy = reductionof(tableTy); + REQUIRE(tableTy == reducedTy); } } // caching @@ -183,169 +215,169 @@ TEST_CASE_FIXTURE(ReductionFixture, "intersections_without_negations") SUBCASE("string_and_string") { TypeId ty = reductionof("string & string"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("never_and_string") { TypeId ty = reductionof("never & string"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("string_and_never") { TypeId ty = reductionof("string & never"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("unknown_and_string") { TypeId ty = reductionof("unknown & string"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("string_and_unknown") { TypeId ty = reductionof("string & unknown"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("any_and_string") { TypeId ty = reductionof("any & string"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("string_and_any") { TypeId ty = reductionof("string & any"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("string_or_number_and_string") { TypeId ty = reductionof("(string | number) & string"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("string_and_string_or_number") { TypeId ty = reductionof("string & (string | number)"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("string_and_a") { TypeId ty = reductionof(R"(string & "a")"); - CHECK(R"("a")" == toString(ty)); + CHECK(R"("a")" == toStringFull(ty)); } SUBCASE("boolean_and_true") { TypeId ty = reductionof("boolean & true"); - CHECK("true" == toString(ty)); + CHECK("true" == toStringFull(ty)); } SUBCASE("boolean_and_a") { TypeId ty = reductionof(R"(boolean & "a")"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("a_and_a") { TypeId ty = reductionof(R"("a" & "a")"); - CHECK(R"("a")" == toString(ty)); + CHECK(R"("a")" == toStringFull(ty)); } SUBCASE("a_and_b") { TypeId ty = reductionof(R"("a" & "b")"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("a_and_true") { TypeId ty = reductionof(R"("a" & true)"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("a_and_true") { TypeId ty = reductionof(R"(true & false)"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("function_type_and_function") { TypeId ty = reductionof("() -> () & fun"); - CHECK("() -> ()" == toString(ty)); + CHECK("() -> ()" == toStringFull(ty)); } SUBCASE("function_type_and_string") { TypeId ty = reductionof("() -> () & string"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("parent_and_child") { TypeId ty = reductionof("Parent & Child"); - CHECK("Child" == toString(ty)); + CHECK("Child" == toStringFull(ty)); } SUBCASE("child_and_parent") { TypeId ty = reductionof("Child & Parent"); - CHECK("Child" == toString(ty)); + CHECK("Child" == toStringFull(ty)); } SUBCASE("child_and_unrelated") { TypeId ty = reductionof("Child & Unrelated"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("string_and_table") { TypeId ty = reductionof("string & {}"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("string_and_child") { TypeId ty = reductionof("string & Child"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("string_and_function") { TypeId ty = reductionof("string & () -> ()"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("function_and_table") { TypeId ty = reductionof("() -> () & {}"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("function_and_class") { TypeId ty = reductionof("() -> () & Child"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("function_and_function") { TypeId ty = reductionof("() -> () & () -> ()"); - CHECK("(() -> ()) & (() -> ())" == toString(ty)); + CHECK("(() -> ()) & (() -> ())" == toStringFull(ty)); } SUBCASE("table_and_table") { TypeId ty = reductionof("{} & {}"); - CHECK("{| |}" == toString(ty)); + CHECK("{| |}" == toStringFull(ty)); } SUBCASE("table_and_metatable") @@ -357,125 +389,137 @@ TEST_CASE_FIXTURE(ReductionFixture, "intersections_without_negations") )"); TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("{ @metatable { }, { } } & {| |}" == toString(ty)); + CHECK("{ @metatable { }, { } } & {| |}" == toStringFull(ty)); } SUBCASE("a_and_string") { TypeId ty = reductionof(R"("a" & string)"); - CHECK(R"("a")" == toString(ty)); + CHECK(R"("a")" == toStringFull(ty)); } SUBCASE("reducible_function_and_function") { TypeId ty = reductionof("((string | string) -> (number | number)) & fun"); - CHECK("(string) -> number" == toString(ty)); + CHECK("(string) -> number" == toStringFull(ty)); } SUBCASE("string_and_error") { TypeId ty = reductionof("string & err"); - CHECK("*error-type* & string" == toString(ty)); + CHECK("*error-type* & string" == toStringFull(ty)); } SUBCASE("table_p_string_and_table_p_number") { TypeId ty = reductionof("{ p: string } & { p: number }"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("table_p_string_and_table_p_string") { TypeId ty = reductionof("{ p: string } & { p: string }"); - CHECK("{| p: string |}" == toString(ty)); + CHECK("{| p: string |}" == toStringFull(ty)); } SUBCASE("table_x_table_p_string_and_table_x_table_p_number") { TypeId ty = reductionof("{ x: { p: string } } & { x: { p: number } }"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("table_p_and_table_q") { TypeId ty = reductionof("{ p: string } & { q: number }"); - CHECK("{| p: string, q: number |}" == toString(ty)); + CHECK("{| p: string, q: number |}" == toStringFull(ty)); } SUBCASE("table_tag_a_or_table_tag_b_and_table_b") { TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { b: string }"); - CHECK("{| a: number, b: string, tag: string |} | {| b: string, tag: number |}" == toString(ty)); + CHECK("{| a: number, b: string, tag: string |} | {| b: string, tag: number |}" == toStringFull(ty)); } SUBCASE("table_string_number_indexer_and_table_string_number_indexer") { TypeId ty = reductionof("{ [string]: number } & { [string]: number }"); - CHECK("{| [string]: number |}" == toString(ty)); + CHECK("{| [string]: number |}" == toStringFull(ty)); } SUBCASE("table_string_number_indexer_and_empty_table") { TypeId ty = reductionof("{ [string]: number } & {}"); - CHECK("{| [string]: number |}" == toString(ty)); + CHECK("{| [string]: number |}" == toStringFull(ty)); } SUBCASE("empty_table_table_string_number_indexer") { TypeId ty = reductionof("{} & { [string]: number }"); - CHECK("{| [string]: number |}" == toString(ty)); + CHECK("{| [string]: number |}" == toStringFull(ty)); } SUBCASE("string_number_indexer_and_number_number_indexer") { TypeId ty = reductionof("{ [string]: number } & { [number]: number }"); - CHECK("never" == toString(ty)); + CHECK("{number} & {| [string]: number |}" == toStringFull(ty)); } SUBCASE("table_p_string_and_indexer_number_number") { TypeId ty = reductionof("{ p: string } & { [number]: number }"); - CHECK("{| [number]: number, p: string |}" == toString(ty)); + CHECK("{| [number]: number, p: string |}" == toStringFull(ty)); } SUBCASE("table_p_string_and_indexer_string_number") { TypeId ty = reductionof("{ p: string } & { [string]: number }"); - CHECK("{| [string]: number, p: string |}" == toString(ty)); + CHECK("{| [string]: number, p: string |}" == toStringFull(ty)); } SUBCASE("table_p_string_and_table_p_string_plus_indexer_string_number") { TypeId ty = reductionof("{ p: string } & { p: string, [string]: number }"); - CHECK("{| [string]: number, p: string |}" == toString(ty)); + CHECK("{| [string]: number, p: string |}" == toStringFull(ty)); } SUBCASE("fresh_type_and_string") { TypeId freshTy = arena.freshType(nullptr); TypeId ty = reductionof(arena.addType(IntersectionType{{freshTy, builtinTypes->stringType}})); - CHECK("a & string" == toString(ty)); + CHECK("a & string" == toStringFull(ty)); } SUBCASE("string_and_fresh_type") { TypeId freshTy = arena.freshType(nullptr); TypeId ty = reductionof(arena.addType(IntersectionType{{builtinTypes->stringType, freshTy}})); - CHECK("a & string" == toString(ty)); + CHECK("a & string" == toStringFull(ty)); } SUBCASE("generic_and_string") { TypeId genericTy = arena.addType(GenericType{"G"}); TypeId ty = reductionof(arena.addType(IntersectionType{{genericTy, builtinTypes->stringType}})); - CHECK("G & string" == toString(ty)); + CHECK("G & string" == toStringFull(ty)); } SUBCASE("string_and_generic") { TypeId genericTy = arena.addType(GenericType{"G"}); TypeId ty = reductionof(arena.addType(IntersectionType{{builtinTypes->stringType, genericTy}})); - CHECK("G & string" == toString(ty)); + CHECK("G & string" == toStringFull(ty)); + } + + SUBCASE("parent_and_child_or_parent_and_anotherchild_or_parent_and_unrelated") + { + TypeId ty = reductionof("Parent & (Child | AnotherChild | Unrelated)"); + CHECK("AnotherChild | Child" == toString(ty)); + } + + SUBCASE("parent_and_child_or_parent_and_anotherchild_or_parent_and_unrelated_2") + { + TypeId ty = reductionof("(Parent & Child) | (Parent & AnotherChild) | (Parent & Unrelated)"); + CHECK("AnotherChild | Child" == toString(ty)); } } // intersections_without_negations @@ -484,163 +528,163 @@ TEST_CASE_FIXTURE(ReductionFixture, "intersections_with_negations") SUBCASE("nil_and_not_nil") { TypeId ty = reductionof("nil & Not"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("nil_and_not_false") { TypeId ty = reductionof("nil & Not"); - CHECK("nil" == toString(ty)); + CHECK("nil" == toStringFull(ty)); } SUBCASE("string_or_nil_and_not_nil") { TypeId ty = reductionof("(string?) & Not"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("string_or_nil_and_not_false_or_nil") { TypeId ty = reductionof("(string?) & Not"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("string_or_nil_and_not_false_and_not_nil") { TypeId ty = reductionof("(string?) & Not & Not"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("not_false_and_bool") { TypeId ty = reductionof("Not & boolean"); - CHECK("true" == toString(ty)); + CHECK("true" == toStringFull(ty)); } SUBCASE("function_type_and_not_function") { TypeId ty = reductionof("() -> () & Not"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("function_type_and_not_string") { TypeId ty = reductionof("() -> () & Not"); - CHECK("() -> ()" == toString(ty)); + CHECK("() -> ()" == toStringFull(ty)); } SUBCASE("not_a_and_string_or_nil") { TypeId ty = reductionof(R"(Not<"a"> & (string | nil))"); - CHECK(R"((string & ~"a")?)" == toString(ty)); + CHECK(R"((string & ~"a")?)" == toStringFull(ty)); } SUBCASE("not_a_and_a") { TypeId ty = reductionof(R"(Not<"a"> & "a")"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("not_a_and_b") { TypeId ty = reductionof(R"(Not<"a"> & "b")"); - CHECK(R"("b")" == toString(ty)); + CHECK(R"("b")" == toStringFull(ty)); } SUBCASE("not_string_and_a") { TypeId ty = reductionof(R"(Not & "a")"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("not_bool_and_true") { TypeId ty = reductionof("Not & true"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("not_string_and_true") { TypeId ty = reductionof("Not & true"); - CHECK("true" == toString(ty)); + CHECK("true" == toStringFull(ty)); } SUBCASE("parent_and_not_child") { TypeId ty = reductionof("Parent & Not"); - CHECK("Parent & ~Child" == toString(ty)); + CHECK("Parent & ~Child" == toStringFull(ty)); } SUBCASE("not_child_and_parent") { TypeId ty = reductionof("Not & Parent"); - CHECK("Parent & ~Child" == toString(ty)); + CHECK("Parent & ~Child" == toStringFull(ty)); } SUBCASE("child_and_not_parent") { TypeId ty = reductionof("Child & Not"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("not_parent_and_child") { TypeId ty = reductionof("Not & Child"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("not_parent_and_unrelated") { TypeId ty = reductionof("Not & Unrelated"); - CHECK("Unrelated" == toString(ty)); + CHECK("Unrelated" == toStringFull(ty)); } SUBCASE("unrelated_and_not_parent") { TypeId ty = reductionof("Unrelated & Not"); - CHECK("Unrelated" == toString(ty)); + CHECK("Unrelated" == toStringFull(ty)); } SUBCASE("not_unrelated_and_parent") { TypeId ty = reductionof("Not & Parent"); - CHECK("Parent" == toString(ty)); + CHECK("Parent" == toStringFull(ty)); } SUBCASE("parent_and_not_unrelated") { TypeId ty = reductionof("Parent & Not"); - CHECK("Parent" == toString(ty)); + CHECK("Parent" == toStringFull(ty)); } SUBCASE("reducible_function_and_not_function") { TypeId ty = reductionof("((string | string) -> (number | number)) & Not"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("string_and_not_error") { TypeId ty = reductionof("string & Not"); - CHECK("string & ~*error-type*" == toString(ty)); + CHECK("string & ~*error-type*" == toStringFull(ty)); } SUBCASE("table_p_string_and_table_p_not_number") { TypeId ty = reductionof("{ p: string } & { p: Not }"); - CHECK("{| p: string |}" == toString(ty)); + CHECK("{| p: string |}" == toStringFull(ty)); } SUBCASE("table_p_string_and_table_p_not_string") { TypeId ty = reductionof("{ p: string } & { p: Not }"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("table_x_table_p_string_and_table_x_table_p_not_number") { TypeId ty = reductionof("{ x: { p: string } } & { x: { p: Not } }"); - CHECK("{| x: {| p: string |} |}" == toString(ty)); + CHECK("{| x: {| p: string |} |}" == toStringFull(ty)); } } // intersections_with_negations @@ -649,223 +693,223 @@ TEST_CASE_FIXTURE(ReductionFixture, "unions_without_negations") SUBCASE("never_or_string") { TypeId ty = reductionof("never | string"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("string_or_never") { TypeId ty = reductionof("string | never"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("unknown_or_string") { TypeId ty = reductionof("unknown | string"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("string_or_unknown") { TypeId ty = reductionof("string | unknown"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("any_or_string") { TypeId ty = reductionof("any | string"); - CHECK("any" == toString(ty)); + CHECK("any" == toStringFull(ty)); } SUBCASE("string_or_any") { TypeId ty = reductionof("string | any"); - CHECK("any" == toString(ty)); + CHECK("any" == toStringFull(ty)); } SUBCASE("string_or_string_and_number") { TypeId ty = reductionof("string | (string & number)"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("string_or_string") { TypeId ty = reductionof("string | string"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("string_or_number") { TypeId ty = reductionof("string | number"); - CHECK("number | string" == toString(ty)); + CHECK("number | string" == toStringFull(ty)); } SUBCASE("number_or_string") { TypeId ty = reductionof("number | string"); - CHECK("number | string" == toString(ty)); + CHECK("number | string" == toStringFull(ty)); } SUBCASE("string_or_number_or_string") { TypeId ty = reductionof("(string | number) | string"); - CHECK("number | string" == toString(ty)); + CHECK("number | string" == toStringFull(ty)); } SUBCASE("string_or_number_or_string_2") { TypeId ty = reductionof("string | (number | string)"); - CHECK("number | string" == toString(ty)); + CHECK("number | string" == toStringFull(ty)); } SUBCASE("string_or_string_or_number") { TypeId ty = reductionof("string | (string | number)"); - CHECK("number | string" == toString(ty)); + CHECK("number | string" == toStringFull(ty)); } SUBCASE("string_or_string_or_number_or_boolean") { TypeId ty = reductionof("string | (string | number | boolean)"); - CHECK("boolean | number | string" == toString(ty)); + CHECK("boolean | number | string" == toStringFull(ty)); } SUBCASE("string_or_string_or_boolean_or_number") { TypeId ty = reductionof("string | (string | boolean | number)"); - CHECK("boolean | number | string" == toString(ty)); + CHECK("boolean | number | string" == toStringFull(ty)); } SUBCASE("string_or_boolean_or_string_or_number") { TypeId ty = reductionof("string | (boolean | string | number)"); - CHECK("boolean | number | string" == toString(ty)); + CHECK("boolean | number | string" == toStringFull(ty)); } SUBCASE("boolean_or_string_or_number_or_string") { TypeId ty = reductionof("(boolean | string | number) | string"); - CHECK("boolean | number | string" == toString(ty)); + CHECK("boolean | number | string" == toStringFull(ty)); } SUBCASE("boolean_or_true") { TypeId ty = reductionof("boolean | true"); - CHECK("boolean" == toString(ty)); + CHECK("boolean" == toStringFull(ty)); } SUBCASE("boolean_or_false") { TypeId ty = reductionof("boolean | false"); - CHECK("boolean" == toString(ty)); + CHECK("boolean" == toStringFull(ty)); } SUBCASE("boolean_or_true_or_false") { TypeId ty = reductionof("boolean | true | false"); - CHECK("boolean" == toString(ty)); + CHECK("boolean" == toStringFull(ty)); } SUBCASE("string_or_a") { TypeId ty = reductionof(R"(string | "a")"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("a_or_a") { TypeId ty = reductionof(R"("a" | "a")"); - CHECK(R"("a")" == toString(ty)); + CHECK(R"("a")" == toStringFull(ty)); } SUBCASE("a_or_b") { TypeId ty = reductionof(R"("a" | "b")"); - CHECK(R"("a" | "b")" == toString(ty)); + CHECK(R"("a" | "b")" == toStringFull(ty)); } SUBCASE("a_or_b_or_string") { TypeId ty = reductionof(R"("a" | "b" | string)"); - CHECK("string" == toString(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("unknown_or_any") { TypeId ty = reductionof("unknown | any"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("any_or_unknown") { TypeId ty = reductionof("any | unknown"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("function_type_or_function") { TypeId ty = reductionof("() -> () | fun"); - CHECK("function" == toString(ty)); + CHECK("function" == toStringFull(ty)); } SUBCASE("function_or_string") { TypeId ty = reductionof("fun | string"); - CHECK("function | string" == toString(ty)); + CHECK("function | string" == toStringFull(ty)); } SUBCASE("parent_or_child") { TypeId ty = reductionof("Parent | Child"); - CHECK("Parent" == toString(ty)); + CHECK("Parent" == toStringFull(ty)); } SUBCASE("child_or_parent") { TypeId ty = reductionof("Child | Parent"); - CHECK("Parent" == toString(ty)); + CHECK("Parent" == toStringFull(ty)); } SUBCASE("parent_or_unrelated") { TypeId ty = reductionof("Parent | Unrelated"); - CHECK("Parent | Unrelated" == toString(ty)); + CHECK("Parent | Unrelated" == toStringFull(ty)); } SUBCASE("parent_or_child_or_unrelated") { TypeId ty = reductionof("Parent | Child | Unrelated"); - CHECK("Parent | Unrelated" == toString(ty)); + CHECK("Parent | Unrelated" == toStringFull(ty)); } SUBCASE("parent_or_unrelated_or_child") { TypeId ty = reductionof("Parent | Unrelated | Child"); - CHECK("Parent | Unrelated" == toString(ty)); + CHECK("Parent | Unrelated" == toStringFull(ty)); } SUBCASE("parent_or_child_or_unrelated_or_child") { TypeId ty = reductionof("Parent | Child | Unrelated | Child"); - CHECK("Parent | Unrelated" == toString(ty)); + CHECK("Parent | Unrelated" == toStringFull(ty)); } SUBCASE("string_or_true") { TypeId ty = reductionof("string | true"); - CHECK("string | true" == toString(ty)); + CHECK("string | true" == toStringFull(ty)); } SUBCASE("string_or_function") { TypeId ty = reductionof("string | () -> ()"); - CHECK("(() -> ()) | string" == toString(ty)); + CHECK("(() -> ()) | string" == toStringFull(ty)); } SUBCASE("string_or_err") { TypeId ty = reductionof("string | err"); - CHECK("*error-type* | string" == toString(ty)); + CHECK("*error-type* | string" == toStringFull(ty)); } } // unions_without_negations @@ -874,211 +918,211 @@ TEST_CASE_FIXTURE(ReductionFixture, "unions_with_negations") SUBCASE("string_or_not_string") { TypeId ty = reductionof("string | Not"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("not_string_or_string") { TypeId ty = reductionof("Not | string"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("not_number_or_string") { TypeId ty = reductionof("Not | string"); - CHECK("~number" == toString(ty)); + CHECK("~number" == toStringFull(ty)); } SUBCASE("string_or_not_number") { TypeId ty = reductionof("string | Not"); - CHECK("~number" == toString(ty)); + CHECK("~number" == toStringFull(ty)); } SUBCASE("not_hi_or_string_and_not_hi") { TypeId ty = reductionof(R"(Not<"hi"> | (string & Not<"hi">))"); - CHECK(R"(~"hi")" == toString(ty)); + CHECK(R"(~"hi")" == toStringFull(ty)); } SUBCASE("string_and_not_hi_or_not_hi") { TypeId ty = reductionof(R"((string & Not<"hi">) | Not<"hi">)"); - CHECK(R"(~"hi")" == toString(ty)); + CHECK(R"(~"hi")" == toStringFull(ty)); } SUBCASE("string_or_not_never") { TypeId ty = reductionof("string | Not"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("not_a_or_not_a") { TypeId ty = reductionof(R"(Not<"a"> | Not<"a">)"); - CHECK(R"(~"a")" == toString(ty)); + CHECK(R"(~"a")" == toStringFull(ty)); } SUBCASE("not_a_or_a") { TypeId ty = reductionof(R"(Not<"a"> | "a")"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("a_or_not_a") { TypeId ty = reductionof(R"("a" | Not<"a">)"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("not_a_or_string") { TypeId ty = reductionof(R"(Not<"a"> | string)"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("string_or_not_a") { TypeId ty = reductionof(R"(string | Not<"a">)"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("not_string_or_a") { TypeId ty = reductionof(R"(Not | "a")"); - CHECK(R"("a" | ~string)" == toString(ty)); + CHECK(R"("a" | ~string)" == toStringFull(ty)); } SUBCASE("a_or_not_string") { TypeId ty = reductionof(R"("a" | Not)"); - CHECK(R"("a" | ~string)" == toString(ty)); + CHECK(R"("a" | ~string)" == toStringFull(ty)); } SUBCASE("not_number_or_a") { TypeId ty = reductionof(R"(Not | "a")"); - CHECK("~number" == toString(ty)); + CHECK("~number" == toStringFull(ty)); } SUBCASE("a_or_not_number") { TypeId ty = reductionof(R"("a" | Not)"); - CHECK("~number" == toString(ty)); + CHECK("~number" == toStringFull(ty)); } SUBCASE("not_a_or_not_b") { TypeId ty = reductionof(R"(Not<"a"> | Not<"b">)"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("boolean_or_not_false") { TypeId ty = reductionof("boolean | Not"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("boolean_or_not_true") { TypeId ty = reductionof("boolean | Not"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("false_or_not_false") { TypeId ty = reductionof("false | Not"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("true_or_not_false") { TypeId ty = reductionof("true | Not"); - CHECK("~false" == toString(ty)); + CHECK("~false" == toStringFull(ty)); } SUBCASE("not_boolean_or_true") { TypeId ty = reductionof("Not | true"); - CHECK("~false" == toString(ty)); + CHECK("~false" == toStringFull(ty)); } SUBCASE("not_false_or_not_boolean") { TypeId ty = reductionof("Not | Not"); - CHECK("~false" == toString(ty)); + CHECK("~false" == toStringFull(ty)); } SUBCASE("function_type_or_not_function") { TypeId ty = reductionof("() -> () | Not"); - CHECK("(() -> ()) | ~function" == toString(ty)); + CHECK("(() -> ()) | ~function" == toStringFull(ty)); } SUBCASE("not_parent_or_child") { TypeId ty = reductionof("Not | Child"); - CHECK("Child | ~Parent" == toString(ty)); + CHECK("Child | ~Parent" == toStringFull(ty)); } SUBCASE("child_or_not_parent") { TypeId ty = reductionof("Child | Not"); - CHECK("Child | ~Parent" == toString(ty)); + CHECK("Child | ~Parent" == toStringFull(ty)); } SUBCASE("parent_or_not_child") { TypeId ty = reductionof("Parent | Not"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("not_child_or_parent") { TypeId ty = reductionof("Not | Parent"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("parent_or_not_unrelated") { TypeId ty = reductionof("Parent | Not"); - CHECK("~Unrelated" == toString(ty)); + CHECK("~Unrelated" == toStringFull(ty)); } SUBCASE("not_string_or_string_and_not_a") { TypeId ty = reductionof(R"(Not | (string & Not<"a">))"); - CHECK(R"(~"a")" == toString(ty)); + CHECK(R"(~"a")" == toStringFull(ty)); } SUBCASE("not_string_or_not_string") { TypeId ty = reductionof("Not | Not"); - CHECK("~string" == toString(ty)); + CHECK("~string" == toStringFull(ty)); } SUBCASE("not_string_or_not_number") { TypeId ty = reductionof("Not | Not"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("not_a_or_not_boolean") { TypeId ty = reductionof(R"(Not<"a"> | Not)"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("not_a_or_boolean") { TypeId ty = reductionof(R"(Not<"a"> | boolean)"); - CHECK(R"(~"a")" == toString(ty)); + CHECK(R"(~"a")" == toStringFull(ty)); } SUBCASE("string_or_err") { TypeId ty = reductionof("string | Not"); - CHECK("string | ~*error-type*" == toString(ty)); + CHECK("string | ~*error-type*" == toStringFull(ty)); } } // unions_with_negations @@ -1086,20 +1130,14 @@ TEST_CASE_FIXTURE(ReductionFixture, "tables") { SUBCASE("reduce_props") { - ToStringOptions opts; - opts.exhaustive = true; - TypeId ty = reductionof("{ x: string | string, y: number | number }"); - CHECK("{| x: string, y: number |}" == toString(ty, opts)); + CHECK("{| x: string, y: number |}" == toStringFull(ty)); } SUBCASE("reduce_indexers") { - ToStringOptions opts; - opts.exhaustive = true; - TypeId ty = reductionof("{ [string | string]: number | number }"); - CHECK("{| [string]: number |}" == toString(ty, opts)); + CHECK("{| [string]: number |}" == toStringFull(ty)); } SUBCASE("reduce_instantiated_type_parameters") @@ -1126,11 +1164,8 @@ TEST_CASE_FIXTURE(ReductionFixture, "tables") SUBCASE("reduce_tables_within_tables") { - ToStringOptions opts; - opts.exhaustive = true; - TypeId ty = reductionof("{ x: { y: string & number } }"); - CHECK("{| x: {| y: never |} |}" == toString(ty, opts)); + CHECK("never" == toStringFull(ty)); } } @@ -1139,21 +1174,23 @@ TEST_CASE_FIXTURE(ReductionFixture, "metatables") SUBCASE("reduce_table_part") { TableType table; + table.state = TableState::Sealed; table.props["x"] = {arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->stringType}})}; TypeId tableTy = arena.addType(std::move(table)); TypeId ty = reductionof(arena.addType(MetatableType{tableTy, arena.addType(TableType{})})); - CHECK("{ @metatable { }, { x: string } }" == toString(ty)); + CHECK("{ @metatable { }, {| x: string |} }" == toStringFull(ty)); } SUBCASE("reduce_metatable_part") { TableType table; + table.state = TableState::Sealed; table.props["x"] = {arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->stringType}})}; TypeId tableTy = arena.addType(std::move(table)); TypeId ty = reductionof(arena.addType(MetatableType{arena.addType(TableType{}), tableTy})); - CHECK("{ @metatable { x: string }, { } }" == toString(ty)); + CHECK("{ @metatable {| x: string |}, { } }" == toStringFull(ty)); } } @@ -1162,37 +1199,37 @@ TEST_CASE_FIXTURE(ReductionFixture, "functions") SUBCASE("reduce_parameters") { TypeId ty = reductionof("(string | string) -> ()"); - CHECK("(string) -> ()" == toString(ty)); + CHECK("(string) -> ()" == toStringFull(ty)); } SUBCASE("reduce_returns") { TypeId ty = reductionof("() -> (string | string)"); - CHECK("() -> string" == toString(ty)); + CHECK("() -> string" == toStringFull(ty)); } SUBCASE("reduce_parameters_and_returns") { TypeId ty = reductionof("(string | string) -> (number | number)"); - CHECK("(string) -> number" == toString(ty)); + CHECK("(string) -> number" == toStringFull(ty)); } SUBCASE("reduce_tail") { TypeId ty = reductionof("() -> ...(string | string)"); - CHECK("() -> (...string)" == toString(ty)); + CHECK("() -> (...string)" == toStringFull(ty)); } SUBCASE("reduce_head_and_tail") { TypeId ty = reductionof("() -> (string | string, number | number, ...(boolean | boolean))"); - CHECK("() -> (string, number, ...boolean)" == toString(ty)); + CHECK("() -> (string, number, ...boolean)" == toStringFull(ty)); } SUBCASE("reduce_overloaded_functions") { TypeId ty = reductionof("((number | number) -> ()) & ((string | string) -> ())"); - CHECK("((number) -> ()) & ((string) -> ())" == toString(ty)); + CHECK("((number) -> ()) & ((string) -> ())" == toStringFull(ty)); } } // functions @@ -1201,49 +1238,49 @@ TEST_CASE_FIXTURE(ReductionFixture, "negations") SUBCASE("not_unknown") { TypeId ty = reductionof("Not"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("not_never") { TypeId ty = reductionof("Not"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("not_any") { TypeId ty = reductionof("Not"); - CHECK("any" == toString(ty)); + CHECK("any" == toStringFull(ty)); } SUBCASE("not_not_reduction") { TypeId ty = reductionof("Not>"); - CHECK("never" == toString(ty)); + CHECK("never" == toStringFull(ty)); } SUBCASE("not_string") { TypeId ty = reductionof("Not"); - CHECK("~string" == toString(ty)); + CHECK("~string" == toStringFull(ty)); } SUBCASE("not_string_or_number") { TypeId ty = reductionof("Not"); - CHECK("~number & ~string" == toString(ty)); + CHECK("~number & ~string" == toStringFull(ty)); } SUBCASE("not_string_and_number") { TypeId ty = reductionof("Not"); - CHECK("unknown" == toString(ty)); + CHECK("unknown" == toStringFull(ty)); } SUBCASE("not_error") { TypeId ty = reductionof("Not"); - CHECK("~*error-type*" == toString(ty)); + CHECK("~*error-type*" == toStringFull(ty)); } } // negations @@ -1252,37 +1289,37 @@ TEST_CASE_FIXTURE(ReductionFixture, "discriminable_unions") SUBCASE("cat_or_dog_and_dog") { TypeId ty = reductionof(R"(({ tag: "cat", catfood: string } | { tag: "dog", dogfood: string }) & { tag: "dog" })"); - CHECK(R"({| dogfood: string, tag: "dog" |})" == toString(ty)); + CHECK(R"({| dogfood: string, tag: "dog" |})" == toStringFull(ty)); } SUBCASE("cat_or_dog_and_not_dog") { TypeId ty = reductionof(R"(({ tag: "cat", catfood: string } | { tag: "dog", dogfood: string }) & { tag: Not<"dog"> })"); - CHECK(R"({| catfood: string, tag: "cat" |})" == toString(ty)); + CHECK(R"({| catfood: string, tag: "cat" |})" == toStringFull(ty)); } SUBCASE("string_or_number_and_number") { TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { tag: string }"); - CHECK("{| a: number, tag: string |}" == toString(ty)); + CHECK("{| a: number, tag: string |}" == toStringFull(ty)); } SUBCASE("string_or_number_and_number") { TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { tag: number }"); - CHECK("{| b: string, tag: number |}" == toString(ty)); + CHECK("{| b: string, tag: number |}" == toStringFull(ty)); } SUBCASE("child_or_unrelated_and_parent") { TypeId ty = reductionof("({ tag: Child, x: number } | { tag: Unrelated, y: string }) & { tag: Parent }"); - CHECK("{| tag: Child, x: number |}" == toString(ty)); + CHECK("{| tag: Child, x: number |}" == toStringFull(ty)); } SUBCASE("child_or_unrelated_and_not_parent") { TypeId ty = reductionof("({ tag: Child, x: number } | { tag: Unrelated, y: string }) & { tag: Not }"); - CHECK("{| tag: Unrelated, y: string |}" == toString(ty)); + CHECK("{| tag: Unrelated, y: string |}" == toStringFull(ty)); } } @@ -1293,7 +1330,7 @@ TEST_CASE_FIXTURE(ReductionFixture, "cycles") check("type F = (f: F) -> ()"); TypeId ty = reductionof(requireTypeAlias("F")); - CHECK("(t1) -> () where t1 = (t1) -> ()" == toString(ty)); + CHECK("t1 where t1 = (t1) -> ()" == toStringFull(ty)); } SUBCASE("recursively_defined_function_and_function") @@ -1301,52 +1338,39 @@ TEST_CASE_FIXTURE(ReductionFixture, "cycles") check("type F = (f: F & fun) -> ()"); TypeId ty = reductionof(requireTypeAlias("F")); - CHECK("(t1) -> () where t1 = (function & t1) -> ()" == toString(ty)); + CHECK("t1 where t1 = (function & t1) -> ()" == toStringFull(ty)); } SUBCASE("recursively_defined_table") { - ToStringOptions opts; - opts.exhaustive = true; - check("type T = { x: T }"); TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("{| x: t1 |} where t1 = {| x: t1 |}" == toString(ty, opts)); + CHECK("t1 where t1 = {| x: t1 |}" == toStringFull(ty)); } SUBCASE("recursively_defined_table_and_table") { - ToStringOptions opts; - opts.exhaustive = true; - check("type T = { x: T & {} }"); TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("{| x: t1 & {| |} |} where t1 = {| x: t1 & {| |} |}" == toString(ty, opts)); + CHECK("t1 where t1 = {| x: t1 & {| |} |}" == toStringFull(ty)); } SUBCASE("recursively_defined_table_and_table_2") { - ToStringOptions opts; - opts.exhaustive = true; - check("type T = { x: T } & { x: number }"); TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("never" == toString(ty)); + CHECK("t1 where t1 = {| x: number |} & {| x: t1 |}" == toStringFull(ty)); } SUBCASE("recursively_defined_table_and_table_3") { - ToStringOptions opts; - opts.exhaustive = true; - check("type T = { x: T } & { x: T }"); TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("{| x: {| x: t1 |} & {| x: t1 |} & {| x: t2 & t2 & {| x: t1 |} & {| x: t1 |} |} |} where t1 = t2 & {| x: t1 |} ; t2 = {| x: t1 |}" == - toString(ty)); + CHECK("t1 where t1 = {| x: t1 |} & {| x: t1 |}" == toStringFull(ty)); } } diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 7a05f8e9c..e23c1a53f 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -923,6 +923,21 @@ assert((function() return table.concat(res, ',') end)() == "6,8,10") +-- checking for a CFG issue that was missed in IR +assert((function(b) + local res = 0 + + if b then + for i = 1, 100 do + res += i + end + else + res += 100000 + end + + return res +end)(true) == 5050) + -- typeof and type require an argument assert(pcall(typeof) == false) assert(pcall(type) == false) diff --git a/tests/conformance/ndebug_upvalues.lua b/tests/conformance/ndebug_upvalues.lua new file mode 100644 index 000000000..bdb67f2ed --- /dev/null +++ b/tests/conformance/ndebug_upvalues.lua @@ -0,0 +1,13 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This tests that the lua_*upval() APIs work correctly even with debug info disabled +local foo = 5 +function clo_test() + -- so `foo` gets captured as an upval + print(foo) + -- yield so we can look at clo_test's upvalues + coroutine.yield() +end + +clo_test() + +return 'OK' diff --git a/tools/faillist.txt b/tools/faillist.txt index 3fcd4200a..2ea178dad 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -42,7 +42,6 @@ AutocompleteTest.type_correct_suggestion_in_argument AutocompleteTest.type_correct_suggestion_in_table BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types -BuiltinTests.assert_removes_falsy_types2 BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy BuiltinTests.bad_select_should_not_crash @@ -53,9 +52,6 @@ BuiltinTests.dont_add_definitions_to_persistent_types BuiltinTests.find_capture_types BuiltinTests.find_capture_types2 BuiltinTests.find_capture_types3 -BuiltinTests.gmatch_capture_types_balanced_escaped_parens -BuiltinTests.gmatch_capture_types_default_capture -BuiltinTests.gmatch_capture_types_parens_in_sets_are_ignored BuiltinTests.gmatch_definition BuiltinTests.ipairs_iterator_should_infer_types_and_type_check BuiltinTests.match_capture_types @@ -80,7 +76,6 @@ BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload BuiltinTests.table_pack BuiltinTests.table_pack_reduce BuiltinTests.table_pack_variadic -BuiltinTests.tonumber_returns_optional_number_type DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props DefinitionTests.declaring_generic_functions @@ -103,7 +98,6 @@ GenericsTests.duplicate_generic_type_packs GenericsTests.duplicate_generic_types GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many -GenericsTests.generic_factories GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_table_method GenericsTests.generic_type_pack_parentheses @@ -140,7 +134,6 @@ NonstrictModeTests.parameters_having_type_any_are_optional NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_props_are_any Normalize.cyclic_table_normalizes_sensibly -Normalize.negations_of_classes ParseErrorRecovery.generic_type_list_recovery ParseErrorRecovery.recovery_of_parenthesized_expressions ParserTests.parse_nesting_based_end_detection_failsafe_earlier @@ -160,16 +153,13 @@ ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weirditer_should_not_loop_forever -ProvisionalTests.while_body_are_also_refined RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string RefinementTest.call_an_incompatible_function_after_using_typeguard RefinementTest.correctly_lookup_property_whose_base_was_previously_refined2 RefinementTest.discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false RefinementTest.discriminate_tag -RefinementTest.eliminate_subclasses_of_instance RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil -RefinementTest.narrow_from_subclasses_of_instance_or_string_or_vector3 RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table @@ -179,7 +169,6 @@ RefinementTest.type_guard_narrowed_into_nothingness RefinementTest.type_narrow_for_all_the_userdata RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector -RefinementTest.typeguard_cast_instance_or_vector3_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.typeguard_narrows_for_table RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table @@ -192,8 +181,6 @@ TableTests.accidentally_checked_prop_in_opposite_branch TableTests.builtin_table_names TableTests.call_method TableTests.call_method_with_explicit_self_argument -TableTests.cannot_augment_sealed_table -TableTests.casting_sealed_tables_with_props_into_table_with_indexer TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.checked_prop_too_early @@ -218,12 +205,10 @@ TableTests.function_calls_produces_sealed_table_given_unsealed_table TableTests.generic_table_instantiation_potential_regression TableTests.getmetatable_returns_pointer_to_metatable TableTests.give_up_after_one_metatable_index_look_up -TableTests.hide_table_error_properties TableTests.indexer_on_sealed_table_must_unify_with_free_table TableTests.indexing_from_a_table_should_prefer_properties_when_possible TableTests.inequality_operators_imply_exactly_matching_types TableTests.infer_array_2 -TableTests.infer_indexer_from_value_property_in_literal TableTests.inferred_return_type_of_free_table TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 @@ -243,7 +228,6 @@ TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.oop_indexer_works TableTests.oop_polymorphic TableTests.open_table_unification_2 -TableTests.persistent_sealed_table_is_immutable TableTests.property_lookup_through_tabletypevar_metatable TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table @@ -252,7 +236,6 @@ TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_ TableTests.result_is_always_any_if_lhs_is_any TableTests.result_is_bool_for_equality_operators_if_lhs_is_any TableTests.right_table_missing_key2 -TableTests.scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables @@ -261,7 +244,6 @@ TableTests.table_function_check_use_after_free TableTests.table_indexing_error_location TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict TableTests.table_insert_should_cope_with_optional_properties_in_strict -TableTests.table_param_row_polymorphism_2 TableTests.table_param_row_polymorphism_3 TableTests.table_simple_call TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors @@ -276,12 +258,10 @@ TableTests.used_colon_correctly TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon TableTests.used_dot_instead_of_colon_but_correctly -ToDot.bound_table -ToDot.function -ToDot.table ToString.exhaustive_toString_of_cyclic_table ToString.function_type_with_argument_names_and_self ToString.function_type_with_argument_names_generic +ToString.named_metatable_toStringNamedFunction ToString.toStringDetailed2 ToString.toStringErrorPack ToString.toStringNamedFunction_generic_pack @@ -297,10 +277,12 @@ TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails TryUnifyTests.variadics_should_use_reversed_properly TypeAliases.cannot_create_cyclic_type_with_unknown_module +TypeAliases.corecursive_types_generic TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2 TypeAliases.generic_param_remap TypeAliases.mismatched_generic_type_param +TypeAliases.mutually_recursive_types_errors TypeAliases.mutually_recursive_types_restriction_not_ok_1 TypeAliases.mutually_recursive_types_restriction_not_ok_2 TypeAliases.mutually_recursive_types_swapsies_not_ok @@ -308,8 +290,6 @@ TypeAliases.recursive_types_restriction_not_ok TypeAliases.report_shadowed_aliases TypeAliases.stringify_optional_parameterized_alias TypeAliases.stringify_type_alias_of_recursive_template_table_type -TypeAliases.stringify_type_alias_of_recursive_template_table_type2 -TypeAliases.type_alias_fwd_declaration_is_precise TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_of_an_imported_recursive_generic_type @@ -337,15 +317,12 @@ TypeInferAnyError.metatable_of_any_can_be_a_table TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.classes_without_overloaded_operators_cannot_be_added -TypeInferClasses.detailed_class_unification_error TypeInferClasses.higher_order_function_arguments_are_contravariant TypeInferClasses.index_instance_property TypeInferClasses.optional_class_field_access_error TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferClasses.warn_when_prop_almost_matches -TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types -TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site @@ -374,9 +351,7 @@ TypeInferFunctions.return_type_by_overload TypeInferFunctions.too_few_arguments_variadic TypeInferFunctions.too_few_arguments_variadic_generic TypeInferFunctions.too_few_arguments_variadic_generic2 -TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_arguments_error_location -TypeInferFunctions.too_many_return_values TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function TypeInferFunctions.vararg_function_is_quantified @@ -399,8 +374,6 @@ TypeInferModules.module_type_conflict_instantiated TypeInferModules.require_a_variadic_function TypeInferModules.type_error_of_unknown_qualified_type TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works -TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 -TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.method_depends_on_table TypeInferOOP.methods_are_topologically_sorted @@ -461,10 +434,7 @@ TypePackTests.type_pack_type_parameters TypePackTests.unify_variadic_tails_in_arguments TypePackTests.unify_variadic_tails_in_arguments_free TypePackTests.variadic_packs -TypeReductionTests.discriminable_unions -TypeReductionTests.intersections_with_negations TypeReductionTests.negations -TypeReductionTests.unions_with_negations TypeSingletons.error_detailed_tagged_union_mismatch_bool TypeSingletons.error_detailed_tagged_union_mismatch_string TypeSingletons.function_call_with_singletons From b0b7dfb71446738aff288c601eb152b06cc061e1 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 20 Jan 2023 14:18:59 +0200 Subject: [PATCH 30/66] Fix a few style changes that went out-of-sync --- Analysis/src/ToString.cpp | 36 +++++++++++++----------------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 37f806100..89d3c5557 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -310,8 +310,7 @@ struct TypeStringifier } Luau::visit( - [this, tv](auto&& t) - { + [this, tv](auto&& t) { return (*this)(tv, t); }, tv->ty); @@ -907,8 +906,7 @@ struct TypePackStringifier } Luau::visit( - [this, tp](auto&& t) - { + [this, tp](auto&& t) { return (*this)(tp, t); }, tp->ty); @@ -1058,11 +1056,9 @@ static void assignCycleNames(const std::set& cycles, const std::set(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) { // If we have a cycle type in type parameters, assign a cycle name for this named table - if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), - [&](auto&& el) - { - return cycles.count(follow(el)); - }) != ttv->instantiatedTypeParams.end()) + if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { + return cycles.count(follow(el)); + }) != ttv->instantiatedTypeParams.end()) cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; continue; @@ -1157,11 +1153,9 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), - [](const auto& a, const auto& b) - { - return a.second < b.second; - }); + std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { + return a.second < b.second; + }); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1172,8 +1166,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) - { + [&tvs, cycleTy = cycleTy](auto&& t) { return tvs(cycleTy, t); }, cycleTy->ty); @@ -1254,11 +1247,9 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), - [](const auto& a, const auto& b) - { - return a.second < b.second; - }); + std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { + return a.second < b.second; + }); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1269,8 +1260,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto t) - { + [&tvs, cycleTy = cycleTy](auto t) { return tvs(cycleTy, t); }, cycleTy->ty); From 53d03f94f7a166b50a9a84c1dcfc0f325ac24a2d Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 27 Jan 2023 13:28:45 -0800 Subject: [PATCH 31/66] Sync to upstream/release/561 --- Analysis/include/Luau/Constraint.h | 1 + .../include/Luau/ConstraintGraphBuilder.h | 5 + Analysis/include/Luau/Normalize.h | 3 +- Analysis/include/Luau/Type.h | 2 + Analysis/include/Luau/VisitType.h | 3 - Analysis/src/BuiltinDefinitions.cpp | 56 +- Analysis/src/ConstraintGraphBuilder.cpp | 43 +- Analysis/src/ConstraintSolver.cpp | 37 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 10 +- Analysis/src/Frontend.cpp | 2 +- Analysis/src/Normalize.cpp | 71 +- Analysis/src/Substitution.cpp | 3 +- Analysis/src/ToString.cpp | 8 +- Analysis/src/TxnLog.cpp | 2 - Analysis/src/Type.cpp | 4 +- Analysis/src/TypeChecker2.cpp | 10 +- Analysis/src/TypeInfer.cpp | 271 ++-- Analysis/src/TypeReduction.cpp | 90 +- Analysis/src/Unifier.cpp | 47 +- Ast/include/Luau/Ast.h | 483 +++---- Ast/src/Lexer.cpp | 10 +- Ast/src/Parser.cpp | 6 +- CLI/Repl.cpp | 33 +- CodeGen/include/Luau/CodeGen.h | 5 +- CodeGen/src/CodeGen.cpp | 53 +- CodeGen/src/EmitCommonX64.cpp | 16 +- CodeGen/src/EmitInstructionX64.cpp | 10 +- CodeGen/src/EmitInstructionX64.h | 2 +- CodeGen/src/IrAnalysis.cpp | 50 + CodeGen/src/IrAnalysis.h | 14 + CodeGen/src/IrBuilder.cpp | 11 +- CodeGen/src/IrLoweringX64.cpp | 1242 +++++++++++++++++ CodeGen/src/IrLoweringX64.h | 89 ++ Common/include/Luau/DenseHash.h | 8 +- Common/include/Luau/ExperimentalFlags.h | 1 - Compiler/src/Compiler.cpp | 3 +- Sources.cmake | 5 + VM/src/ldebug.cpp | 17 +- tests/AstJsonEncoder.test.cpp | 2 - tests/Autocomplete.test.cpp | 8 - tests/Compiler.test.cpp | 7 - tests/Conformance.test.cpp | 2 - tests/CostModel.test.cpp | 2 - tests/DenseHash.test.cpp | 79 ++ tests/Fixture.cpp | 3 +- tests/Fixture.h | 1 - tests/Lexer.test.cpp | 10 - tests/Linter.test.cpp | 2 - tests/Normalize.test.cpp | 46 +- tests/Parser.test.cpp | 14 - tests/ToDot.test.cpp | 2 +- tests/Transpiler.test.cpp | 4 - tests/TypeInfer.aliases.test.cpp | 10 +- tests/TypeInfer.functions.test.cpp | 6 - tests/TypeInfer.operators.test.cpp | 6 +- tests/TypeInfer.provisional.test.cpp | 2 - tests/TypeInfer.refinements.test.cpp | 8 +- tests/TypeInfer.singletons.test.cpp | 6 +- tests/TypeInfer.tables.test.cpp | 10 +- tests/TypeInfer.test.cpp | 8 +- tests/TypeInfer.unknownnever.test.cpp | 9 - tests/TypeReduction.test.cpp | 84 ++ tools/faillist.txt | 39 +- 63 files changed, 2334 insertions(+), 762 deletions(-) create mode 100644 CodeGen/src/IrAnalysis.cpp create mode 100644 CodeGen/src/IrAnalysis.h create mode 100644 CodeGen/src/IrLoweringX64.cpp create mode 100644 CodeGen/src/IrLoweringX64.h create mode 100644 tests/DenseHash.test.cpp diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index ec94eee96..f814cb9f7 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -89,6 +89,7 @@ struct NameConstraint { TypeId namedType; std::string name; + bool synthetic = false; }; // target ~ inst target diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index a1caf85af..aac99afc3 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -80,6 +80,8 @@ struct ConstraintGraphBuilder // A mapping of AST node to TypePackId. DenseHashMap astTypePacks{nullptr}; + DenseHashMap astExpectedTypes{nullptr}; + // If the node was applied as a function, this is the unspecialized type of // that expression. DenseHashMap astOriginalCallTypes{nullptr}; @@ -88,6 +90,8 @@ struct ConstraintGraphBuilder // overload that was selected. DenseHashMap astOverloadResolvedTypes{nullptr}; + + // Types resolved from type annotations. Analogous to astTypes. DenseHashMap astResolvedTypes{nullptr}; @@ -207,6 +211,7 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); + Inference check(const ScopePtr& scope, AstExprInterpString* interpString); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 865a9c4d3..15dc7d4a1 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -253,7 +253,8 @@ struct NormalizedType TypeId threads; // The (meta)table part of the type. - // Each element of this set is a (meta)table type. + // Each element of this set is a (meta)table type, or the top `table` type. + // An empty set denotes never. TypeIds tables; // The function part of the type. diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 4962274c9..0136327dc 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -117,6 +117,7 @@ struct PrimitiveType String, Thread, Function, + Table, }; Type type; @@ -651,6 +652,7 @@ struct BuiltinTypes const TypeId threadType; const TypeId functionType; const TypeId classType; + const TypeId tableType; const TypeId trueType; const TypeId falseType; const TypeId anyType; diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index e0ab12e7e..ff4dfc3c3 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -9,7 +9,6 @@ #include "Luau/Type.h" LUAU_FASTINT(LuauVisitRecursionLimit) -LUAU_FASTFLAG(LuauCompleteVisitor); namespace Luau { @@ -322,8 +321,6 @@ struct GenericTypeVisitor if (visit(ty, *ntv)) traverse(ntv->ty); } - else if (!FFlag::LuauCompleteVisitor) - return visit_detail::unsee(seen, ty); else LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypeId) is not exhaustive!"); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 1fb915e95..1a5a6bf65 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -15,9 +15,7 @@ #include -LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false) -LUAU_FASTFLAG(LuauReportShadowedTypeAlias) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -252,11 +250,8 @@ void registerBuiltinTypes(Frontend& frontend) frontend.getGlobalScope()->addBuiltinTypeBinding("string", TypeFun{{}, frontend.builtinTypes->stringType}); frontend.getGlobalScope()->addBuiltinTypeBinding("boolean", TypeFun{{}, frontend.builtinTypes->booleanType}); frontend.getGlobalScope()->addBuiltinTypeBinding("thread", TypeFun{{}, frontend.builtinTypes->threadType}); - if (FFlag::LuauUnknownAndNeverType) - { - frontend.getGlobalScope()->addBuiltinTypeBinding("unknown", TypeFun{{}, frontend.builtinTypes->unknownType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("never", TypeFun{{}, frontend.builtinTypes->neverType}); - } + frontend.getGlobalScope()->addBuiltinTypeBinding("unknown", TypeFun{{}, frontend.builtinTypes->unknownType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("never", TypeFun{{}, frontend.builtinTypes->neverType}); } void registerBuiltinGlobals(TypeChecker& typeChecker) @@ -315,7 +310,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) FunctionType{ {genericMT}, {}, - arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}), + arena.addTypePack(TypePack{{tabTy, genericMT}}), arena.addTypePack(TypePack{{tableMetaMT}}) } ), "@luau" @@ -357,8 +352,7 @@ void registerBuiltinGlobals(Frontend& frontend) LUAU_ASSERT(!frontend.globalTypes.types.isFrozen()); LUAU_ASSERT(!frontend.globalTypes.typePacks.isFrozen()); - if (FFlag::LuauReportShadowedTypeAlias) - registerBuiltinTypes(frontend); + registerBuiltinTypes(frontend); TypeArena& arena = frontend.globalTypes; NotNull builtinTypes = frontend.builtinTypes; @@ -409,7 +403,7 @@ void registerBuiltinGlobals(Frontend& frontend) FunctionType{ {genericMT}, {}, - arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}), + arena.addTypePack(TypePack{{tabTy, genericMT}}), arena.addTypePack(TypePack{{tableMetaMT}}) } ), "@luau" @@ -537,11 +531,8 @@ static std::optional> magicFunctionSetMetaTable( { auto [paramPack, _predicates] = withPredicate; - if (FFlag::LuauUnknownAndNeverType) - { - if (size(paramPack) < 2 && finite(paramPack)) - return std::nullopt; - } + if (size(paramPack) < 2 && finite(paramPack)) + return std::nullopt; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -550,11 +541,8 @@ static std::optional> magicFunctionSetMetaTable( TypeId target = follow(expectedArgs[0]); TypeId mt = follow(expectedArgs[1]); - if (FFlag::LuauUnknownAndNeverType) - { - typechecker.tablify(target); - typechecker.tablify(mt); - } + typechecker.tablify(target); + typechecker.tablify(mt); if (const auto& tab = get(target)) { @@ -564,9 +552,6 @@ static std::optional> magicFunctionSetMetaTable( } else { - if (!FFlag::LuauUnknownAndNeverType) - typechecker.tablify(mt); - const TableType* mtTtv = get(mt); MetatableType mtv{target, mt}; if ((tab->name || tab->syntheticName) && (mtTtv && (mtTtv->name || mtTtv->syntheticName))) @@ -583,12 +568,7 @@ static std::optional> magicFunctionSetMetaTable( TypeId mtTy = arena.addType(mtv); if (expr.args.size < 1) - { - if (FFlag::LuauUnknownAndNeverType) - return std::nullopt; - else - return WithPredicate{}; - } + return std::nullopt; if (!expr.self) { @@ -635,20 +615,10 @@ static std::optional> magicFunctionAssert( if (head.size() > 0) { auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true, typechecker.builtinTypes->nilType); - if (FFlag::LuauUnknownAndNeverType) - { - if (get(*ty)) - head = {*ty}; - else - head[0] = *ty; - } + if (get(*ty)) + head = {*ty}; else - { - if (!ty) - head = {typechecker.nilType}; - else - head[0] = *ty; - } + head[0] = *ty; } return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 7181e4f03..b6184e36e 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -414,6 +414,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) std::vector varTypes; varTypes.reserve(local->vars.size); + // Used to name the first value type, even if it's not placed in varTypes, + // for the purpose of synthetic name attribution. + std::optional firstValueType; + for (AstLocal* local : local->vars) { TypeId ty = nullptr; @@ -456,6 +460,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) else varTypes[i] = exprType; } + + if (i == 0) + firstValueType = exprType; } else { @@ -488,6 +495,22 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) } } + if (local->vars.size == 1 && local->values.size == 1 && firstValueType) + { + AstLocal* var = local->vars.data[0]; + AstExpr* value = local->values.data[0]; + + if (value->is()) + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); + else if (const AstExprCall* call = value->as()) + { + if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") + { + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); + } + } + } + for (size_t i = 0; i < local->vars.size; ++i) { AstLocal* l = local->vars.data[i]; @@ -1138,7 +1161,13 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa TypeId resultTy = arena->addType(mtv); if (AstExprLocal* targetLocal = targetExpr->as()) + { scope->bindings[targetLocal->local].typeId = resultTy; + auto def = dfg->getDef(targetLocal->local); + if (def) + scope->dcrRefinements[*def] = resultTy; // TODO: typestates: track this as an assignment + } + return InferencePack{arena->addTypePack({resultTy}), std::move(returnConnectives)}; } @@ -1248,6 +1277,8 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st result = check(scope, ifElse, expectedType); else if (auto typeAssert = expr->as()) result = check(scope, typeAssert); + else if (auto interpString = expr->as()) + result = check(scope, interpString); else if (auto err = expr->as()) { // Open question: Should we traverse into this? @@ -1264,6 +1295,8 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st LUAU_ASSERT(result.ty); astTypes[expr] = result.ty; + if (expectedType) + astExpectedTypes[expr] = *expectedType; return result; } @@ -1509,6 +1542,14 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssert return Inference{resolveType(scope, typeAssert->annotation, /* inTypeArguments */ false)}; } +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprInterpString* interpString) +{ + for (AstExpr* expr : interpString->expressions) + check(scope, expr); + + return Inference{builtinTypes->stringType}; +} + std::tuple ConstraintGraphBuilder::checkBinary( const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) { @@ -1551,7 +1592,7 @@ std::tuple ConstraintGraphBuilder::checkBinary( else if (typeguard->type == "boolean") discriminantTy = builtinTypes->threadType; else if (typeguard->type == "table") - discriminantTy = builtinTypes->neverType; // TODO: replace with top table type + discriminantTy = builtinTypes->tableType; else if (typeguard->type == "function") discriminantTy = builtinTypes->functionType; else if (typeguard->type == "userdata") diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 3fbd7d9e2..0a9b82bad 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -883,7 +883,12 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull(target)) - ttv->name = c.name; + { + if (c.synthetic && !ttv->name) + ttv->syntheticName = c.name; + else + ttv->name = c.name; + } else if (MetatableType* mtv = getMutable(target)) mtv->syntheticName = c.name; else if (get(target) || get(target)) @@ -1594,7 +1599,7 @@ bool ConstraintSolver::tryDispatchIterableFunction( const TypeId firstIndex = isNil(firstIndexTy) ? arena->freshType(constraint->scope) // FIXME: Surely this should be a union (free | nil) : firstIndexTy; - // nextTy : (tableTy, indexTy?) -> (indexTy, valueTailTy...) + // nextTy : (tableTy, indexTy?) -> (indexTy?, valueTailTy...) const TypePackId nextArgPack = arena->addTypePack({tableTy, arena->addType(UnionType{{firstIndex, builtinTypes->nilType}})}); const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); @@ -1602,7 +1607,25 @@ bool ConstraintSolver::tryDispatchIterableFunction( const TypeId expectedNextTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); unify(nextTy, expectedNextTy, constraint->scope); - pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); + auto it = begin(nextRetPack); + std::vector modifiedNextRetHead; + + // The first value is never nil in the context of the loop, even if it's nil + // in the next function's return type, because the loop will not advance if + // it's nil. + if (it != end(nextRetPack)) + { + TypeId firstRet = *it; + TypeId modifiedFirstRet = stripNil(builtinTypes, *arena, firstRet); + modifiedNextRetHead.push_back(modifiedFirstRet); + ++it; + } + + for (; it != end(nextRetPack); ++it) + modifiedNextRetHead.push_back(*it); + + TypePackId modifiedNextRetPack = arena->addTypePack(std::move(modifiedNextRetHead), it.tail()); + pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, modifiedNextRetPack}); return true; } @@ -1649,8 +1672,8 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons resultType = parts[0]; else if (parts.size() > 1) resultType = arena->addType(UnionType{std::move(parts)}); - else - LUAU_ASSERT(false); // parts.size() == 0 + + // otherwise, nothing: no matching property } else if (auto itv = get(subjectType)) { @@ -1662,8 +1685,8 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons resultType = parts[0]; else if (parts.size() > 1) resultType = arena->addType(IntersectionType{std::move(parts)}); - else - LUAU_ASSERT(false); // parts.size() == 0 + + // otherwise, nothing: no matching property } return resultType; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 1fe09773c..364244ad3 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(LuauUnknownAndNeverType) - namespace Luau { @@ -115,6 +113,7 @@ declare function typeof(value: T): string -- `assert` has a magic function attached that will give more detailed type information declare function assert(value: T, errorMessage: string?): T +declare function error(message: T, level: number?): never declare function tostring(value: T): string declare function tonumber(value: T, radix: number?): number? @@ -199,14 +198,7 @@ declare function unpack(tab: {V}, i: number?, j: number?): ...V std::string getBuiltinDefinitionSource() { - std::string result = kBuiltinDefinitionLuaSrc; - - if (FFlag::LuauUnknownAndNeverType) - result += "declare function error(message: T, level: number?): never\n"; - else - result += "declare function error(message: T, level: number?)\n"; - return result; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index f4e529dbe..d200df343 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -922,10 +922,10 @@ ModulePtr Frontend::check( for (TypeError& e : cs.errors) result->errors.emplace_back(std::move(e)); - result->scopes = std::move(cgb.scopes); result->astTypes = std::move(cgb.astTypes); result->astTypePacks = std::move(cgb.astTypePacks); + result->astExpectedTypes = std::move(cgb.astExpectedTypes); result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); result->astOverloadResolvedTypes = std::move(cgb.astOverloadResolvedTypes); result->astResolvedTypes = std::move(cgb.astResolvedTypes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 901144e4a..a7b2b7276 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -19,7 +19,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); -LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) @@ -448,8 +448,20 @@ static bool areNormalizedFunctions(const NormalizedFunctionType& tys) static bool areNormalizedTables(const TypeIds& tys) { for (TypeId ty : tys) - if (!get(ty) && !get(ty)) + { + if (get(ty) || get(ty)) + continue; + + const PrimitiveType* pt = get(ty); + if (!pt) return false; + + if (pt->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) + continue; + + return false; + } + return true; } @@ -1216,7 +1228,25 @@ void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) { for (TypeId there : theres) - unionTablesWithTable(heres, there); + { + if (FFlag::LuauNegatedTableTypes) + { + if (there == builtinTypes->tableType) + { + heres.clear(); + heres.insert(there); + return; + } + else + { + unionTablesWithTable(heres, there); + } + } + else + { + unionTablesWithTable(heres, there); + } + } } // So why `ignoreSmallerTyvars`? @@ -1375,6 +1405,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); here.functions.resetToTop(); } + else if (ptv->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) + { + here.tables.clear(); + here.tables.insert(there); + } else LUAU_ASSERT(!"Unreachable"); } @@ -1504,6 +1539,21 @@ std::optional Normalizer::negateNormal(const NormalizedType& her return std::nullopt; } + /* + * It is not possible to negate an arbitrary table type, because function + * types are not runtime-testable. Thus, we prohibit negation of anything + * other than `table` and `never`. + */ + if (FFlag::LuauNegatedTableTypes) + { + if (here.tables.empty()) + result.tables.insert(builtinTypes->tableType); + else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType) + result.tables.clear(); + else + return std::nullopt; + } + // TODO: negating tables // TODO: negating tyvars? @@ -1571,6 +1621,10 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) case PrimitiveType::Function: here.functions.resetToNever(); break; + case PrimitiveType::Table: + LUAU_ASSERT(FFlag::LuauNegatedTableTypes); + here.tables.clear(); + break; } } @@ -1995,6 +2049,11 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there if (sharedState->counters.recursionLimit > 0 && sharedState->counters.recursionLimit < sharedState->counters.recursionCount) return std::nullopt; + if (isPrim(here, PrimitiveType::Table)) + return there; + else if (isPrim(there, PrimitiveType::Table)) + return here; + TypeId htable = here; TypeId hmtable = nullptr; if (const MetatableType* hmtv = get(here)) @@ -2522,6 +2581,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) NormalizedStringType strings = std::move(here.strings); NormalizedFunctionType functions = std::move(here.functions); TypeId threads = here.threads; + TypeIds tables = std::move(here.tables); clearNormal(here); @@ -2540,6 +2600,11 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); here.functions = std::move(functions); } + else if (ptv->type == PrimitiveType::Table) + { + LUAU_ASSERT(FFlag::LuauNegatedTableTypes); + here.tables = std::move(tables); + } else LUAU_ASSERT(!"Unreachable"); } diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 2469152eb..160647a05 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -12,7 +12,6 @@ LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false) LUAU_FASTFLAG(LuauClonePublicInterfaceLess) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false) -LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) namespace Luau @@ -184,7 +183,7 @@ TarjanResult Tarjan::loop() if (currEdge == -1) { ++childCount; - if (childLimit > 0 && (FFlag::LuauUnknownAndNeverType ? childLimit <= childCount : childLimit < childCount)) + if (childLimit > 0 && childLimit <= childCount) return TarjanResult::TooManyChildren; stack.push_back(index); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 89d3c5557..1972177cf 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -14,7 +14,6 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) /* @@ -444,6 +443,9 @@ struct TypeStringifier case PrimitiveType::Function: state.emit("function"); return; + case PrimitiveType::Table: + state.emit("table"); + return; default: LUAU_ASSERT(!"Unknown primitive type"); throw InternalCompilerError("Unknown primitive type " + std::to_string(ptv.type)); @@ -823,7 +825,7 @@ struct TypeStringifier void operator()(TypeId, const ErrorType& tv) { state.result.error = true; - state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); + state.emit("*error-type*"); } void operator()(TypeId, const LazyType& ltv) @@ -962,7 +964,7 @@ struct TypePackStringifier void operator()(TypePackId, const Unifiable::Error& error) { state.result.error = true; - state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); + state.emit("*error-type*"); } void operator()(TypePackId, const VariadicTypePack& pack) diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index dacd82dc1..5040952e8 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAG(LuauUnknownAndNeverType) - namespace Luau { diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 4b2165187..f29a02241 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -24,7 +24,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauMatchReturnsOptionalString, false); @@ -213,7 +212,7 @@ bool isOptional(TypeId ty) ty = follow(ty); - if (get(ty) || (FFlag::LuauUnknownAndNeverType && get(ty))) + if (get(ty) || get(ty)) return true; auto utv = get(ty); @@ -761,6 +760,7 @@ BuiltinTypes::BuiltinTypes() , threadType(arena->addType(Type{PrimitiveType{PrimitiveType::Thread}, /*persistent*/ true})) , functionType(arena->addType(Type{PrimitiveType{PrimitiveType::Function}, /*persistent*/ true})) , classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}}, /*persistent*/ true})) + , tableType(arena->addType(Type{PrimitiveType{PrimitiveType::Table}, /*persistent*/ true})) , trueType(arena->addType(Type{SingletonType{BooleanSingleton{true}}, /*persistent*/ true})) , falseType(arena->addType(Type{SingletonType{BooleanSingleton{false}}, /*persistent*/ true})) , anyType(arena->addType(Type{AnyType{}, /*persistent*/ true})) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 1972f26f3..59c488fdf 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -372,7 +372,7 @@ struct TypeChecker2 break; } - AstLocal* var = local->vars.data[i]; + AstLocal* var = local->vars.data[j]; if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); @@ -755,6 +755,8 @@ struct TypeChecker2 return visit(e); else if (auto e = expr->as()) return visit(e); + else if (auto e = expr->as()) + return visit(e); else if (auto e = expr->as()) return visit(e); else @@ -1358,6 +1360,12 @@ struct TypeChecker2 visit(expr->falseExpr, RValue); } + void visit(AstExprInterpString* interpString) + { + for (AstExpr* expr : interpString->expressions) + visit(expr, RValue); + } + void visit(AstExprError* expr) { // TODO! diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index de52a5261..a25ddc7c4 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,18 +35,10 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) -LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) -LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) -LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauScopelessModule, false) -LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) -LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) -LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) -LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) @@ -246,11 +238,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull builtin globalScope->addBuiltinTypeBinding("string", TypeFun{{}, stringType}); globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, booleanType}); globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, threadType}); - if (FFlag::LuauUnknownAndNeverType) - { - globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, unknownType}); - globalScope->addBuiltinTypeBinding("never", TypeFun{{}, neverType}); - } + globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, unknownType}); + globalScope->addBuiltinTypeBinding("never", TypeFun{{}, neverType}); } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -661,7 +650,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std Name name = typealias->name.value; - if (FFlag::LuauReportShadowedTypeAlias && duplicateTypeAliases.contains({typealias->exported, name})) + if (duplicateTypeAliases.contains({typealias->exported, name})) continue; TypeId type = bindings[name].type; @@ -1066,17 +1055,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) // If the expression list only contains one expression and it's a function call or is otherwise within parentheses, use FunctionResult. // Otherwise, we'll want to use ExprListResult to make the error messaging more general. - CountMismatch::Context ctx = FFlag::LuauBetterMessagingOnCountMismatch ? CountMismatch::ExprListResult : CountMismatch::FunctionResult; - if (FFlag::LuauBetterMessagingOnCountMismatch) + CountMismatch::Context ctx = CountMismatch::ExprListResult; + if (local.values.size == 1) { - if (local.values.size == 1) - { - AstExpr* e = local.values.data[0]; - while (auto group = e->as()) - e = group->expr; - if (e->is()) - ctx = CountMismatch::FunctionResult; - } + AstExpr* e = local.values.data[0]; + while (auto group = e->as()) + e = group->expr; + if (e->is()) + ctx = CountMismatch::FunctionResult; } Unifier state = mkUnifier(scope, local.location); @@ -1438,11 +1424,8 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (FFlag::LuauUnknownAndNeverType) - { - InplaceDemoter demoter{funScope->level, ¤tModule->internalTypes}; - demoter.traverse(ty); - } + InplaceDemoter demoter{funScope->level, ¤tModule->internalTypes}; + demoter.traverse(ty); if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; @@ -1591,12 +1574,9 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea Location location = scope->typeAliasLocations[name]; reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); - if (!FFlag::LuauReportShadowedTypeAlias) - bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; - duplicateTypeAliases.insert({typealias.exported, name}); } - else if (FFlag::LuauReportShadowedTypeAlias) + else { if (globalScope->builtinTypeNames.contains(name)) { @@ -1623,25 +1603,6 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea scope->typeAliasNameLocations[name] = typealias.nameLocation; } } - else - { - ScopePtr aliasScope = childScope(scope, typealias.location); - aliasScope->level = scope->level.incr(); - aliasScope->level.subLevel = subLevel; - - auto [generics, genericPacks] = - createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); - - TypeId ty = freshType(aliasScope); - FreeType* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; - - scope->typeAliasLocations[name] = typealias.location; - if (FFlag::SupportTypeAliasGoToDeclaration) - scope->typeAliasNameLocations[name] = typealias.nameLocation; - } } void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) @@ -1840,7 +1801,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - result = checkExpr(scope, *a, FFlag::LuauBinaryNeedsExpectedTypesToo ? expectedType : std::nullopt); + result = checkExpr(scope, *a, expectedType); else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -2084,7 +2045,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( } std::vector result = reduceUnion(goodOptions); - if (FFlag::LuauUnknownAndNeverType && result.empty()) + if (result.empty()) return neverType; if (result.size() == 1) @@ -2432,13 +2393,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp operandType = stripFromNilAndReport(operandType, expr.location); // # operator is guaranteed to return number - if ((FFlag::LuauNeverTypesAndOperatorsInference && get(operandType)) || get(operandType) || get(operandType)) - { - if (FFlag::LuauNeverTypesAndOperatorsInference) - return {numberType}; - else - return {!FFlag::LuauUnknownAndNeverType ? errorRecoveryType(scope) : operandType}; - } + if (get(operandType) || get(operandType) || get(operandType)) + return {numberType}; DenseHashSet seen{nullptr}; @@ -2518,7 +2474,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, cons return a; std::vector types = reduceUnion({a, b}); - if (FFlag::LuauUnknownAndNeverType && types.empty()) + if (types.empty()) return neverType; if (types.size() == 1) @@ -2649,12 +2605,9 @@ TypeId TypeChecker::checkRelationalOperation( case AstExprBinary::CompareGe: case AstExprBinary::CompareLe: { - if (FFlag::LuauNeverTypesAndOperatorsInference) - { - // If one of the operand is never, it doesn't make sense to unify these. - if (get(lhsType) || get(rhsType)) - return booleanType; - } + // If one of the operand is never, it doesn't make sense to unify these. + if (get(lhsType) || get(rhsType)) + return booleanType; if (FFlag::LuauIntersectionTestForEquality && isEquality) { @@ -2897,10 +2850,8 @@ TypeId TypeChecker::checkBinaryOperation( // If we know nothing at all about the lhs type, we can usually say nothing about the result. // The notable exception to this is the equality and inequality operators, which always produce a boolean. - const bool lhsIsAny = get(lhsType) || get(lhsType) || - (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get(lhsType)); - const bool rhsIsAny = get(rhsType) || get(rhsType) || - (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get(rhsType)); + const bool lhsIsAny = get(lhsType) || get(lhsType) || get(lhsType); + const bool rhsIsAny = get(rhsType) || get(rhsType) || get(rhsType); if (lhsIsAny) return lhsType; @@ -3102,7 +3053,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {trueType.type}; std::vector types = reduceUnion({trueType.type, falseType.type}); - if (FFlag::LuauUnknownAndNeverType && types.empty()) + if (types.empty()) return {neverType}; return {types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})}; } @@ -3709,15 +3660,10 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) { - if (FFlag::LuauUnknownAndNeverType) - { - WithPredicate result = checkExprPackHelper(scope, expr); - if (containsNever(result.type)) - return {uninhabitableTypePack}; - return result; - } - else - return checkExprPackHelper(scope, expr); + WithPredicate result = checkExprPackHelper(scope, expr); + if (containsNever(result.type)) + return {uninhabitableTypePack}; + return result; } WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr) @@ -3843,10 +3789,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam } TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); - if (FFlag::LuauReturnsFromCallsitesAreNotWidened) - state.tryUnify(tail, varPack); - else - state.tryUnify(varPack, tail); + state.tryUnify(tail, varPack); return; } } @@ -4031,24 +3974,13 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope return {errorRecoveryTypePack(scope)}; TypePack* args = nullptr; - if (FFlag::LuauUnknownAndNeverType) - { - if (expr.self) - { - argPack = addTypePack(TypePack{{selfType}, argPack}); - argListResult.type = argPack; - } - args = getMutable(argPack); - LUAU_ASSERT(args); - } - else + if (expr.self) { - args = getMutable(argPack); - LUAU_ASSERT(args != nullptr); - - if (expr.self) - args->head.insert(args->head.begin(), selfType); + argPack = addTypePack(TypePack{{selfType}, argPack}); + argListResult.type = argPack; } + args = getMutable(argPack); + LUAU_ASSERT(args); std::vector argLocations; argLocations.reserve(expr.args.size + 1); @@ -4107,7 +4039,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st else { std::vector result = reduceUnion({*el, ty}); - if (FFlag::LuauUnknownAndNeverType && result.empty()) + if (result.empty()) el = neverType; else el = result.size() == 1 ? result[0] : addType(UnionType{std::move(result)}); @@ -4451,7 +4383,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); - if (FFlag::LuauUnknownAndNeverType && containsNever(typePack)) + if (containsNever(typePack)) { // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, // ...never) @@ -4474,7 +4406,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons auto [type, exprPredicates] = checkExpr(scope, *expr, expectedType); insert(exprPredicates); - if (FFlag::LuauUnknownAndNeverType && get(type)) + if (get(type)) { // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, // ...never) @@ -4509,7 +4441,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons for (TxnLog& log : inverseLogs) log.commit(); - if (FFlag::LuauUnknownAndNeverType && uninhabitable) + if (uninhabitable) return {uninhabitableTypePack}; return {pack, predicates}; } @@ -4997,16 +4929,8 @@ std::optional TypeChecker::filterMapImpl(TypeId type, TypeIdPredicate pr std::pair, bool> TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { - if (FFlag::LuauUnknownAndNeverType) - { - TypeId ty = filterMapImpl(type, predicate).value_or(neverType); - return {ty, !bool(get(ty))}; - } - else - { - std::optional ty = filterMapImpl(type, predicate); - return {ty, bool(ty)}; - } + TypeId ty = filterMapImpl(type, predicate).value_or(neverType); + return {ty, !bool(get(ty))}; } std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense, TypeId emptySetTy) @@ -5587,18 +5511,7 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const if (!key) { auto [result, ok] = filterMap(*ty, predicate); - if (FFlag::LuauUnknownAndNeverType) - { - addRefinement(refis, *target, *result); - } - else - { - if (ok) - addRefinement(refis, *target, *result); - else - addRefinement(refis, *target, errorRecoveryType(scope)); - } - + addRefinement(refis, *target, *result); return; } @@ -5621,21 +5534,10 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const return; // Do nothing. An error was already reported, as per usual. auto [result, ok] = filterMap(*discriminantTy, predicate); - if (FFlag::LuauUnknownAndNeverType) + if (!get(*result)) { - if (!get(*result)) - { - viableTargetOptions.insert(option); - viableChildOptions.insert(*result); - } - } - else - { - if (ok) - { - viableTargetOptions.insert(option); - viableChildOptions.insert(*result); - } + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); } } @@ -5891,7 +5793,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional mapsTo = std::nullopt) { TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional { - if (FFlag::LuauUnknownAndNeverType && sense && get(ty)) + if (sense && get(ty)) return mapsTo.value_or(ty); if (f(ty) == sense) @@ -5985,56 +5887,44 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc if (maybeSingleton(eqP.type)) { - if (FFlag::LuauImplicitElseRefinement) + bool optionIsSubtype = canUnify(option, eqP.type, scope, eqP.location).empty(); + bool targetIsSubtype = canUnify(eqP.type, option, scope, eqP.location).empty(); + + // terminology refresher: + // - option is the type of the expression `x`, and + // - eqP.type is the type of the expression `"hello"` + // + // "hello" == x where + // x : "hello" | "world" -> x : "hello" + // x : number | string -> x : "hello" + // x : number -> x : never + // + // "hello" ~= x where + // x : "hello" | "world" -> x : "world" + // x : number | string -> x : number | string + // x : number -> x : number + + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional nope = std::nullopt; + + if (sense) { - bool optionIsSubtype = canUnify(option, eqP.type, scope, eqP.location).empty(); - bool targetIsSubtype = canUnify(eqP.type, option, scope, eqP.location).empty(); - - // terminology refresher: - // - option is the type of the expression `x`, and - // - eqP.type is the type of the expression `"hello"` - // - // "hello" == x where - // x : "hello" | "world" -> x : "hello" - // x : number | string -> x : "hello" - // x : number -> x : never - // - // "hello" ~= x where - // x : "hello" | "world" -> x : "world" - // x : number | string -> x : number | string - // x : number -> x : number - - // local variable works around an odd gcc 9.3 warning: may be used uninitialized - std::optional nope = std::nullopt; - - if (sense) - { - if (optionIsSubtype && !targetIsSubtype) - return option; - else if (!optionIsSubtype && targetIsSubtype) - return follow(eqP.type); - else if (!optionIsSubtype && !targetIsSubtype) - return nope; - else if (optionIsSubtype && targetIsSubtype) - return follow(eqP.type); - } - else - { - bool isOptionSingleton = get(option); - if (!isOptionSingleton) - return option; - else if (optionIsSubtype && targetIsSubtype) - return nope; - } + if (optionIsSubtype && !targetIsSubtype) + return option; + else if (!optionIsSubtype && targetIsSubtype) + return follow(eqP.type); + else if (!optionIsSubtype && !targetIsSubtype) + return nope; + else if (optionIsSubtype && targetIsSubtype) + return follow(eqP.type); } else { - if (!sense || canUnify(eqP.type, option, scope, eqP.location).empty()) - return sense ? eqP.type : option; - - // local variable works around an odd gcc 9.3 warning: may be used uninitialized - std::optional res = std::nullopt; - return res; + bool isOptionSingleton = get(option); + if (!isOptionSingleton) + return option; + else if (optionIsSubtype && targetIsSubtype) + return nope; } } @@ -6063,8 +5953,7 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp // HACK: tryUnify would undo the changes to the expectedTypePack if the length mismatches, but // we want to tie up free types to be error types, so we do this instead. - if (FFlag::LuauUnknownAndNeverType) - currentModule->errors.resize(oldErrorsSize); + currentModule->errors.resize(oldErrorsSize); for (TypeId& tp : expectedPack->head) tp = follow(tp); diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index c748f8632..3404c71d8 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -56,12 +56,6 @@ struct TypeReducer TypeId memoize(TypeId ty, TypeId reducedTy); TypePackId memoize(TypePackId tp, TypePackId reducedTp); - // It's either cyclic with no memoized result, so we should terminate, or - // there is a memoized result but one that's being reduced top-down, so - // we need to return the root of that memoized result to tighten up things. - TypeId memoizedOr(TypeId ty) const; - TypePackId memoizedOr(TypePackId tp) const; - using BinaryFold = std::optional (TypeReducer::*)(TypeId, TypeId); using UnaryFold = TypeId (TypeReducer::*)(TypeId); @@ -319,6 +313,24 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) } else if (auto [f, p] = get2(left, right); f && p) return intersectionType(right, left); // () -> () & P ~ P & () -> () + else if (auto [p, t] = get2(left, right); p && t) + { + if (p->type == PrimitiveType::Table) + return right; // table & {} ~ {} + else + return builtinTypes->neverType; // string & {} ~ never + } + else if (auto [p, t] = get2(left, right); p && t) + { + if (p->type == PrimitiveType::Table) + return right; // table & {} ~ {} + else + return builtinTypes->neverType; // string & {} ~ never + } + else if (auto [t, p] = get2(left, right); t && p) + return intersectionType(right, left); // {} & P ~ P & {} + else if (auto [t, p] = get2(left, right); t && p) + return intersectionType(right, left); // M & P ~ P & M else if (auto [s1, s2] = get2(left, right); s1 && s2) { if (*s1 == *s2) @@ -472,6 +484,20 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) else return right; // ~Base & Unrelated ~ Unrelated } + else if (auto [np, t] = get2(nlTy, right); np && t) + { + if (np->type == PrimitiveType::Table) + return builtinTypes->neverType; // ~table & {} ~ never + else + return right; // ~string & {} ~ {} + } + else if (auto [np, t] = get2(nlTy, right); np && t) + { + if (np->type == PrimitiveType::Table) + return builtinTypes->neverType; // ~table & {} ~ never + else + return right; // ~string & {} ~ {} + } else return std::nullopt; // TODO } @@ -529,6 +555,24 @@ std::optional TypeReducer::unionType(TypeId left, TypeId right) } else if (auto [f, p] = get2(left, right); f && p) return unionType(right, left); // () -> () | P ~ P | () -> () + else if (auto [p, t] = get2(left, right); p && t) + { + if (p->type == PrimitiveType::Table) + return left; // table | {} ~ table + else + return std::nullopt; // P | {} ~ P | {} + } + else if (auto [p, t] = get2(left, right); p && t) + { + if (p->type == PrimitiveType::Table) + return left; // table | {} ~ table + else + return std::nullopt; // P | {} ~ P | {} + } + else if (auto [t, p] = get2(left, right); t && p) + return unionType(right, left); // {} | P ~ P | {} + else if (auto [t, p] = get2(left, right); t && p) + return unionType(right, left); // M | P ~ P | M else if (auto [s1, s2] = get2(left, right); s1 && s2) { if (*s1 == *s2) @@ -642,6 +686,20 @@ std::optional TypeReducer::unionType(TypeId left, TypeId right) else return left; // ~Base | Unrelated ~ ~Base } + else if (auto [np, t] = get2(nlTy, right); np && t) + { + if (np->type == PrimitiveType::Table) + return std::nullopt; // ~table | {} ~ ~table | {} + else + return right; // ~P | {} ~ ~P | {} + } + else if (auto [np, t] = get2(nlTy, right); np && t) + { + if (np->type == PrimitiveType::Table) + return std::nullopt; // ~table | {} ~ ~table | {} + else + return right; // ~P | M ~ ~P | M + } else return std::nullopt; // TODO } @@ -850,26 +908,6 @@ TypePackId TypeReducer::memoize(TypePackId tp, TypePackId reducedTp) return reducedTp; } -TypeId TypeReducer::memoizedOr(TypeId ty) const -{ - ty = follow(ty); - - if (auto ctx = memoizedTypes->find(ty)) - return ctx->type; - else - return ty; -}; - -TypePackId TypeReducer::memoizedOr(TypePackId tp) const -{ - tp = follow(tp); - - if (auto ctx = memoizedTypePacks->find(tp)) - return ctx->type; - else - return tp; -}; - struct MarkCycles : TypeVisitor { DenseHashSet cyclicTypes{nullptr}; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 80f63f10a..e6d614411 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -15,10 +15,8 @@ #include -LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTFLAG(LuauErrorRecoveryType); -LUAU_FASTFLAG(LuauUnknownAndNeverType) -LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) +LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) +LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauUnifyAnyTxnLog, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner2, false) @@ -28,6 +26,7 @@ LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNegatedFunctionTypes) LUAU_FASTFLAG(LuauNegatedClassTypes) +LUAU_FASTFLAG(LuauNegatedTableTypes) namespace Luau { @@ -452,13 +451,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } else if (subFree) { - if (FFlag::LuauUnknownAndNeverType) - { - // Normally, if the subtype is free, it should not be bound to any, unknown, or error types. - // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. - if (log.get(superTy)) - return; - } + // Normally, if the subtype is free, it should not be bound to any, unknown, or error types. + // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. + if (log.get(superTy)) + return; // Unification can't change the level of a generic. auto superGeneric = log.getMutable(superTy); @@ -569,6 +565,11 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Ok. Do nothing. forall functions F, F <: function } + else if (FFlag::LuauNegatedTableTypes && isPrim(superTy, PrimitiveType::Table) && (get(subTy) || get(subTy))) + { + // Ok, do nothing: forall tables T, T <: table + } + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyFunctions(subTy, superTy, isFunctionCall); @@ -576,11 +577,11 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyTables(subTy, superTy, isIntersection); } - else if (FFlag::LuauScalarShapeSubtyping && log.get(superTy) && (log.get(subTy) || log.get(subTy))) + else if (log.get(superTy) && (log.get(subTy) || log.get(subTy))) { tryUnifyScalarShape(subTy, superTy, /*reversed*/ false); } - else if (FFlag::LuauScalarShapeSubtyping && log.get(subTy) && (log.get(superTy) || log.get(superTy))) + else if (log.get(subTy) && (log.get(superTy) || log.get(superTy))) { tryUnifyScalarShape(subTy, superTy, /*reversed*/ true); } @@ -1032,6 +1033,12 @@ void Unifier::tryUnifyNormalizedTypes( bool found = false; for (TypeId superTable : superNorm.tables) { + if (FFlag::LuauNegatedTableTypes && isPrim(superTable, PrimitiveType::Table)) + { + found = true; + break; + } + Unifier innerState = makeChildUnifier(); if (get(superTable)) innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); @@ -2031,8 +2038,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) { - LUAU_ASSERT(FFlag::LuauScalarShapeSubtyping); - TypeId osubTy = subTy; TypeId osuperTy = superTy; @@ -2490,22 +2495,14 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) return; } - TypePackId anyTp; - if (FFlag::LuauUnknownAndNeverType) - anyTp = types->addTypePack(TypePackVar{VariadicTypePack{anyTy}}); - else - { - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{builtinTypes->anyType}}); - anyTp = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - } + TypePackId anyTp = types->addTypePack(TypePackVar{VariadicTypePack{anyTy}}); std::vector queue = {subTy}; sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny( - queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, FFlag::LuauUnknownAndNeverType ? anyTy : builtinTypes->anyType, anyTp); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, anyTy, anyTp); } void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 7731312db..81221dd10 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -52,246 +52,12 @@ struct AstName } }; -class AstVisitor -{ -public: - virtual ~AstVisitor() {} - - virtual bool visit(class AstNode*) - { - return true; - } - - virtual bool visit(class AstExpr* node) - { - return visit((class AstNode*)node); - } - - virtual bool visit(class AstExprGroup* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprConstantNil* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprConstantBool* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprConstantNumber* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprConstantString* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprLocal* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprGlobal* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprVarargs* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprCall* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprIndexName* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprIndexExpr* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprFunction* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprTable* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprUnary* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprBinary* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprTypeAssertion* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprIfElse* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprInterpString* node) - { - return visit((class AstExpr*)node); - } - virtual bool visit(class AstExprError* node) - { - return visit((class AstExpr*)node); - } - - virtual bool visit(class AstStat* node) - { - return visit((class AstNode*)node); - } - - virtual bool visit(class AstStatBlock* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatIf* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatWhile* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatRepeat* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatBreak* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatContinue* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatReturn* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatExpr* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatLocal* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatFor* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatForIn* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatAssign* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatCompoundAssign* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatFunction* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatLocalFunction* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatTypeAlias* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatDeclareFunction* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatDeclareGlobal* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatDeclareClass* node) - { - return visit((class AstStat*)node); - } - virtual bool visit(class AstStatError* node) - { - return visit((class AstStat*)node); - } - - // By default visiting type annotations is disabled; override this in your visitor if you need to! - virtual bool visit(class AstType* node) - { - return false; - } - - virtual bool visit(class AstTypeReference* node) - { - return visit((class AstType*)node); - } - virtual bool visit(class AstTypeTable* node) - { - return visit((class AstType*)node); - } - virtual bool visit(class AstTypeFunction* node) - { - return visit((class AstType*)node); - } - virtual bool visit(class AstTypeTypeof* node) - { - return visit((class AstType*)node); - } - virtual bool visit(class AstTypeUnion* node) - { - return visit((class AstType*)node); - } - virtual bool visit(class AstTypeIntersection* node) - { - return visit((class AstType*)node); - } - virtual bool visit(class AstTypeSingletonBool* node) - { - return visit((class AstType*)node); - } - virtual bool visit(class AstTypeSingletonString* node) - { - return visit((class AstType*)node); - } - virtual bool visit(class AstTypeError* node) - { - return visit((class AstType*)node); - } - - virtual bool visit(class AstTypePack* node) - { - return false; - } - virtual bool visit(class AstTypePackExplicit* node) - { - return visit((class AstTypePack*)node); - } - virtual bool visit(class AstTypePackVariadic* node) - { - return visit((class AstTypePack*)node); - } - virtual bool visit(class AstTypePackGeneric* node) - { - return visit((class AstTypePack*)node); - } -}; - class AstType; +class AstVisitor; +class AstStat; +class AstStatBlock; +class AstExpr; +class AstTypePack; struct AstLocal { @@ -1277,6 +1043,245 @@ class AstTypePackGeneric : public AstTypePack AstName genericName; }; +class AstVisitor +{ +public: + virtual ~AstVisitor() {} + + virtual bool visit(class AstNode*) + { + return true; + } + + virtual bool visit(class AstExpr* node) + { + return visit(static_cast(node)); + } + + virtual bool visit(class AstExprGroup* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprConstantNil* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprConstantBool* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprConstantNumber* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprConstantString* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprLocal* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprGlobal* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprVarargs* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprCall* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprIndexName* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprIndexExpr* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprFunction* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprTable* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprUnary* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprBinary* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprTypeAssertion* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprIfElse* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprInterpString* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExprError* node) + { + return visit(static_cast(node)); + } + + virtual bool visit(class AstStat* node) + { + return visit(static_cast(node)); + } + + virtual bool visit(class AstStatBlock* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatIf* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatWhile* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatRepeat* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatBreak* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatContinue* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatReturn* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatExpr* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatLocal* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatFor* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatForIn* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatAssign* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatCompoundAssign* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatFunction* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatLocalFunction* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatTypeAlias* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatDeclareFunction* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatDeclareGlobal* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatDeclareClass* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstStatError* node) + { + return visit(static_cast(node)); + } + + // By default visiting type annotations is disabled; override this in your visitor if you need to! + virtual bool visit(class AstType* node) + { + return false; + } + + virtual bool visit(class AstTypeReference* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstTypeTable* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstTypeFunction* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstTypeTypeof* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstTypeUnion* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstTypeIntersection* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstTypeSingletonBool* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstTypeSingletonString* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstTypeError* node) + { + return visit(static_cast(node)); + } + + virtual bool visit(class AstTypePack* node) + { + return false; + } + virtual bool visit(class AstTypePackExplicit* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstTypePackVariadic* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstTypePackGeneric* node) + { + return visit(static_cast(node)); + } +}; + AstName getIdentifier(AstExpr*); Location getLocation(const AstTypeList& typeList); diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 66436acde..118b06798 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAG(LuauInterpolatedStringBaseSupport) - namespace Luau { @@ -835,13 +833,7 @@ Lexeme Lexer::readNext() return readQuotedString(); case '`': - if (FFlag::LuauInterpolatedStringBaseSupport) - return readInterpolatedStringBegin(); - else - { - consume(); - return Lexeme(Location(start, 1), '`'); - } + return readInterpolatedStringBegin(); case '.': consume(); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index c71bd7c58..99a41938b 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -17,8 +17,6 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) -LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) - LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false) bool lua_telemetry_parsed_out_of_range_bin_integer = false; @@ -2174,11 +2172,11 @@ AstExpr* Parser::parseSimpleExpr() return parseNumber(); } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || - (FFlag::LuauInterpolatedStringBaseSupport && lexer.current().type == Lexeme::InterpStringSimple)) + lexer.current().type == Lexeme::InterpStringSimple) { return parseString(); } - else if (FFlag::LuauInterpolatedStringBaseSupport && lexer.current().type == Lexeme::InterpStringBegin) + else if (lexer.current().type == Lexeme::InterpStringBegin) { return parseInterpString(); } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index e567725e5..69a40356b 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -48,8 +48,10 @@ enum class CompileFormat Text, Binary, Remarks, - Codegen, - CodegenVerbose, + Codegen, // Prints annotated native code including IR and assembly + CodegenAsm, // Prints annotated native code assembly + CodegenIr, // Prints annotated native code IR + CodegenVerbose, // Prints annotated native code including IR, assembly and outlined code CodegenNull, Null }; @@ -716,7 +718,19 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st try { Luau::BytecodeBuilder bcb; - Luau::CodeGen::AssemblyOptions options = {format == CompileFormat::CodegenNull, format == CompileFormat::Codegen, annotateInstruction, &bcb}; + + Luau::CodeGen::AssemblyOptions options; + options.outputBinary = format == CompileFormat::CodegenNull; + + if (!options.outputBinary) + { + options.includeAssembly = format != CompileFormat::CodegenIr; + options.includeIr = format != CompileFormat::CodegenAsm; + options.includeOutlinedCode = format == CompileFormat::CodegenVerbose; + } + + options.annotator = annotateInstruction; + options.annotatorContext = &bcb; if (format == CompileFormat::Text) { @@ -729,7 +743,8 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpSource(*source); } - else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenVerbose) + else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || + format == CompileFormat::CodegenVerbose) { bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | Luau::BytecodeBuilder::Dump_Remarks); @@ -760,6 +775,8 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); break; case CompileFormat::Codegen: + case CompileFormat::CodegenAsm: + case CompileFormat::CodegenIr: case CompileFormat::CodegenVerbose: printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); break; @@ -850,6 +867,14 @@ int replMain(int argc, char** argv) { compileFormat = CompileFormat::Codegen; } + else if (strcmp(argv[1], "--compile=codegenasm") == 0) + { + compileFormat = CompileFormat::CodegenAsm; + } + else if (strcmp(argv[1], "--compile=codegenir") == 0) + { + compileFormat = CompileFormat::CodegenIr; + } else if (strcmp(argv[1], "--compile=codegenverbose") == 0) { compileFormat = CompileFormat::CodegenVerbose; diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index cef9ec7cb..84cf682f6 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -22,7 +22,10 @@ using annotatorFn = void (*)(void* context, std::string& result, int fid, int in struct AssemblyOptions { bool outputBinary = false; - bool skipOutlinedCode = false; + + bool includeAssembly = false; + bool includeIr = false; + bool includeOutlinedCode = false; // Optional annotator function can be provided to describe each instruction, it takes function id and sequential instruction id annotatorFn annotator = nullptr; diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 405b92ddd..4ed950c07 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -13,6 +13,9 @@ #include "CodeGenX64.h" #include "EmitCommonX64.h" #include "EmitInstructionX64.h" +#include "IrAnalysis.h" +#include "IrBuilder.h" +#include "IrLoweringX64.h" #include "NativeState.h" #include "lapi.h" @@ -27,6 +30,8 @@ #endif #endif +LUAU_FASTFLAGVARIABLE(DebugUseOldCodegen, false) + namespace Luau { namespace CodeGen @@ -241,7 +246,7 @@ static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& skip = emitInstFastCall2K(build, pc, i, next); break; case LOP_FORNPREP: - emitInstForNPrep(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); + emitInstForNPrep(build, pc, i, next, labelarr[i + 1 + LUAU_INSN_D(*pc)]); break; case LOP_FORNLOOP: emitInstForNLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next); @@ -404,7 +409,7 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat result->proto = proto; - if (build.logText) + if (options.includeAssembly || options.includeIr) { if (proto->debugname) build.logAppend("; function %s()", getstr(proto->debugname)); @@ -417,6 +422,38 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat build.logAppend("\n"); } + if (!FFlag::DebugUseOldCodegen) + { + build.align(kFunctionAlignment, AlignmentDataX64::Ud2); + + Label start = build.setLabel(); + + IrBuilder builder; + builder.buildFunctionIr(proto); + + updateUseInfo(builder.function); + + IrLoweringX64 lowering(build, helpers, data, proto, builder.function); + + lowering.lower(options); + + result->instTargets = new uintptr_t[proto->sizecode]; + + for (int i = 0; i < proto->sizecode; i++) + { + auto [irLocation, asmLocation] = builder.function.bcMapping[i]; + + result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation - start.location; + } + + result->location = start.location; + + if (build.logText) + build.logAppend("\n"); + + return result; + } + std::vector(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); } else { @@ -321,9 +321,9 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") { o.maxTypeLength = 30; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); } else { diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index cd0a06308..1eaec909c 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -357,8 +357,6 @@ TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) _(_) )"); - - LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index e18a73788..ea6fff773 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -132,22 +132,40 @@ TEST_CASE_FIXTURE(Fixture, "should_still_pick_an_overload_whose_arguments_are_un TEST_CASE_FIXTURE(Fixture, "propagates_name") { - const std::string code = R"( - type A={a:number} - type B={b:string} + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CheckResult result = check(R"( + type A={a:number} + type B={b:string} - local c:A&B - local b = c - )"; - const std::string expected = R"( - type A={a:number} - type B={b:string} + local c:A&B + local b = c + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("{| a: number, b: string |}" == toString(requireType("b"))); + } + else + { + const std::string code = R"( + type A={a:number} + type B={b:string} + + local c:A&B + local b = c + )"; + + const std::string expected = R"( + type A={a:number} + type B={b:string} - local c:A&B - local b:A&B=c - )"; + local c:A&B + local b:A&B=c + )"; - CHECK_EQ(expected, decorateWithTypes(code)); + CHECK_EQ(expected, decorateWithTypes(code)); + } } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guaranteed_to_exist") @@ -161,17 +179,10 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guarante )"); LUAU_REQUIRE_NO_ERRORS(result); - - const IntersectionType* r = get(requireType("r")); - REQUIRE(r); - - TableType* a = getMutable(r->parts[0]); - REQUIRE(a); - CHECK_EQ(typeChecker.numberType, a->props["y"].type); - - TableType* b = getMutable(r->parts[1]); - REQUIRE(b); - CHECK_EQ(typeChecker.numberType, b->props["y"].type); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK("{| y: number |}" == toString(requireType("r"))); + else + CHECK("{| y: number |} & {| y: number |}" == toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_depth") @@ -207,7 +218,10 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_mixed_types") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number & string", toString(requireType("r"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("never", toString(requireType("r"))); + else + CHECK_EQ("number & string", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_part_missing_the_property") @@ -387,7 +401,10 @@ local a: XYZ = 3 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into '{| x: number, y: number, z: number |}')"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' caused by: Not all intersection parts are compatible. Type 'number' could not be converted into 'X')"); } @@ -404,7 +421,11 @@ local b: number = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(toString(result.errors[0]), R"(Type '{| x: number, y: number, z: number |}' could not be converted into 'number')"); + else + CHECK_EQ( + toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); } TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") @@ -444,7 +465,11 @@ TEST_CASE_FIXTURE(Fixture, "intersect_bool_and_false") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible"); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(toString(result.errors[0]), "Type 'false' could not be converted into 'true'"); + else + CHECK_EQ( + toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") @@ -456,9 +481,14 @@ TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - // TODO: odd stringification of `false & (boolean & false)`.) - CHECK_EQ(toString(result.errors[0]), - "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(toString(result.errors[0]), "Type 'false' could not be converted into 'true'"); + else + { + // TODO: odd stringification of `false & (boolean & false)`.) + CHECK_EQ(toString(result.errors[0]), + "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); + } } TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") @@ -496,8 +526,21 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into " - "'{| p: nil |}'; none of the intersection parts are compatible"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(result.errors[0]), + "Type '{| p: number?, q: nil, r: number? |}' could not be converted into '{| p: nil |}'\n" + "caused by:\n" + " Property 'p' is not compatible. Type 'number?' could not be converted into 'nil'\n" + "caused by:\n" + " Not all union options are compatible. Type 'number' could not be converted into 'nil' in an invariant context"); + } + else + { + CHECK_EQ(toString(result.errors[0]), + "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into " + "'{| p: nil |}'; none of the intersection parts are compatible"); + } } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") @@ -508,9 +551,35 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") local z : { p : string?, q : number? } = x -- Not OK )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, " - "q: number? |}'; none of the intersection parts are compatible"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(toString(result.errors[0]), + "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" + "caused by:\n" + " Property 'p' is not compatible. Type 'number?' could not be converted into 'string?'\n" + "caused by:\n" + " Not all union options are compatible. Type 'number' could not be converted into 'string?'\n" + "caused by:\n" + " None of the union options are compatible. For example: Type 'number' could not be converted into 'string' in an invariant context"); + + CHECK_EQ(toString(result.errors[1]), + "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" + "caused by:\n" + " Property 'q' is not compatible. Type 'string?' could not be converted into 'number?'\n" + "caused by:\n" + " Not all union options are compatible. Type 'string' could not be converted into 'number?'\n" + "caused by:\n" + " None of the union options are compatible. For example: Type 'string' could not be converted into 'number' in an invariant context"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, " + "q: number? |}'; none of the intersection parts are compatible"); + } } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") @@ -537,9 +606,18 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into " - "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(result.errors[0]), + "Type '((number?) -> {| p: number, q: number |}) & ((string?) -> {| p: number, r: number |})' could not be converted into " + "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); + } + else + { + CHECK_EQ(toString(result.errors[0]), + "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into " + "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); + } } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") @@ -840,7 +918,7 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("({| x: number |} & {| x: string |}) -> {| x: number |} & {| x: string |}", toString(requireType("f"))); + CHECK_EQ("(never) -> never", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") @@ -856,7 +934,7 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("({| x: number |} & {| x: string |}) -> number & string", toString(requireType("f"))); + CHECK_EQ("(never) -> never", toString(requireType("f"))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 7a89df96c..30cbe1d5b 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -661,4 +661,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_ok_with_inference") CHECK(toString(requireType("b")) == "string"); } +TEST_CASE_FIXTURE(Fixture, "for_loop_lower_bound_is_string") +{ + CheckResult result = check(R"( + for i: unknown = 1, 10 do end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "for_loop_lower_bound_is_string_2") +{ + CheckResult result = check(R"( + for i: never = 1, 10 do end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'never'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "for_loop_lower_bound_is_string_3") +{ + CheckResult result = check(R"( + for i: number | string = 1, 10 do end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index d4350fdea..feb04c29b 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -25,16 +25,8 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types") local x:string|number = s )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(*requireType("s")), "(string & ~(false?)) | number"); - CHECK_EQ(toString(*requireType("x")), "number | string"); - } - else - { - CHECK_EQ(toString(*requireType("s")), "number | string"); - CHECK_EQ(toString(*requireType("x")), "number | string"); - } + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("x")), "number | string"); } TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") @@ -45,16 +37,8 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") local y = x or "s" )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(*requireType("s")), "(string & ~(false?)) | number"); - CHECK_EQ(toString(*requireType("y")), "((number | string) & ~(false?)) | string"); - } - else - { - CHECK_EQ(toString(*requireType("s")), "number | string"); - CHECK_EQ(toString(*requireType("y")), "number | string"); - } + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("y")), "number | string"); } TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") @@ -78,14 +62,7 @@ TEST_CASE_FIXTURE(Fixture, "and_does_not_always_add_boolean") local x:boolean|number = s )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(*requireType("s")), "((false?) & string) | number"); - } - else - { - CHECK_EQ(toString(*requireType("s")), "number"); - } + CHECK_EQ(toString(*requireType("s")), "number"); } TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") @@ -104,14 +81,7 @@ TEST_CASE_FIXTURE(Fixture, "and_or_ternary") local s = (1/2) > 0.5 and "a" or 10 )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(*requireType("s")), "((((false?) & boolean) | string) & ~(false?)) | number"); - } - else - { - CHECK_EQ(toString(*requireType("s")), "number | string"); - } + CHECK_EQ(toString(*requireType("s")), "number | string"); } TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") @@ -833,14 +803,7 @@ local b: number = 1 or a TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ(typeChecker.numberType, tm->wantedType); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("((number & ~(false?)) | number)?", toString(tm->givenType)); - } - else - { - CHECK_EQ("number?", toString(tm->givenType)); - } + CHECK_EQ("number?", toString(tm->givenType)); } TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") @@ -901,14 +864,7 @@ TEST_CASE_FIXTURE(Fixture, "refine_and_or") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("((((false?) & ({| x: number? |}?)) | a) & ~(false?)) | number", toString(requireType("u"))); - } - else - { - CHECK_EQ("number", toString(requireType("u"))); - } + CHECK_EQ("number", toString(requireType("u"))); } TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") @@ -1095,20 +1051,16 @@ local z = b and 1 local w = c and 1 )"); + CHECK("number?" == toString(requireType("x"))); + CHECK("number" == toString(requireType("y"))); if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK("((false?) & (number?)) | number" == toString(requireType("x"))); - CHECK("((false?) & string) | number" == toString(requireType("y"))); - CHECK("((false?) & boolean) | number" == toString(requireType("z"))); - CHECK("((false?) & a) | number" == toString(requireType("w"))); - } + CHECK("false | number" == toString(requireType("z"))); else - { - CHECK("number?" == toString(requireType("x"))); - CHECK("number" == toString(requireType("y"))); CHECK("boolean | number" == toString(requireType("z"))); // 'false' widened to boolean + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK("((false?) & a) | number" == toString(requireType("w"))); + else CHECK("(boolean | number)?" == toString(requireType("w"))); - } } TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_or") @@ -1133,24 +1085,20 @@ local e1 = e or 'e' local f1 = f or 'f' )"); + CHECK("number | string" == toString(requireType("a1"))); + CHECK("number" == toString(requireType("b1"))); if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK("((false | number) & ~(false?)) | string" == toString(requireType("a1"))); - CHECK("((number?) & ~(false?)) | number" == toString(requireType("b1"))); - CHECK("(boolean & ~(false?)) | string" == toString(requireType("c1"))); - CHECK("(true & ~(false?)) | string" == toString(requireType("d1"))); - CHECK("(false & ~(false?)) | string" == toString(requireType("e1"))); - CHECK("(nil & ~(false?)) | string" == toString(requireType("f1"))); + CHECK("string | true" == toString(requireType("c1"))); + CHECK("string | true" == toString(requireType("d1"))); } else { - CHECK("number | string" == toString(requireType("a1"))); - CHECK("number" == toString(requireType("b1"))); CHECK("boolean | string" == toString(requireType("c1"))); // 'true' widened to boolean CHECK("boolean | string" == toString(requireType("d1"))); // 'true' widened to boolean - CHECK("string" == toString(requireType("e1"))); - CHECK("string" == toString(requireType("f1"))); } + CHECK("string" == toString(requireType("e1"))); + CHECK("string" == toString(requireType("f1"))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 937520219..7d629f715 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -594,28 +594,6 @@ return wrapStrictTable(Constants, "Constants") CHECK(get(*result)); } -// We need a simplification step to make this do the right thing. ("normalization-lite") -TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") -{ - CheckResult result = check(R"( - local function foo(t, x) - if x == "hi" or x == "bye" then - table.insert(t, x) - end - - return t - end - - local t = foo({}, "hi") - table.insert(t, "totally_unrelated_type" :: "totally_unrelated_type") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // We'd really like for this to be {string} - CHECK_EQ("{string | string}", toString(requireType("t"))); -} - namespace { struct IsSubtypeFixture : Fixture @@ -814,4 +792,44 @@ caused by: } } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") +{ + CheckResult result = check(R"( + local function foo(t, x) + if x == "hi" or x == "bye" then + table.insert(t, x) + end + + return t + end + + local t = foo({}, "hi") + table.insert(t, "totally_unrelated_type" :: "totally_unrelated_type") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("{string}", toString(requireType("t"))); + else + { + // We'd really like for this to be {string} + CHECK_EQ("{string | string}", toString(requireType("t"))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_clone_it") +{ + CheckResult result = check(R"( + local function f(x: unknown) + if typeof(x) == "table" then + local cloned: {} = table.clone(x) + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + // LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 786b07d1b..43c0b38e8 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -36,7 +36,7 @@ std::optional> magicFunctionInstanceIsA( return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } -std::vector dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) +std::vector dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) { if (ctx.callSite->args.size != 1) return {}; @@ -54,7 +54,7 @@ std::vector dcrMagicRefinementInstanceIsA(const MagicRefinementCon if (!tfun) return {}; - return {ctx.connectiveArena->proposition(*def, tfun->type)}; + return {ctx.refinementArena->proposition(*def, tfun->type)}; } struct RefinementClassFixture : BuiltinsFixture @@ -122,16 +122,8 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({5, 26}))); - } - else - { - CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26}))); - } + CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26}))); } TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint") @@ -148,16 +140,8 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({5, 26}))); - } - else - { - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); - } + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); } TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through") @@ -174,16 +158,8 @@ TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({5, 26}))); - } - else - { - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); - } + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); } TEST_CASE_FIXTURE(Fixture, "and_constraint") @@ -202,16 +178,8 @@ TEST_CASE_FIXTURE(Fixture, "and_constraint") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("(number?) & ~(false?)", toString(requireTypeAtPosition({4, 26}))); - } - else - { - CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("number", toString(requireTypeAtPosition({4, 26}))); - } + CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("number", toString(requireTypeAtPosition({4, 26}))); CHECK_EQ("string?", toString(requireTypeAtPosition({6, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({7, 26}))); @@ -236,16 +204,8 @@ TEST_CASE_FIXTURE(Fixture, "not_and_constraint") CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("(number?) & ~(false?)", toString(requireTypeAtPosition({7, 26}))); - } - else - { - CHECK_EQ("string", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("number", toString(requireTypeAtPosition({7, 26}))); - } + CHECK_EQ("string", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("number", toString(requireTypeAtPosition({7, 26}))); } TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") @@ -267,16 +227,8 @@ TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({7, 26}))); - } - else - { - CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); - } + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); } TEST_CASE_FIXTURE(Fixture, "a_and_b_or_a_and_c") @@ -297,26 +249,17 @@ TEST_CASE_FIXTURE(Fixture, "a_and_b_or_a_and_c") LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(string?) & (~(false?) | ~(false?))", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); CHECK_EQ("boolean", toString(requireTypeAtPosition({5, 28}))); - - CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28}))); - CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28}))); - CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28}))); - } else - { - CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); CHECK_EQ("true", toString(requireTypeAtPosition({5, 28}))); // oh no! :( - CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28}))); - CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28}))); - CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28}))); - } + CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28}))); + CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28}))); } TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") @@ -357,14 +300,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_if_condition_position") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("any & number", toString(requireTypeAtPosition({3, 26}))); - } - else - { - CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); - } + CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") @@ -433,8 +369,8 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK("{| x: number? |} & {| x: ~(false?) |}" == toString(requireTypeAtPosition({4, 23}))); - CHECK("(number?) & ~(false?)" == toString(requireTypeAtPosition({5, 26}))); + CHECK("{| x: number |}" == toString(requireTypeAtPosition({4, 23}))); + CHECK("number" == toString(requireTypeAtPosition({5, 26}))); } CHECK_EQ("number?", toString(requireType("bar"))); @@ -478,22 +414,11 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "((number | string)?) & unknown"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "(boolean?) & unknown"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "((number | string)?) & unknown"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "(boolean?) & unknown"); // a ~= b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b - - CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b - } + CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b } TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") @@ -510,16 +435,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & unknown"); // a == 1 - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & unknown"); // a ~= 1 - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1; - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1; + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 } TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") @@ -538,8 +455,8 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello" & ((number | string)?))"); // a == "hello" - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"(((number | string)?) & ~"hello")"); // a ~= "hello" + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello" + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"(((string & ~"hello") | number)?)"); // a ~= "hello" } else { @@ -562,16 +479,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & ~nil"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & unknown"); // a == nil - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil } TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") @@ -586,17 +495,8 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - ToStringOptions opts; - CHECK_EQ(toString(requireTypeAtPosition({3, 33}), opts), "a & unknown"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36}), opts), "(string?) & unknown"); // a == b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b - } + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b } TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") @@ -611,16 +511,8 @@ TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_e LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any & unknown"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "({| x: number |}?) & unknown"); // a ~= b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b - } + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b } TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") @@ -639,22 +531,11 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string & unknown"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "(string?) & unknown"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string & unknown"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "(string?) & unknown"); // a == b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b - - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b - } + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b } TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") @@ -729,16 +610,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_not_to_be_string") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(boolean | number | string) & ~string", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" - CHECK_EQ("(boolean | number | string) & string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" - } - else - { - CHECK_EQ("boolean | number", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" - } + CHECK_EQ("boolean | number", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_narrows_for_table") @@ -773,16 +646,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_narrows_for_functions") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(((number) -> string) | string) & function", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" - CHECK_EQ("(((number) -> string) | string) & ~function", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" - } - else - { - CHECK_EQ("(number) -> string", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" - } + CHECK_EQ("(number) -> string", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" } TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_can_filter_for_intersection_of_tables") @@ -821,16 +686,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_can_filter_for_overloaded_functio LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("((((number) -> string) & ((string) -> number))?) & function", toString(requireTypeAtPosition({4, 28}))); - CHECK_EQ("((((number) -> string) & ((string) -> number))?) & ~function", toString(requireTypeAtPosition({6, 28}))); - } - else - { - CHECK_EQ("((number) -> string) & ((string) -> number)", toString(requireTypeAtPosition({4, 28}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); - } + CHECK_EQ("((number) -> string) & ((string) -> number)", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_narrowed_into_nothingness") @@ -898,16 +755,8 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({4, 28}))); - } - else - { - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); - } + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); } TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") @@ -923,16 +772,8 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({4, 28}))); - } - else - { - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); - } + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "either_number_or_string") @@ -947,14 +788,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "either_number_or_string") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(number | string) & any", toString(requireTypeAtPosition({3, 28}))); - } - else - { - CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); - } + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") @@ -984,16 +818,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_a_to_be_truthy_then_assert_a_to_be_nu LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("((number | string)?) & ~(false?)", toString(requireTypeAtPosition({3, 18}))); - CHECK_EQ("((number | string)?) & ~(false?) & number", toString(requireTypeAtPosition({5, 18}))); - } - else - { - CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 18}))); - CHECK_EQ("number", toString(requireTypeAtPosition({5, 18}))); - } + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 18}))); + CHECK_EQ("number", toString(requireTypeAtPosition({5, 18}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") @@ -1012,14 +838,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "merge_should_be_fully_agnostic_of_hashmap_or LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(string | table) & (string | {| x: string |}) & string", toString(requireTypeAtPosition({6, 28}))); - } - else - { - CHECK_EQ("string", toString(requireTypeAtPosition({6, 28}))); - } + CHECK_EQ("string", toString(requireTypeAtPosition({6, 28}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") @@ -1036,16 +855,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_the_correct_types_opposite_of_when_a_ LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(boolean | number | string) & ~number & ~string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("(boolean | number | string) & (number | string)", toString(requireTypeAtPosition({5, 28}))); - } - else - { - CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); - } + CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_ifelse_expression") @@ -1058,16 +869,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_ifelse_expression") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({2, 29}))); - CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({2, 45}))); - } - else - { - CHECK_EQ("string", toString(requireTypeAtPosition({2, 29}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({2, 45}))); - } + CHECK_EQ("string", toString(requireTypeAtPosition({2, 29}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({2, 45}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_ifelse_expression") @@ -1080,16 +883,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_ifelse_expressio LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({2, 42}))); - CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({2, 50}))); - } - else - { - CHECK_EQ("nil", toString(requireTypeAtPosition({2, 42}))); - CHECK_EQ("string", toString(requireTypeAtPosition({2, 50}))); - } + CHECK_EQ("nil", toString(requireTypeAtPosition({2, 42}))); + CHECK_EQ("string", toString(requireTypeAtPosition({2, 50}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") @@ -1106,16 +901,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("any & number", toString(requireTypeAtPosition({6, 49}))); - CHECK_EQ("any & ~number", toString(requireTypeAtPosition({6, 66}))); - } + CHECK_EQ("~number", toString(requireTypeAtPosition({6, 66}))); else - { - CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); - } } TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") @@ -1196,17 +986,11 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(R"(({| tag: "exists", x: string |} | {| tag: "missing", x: nil |}) & {| x: ~(false?) |})", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ( - R"(({| tag: "exists", x: string |} | {| tag: "missing", x: nil |}) & {| x: ~~(false?) |})", toString(requireTypeAtPosition({7, 28}))); - } + CHECK_EQ(R"({| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); else - { - CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); CHECK_EQ(R"({| tag: "exists", x: string |} | {| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); - } } TEST_CASE_FIXTURE(Fixture, "discriminate_tag") @@ -1229,8 +1013,8 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag") if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(R"((Cat | Dog) & {| tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ(R"((Cat | Dog) & {| tag: ~"Cat" |} & {| tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); + CHECK_EQ(R"({| catfood: string, name: string, tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ(R"({| dogfood: string, name: string, tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); } else { @@ -1259,8 +1043,8 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag_with_implicit_else") if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(R"((Cat | Dog) & {| tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ(R"((Cat | Dog) & {| tag: ~"Cat" |})", toString(requireTypeAtPosition({9, 33}))); + CHECK_EQ(R"({| catfood: string, name: string, tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ(R"({| dogfood: string, name: string, tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); } else { @@ -1294,16 +1078,8 @@ TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("boolean & ~(false?)", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("boolean & ~~(false?)", toString(requireTypeAtPosition({5, 28}))); - } - else - { - CHECK_EQ("true", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("false", toString(requireTypeAtPosition({5, 28}))); - } + CHECK_EQ("true", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("false", toString(requireTypeAtPosition({5, 28}))); } TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") @@ -1355,16 +1131,8 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(R"(({| tag: "Folder", x: Folder |} | {| tag: "Part", x: Part |}) & {| x: Part |})", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ(R"(({| tag: "Folder", x: Folder |} | {| tag: "Part", x: Part |}) & {| x: ~Part |})", toString(requireTypeAtPosition({7, 28}))); - } - else - { - CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28}))); - } + CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28}))); } TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") @@ -1406,16 +1174,8 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(Instance | Vector3) & Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("(Instance | Vector3) & ~Vector3", toString(requireTypeAtPosition({5, 28}))); - } - else - { - CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); - } + CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); } TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") @@ -1452,8 +1212,8 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_but_the_discriminant_type if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ("(Instance | Vector3 | number | string) & never", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("(Instance | Vector3 | number | string) & ~never", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance | Vector3 | number | string", toString(requireTypeAtPosition({5, 28}))); } else { @@ -1476,16 +1236,8 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(Folder | Part | string) & Instance", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("(Folder | Part | string) & ~Instance", toString(requireTypeAtPosition({5, 28}))); - } - else - { - CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); - } + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); } TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_from_subclasses_of_instance_or_string_or_vector3") @@ -1502,16 +1254,8 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_from_subclasses_of_instance_or LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(Folder | Part | Vector3 | string) & Instance", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("(Folder | Part | Vector3 | string) & ~Instance", toString(requireTypeAtPosition({5, 28}))); - } - else - { - CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Vector3 | string", toString(requireTypeAtPosition({5, 28}))); - } + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Vector3 | string", toString(requireTypeAtPosition({5, 28}))); } TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") @@ -1556,16 +1300,8 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_param_of_type_instance_without LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("Folder & Instance", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Instance & ~Folder & table", toString(requireTypeAtPosition({5, 28}))); - } - else - { - CHECK_EQ("Folder", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("never", toString(requireTypeAtPosition({5, 28}))); - } + CHECK_EQ("Folder", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("never", toString(requireTypeAtPosition({5, 28}))); } TEST_CASE_FIXTURE(RefinementClassFixture, "refine_param_of_type_folder_or_part_without_using_typeof") @@ -1582,16 +1318,8 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_param_of_type_folder_or_part_w LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("(Folder | Part) & Folder", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("(Folder | Part) & ~Folder", toString(requireTypeAtPosition({5, 28}))); - } - else - { - CHECK_EQ("Folder", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); - } + CHECK_EQ("Folder", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); } TEST_CASE_FIXTURE(RefinementClassFixture, "isa_type_refinement_must_be_known_ahead_of_time") @@ -1610,16 +1338,8 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "isa_type_refinement_must_be_known_ahe LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("Instance", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); - } - else - { - CHECK_EQ("Instance", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); - } + CHECK_EQ("Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); } TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") @@ -1673,8 +1393,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ("unknown & string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("unknown & ~string", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("~string", toString(requireTypeAtPosition({5, 28}))); } else { @@ -1714,14 +1434,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "what_nonsensical_condition") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("a & number & string", toString(requireTypeAtPosition({3, 28}))); - } - else - { - CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); - } + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(Fixture, "else_with_no_explicit_expression_should_also_refine_the_tagged_union") @@ -1752,7 +1465,30 @@ local _ = _ ~= _ or _ or _ end )"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // Without a realistic motivating case, it's hard to tell if it's important for this to work without errors. + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + } + else + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + local function f(x: unknown) + if typeof(x) == "table" then + local len = #x + end + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("table", toString(requireTypeAtPosition({3, 29}))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 626a4c546..e3c1ab10c 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -3435,4 +3435,62 @@ _ = _._ LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_unify_instantiated_table") +{ + ScopedFastFlag sff[]{ + {"LuauInstantiateInSubtyping", true}, + {"LuauScalarShapeUnifyToMtOwner2", true}, + {"LuauTableUnifyInstantiationFix", true}, + }; + + CheckResult result = check(R"( +function _(...) +end +local function l0():typeof(_()()[_()()[_]]) +end +return _[_()()[_]] <= _ + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_table_unify_instantiated_table_with_prop_realloc") +{ + ScopedFastFlag sff[]{ + {"LuauInstantiateInSubtyping", true}, + {"LuauScalarShapeUnifyToMtOwner2", true}, + {"LuauTableUnifyInstantiationFix", true}, + }; + + CheckResult result = check(R"( +function _(l0,l0) +do +_ = _().n0 +end +l0(_()._,_) +end +_(_,function(...) +end) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_unify_prop_realloc") +{ + // For this test, we don't need LuauInstantiateInSubtyping + ScopedFastFlag sff[]{ + {"LuauScalarShapeUnifyToMtOwner2", true}, + {"LuauTableUnifyInstantiationFix", true}, + }; + + CheckResult result = check(R"( +n3,_ = nil +_ = _[""]._,_[l0][_._][{[_]=_,_=_,}][_G].number +_ = {_,} + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 5486b9699..78eb6d477 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1014,7 +1014,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_packs_with_tails_in_vararg_adjustment") end) return result end - end + end )"); LUAU_REQUIRE_NO_ERRORS(result); diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index 6caa46eeb..6bfb93b2a 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -116,11 +116,23 @@ TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable" local x, y, z = f() )"); - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("never", toString(requireType("x"))); - CHECK_EQ("never", toString(requireType("y"))); - CHECK_EQ("never", toString(requireType("z"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Function only returns 2 values, but 3 are required here", toString(result.errors[0])); + + CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("never", toString(requireType("y"))); + CHECK_EQ("*error-type*", toString(requireType("z"))); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); + CHECK_EQ("never", toString(requireType("y"))); + CHECK_EQ("never", toString(requireType("z"))); + } } TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable2") @@ -135,10 +147,20 @@ TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable2 LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("never", toString(requireType("x1"))); - CHECK_EQ("never", toString(requireType("x2"))); - CHECK_EQ("never", toString(requireType("y1"))); - CHECK_EQ("never", toString(requireType("y2"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("string", toString(requireType("x1"))); + CHECK_EQ("never", toString(requireType("x2"))); + CHECK_EQ("never", toString(requireType("y1"))); + CHECK_EQ("string", toString(requireType("y2"))); + } + else + { + CHECK_EQ("never", toString(requireType("x1"))); + CHECK_EQ("never", toString(requireType("x2"))); + CHECK_EQ("never", toString(requireType("y1"))); + CHECK_EQ("never", toString(requireType("y2"))); + } } TEST_CASE_FIXTURE(Fixture, "index_on_never") @@ -290,8 +312,14 @@ TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_i )"); LUAU_REQUIRE_NO_ERRORS(result); - // Widening doesn't normalize yet, so the result is a bit strange - CHECK_EQ("(nil, a) -> boolean | boolean", toString(requireType("ord"))); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("(nil, a) -> boolean", toString(requireType("ord"))); + else + { + // Widening doesn't normalize yet, so the result is a bit strange + CHECK_EQ("(nil, a) -> boolean | boolean", toString(requireType("ord"))); + } } TEST_CASE_FIXTURE(Fixture, "math_operators_and_never") diff --git a/tests/TypeReduction.test.cpp b/tests/TypeReduction.test.cpp index e078f25f2..582725b74 100644 --- a/tests/TypeReduction.test.cpp +++ b/tests/TypeReduction.test.cpp @@ -482,6 +482,24 @@ TEST_CASE_FIXTURE(ReductionFixture, "intersections_without_negations") CHECK("{| [string]: number, p: string |}" == toStringFull(ty)); } + SUBCASE("array_number_and_array_string") + { + TypeId ty = reductionof("{number} & {string}"); + CHECK("{never}" == toStringFull(ty)); + } + + SUBCASE("array_string_and_array_string") + { + TypeId ty = reductionof("{string} & {string}"); + CHECK("{string}" == toStringFull(ty)); + } + + SUBCASE("array_string_or_number_and_array_string") + { + TypeId ty = reductionof("{string | number} & {string}"); + CHECK("{string}" == toStringFull(ty)); + } + SUBCASE("fresh_type_and_string") { TypeId freshTy = arena.freshType(nullptr); @@ -690,7 +708,7 @@ TEST_CASE_FIXTURE(ReductionFixture, "intersections_with_negations") SUBCASE("string_and_not_error") { TypeId ty = reductionof("string & Not"); - CHECK("string & ~*error-type*" == toStringFull(ty)); + CHECK("string" == toStringFull(ty)); } SUBCASE("table_p_string_and_table_p_not_number") @@ -711,6 +729,12 @@ TEST_CASE_FIXTURE(ReductionFixture, "intersections_with_negations") CHECK("{| x: {| p: string |} |}" == toStringFull(ty)); } + SUBCASE("table_or_nil_and_truthy") + { + TypeId ty = reductionof("({ x: number | string }?) & Not"); + CHECK("{| x: number | string |}" == toString(ty)); + } + SUBCASE("not_top_table_and_table") { TypeId ty = reductionof("Not & {}"); @@ -1251,6 +1275,12 @@ TEST_CASE_FIXTURE(ReductionFixture, "tables") TypeId ty = reductionof("{ x: { y: string & number } }"); CHECK("never" == toStringFull(ty)); } + + SUBCASE("array_of_never") + { + TypeId ty = reductionof("{never}"); + CHECK("{never}" == toStringFull(ty)); + } } TEST_CASE_FIXTURE(ReductionFixture, "metatables") diff --git a/tools/faillist.txt b/tools/faillist.txt index 198483878..37666878a 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,6 +1,5 @@ AnnotationTests.corecursive_types_error_on_tight_loop AnnotationTests.duplicate_type_param_name -AnnotationTests.for_loop_counter_annotation_is_checked AnnotationTests.generic_aliases_are_cloned_properly AnnotationTests.occurs_check_on_cyclic_intersection_type AnnotationTests.occurs_check_on_cyclic_union_type @@ -18,12 +17,8 @@ AutocompleteTest.keyword_methods AutocompleteTest.no_incompatible_self_calls AutocompleteTest.no_wrong_compatible_self_calls_with_generics AutocompleteTest.string_singleton_as_table_key -AutocompleteTest.suggest_external_module_type AutocompleteTest.suggest_table_keys -AutocompleteTest.type_correct_argument_type_suggestion AutocompleteTest.type_correct_expected_argument_type_pack_suggestion -AutocompleteTest.type_correct_expected_argument_type_suggestion -AutocompleteTest.type_correct_expected_argument_type_suggestion_optional AutocompleteTest.type_correct_expected_argument_type_suggestion_self AutocompleteTest.type_correct_expected_return_type_pack_suggestion AutocompleteTest.type_correct_expected_return_type_suggestion @@ -118,37 +113,28 @@ ParserTests.parse_nesting_based_end_detection_failsafe_earlier ParserTests.parse_nesting_based_end_detection_local_function ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.bail_early_if_unification_is_too_complicated -ProvisionalTests.discriminate_from_x_not_equal_to_nil ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean ProvisionalTests.free_options_cannot_be_unified_together ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns -ProvisionalTests.lvalue_equals_another_lvalue_with_no_overlap ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing +ProvisionalTests.refine_unknown_to_table_then_clone_it ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weirditer_should_not_loop_forever RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string -RefinementTest.call_an_incompatible_function_after_using_typeguard -RefinementTest.correctly_lookup_property_whose_base_was_previously_refined2 -RefinementTest.discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false -RefinementTest.discriminate_tag RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true -RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table -RefinementTest.refine_unknowns RefinementTest.type_guard_can_filter_for_intersection_of_tables -RefinementTest.type_guard_narrowed_into_nothingness RefinementTest.type_narrow_for_all_the_userdata RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position -RefinementTest.typeguard_narrows_for_table RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RefinementTest.x_is_not_instance_or_else_not_part RuntimeLimits.typescript_port_of_Result_type @@ -178,6 +164,8 @@ TableTests.found_like_key_in_table_function_call TableTests.found_like_key_in_table_property_access TableTests.found_multiple_like_keys TableTests.function_calls_produces_sealed_table_given_unsealed_table +TableTests.fuzz_table_unify_instantiated_table +TableTests.fuzz_table_unify_instantiated_table_with_prop_realloc TableTests.generic_table_instantiation_potential_regression TableTests.give_up_after_one_metatable_index_look_up TableTests.indexer_on_sealed_table_must_unify_with_free_table @@ -220,9 +208,9 @@ TableTests.table_param_row_polymorphism_3 TableTests.table_simple_call TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors +TableTests.table_unification_4 TableTests.tc_member_function TableTests.tc_member_function_2 -TableTests.unification_of_unions_in_a_self_referential_type TableTests.unifying_tables_shouldnt_uaf1 TableTests.unifying_tables_shouldnt_uaf2 TableTests.used_colon_correctly @@ -357,9 +345,7 @@ TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_ TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown TypeInferOperators.operator_eq_completely_incompatible -TypeInferOperators.or_joins_types_with_no_superfluous_union TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not -TypeInferOperators.refine_and_or TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.UnknownGlobalCompoundAssign @@ -368,16 +354,8 @@ TypeInferOperators.unrelated_primitives_cannot_be_compared TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.string_index TypeInferUnknownNever.assign_to_global_which_is_never -TypeInferUnknownNever.assign_to_local_which_is_never -TypeInferUnknownNever.assign_to_prop_which_is_never -TypeInferUnknownNever.assign_to_subscript_which_is_never -TypeInferUnknownNever.call_never TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators -TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_sorta_never TypeInferUnknownNever.math_operators_and_never -TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable -TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 -TypeInferUnknownNever.unary_minus_of_never TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check TypePackTests.self_and_varargs_should_work @@ -401,24 +379,19 @@ TypePackTests.type_pack_type_parameters TypePackTests.unify_variadic_tails_in_arguments TypePackTests.unify_variadic_tails_in_arguments_free TypePackTests.variadic_packs -TypeReductionTests.negations TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch -TypeSingletons.indexing_on_string_singletons TypeSingletons.indexing_on_union_of_string_singletons TypeSingletons.overloaded_function_call_with_singletons TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.table_properties_singleton_strings_mismatch TypeSingletons.table_properties_type_error_escapes -TypeSingletons.taking_the_length_of_string_singleton TypeSingletons.taking_the_length_of_union_of_string_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeSingletons.widening_happens_almost_everywhere TypeSingletons.widening_happens_almost_everywhere_except_for_tables UnionTypes.index_on_a_union_type_with_missing_property -UnionTypes.index_on_a_union_type_with_one_optional_property -UnionTypes.index_on_a_union_type_with_one_property_of_type_any UnionTypes.optional_assignment_errors UnionTypes.optional_call_error UnionTypes.optional_field_access_error From b388e2799520a4afc4326d0e78e2380762ccbe11 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 10 Feb 2023 10:50:54 -0800 Subject: [PATCH 35/66] Sync to upstream/release/563 --- Analysis/include/Luau/Constraint.h | 3 +- Analysis/include/Luau/DcrLogger.h | 1 + Analysis/include/Luau/Refinement.h | 14 +- Analysis/include/Luau/Type.h | 13 +- Analysis/include/Luau/TypeInfer.h | 14 +- Analysis/src/BuiltinDefinitions.cpp | 12 - Analysis/src/Clone.cpp | 1 + Analysis/src/ConstraintGraphBuilder.cpp | 224 ++++++------ Analysis/src/ConstraintSolver.cpp | 121 ++++--- Analysis/src/DcrLogger.cpp | 4 + Analysis/src/Frontend.cpp | 50 +-- Analysis/src/Instantiation.cpp | 1 + Analysis/src/Module.cpp | 8 +- Analysis/src/Refinement.cpp | 5 + Analysis/src/TypeInfer.cpp | 86 +++-- Analysis/src/Unifier.cpp | 18 + CMakeLists.txt | 1 + CodeGen/include/Luau/IrAnalysis.h | 4 +- CodeGen/include/Luau/IrBuilder.h | 2 + CodeGen/include/Luau/IrData.h | 396 +++++++++++++++++++- CodeGen/include/Luau/IrDump.h | 8 +- CodeGen/include/Luau/IrUtils.h | 27 +- CodeGen/include/Luau/OptimizeFinalX64.h | 14 + CodeGen/src/CodeGen.cpp | 3 +- CodeGen/src/EmitCommonX64.cpp | 15 +- CodeGen/src/IrAnalysis.cpp | 49 ++- CodeGen/src/IrBuilder.cpp | 144 ++++---- CodeGen/src/IrDump.cpp | 93 +++-- CodeGen/src/IrLoweringX64.cpp | 167 ++++++--- CodeGen/src/IrTranslation.cpp | 238 +++++++++++- CodeGen/src/IrTranslation.h | 5 + CodeGen/src/IrUtils.cpp | 133 +++++++ CodeGen/src/OptimizeFinalX64.cpp | 111 ++++++ Common/include/Luau/Bytecode.h | 8 +- Common/include/Luau/ExperimentalFlags.h | 1 + Sources.cmake | 4 + tests/Frontend.test.cpp | 10 - tests/IrBuilder.test.cpp | 223 ++++++++++++ tests/Module.test.cpp | 6 - tests/NonstrictMode.test.cpp | 2 - tests/TypeInfer.annotations.test.cpp | 6 - tests/TypeInfer.functions.test.cpp | 2 - tests/TypeInfer.generics.test.cpp | 20 +- tests/TypeInfer.provisional.test.cpp | 4 - tests/TypeInfer.refinements.test.cpp | 60 +++- tests/TypeInfer.tables.test.cpp | 105 +++++- tests/TypeInfer.test.cpp | 14 + tools/faillist.txt | 38 +- tools/flag-bisect.py | 458 ++++++++++++++++++++++++ 49 files changed, 2433 insertions(+), 513 deletions(-) create mode 100644 CodeGen/include/Luau/OptimizeFinalX64.h create mode 100644 CodeGen/src/IrUtils.cpp create mode 100644 CodeGen/src/OptimizeFinalX64.cpp create mode 100644 tests/IrBuilder.test.cpp create mode 100644 tools/flag-bisect.py diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 8159b76b7..18ff30921 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -106,6 +106,7 @@ struct FunctionCallConstraint TypePackId argsPack; TypePackId result; class AstExprCall* callSite; + std::vector> discriminantTypes; }; // result ~ prim ExpectedType SomeSingletonType MultitonType @@ -180,7 +181,7 @@ struct Constraint Constraint& operator=(const Constraint&) = delete; NotNull scope; - Location location; // TODO: Extract this out into only the constraints that needs a location. Not all constraints needs locations. + Location location; ConstraintV c; std::vector> dependencies; diff --git a/Analysis/include/Luau/DcrLogger.h b/Analysis/include/Luau/DcrLogger.h index 30d2e15ec..45c84c66e 100644 --- a/Analysis/include/Luau/DcrLogger.h +++ b/Analysis/include/Luau/DcrLogger.h @@ -65,6 +65,7 @@ struct ConstraintBlock struct ConstraintSnapshot { std::string stringification; + Location location; std::vector blocks; }; diff --git a/Analysis/include/Luau/Refinement.h b/Analysis/include/Luau/Refinement.h index 3e1f234a1..e7d3cf23b 100644 --- a/Analysis/include/Luau/Refinement.h +++ b/Analysis/include/Luau/Refinement.h @@ -11,14 +11,20 @@ namespace Luau struct Type; using TypeId = const Type*; +struct Variadic; struct Negation; struct Conjunction; struct Disjunction; struct Equivalence; struct Proposition; -using Refinement = Variant; +using Refinement = Variant; using RefinementId = Refinement*; // Can and most likely is nullptr. +struct Variadic +{ + std::vector refinements; +}; + struct Negation { RefinementId refinement; @@ -56,13 +62,15 @@ const T* get(RefinementId refinement) struct RefinementArena { - TypedAllocator allocator; - + RefinementId variadic(const std::vector& refis); RefinementId negation(RefinementId refinement); RefinementId conjunction(RefinementId lhs, RefinementId rhs); RefinementId disjunction(RefinementId lhs, RefinementId rhs); RefinementId equivalence(RefinementId lhs, RefinementId rhs); RefinementId proposition(DefId def, TypeId discriminantTy); + +private: + TypedAllocator allocator; }; } // namespace Luau diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 6c8e1bc3a..00e6d6c65 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -263,15 +263,12 @@ using DcrMagicFunction = bool (*)(MagicFunctionCallContext); struct MagicRefinementContext { - ScopePtr scope; - NotNull cgb; - NotNull dfg; - NotNull refinementArena; - std::vector argumentRefinements; + NotNull scope; const class AstExprCall* callSite; + std::vector> discriminantTypes; }; -using DcrMagicRefinement = std::vector (*)(const MagicRefinementContext&); +using DcrMagicRefinement = void (*)(const MagicRefinementContext&); struct FunctionType { @@ -304,8 +301,8 @@ struct FunctionType TypePackId argTypes; TypePackId retTypes; MagicFunction magicFunction = nullptr; - DcrMagicFunction dcrMagicFunction = nullptr; // Fired only while solving constraints - DcrMagicRefinement dcrMagicRefinement = nullptr; // Fired only while generating constraints + DcrMagicFunction dcrMagicFunction = nullptr; + DcrMagicRefinement dcrMagicRefinement = nullptr; bool hasSelf; bool hasNoGenerics = false; }; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 4c2d38ad1..d748a1f50 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -57,6 +57,12 @@ class TimeLimitError : public InternalCompilerError } }; +enum class ValueContext +{ + LValue, + RValue +}; + // All Types are retained via Environment::types. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -119,14 +125,14 @@ struct TypeChecker std::optional expectedType); // Returns the type of the lvalue. - TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr); + TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx); // Returns the type of the lvalue. - TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx); TypeId checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); TypeId checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); - TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); - TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr, ValueContext ctx); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr, ValueContext ctx); TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 006df6e43..c17169f45 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -42,8 +42,6 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); static bool dcrMagicFunctionPack(MagicFunctionCallContext context); -static std::vector dcrMagicRefinementAssert(const MagicRefinementContext& context); - TypeId makeUnion(TypeArena& arena, std::vector&& types) { return arena.addType(UnionType{std::move(types)}); @@ -422,7 +420,6 @@ void registerBuiltinGlobals(Frontend& frontend) } attachMagicFunction(getGlobalBinding(frontend, "assert"), magicFunctionAssert); - attachDcrMagicRefinement(getGlobalBinding(frontend, "assert"), dcrMagicRefinementAssert); attachMagicFunction(getGlobalBinding(frontend, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(frontend, "select"), magicFunctionSelect); attachDcrMagicFunction(getGlobalBinding(frontend, "select"), dcrMagicFunctionSelect); @@ -624,15 +621,6 @@ static std::optional> magicFunctionAssert( return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } -static std::vector dcrMagicRefinementAssert(const MagicRefinementContext& ctx) -{ - if (ctx.argumentRefinements.empty()) - return {}; - - ctx.cgb->applyRefinements(ctx.scope, ctx.callSite->location, ctx.argumentRefinements[0]); - return {}; -} - static std::optional> magicFunctionPack( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 3dd8df870..ff8e0c3c2 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -438,6 +438,7 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl clone.genericPacks = ftv->genericPacks; clone.magicFunction = ftv->magicFunction; clone.dcrMagicFunction = ftv->dcrMagicFunction; + clone.dcrMagicRefinement = ftv->dcrMagicRefinement; clone.tags = ftv->tags; clone.argNames = ftv->argNames; result = dest.addType(std::move(clone)); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 09182f577..f773863c9 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -15,7 +15,6 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauNegatedClassTypes); -LUAU_FASTFLAG(LuauScopelessModule); LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration); namespace Luau @@ -96,6 +95,18 @@ static std::optional matchTypeGuard(const AstExprBinary* binary) }; } +static bool matchAssert(const AstExprCall& call) +{ + if (call.args.size < 1) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != "assert") + return false; + + return true; +} + namespace { @@ -198,6 +209,11 @@ static void computeRefinement(const ScopePtr& scope, RefinementId refinement, st if (!refinement) return; + else if (auto variadic = get(refinement)) + { + for (RefinementId refi : variadic->refinements) + computeRefinement(scope, refi, refis, sense, arena, eq, constraints); + } else if (auto negation = get(refinement)) return computeRefinement(scope, negation->refinement, refis, !sense, arena, eq, constraints); else if (auto conjunction = get(refinement)) @@ -546,8 +562,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (ModulePtr module = moduleResolver->getModule(moduleInfo->name)) { - scope->importedTypeBindings[name] = - FFlag::LuauScopelessModule ? module->exportedTypeBindings : module->getModuleScope()->exportedTypeBindings; + scope->importedTypeBindings[name] = module->exportedTypeBindings; if (FFlag::SupportTypeAliasGoToDeclaration) scope->importedModules[name] = moduleName; } @@ -697,18 +712,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct } else if (AstExprIndexName* indexName = function->name->as()) { - TypeId containingTableType = check(scope, indexName->expr).ty; - - // TODO look into stack utilization. This is probably ok because it scales with AST depth. - TypeId prospectiveTableType = arena->addType(TableType{TableState::Unsealed, TypeLevel{}, scope.get()}); - - NotNull prospectiveTable{getMutable(prospectiveTableType)}; - - Property& prop = prospectiveTable->props[indexName->index.value]; - prop.type = generalizedType; - prop.location = function->name->location; - - addConstraint(scope, indexName->location, SubtypeConstraint{containingTableType, prospectiveTableType}); + TypeId lvalueType = checkLValue(scope, indexName); + // TODO figure out how to populate the location field of the table Property. + addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); } else if (AstExprError* err = function->name->as()) { @@ -783,13 +789,13 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement auto [_, refinement] = check(condScope, ifStatement->condition, std::nullopt); ScopePtr thenScope = childScope(ifStatement->thenbody, scope); - applyRefinements(thenScope, Location{}, refinement); + applyRefinements(thenScope, ifStatement->condition->location, refinement); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { ScopePtr elseScope = childScope(ifStatement->elsebody, scope); - applyRefinements(elseScope, Location{}, refinementArena.negation(refinement)); + applyRefinements(elseScope, ifStatement->elseLocation.value_or(ifStatement->condition->location), refinementArena.negation(refinement)); visit(elseScope, ifStatement->elsebody); } } @@ -1059,6 +1065,10 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) { std::vector exprArgs; + + std::vector returnRefinements; + std::vector> discriminantTypes; + if (call->self) { AstExprIndexName* indexExpr = call->func->as(); @@ -1066,13 +1076,37 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa ice->ice("method call expression has no 'self'"); exprArgs.push_back(indexExpr->expr); + + if (auto def = dfg->getDef(indexExpr->expr)) + { + TypeId discriminantTy = arena->addType(BlockedType{}); + returnRefinements.push_back(refinementArena.proposition(*def, discriminantTy)); + discriminantTypes.push_back(discriminantTy); + } + else + discriminantTypes.push_back(std::nullopt); + } + + for (AstExpr* arg : call->args) + { + exprArgs.push_back(arg); + + if (auto def = dfg->getDef(arg)) + { + TypeId discriminantTy = arena->addType(BlockedType{}); + returnRefinements.push_back(refinementArena.proposition(*def, discriminantTy)); + discriminantTypes.push_back(discriminantTy); + } + else + discriminantTypes.push_back(std::nullopt); } - exprArgs.insert(exprArgs.end(), call->args.begin(), call->args.end()); Checkpoint startCheckpoint = checkpoint(this); TypeId fnType = check(scope, call->func).ty; Checkpoint fnEndCheckpoint = checkpoint(this); + module->astOriginalCallTypes[call->func] = fnType; + TypePackId expectedArgPack = arena->freshTypePack(scope.get()); TypePackId expectedRetPack = arena->freshTypePack(scope.get()); TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack}); @@ -1129,7 +1163,11 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa argumentRefinements.push_back(refinement); } else - argTail = checkPack(scope, arg, {}).tp; // FIXME? not sure about expectedTypes here + { + auto [tp, refis] = checkPack(scope, arg, {}); // FIXME? not sure about expectedTypes here + argTail = tp; + argumentRefinements.insert(argumentRefinements.end(), refis.begin(), refis.end()); + } } Checkpoint argEndCheckpoint = checkpoint(this); @@ -1140,13 +1178,6 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa constraint->dependencies.push_back(extractArgsConstraint); }); - std::vector returnRefinements; - if (auto ftv = get(follow(fnType)); ftv && ftv->dcrMagicRefinement) - { - MagicRefinementContext ctx{scope, NotNull{this}, dfg, NotNull{&refinementArena}, std::move(argumentRefinements), call}; - returnRefinements = ftv->dcrMagicRefinement(ctx); - } - if (matchSetmetatable(*call)) { TypePack argTailPack; @@ -1171,12 +1202,12 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa scope->dcrRefinements[*def] = resultTy; // TODO: typestates: track this as an assignment } - - return InferencePack{arena->addTypePack({resultTy}), std::move(returnRefinements)}; + return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; } else { - module->astOriginalCallTypes[call->func] = fnType; + if (matchAssert(*call) && !argumentRefinements.empty()) + applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); TypeId instantiatedType = arena->addType(BlockedType{}); // TODO: How do expectedTypes play into this? Do they? @@ -1200,6 +1231,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa argPack, rets, call, + std::move(discriminantTypes), }); // We force constraints produced by checking function arguments to wait @@ -1211,7 +1243,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa fcc->dependencies.emplace_back(constraint.get()); }); - return InferencePack{rets, std::move(returnRefinements)}; + return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; } } @@ -1386,74 +1418,10 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* gl return Inference{builtinTypes->errorRecoveryType()}; } -static std::optional lookupProp(TypeId ty, const std::string& propName, NotNull arena) -{ - ty = follow(ty); - - if (auto ctv = get(ty)) - { - if (auto prop = lookupClassProp(ctv, propName)) - return prop->type; - } - else if (auto ttv = get(ty)) - { - if (auto it = ttv->props.find(propName); it != ttv->props.end()) - return it->second.type; - } - else if (auto utv = get(ty)) - { - std::vector types; - - for (TypeId ty : utv) - { - if (auto prop = lookupProp(ty, propName, arena)) - { - if (std::find(begin(types), end(types), *prop) == end(types)) - types.push_back(*prop); - } - else - return std::nullopt; - } - - if (types.size() == 1) - return types[0]; - else - return arena->addType(IntersectionType{std::move(types)}); - } - else if (auto utv = get(ty)) - { - std::vector types; - - for (TypeId ty : utv) - { - if (auto prop = lookupProp(ty, propName, arena)) - { - if (std::find(begin(types), end(types), *prop) == end(types)) - types.push_back(*prop); - } - else - return std::nullopt; - } - - if (types.size() == 1) - return types[0]; - else - return arena->addType(UnionType{std::move(types)}); - } - - return std::nullopt; -} - Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr).ty; - - // HACK: We need to return the actual type for type refinements so that it can invoke the dcrMagicRefinement function. - TypeId result; - if (auto prop = lookupProp(obj, indexName->index.value, arena)) - result = *prop; - else - result = freshType(scope); + TypeId result = freshType(scope); std::optional def = dfg->getDef(indexName); if (def) @@ -1723,11 +1691,6 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) TypeId updatedType = arena->addType(BlockedType{}); addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), propTy}); - std::optional def = dfg->getDef(sym); - LUAU_ASSERT(def); - symbolScope->bindings[sym].typeId = updatedType; - symbolScope->dcrRefinements[*def] = updatedType; - TypeId prevSegmentTy = updatedType; for (size_t i = 0; i < segments.size(); ++i) { @@ -1739,7 +1702,16 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) module->astTypes[expr] = prevSegmentTy; module->astTypes[e] = updatedType; - // astTypes[expr] = propTy; + + symbolScope->bindings[sym].typeId = updatedType; + + std::optional def = dfg->getDef(sym); + if (def) + { + // This can fail if the user is erroneously trying to augment a builtin + // table like os or string. + symbolScope->dcrRefinements[*def] = updatedType; + } return propTy; } @@ -1765,9 +1737,30 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp addConstraint(scope, location, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}); }; + std::optional annotatedKeyType; + std::optional annotatedIndexResultType; + + if (expectedType) + { + if (const TableType* ttv = get(follow(*expectedType))) + { + if (ttv->indexer) + { + annotatedKeyType.emplace(follow(ttv->indexer->indexType)); + annotatedIndexResultType.emplace(ttv->indexer->indexResultType); + } + } + } + + bool isIndexedResultType = false; + std::optional pinnedIndexResultType; + + for (const AstExprTable::Item& item : expr->items) { std::optional expectedValueType; + if (item.kind == AstExprTable::Item::Kind::General || item.kind == AstExprTable::Item::Kind::List) + isIndexedResultType = true; if (item.key && expectedType) { @@ -1786,14 +1779,39 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp } } - TypeId itemTy = check(scope, item.value, expectedValueType).ty; + + // We'll resolve the expected index result type here with the following priority: + // 1. Record table types - in which key, value pairs must be handled on a k,v pair basis. + // In this case, the above if-statement will populate expectedValueType + // 2. Someone places an annotation on a General or List table + // Trust the annotation and have the solver inform them if they get it wrong + // 3. Someone omits the annotation on a general or List table + // Use the type of the first indexResultType as the expected type + std::optional checkExpectedIndexResultType; + if (expectedValueType) + { + checkExpectedIndexResultType = expectedValueType; + } + else if (annotatedIndexResultType) + { + checkExpectedIndexResultType = annotatedIndexResultType; + } + else if (pinnedIndexResultType) + { + checkExpectedIndexResultType = pinnedIndexResultType; + } + + TypeId itemTy = check(scope, item.value, checkExpectedIndexResultType).ty; + + if (isIndexedResultType && !pinnedIndexResultType) + pinnedIndexResultType = itemTy; if (item.key) { // Even though we don't need to use the type of the item's key if // it's a string constant, we still want to check it to populate // astTypes. - TypeId keyTy = check(scope, item.key).ty; + TypeId keyTy = check(scope, item.key, annotatedKeyType).ty; if (AstExprConstantString* key = item.key->as()) { diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 96d16c438..76fd0bca8 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); -LUAU_FASTFLAG(LuauScopelessModule); namespace Luau { @@ -424,9 +423,7 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo LUAU_ASSERT(false); if (success) - { unblock(constraint); - } return success; } @@ -1129,6 +1126,28 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull std::optional { + auto it = begin(t); + auto endIt = end(t); + + LUAU_ASSERT(it != endIt); + TypeId fst = follow(*it); + while (it != endIt) + { + if (follow(*it) != fst) + return std::nullopt; + ++it; + } + + return fst; + }; + + // Sometimes the `fn` type is a union/intersection, but whose constituents are all the same pointer. + if (auto ut = get(fn)) + fn = collapse(ut).value_or(fn); + else if (auto it = get(fn)) + fn = collapse(it).value_or(fn); + // We don't support magic __call metamethods. if (std::optional callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location)) { @@ -1140,69 +1159,73 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(BlockedType{}); TypeId inferredFnType = arena->addType(FunctionType(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result)); - // Alter the inner constraints. - LUAU_ASSERT(c.innerConstraints.size() == 2); - - // Anything that is blocked on this constraint must also be blocked on our inner constraints - auto blockedIt = blocked.find(constraint.get()); - if (blockedIt != blocked.end()) - { - for (const auto& ic : c.innerConstraints) - { - for (const auto& blockedConstraint : blockedIt->second) - block(ic, blockedConstraint); - } - } - asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm}; asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType}; - unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); - asMutable(c.result)->ty.emplace(constraint->scope); - unblock(c.result); - return true; } + else + { + const FunctionType* ftv = get(fn); + bool usedMagic = false; + + if (ftv) + { + if (ftv->dcrMagicFunction) + usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull(this), c.callSite, c.argsPack, result}); - const FunctionType* ftv = get(fn); - bool usedMagic = false; + if (ftv->dcrMagicRefinement) + ftv->dcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); + } - if (ftv && ftv->dcrMagicFunction != nullptr) - { - usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull(this), c.callSite, c.argsPack, result}); + if (usedMagic) + { + // There are constraints that are blocked on these constraints. If we + // are never going to even examine them, then we should not block + // anything else on them. + // + // TODO CLI-58842 +#if 0 + for (auto& c: c.innerConstraints) + unblock(c); +#endif + } + else + asMutable(c.result)->ty.emplace(constraint->scope); } - if (usedMagic) + for (std::optional ty : c.discriminantTypes) { - // There are constraints that are blocked on these constraints. If we - // are never going to even examine them, then we should not block - // anything else on them. + if (!ty || !isBlocked(*ty)) + continue; + + // We use `any` here because the discriminant type may be pointed at by both branches, + // where the discriminant type is not negated, and the other where it is negated, i.e. + // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` + // v.s. + // `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` // - // TODO CLI-58842 -#if 0 - for (auto& c: c.innerConstraints) - unblock(c); -#endif + // In practice, users cannot negate `any`, so this is an implementation detail we can always change. + *asMutable(follow(*ty)) = BoundType{builtinTypes->anyType}; } - else + + // Alter the inner constraints. + LUAU_ASSERT(c.innerConstraints.size() == 2); + + // Anything that is blocked on this constraint must also be blocked on our inner constraints + auto blockedIt = blocked.find(constraint.get()); + if (blockedIt != blocked.end()) { - // Anything that is blocked on this constraint must also be blocked on our inner constraints - auto blockedIt = blocked.find(constraint.get()); - if (blockedIt != blocked.end()) + for (const auto& ic : c.innerConstraints) { - for (const auto& ic : c.innerConstraints) - { - for (const auto& blockedConstraint : blockedIt->second) - block(ic, blockedConstraint); - } + for (const auto& blockedConstraint : blockedIt->second) + block(ic, blockedConstraint); } - - unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); - asMutable(c.result)->ty.emplace(constraint->scope); } - unblock(c.result); + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); + unblock(c.result); return true; } @@ -1930,7 +1953,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l return errorRecoveryType(); } - TypePackId modulePack = FFlag::LuauScopelessModule ? module->returnType : module->getModuleScope()->returnType; + TypePackId modulePack = module->returnType; if (get(modulePack)) return errorRecoveryType(); diff --git a/Analysis/src/DcrLogger.cpp b/Analysis/src/DcrLogger.cpp index ef33aa606..a1ef650b8 100644 --- a/Analysis/src/DcrLogger.cpp +++ b/Analysis/src/DcrLogger.cpp @@ -105,6 +105,7 @@ void write(JsonEmitter& emitter, const ConstraintSnapshot& snapshot) { ObjectEmitter o = emitter.writeObject(); o.writePair("stringification", snapshot.stringification); + o.writePair("location", snapshot.location); o.writePair("blocks", snapshot.blocks); o.finish(); } @@ -293,6 +294,7 @@ void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vec std::string id = toPointerId(c); solveLog.initialState.constraints[id] = { toString(*c.get(), opts), + c->location, snapshotBlocks(c), }; } @@ -310,6 +312,7 @@ StepSnapshot DcrLogger::prepareStepSnapshot( std::string id = toPointerId(c); constraints[id] = { toString(*c.get(), opts), + c->location, snapshotBlocks(c), }; } @@ -337,6 +340,7 @@ void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vecto std::string id = toPointerId(c); solveLog.finalState.constraints[id] = { toString(*c.get(), opts), + c->location, snapshotBlocks(c), }; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 94342cca1..a70d6dda7 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -31,7 +31,6 @@ LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(DebugLuauLogSolverToJson); -LUAU_FASTFLAG(LuauScopelessModule); namespace Luau { @@ -113,9 +112,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c CloneState cloneState; std::vector typesToPersist; - typesToPersist.reserve( - checkedModule->declaredGlobals.size() + - (FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings.size() : checkedModule->getModuleScope()->exportedTypeBindings.size())); + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); for (const auto& [name, ty] : checkedModule->declaredGlobals) { @@ -127,8 +124,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c typesToPersist.push_back(globalTy); } - for (const auto& [name, ty] : - FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings : checkedModule->getModuleScope()->exportedTypeBindings) + for (const auto& [name, ty] : checkedModule->exportedTypeBindings) { TypeFun globalTy = clone(ty, globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; @@ -173,9 +169,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t CloneState cloneState; std::vector typesToPersist; - typesToPersist.reserve( - checkedModule->declaredGlobals.size() + - (FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings.size() : checkedModule->getModuleScope()->exportedTypeBindings.size())); + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); for (const auto& [name, ty] : checkedModule->declaredGlobals) { @@ -187,8 +181,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t typesToPersist.push_back(globalTy); } - for (const auto& [name, ty] : - FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings : checkedModule->getModuleScope()->exportedTypeBindings) + for (const auto& [name, ty] : checkedModule->exportedTypeBindings) { TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; @@ -571,30 +564,17 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalinternalTypes.clear(); - if (FFlag::LuauScopelessModule) - { - module->astTypes.clear(); - module->astTypePacks.clear(); - module->astExpectedTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astOverloadResolvedTypes.clear(); - module->astResolvedTypes.clear(); - module->astOriginalResolvedTypes.clear(); - module->astResolvedTypePacks.clear(); - module->astScopes.clear(); - - module->scopes.clear(); - } - else - { - module->astTypes.clear(); - module->astExpectedTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astResolvedTypes.clear(); - module->astResolvedTypePacks.clear(); - module->astOriginalResolvedTypes.clear(); - module->scopes.resize(1); - } + module->astTypes.clear(); + module->astTypePacks.clear(); + module->astExpectedTypes.clear(); + module->astOriginalCallTypes.clear(); + module->astOverloadResolvedTypes.clear(); + module->astResolvedTypes.clear(); + module->astOriginalResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); + module->astScopes.clear(); + + module->scopes.clear(); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 912c4155b..9c3ae0771 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -47,6 +47,7 @@ TypeId Instantiation::clean(TypeId ty) FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.magicFunction = ftv->magicFunction; clone.dcrMagicFunction = ftv->dcrMagicFunction; + clone.dcrMagicRefinement = ftv->dcrMagicRefinement; clone.tags = ftv->tags; clone.argNames = ftv->argNames; TypeId result = addType(std::move(clone)); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index a9faded53..c0f4405c6 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); -LUAU_FASTFLAG(LuauScopelessModule); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); @@ -227,11 +226,8 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr } // Copy external stuff over to Module itself - if (FFlag::LuauScopelessModule) - { - this->returnType = moduleScope->returnType; - this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); - } + this->returnType = moduleScope->returnType; + this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); } bool Module::hasModuleScope() const diff --git a/Analysis/src/Refinement.cpp b/Analysis/src/Refinement.cpp index fb019f1df..459379ad9 100644 --- a/Analysis/src/Refinement.cpp +++ b/Analysis/src/Refinement.cpp @@ -29,4 +29,9 @@ RefinementId RefinementArena::proposition(DefId def, TypeId discriminantTy) return NotNull{allocator.allocate(Proposition{def, discriminantTy})}; } +RefinementId RefinementArena::variadic(const std::vector& refis) +{ + return NotNull{allocator.allocate(Variadic{refis})}; +} + } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 07bdbd4e3..e59c7e0ee 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -26,6 +26,7 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) +LUAU_FASTFLAGVARIABLE(LuauDontExtendUnsealedRValueTables, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) @@ -35,7 +36,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) -LUAU_FASTFLAGVARIABLE(LuauScopelessModule, false) LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) @@ -43,6 +43,7 @@ LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration) +LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) namespace Luau { @@ -913,7 +914,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } else { - expectedTypes.push_back(checkLValue(scope, *dest)); + expectedTypes.push_back(checkLValue(scope, *dest, ValueContext::LValue)); } } @@ -930,7 +931,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) TypeId left = nullptr; if (dest->is() || dest->is()) - left = checkLValue(scope, *dest); + left = checkLValue(scope, *dest, ValueContext::LValue); else left = *expectedTypes[i]; @@ -1119,8 +1120,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) if (ModulePtr module = resolver->getModule(moduleInfo->name)) { - scope->importedTypeBindings[name] = - FFlag::LuauScopelessModule ? module->exportedTypeBindings : module->getModuleScope()->exportedTypeBindings; + scope->importedTypeBindings[name] = module->exportedTypeBindings; if (FFlag::SupportTypeAliasGoToDeclaration) scope->importedModules[name] = moduleInfo->name; } @@ -2132,7 +2132,7 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { - TypeId ty = checkLValue(scope, expr); + TypeId ty = checkLValue(scope, expr, ValueContext::RValue); if (std::optional lvalue = tryGetLValue(expr)) if (std::optional refiTy = resolveLValue(scope, *lvalue)) @@ -2977,14 +2977,23 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; + if (!FFlag::LuauTypecheckTypeguards) + { + if (auto predicate = tryGetTypeGuardPredicate(expr)) + return {booleanType, {std::move(*predicate)}}; + } // For these, passing expectedType is worse than simply forcing them, because their implementation // may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first. WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); + if (FFlag::LuauTypecheckTypeguards) + { + if (auto predicate = tryGetTypeGuardPredicate(expr)) + return {booleanType, {std::move(*predicate)}}; + } + PredicateVec predicates; if (auto lvalue = tryGetLValue(*expr.left)) @@ -3068,21 +3077,21 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {stringType}; } -TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) +TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx) { - return checkLValueBinding(scope, expr); + return checkLValueBinding(scope, expr, ctx); } -TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx) { if (auto a = expr.as()) return checkLValueBinding(scope, *a); else if (auto a = expr.as()) return checkLValueBinding(scope, *a); else if (auto a = expr.as()) - return checkLValueBinding(scope, *a); + return checkLValueBinding(scope, *a, ctx); else if (auto a = expr.as()) - return checkLValueBinding(scope, *a); + return checkLValueBinding(scope, *a, ctx); else if (auto a = expr.as()) { for (AstExpr* expr : a->expressions) @@ -3130,7 +3139,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGloba return result; } -TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr, ValueContext ctx) { TypeId lhs = checkExpr(scope, *expr.expr).type; @@ -3153,7 +3162,15 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { return it->second.type; } - else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + else if (!FFlag::LuauDontExtendUnsealedRValueTables && (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free)) + { + TypeId theType = freshType(scope); + Property& property = lhsTable->props[name]; + property.type = theType; + property.location = expr.indexLocation; + return theType; + } + else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free)) { TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; @@ -3216,7 +3233,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex return errorRecoveryType(scope); } -TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr, ValueContext ctx) { TypeId exprType = checkExpr(scope, *expr.expr).type; tablify(exprType); @@ -3274,7 +3291,15 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { return it->second.type; } - else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) + else if (!FFlag::LuauDontExtendUnsealedRValueTables && (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)) + { + TypeId resultType = freshType(scope); + Property& property = exprTable->props[value->value.data]; + property.type = resultType; + property.location = expr.index->location; + return resultType; + } + else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) { TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; @@ -3290,20 +3315,35 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex unify(indexType, indexer.indexType, scope, expr.index->location); return indexer.indexResultType; } - else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) + else if (!FFlag::LuauDontExtendUnsealedRValueTables && (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)) { TypeId resultType = freshType(exprTable->level); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return resultType; } + else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + { + TypeId indexerType = freshType(exprTable->level); + unify(indexType, indexerType, scope, expr.location); + TypeId indexResultType = freshType(exprTable->level); + + exprTable->indexer = TableIndexer{anyIfNonstrict(indexerType), anyIfNonstrict(indexResultType)}; + return indexResultType; + } else { /* - * If we use [] indexing to fetch a property from a sealed table that has no indexer, we have no idea if it will - * work, so we just mint a fresh type, return that, and hope for the best. + * If we use [] indexing to fetch a property from a sealed table that + * has no indexer, we have no idea if it will work so we just return any + * and hope for the best. */ - TypeId resultType = freshType(scope); - return resultType; + if (FFlag::LuauDontExtendUnsealedRValueTables) + return anyType; + else + { + TypeId resultType = freshType(scope); + return resultType; + } } } @@ -4508,7 +4548,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return errorRecoveryType(scope); } - TypePackId modulePack = FFlag::LuauScopelessModule ? module->returnType : module->getModuleScope()->returnType; + TypePackId modulePack = module->returnType; if (get(modulePack)) return errorRecoveryType(scope); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index d48c72f73..bda062af8 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -1038,6 +1038,24 @@ void Unifier::tryUnifyNormalizedTypes( } } + if (FFlag::DebugLuauDeferredConstraintResolution) + { + for (TypeId superTable : superNorm.tables) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify(subClass, superTable); + + if (innerState.errors.empty()) + { + found = true; + log.concat(std::move(innerState.log)); + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); + } + } + if (!found) { return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); diff --git a/CMakeLists.txt b/CMakeLists.txt index 05d701ee4..4255c7c25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -172,6 +172,7 @@ endif() if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) + target_compile_options(Luau.Reduce.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Ast.CLI PRIVATE ${LUAU_OPTIONS}) diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 7482b0ad6..0941d475d 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -8,7 +8,9 @@ namespace CodeGen struct IrFunction; -void updateUseInfo(IrFunction& function); +void updateUseCounts(IrFunction& function); + +void updateLastUseLocations(IrFunction& function); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 5b51e0ad6..ebbba6893 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -50,6 +50,8 @@ struct IrBuilder IrOp vmConst(uint32_t index); IrOp vmUpvalue(uint8_t index); + bool inTerminatedBlock = false; + bool activeFastcallFallback = false; IrOp fastcallFallbackReturn; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 1c70c8017..28f5b29bd 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -16,31 +16,90 @@ namespace Luau namespace CodeGen { +// IR instruction command. +// In the command description, following abbreviations are used: +// * Rn - VM stack register slot, n in 0..254 +// * Kn - VM proto constant slot, n in 0..2^23-1 +// * UPn - VM function upvalue slot, n in 0..254 +// * A, B, C, D, E are instruction arguments enum class IrCmd : uint8_t { NOP, + // Load a tag from TValue + // A: Rn or Kn LOAD_TAG, + + // Load a pointer (*) from TValue + // A: Rn or Kn LOAD_POINTER, + + // Load a double number from TValue + // A: Rn or Kn LOAD_DOUBLE, + + // Load an int from TValue + // A: Rn LOAD_INT, + + // Load a TValue from memory + // A: Rn or Kn or pointer (TValue) LOAD_TVALUE, + + // Load a TValue from table node value + // A: pointer (LuaNode) LOAD_NODE_VALUE_TV, // TODO: we should find a way to generalize LOAD_TVALUE + + // Load current environment table LOAD_ENV, + // Get pointer (TValue) to table array at index + // A: pointer (Table) + // B: unsigned int GET_ARR_ADDR, + + // Get pointer (LuaNode) to table node element at the active cached slot index + // A: pointer (Table) GET_SLOT_NODE_ADDR, + // Store a tag into TValue + // A: Rn + // B: tag STORE_TAG, + + // Store a pointer (*) into TValue + // A: Rn + // B: pointer STORE_POINTER, + + // Store a double number into TValue + // A: Rn + // B: double STORE_DOUBLE, + + // Store an int into TValue + // A: Rn + // B: int STORE_INT, + + // Store a TValue into memory + // A: Rn or pointer (TValue) + // B: TValue STORE_TVALUE, + + // Store a TValue into table node value + // A: pointer (LuaNode) + // B: TValue STORE_NODE_VALUE_TV, // TODO: we should find a way to generalize STORE_TVALUE + // Add/Sub two integers together + // A, B: int ADD_INT, SUB_INT, + // Add/Sub/Mul/Div/Mod/Pow two double numbers + // A, B: double + // In final x64 lowering, B can also be Rn or Kn ADD_NUM, SUB_NUM, MUL_NUM, @@ -48,91 +107,351 @@ enum class IrCmd : uint8_t MOD_NUM, POW_NUM, + // Negate a double number + // A: double UNM_NUM, + // Compute Luau 'not' operation on destructured TValue + // A: tag + // B: double NOT_ANY, // TODO: boolean specialization will be useful + // Unconditional jump + // A: block JUMP, + + // Jump if TValue is truthy + // A: Rn + // B: block (if true) + // C: block (if false) JUMP_IF_TRUTHY, + + // Jump if TValue is falsy + // A: Rn + // B: block (if true) + // C: block (if false) JUMP_IF_FALSY, + + // Jump if tags are equal + // A, B: tag + // C: block (if true) + // D: block (if false) JUMP_EQ_TAG, - JUMP_EQ_BOOLEAN, + + // Jump if two int numbers are equal + // A, B: int + // C: block (if true) + // D: block (if false) + JUMP_EQ_INT, + + // Jump if pointers are equal + // A, B: pointer (*) + // C: block (if true) + // D: block (if false) JUMP_EQ_POINTER, + // Perform a conditional jump based on the result of double comparison + // A, B: double + // C: condition + // D: block (if true) + // E: block (if false) JUMP_CMP_NUM, - JUMP_CMP_STR, + + // Perform a conditional jump based on the result of TValue comparison + // A, B: Rn + // C: condition + // D: block (if true) + // E: block (if false) JUMP_CMP_ANY, + // Get table length + // A: pointer (Table) TABLE_LEN, + + // Allocate new table + // A: int (array element count) + // B: int (node element count) NEW_TABLE, + + // Duplicate a table + // A: pointer (Table) DUP_TABLE, + // Try to convert a double number into a table index or jump if it's not an integer + // A: double + // B: block NUM_TO_INDEX, + // Convert integer into a double number + // A: int + INT_TO_NUM, + // Fallback functions + + // Perform an arithmetic operation on TValues of any type + // A: Rn (where to store the result) + // B: Rn (lhs) + // C: Rn or Kn (rhs) DO_ARITH, + + // Get length of a TValue of any type + // A: Rn (where to store the result) + // B: Rn DO_LEN, + + // Lookup a value in TValue of any type using a key of any type + // A: Rn (where to store the result) + // B: Rn + // C: Rn or unsigned int (key) GET_TABLE, + + // Store a value into TValue of any type using a key of any type + // A: Rn (value to store) + // B: Rn + // C: Rn or unsigned int (key) SET_TABLE, + + // Lookup a value in the environment + // A: Rn (where to store the result) + // B: unsigned int (import path) GET_IMPORT, + + // Concatenate multiple TValues + // A: Rn (where to store the result) + // B: unsigned int (index of the first VM stack slot) + // C: unsigned int (number of stack slots to go over) CONCAT, + + // Load function upvalue into stack slot + // A: Rn + // B: UPn GET_UPVALUE, + + // Store TValue from stack slot into a function upvalue + // A: UPn + // B: Rn SET_UPVALUE, - // Guards and checks + // Convert TValues into numbers for a numerical for loop + // A: Rn (start) + // B: Rn (end) + // C: Rn (step) + PREPARE_FORN, + + // Guards and checks (these instructions are not block terminators even though they jump to fallback) + + // Guard against tag mismatch + // A, B: tag + // C: block + // In final x64 lowering, A can also be Rn CHECK_TAG, + + // Guard against readonly table + // A: pointer (Table) + // B: block CHECK_READONLY, + + // Guard against table having a metatable + // A: pointer (Table) + // B: block CHECK_NO_METATABLE, + + // Guard against executing in unsafe environment + // A: block CHECK_SAFE_ENV, + + // Guard against index overflowing the table array size + // A: pointer (Table) + // B: block CHECK_ARRAY_SIZE, + + // Guard against cached table node slot not matching the actual table node slot for a key + // A: pointer (LuaNode) + // B: Kn + // C: block CHECK_SLOT_MATCH, // Special operations + + // Check interrupt handler + // A: unsigned int (pcpos) INTERRUPT, + + // Check and run GC assist if necessary CHECK_GC, + + // Handle GC write barrier (forward) + // A: pointer (GCObject) + // B: Rn (TValue that was written to the object) BARRIER_OBJ, + + // Handle GC write barrier (backwards) for a write into a table + // A: pointer (Table) BARRIER_TABLE_BACK, + + // Handle GC write barrier (forward) for a write into a table + // A: pointer (Table) + // B: Rn (TValue that was written to the object) BARRIER_TABLE_FORWARD, + + // Update savedpc value + // A: unsigned int (pcpos) SET_SAVEDPC, + + // Close open upvalues for registers at specified index or higher + // A: Rn (starting register index) CLOSE_UPVALS, // While capture is a no-op right now, it might be useful to track register/upvalue lifetimes + // A: Rn or UPn + // B: boolean (true for reference capture, false for value capture) CAPTURE, // Operations that don't have an IR representation yet + + // Set a list of values to table in target register + // A: unsigned int (bytecode instruction index) + // B: Rn (target) + // C: Rn (source start) + // D: int (count or -1 to assign values up to stack top) + // E: unsigned int (table index to start from) LOP_SETLIST, + + // Load function from source register using name into target register and copying source register into target register + 1 + // A: unsigned int (bytecode instruction index) + // B: Rn (target) + // C: Rn (source) + // D: block (next) + // E: block (fallback) LOP_NAMECALL, + + // Call specified function + // A: unsigned int (bytecode instruction index) + // B: Rn (function, followed by arguments) + // C: int (argument count or -1 to preserve all arguments up to stack top) + // D: int (result count or -1 to preserve all results and adjust stack top) + // Note: return values are placed starting from Rn specified in 'B' LOP_CALL, + + // Return specified values from the function + // A: unsigned int (bytecode instruction index) + // B: Rn (value start) + // B: int (result count or -1 to return all values up to stack top) LOP_RETURN, + + // Perform a fast call of a built-in function + // A: unsigned int (bytecode instruction index) + // B: Rn (argument start) + // C: int (argument count or -1 preserve all arguments up to stack top) + // D: block (fallback) + // Note: return values are placed starting from Rn specified in 'B' LOP_FASTCALL, + + // Perform a fast call of a built-in function using 1 register argument + // A: unsigned int (bytecode instruction index) + // B: Rn (result start) + // C: Rn (arg1) + // D: block (fallback) LOP_FASTCALL1, + + // Perform a fast call of a built-in function using 2 register arguments + // A: unsigned int (bytecode instruction index) + // B: Rn (result start) + // C: Rn (arg1) + // D: Rn (arg2) + // E: block (fallback) LOP_FASTCALL2, + + // Perform a fast call of a built-in function using 1 register argument and 1 constant argument + // A: unsigned int (bytecode instruction index) + // B: Rn (result start) + // C: Rn (arg1) + // D: Kn (arg2) + // E: block (fallback) LOP_FASTCALL2K, - LOP_FORNPREP, - LOP_FORNLOOP, + LOP_FORGLOOP, LOP_FORGLOOP_FALLBACK, - LOP_FORGPREP_NEXT, - LOP_FORGPREP_INEXT, LOP_FORGPREP_XNEXT_FALLBACK, + + // Perform `and` or `or` operation (selecting lhs or rhs based on whether the lhs is truthy) and put the result into target register + // A: unsigned int (bytecode instruction index) + // B: Rn (target) + // C: Rn (lhs) + // D: Rn or Kn (rhs) LOP_AND, LOP_ANDK, LOP_OR, LOP_ORK, + + // Increment coverage data (saturating 24 bit add) + // A: unsigned int (bytecode instruction index) LOP_COVERAGE, // Operations that have a translation, but use a full instruction fallback + + // Load a value from global table at specified key + // A: unsigned int (bytecode instruction index) + // B: Rn (dest) + // C: Kn (key) FALLBACK_GETGLOBAL, + + // Store a value into global table at specified key + // A: unsigned int (bytecode instruction index) + // B: Rn (value) + // C: Kn (key) FALLBACK_SETGLOBAL, + + // Load a value from table at specified key + // A: unsigned int (bytecode instruction index) + // B: Rn (dest) + // C: Rn (table) + // D: Kn (key) FALLBACK_GETTABLEKS, + + // Store a value into a table at specified key + // A: unsigned int (bytecode instruction index) + // B: Rn (value) + // C: Rn (table) + // D: Kn (key) FALLBACK_SETTABLEKS, + + // Load function from source register using name into target register and copying source register into target register + 1 + // A: unsigned int (bytecode instruction index) + // B: Rn (target) + // C: Rn (source) + // D: Kn (name) FALLBACK_NAMECALL, // Operations that don't have assembly lowering at all + + // Prepare stack for variadic functions so that GETVARARGS works correctly + // A: unsigned int (bytecode instruction index) + // B: int (numparams) FALLBACK_PREPVARARGS, + + // Copy variables into the target registers from vararg storage for current function + // A: unsigned int (bytecode instruction index) + // B: Rn (dest start) + // C: int (count) FALLBACK_GETVARARGS, + + // Create closure from a child proto + // A: unsigned int (bytecode instruction index) + // B: Rn (dest) + // C: unsigned int (protoid) FALLBACK_NEWCLOSURE, + + // Create closure from a pre-created function object (reusing it unless environments diverge) + // A: unsigned int (bytecode instruction index) + // B: Rn (dest) + // C: Kn (prototype) FALLBACK_DUPCLOSURE, + + // Prepare loop variables for a generic for loop, jump to the loop backedge unconditionally + // A: unsigned int (bytecode instruction index) + // B: Rn (loop state, updates Rn Rn+1 Rn+2) + // B: block FALLBACK_FORGPREP, }; @@ -251,15 +570,18 @@ enum class IrBlockKind : uint8_t Bytecode, Fallback, Internal, + Dead, }; struct IrBlock { IrBlockKind kind; + uint16_t useCount = 0; + // Start points to an instruction index in a stream // End is implicit - uint32_t start; + uint32_t start = ~0u; Label label; }; @@ -279,6 +601,64 @@ struct IrFunction std::vector bcMapping; Proto* proto = nullptr; + + IrBlock& blockOp(IrOp op) + { + LUAU_ASSERT(op.kind == IrOpKind::Block); + return blocks[op.index]; + } + + IrInst& instOp(IrOp op) + { + LUAU_ASSERT(op.kind == IrOpKind::Inst); + return instructions[op.index]; + } + + IrConst& constOp(IrOp op) + { + LUAU_ASSERT(op.kind == IrOpKind::Constant); + return constants[op.index]; + } + + uint8_t tagOp(IrOp op) + { + IrConst& value = constOp(op); + + LUAU_ASSERT(value.kind == IrConstKind::Tag); + return value.valueTag; + } + + bool boolOp(IrOp op) + { + IrConst& value = constOp(op); + + LUAU_ASSERT(value.kind == IrConstKind::Bool); + return value.valueBool; + } + + int intOp(IrOp op) + { + IrConst& value = constOp(op); + + LUAU_ASSERT(value.kind == IrConstKind::Int); + return value.valueInt; + } + + unsigned uintOp(IrOp op) + { + IrConst& value = constOp(op); + + LUAU_ASSERT(value.kind == IrConstKind::Uint); + return value.valueUint; + } + + double doubleOp(IrOp op) + { + IrConst& value = constOp(op); + + LUAU_ASSERT(value.kind == IrConstKind::Double); + return value.valueDouble; + } }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index 2f44ea852..47a5f9e92 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -21,12 +21,16 @@ struct IrToStringContext std::vector& constants; }; -void toString(IrToStringContext& ctx, IrInst inst, uint32_t index); +void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index); +void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index); // Block title void toString(IrToStringContext& ctx, IrOp op); void toString(std::string& result, IrConst constant); -void toStringDetailed(IrToStringContext& ctx, IrInst inst, uint32_t index); +void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index); +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index); // Block title + +std::string toString(IrFunction& function, bool includeDetails); std::string dump(IrFunction& function); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 843820556..1aef9a3fc 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -92,19 +92,14 @@ inline bool isBlockTerminator(IrCmd cmd) case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_EQ_TAG: - case IrCmd::JUMP_EQ_BOOLEAN: + case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: - case IrCmd::JUMP_CMP_STR: case IrCmd::JUMP_CMP_ANY: case IrCmd::LOP_NAMECALL: case IrCmd::LOP_RETURN: - case IrCmd::LOP_FORNPREP: - case IrCmd::LOP_FORNLOOP: case IrCmd::LOP_FORGLOOP: case IrCmd::LOP_FORGLOOP_FALLBACK: - case IrCmd::LOP_FORGPREP_NEXT: - case IrCmd::LOP_FORGPREP_INEXT: case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: case IrCmd::FALLBACK_FORGPREP: return true; @@ -142,6 +137,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: case IrCmd::NUM_TO_INDEX: + case IrCmd::INT_TO_NUM: return true; default: break; @@ -157,5 +153,24 @@ inline bool hasSideEffects(IrCmd cmd) return !hasResult(cmd); } +// Remove a single instruction +void kill(IrFunction& function, IrInst& inst); + +// Remove a range of instructions +void kill(IrFunction& function, uint32_t start, uint32_t end); + +// Remove a block, including all instructions inside +void kill(IrFunction& function, IrBlock& block); + +void removeUse(IrFunction& function, IrInst& inst); +void removeUse(IrFunction& function, IrBlock& block); + +// Replace a single operand and update use counts (can cause chain removal of dead code) +void replace(IrFunction& function, IrOp& original, IrOp replacement); + +// Replace a single instruction +// Target instruction index instead of reference is used to handle introduction of a new block terminator +void replace(IrFunction& function, uint32_t instIdx, IrInst replacement); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/OptimizeFinalX64.h b/CodeGen/include/Luau/OptimizeFinalX64.h new file mode 100644 index 000000000..bc50dd74f --- /dev/null +++ b/CodeGen/include/Luau/OptimizeFinalX64.h @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrData.h" + +namespace Luau +{ +namespace CodeGen +{ + +void optimizeMemoryOperandsX64(IrFunction& function); + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 72b2cbb34..78f001f17 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -7,6 +7,7 @@ #include "Luau/CodeBlockUnwind.h" #include "Luau/IrAnalysis.h" #include "Luau/IrBuilder.h" +#include "Luau/OptimizeFinalX64.h" #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" #include "Luau/UnwindBuilderWin.h" @@ -431,7 +432,7 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat IrBuilder builder; builder.buildFunctionIr(proto); - updateUseInfo(builder.function); + optimizeMemoryOperandsX64(builder.function); IrLoweringX64 lowering(build, helpers, data, proto, builder.function); diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index e0dae6699..7d36e17de 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -195,13 +195,20 @@ static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, Register if (object == rArg3) { LUAU_ASSERT(tmp != rArg2); - build.mov(rArg2, object); - build.mov(rArg3, tmp); + + if (rArg2 != object) + build.mov(rArg2, object); + + if (rArg3 != tmp) + build.mov(rArg3, tmp); } else { - build.mov(rArg3, tmp); - build.mov(rArg2, object); + if (rArg3 != tmp) + build.mov(rArg3, tmp); + + if (rArg2 != object) + build.mov(rArg2, object); } build.mov(rArg1, rState); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index b9f3953a6..a27d78aa4 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -11,31 +11,56 @@ namespace Luau namespace CodeGen { -static void recordUse(IrInst& inst, size_t index) +void updateUseCounts(IrFunction& function) { - LUAU_ASSERT(inst.useCount < 0xffff); + std::vector& blocks = function.blocks; + std::vector& instructions = function.instructions; + + for (IrBlock& block : blocks) + block.useCount = 0; + + for (IrInst& inst : instructions) + inst.useCount = 0; - inst.useCount++; - inst.lastUse = uint32_t(index); + auto checkOp = [&](IrOp op) { + if (op.kind == IrOpKind::Inst) + { + IrInst& target = instructions[op.index]; + LUAU_ASSERT(target.useCount < 0xffff); + target.useCount++; + } + else if (op.kind == IrOpKind::Block) + { + IrBlock& target = blocks[op.index]; + LUAU_ASSERT(target.useCount < 0xffff); + target.useCount++; + } + }; + + for (IrInst& inst : instructions) + { + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + } } -void updateUseInfo(IrFunction& function) +void updateLastUseLocations(IrFunction& function) { std::vector& instructions = function.instructions; for (IrInst& inst : instructions) - { - inst.useCount = 0; inst.lastUse = 0; - } - for (size_t i = 0; i < instructions.size(); ++i) + for (size_t instIdx = 0; instIdx < instructions.size(); ++instIdx) { - IrInst& inst = instructions[i]; + IrInst& inst = instructions[instIdx]; - auto checkOp = [&instructions, i](IrOp op) { + auto checkOp = [&](IrOp op) { if (op.kind == IrOpKind::Inst) - recordUse(instructions[op.index], i); + instructions[op.index].lastUse = uint32_t(instIdx); }; checkOp(inst.a); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 25f6d451a..9c5731067 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -2,6 +2,7 @@ #include "Luau/IrBuilder.h" #include "Luau/Common.h" +#include "Luau/IrAnalysis.h" #include "Luau/IrUtils.h" #include "CustomExecUtils.h" @@ -40,7 +41,9 @@ void IrBuilder::buildFunctionIr(Proto* proto) if (instIndexToBlock[i] != kNoAssociatedBlockIndex) beginBlock(blockAtInst(i)); - translateInst(op, pc, i); + // We skip dead bytecode instructions when they appear after block was already terminated + if (!inTerminatedBlock) + translateInst(op, pc, i); i = nexti; LUAU_ASSERT(i <= proto->sizecode); @@ -52,6 +55,9 @@ void IrBuilder::buildFunctionIr(Proto* proto) inst(IrCmd::JUMP, blockAtInst(i)); } } + + // Now that all has been generated, compute use counts + updateUseCounts(function); } void IrBuilder::rebuildBytecodeBasicBlocks(Proto* proto) @@ -120,7 +126,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstSetGlobal(*this, pc, i); break; case LOP_CALL: - inst(IrCmd::LOP_CALL, constUint(i)); + inst(IrCmd::LOP_CALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1), constInt(LUAU_INSN_C(*pc) - 1)); if (activeFastcallFallback) { @@ -132,7 +138,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } break; case LOP_RETURN: - inst(IrCmd::LOP_RETURN, constUint(i)); + inst(IrCmd::LOP_RETURN, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); break; case LOP_GETTABLE: translateInstGetTable(*this, pc, i); @@ -249,7 +255,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstDupTable(*this, pc, i); break; case LOP_SETLIST: - inst(IrCmd::LOP_SETLIST, constUint(i)); + inst(IrCmd::LOP_SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); break; case LOP_GETUPVAL: translateInstGetUpval(*this, pc, i); @@ -262,10 +268,15 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) break; case LOP_FASTCALL: { + int skip = LUAU_INSN_C(*pc); + IrOp fallback = block(IrBlockKind::Fallback); - IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + IrOp next = blockAtInst(i + skip + 2); + + Instruction call = pc[skip + 1]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); - inst(IrCmd::LOP_FASTCALL, constUint(i), fallback); + inst(IrCmd::LOP_FASTCALL, constUint(i), vmReg(LUAU_INSN_A(call)), constInt(LUAU_INSN_B(call) - 1), fallback); inst(IrCmd::JUMP, next); beginBlock(fallback); @@ -276,10 +287,15 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } case LOP_FASTCALL1: { + int skip = LUAU_INSN_C(*pc); + IrOp fallback = block(IrBlockKind::Fallback); - IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + IrOp next = blockAtInst(i + skip + 2); + + Instruction call = pc[skip + 1]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); - inst(IrCmd::LOP_FASTCALL1, constUint(i), fallback); + inst(IrCmd::LOP_FASTCALL1, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), fallback); inst(IrCmd::JUMP, next); beginBlock(fallback); @@ -290,10 +306,15 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } case LOP_FASTCALL2: { + int skip = LUAU_INSN_C(*pc); + IrOp fallback = block(IrBlockKind::Fallback); - IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + IrOp next = blockAtInst(i + skip + 2); - inst(IrCmd::LOP_FASTCALL2, constUint(i), fallback); + Instruction call = pc[skip + 1]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + inst(IrCmd::LOP_FASTCALL2, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), vmReg(pc[1]), fallback); inst(IrCmd::JUMP, next); beginBlock(fallback); @@ -304,10 +325,15 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } case LOP_FASTCALL2K: { + int skip = LUAU_INSN_C(*pc); + IrOp fallback = block(IrBlockKind::Fallback); - IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + IrOp next = blockAtInst(i + skip + 2); + + Instruction call = pc[skip + 1]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); - inst(IrCmd::LOP_FASTCALL2K, constUint(i), fallback); + inst(IrCmd::LOP_FASTCALL2K, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), vmConst(pc[1]), fallback); inst(IrCmd::JUMP, next); beginBlock(fallback); @@ -317,72 +343,50 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) break; } case LOP_FORNPREP: - { - IrOp loopStart = blockAtInst(i + getOpLength(LOP_FORNPREP)); - IrOp loopExit = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); - - inst(IrCmd::LOP_FORNPREP, constUint(i), loopStart, loopExit); - - beginBlock(loopStart); + translateInstForNPrep(*this, pc, i); break; - } case LOP_FORNLOOP: - { - IrOp loopRepeat = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); - IrOp loopExit = blockAtInst(i + getOpLength(LOP_FORNLOOP)); - - inst(IrCmd::LOP_FORNLOOP, constUint(i), loopRepeat, loopExit); - - beginBlock(loopExit); + translateInstForNLoop(*this, pc, i); break; - } case LOP_FORGLOOP: { - IrOp loopRepeat = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); - IrOp loopExit = blockAtInst(i + getOpLength(LOP_FORGLOOP)); - IrOp fallback = block(IrBlockKind::Fallback); + // We have a translation for ipairs-style traversal, general loop iteration is still too complex + if (int(pc[1]) < 0) + { + translateInstForGLoopIpairs(*this, pc, i); + } + else + { + IrOp loopRepeat = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); + IrOp loopExit = blockAtInst(i + getOpLength(LOP_FORGLOOP)); + IrOp fallback = block(IrBlockKind::Fallback); - inst(IrCmd::LOP_FORGLOOP, constUint(i), loopRepeat, loopExit, fallback); + inst(IrCmd::LOP_FORGLOOP, constUint(i), loopRepeat, loopExit, fallback); - beginBlock(fallback); - inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), loopRepeat, loopExit); + beginBlock(fallback); + inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), loopRepeat, loopExit); - beginBlock(loopExit); + beginBlock(loopExit); + } break; } case LOP_FORGPREP_NEXT: - { - IrOp target = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); - IrOp fallback = block(IrBlockKind::Fallback); - - inst(IrCmd::LOP_FORGPREP_NEXT, constUint(i), target, fallback); - - beginBlock(fallback); - inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, constUint(i), target); + translateInstForGPrepNext(*this, pc, i); break; - } case LOP_FORGPREP_INEXT: - { - IrOp target = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); - IrOp fallback = block(IrBlockKind::Fallback); - - inst(IrCmd::LOP_FORGPREP_INEXT, constUint(i), target, fallback); - - beginBlock(fallback); - inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, constUint(i), target); + translateInstForGPrepInext(*this, pc, i); break; - } case LOP_AND: - inst(IrCmd::LOP_AND, constUint(i)); + inst(IrCmd::LOP_AND, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); break; case LOP_ANDK: - inst(IrCmd::LOP_ANDK, constUint(i)); + inst(IrCmd::LOP_ANDK, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); break; case LOP_OR: - inst(IrCmd::LOP_OR, constUint(i)); + inst(IrCmd::LOP_OR, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); break; case LOP_ORK: - inst(IrCmd::LOP_ORK, constUint(i)); + inst(IrCmd::LOP_ORK, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); break; case LOP_COVERAGE: inst(IrCmd::LOP_COVERAGE, constUint(i)); @@ -401,30 +405,34 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) IrOp next = blockAtInst(i + getOpLength(LOP_NAMECALL)); IrOp fallback = block(IrBlockKind::Fallback); - inst(IrCmd::LOP_NAMECALL, constUint(i), next, fallback); + inst(IrCmd::LOP_NAMECALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), next, fallback); beginBlock(fallback); - inst(IrCmd::FALLBACK_NAMECALL, constUint(i)); + inst(IrCmd::FALLBACK_NAMECALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(pc[1])); inst(IrCmd::JUMP, next); beginBlock(next); break; } case LOP_PREPVARARGS: - inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i)); + inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc))); break; case LOP_GETVARARGS: - inst(IrCmd::FALLBACK_GETVARARGS, constUint(i)); + inst(IrCmd::FALLBACK_GETVARARGS, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); break; case LOP_NEWCLOSURE: - inst(IrCmd::FALLBACK_NEWCLOSURE, constUint(i)); + inst(IrCmd::FALLBACK_NEWCLOSURE, constUint(i), vmReg(LUAU_INSN_A(*pc)), constUint(LUAU_INSN_D(*pc))); break; case LOP_DUPCLOSURE: - inst(IrCmd::FALLBACK_DUPCLOSURE, constUint(i)); + inst(IrCmd::FALLBACK_DUPCLOSURE, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmConst(LUAU_INSN_D(*pc))); break; case LOP_FORGPREP: - inst(IrCmd::FALLBACK_FORGPREP, constUint(i)); + { + IrOp loopStart = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); + + inst(IrCmd::FALLBACK_FORGPREP, constUint(i), vmReg(LUAU_INSN_A(*pc)), loopStart); break; + } default: LUAU_ASSERT(!"unknown instruction"); break; @@ -445,6 +453,8 @@ void IrBuilder::beginBlock(IrOp block) LUAU_ASSERT(target.start == ~0u || target.start == uint32_t(function.instructions.size())); target.start = uint32_t(function.instructions.size()); + + inTerminatedBlock = false; } IrOp IrBuilder::constBool(bool value) @@ -528,6 +538,10 @@ IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e) { uint32_t index = uint32_t(function.instructions.size()); function.instructions.push_back({cmd, a, b, c, d, e}); + + if (isBlockTerminator(cmd)) + inTerminatedBlock = true; + return {IrOpKind::Inst, index}; } @@ -537,7 +551,7 @@ IrOp IrBuilder::block(IrBlockKind kind) kind = IrBlockKind::Fallback; uint32_t index = uint32_t(function.blocks.size()); - function.blocks.push_back(IrBlock{kind, ~0u}); + function.blocks.push_back(IrBlock{kind}); return IrOp{IrOpKind::Block, index}; } diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index eb5a07440..5a23861e4 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -29,6 +29,14 @@ static void append(std::string& result, const char* fmt, ...) result.append(buf); } +static void padToDetailColumn(std::string& result, size_t lineStart) +{ + int pad = kDetailsAlignColumn - int(result.size() - lineStart); + + if (pad > 0) + result.append(pad, ' '); +} + static const char* getTagName(uint8_t tag) { switch (tag) @@ -122,14 +130,12 @@ const char* getCmdName(IrCmd cmd) return "JUMP_IF_FALSY"; case IrCmd::JUMP_EQ_TAG: return "JUMP_EQ_TAG"; - case IrCmd::JUMP_EQ_BOOLEAN: - return "JUMP_EQ_BOOLEAN"; + case IrCmd::JUMP_EQ_INT: + return "JUMP_EQ_INT"; case IrCmd::JUMP_EQ_POINTER: return "JUMP_EQ_POINTER"; case IrCmd::JUMP_CMP_NUM: return "JUMP_CMP_NUM"; - case IrCmd::JUMP_CMP_STR: - return "JUMP_CMP_STR"; case IrCmd::JUMP_CMP_ANY: return "JUMP_CMP_ANY"; case IrCmd::TABLE_LEN: @@ -140,6 +146,8 @@ const char* getCmdName(IrCmd cmd) return "DUP_TABLE"; case IrCmd::NUM_TO_INDEX: return "NUM_TO_INDEX"; + case IrCmd::INT_TO_NUM: + return "INT_TO_NUM"; case IrCmd::DO_ARITH: return "DO_ARITH"; case IrCmd::DO_LEN: @@ -156,6 +164,8 @@ const char* getCmdName(IrCmd cmd) return "GET_UPVALUE"; case IrCmd::SET_UPVALUE: return "SET_UPVALUE"; + case IrCmd::PREPARE_FORN: + return "PREPARE_FORN"; case IrCmd::CHECK_TAG: return "CHECK_TAG"; case IrCmd::CHECK_READONLY: @@ -200,18 +210,10 @@ const char* getCmdName(IrCmd cmd) return "LOP_FASTCALL2"; case IrCmd::LOP_FASTCALL2K: return "LOP_FASTCALL2K"; - case IrCmd::LOP_FORNPREP: - return "LOP_FORNPREP"; - case IrCmd::LOP_FORNLOOP: - return "LOP_FORNLOOP"; case IrCmd::LOP_FORGLOOP: return "LOP_FORGLOOP"; case IrCmd::LOP_FORGLOOP_FALLBACK: return "LOP_FORGLOOP_FALLBACK"; - case IrCmd::LOP_FORGPREP_NEXT: - return "LOP_FORGPREP_NEXT"; - case IrCmd::LOP_FORGPREP_INEXT: - return "LOP_FORGPREP_INEXT"; case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: return "LOP_FORGPREP_XNEXT_FALLBACK"; case IrCmd::LOP_AND: @@ -259,12 +261,14 @@ const char* getBlockKindName(IrBlockKind kind) return "bb_fallback"; case IrBlockKind::Internal: return "bb"; + case IrBlockKind::Dead: + return "dead"; } LUAU_UNREACHABLE(); } -void toString(IrToStringContext& ctx, IrInst inst, uint32_t index) +void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index) { append(ctx.result, " "); @@ -305,6 +309,11 @@ void toString(IrToStringContext& ctx, IrInst inst, uint32_t index) } } +void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index) +{ + append(ctx.result, "%s_%u:", getBlockKindName(block.kind), index); +} + void toString(IrToStringContext& ctx, IrOp op) { switch (op.kind) @@ -358,18 +367,12 @@ void toString(std::string& result, IrConst constant) } } -void toStringDetailed(IrToStringContext& ctx, IrInst inst, uint32_t index) +void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index) { size_t start = ctx.result.size(); toString(ctx, inst, index); - - int pad = kDetailsAlignColumn - int(ctx.result.size() - start); - - if (pad > 0) - ctx.result.append(pad, ' '); - - LUAU_ASSERT(inst.useCount == 0 || inst.lastUse != 0); + padToDetailColumn(ctx.result, start); if (inst.useCount == 0 && hasSideEffects(inst.cmd)) append(ctx.result, "; %%%u, has side-effects\n", index); @@ -377,7 +380,17 @@ void toStringDetailed(IrToStringContext& ctx, IrInst inst, uint32_t index) append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); } -std::string dump(IrFunction& function) +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index) +{ + size_t start = ctx.result.size(); + + toString(ctx, block, index); + padToDetailColumn(ctx.result, start); + + append(ctx.result, "; useCount: %d\n", block.useCount); +} + +std::string toString(IrFunction& function, bool includeDetails) { std::string result; IrToStringContext ctx{result, function.blocks, function.constants}; @@ -386,7 +399,18 @@ std::string dump(IrFunction& function) { IrBlock& block = function.blocks[i]; - append(ctx.result, "%s_%u:\n", getBlockKindName(block.kind), unsigned(i)); + if (block.kind == IrBlockKind::Dead) + continue; + + if (includeDetails) + { + toStringDetailed(ctx, block, uint32_t(i)); + } + else + { + toString(ctx, block, uint32_t(i)); + ctx.result.append("\n"); + } if (block.start == ~0u) { @@ -394,10 +418,9 @@ std::string dump(IrFunction& function) continue; } - for (uint32_t index = block.start; true; index++) + // To allow dumping blocks that are still being constructed, we can't rely on terminator and need a bounds check + for (uint32_t index = block.start; index < uint32_t(function.instructions.size()); index++) { - LUAU_ASSERT(index < function.instructions.size()); - IrInst& inst = function.instructions[index]; // Nop is used to replace dead instructions in-place, so it's not that useful to see them @@ -405,7 +428,16 @@ std::string dump(IrFunction& function) continue; append(ctx.result, " "); - toStringDetailed(ctx, inst, index); + + if (includeDetails) + { + toStringDetailed(ctx, inst, index); + } + else + { + toString(ctx, inst, index); + ctx.result.append("\n"); + } if (isBlockTerminator(inst.cmd)) { @@ -415,6 +447,13 @@ std::string dump(IrFunction& function) } } + return result; +} + +std::string dump(IrFunction& function) +{ + std::string result = toString(function, /* includeDetails */ true); + printf("%s\n", result.c_str()); return result; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 62b5dea59..03bb18146 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -3,6 +3,7 @@ #include "Luau/CodeGen.h" #include "Luau/DenseHash.h" +#include "Luau/IrAnalysis.h" #include "Luau/IrDump.h" #include "Luau/IrUtils.h" @@ -30,6 +31,9 @@ IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, { freeGprMap.fill(true); freeXmmMap.fill(true); + + // In order to allocate registers during lowering, we need to know where instruction results are last used + updateLastUseLocations(function); } void IrLoweringX64::lower(AssemblyOptions options) @@ -93,6 +97,9 @@ void IrLoweringX64::lower(AssemblyOptions options) IrBlock& block = function.blocks[blockIndex]; LUAU_ASSERT(block.start != ~0u); + if (block.kind == IrBlockKind::Dead) + continue; + // If we want to skip fallback code IR/asm, we'll record when those blocks start once we see them if (block.kind == IrBlockKind::Fallback && !seenFallback) { @@ -102,7 +109,10 @@ void IrLoweringX64::lower(AssemblyOptions options) } if (options.includeIr) - build.logAppend("# %s_%u:\n", getBlockKindName(block.kind), blockIndex); + { + build.logAppend("# "); + toStringDetailed(ctx, block, uint32_t(i)); + } build.setLabel(block.label); @@ -179,6 +189,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(inst.regX64, luauRegTag(inst.a.index)); else if (inst.a.kind == IrOpKind::VmConst) build.mov(inst.regX64, luauConstantTag(inst.a.index)); + // If we have a register, we assume it's a pointer to TValue + // We might introduce explicit operand types in the future to make this more robust + else if (inst.a.kind == IrOpKind::Inst) + build.mov(inst.regX64, dword[regOp(inst.a) + offsetof(TValue, tt)]); else LUAU_ASSERT(!"Unsupported instruction form"); break; @@ -237,7 +251,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regX64 = allocGprRegOrReuse(SizeX64::qword, index, {inst.b}); - build.mov(dwordReg(inst.regX64), regOp(inst.b)); + if (dwordReg(inst.regX64) != regOp(inst.b)) + build.mov(dwordReg(inst.regX64), regOp(inst.b)); + build.shl(dwordReg(inst.regX64), kTValueSizeLog2); build.add(inst.regX64, qword[regOp(inst.a) + offsetof(Table, array)]); } @@ -442,7 +458,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } else { - LUAU_ASSERT(!"Unsupported instruction form"); + if (lhs != xmm0) + build.vmovsd(xmm0, lhs, lhs); + + build.vmovsd(xmm1, memRegDoubleOp(inst.b)); + build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); + + if (inst.regX64 != xmm0) + build.vmovsd(inst.regX64, xmm0, xmm0); } break; @@ -525,8 +548,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; } - case IrCmd::JUMP_EQ_BOOLEAN: - build.cmp(regOp(inst.a), boolOp(inst.b) ? 1 : 0); + case IrCmd::JUMP_EQ_INT: + build.cmp(regOp(inst.a), intOp(inst.b)); build.jcc(ConditionX64::Equal, labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.d), next); @@ -576,7 +599,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(dwordReg(rArg2), uintOp(inst.a)); build.mov(dwordReg(rArg3), uintOp(inst.b)); build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); - build.mov(inst.regX64, rax); + + if (inst.regX64 != rax) + build.mov(inst.regX64, rax); break; case IrCmd::DUP_TABLE: inst.regX64 = allocGprReg(SizeX64::qword); @@ -585,7 +610,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(rArg2, regOp(inst.a)); build.mov(rArg1, rState); build.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]); - build.mov(inst.regX64, rax); + + if (inst.regX64 != rax) + build.mov(inst.regX64, rax); break; case IrCmd::NUM_TO_INDEX: { @@ -596,6 +623,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) convertNumberToIndexOrJump(build, tmp.reg, regOp(inst.a), inst.regX64, labelOp(inst.b)); break; } + case IrCmd::INT_TO_NUM: + inst.regX64 = allocXmmReg(); + + build.vcvtsi2sd(inst.regX64, inst.regX64, regOp(inst.a)); + break; case IrCmd::DO_ARITH: LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); @@ -711,6 +743,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.setLabel(next); break; } + case IrCmd::PREPARE_FORN: + callPrepareForN(build, inst.a.index, inst.b.index, inst.c.index); + break; case IrCmd::CHECK_TAG: if (inst.a.kind == IrOpKind::Inst) { @@ -828,7 +863,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(tmp2.reg, qword[tmp1.reg + offsetof(UpVal, v)]); build.jcc(ConditionX64::Above, next); - build.mov(rArg2, tmp2.reg); + if (rArg2 != tmp2.reg) + build.mov(rArg2, tmp2.reg); + build.mov(rArg1, rState); build.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); @@ -843,6 +880,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::LOP_SETLIST: { const Instruction* pc = proto->code + uintOp(inst.a); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::Constant); + LUAU_ASSERT(inst.e.kind == IrOpKind::Constant); Label next; emitInstSetList(build, pc, next); @@ -852,13 +893,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::LOP_NAMECALL: { const Instruction* pc = proto->code + uintOp(inst.a); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - emitInstNameCall(build, pc, uintOp(inst.a), proto->k, blockOp(inst.b).label, blockOp(inst.c).label); + emitInstNameCall(build, pc, uintOp(inst.a), proto->k, blockOp(inst.d).label, blockOp(inst.e).label); break; } case IrCmd::LOP_CALL: { const Instruction* pc = proto->code + uintOp(inst.a); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + LUAU_ASSERT(inst.d.kind == IrOpKind::Constant); emitInstCall(build, helpers, pc, uintOp(inst.a)); break; @@ -866,27 +912,37 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::LOP_RETURN: { const Instruction* pc = proto->code + uintOp(inst.a); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); emitInstReturn(build, helpers, pc, uintOp(inst.a)); break; } case IrCmd::LOP_FASTCALL: - emitInstFastCall(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + + emitInstFastCall(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.d)); break; case IrCmd::LOP_FASTCALL1: - emitInstFastCall1(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + + emitInstFastCall1(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.d)); break; case IrCmd::LOP_FASTCALL2: - emitInstFastCall2(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg); + + emitInstFastCall2(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.e)); break; case IrCmd::LOP_FASTCALL2K: - emitInstFastCall2K(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); - break; - case IrCmd::LOP_FORNPREP: - emitInstForNPrep(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b), labelOp(inst.c)); - break; - case IrCmd::LOP_FORNLOOP: - emitInstForNLoop(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b), labelOp(inst.c)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + emitInstFastCall2K(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.e)); break; case IrCmd::LOP_FORGLOOP: emitinstForGLoop(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b), labelOp(inst.c), labelOp(inst.d)); @@ -895,12 +951,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) emitinstForGLoopFallback(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); build.jmp(labelOp(inst.c)); break; - case IrCmd::LOP_FORGPREP_NEXT: - emitInstForGPrepNext(build, proto->code + uintOp(inst.a), labelOp(inst.b), labelOp(inst.c)); - break; - case IrCmd::LOP_FORGPREP_INEXT: - emitInstForGPrepInext(build, proto->code + uintOp(inst.a), labelOp(inst.b), labelOp(inst.c)); - break; case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: emitInstForGPrepXnextFallback(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); break; @@ -922,30 +972,59 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // Full instruction fallbacks case IrCmd::FALLBACK_GETGLOBAL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_GETGLOBAL, uintOp(inst.a)); break; case IrCmd::FALLBACK_SETGLOBAL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_SETGLOBAL, uintOp(inst.a)); break; case IrCmd::FALLBACK_GETTABLEKS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_GETTABLEKS, uintOp(inst.a)); break; case IrCmd::FALLBACK_SETTABLEKS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_SETTABLEKS, uintOp(inst.a)); break; case IrCmd::FALLBACK_NAMECALL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_NAMECALL, uintOp(inst.a)); break; case IrCmd::FALLBACK_PREPVARARGS: + LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); + emitFallback(build, data, LOP_PREPVARARGS, uintOp(inst.a)); break; case IrCmd::FALLBACK_GETVARARGS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + emitFallback(build, data, LOP_GETVARARGS, uintOp(inst.a)); break; case IrCmd::FALLBACK_NEWCLOSURE: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + emitFallback(build, data, LOP_NEWCLOSURE, uintOp(inst.a)); break; case IrCmd::FALLBACK_DUPCLOSURE: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_DUPCLOSURE, uintOp(inst.a)); break; case IrCmd::FALLBACK_FORGPREP: @@ -1006,60 +1085,42 @@ OperandX64 IrLoweringX64::memRegTagOp(IrOp op) const RegisterX64 IrLoweringX64::regOp(IrOp op) const { - LUAU_ASSERT(op.kind == IrOpKind::Inst); - return function.instructions[op.index].regX64; + return function.instOp(op).regX64; } IrConst IrLoweringX64::constOp(IrOp op) const { - LUAU_ASSERT(op.kind == IrOpKind::Constant); - return function.constants[op.index]; + return function.constOp(op); } uint8_t IrLoweringX64::tagOp(IrOp op) const { - IrConst value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Tag); - return value.valueTag; + return function.tagOp(op); } bool IrLoweringX64::boolOp(IrOp op) const { - IrConst value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Bool); - return value.valueBool; + return function.boolOp(op); } int IrLoweringX64::intOp(IrOp op) const { - IrConst value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Int); - return value.valueInt; + return function.intOp(op); } unsigned IrLoweringX64::uintOp(IrOp op) const { - IrConst value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Uint); - return value.valueUint; + return function.uintOp(op); } double IrLoweringX64::doubleOp(IrOp op) const { - IrConst value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Double); - return value.valueDouble; + return function.doubleOp(op); } IrBlock& IrLoweringX64::blockOp(IrOp op) const { - LUAU_ASSERT(op.kind == IrOpKind::Block); - return function.blocks[op.index]; + return function.blockOp(op); } Label& IrLoweringX64::labelOp(IrOp op) const @@ -1162,7 +1223,9 @@ void IrLoweringX64::freeLastUseReg(IrInst& target, uint32_t index) { if (target.lastUse == index && !target.reusedReg) { - LUAU_ASSERT(target.regX64 != noreg); + // Register might have already been freed if it had multiple uses inside a single instruction + if (target.regX64 == noreg) + return; freeReg(target.regX64); target.regX64 = noreg; diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 32958b56e..fdbdf6670 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -3,6 +3,9 @@ #include "Luau/Bytecode.h" #include "Luau/IrBuilder.h" +#include "Luau/IrUtils.h" + +#include "CustomExecUtils.h" #include "lobject.h" #include "ltm.h" @@ -215,7 +218,7 @@ void translateInstJumpxEqB(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(checkValue); IrOp va = build.inst(IrCmd::LOAD_INT, build.vmReg(ra)); - build.inst(IrCmd::JUMP_EQ_BOOLEAN, va, build.constBool(aux & 0x1), not_ ? next : target, not_ ? target : next); + build.inst(IrCmd::JUMP_EQ_INT, va, build.constInt(aux & 0x1), not_ ? next : target, not_ ? target : next); // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(next)) @@ -238,7 +241,12 @@ void translateInstJumpxEqN(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(checkValue); IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmConst(aux & 0xffffff)); + + LUAU_ASSERT(build.function.proto); + TValue protok = build.function.proto->k[aux & 0xffffff]; + + LUAU_ASSERT(protok.tt == LUA_TNUMBER); + IrOp vb = build.constDouble(protok.value.n); build.inst(IrCmd::JUMP_CMP_NUM, va, vb, build.cond(IrCondition::NotEqual), not_ ? target : next, not_ ? next : target); @@ -286,7 +294,20 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, } IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb)); - IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, opc); + IrOp vc; + + if (opc.kind == IrOpKind::VmConst) + { + LUAU_ASSERT(build.function.proto); + TValue protok = build.function.proto->k[opc.index]; + + LUAU_ASSERT(protok.tt == LUA_TNUMBER); + vc = build.constDouble(protok.value.n); + } + else + { + vc = build.inst(IrCmd::LOAD_DOUBLE, opc); + } IrOp va; @@ -458,6 +479,209 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc) build.inst(IrCmd::CLOSE_UPVALS, build.vmReg(ra)); } +void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + + IrOp loopStart = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); + IrOp loopExit = build.blockAtInst(getJumpTarget(*pc, pcpos)); + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp nextStep = build.block(IrBlockKind::Internal); + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); + + IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); + build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), fallback); + IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), fallback); + IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), fallback); + build.inst(IrCmd::JUMP, nextStep); + + // After successful conversion of arguments to number in a fallback, we return here + build.beginBlock(nextStep); + + IrOp zero = build.constDouble(0.0); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + + // step <= 0 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + + // TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation + + // step <= 0 is false, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopStart, loopExit); + + // step <= 0 is true, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopStart, loopExit); + + // Fallback will try to convert loop variables to numbers or throw an error + build.beginBlock(fallback); + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::PREPARE_FORN, build.vmReg(ra + 0), build.vmReg(ra + 1), build.vmReg(ra + 2)); + build.inst(IrCmd::JUMP, nextStep); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(loopStart)) + build.beginBlock(loopStart); +} + +void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + + IrOp loopRepeat = build.blockAtInst(getJumpTarget(*pc, pcpos)); + IrOp loopExit = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); + + build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + + IrOp zero = build.constDouble(0.0); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + idx = build.inst(IrCmd::ADD_NUM, idx, step); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); + + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); + + // step <= 0 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + + // step <= 0 is false, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + + // step <= 0 is true, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(loopExit)) + build.beginBlock(loopExit); +} + +void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp fallback = build.block(IrBlockKind::Fallback); + + // fast-path: pairs/next + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback); + IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagC, build.constTag(LUA_TNIL), fallback); + + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNIL)); + + // setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(ra + 2), build.constInt(0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 2), build.constTag(LUA_TLIGHTUSERDATA)); + + build.inst(IrCmd::JUMP, target); + + // FallbackStreamScope not used here because this instruction doesn't fallthrough to next instruction + build.beginBlock(fallback); + build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target); +} + +void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp finish = build.block(IrBlockKind::Internal); + + // fast-path: ipairs/inext + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback); + IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagC, build.constTag(LUA_TNUMBER), fallback); + + IrOp numC = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + build.inst(IrCmd::JUMP_CMP_NUM, numC, build.constDouble(0.0), build.cond(IrCondition::NotEqual), fallback, finish); + + build.beginBlock(finish); + + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNIL)); + + // setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(ra + 2), build.constInt(0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 2), build.constTag(LUA_TLIGHTUSERDATA)); + + build.inst(IrCmd::JUMP, target); + + // FallbackStreamScope not used here because this instruction doesn't fallthrough to next instruction + build.beginBlock(fallback); + build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target); +} + +void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + LUAU_ASSERT(int(pc[1]) < 0); + + IrOp loopRepeat = build.blockAtInst(getJumpTarget(*pc, pcpos)); + IrOp loopExit = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp hasElem = build.block(IrBlockKind::Internal); + + build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + + // fast-path: builtin table iteration + IrOp tagA = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra)); + build.inst(IrCmd::CHECK_TAG, tagA, build.constTag(LUA_TNIL), fallback); + + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(ra + 1)); + IrOp index = build.inst(IrCmd::LOAD_INT, build.vmReg(ra + 2)); + + IrOp elemPtr = build.inst(IrCmd::GET_ARR_ADDR, table, index); + + // Terminate if array has ended + build.inst(IrCmd::CHECK_ARRAY_SIZE, table, index, loopExit); + + // Terminate if element is nil + IrOp elemTag = build.inst(IrCmd::LOAD_TAG, elemPtr); + build.inst(IrCmd::JUMP_EQ_TAG, elemTag, build.constTag(LUA_TNIL), loopExit, hasElem); + build.beginBlock(hasElem); + + IrOp nextIndex = build.inst(IrCmd::ADD_INT, index, build.constInt(1)); + + // We update only a dword part of the userdata pointer that's reused in loop iteration as an index + // Upper bits start and remain to be 0 + build.inst(IrCmd::STORE_INT, build.vmReg(ra + 2), nextIndex); + // Tag should already be set to lightuserdata + + // setnvalue(ra + 3, double(index + 1)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 3), build.inst(IrCmd::INT_TO_NUM, nextIndex)); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 3), build.constTag(LUA_TNUMBER)); + + // setobj2s(L, ra + 4, e); + IrOp elemTV = build.inst(IrCmd::LOAD_TVALUE, elemPtr); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra + 4), elemTV); + + build.inst(IrCmd::JUMP, loopRepeat); + + build.beginBlock(fallback); + build.inst(IrCmd::LOP_FORGLOOP_FALLBACK, build.constUint(pcpos), loopRepeat, loopExit); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(loopExit)) + build.beginBlock(loopExit); +} + void translateInstGetTableN(IrBuilder& build, const Instruction* pc, int pcpos) { int ra = LUAU_INSN_A(*pc); @@ -654,7 +878,7 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp next = build.blockAtInst(pcpos + 2); FallbackStreamScope scope(build, fallback, next); - build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos)); + build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); build.inst(IrCmd::JUMP, next); } @@ -685,7 +909,7 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp next = build.blockAtInst(pcpos + 2); FallbackStreamScope scope(build, fallback, next); - build.inst(IrCmd::FALLBACK_SETTABLEKS, build.constUint(pcpos)); + build.inst(IrCmd::FALLBACK_SETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); build.inst(IrCmd::JUMP, next); } @@ -708,7 +932,7 @@ void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos) IrOp next = build.blockAtInst(pcpos + 2); FallbackStreamScope scope(build, fallback, next); - build.inst(IrCmd::FALLBACK_GETGLOBAL, build.constUint(pcpos)); + build.inst(IrCmd::FALLBACK_GETGLOBAL, build.constUint(pcpos), build.vmReg(ra), build.vmConst(aux)); build.inst(IrCmd::JUMP, next); } @@ -734,7 +958,7 @@ void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos) IrOp next = build.blockAtInst(pcpos + 2); FallbackStreamScope scope(build, fallback, next); - build.inst(IrCmd::FALLBACK_SETGLOBAL, build.constUint(pcpos)); + build.inst(IrCmd::FALLBACK_SETGLOBAL, build.constUint(pcpos), build.vmReg(ra), build.vmConst(aux)); build.inst(IrCmd::JUMP, next); } diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 53030a203..6ffc911d2 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -42,6 +42,11 @@ void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc); +void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetTableN(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetTableN(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetTable(IrBuilder& build, const Instruction* pc, int pcpos); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp new file mode 100644 index 000000000..0c1a89668 --- /dev/null +++ b/CodeGen/src/IrUtils.cpp @@ -0,0 +1,133 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IrUtils.h" + +namespace Luau +{ +namespace CodeGen +{ + +static uint32_t getBlockEnd(IrFunction& function, uint32_t start) +{ + uint32_t end = start; + + // Find previous block terminator + while (!isBlockTerminator(function.instructions[end].cmd)) + end++; + + return end; +} + +static void addUse(IrFunction& function, IrOp op) +{ + if (op.kind == IrOpKind::Inst) + function.instructions[op.index].useCount++; + else if (op.kind == IrOpKind::Block) + function.blocks[op.index].useCount++; +} + +static void removeUse(IrFunction& function, IrOp op) +{ + if (op.kind == IrOpKind::Inst) + removeUse(function, function.instructions[op.index]); + else if (op.kind == IrOpKind::Block) + removeUse(function, function.blocks[op.index]); +} + +void kill(IrFunction& function, IrInst& inst) +{ + LUAU_ASSERT(inst.useCount == 0); + + inst.cmd = IrCmd::NOP; + + removeUse(function, inst.a); + removeUse(function, inst.b); + removeUse(function, inst.c); + removeUse(function, inst.d); + removeUse(function, inst.e); +} + +void kill(IrFunction& function, uint32_t start, uint32_t end) +{ + // Kill instructions in reverse order to avoid killing instructions that are still marked as used + for (int i = int(end); i >= int(start); i--) + { + IrInst& curr = function.instructions[i]; + + if (curr.cmd == IrCmd::NOP) + continue; + + kill(function, curr); + } +} + +void kill(IrFunction& function, IrBlock& block) +{ + LUAU_ASSERT(block.useCount == 0); + + block.kind = IrBlockKind::Dead; + + uint32_t start = block.start; + uint32_t end = getBlockEnd(function, start); + + kill(function, start, end); +} + +void removeUse(IrFunction& function, IrInst& inst) +{ + LUAU_ASSERT(inst.useCount); + inst.useCount--; + + if (inst.useCount == 0) + kill(function, inst); +} + +void removeUse(IrFunction& function, IrBlock& block) +{ + LUAU_ASSERT(block.useCount); + block.useCount--; + + if (block.useCount == 0) + kill(function, block); +} + +void replace(IrFunction& function, IrOp& original, IrOp replacement) +{ + // Add use before removing new one if that's the last one keeping target operand alive + addUse(function, replacement); + removeUse(function, original); + + original = replacement; +} + +void replace(IrFunction& function, uint32_t instIdx, IrInst replacement) +{ + IrInst& inst = function.instructions[instIdx]; + IrCmd prevCmd = inst.cmd; + + // Add uses before removing new ones if those are the last ones keeping target operand alive + addUse(function, replacement.a); + addUse(function, replacement.b); + addUse(function, replacement.c); + addUse(function, replacement.d); + addUse(function, replacement.e); + + removeUse(function, inst.a); + removeUse(function, inst.b); + removeUse(function, inst.c); + removeUse(function, inst.d); + removeUse(function, inst.e); + + inst = replacement; + + // If we introduced an earlier terminating instruction, all following instructions become dead + if (!isBlockTerminator(prevCmd) && isBlockTerminator(inst.cmd)) + { + uint32_t start = instIdx + 1; + uint32_t end = getBlockEnd(function, start); + + kill(function, start, end); + } +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/OptimizeFinalX64.cpp b/CodeGen/src/OptimizeFinalX64.cpp new file mode 100644 index 000000000..57f9a5c42 --- /dev/null +++ b/CodeGen/src/OptimizeFinalX64.cpp @@ -0,0 +1,111 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/OptimizeFinalX64.h" + +#include "Luau/IrUtils.h" + +#include + +namespace Luau +{ +namespace CodeGen +{ + +// x64 assembly allows memory operands, but IR separates loads from uses +// To improve final x64 lowering, we try to 'inline' single-use register/constant loads into some of our instructions +// This pass might not be useful on different architectures +static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block) +{ + LUAU_ASSERT(block.kind != IrBlockKind::Dead); + + for (uint32_t index = block.start; true; index++) + { + LUAU_ASSERT(index < function.instructions.size()); + IrInst& inst = function.instructions[index]; + + switch (inst.cmd) + { + case IrCmd::CHECK_TAG: + { + if (inst.a.kind == IrOpKind::Inst) + { + IrInst& tag = function.instOp(inst.a); + + if (tag.useCount == 1 && tag.cmd == IrCmd::LOAD_TAG && (tag.a.kind == IrOpKind::VmReg || tag.a.kind == IrOpKind::VmConst)) + replace(function, inst.a, tag.a); + } + break; + } + case IrCmd::ADD_NUM: + case IrCmd::SUB_NUM: + case IrCmd::MUL_NUM: + case IrCmd::DIV_NUM: + case IrCmd::MOD_NUM: + case IrCmd::POW_NUM: + { + if (inst.b.kind == IrOpKind::Inst) + { + IrInst& rhs = function.instOp(inst.b); + + if (rhs.useCount == 1 && rhs.cmd == IrCmd::LOAD_DOUBLE && (rhs.a.kind == IrOpKind::VmReg || rhs.a.kind == IrOpKind::VmConst)) + replace(function, inst.b, rhs.a); + } + break; + } + case IrCmd::JUMP_EQ_TAG: + { + if (inst.a.kind == IrOpKind::Inst) + { + IrInst& tagA = function.instOp(inst.a); + + if (tagA.useCount == 1 && tagA.cmd == IrCmd::LOAD_TAG && (tagA.a.kind == IrOpKind::VmReg || tagA.a.kind == IrOpKind::VmConst)) + { + replace(function, inst.a, tagA.a); + break; + } + } + + if (inst.b.kind == IrOpKind::Inst) + { + IrInst& tagB = function.instOp(inst.b); + + if (tagB.useCount == 1 && tagB.cmd == IrCmd::LOAD_TAG && (tagB.a.kind == IrOpKind::VmReg || tagB.a.kind == IrOpKind::VmConst)) + { + std::swap(inst.a, inst.b); + replace(function, inst.a, tagB.a); + } + } + break; + } + case IrCmd::JUMP_CMP_NUM: + { + if (inst.a.kind == IrOpKind::Inst) + { + IrInst& num = function.instOp(inst.a); + + if (num.useCount == 1 && num.cmd == IrCmd::LOAD_DOUBLE) + replace(function, inst.a, num.a); + } + break; + } + default: + break; + } + + if (isBlockTerminator(inst.cmd)) + break; + } +} + +void optimizeMemoryOperandsX64(IrFunction& function) +{ + for (IrBlock& block : function.blocks) + { + if (block.kind == IrBlockKind::Dead) + continue; + + optimizeMemoryOperandsX64(function, block); + } +} + +} // namespace CodeGen +} // namespace Luau diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 32e5ba9bb..ff9a5da61 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -28,7 +28,7 @@ // Upvalues: 0-254. Upvalues refer to the values stored in the closure object. // Constants: 0-2^23-1. Constants are stored in a table allocated with each proto; to allow for future bytecode tweaks the encodable value is limited to 23 bits. // Closures: 0-2^15-1. Closures are created from child protos via a child index; the limit is for the number of closures immediately referenced in each function. -// Jumps: -2^23..2^23. Jump offsets are specified in word increments, so jumping over an instruction may sometimes require an offset of 2 or more. +// Jumps: -2^23..2^23. Jump offsets are specified in word increments, so jumping over an instruction may sometimes require an offset of 2 or more. Note that for jump instructions with AUX, the AUX word is included as part of the jump offset. // # Bytecode versions // Bytecode serialized format embeds a version number, that dictates both the serialized form as well as the allowed instructions. As long as the bytecode version falls into supported @@ -194,7 +194,7 @@ enum LuauOpcode // JUMPIFEQ, JUMPIFLE, JUMPIFLT, JUMPIFNOTEQ, JUMPIFNOTLE, JUMPIFNOTLT: jumps to target offset if the comparison is true (or false, for NOT variants) // A: source register 1 - // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + // D: jump offset (-32768..32767; 1 means "next instruction" aka "don't jump") // AUX: source register 2 LOP_JUMPIFEQ, LOP_JUMPIFLE, @@ -376,14 +376,14 @@ enum LuauOpcode // JUMPXEQKNIL, JUMPXEQKB: jumps to target offset if the comparison with constant is true (or false, see AUX) // A: source register 1 - // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + // D: jump offset (-32768..32767; 1 means "next instruction" aka "don't jump") // AUX: constant value (for boolean) in low bit, NOT flag (that flips comparison result) in high bit LOP_JUMPXEQKNIL, LOP_JUMPXEQKB, // JUMPXEQKN, JUMPXEQKS: jumps to target offset if the comparison with constant is true (or false, see AUX) // A: source register 1 - // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + // D: jump offset (-32768..32767; 1 means "next instruction" aka "don't jump") // AUX: constant table index in low 24 bits, NOT flag (that flips comparison result) in high bit LOP_JUMPXEQKN, LOP_JUMPXEQKS, diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 94dce41ac..a14cc1e65 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,6 +13,7 @@ inline bool isFlagExperimental(const char* flag) static const char* kList[] = { "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauTryhardAnd", // waiting for a fix in graphql-lua -> apollo-client-lia -> lua-apps + "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) // makes sure we always have at least one entry nullptr, }; diff --git a/Sources.cmake b/Sources.cmake index 815301bc8..aef55e6b2 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -70,6 +70,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/IrUtils.h CodeGen/include/Luau/Label.h CodeGen/include/Luau/OperandX64.h + CodeGen/include/Luau/OptimizeFinalX64.h CodeGen/include/Luau/RegisterA64.h CodeGen/include/Luau/RegisterX64.h CodeGen/include/Luau/UnwindBuilder.h @@ -92,7 +93,9 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/IrDump.cpp CodeGen/src/IrLoweringX64.cpp CodeGen/src/IrTranslation.cpp + CodeGen/src/IrUtils.cpp CodeGen/src/NativeState.cpp + CodeGen/src/OptimizeFinalX64.cpp CodeGen/src/UnwindBuilderDwarf2.cpp CodeGen/src/UnwindBuilderWin.cpp @@ -337,6 +340,7 @@ if(TARGET Luau.UnitTest) tests/DenseHash.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp + tests/IrBuilder.test.cpp tests/JsonEmitter.test.cpp tests/Lexer.test.cpp tests/Linter.test.cpp diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index a69965e04..1d31b2813 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -10,8 +10,6 @@ #include -LUAU_FASTFLAG(LuauScopelessModule) - using namespace Luau; namespace @@ -145,8 +143,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "real_source") TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; fileResolver.source["game/Gui/Modules/B"] = R"( local Modules = game:GetService('Gui').Modules @@ -224,8 +220,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "any_annotation_breaks_cycle") TEST_CASE_FIXTURE(FrontendFixture, "nocheck_modules_are_typed") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/Gui/Modules/A"] = R"( --!nocheck export type Foo = number @@ -281,8 +275,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_detection_between_check_and_nocheck") TEST_CASE_FIXTURE(FrontendFixture, "nocheck_cycle_used_by_checked") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/Gui/Modules/A"] = R"( --!nocheck local Modules = game:GetService('Gui').Modules @@ -501,8 +493,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "dont_recheck_script_that_hasnt_been_marked_d TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_is_dirty") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; fileResolver.source["game/Gui/Modules/B"] = R"( local Modules = game:GetService('Gui').Modules diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp new file mode 100644 index 000000000..4ed872862 --- /dev/null +++ b/tests/IrBuilder.test.cpp @@ -0,0 +1,223 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IrBuilder.h" +#include "Luau/IrAnalysis.h" +#include "Luau/IrDump.h" +#include "Luau/OptimizeFinalX64.h" + +#include "doctest.h" + +using namespace Luau::CodeGen; + +class IrBuilderFixture +{ +public: + IrBuilder build; +}; + +TEST_SUITE_BEGIN("Optimization"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(0), fallback); + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmConst(5)); + build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(0), fallback); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(fallback); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into CHECK_TAG + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + CHECK_TAG R2, tnil, bb_fallback_1 + CHECK_TAG K5, tnil, bb_fallback_1 + LOP_RETURN 0u + +bb_fallback_1: + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptBinaryArith") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp opA = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); + IrOp opB = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); + build.inst(IrCmd::ADD_NUM, opA, opB); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into second argument + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R1 + %2 = ADD_NUM %0, R2 + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag1") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp opA = build.inst(IrCmd::LOAD_TAG, build.vmReg(1)); + IrOp opB = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(trueBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into first argument + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %1 = LOAD_TAG R2 + JUMP_EQ_TAG R1, %1, bb_1, bb_2 + +bb_1: + LOP_RETURN 0u + +bb_2: + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag2") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp opA = build.inst(IrCmd::LOAD_TAG, build.vmReg(1)); + IrOp opB = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::STORE_TAG, build.vmReg(6), opA); + build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(trueBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into second argument is it can't be done for the first one + // We also swap first and second argument to generate memory access on the LHS + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %0 = LOAD_TAG R1 + STORE_TAG R6, %0 + JUMP_EQ_TAG R2, %0, bb_1, bb_2 + +bb_1: + LOP_RETURN 0u + +bb_2: + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + IrOp arrElem = build.inst(IrCmd::GET_ARR_ADDR, table, build.constUint(0)); + IrOp opA = build.inst(IrCmd::LOAD_TAG, arrElem); + build.inst(IrCmd::JUMP_EQ_TAG, opA, build.constTag(0), trueBlock, falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(trueBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into first argument + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %0 = LOAD_POINTER R1 + %1 = GET_ARR_ADDR %0, 0u + %2 = LOAD_TAG %1 + JUMP_EQ_TAG %2, tnil, bb_1, bb_2 + +bb_1: + LOP_RETURN 0u + +bb_2: + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp opA = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); + IrOp opB = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); + build.inst(IrCmd::JUMP_CMP_NUM, opA, opB, trueBlock, falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(trueBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into first argument + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %1 = LOAD_DOUBLE R2 + JUMP_CMP_NUM R1, %1, bb_1, bb_2 + +bb_1: + LOP_RETURN 0u + +bb_2: + LOP_RETURN 0u + +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 34c2e8fd8..8557913a6 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -112,8 +112,6 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( return {sign=math.sign} )"); @@ -285,8 +283,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["Module/A"] = R"( export type A = B type B = A @@ -310,7 +306,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_reexports") {"LuauSubstitutionReentrant", true}, {"LuauClassTypeVarsInSubstitution", true}, {"LuauSubstitutionFixMissingFields", true}, - {"LuauScopelessModule", true}, }; fileResolver.source["Module/A"] = R"( @@ -349,7 +344,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") {"LuauSubstitutionReentrant", true}, {"LuauClassTypeVarsInSubstitution", true}, {"LuauSubstitutionFixMissingFields", true}, - {"LuauScopelessModule", true}, }; fileResolver.source["Module/A"] = R"( diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 5deeb35dc..28c5bba06 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -253,8 +253,6 @@ TEST_CASE_FIXTURE(Fixture, "delay_function_does_not_require_its_argument_to_retu TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( --!nonstrict diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 3e98367cd..d5f953746 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -465,8 +465,6 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_always_resolve_to_a_real_type") TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( export type A = {field: number} @@ -498,8 +496,6 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( export type Array = { [number]: T } )"); @@ -527,8 +523,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( export type Record = { name: string, location: string } local a: Record = { name="Waldo", location="?????" } diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 71586f9f2..683469a82 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -109,8 +109,6 @@ TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( local T = {} function T.f(...) diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 3861a8b6c..d7b0bdb4e 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -282,8 +282,14 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") function x:f(): string return self:id("hello") end function x:g(): number return self:id(37) end )"); - // TODO: Quantification should be doing the conversion, not normalization. - LUAU_REQUIRE_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + // TODO: Quantification should be doing the conversion, not normalization. + LUAU_REQUIRE_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") @@ -296,8 +302,14 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") local y: number = self:id(37) end )"); - // TODO: Should typecheck but currently errors CLI-39916 - LUAU_REQUIRE_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + // TODO: Should typecheck but currently errors CLI-39916 + LUAU_REQUIRE_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "infer_generic_property") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 7d629f715..c389f325f 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -514,8 +514,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") // Ideally, we would not try to export a function type with generic types from incorrect scope TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/A"] = R"( local wrapStrictTable @@ -555,8 +553,6 @@ return wrapStrictTable(Constants, "Constants") TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/A"] = R"( local wrapStrictTable diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 43c0b38e8..fb44ec4d4 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -36,25 +36,26 @@ std::optional> magicFunctionInstanceIsA( return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } -std::vector dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) +void dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) { - if (ctx.callSite->args.size != 1) - return {}; + if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty()) + return; auto index = ctx.callSite->func->as(); auto str = ctx.callSite->args.data[0]->as(); if (!index || !str) - return {}; + return; - std::optional def = ctx.dfg->getDef(index->expr); - if (!def) - return {}; + std::optional discriminantTy = ctx.discriminantTypes[0]; + if (!discriminantTy) + return; std::optional tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size)); if (!tfun) - return {}; + return; - return {ctx.refinementArena->proposition(*def, tfun->type)}; + LUAU_ASSERT(get(*discriminantTy)); + asMutable(*discriminantTy)->ty.emplace(tfun->type); } struct RefinementClassFixture : BuiltinsFixture @@ -1491,4 +1492,45 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length CHECK_EQ("table", toString(requireTypeAtPosition({3, 29}))); } +TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_during_constraint_solving_stage") +{ + CheckResult result = check(R"( + type Id = T + + local function f(x: Id | Id>) + if typeof(x) ~= "string" and x:IsA("Part") then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("Folder | string", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_during_constraint_solving_stage_2") +{ + CheckResult result = check(R"( + local function hof(f: (Instance) -> ()) end + + hof(function(inst) + if inst:IsA("Part") then + local foo = inst + else + local foo = inst + end + end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Instance & ~Part", toString(requireTypeAtPosition({7, 28}))); + else + CHECK_EQ("Instance", toString(requireTypeAtPosition({7, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index e3c1ab10c..fcebd1fed 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -18,6 +18,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(LuauDontExtendUnsealedRValueTables) TEST_SUITE_BEGIN("TableTests"); @@ -628,7 +629,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") const TableIndexer& indexer = *ttv->indexer; - REQUIRE_EQ(indexer.indexType, typeChecker.numberType); + REQUIRE("number" == toString(indexer.indexType)); REQUIRE(nullptr != get(follow(indexer.indexResultType))); } @@ -869,6 +870,51 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_ CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } +TEST_CASE_FIXTURE(Fixture, "any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode") +{ + CheckResult result = check(R"( + --!nonstrict + + local constants = { + key1 = "value1", + key2 = "value2" + } + + local function getKey() + return "key1" + end + + local k1 = constants[getKey()] + )"); + + CHECK("any" == toString(requireType("k1"))); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode") +{ + CheckResult result = check(R"( + local constants = { + key1 = "value1", + key2 = "value2" + } + + function getConstant(key) + return constants[key] + end + + local k1 = getConstant("key1") + )"); + + if (FFlag::LuauDontExtendUnsealedRValueTables) + CHECK("any" == toString(requireType("k1"))); + else + CHECK("a" == toString(requireType("k1"))); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer") { CheckResult result = check(R"( @@ -2967,8 +3013,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the // The real bug here was that we weren't always uncondionally typechecking a trailing return statement last. TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( local function a(state) print(state.blah) @@ -3493,4 +3537,59 @@ _ = {_,} LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type") +{ + ScopedFastFlag sff{"LuauDontExtendUnsealedRValueTables", true}; + + CheckResult result = check(R"( + local events = {} + local mockObserveEvent = function(_, key, callback) + events[key] = callback + end + + events['FriendshipNotifications']({ + EventArgs = { + UserId2 = '2' + }, + Type = 'FriendshipDeclined' + }) + )"); + + TypeId ty = follow(requireType("events")); + const TableType* tt = get(ty); + REQUIRE_MESSAGE(tt, "Expected table but got " << toString(ty, {true})); + + CHECK(tt->props.empty()); + REQUIRE(tt->indexer); + + CHECK("string" == toString(tt->indexer->indexType)); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_extend_unsealed_tables_in_rvalue_position") +{ + ScopedFastFlag sff{"LuauDontExtendUnsealedRValueTables", true}; + + CheckResult result = check(R"( + local testDictionary = { + FruitName = "Lemon", + FruitColor = "Yellow", + Sour = true + } + + local print: any + + print(testDictionary[""]) + )"); + + TypeId ty = follow(requireType("testDictionary")); + const TableType* ttv = get(ty); + REQUIRE(ttv); + + CHECK(0 == ttv->props.count("")); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 4f0afc35d..16797ee4d 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1158,4 +1158,18 @@ end LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "typechecking_in_type_guards") +{ + ScopedFastFlag sff{"LuauTypecheckTypeguards", true}; + + CheckResult result = check(R"( +local a = type(foo) == 'nil' +local b = typeof(foo) ~= 'nil' + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Unknown global 'foo'"); + CHECK(toString(result.errors[1]) == "Unknown global 'foo'"); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index 37666878a..565982cf2 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -11,13 +11,9 @@ AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop AutocompleteTest.autocomplete_first_function_arg_expected_type AutocompleteTest.autocomplete_oop_implicit_self +AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.do_compatible_self_calls AutocompleteTest.do_wrong_compatible_self_calls -AutocompleteTest.keyword_methods -AutocompleteTest.no_incompatible_self_calls -AutocompleteTest.no_wrong_compatible_self_calls_with_generics -AutocompleteTest.string_singleton_as_table_key -AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_expected_argument_type_pack_suggestion AutocompleteTest.type_correct_expected_argument_type_suggestion_self AutocompleteTest.type_correct_expected_return_type_pack_suggestion @@ -60,7 +56,6 @@ DefinitionTests.single_class_type_identity_in_global_types FrontendTest.environments FrontendTest.nocheck_cycle_used_by_checked FrontendTest.reexport_cyclic_type -FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics2 GenericsTests.better_mismatch_error_messages @@ -126,6 +121,7 @@ ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weirditer_should_not_loop_forever RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string +RefinementTest.discriminate_tag RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.narrow_property_of_a_bounded_variable @@ -136,23 +132,19 @@ RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table -RefinementTest.x_is_not_instance_or_else_not_part RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_can_turn_into_a_scalar_directly TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.accidentally_checked_prop_in_opposite_branch -TableTests.builtin_table_names +TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode TableTests.call_method -TableTests.call_method_with_explicit_self_argument TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.checked_prop_too_early -TableTests.defining_a_method_for_a_local_unsealed_table_is_ok -TableTests.defining_a_self_method_for_a_local_unsealed_table_is_ok +TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index TableTests.dont_quantify_table_that_belongs_to_outer_scope -TableTests.dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table TableTests.dont_suggest_exact_match_keys TableTests.error_detailed_metatable_prop TableTests.expected_indexer_from_table_union @@ -175,7 +167,6 @@ TableTests.infer_array_2 TableTests.inferred_return_type_of_free_table TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 -TableTests.instantiate_tables_at_scope_level TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors @@ -191,7 +182,6 @@ TableTests.oop_polymorphic TableTests.open_table_unification_2 TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table -TableTests.quantifying_a_bound_var_works TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table TableTests.result_is_always_any_if_lhs_is_any TableTests.result_is_bool_for_equality_operators_if_lhs_is_any @@ -200,7 +190,6 @@ TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables TableTests.table_call_metamethod_basic -TableTests.table_function_check_use_after_free TableTests.table_indexing_error_location TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict TableTests.table_insert_should_cope_with_optional_properties_in_strict @@ -209,16 +198,11 @@ TableTests.table_simple_call TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.table_unification_4 -TableTests.tc_member_function TableTests.tc_member_function_2 -TableTests.unifying_tables_shouldnt_uaf1 TableTests.unifying_tables_shouldnt_uaf2 -TableTests.used_colon_correctly TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon -TableTests.used_dot_instead_of_colon_but_correctly ToString.exhaustive_toString_of_cyclic_table -ToString.function_type_with_argument_names_and_self ToString.function_type_with_argument_names_generic ToString.named_metatable_toStringNamedFunction ToString.toStringDetailed2 @@ -238,7 +222,6 @@ TryUnifyTests.variadics_should_use_reversed_properly TypeAliases.cannot_create_cyclic_type_with_unknown_module TypeAliases.corecursive_types_generic TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any -TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2 TypeAliases.generic_param_remap TypeAliases.mismatched_generic_type_param TypeAliases.mutually_recursive_types_errors @@ -256,7 +239,6 @@ TypeInfer.checking_should_not_ice TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError -TypeInfer.follow_on_new_types_in_substitution TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.globals TypeInfer.globals2 @@ -281,9 +263,7 @@ TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site -TypeInferFunctions.dont_mutate_the_underlying_head_of_typepack_when_calling_with_self TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict -TypeInferFunctions.first_argument_can_be_optional TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite @@ -309,7 +289,6 @@ TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function -TypeInferFunctions.vararg_function_is_quantified TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_generic_next @@ -318,22 +297,15 @@ TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop -TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free -TypeInferModules.bound_free_table_export_is_ok TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types_4 -TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated -TypeInferModules.require_a_variadic_function TypeInferModules.type_error_of_unknown_qualified_type TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory -TypeInferOOP.method_depends_on_table TypeInferOOP.methods_are_topologically_sorted -TypeInferOOP.nonstrict_self_mismatch_tail TypeInferOOP.object_constructor_can_refer_to_method_of_self -TypeInferOOP.table_oop TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.CallOrOfFunctions TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable @@ -345,7 +317,6 @@ TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_ TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown TypeInferOperators.operator_eq_completely_incompatible -TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.UnknownGlobalCompoundAssign @@ -358,7 +329,6 @@ TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ TypeInferUnknownNever.math_operators_and_never TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check -TypePackTests.self_and_varargs_should_work TypePackTests.type_alias_backwards_compatible TypePackTests.type_alias_default_export TypePackTests.type_alias_default_mixed_self diff --git a/tools/flag-bisect.py b/tools/flag-bisect.py new file mode 100644 index 000000000..01f3ef7ce --- /dev/null +++ b/tools/flag-bisect.py @@ -0,0 +1,458 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +import argparse +import asyncio +import copy +import json +import math +import os +import platform +import re +import subprocess +import sys +import textwrap +from enum import Enum + +def add_parser(subparsers): + flag_bisect_command = subparsers.add_parser('flag-bisect', + help=help(), + description=help(), + epilog=epilog(), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + add_argument_parsers(flag_bisect_command) + flag_bisect_command.set_defaults(func=flag_bisect_main) + return flag_bisect_command + +def help(): + return 'Search for a set of flags triggering the faulty behavior in unit tests' + +def get_terminal_width(): + try: + return os.get_terminal_size().columns + except: + # Return a reasonable default when a terminal is not available + return 80 +def wrap_text(text, width): + leading_whitespace_re = re.compile('( *)') + + def get_paragraphs_and_indent(string): + lines = string.split('\n') + result = '' + line_count = 0 + initial_indent = '' + subsequent_indent = '' + for line in lines: + if len(line.strip()) == 0: + if line_count > 0: + yield result, initial_indent, subsequent_indent + result = '' + line_count = 0 + else: + line_count += 1 + if line_count == 1: + initial_indent = leading_whitespace_re.match(line).group(1) + subsequent_indent = initial_indent + elif line_count == 2: + subsequent_indent = leading_whitespace_re.match(line).group(1) + result += line.strip() + '\n' + + result = '' + for paragraph, initial_indent, subsequent_indent in get_paragraphs_and_indent(text): + result += textwrap.fill(paragraph, width=width, initial_indent=initial_indent, subsequent_indent=subsequent_indent, break_on_hyphens=False) + '\n\n' + return result + +def wrap_text_for_terminal(text): + right_margin = 2 # This margin matches what argparse uses when formatting argument documentation + min_width = 20 + width = max(min_width, get_terminal_width() - right_margin) + return wrap_text(text, width) + +def epilog(): + return wrap_text_for_terminal(''' + This tool uses the delta debugging algorithm to minimize the set of flags to the ones that are faulty in your unit tests, + and the usage is trivial. Just provide a path to the unit test and you're done, the tool will do the rest. + + There are many use cases with flag-bisect. Included but not limited to: + + 1: If your test is failing when you omit `--fflags=true` but it works when passing `--fflags=true`, then you can + use this tool to find that set of flag requirements to see which flags are missing that will help to fix it. Ditto + for the opposite too, this tool is generalized for that case. + + 2: If you happen to run into a problem on production, and you're not sure which flags is the problem and you can easily + create a unit test, you can run flag-bisect on that unit test to rapidly find the set of flags. + + 3: If you have a flag that causes a performance regression, there's also the `--timeout=N` where `N` is in seconds. + + 4: If you have tests that are demonstrating flakiness behavior, you can also use `--tries=N` where `N` is the number of + attempts to run the same set of flags before moving on to the new set. This will eventually drill down to the flaky flag(s). + Generally 8 tries should be more than enough, but it depends on the rarity. The more rare it is, the higher the attempts count + needs to be. Note that this comes with a performance cost the higher you go, but certainly still faster than manual search. + This argument will disable parallel mode by default. If this is not desired, explicitly write `--parallel=on`. + + 5: By default flag-bisect runs in parallel mode which uses a slightly modified version of delta debugging algorithm to support + trying multiple sets of flags concurrently. This means that the number of sets the algorithm will try at once is equal to the + number of concurrent jobs. There is currently no upper bound to that, so heed this warning that your machine may slow down + significantly. In this mode, we display the number of jobs it is running in parallel. Use `--parallel=off` to disable parallel + mode. + + Be aware that this introduces some level of *non-determinism*, and it is fundamental due to the interaction with flag dependencies + and the fact one job may finish faster than another job that got ran in the same cycle. However, it generally shouldn't matter + if your test is deterministic and has no implicit flag dependencies in the codebase. + + The tool will try to automatically figure out which of `--pass` or `--fail` to use if you omit them or use `--auto` by applying + heuristics. For example, if the tests works using `--fflags=true` and crashes if omitting `--fflags=true`, then it knows + to use `--pass` to give you set of flags that will cause that crash. As usual, vice versa is also true. Since this is a + heuristic, if it gets that guess wrong, you can override with `--pass` or `--fail`. + + You can speed this process up by scoping it to as few tests as possible, for example if you're using doctest then you'd + pass `--tc=my_test` as an argument after `--`, so `flag-bisect ./path/to/binary -- --tc=my_test`. + ''') + +class InterestnessMode(Enum): + AUTO = 0, + FAIL = 1, + PASS = 2, + +def add_argument_parsers(parser): + parser.add_argument('binary_path', help='Path to the unit test binary that will be bisected for a set of flags') + + parser.add_argument('--tries', dest='attempts', type=int, default=1, metavar='N', + help='If the tests are flaky, flag-bisect will try again with the same set by N amount of times before moving on') + + parser.add_argument('--parallel', dest='parallel', choices=['on', 'off'], default='default', + help='Test multiple sets of flags in parallel, useful when the test takes a while to run.') + + parser.add_argument('--explicit', dest='explicit', action='store_true', default=False, help='Explicitly set flags to false') + + parser.add_argument('--filter', dest='filter', default=None, help='Regular expression to filter for a subset of flags to test') + + parser.add_argument('--verbose', dest='verbose', action='store_true', default=False, help='Show stdout and stderr of the program being run') + + interestness_parser = parser.add_mutually_exclusive_group() + interestness_parser.add_argument('--auto', dest='mode', action='store_const', const=InterestnessMode.AUTO, + default=InterestnessMode.AUTO, help='Automatically figure out which one of --pass or --fail should be used') + interestness_parser.add_argument('--fail', dest='mode', action='store_const', const=InterestnessMode.FAIL, + help='You want this if omitting --fflags=true causes tests to fail') + interestness_parser.add_argument('--pass', dest='mode', action='store_const', const=InterestnessMode.PASS, + help='You want this if passing --fflags=true causes tests to pass') + interestness_parser.add_argument('--timeout', dest='timeout', type=int, default=0, metavar='SECONDS', + help='Find the flag(s) causing performance regression if time to run exceeds the timeout in seconds') + +class Options: + def __init__(self, args, other_args, sense): + self.path = args.binary_path + self.explicit = args.explicit + self.sense = sense + self.timeout = args.timeout + self.interested_in_timeouts = args.timeout != 0 + self.attempts = args.attempts + self.parallel = (args.parallel == 'on' or args.parallel == 'default') if args.attempts == 1 else args.parallel == 'on' + self.filter = re.compile(".*" + args.filter + ".*") if args.filter else None + self.verbose = args.verbose + self.other_args = [arg for arg in other_args if arg != '--'] # Useless to have -- here, discard. + + def copy_with_sense(self, sense): + new_copy = copy.copy(self) + new_copy.sense = sense + return new_copy + +class InterestnessResult(Enum): + FAIL = 0, + PASS = 1, + TIMED_OUT = 2, + +class Progress: + def __init__(self, count, n_of_jobs=None): + self.count = count + self.steps = 0 + self.n_of_jobs = n_of_jobs + self.buffer = None + + def show(self): + # remaining is actually the height of the current search tree. + remain = int(math.log2(self.count)) + flag_plural = 'flag' if self.count == 1 else 'flags' + node_plural = 'node' if remain == 1 else 'nodes' + jobs_info = f', running {self.n_of_jobs} jobs' if self.n_of_jobs is not None else '' + return f'flag bisection: testing {self.count} {flag_plural} (step {self.steps}, {remain} {node_plural} remain{jobs_info})' + + def hide(self): + if self.buffer: + sys.stdout.write('\b \b' * len(self.buffer)) + + def update(self, len, n_of_jobs=None): + self.hide() + self.count = len + self.steps += 1 + self.n_of_jobs = n_of_jobs + self.buffer = self.show() + sys.stdout.write(self.buffer) + sys.stdout.flush() + +def list_fflags(options): + try: + out = subprocess.check_output([options.path, '--list-fflags'], encoding='UTF-8') + flag_names = [] + + # It's unlikely that a program we're going to test has no flags. + # So if the output doesn't start with FFlag, assume it doesn't support --list-fflags and therefore cannot be bisected. + if not out.startswith('FFlag') and not out.startswith('DFFlag') and not out.startswith('SFFlag'): + return None + + flag_names = out.split('\n')[:-1] + + subset = [flag for flag in flag_names if options.filter.match(flag) is not None] if options.filter else flag_names + return subset if subset else None + except: + return None + +def mk_flags_argument(options, flags, initial_flags): + lst = [flag + '=true' for flag in flags] + + # When --explicit is provided, we'd like to find the set of flags from initial_flags that's not in active flags. + # This is so that we can provide a =false value instead of leaving them out to be the default value. + if options.explicit: + for flag in initial_flags: + if flag not in flags: + lst.append(flag + '=false') + + return '--fflags=' + ','.join(lst) + +def mk_command_line(options, flags_argument): + arguments = [options.path, *options.other_args] + if flags_argument is not None: + arguments.append(flags_argument) + return arguments + +async def get_interestness(options, flags_argument): + try: + timeout = options.timeout if options.interested_in_timeouts else None + cmd = mk_command_line(options, flags_argument) + stdout = subprocess.PIPE if not options.verbose else None + stderr = subprocess.PIPE if not options.verbose else None + process = subprocess.run(cmd, stdout=stdout, stderr=stderr, timeout=timeout) + return InterestnessResult.PASS if process.returncode == 0 else InterestnessResult.FAIL + except subprocess.TimeoutExpired: + return InterestnessResult.TIMED_OUT + +async def is_hot(options, flags_argument, pred=any): + results = await asyncio.gather(*[get_interestness(options, flags_argument) for _ in range(options.attempts)]) + + if options.interested_in_timeouts: + return pred([InterestnessResult.TIMED_OUT == x for x in results]) + else: + return pred([(InterestnessResult.PASS if options.sense else InterestnessResult.FAIL) == x for x in results]) + +def pairwise_disjoints(flags, granularity): + offset = 0 + per_slice_len = len(flags) // granularity + while offset < len(flags): + yield flags[offset:offset + per_slice_len] + offset += per_slice_len + +def subsets_and_complements(flags, granularity): + for disjoint_set in pairwise_disjoints(flags, granularity): + yield disjoint_set, [flag for flag in flags if flag not in disjoint_set] + +# https://www.cs.purdue.edu/homes/xyzhang/fall07/Papers/delta-debugging.pdf +async def ddmin(options, initial_flags): + current = initial_flags + granularity = 2 + + progress = Progress(len(current)) + progress.update(len(current)) + + while len(current) >= 2: + changed = False + + for (subset, complement) in subsets_and_complements(current, granularity): + progress.update(len(current)) + if await is_hot(options, mk_flags_argument(options, complement, initial_flags)): + current = complement + granularity = max(granularity - 1, 2) + changed = True + break + elif await is_hot(options, mk_flags_argument(options, subset, initial_flags)): + current = subset + granularity = 2 + changed = True + break + + if not changed: + if granularity == len(current): + break + granularity = min(granularity * 2, len(current)) + + progress.hide() + return current + +async def ddmin_parallel(options, initial_flags): + current = initial_flags + granularity = 2 + + progress = Progress(len(current)) + progress.update(len(current), granularity) + + while len(current) >= 2: + changed = False + + subset_jobs = [] + complement_jobs = [] + + def advance(task): + nonlocal current + nonlocal granularity + nonlocal changed + # task.cancel() calls the callback passed to add_done_callback... + if task.cancelled(): + return + hot, new_delta, new_granularity = task.result() + if hot and not changed: + current = new_delta + granularity = new_granularity + changed = True + for job in subset_jobs: + job.cancel() + for job in complement_jobs: + job.cancel() + + for (subset, complement) in subsets_and_complements(current, granularity): + async def work(flags, new_granularity): + hot = await is_hot(options, mk_flags_argument(options, flags, initial_flags)) + return (hot, flags, new_granularity) + + # We want to run subset jobs in parallel first. + subset_job = asyncio.create_task(work(subset, 2)) + subset_job.add_done_callback(advance) + subset_jobs.append(subset_job) + + # Then the complements afterwards, but only if we didn't find a new subset. + complement_job = asyncio.create_task(work(complement, max(granularity - 1, 2))) + complement_job.add_done_callback(advance) + complement_jobs.append(complement_job) + + # When we cancel jobs, the asyncio.gather will be waiting pointlessly. + # In that case, we'd like to return the control to this routine. + await asyncio.gather(*subset_jobs, return_exceptions=True) + if not changed: + await asyncio.gather(*complement_jobs, return_exceptions=True) + progress.update(len(current), granularity) + + if not changed: + if granularity == len(current): + break + granularity = min(granularity * 2, len(current)) + + progress.hide() + return current + +def search(options, initial_flags): + if options.parallel: + return ddmin_parallel(options, initial_flags) + else: + return ddmin(options, initial_flags) + +async def do_work(args, other_args): + sense = None + + # If --timeout isn't used, try to apply a heuristic to figure out which of --pass or --fail we want. + if args.timeout == 0 and args.mode == InterestnessMode.AUTO: + inner_options = Options(args, other_args, sense) + + # We aren't interested in timeout for this heuristic. It just makes no sense to assume timeouts. + # This actually cannot happen by this point, but if we make timeout a non-exclusive switch to --auto, this will go wrong. + inner_options.timeout = 0 + inner_options.interested_in_timeouts = False + + all_tasks = asyncio.gather( + is_hot(inner_options.copy_with_sense(True), '--fflags=true', all), + is_hot(inner_options.copy_with_sense(False), '--fflags=false' if inner_options.explicit else None, all), + ) + + # If it times out, we can print a message saying that this is still working. We intentionally want to continue doing work. + done, pending = await asyncio.wait([all_tasks], timeout=1.5) + if all_tasks not in done: + print('Hang on! I\'m running your program to try and figure out which of --pass or --fail to use!') + print('Need to find out faster? Cancel the work and explicitly write --pass or --fail') + + is_pass_hot, is_fail_hot = await all_tasks + + # This is a bit counter-intuitive, but the following table tells us which of the sense we want. + # Because when you omit --fflags=true argument and it fails, then is_fail_hot is True. + # Consequently, you need to use --pass to find out what that set of flags is. And vice versa. + # + # Also, when is_pass_hot is True and is_fail_hot is False, then that program is working as expected. + # There should be no reason to run flag bisection. + # However, this can be ambiguous in the opposite of the aforementioned outcome! + # + # is_pass_hot | is_fail_hot | is ambiguous? + #-------------|-------------|--------------- + # True | True | No! Pick --pass. + # False | False | No! Pick --fail. + # True | False | No! But this is the exact situation where you shouldn't need to flag-bisect. Raise an error. + # False | True | Yes! But we'll pragmatically pick --fail here in the hope it gives the correct set of flags. + + if is_pass_hot and not is_fail_hot: + print('The tests seems to be working fine for me. If you really need to flag-bisect, please try again with an explicit --pass or --fail', file=sys.stderr) + return 1 + + if not is_pass_hot and is_fail_hot: + print('I couldn\'t quite figure out which of --pass or --fail to use, but I\'ll carry on anyway') + + sense = is_pass_hot + argument = '--pass' if sense else '--fail' + print(f'I\'m bisecting flags as if {argument} was used') + else: + sense = True if args.mode == InterestnessMode.PASS else False + + options = Options(args, other_args, sense) + + initial_flags = list_fflags(options) + if initial_flags is None: + print('I cannot bisect flags with ' + options.path, file=sys.stderr) + print('These are required for me to be able to cooperate:', file=sys.stderr) + print('\t--list-fflags must print a list of flags separated by newlines, including FFlag prefix', file=sys.stderr) + print('\t--fflags=... to accept a comma-separated pair of flag names and their value in the form FFlagFoo=true', file=sys.stderr) + return 1 + + # On Windows, there is an upper bound on the numbers of characters for a command line incantation. + # If we don't handle this ourselves, the runtime error is going to look nothing like the actual problem. + # It'd say "file name way too long" or something to that effect. We can teed up a better error message and + # tell the user how to work around it by using --filter. + if platform.system() == 'Windows': + cmd_line = ' '.join(mk_command_line(options, mk_flags_argument(options, initial_flags, []))) + if len(cmd_line) >= 8191: + print(f'Never mind! The command line is too long because we have {len(initial_flags)} flags to test', file=sys.stderr) + print('Consider using `--filter=` to narrow it down upfront, or use any version of WSL instead', file=sys.stderr) + return 1 + + hot_flags = await search(options, initial_flags) + if hot_flags: + print('I narrowed down to these flags:') + print(textwrap.indent('\n'.join(hot_flags), prefix='\t')) + + # If we showed the command line in explicit mode, all flags would be listed here. + # This would pollute the terminal with 3000 flags. We don't want that. Don't show it. + # Ditto for when the number flags we bisected are equal. + if not options.explicit and len(hot_flags) != len(initial_flags): + print('$ ' + ' '.join(mk_command_line(options, mk_flags_argument(options, hot_flags, initial_flags)))) + + return 0 + + print('I found nothing, sorry', file=sys.stderr) + return 1 + +def flag_bisect_main(args, other_args): + return asyncio.run(do_work(args, other_args)) + +def main(): + parser = argparse.ArgumentParser(description=help(), epilog=epilog(), formatter_class=argparse.RawTextHelpFormatter) + add_argument_parsers(parser) + args, other_args = parser.parse_known_args() + return flag_bisect_main(args, other_args) + +if __name__ == '__main__': + sys.exit(main()) From 5c77305609862faa45241bfad94d6e309e698364 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 17 Feb 2023 16:53:37 +0200 Subject: [PATCH 36/66] Sync to upstream/release/564 --- Analysis/include/Luau/Constraint.h | 3 +- .../include/Luau/ConstraintGraphBuilder.h | 12 +- Analysis/include/Luau/ConstraintSolver.h | 2 +- Analysis/include/Luau/Frontend.h | 14 +- Analysis/src/Autocomplete.cpp | 4 - Analysis/src/BuiltinDefinitions.cpp | 4 - Analysis/src/ConstraintGraphBuilder.cpp | 254 ++++++++------ Analysis/src/ConstraintSolver.cpp | 87 ++--- Analysis/src/DataFlowGraph.cpp | 9 +- Analysis/src/Frontend.cpp | 53 ++- Analysis/src/Module.cpp | 5 +- Analysis/src/Type.cpp | 36 +- Analysis/src/TypeChecker2.cpp | 83 +++-- Analysis/src/Unifier.cpp | 57 +--- CLI/Repl.cpp | 9 +- CodeGen/include/Luau/IrData.h | 27 +- CodeGen/include/Luau/IrUtils.h | 21 ++ CodeGen/src/IrDump.cpp | 6 +- CodeGen/src/IrLoweringX64.cpp | 25 +- CodeGen/src/IrTranslation.cpp | 10 +- CodeGen/src/IrUtils.cpp | 260 ++++++++++++++- Common/include/Luau/ExperimentalFlags.h | 2 +- tests/Autocomplete.test.cpp | 2 - tests/ClassFixture.h | 1 + tests/Fixture.cpp | 17 +- tests/IrBuilder.test.cpp | 315 +++++++++++++++++- tests/TypeInfer.aliases.test.cpp | 2 +- tests/TypeInfer.functions.test.cpp | 20 ++ tests/TypeInfer.generics.test.cpp | 23 +- tests/TypeInfer.operators.test.cpp | 2 +- tests/TypeInfer.refinements.test.cpp | 14 +- tests/TypeInfer.tryUnify.test.cpp | 2 - tools/faillist.txt | 49 +-- 33 files changed, 1047 insertions(+), 383 deletions(-) diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 18ff30921..65599e498 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -90,6 +90,8 @@ struct NameConstraint TypeId namedType; std::string name; bool synthetic = false; + std::vector typeParameters; + std::vector typePackParameters; }; // target ~ inst target @@ -101,7 +103,6 @@ struct TypeAliasExpansionConstraint struct FunctionCallConstraint { - std::vector> innerConstraints; TypeId fn; TypePackId argsPack; TypePackId result; diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 29afabf38..085b67328 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -162,10 +162,10 @@ struct ConstraintGraphBuilder void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); void visit(const ScopePtr& scope, AstStatError* error); - InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes = {}); - InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes = {}); + InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes = {}); + InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector>& expectedTypes = {}); - InferencePack checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes); + InferencePack checkPack(const ScopePtr& scope, AstExprCall* call); /** * Checks an expression that is expected to evaluate to one type. @@ -244,8 +244,10 @@ struct ConstraintGraphBuilder **/ TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments); - std::vector> createGenerics(const ScopePtr& scope, AstArray generics); - std::vector> createGenericPacks(const ScopePtr& scope, AstArray packs); + std::vector> createGenerics( + const ScopePtr& scope, AstArray generics, bool useCache = false); + std::vector> createGenericPacks( + const ScopePtr& scope, AstArray packs, bool useCache = false); Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 66d3e8f3f..de7b3a044 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -183,7 +183,7 @@ struct ConstraintSolver /** Pushes a new solver constraint to the solver. * @param cv the body of the constraint. **/ - void pushConstraint(NotNull scope, const Location& location, ConstraintV cv); + NotNull pushConstraint(NotNull scope, const Location& location, ConstraintV cv); /** * Attempts to resolve a module from its module information. Returns the diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index dfb35cbdb..403551f67 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -162,7 +162,7 @@ struct Frontend ScopePtr getGlobalScope(); private: - ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector requireCycles, + ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete = false); std::pair getSourceNode(const ModuleName& name); @@ -202,4 +202,16 @@ struct Frontend ScopePtr globalScope; }; +ModulePtr check( + const SourceModule& sourceModule, + const std::vector& requireCycles, + NotNull builtinTypes, + NotNull iceHandler, + NotNull moduleResolver, + NotNull fileResolver, + const ScopePtr& globalScope, + NotNull unifierState, + FrontendOptions options +); + } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index dd6e11468..85e27168a 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -16,7 +16,6 @@ LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInWhile, false); LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInFor, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringContent, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -1268,9 +1267,6 @@ static bool isSimpleInterpolatedString(const AstNode* node) static std::optional getStringContents(const AstNode* node) { - if (!FFlag::LuauAutocompleteStringContent) - return std::nullopt; - if (const AstExprConstantString* string = node->as()) { return std::string(string->value.data, string->value.size); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index c17169f45..7bb57208c 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -15,8 +15,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false) - /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk @@ -558,8 +556,6 @@ static std::optional> magicFunctionSetMetaTable( if (tableName == metatableName) mtv.syntheticName = tableName; - else if (!FFlag::LuauBuiltInMetatableNoBadSynthetic) - mtv.syntheticName = "{ @metatable: " + metatableName + ", " + tableName + " }"; } TypeId mtTy = arena.addType(mtv); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index f773863c9..aa605bdf0 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -11,6 +11,8 @@ #include "Luau/TypeUtils.h" #include "Luau/Type.h" +#include + LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); @@ -334,13 +336,12 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, // In order to enable mutually-recursive type aliases, we need to // populate the type bindings before we actually check any of the - // alias statements. Since we're not ready to actually resolve - // any of the annotations, we just use a fresh type for now. + // alias statements. for (AstStat* stat : block->body) { if (auto alias = stat->as()) { - if (scope->privateTypeBindings.count(alias->name.value) != 0) + if (scope->exportedTypeBindings.count(alias->name.value) || scope->privateTypeBindings.count(alias->name.value)) { auto it = aliasDefinitionLocations.find(alias->name.value); LUAU_ASSERT(it != aliasDefinitionLocations.end()); @@ -348,30 +349,28 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, continue; } - bool hasGenerics = alias->generics.size > 0 || alias->genericPacks.size > 0; - - ScopePtr defnScope = scope; - if (hasGenerics) - { - defnScope = childScope(alias, scope); - } + ScopePtr defnScope = childScope(alias, scope); - TypeId initialType = freshType(scope); - TypeFun initialFun = TypeFun{initialType}; + TypeId initialType = arena->addType(BlockedType{}); + TypeFun initialFun{initialType}; - for (const auto& [name, gen] : createGenerics(defnScope, alias->generics)) + for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true)) { initialFun.typeParams.push_back(gen); defnScope->privateTypeBindings[name] = TypeFun{gen.ty}; } - for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks)) + for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true)) { initialFun.typePackParams.push_back(genPack); defnScope->privateTypePackBindings[name] = genPack.tp; } - scope->privateTypeBindings[alias->name.value] = std::move(initialFun); + if (alias->exported) + scope->exportedTypeBindings[alias->name.value] = std::move(initialFun); + else + scope->privateTypeBindings[alias->name.value] = std::move(initialFun); + astTypeAliasDefiningScopes[alias] = defnScope; aliasDefinitionLocations[alias->name.value] = alias->location; } @@ -387,42 +386,46 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) if (auto s = stat->as()) visit(scope, s); - else if (auto s = stat->as()) - visit(scope, s); - else if (auto s = stat->as()) - visit(scope, s); - else if (auto s = stat->as()) - visit(scope, s); + else if (auto i = stat->as()) + visit(scope, i); else if (auto s = stat->as()) visit(scope, s); else if (auto s = stat->as()) visit(scope, s); - else if (auto f = stat->as()) - visit(scope, f); - else if (auto f = stat->as()) - visit(scope, f); + else if (stat->is() || stat->is()) + { + // Nothing + } else if (auto r = stat->as()) visit(scope, r); + else if (auto e = stat->as()) + checkPack(scope, e->expr); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); else if (auto a = stat->as()) visit(scope, a); else if (auto a = stat->as()) visit(scope, a); - else if (auto e = stat->as()) - checkPack(scope, e->expr); - else if (auto i = stat->as()) - visit(scope, i); + else if (auto f = stat->as()) + visit(scope, f); + else if (auto f = stat->as()) + visit(scope, f); else if (auto a = stat->as()) visit(scope, a); else if (auto s = stat->as()) visit(scope, s); - else if (auto s = stat->as()) - visit(scope, s); else if (auto s = stat->as()) visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); else if (auto s = stat->as()) visit(scope, s); else - LUAU_ASSERT(0); + LUAU_ASSERT(0 && "Internal error: Unknown AstStat type"); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) @@ -482,7 +485,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) } else { - std::vector expectedTypes; + std::vector> expectedTypes; if (hasAnnotation) expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes)); @@ -680,6 +683,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct TypeId generalizedType = arena->addType(BlockedType{}); + Checkpoint start = checkpoint(this); FunctionSignature sig = checkFunctionSignature(scope, function->func); if (AstExprLocal* localName = function->name->as()) @@ -724,7 +728,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct if (generalizedType == nullptr) ice->ice("generalizedType == nullptr", function->location); - Checkpoint start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); Checkpoint end = checkpoint(this); @@ -745,7 +748,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) // interesting in it is if the function has an explicit return annotation. // If this is the case, then we can expect that the return expression // conforms to that. - std::vector expectedTypes; + std::vector> expectedTypes; for (TypeId ty : scope->returnType) expectedTypes.push_back(ty); @@ -764,8 +767,21 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { TypePackId varPackId = checkLValues(scope, assign->vars); - TypePack expectedTypes = extendTypePack(*arena, builtinTypes, varPackId, assign->values.size); - TypePackId valuePack = checkPack(scope, assign->values, expectedTypes.head).tp; + TypePack expectedPack = extendTypePack(*arena, builtinTypes, varPackId, assign->values.size); + + std::vector> expectedTypes; + expectedTypes.reserve(expectedPack.head.size()); + + for (TypeId ty : expectedPack.head) + { + ty = follow(ty); + if (get(ty)) + expectedTypes.push_back(std::nullopt); + else + expectedTypes.push_back(ty); + } + + TypePackId valuePack = checkPack(scope, assign->values, expectedTypes).tp; addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); } @@ -800,35 +816,70 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement } } +static bool occursCheck(TypeId needle, TypeId haystack) +{ + LUAU_ASSERT(get(needle)); + haystack = follow(haystack); + + auto checkHaystack = [needle](TypeId haystack) { + return occursCheck(needle, haystack); + }; + + if (needle == haystack) + return true; + else if (auto ut = get(haystack)) + return std::any_of(begin(ut), end(ut), checkHaystack); + else if (auto it = get(haystack)) + return std::any_of(begin(it), end(it), checkHaystack); + + return false; +} + void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias) { - auto bindingIt = scope->privateTypeBindings.find(alias->name.value); - ScopePtr* defnIt = astTypeAliasDefiningScopes.find(alias); + ScopePtr* defnScope = astTypeAliasDefiningScopes.find(alias); + + std::unordered_map* typeBindings; + if (alias->exported) + typeBindings = &scope->exportedTypeBindings; + else + typeBindings = &scope->privateTypeBindings; + // These will be undefined if the alias was a duplicate definition, in which // case we just skip over it. - if (bindingIt == scope->privateTypeBindings.end() || defnIt == nullptr) - { + auto bindingIt = typeBindings->find(alias->name.value); + if (bindingIt == typeBindings->end() || defnScope == nullptr) return; - } - ScopePtr resolvingScope = *defnIt; - TypeId ty = resolveType(resolvingScope, alias->type, /* inTypeArguments */ false); + TypeId ty = resolveType(*defnScope, alias->type, /* inTypeArguments */ false); - if (alias->exported) + TypeId aliasTy = bindingIt->second.type; + LUAU_ASSERT(get(aliasTy)); + + if (occursCheck(aliasTy, ty)) { - Name typeName(alias->name.value); - scope->exportedTypeBindings[typeName] = TypeFun{ty}; + asMutable(aliasTy)->ty.emplace(builtinTypes->anyType); + reportError(alias->nameLocation, OccursCheckFailed{}); } - - LUAU_ASSERT(get(bindingIt->second.type)); - - // Rather than using a subtype constraint, we instead directly bind - // the free type we generated in the first pass to the resolved type. - // This prevents a case where you could cause another constraint to - // bind the free alias type to an unrelated type, causing havoc. - asMutable(bindingIt->second.type)->ty.emplace(ty); - - addConstraint(scope, alias->location, NameConstraint{ty, alias->name.value}); + else + asMutable(aliasTy)->ty.emplace(ty); + + std::vector typeParams; + for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true)) + typeParams.push_back(tyParam.second.ty); + + std::vector typePackParams; + for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true)) + typePackParams.push_back(tpParam.second.tp); + + addConstraint(scope, alias->type->location, + NameConstraint{ + ty, + alias->name.value, + /*synthetic=*/false, + std::move(typeParams), + std::move(typePackParams), + }); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) @@ -997,7 +1048,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) check(scope, expr); } -InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes) +InferencePack ConstraintGraphBuilder::checkPack( + const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes) { std::vector head; std::optional tail; @@ -1010,11 +1062,11 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray< std::optional expectedType; if (i < expectedTypes.size()) expectedType = expectedTypes[i]; - head.push_back(check(scope, expr).ty); + head.push_back(check(scope, expr, expectedType).ty); } else { - std::vector expectedTailTypes; + std::vector> expectedTailTypes; if (i < expectedTypes.size()) expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); tail = checkPack(scope, expr, expectedTailTypes).tp; @@ -1027,7 +1079,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray< return InferencePack{arena->addTypePack(TypePack{std::move(head), tail})}; } -InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector>& expectedTypes) { RecursionCounter counter{&recursionCount}; @@ -1040,7 +1092,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* InferencePack result; if (AstExprCall* call = expr->as()) - result = checkPack(scope, call, expectedTypes); + result = checkPack(scope, call); else if (AstExprVarargs* varargs = expr->as()) { if (scope->varargPack) @@ -1062,7 +1114,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* return result; } -InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call) { std::vector exprArgs; @@ -1164,7 +1216,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } else { - auto [tp, refis] = checkPack(scope, arg, {}); // FIXME? not sure about expectedTypes here + auto [tp, refis] = checkPack(scope, arg, {}); argTail = tp; argumentRefinements.insert(argumentRefinements.end(), refis.begin(), refis.end()); } @@ -1209,24 +1261,13 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa if (matchAssert(*call) && !argumentRefinements.empty()) applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); - TypeId instantiatedType = arena->addType(BlockedType{}); // TODO: How do expectedTypes play into this? Do they? TypePackId rets = arena->addTypePack(BlockedTypePack{}); TypePackId argPack = arena->addTypePack(TypePack{args, argTail}); FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets); - TypeId inferredFnType = arena->addType(ftv); - - unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); - NotNull ic(unqueuedConstraints.back().get()); - - unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{instantiatedType, inferredFnType})); - NotNull sc(unqueuedConstraints.back().get()); NotNull fcc = addConstraint(scope, call->func->location, FunctionCallConstraint{ - {ic, sc}, fnType, argPack, rets, @@ -1276,12 +1317,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st else if (expr->is()) result = flattenPack(scope, expr->location, checkPack(scope, expr)); else if (auto call = expr->as()) - { - std::vector expectedTypes; - if (expectedType) - expectedTypes.push_back(*expectedType); - result = flattenPack(scope, expr->location, checkPack(scope, call, expectedTypes)); // TODO: needs predicates too - } + result = flattenPack(scope, expr->location, checkPack(scope, call)); // TODO: needs predicates too else if (auto a = expr->as()) { Checkpoint startCheckpoint = checkpoint(this); @@ -1883,6 +1919,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS } std::vector argTypes; + std::vector> argNames; TypePack expectedArgPack; const FunctionType* expectedFunction = expectedType ? get(*expectedType) : nullptr; @@ -1895,14 +1932,27 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS genericTypePacks = expectedFunction->genericPacks; } + if (fn->self) + { + TypeId selfType = freshType(signatureScope); + argTypes.push_back(selfType); + argNames.emplace_back(FunctionArgument{fn->self->name.value, fn->self->location}); + signatureScope->bindings[fn->self] = Binding{selfType, fn->self->location}; + } + for (size_t i = 0; i < fn->args.size; ++i) { AstLocal* local = fn->args.data[i]; TypeId t = freshType(signatureScope); argTypes.push_back(t); + argNames.emplace_back(FunctionArgument{local->name.value, local->location}); signatureScope->bindings[local] = Binding{t, local->location}; + auto def = dfg->getDef(local); + LUAU_ASSERT(def); + signatureScope->dcrRefinements[*def] = t; + TypeId annotationTy = t; if (local->annotation) @@ -1918,12 +1968,6 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS { addConstraint(signatureScope, local->location, SubtypeConstraint{t, expectedArgPack.head[i]}); } - - // HACK: This is the one case where the type of the definition will diverge from the type of the binding. - // We need to do this because there are cases where type refinements needs to have the information available - // at constraint generation time. - if (auto def = dfg->getDef(local)) - signatureScope->dcrRefinements[*def] = annotationTy; } TypePackId varargPack = nullptr; @@ -1978,6 +2022,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS actualFunction.hasNoGenerics = !hasGenerics; actualFunction.generics = std::move(genericTypes); actualFunction.genericPacks = std::move(genericTypePacks); + actualFunction.argNames = std::move(argNames); TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); @@ -2085,11 +2130,6 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b } else { - std::string typeName; - if (ref->prefix) - typeName = std::string(ref->prefix->value) + "."; - typeName += ref->name.value; - result = builtinTypes->errorRecoveryType(); } } @@ -2245,6 +2285,8 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTyp else if (auto var = tp->as()) { TypeId ty = resolveType(scope, var->variadicType, inTypeArgument); + if (get(follow(ty))) + ty = freshType(scope); result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); } else if (auto gen = tp->as()) @@ -2287,12 +2329,22 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const return arena->addTypePack(TypePack{head, tail}); } -std::vector> ConstraintGraphBuilder::createGenerics(const ScopePtr& scope, AstArray generics) +std::vector> ConstraintGraphBuilder::createGenerics( + const ScopePtr& scope, AstArray generics, bool useCache) { std::vector> result; for (const auto& generic : generics) { - TypeId genericTy = arena->addType(GenericType{scope.get(), generic.name.value}); + TypeId genericTy = nullptr; + + if (auto it = scope->parent->typeAliasTypeParameters.find(generic.name.value); useCache && it != scope->parent->typeAliasTypeParameters.end()) + genericTy = it->second; + else + { + genericTy = arena->addType(GenericType{scope.get(), generic.name.value}); + scope->parent->typeAliasTypeParameters[generic.name.value] = genericTy; + } + std::optional defaultTy = std::nullopt; if (generic.defaultValue) @@ -2305,12 +2357,22 @@ std::vector> ConstraintGraphBuilder::crea } std::vector> ConstraintGraphBuilder::createGenericPacks( - const ScopePtr& scope, AstArray generics) + const ScopePtr& scope, AstArray generics, bool useCache) { std::vector> result; for (const auto& generic : generics) { - TypePackId genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}}); + TypePackId genericTy; + + if (auto it = scope->parent->typeAliasTypePackParameters.find(generic.name.value); + useCache && it != scope->parent->typeAliasTypePackParameters.end()) + genericTy = it->second; + else + { + genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}}); + scope->parent->typeAliasTypePackParameters[generic.name.value] = genericTy; + } + std::optional defaultTy = std::nullopt; if (generic.defaultValue) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 76fd0bca8..879dac39b 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -217,12 +217,6 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); printf("\t%d\t\t%s\n", blockCount, toString(*dep, opts).c_str()); } - - if (auto fcc = get(*c)) - { - for (NotNull inner : fcc->innerConstraints) - printf("\t ->\t\t%s\n", toString(*inner, opts).c_str()); - } } } @@ -531,32 +525,19 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull mm = findMetatableEntry(builtinTypes, errors, operandType, "__unm", constraint->location)) { - const FunctionType* ftv = get(follow(*mm)); + TypeId mmTy = follow(*mm); - if (!ftv) - { - if (std::optional callMm = findMetatableEntry(builtinTypes, errors, follow(*mm), "__call", constraint->location)) - { - ftv = get(follow(*callMm)); - } - } + if (get(mmTy) && !force) + return block(mmTy, constraint); - if (!ftv) - { - asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); - return true; - } + TypePackId argPack = arena->addTypePack(TypePack{{operandType}, {}}); + TypePackId retPack = arena->addTypePack(BlockedTypePack{}); - TypePackId argsPack = arena->addTypePack({operandType}); - unify(ftv->argTypes, argsPack, constraint->scope); + asMutable(c.resultType)->ty.emplace(constraint->scope); - TypeId result = builtinTypes->errorRecoveryType(); - if (ftv) - { - result = first(ftv->retTypes).value_or(builtinTypes->errorRecoveryType()); - } + pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{retPack, arena->addTypePack(TypePack{{c.resultType}})}); - asMutable(c.resultType)->ty.emplace(result); + pushConstraint(constraint->scope, constraint->location, FunctionCallConstraint{mmTy, argPack, retPack, nullptr}); } else { @@ -884,7 +865,11 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNullname) ttv->syntheticName = c.name; else + { ttv->name = c.name; + ttv->instantiatedTypeParams = c.typeParameters; + ttv->instantiatedTypePackParams = c.typePackParameters; + } } else if (MetatableType* mtv = getMutable(target)) mtv->syntheticName = c.name; @@ -1032,6 +1017,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul { // TODO (CLI-56761): Report an error. bindResult(errorRecoveryType()); + reportError(GenericError{"Recursive type being used with different parameters"}, constraint->location); return true; } @@ -1119,9 +1105,10 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull constraint) { TypeId fn = follow(c.fn); + TypePackId argsPack = follow(c.argsPack); TypePackId result = follow(c.result); - if (isBlocked(c.fn)) + if (isBlocked(fn)) { return block(c.fn, constraint); } @@ -1156,12 +1143,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(BlockedType{}); - TypeId inferredFnType = arena->addType(FunctionType(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result)); - - asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm}; - asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType}; - + argsPack = arena->addTypePack(TypePack{args, {}}); + fn = *callMm; asMutable(c.result)->ty.emplace(constraint->scope); } else @@ -1178,19 +1161,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulldcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); } - if (usedMagic) - { - // There are constraints that are blocked on these constraints. If we - // are never going to even examine them, then we should not block - // anything else on them. - // - // TODO CLI-58842 -#if 0 - for (auto& c: c.innerConstraints) - unblock(c); -#endif - } - else + if (!usedMagic) asMutable(c.result)->ty.emplace(constraint->scope); } @@ -1209,22 +1180,24 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullanyType}; } - // Alter the inner constraints. - LUAU_ASSERT(c.innerConstraints.size() == 2); + TypeId instantiatedTy = arena->addType(BlockedType{}); + TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); + + auto ic = pushConstraint(constraint->scope, constraint->location, InstantiationConstraint{instantiatedTy, fn}); + auto sc = pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{instantiatedTy, inferredTy}); - // Anything that is blocked on this constraint must also be blocked on our inner constraints + // Anything that is blocked on this constraint must also be blocked on our + // synthesized constraints. auto blockedIt = blocked.find(constraint.get()); if (blockedIt != blocked.end()) { - for (const auto& ic : c.innerConstraints) + for (const auto& blockedConstraint : blockedIt->second) { - for (const auto& blockedConstraint : blockedIt->second) - block(ic, blockedConstraint); + block(ic, blockedConstraint); + block(sc, blockedConstraint); } } - unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); - unblock(c.result); return true; } @@ -1914,12 +1887,14 @@ void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope, const Location& location, ConstraintV cv) +NotNull ConstraintSolver::pushConstraint(NotNull scope, const Location& location, ConstraintV cv) { std::unique_ptr c = std::make_unique(scope, location, std::move(cv)); NotNull borrow = NotNull(c.get()); solverConstraints.push_back(std::move(c)); unsolvedConstraints.push_back(borrow); + + return borrow; } TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& location) diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index cffd00c91..7e7166037 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -372,7 +372,7 @@ ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInde ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) { visitExpr(scope, i->expr); - visitExpr(scope, i->expr); + visitExpr(scope, i->index); if (i->index->as()) { @@ -405,6 +405,13 @@ ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunc ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) { + for (AstExprTable::Item item : t->items) + { + if (item.key) + visitExpr(scope, item.key); + visitExpr(scope, item.value); + } + return {}; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index a70d6dda7..fb61b4ab3 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -104,7 +104,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c module.root = parseResult.root; module.mode = Mode::Definition; - ModulePtr checkedModule = check(module, Mode::Definition, globalScope, {}); + ModulePtr checkedModule = check(module, Mode::Definition, {}); if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, checkedModule}; @@ -517,7 +517,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional requireCycles, bool forAutocomplete) -{ +ModulePtr check( + const SourceModule& sourceModule, + const std::vector& requireCycles, + NotNull builtinTypes, + NotNull iceHandler, + NotNull moduleResolver, + NotNull fileResolver, + const ScopePtr& globalScope, + NotNull unifierState, + FrontendOptions options +) { ModulePtr result = std::make_shared(); - result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, NotNull{&iceHandler}); + result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, iceHandler); std::unique_ptr logger; if (FFlag::DebugLuauLogSolverToJson) @@ -872,20 +880,17 @@ ModulePtr Frontend::check( } } - DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&iceHandler}); + DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); - const NotNull mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; - const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; - - Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&typeChecker.unifierState}}; + Normalizer normalizer{&result->internalTypes, builtinTypes, unifierState}; ConstraintGraphBuilder cgb{ sourceModule.name, result, &result->internalTypes, - mr, + moduleResolver, builtinTypes, - NotNull(&iceHandler), + iceHandler, globalScope, logger.get(), NotNull{&dfg}, @@ -894,7 +899,7 @@ ModulePtr Frontend::check( cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, NotNull(&moduleResolver), + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, moduleResolver, requireCycles, logger.get()}; if (options.randomizeConstraintResolutionSeed) @@ -908,7 +913,7 @@ ModulePtr Frontend::check( result->scopes = std::move(cgb.scopes); result->type = sourceModule.type; - result->clonePublicInterface(builtinTypes, iceHandler); + result->clonePublicInterface(builtinTypes, *iceHandler); Luau::check(builtinTypes, logger.get(), sourceModule, result.get()); @@ -929,6 +934,22 @@ ModulePtr Frontend::check( return result; } +ModulePtr Frontend::check( + const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete) +{ + return Luau::check( + sourceModule, + requireCycles, + builtinTypes, + NotNull{&iceHandler}, + NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, + NotNull{fileResolver}, + forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope, + NotNull{&typeChecker.unifierState}, + options + ); +} + // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(const ModuleName& name) { diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index c0f4405c6..b51b7c9a6 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -227,7 +227,10 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr // Copy external stuff over to Module itself this->returnType = moduleScope->returnType; - this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); + if (FFlag::DebugLuauDeferredConstraintResolution) + this->exportedTypeBindings = moduleScope->exportedTypeBindings; + else + this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); } bool Module::hasModuleScope() const diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index f874a0b7b..2f69f6980 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -24,7 +24,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauMatchReturnsOptionalString, false); @@ -358,39 +357,24 @@ bool maybeGeneric(TypeId ty) { LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); - if (FFlag::LuauMaybeGenericIntersectionTypes) - { - ty = follow(ty); - - if (get(ty)) - return true; - - if (auto ttv = get(ty)) - { - // TODO: recurse on table types CLI-39914 - (void)ttv; - return true; - } - - if (auto itv = get(ty)) - { - return std::any_of(begin(itv), end(itv), maybeGeneric); - } - - return isGeneric(ty); - } - ty = follow(ty); + if (get(ty)) return true; - else if (auto ttv = get(ty)) + + if (auto ttv = get(ty)) { // TODO: recurse on table types CLI-39914 (void)ttv; return true; } - else - return isGeneric(ty); + + if (auto itv = get(ty)) + { + return std::any_of(begin(itv), end(itv), maybeGeneric); + } + + return isGeneric(ty); } bool maybeSingleton(TypeId ty) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 133e324b2..4322a0daa 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -204,12 +204,6 @@ struct TypeChecker2 bestLocation = scopeBounds; } } - else if (scopeBounds.begin > location.end) - { - // TODO: Is this sound? This relies on the fact that scopes are inserted - // into the scope list in the order that they appear in the AST. - break; - } } return bestScope; @@ -676,18 +670,7 @@ struct TypeChecker2 void visit(AstStatTypeAlias* stat) { - for (const AstGenericType& el : stat->generics) - { - if (el.defaultValue) - visit(el.defaultValue); - } - - for (const AstGenericTypePack& el : stat->genericPacks) - { - if (el.defaultValue) - visit(el.defaultValue); - } - + visitGenerics(stat->generics, stat->genericPacks); visit(stat->type); } @@ -701,6 +684,7 @@ struct TypeChecker2 void visit(AstStatDeclareFunction* stat) { + visitGenerics(stat->generics, stat->genericPacks); visit(stat->params); visit(stat->retTypes); } @@ -973,8 +957,9 @@ struct TypeChecker2 void visit(AstExprIndexName* indexName, ValueContext context) { - TypeId leftType = lookupType(indexName->expr); + visit(indexName->expr, RValue); + TypeId leftType = lookupType(indexName->expr); const NormalizedType* norm = normalizer.normalize(leftType); if (!norm) reportError(NormalizationTooComplex{}, indexName->indexLocation); @@ -993,11 +978,18 @@ struct TypeChecker2 { auto StackPusher = pushStack(fn); + visitGenerics(fn->generics, fn->genericPacks); + TypeId inferredFnTy = lookupType(fn); const FunctionType* inferredFtv = get(inferredFnTy); LUAU_ASSERT(inferredFtv); + // There is no way to write an annotation for the self argument, so we + // cannot do anything to check it. auto argIt = begin(inferredFtv->argTypes); + if (fn->self) + ++argIt; + for (const auto& arg : fn->args) { if (argIt == end(inferredFtv->argTypes)) @@ -1037,6 +1029,7 @@ struct TypeChecker2 NotNull scope = stack.back(); TypeId operandType = lookupType(expr->expr); + TypeId resultType = lookupType(expr); if (get(operandType) || get(operandType) || get(operandType)) return; @@ -1048,9 +1041,6 @@ struct TypeChecker2 { if (const FunctionType* ftv = get(follow(*mm))) { - TypePackId expectedArgs = testArena.addTypePack({operandType}); - reportErrors(tryUnify(scope, expr->location, expectedArgs, ftv->argTypes)); - if (std::optional ret = first(ftv->retTypes)) { if (expr->op == AstExprUnary::Op::Len) @@ -1062,6 +1052,25 @@ struct TypeChecker2 { reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); } + + std::optional firstArg = first(ftv->argTypes); + if (!firstArg) + { + reportError(GenericError{"__unm metamethod must accept one argument"}, expr->location); + return; + } + + TypePackId expectedArgs = testArena.addTypePack({operandType}); + TypePackId expectedRet = testArena.addTypePack({resultType}); + + TypeId expectedFunction = testArena.addType(FunctionType{expectedArgs, expectedRet}); + + ErrorVec errors = tryUnify(scope, expr->location, *mm, expectedFunction); + if (!errors.empty()) + { + reportError(TypeMismatch{*firstArg, operandType}, expr->location); + return; + } } return; @@ -1413,6 +1422,33 @@ struct TypeChecker2 ice.ice("flattenPack got a weird pack!"); } + void visitGenerics(AstArray generics, AstArray genericPacks) + { + DenseHashSet seen{AstName{}}; + + for (const auto& g : generics) + { + if (seen.contains(g.name)) + reportError(DuplicateGenericParameter{g.name.value}, g.location); + else + seen.insert(g.name); + + if (g.defaultValue) + visit(g.defaultValue); + } + + for (const auto& g : genericPacks) + { + if (seen.contains(g.name)) + reportError(DuplicateGenericParameter{g.name.value}, g.location); + else + seen.insert(g.name); + + if (g.defaultValue) + visit(g.defaultValue); + } + } + void visit(AstType* ty) { if (auto t = ty->as()) @@ -1579,8 +1615,7 @@ struct TypeChecker2 void visit(AstTypeFunction* ty) { - // TODO! - + visitGenerics(ty->generics, ty->genericPacks); visit(ty->argTypes); visit(ty->returnTypes); } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index bda062af8..7104f2e73 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -17,7 +17,6 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) -LUAU_FASTFLAGVARIABLE(LuauUnifyAnyTxnLog, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner2, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) @@ -475,40 +474,23 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return; } - if (FFlag::LuauUnifyAnyTxnLog) - { - if (log.get(superTy)) - return tryUnifyWithAny(subTy, builtinTypes->anyType); + if (log.get(superTy)) + return tryUnifyWithAny(subTy, builtinTypes->anyType); - if (log.get(superTy)) - return tryUnifyWithAny(subTy, builtinTypes->errorType); + if (log.get(superTy)) + return tryUnifyWithAny(subTy, builtinTypes->errorType); - if (log.get(superTy)) - return tryUnifyWithAny(subTy, builtinTypes->unknownType); + if (log.get(superTy)) + return tryUnifyWithAny(subTy, builtinTypes->unknownType); - if (log.get(subTy)) - return tryUnifyWithAny(superTy, builtinTypes->anyType); + if (log.get(subTy)) + return tryUnifyWithAny(superTy, builtinTypes->anyType); - if (log.get(subTy)) - return tryUnifyWithAny(superTy, builtinTypes->errorType); + if (log.get(subTy)) + return tryUnifyWithAny(superTy, builtinTypes->errorType); - if (log.get(subTy)) - return tryUnifyWithAny(superTy, builtinTypes->neverType); - } - else - { - if (get(superTy) || get(superTy) || get(superTy)) - return tryUnifyWithAny(subTy, superTy); - - if (get(subTy)) - return tryUnifyWithAny(superTy, subTy); - - if (log.get(subTy)) - return tryUnifyWithAny(superTy, subTy); - - if (log.get(subTy)) - return tryUnifyWithAny(superTy, subTy); - } + if (log.get(subTy)) + return tryUnifyWithAny(superTy, builtinTypes->neverType); auto& cache = sharedState.cachedUnify; @@ -2535,18 +2517,9 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) { LUAU_ASSERT(get(anyTy) || get(anyTy) || get(anyTy) || get(anyTy)); - if (FFlag::LuauUnifyAnyTxnLog) - { - // These types are not visited in general loop below - if (log.get(subTy) || log.get(subTy) || log.get(subTy)) - return; - } - else - { - // These types are not visited in general loop below - if (get(subTy) || get(subTy) || get(subTy)) - return; - } + // These types are not visited in general loop below + if (log.get(subTy) || log.get(subTy) || log.get(subTy)) + return; TypePackId anyTp = types->addTypePack(TypePackVar{VariadicTypePack{anyTy}}); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 69a40356b..63baea8b7 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -319,6 +319,9 @@ std::string runCode(lua_State* L, const std::string& source) lua_insert(T, 1); lua_pcall(T, n, 0, 0); } + + lua_pop(L, 1); + return std::string(); } else { @@ -336,11 +339,9 @@ std::string runCode(lua_State* L, const std::string& source) error += "\nstack backtrace:\n"; error += lua_debugtrace(T); - fprintf(stdout, "%s", error.c_str()); + lua_pop(L, 1); + return error; } - - lua_pop(L, 1); - return std::string(); } // Replaces the top of the lua stack with the metatable __index for the value diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 28f5b29bd..6a7094684 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -55,7 +55,7 @@ enum class IrCmd : uint8_t // Get pointer (TValue) to table array at index // A: pointer (Table) - // B: unsigned int + // B: int GET_ARR_ADDR, // Get pointer (LuaNode) to table node element at the active cached slot index @@ -177,7 +177,7 @@ enum class IrCmd : uint8_t // A: pointer (Table) DUP_TABLE, - // Try to convert a double number into a table index or jump if it's not an integer + // Try to convert a double number into a table index (int) or jump if it's not an integer // A: double // B: block NUM_TO_INDEX, @@ -216,10 +216,10 @@ enum class IrCmd : uint8_t // B: unsigned int (import path) GET_IMPORT, - // Concatenate multiple TValues - // A: Rn (where to store the result) - // B: unsigned int (index of the first VM stack slot) - // C: unsigned int (number of stack slots to go over) + // Concatenate multiple TValues into a string + // A: Rn (value start) + // B: unsigned int (number of registers to go over) + // Note: result is stored in the register specified in 'A' CONCAT, // Load function upvalue into stack slot @@ -262,7 +262,8 @@ enum class IrCmd : uint8_t // Guard against index overflowing the table array size // A: pointer (Table) - // B: block + // B: int (index) + // C: block CHECK_ARRAY_SIZE, // Guard against cached table node slot not matching the actual table node slot for a key @@ -451,8 +452,12 @@ enum class IrCmd : uint8_t // Prepare loop variables for a generic for loop, jump to the loop backedge unconditionally // A: unsigned int (bytecode instruction index) // B: Rn (loop state, updates Rn Rn+1 Rn+2) - // B: block + // C: block FALLBACK_FORGPREP, + + // Instruction that passes value through, it is produced by constant folding and users substitute it with the value + SUBSTITUTE, + // A: operand of any type }; enum class IrConstKind : uint8_t @@ -659,6 +664,12 @@ struct IrFunction LUAU_ASSERT(value.kind == IrConstKind::Double); return value.valueDouble; } + + IrCondition conditionOp(IrOp op) + { + LUAU_ASSERT(op.kind == IrOpKind::Condition); + return IrCondition(op.index); + } }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 1aef9a3fc..3e95813bb 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -10,6 +10,8 @@ namespace Luau namespace CodeGen { +struct IrBuilder; + inline bool isJumpD(LuauOpcode op) { switch (op) @@ -138,6 +140,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::DUP_TABLE: case IrCmd::NUM_TO_INDEX: case IrCmd::INT_TO_NUM: + case IrCmd::SUBSTITUTE: return true; default: break; @@ -153,6 +156,12 @@ inline bool hasSideEffects(IrCmd cmd) return !hasResult(cmd); } +inline bool isPseudo(IrCmd cmd) +{ + // Instructions that are used for internal needs and are not a part of final lowering + return cmd == IrCmd::NOP || cmd == IrCmd::SUBSTITUTE; +} + // Remove a single instruction void kill(IrFunction& function, IrInst& inst); @@ -172,5 +181,17 @@ void replace(IrFunction& function, IrOp& original, IrOp replacement); // Target instruction index instead of reference is used to handle introduction of a new block terminator void replace(IrFunction& function, uint32_t instIdx, IrInst replacement); +// Replace instruction with a different value (using IrCmd::SUBSTITUTE) +void substitute(IrFunction& function, IrInst& inst, IrOp replacement); + +// Replace instruction arguments that point to substitutions with target values +void applySubstitutions(IrFunction& function, IrOp& op); +void applySubstitutions(IrFunction& function, IrInst& inst); + +// Perform constant folding on instruction at index +// For most instructions, successful folding results in a IrCmd::SUBSTITUTE +// But it can also be successful on conditional control-flow, replacing it with an unconditional IrCmd::JUMP +void foldConstants(IrBuilder& build, IrFunction& function, uint32_t instIdx); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 5a23861e4..918a8829a 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -246,6 +246,8 @@ const char* getCmdName(IrCmd cmd) return "FALLBACK_DUPCLOSURE"; case IrCmd::FALLBACK_FORGPREP: return "FALLBACK_FORGPREP"; + case IrCmd::SUBSTITUTE: + return "SUBSTITUTE"; } LUAU_UNREACHABLE(); @@ -423,8 +425,8 @@ std::string toString(IrFunction& function, bool includeDetails) { IrInst& inst = function.instructions[index]; - // Nop is used to replace dead instructions in-place, so it's not that useful to see them - if (inst.cmd == IrCmd::NOP) + // Skip pseudo instructions unless they are still referenced + if (isPseudo(inst.cmd) && inst.useCount == 0) continue; append(ctx.result, " "); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 03bb18146..3c816554a 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -20,7 +20,7 @@ namespace Luau namespace CodeGen { -static RegisterX64 gprAlocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; +static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) : build(build) @@ -111,7 +111,7 @@ void IrLoweringX64::lower(AssemblyOptions options) if (options.includeIr) { build.logAppend("# "); - toStringDetailed(ctx, block, uint32_t(i)); + toStringDetailed(ctx, block, blockIndex); } build.setLabel(block.label); @@ -133,9 +133,9 @@ void IrLoweringX64::lower(AssemblyOptions options) IrInst& inst = function.instructions[index]; - // Nop is used to replace dead instructions in-place - // Because it doesn't have any effects aside from output (when enabled), we skip it completely - if (inst.cmd == IrCmd::NOP) + // Skip pseudo instructions, but make sure they are not used at this stage + // This also prevents them from getting into text output when that's enabled + if (isPseudo(inst.cmd)) { LUAU_ASSERT(inst.useCount == 0); continue; @@ -263,8 +263,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, array)]); - if (uintOp(inst.b) != 0) - build.lea(inst.regX64, addr[inst.regX64 + uintOp(inst.b) * sizeof(TValue)]); + if (intOp(inst.b) != 0) + build.lea(inst.regX64, addr[inst.regX64 + intOp(inst.b) * sizeof(TValue)]); } else { @@ -688,9 +688,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) emitInstGetImportFallback(build, inst.a.index, uintOp(inst.b)); break; case IrCmd::CONCAT: + LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); + build.mov(rArg1, rState); - build.mov(dwordReg(rArg2), uintOp(inst.a)); - build.mov(dwordReg(rArg3), uintOp(inst.b)); + build.mov(dwordReg(rArg2), uintOp(inst.b)); + build.mov(dwordReg(rArg3), inst.a.index + uintOp(inst.b) - 1); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]); emitUpdateBase(build); @@ -778,7 +780,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (inst.b.kind == IrOpKind::Inst) build.cmp(dword[regOp(inst.a) + offsetof(Table, sizearray)], regOp(inst.b)); else if (inst.b.kind == IrOpKind::Constant) - build.cmp(dword[regOp(inst.a) + offsetof(Table, sizearray)], uintOp(inst.b)); + build.cmp(dword[regOp(inst.a) + offsetof(Table, sizearray)], intOp(inst.b)); else LUAU_ASSERT(!"Unsupported instruction form"); @@ -897,6 +899,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); emitInstNameCall(build, pc, uintOp(inst.a), proto->k, blockOp(inst.d).label, blockOp(inst.e).label); + jumpOrFallthrough(blockOp(inst.d), next); break; } case IrCmd::LOP_CALL: @@ -1133,7 +1136,7 @@ RegisterX64 IrLoweringX64::allocGprReg(SizeX64 preferredSize) LUAU_ASSERT( preferredSize == SizeX64::byte || preferredSize == SizeX64::word || preferredSize == SizeX64::dword || preferredSize == SizeX64::qword); - for (RegisterX64 reg : gprAlocOrder) + for (RegisterX64 reg : kGprAllocOrder) { if (freeGprMap[reg.index]) { diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index fdbdf6670..0885c0562 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -695,10 +695,10 @@ void translateInstGetTableN(IrBuilder& build, const Instruction* pc, int pcpos) IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); - build.inst(IrCmd::CHECK_ARRAY_SIZE, vb, build.constUint(c), fallback); + build.inst(IrCmd::CHECK_ARRAY_SIZE, vb, build.constInt(c), fallback); build.inst(IrCmd::CHECK_NO_METATABLE, vb, fallback); - IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constUint(c)); + IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(c)); // TODO: per-component loads and stores might be preferable IrOp arrElTval = build.inst(IrCmd::LOAD_TVALUE, arrEl); @@ -725,11 +725,11 @@ void translateInstSetTableN(IrBuilder& build, const Instruction* pc, int pcpos) IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); - build.inst(IrCmd::CHECK_ARRAY_SIZE, vb, build.constUint(c), fallback); + build.inst(IrCmd::CHECK_ARRAY_SIZE, vb, build.constInt(c), fallback); build.inst(IrCmd::CHECK_NO_METATABLE, vb, fallback); build.inst(IrCmd::CHECK_READONLY, vb, fallback); - IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constUint(c)); + IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(c)); // TODO: per-component loads and stores might be preferable IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); @@ -969,7 +969,7 @@ void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos) int rc = LUAU_INSN_C(*pc); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst(IrCmd::CONCAT, build.constUint(rc - rb + 1), build.constUint(rc)); + build.inst(IrCmd::CONCAT, build.vmReg(rb), build.constUint(rc - rb + 1)); // TODO: per-component loads and stores might be preferable IrOp tvb = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb)); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 0c1a89668..2ff1c0d14 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -1,6 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrUtils.h" +#include "Luau/IrBuilder.h" + +#include "lua.h" +#include "lnumutils.h" + +#include +#include + namespace Luau { namespace CodeGen @@ -8,16 +16,19 @@ namespace CodeGen static uint32_t getBlockEnd(IrFunction& function, uint32_t start) { + LUAU_ASSERT(start < function.instructions.size()); + uint32_t end = start; // Find previous block terminator while (!isBlockTerminator(function.instructions[end].cmd)) end++; + LUAU_ASSERT(end < function.instructions.size()); return end; } -static void addUse(IrFunction& function, IrOp op) +void addUse(IrFunction& function, IrOp op) { if (op.kind == IrOpKind::Inst) function.instructions[op.index].useCount++; @@ -25,7 +36,7 @@ static void addUse(IrFunction& function, IrOp op) function.blocks[op.index].useCount++; } -static void removeUse(IrFunction& function, IrOp op) +void removeUse(IrFunction& function, IrOp op) { if (op.kind == IrOpKind::Inst) removeUse(function, function.instructions[op.index]); @@ -44,6 +55,12 @@ void kill(IrFunction& function, IrInst& inst) removeUse(function, inst.c); removeUse(function, inst.d); removeUse(function, inst.e); + + inst.a = {}; + inst.b = {}; + inst.c = {}; + inst.d = {}; + inst.e = {}; } void kill(IrFunction& function, uint32_t start, uint32_t end) @@ -51,6 +68,7 @@ void kill(IrFunction& function, uint32_t start, uint32_t end) // Kill instructions in reverse order to avoid killing instructions that are still marked as used for (int i = int(end); i >= int(start); i--) { + LUAU_ASSERT(unsigned(i) < function.instructions.size()); IrInst& curr = function.instructions[i]; if (curr.cmd == IrCmd::NOP) @@ -102,7 +120,6 @@ void replace(IrFunction& function, IrOp& original, IrOp replacement) void replace(IrFunction& function, uint32_t instIdx, IrInst replacement) { IrInst& inst = function.instructions[instIdx]; - IrCmd prevCmd = inst.cmd; // Add uses before removing new ones if those are the last ones keeping target operand alive addUse(function, replacement.a); @@ -111,6 +128,20 @@ void replace(IrFunction& function, uint32_t instIdx, IrInst replacement) addUse(function, replacement.d); addUse(function, replacement.e); + // If we introduced an earlier terminating instruction, all following instructions become dead + if (!isBlockTerminator(inst.cmd) && isBlockTerminator(replacement.cmd)) + { + uint32_t start = instIdx + 1; + + // If we are in the process of constructing a block, replacement might happen at the last instruction + if (start < function.instructions.size()) + { + uint32_t end = getBlockEnd(function, start); + + kill(function, start, end); + } + } + removeUse(function, inst.a); removeUse(function, inst.b); removeUse(function, inst.c); @@ -118,14 +149,227 @@ void replace(IrFunction& function, uint32_t instIdx, IrInst replacement) removeUse(function, inst.e); inst = replacement; +} - // If we introduced an earlier terminating instruction, all following instructions become dead - if (!isBlockTerminator(prevCmd) && isBlockTerminator(inst.cmd)) +void substitute(IrFunction& function, IrInst& inst, IrOp replacement) +{ + LUAU_ASSERT(!isBlockTerminator(inst.cmd)); + + inst.cmd = IrCmd::SUBSTITUTE; + + removeUse(function, inst.a); + removeUse(function, inst.b); + removeUse(function, inst.c); + removeUse(function, inst.d); + removeUse(function, inst.e); + + inst.a = replacement; + inst.b = {}; + inst.c = {}; + inst.d = {}; + inst.e = {}; +} + +void applySubstitutions(IrFunction& function, IrOp& op) +{ + if (op.kind == IrOpKind::Inst) { - uint32_t start = instIdx + 1; - uint32_t end = getBlockEnd(function, start); + IrInst& src = function.instructions[op.index]; + + if (src.cmd == IrCmd::SUBSTITUTE) + { + op.kind = src.a.kind; + op.index = src.a.index; + + // If we substitute with the result of a different instruction, update the use count + if (op.kind == IrOpKind::Inst) + { + IrInst& dst = function.instructions[op.index]; + LUAU_ASSERT(dst.cmd != IrCmd::SUBSTITUTE && "chained substitutions are not allowed"); + + dst.useCount++; + } + + LUAU_ASSERT(src.useCount > 0); + src.useCount--; + } + } +} + +void applySubstitutions(IrFunction& function, IrInst& inst) +{ + applySubstitutions(function, inst.a); + applySubstitutions(function, inst.b); + applySubstitutions(function, inst.c); + applySubstitutions(function, inst.d); + applySubstitutions(function, inst.e); +} + +static bool compare(double a, double b, IrCondition cond) +{ + switch (cond) + { + case IrCondition::Equal: + return a == b; + case IrCondition::NotEqual: + return a != b; + case IrCondition::Less: + return a < b; + case IrCondition::NotLess: + return !(a < b); + case IrCondition::LessEqual: + return a <= b; + case IrCondition::NotLessEqual: + return !(a <= b); + case IrCondition::Greater: + return a > b; + case IrCondition::NotGreater: + return !(a > b); + case IrCondition::GreaterEqual: + return a >= b; + case IrCondition::NotGreaterEqual: + return !(a >= b); + default: + LUAU_ASSERT(!"unsupported conidtion"); + } + + return false; +} + +void foldConstants(IrBuilder& build, IrFunction& function, uint32_t index) +{ + IrInst& inst = function.instructions[index]; + + switch (inst.cmd) + { + case IrCmd::ADD_INT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + // We need to avoid signed integer overflow, but we also have to produce a result + // So we add numbers as unsigned and use fixed-width integer types to force a two's complement evaluation + int32_t lhs = function.intOp(inst.a); + int32_t rhs = function.intOp(inst.b); + int sum = int32_t(uint32_t(lhs) + uint32_t(rhs)); + + substitute(function, inst, build.constInt(sum)); + } + break; + case IrCmd::SUB_INT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + // We need to avoid signed integer overflow, but we also have to produce a result + // So we subtract numbers as unsigned and use fixed-width integer types to force a two's complement evaluation + int32_t lhs = function.intOp(inst.a); + int32_t rhs = function.intOp(inst.b); + int sum = int32_t(uint32_t(lhs) - uint32_t(rhs)); + + substitute(function, inst, build.constInt(sum)); + } + break; + case IrCmd::ADD_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(function.doubleOp(inst.a) + function.doubleOp(inst.b))); + break; + case IrCmd::SUB_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(function.doubleOp(inst.a) - function.doubleOp(inst.b))); + break; + case IrCmd::MUL_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(function.doubleOp(inst.a) * function.doubleOp(inst.b))); + break; + case IrCmd::DIV_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(function.doubleOp(inst.a) / function.doubleOp(inst.b))); + break; + case IrCmd::MOD_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(luai_nummod(function.doubleOp(inst.a), function.doubleOp(inst.b)))); + break; + case IrCmd::POW_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(pow(function.doubleOp(inst.a), function.doubleOp(inst.b)))); + break; + case IrCmd::UNM_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(-function.doubleOp(inst.a))); + break; + case IrCmd::NOT_ANY: + if (inst.a.kind == IrOpKind::Constant) + { + uint8_t a = function.tagOp(inst.a); + + if (a == LUA_TNIL) + substitute(function, inst, build.constInt(1)); + else if (a != LUA_TBOOLEAN) + substitute(function, inst, build.constInt(0)); + else if (inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(function.intOp(inst.b) == 1 ? 0 : 1)); + } + break; + case IrCmd::JUMP_EQ_TAG: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + if (function.tagOp(inst.a) == function.tagOp(inst.b)) + replace(function, index, {IrCmd::JUMP, inst.c}); + else + replace(function, index, {IrCmd::JUMP, inst.d}); + } + break; + case IrCmd::JUMP_EQ_INT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + if (function.intOp(inst.a) == function.intOp(inst.b)) + replace(function, index, {IrCmd::JUMP, inst.c}); + else + replace(function, index, {IrCmd::JUMP, inst.d}); + } + break; + case IrCmd::JUMP_CMP_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + if (compare(function.doubleOp(inst.a), function.doubleOp(inst.b), function.conditionOp(inst.c))) + replace(function, index, {IrCmd::JUMP, inst.d}); + else + replace(function, index, {IrCmd::JUMP, inst.e}); + } + break; + case IrCmd::NUM_TO_INDEX: + if (inst.a.kind == IrOpKind::Constant) + { + double value = function.doubleOp(inst.a); + + // To avoid undefined behavior of casting a value not representable in the target type, we check the range + if (value >= INT_MIN && value <= INT_MAX) + { + int arrIndex = int(value); - kill(function, start, end); + if (double(arrIndex) == value) + substitute(function, inst, build.constInt(arrIndex)); + else + replace(function, index, {IrCmd::JUMP, inst.b}); + } + else + { + replace(function, index, {IrCmd::JUMP, inst.b}); + } + } + break; + case IrCmd::INT_TO_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(double(function.intOp(inst.a)))); + break; + case IrCmd::CHECK_TAG: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + if (function.tagOp(inst.a) == function.tagOp(inst.b)) + kill(function, inst); + else + replace(function, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path + } + break; + default: + break; } } diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index a14cc1e65..afd364018 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -10,7 +10,7 @@ inline bool isFlagExperimental(const char* flag) { // Flags in this list are disabled by default in various command-line tools. They may have behavior that is not fully final, // or critical bugs that are found after the code has been submitted. - static const char* kList[] = { + static const char* const kList[] = { "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauTryhardAnd", // waiting for a fix in graphql-lua -> apollo-client-lia -> lua-apps "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 3fd0e9cd9..d238e9eca 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3442,8 +3442,6 @@ TEST_CASE_FIXTURE(ACFixture, "type_reduction_is_hooked_up_to_autocomplete") TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") { - ScopedFastFlag luauAutocompleteStringContent{"LuauAutocompleteStringContent", true}; - loadDefinition(R"( declare function require(path: string): any )"); diff --git a/tests/ClassFixture.h b/tests/ClassFixture.h index 66aec7646..c46697a26 100644 --- a/tests/ClassFixture.h +++ b/tests/ClassFixture.h @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once #include "Fixture.h" diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 9d21973dc..f245ca933 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -176,7 +176,22 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars { frontend.lint(*sourceModule); - typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + Luau::check( + *sourceModule, + {}, + frontend.builtinTypes, + NotNull{&ice}, + NotNull{&moduleResolver}, + NotNull{&fileResolver}, + typeChecker.globalScope, + NotNull{&typeChecker.unifierState}, + frontend.options + ); + } + else + typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); } throw ParseErrors(result.errors); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 4ed872862..4bb638e83 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -2,16 +2,83 @@ #include "Luau/IrBuilder.h" #include "Luau/IrAnalysis.h" #include "Luau/IrDump.h" +#include "Luau/IrUtils.h" #include "Luau/OptimizeFinalX64.h" #include "doctest.h" +#include + using namespace Luau::CodeGen; class IrBuilderFixture { public: + void constantFold() + { + for (size_t i = 0; i < build.function.instructions.size(); i++) + { + IrInst& inst = build.function.instructions[i]; + + applySubstitutions(build.function, inst); + foldConstants(build, build.function, uint32_t(i)); + } + } + + template + void withOneBlock(F&& f) + { + IrOp main = build.block(IrBlockKind::Internal); + IrOp a = build.block(IrBlockKind::Internal); + + build.beginBlock(main); + f(a); + + build.beginBlock(a); + build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + }; + + template + void withTwoBlocks(F&& f) + { + IrOp main = build.block(IrBlockKind::Internal); + IrOp a = build.block(IrBlockKind::Internal); + IrOp b = build.block(IrBlockKind::Internal); + + build.beginBlock(main); + f(a, b); + + build.beginBlock(a); + build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + + build.beginBlock(b); + build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + }; + + void checkEq(IrOp lhs, IrOp rhs) + { + CHECK_EQ(lhs.kind, rhs.kind); + LUAU_ASSERT(lhs.kind != IrOpKind::Constant && "can't compare constants, each ref is unique"); + CHECK_EQ(lhs.index, rhs.index); + } + + void checkEq(IrOp instOp, const IrInst& inst) + { + const IrInst& target = build.function.instOp(instOp); + CHECK(target.cmd == inst.cmd); + checkEq(target.a, inst.a); + checkEq(target.b, inst.b); + checkEq(target.c, inst.c); + checkEq(target.d, inst.d); + checkEq(target.e, inst.e); + } + IrBuilder build; + + // Luau.VM headers are not accessible + static const int tnil = 0; + static const int tboolean = 1; + static const int tnumber = 3; }; TEST_SUITE_BEGIN("Optimization"); @@ -153,7 +220,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") build.beginBlock(block); IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); - IrOp arrElem = build.inst(IrCmd::GET_ARR_ADDR, table, build.constUint(0)); + IrOp arrElem = build.inst(IrCmd::GET_ARR_ADDR, table, build.constInt(0)); IrOp opA = build.inst(IrCmd::LOAD_TAG, arrElem); build.inst(IrCmd::JUMP_EQ_TAG, opA, build.constTag(0), trueBlock, falseBlock); build.inst(IrCmd::LOP_RETURN, build.constUint(0)); @@ -171,7 +238,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( bb_0: %0 = LOAD_POINTER R1 - %1 = GET_ARR_ADDR %0, 0u + %1 = GET_ARR_ADDR %0, 0i %2 = LOAD_TAG %1 JUMP_EQ_TAG %2, tnil, bb_1, bb_2 @@ -221,3 +288,247 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum") } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("ConstantFolding"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::ADD_INT, build.constInt(10), build.constInt(20))); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::ADD_INT, build.constInt(INT_MAX), build.constInt(1))); + + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::SUB_INT, build.constInt(10), build.constInt(20))); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::SUB_INT, build.constInt(INT_MIN), build.constInt(1))); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::ADD_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::SUB_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MUL_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::DIV_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MOD_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::POW_NUM, build.constDouble(5), build.constDouble(2))); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::UNM_NUM, build.constDouble(5))); + + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tnil), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tnumber), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(1))); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::INT_TO_NUM, build.constInt(8))); + + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + STORE_INT R0, 30i + STORE_INT R0, -2147483648i + STORE_INT R0, -10i + STORE_INT R0, 2147483647i + STORE_DOUBLE R0, 7 + STORE_DOUBLE R0, -3 + STORE_DOUBLE R0, 10 + STORE_DOUBLE R0, 0.40000000000000002 + STORE_DOUBLE R0, 1 + STORE_DOUBLE R0, 25 + STORE_DOUBLE R0, -5 + STORE_INT R0, 1i + STORE_INT R0, 0i + STORE_INT R0, 1i + STORE_INT R0, 0i + STORE_DOUBLE R0, 8 + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowEq") +{ + withTwoBlocks([this](IrOp a, IrOp b) { + build.inst(IrCmd::JUMP_EQ_TAG, build.constTag(tnil), build.constTag(tnil), a, b); + }); + + withTwoBlocks([this](IrOp a, IrOp b) { + build.inst(IrCmd::JUMP_EQ_TAG, build.constTag(tnil), build.constTag(tnumber), a, b); + }); + + withTwoBlocks([this](IrOp a, IrOp b) { + build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(0), a, b); + }); + + withTwoBlocks([this](IrOp a, IrOp b) { + build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), a, b); + }); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + JUMP bb_1 + +bb_1: + LOP_RETURN 1u + +bb_3: + JUMP bb_5 + +bb_5: + LOP_RETURN 2u + +bb_6: + JUMP bb_7 + +bb_7: + LOP_RETURN 1u + +bb_9: + JUMP bb_11 + +bb_11: + LOP_RETURN 2u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumToIndex") +{ + withOneBlock([this](IrOp a) { + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, build.constDouble(4), a)); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + }); + + withOneBlock([this](IrOp a) { + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, build.constDouble(1.2), a)); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + }); + + withOneBlock([this](IrOp a) { + IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, nan, a)); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + }); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + STORE_INT R0, 4i + LOP_RETURN 0u + +bb_2: + JUMP bb_3 + +bb_3: + LOP_RETURN 1u + +bb_4: + JUMP bb_5 + +bb_5: + LOP_RETURN 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "Guards") +{ + withOneBlock([this](IrOp a) { + build.inst(IrCmd::CHECK_TAG, build.constTag(tnumber), build.constTag(tnumber), a); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + }); + + withOneBlock([this](IrOp a) { + build.inst(IrCmd::CHECK_TAG, build.constTag(tnil), build.constTag(tnumber), a); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + }); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + LOP_RETURN 0u + +bb_2: + JUMP bb_3 + +bb_3: + LOP_RETURN 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowCmpNum") +{ + IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); + + auto compareFold = [this](IrOp lhs, IrOp rhs, IrCondition cond, bool result) { + IrOp instOp; + IrInst instExpected; + + withTwoBlocks([&](IrOp a, IrOp b) { + instOp = build.inst(IrCmd::JUMP_CMP_NUM, lhs, rhs, build.cond(cond), a, b); + instExpected = IrInst{IrCmd::JUMP, result ? a : b}; + }); + + updateUseCounts(build.function); + constantFold(); + checkEq(instOp, instExpected); + }; + + compareFold(build.constDouble(1), build.constDouble(1), IrCondition::Equal, true); + compareFold(build.constDouble(1), build.constDouble(2), IrCondition::Equal, false); + compareFold(nan, nan, IrCondition::Equal, false); + + compareFold(build.constDouble(1), build.constDouble(1), IrCondition::NotEqual, false); + compareFold(build.constDouble(1), build.constDouble(2), IrCondition::NotEqual, true); + compareFold(nan, nan, IrCondition::NotEqual, true); + + compareFold(build.constDouble(1), build.constDouble(1), IrCondition::Less, false); + compareFold(build.constDouble(1), build.constDouble(2), IrCondition::Less, true); + compareFold(build.constDouble(2), build.constDouble(1), IrCondition::Less, false); + compareFold(build.constDouble(1), nan, IrCondition::Less, false); + + compareFold(build.constDouble(1), build.constDouble(1), IrCondition::NotLess, true); + compareFold(build.constDouble(1), build.constDouble(2), IrCondition::NotLess, false); + compareFold(build.constDouble(2), build.constDouble(1), IrCondition::NotLess, true); + compareFold(build.constDouble(1), nan, IrCondition::NotLess, true); + + compareFold(build.constDouble(1), build.constDouble(1), IrCondition::LessEqual, true); + compareFold(build.constDouble(1), build.constDouble(2), IrCondition::LessEqual, true); + compareFold(build.constDouble(2), build.constDouble(1), IrCondition::LessEqual, false); + compareFold(build.constDouble(1), nan, IrCondition::LessEqual, false); + + compareFold(build.constDouble(1), build.constDouble(1), IrCondition::NotLessEqual, false); + compareFold(build.constDouble(1), build.constDouble(2), IrCondition::NotLessEqual, false); + compareFold(build.constDouble(2), build.constDouble(1), IrCondition::NotLessEqual, true); + compareFold(build.constDouble(1), nan, IrCondition::NotLessEqual, true); + + compareFold(build.constDouble(1), build.constDouble(1), IrCondition::Greater, false); + compareFold(build.constDouble(1), build.constDouble(2), IrCondition::Greater, false); + compareFold(build.constDouble(2), build.constDouble(1), IrCondition::Greater, true); + compareFold(build.constDouble(1), nan, IrCondition::Greater, false); + + compareFold(build.constDouble(1), build.constDouble(1), IrCondition::NotGreater, true); + compareFold(build.constDouble(1), build.constDouble(2), IrCondition::NotGreater, true); + compareFold(build.constDouble(2), build.constDouble(1), IrCondition::NotGreater, false); + compareFold(build.constDouble(1), nan, IrCondition::NotGreater, true); + + compareFold(build.constDouble(1), build.constDouble(1), IrCondition::GreaterEqual, true); + compareFold(build.constDouble(1), build.constDouble(2), IrCondition::GreaterEqual, false); + compareFold(build.constDouble(2), build.constDouble(1), IrCondition::GreaterEqual, true); + compareFold(build.constDouble(1), nan, IrCondition::GreaterEqual, false); + + compareFold(build.constDouble(1), build.constDouble(1), IrCondition::NotGreaterEqual, false); + compareFold(build.constDouble(1), build.constDouble(2), IrCondition::NotGreaterEqual, true); + compareFold(build.constDouble(2), build.constDouble(1), IrCondition::NotGreaterEqual, false); + compareFold(build.constDouble(1), nan, IrCondition::NotGreaterEqual, true); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 1eaec909c..c70ef5226 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -226,7 +226,7 @@ TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") LUAU_REQUIRE_ERROR_COUNT(1, result); - const char* expectedError; + std::string expectedError; if (FFlag::LuauTypeMismatchInvarianceInError) expectedError = "Type 'bad' could not be converted into 'U'\n" "caused by:\n" diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 683469a82..a267419e0 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -95,6 +95,26 @@ TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); } +TEST_CASE_FIXTURE(Fixture, "generalize_table_property") +{ + CheckResult result = check(R"( + local T = {} + + T.foo = function(x) + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("T"); + const TableType* tt = get(follow(t)); + REQUIRE(tt); + + TypeId fooTy = tt->props.at("foo").type; + CHECK("(a) -> a" == toString(fooTy)); +} + TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and_size") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index d7b0bdb4e..0ba889c89 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -303,13 +303,8 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") end )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_NO_ERRORS(result); - else - { - // TODO: Should typecheck but currently errors CLI-39916 - LUAU_REQUIRE_ERRORS(result); - } + // TODO: Should typecheck but currently errors CLI-54277 + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_generic_property") @@ -1053,8 +1048,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") )"); LUAU_REQUIRE_NO_ERRORS(result); +} - result = check(R"( +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument_2") +{ + CheckResult result = check(R"( local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do @@ -1068,8 +1066,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") LUAU_REQUIRE_NO_ERRORS(result); REQUIRE_EQ("{boolean}", toString(requireType("r"))); +} - check(R"( +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument_3") +{ + CheckResult result = check(R"( local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do @@ -1214,10 +1215,6 @@ TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_gen TEST_CASE_FIXTURE(Fixture, "do_not_always_instantiate_generic_intersection_types") { - ScopedFastFlag sff[] = { - {"LuauMaybeGenericIntersectionTypes", true}, - }; - CheckResult result = check(R"( --!strict type Array = { [number]: T } diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index feb04c29b..d75f00a2d 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -464,7 +464,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") local foo local mt = {} - mt.__unm = function(val: typeof(foo)): string + mt.__unm = function(val): string return tostring(val.value) .. "test" end diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index fb44ec4d4..50056290b 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1478,8 +1478,6 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length") { - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; - CheckResult result = check(R"( local function f(x: unknown) if typeof(x) == "table" then @@ -1488,8 +1486,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length end )"); - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("table", toString(requireTypeAtPosition({3, 29}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("a & table", toString(requireTypeAtPosition({3, 29}))); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("unknown", toString(requireTypeAtPosition({3, 29}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_during_constraint_solving_stage") diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 8a55c5cf1..a22149c71 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -362,8 +362,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_unify_any_should_check_log") { - ScopedFastFlag luauUnifyAnyTxnLog{"LuauUnifyAnyTxnLog", true}; - CheckResult result = check(R"( repeat _._,_ = nil diff --git a/tools/faillist.txt b/tools/faillist.txt index 565982cf2..0a09f3f64 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,8 +1,4 @@ -AnnotationTests.corecursive_types_error_on_tight_loop -AnnotationTests.duplicate_type_param_name -AnnotationTests.generic_aliases_are_cloned_properly -AnnotationTests.occurs_check_on_cyclic_intersection_type -AnnotationTests.occurs_check_on_cyclic_union_type +AnnotationTests.instantiate_type_fun_should_not_trip_rbxassert AnnotationTests.too_many_type_params AnnotationTests.two_type_params AstQuery.last_argument_function_call_type @@ -14,9 +10,6 @@ AutocompleteTest.autocomplete_oop_implicit_self AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.do_compatible_self_calls AutocompleteTest.do_wrong_compatible_self_calls -AutocompleteTest.type_correct_expected_argument_type_pack_suggestion -AutocompleteTest.type_correct_expected_argument_type_suggestion_self -AutocompleteTest.type_correct_expected_return_type_pack_suggestion AutocompleteTest.type_correct_expected_return_type_suggestion AutocompleteTest.type_correct_suggestion_for_overloads BuiltinTests.aliased_string_format @@ -51,41 +44,32 @@ BuiltinTests.table_pack_variadic DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props DefinitionTests.definition_file_classes -DefinitionTests.definitions_symbols_are_generated_for_recursively_referenced_types -DefinitionTests.single_class_type_identity_in_global_types FrontendTest.environments FrontendTest.nocheck_cycle_used_by_checked -FrontendTest.reexport_cyclic_type -GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics2 GenericsTests.better_mismatch_error_messages +GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_infer_generic_functions -GenericsTests.duplicate_generic_type_packs -GenericsTests.duplicate_generic_types GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_functions_should_be_memory_safe -GenericsTests.generic_table_method GenericsTests.generic_type_pack_parentheses GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments -GenericsTests.infer_generic_function_function_argument +GenericsTests.infer_generic_function_function_argument_2 +GenericsTests.infer_generic_function_function_argument_3 GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_lib_function_function_argument -GenericsTests.infer_generic_property GenericsTests.instantiated_function_argument_names GenericsTests.instantiation_sharing_types GenericsTests.no_stack_overflow_from_quantifying -GenericsTests.reject_clashing_generic_and_pack_names GenericsTests.self_recursive_instantiated_param -IntersectionTypes.no_stack_overflow_from_flattenintersection IntersectionTypes.select_correct_union_fn IntersectionTypes.should_still_pick_an_overload_whose_arguments_are_unions IntersectionTypes.table_intersection_write_sealed IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect -ModuleTests.any_persistance_does_not_leak ModuleTests.clone_self_property ModuleTests.deepClone_cyclic_table NonstrictModeTests.for_in_iterator_variables_are_any @@ -102,10 +86,6 @@ NonstrictModeTests.parameters_having_type_any_are_optional NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_props_are_any Normalize.cyclic_table_normalizes_sensibly -ParseErrorRecovery.generic_type_list_recovery -ParseErrorRecovery.recovery_of_parenthesized_expressions -ParserTests.parse_nesting_based_end_detection_failsafe_earlier -ParserTests.parse_nesting_based_end_detection_local_function ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.bail_early_if_unification_is_too_complicated ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack @@ -114,7 +94,6 @@ ProvisionalTests.free_options_cannot_be_unified_together ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing -ProvisionalTests.refine_unknown_to_table_then_clone_it ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument @@ -122,10 +101,11 @@ ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weirditer_should_not_loop_forever RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string RefinementTest.discriminate_tag -RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true +RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage +RefinementTest.refine_unknowns RefinementTest.type_guard_can_filter_for_intersection_of_tables RefinementTest.type_narrow_for_all_the_userdata RefinementTest.type_narrow_to_vector @@ -157,7 +137,6 @@ TableTests.found_like_key_in_table_property_access TableTests.found_multiple_like_keys TableTests.function_calls_produces_sealed_table_given_unsealed_table TableTests.fuzz_table_unify_instantiated_table -TableTests.fuzz_table_unify_instantiated_table_with_prop_realloc TableTests.generic_table_instantiation_potential_regression TableTests.give_up_after_one_metatable_index_look_up TableTests.indexer_on_sealed_table_must_unify_with_free_table @@ -198,39 +177,31 @@ TableTests.table_simple_call TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.table_unification_4 -TableTests.tc_member_function_2 TableTests.unifying_tables_shouldnt_uaf2 TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon ToString.exhaustive_toString_of_cyclic_table -ToString.function_type_with_argument_names_generic ToString.named_metatable_toStringNamedFunction ToString.toStringDetailed2 ToString.toStringErrorPack ToString.toStringNamedFunction_generic_pack ToString.toStringNamedFunction_hide_self_param -ToString.toStringNamedFunction_hide_type_params -ToString.toStringNamedFunction_id ToString.toStringNamedFunction_include_self_param ToString.toStringNamedFunction_map -ToString.toStringNamedFunction_variadics TryUnifyTests.cli_41095_concat_log_in_sealed_table_unification TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails TryUnifyTests.variadics_should_use_reversed_properly TypeAliases.cannot_create_cyclic_type_with_unknown_module -TypeAliases.corecursive_types_generic TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any TypeAliases.generic_param_remap TypeAliases.mismatched_generic_type_param -TypeAliases.mutually_recursive_types_errors TypeAliases.mutually_recursive_types_restriction_not_ok_1 TypeAliases.mutually_recursive_types_restriction_not_ok_2 TypeAliases.mutually_recursive_types_swapsies_not_ok TypeAliases.recursive_types_restriction_not_ok TypeAliases.report_shadowed_aliases -TypeAliases.stringify_type_alias_of_recursive_template_table_type TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_of_an_imported_recursive_generic_type @@ -298,7 +269,7 @@ TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop TypeInferModules.custom_require_global -TypeInferModules.do_not_modify_imported_types_4 +TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated TypeInferModules.type_error_of_unknown_qualified_type @@ -312,6 +283,7 @@ TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators TypeInferOperators.cli_38355_recursive_union +TypeInferOperators.compound_assign_metatable TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators @@ -319,6 +291,7 @@ TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown TypeInferOperators.operator_eq_completely_incompatible TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs +TypeInferOperators.typecheck_unary_len_error TypeInferOperators.UnknownGlobalCompoundAssign TypeInferOperators.unrelated_classes_cannot_be_compared TypeInferOperators.unrelated_primitives_cannot_be_compared @@ -341,11 +314,8 @@ TypePackTests.type_alias_defaults_confusing_types TypePackTests.type_alias_defaults_recursive_type TypePackTests.type_alias_type_pack_multi TypePackTests.type_alias_type_pack_variadic -TypePackTests.type_alias_type_packs TypePackTests.type_alias_type_packs_errors -TypePackTests.type_alias_type_packs_import TypePackTests.type_alias_type_packs_nested -TypePackTests.type_pack_type_parameters TypePackTests.unify_variadic_tails_in_arguments TypePackTests.unify_variadic_tails_in_arguments_free TypePackTests.variadic_packs @@ -360,7 +330,6 @@ TypeSingletons.table_properties_type_error_escapes TypeSingletons.taking_the_length_of_union_of_string_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeSingletons.widening_happens_almost_everywhere -TypeSingletons.widening_happens_almost_everywhere_except_for_tables UnionTypes.index_on_a_union_type_with_missing_property UnionTypes.optional_assignment_errors UnionTypes.optional_call_error From 1e7b23fbfc3a8681f867d613bb4845db83b3715f Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 24 Feb 2023 10:24:22 -0800 Subject: [PATCH 37/66] Sync to upstream/release/565 --- Analysis/include/Luau/Constraint.h | 26 +- .../include/Luau/ConstraintGraphBuilder.h | 27 +- Analysis/include/Luau/ConstraintSolver.h | 17 +- Analysis/include/Luau/DcrLogger.h | 41 +- Analysis/include/Luau/Scope.h | 2 +- Analysis/include/Luau/Symbol.h | 5 + Analysis/include/Luau/Type.h | 23 + Analysis/include/Luau/TypeInfer.h | 4 +- Analysis/include/Luau/TypeReduction.h | 41 +- Analysis/include/Luau/Unifier.h | 6 + Analysis/src/ConstraintGraphBuilder.cpp | 163 ++-- Analysis/src/ConstraintSolver.cpp | 354 ++++++- Analysis/src/DcrLogger.cpp | 223 +++-- Analysis/src/Frontend.cpp | 4 +- Analysis/src/Normalize.cpp | 2 + Analysis/src/Quantify.cpp | 2 +- Analysis/src/Scope.cpp | 6 +- Analysis/src/ToString.cpp | 6 + Analysis/src/TypeChecker2.cpp | 244 ++++- Analysis/src/TypeInfer.cpp | 205 +++-- Analysis/src/TypeReduction.cpp | 362 ++++---- Analysis/src/Unifier.cpp | 7 +- CodeGen/include/Luau/IrBuilder.h | 3 + CodeGen/include/Luau/IrData.h | 95 +- CodeGen/include/Luau/IrUtils.h | 9 +- CodeGen/include/Luau/OptimizeConstProp.h | 16 + CodeGen/src/CodeGen.cpp | 492 +--------- CodeGen/src/EmitBuiltinsX64.cpp | 24 +- CodeGen/src/EmitBuiltinsX64.h | 13 +- CodeGen/src/EmitCommonX64.h | 8 - CodeGen/src/EmitInstructionX64.cpp | 867 +----------------- CodeGen/src/EmitInstructionX64.h | 61 +- CodeGen/src/IrAnalysis.cpp | 2 + CodeGen/src/IrBuilder.cpp | 63 +- CodeGen/src/IrDump.cpp | 57 +- CodeGen/src/IrLoweringX64.cpp | 373 ++++---- CodeGen/src/IrLoweringX64.h | 32 +- CodeGen/src/IrRegAllocX64.cpp | 181 ++++ CodeGen/src/IrRegAllocX64.h | 51 ++ CodeGen/src/IrTranslateBuiltins.cpp | 40 + CodeGen/src/IrTranslateBuiltins.h | 27 + CodeGen/src/IrTranslation.cpp | 81 +- CodeGen/src/IrTranslation.h | 3 + CodeGen/src/IrUtils.cpp | 70 +- CodeGen/src/OptimizeConstProp.cpp | 565 ++++++++++++ CodeGen/src/OptimizeFinalX64.cpp | 5 +- Common/include/Luau/Bytecode.h | 2 +- Compiler/src/Compiler.cpp | 22 +- Sources.cmake | 6 + tests/Compiler.test.cpp | 61 +- tests/ConstraintGraphBuilderFixture.cpp | 3 +- tests/IrBuilder.test.cpp | 689 +++++++++++++- tests/Module.test.cpp | 47 +- tests/NonstrictMode.test.cpp | 2 +- tests/ToString.test.cpp | 12 +- tests/TypeInfer.aliases.test.cpp | 12 +- tests/TypeInfer.functions.test.cpp | 15 +- tests/TypeInfer.refinements.test.cpp | 34 + tests/TypeInfer.tables.test.cpp | 25 +- tests/TypeInfer.tryUnify.test.cpp | 9 +- tests/TypeInfer.unknownnever.test.cpp | 2 +- tools/faillist.txt | 64 +- 62 files changed, 3492 insertions(+), 2421 deletions(-) create mode 100644 CodeGen/include/Luau/OptimizeConstProp.h create mode 100644 CodeGen/src/IrRegAllocX64.cpp create mode 100644 CodeGen/src/IrRegAllocX64.h create mode 100644 CodeGen/src/IrTranslateBuiltins.cpp create mode 100644 CodeGen/src/IrTranslateBuiltins.h create mode 100644 CodeGen/src/OptimizeConstProp.cpp diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 65599e498..1c41bbb7f 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -159,6 +159,20 @@ struct SetPropConstraint TypeId propType; }; +// result ~ setIndexer subjectType indexType propType +// +// If the subject is a table or table-like thing that already has an indexer, +// unify its indexType and propType with those from this constraint. +// +// If the table is a free or unsealed table, we augment it with a new indexer. +struct SetIndexerConstraint +{ + TypeId resultType; + TypeId subjectType; + TypeId indexType; + TypeId propType; +}; + // if negation: // result ~ if isSingleton D then ~D else unknown where D = discriminantType // if not negation: @@ -170,9 +184,19 @@ struct SingletonOrTopTypeConstraint bool negated; }; +// resultType ~ unpack sourceTypePack +// +// Similar to PackSubtypeConstraint, but with one important difference: If the +// sourcePack is blocked, this constraint blocks. +struct UnpackConstraint +{ + TypePackId resultPack; + TypePackId sourcePack; +}; + using ConstraintV = Variant; + HasPropConstraint, SetPropConstraint, SetIndexerConstraint, SingletonOrTopTypeConstraint, UnpackConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 085b67328..7b2711f89 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -191,7 +191,7 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); - TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); + std::vector checkLValues(const ScopePtr& scope, AstArray exprs); TypeId checkLValue(const ScopePtr& scope, AstExpr* expr); @@ -244,10 +244,31 @@ struct ConstraintGraphBuilder **/ TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments); + /** + * Creates generic types given a list of AST definitions, resolving default + * types as required. + * @param scope the scope that the generics should belong to. + * @param generics the AST generics to create types for. + * @param useCache whether to use the generic type cache for the given + * scope. + * @param addTypes whether to add the types to the scope's + * privateTypeBindings map. + **/ std::vector> createGenerics( - const ScopePtr& scope, AstArray generics, bool useCache = false); + const ScopePtr& scope, AstArray generics, bool useCache = false, bool addTypes = true); + + /** + * Creates generic type packs given a list of AST definitions, resolving + * default type packs as required. + * @param scope the scope that the generic packs should belong to. + * @param generics the AST generics to create type packs for. + * @param useCache whether to use the generic type pack cache for the given + * scope. + * @param addTypes whether to add the types to the scope's + * privateTypePackBindings map. + **/ std::vector> createGenericPacks( - const ScopePtr& scope, AstArray packs, bool useCache = false); + const ScopePtr& scope, AstArray packs, bool useCache = false, bool addTypes = true); Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index de7b3a044..62687ae47 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -8,6 +8,7 @@ #include "Luau/Normalize.h" #include "Luau/ToString.h" #include "Luau/Type.h" +#include "Luau/TypeReduction.h" #include "Luau/Variant.h" #include @@ -19,7 +20,12 @@ struct DcrLogger; // TypeId, TypePackId, or Constraint*. It is impossible to know which, but we // never dereference this pointer. -using BlockedConstraintId = const void*; +using BlockedConstraintId = Variant; + +struct HashBlockedConstraintId +{ + size_t operator()(const BlockedConstraintId& bci) const; +}; struct ModuleResolver; @@ -47,6 +53,7 @@ struct ConstraintSolver NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; + NotNull reducer; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; @@ -65,7 +72,7 @@ struct ConstraintSolver // anything. std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. - std::unordered_map>> blocked; + std::unordered_map>, HashBlockedConstraintId> blocked; // Memoized instantiations of type aliases. DenseHashMap instantiatedAliases{{}}; @@ -78,7 +85,8 @@ struct ConstraintSolver DcrLogger* logger; explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); + ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, + DcrLogger* logger); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); @@ -112,7 +120,9 @@ struct ConstraintSolver bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); bool tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force); bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); + bool tryDispatch(const UnpackConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod @@ -123,6 +133,7 @@ struct ConstraintSolver TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); std::optional lookupTableProp(TypeId subjectType, const std::string& propName); + std::optional lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); void block(NotNull target, NotNull constraint); /** diff --git a/Analysis/include/Luau/DcrLogger.h b/Analysis/include/Luau/DcrLogger.h index 45c84c66e..1e170d5bb 100644 --- a/Analysis/include/Luau/DcrLogger.h +++ b/Analysis/include/Luau/DcrLogger.h @@ -4,6 +4,7 @@ #include "Luau/Constraint.h" #include "Luau/NotNull.h" #include "Luau/Scope.h" +#include "Luau/Module.h" #include "Luau/ToString.h" #include "Luau/Error.h" #include "Luau/Variant.h" @@ -34,11 +35,26 @@ struct TypeBindingSnapshot std::string typeString; }; +struct ExprTypesAtLocation +{ + Location location; + TypeId ty; + std::optional expectedTy; +}; + +struct AnnotationTypesAtLocation +{ + Location location; + TypeId resolvedTy; +}; + struct ConstraintGenerationLog { std::string source; - std::unordered_map constraintLocations; std::vector errors; + + std::vector exprTypeLocations; + std::vector annotationTypeLocations; }; struct ScopeSnapshot @@ -49,16 +65,11 @@ struct ScopeSnapshot std::vector children; }; -enum class ConstraintBlockKind -{ - TypeId, - TypePackId, - ConstraintId, -}; +using ConstraintBlockTarget = Variant>; struct ConstraintBlock { - ConstraintBlockKind kind; + ConstraintBlockTarget target; std::string stringification; }; @@ -71,16 +82,18 @@ struct ConstraintSnapshot struct BoundarySnapshot { - std::unordered_map constraints; + DenseHashMap unsolvedConstraints{nullptr}; ScopeSnapshot rootScope; + DenseHashMap typeStrings{nullptr}; }; struct StepSnapshot { - std::string currentConstraint; + const Constraint* currentConstraint; bool forced; - std::unordered_map unsolvedConstraints; + DenseHashMap unsolvedConstraints{nullptr}; ScopeSnapshot rootScope; + DenseHashMap typeStrings{nullptr}; }; struct TypeSolveLog @@ -95,8 +108,6 @@ struct TypeCheckLog std::vector errors; }; -using ConstraintBlockTarget = Variant>; - struct DcrLogger { std::string compileOutput(); @@ -104,6 +115,7 @@ struct DcrLogger void captureSource(std::string source); void captureGenerationError(const TypeError& error); void captureConstraintLocation(NotNull constraint, Location location); + void captureGenerationModule(const ModulePtr& module); void pushBlock(NotNull constraint, TypeId block); void pushBlock(NotNull constraint, TypePackId block); @@ -126,9 +138,10 @@ struct DcrLogger TypeSolveLog solveLog; TypeCheckLog checkLog; - ToStringOptions opts; + ToStringOptions opts{true}; std::vector snapshotBlocks(NotNull constraint); + void captureBoundaryState(BoundarySnapshot& target, const Scope* rootScope, const std::vector>& unsolvedConstraints); }; } // namespace Luau diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index a8f83e2f7..85a36fc90 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -52,7 +52,7 @@ struct Scope std::optional lookup(Symbol sym) const; std::optional lookup(DefId def) const; - std::optional> lookupEx(Symbol sym); + std::optional> lookupEx(Symbol sym); std::optional lookupType(const Name& name); std::optional lookupImportedType(const Name& moduleAlias, const Name& name); diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index 0432946cc..b47554e0d 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -37,6 +37,11 @@ struct Symbol AstLocal* local; AstName global; + explicit operator bool() const + { + return local != nullptr || global.value != nullptr; + } + bool operator==(const Symbol& rhs) const { if (local) diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 00e6d6c65..d009001b6 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -246,6 +246,18 @@ struct WithPredicate { T type; PredicateVec predicates; + + WithPredicate() = default; + explicit WithPredicate(T type) + : type(type) + { + } + + WithPredicate(T type, PredicateVec predicates) + : type(type) + , predicates(std::move(predicates)) + { + } }; using MagicFunction = std::function>( @@ -853,4 +865,15 @@ bool hasTag(TypeId ty, const std::string& tagName); bool hasTag(const Property& prop, const std::string& tagName); bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work. +/* + * Use this to change the kind of a particular type. + * + * LUAU_NOINLINE so that the calling frame doesn't have to pay the stack storage for the new variant. + */ +template +LUAU_NOINLINE T* emplaceType(Type* ty, Args&&... args) +{ + return &ty->ty.emplace(std::forward(args)...); +} + } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index d748a1f50..678bd419d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -146,10 +146,12 @@ struct TypeChecker WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr); WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExprPackHelper2( + const ScopePtr& scope, const AstExprCall& expr, TypeId selfType, TypeId actualFunctionType, TypeId functionType, TypePackId retPack); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); - std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + std::unique_ptr> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h index 0ad034a49..80a7ac596 100644 --- a/Analysis/include/Luau/TypeReduction.h +++ b/Analysis/include/Luau/TypeReduction.h @@ -12,11 +12,36 @@ namespace Luau namespace detail { template -struct ReductionContext +struct ReductionEdge { T type = nullptr; bool irreducible = false; }; + +struct TypeReductionMemoization +{ + TypeReductionMemoization() = default; + + TypeReductionMemoization(const TypeReductionMemoization&) = delete; + TypeReductionMemoization& operator=(const TypeReductionMemoization&) = delete; + + TypeReductionMemoization(TypeReductionMemoization&&) = default; + TypeReductionMemoization& operator=(TypeReductionMemoization&&) = default; + + DenseHashMap> types{nullptr}; + DenseHashMap> typePacks{nullptr}; + + bool isIrreducible(TypeId ty); + bool isIrreducible(TypePackId tp); + + TypeId memoize(TypeId ty, TypeId reducedTy); + TypePackId memoize(TypePackId tp, TypePackId reducedTp); + + // Reducing A into B may have a non-irreducible edge A to B for which B is not irreducible, which means B could be reduced into C. + // Because reduction should always be transitive, A should point to C if A points to B and B points to C. + std::optional> memoizedof(TypeId ty) const; + std::optional> memoizedof(TypePackId tp) const; +}; } // namespace detail struct TypeReductionOptions @@ -42,29 +67,19 @@ struct TypeReduction std::optional reduce(TypePackId tp); std::optional reduce(const TypeFun& fun); - /// Creating a child TypeReduction will allow the parent TypeReduction to share its memoization with the child TypeReductions. - /// This is safe as long as the parent's TypeArena continues to outlive both TypeReduction memoization. - TypeReduction fork(NotNull arena, const TypeReductionOptions& opts = {}) const; - private: - const TypeReduction* parent = nullptr; - NotNull arena; NotNull builtinTypes; NotNull handle; - TypeReductionOptions options; - DenseHashMap> memoizedTypes{nullptr}; - DenseHashMap> memoizedTypePacks{nullptr}; + TypeReductionOptions options; + detail::TypeReductionMemoization memoization; // Computes an *estimated length* of the cartesian product of the given type. size_t cartesianProductSize(TypeId ty) const; bool hasExceededCartesianProductLimit(TypeId ty) const; bool hasExceededCartesianProductLimit(TypePackId tp) const; - - std::optional memoizedof(TypeId ty) const; - std::optional memoizedof(TypePackId tp) const; }; } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 988ad9c69..ebfff4c29 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -67,6 +67,12 @@ struct Unifier UnifierSharedState& sharedState; + // When the Unifier is forced to unify two blocked types (or packs), they + // get added to these vectors. The ConstraintSolver can use this to know + // when it is safe to reattempt dispatching a constraint. + std::vector blockedTypes; + std::vector blockedTypePacks; + Unifier( NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index aa605bdf0..fe412632c 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -320,6 +320,9 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) prepopulateGlobalScope(scope, block); visitBlockWithoutChildScope(scope, block); + + if (FFlag::DebugLuauLogSolverToJson) + logger->captureGenerationModule(module); } void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) @@ -357,13 +360,11 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true)) { initialFun.typeParams.push_back(gen); - defnScope->privateTypeBindings[name] = TypeFun{gen.ty}; } for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true)) { initialFun.typePackParams.push_back(genPack); - defnScope->privateTypePackBindings[name] = genPack.tp; } if (alias->exported) @@ -503,13 +504,13 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (j - i < packTypes.head.size()) varTypes[j] = packTypes.head[j - i]; else - varTypes[j] = freshType(scope); + varTypes[j] = arena->addType(BlockedType{}); } } std::vector tailValues{varTypes.begin() + i, varTypes.end()}; TypePackId tailPack = arena->addTypePack(std::move(tailValues)); - addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); + addConstraint(scope, local->location, UnpackConstraint{tailPack, exprPack}); } } } @@ -686,6 +687,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct Checkpoint start = checkpoint(this); FunctionSignature sig = checkFunctionSignature(scope, function->func); + std::unordered_set excludeList; + if (AstExprLocal* localName = function->name->as()) { std::optional existingFunctionTy = scope->lookup(localName->local); @@ -716,9 +719,20 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct } else if (AstExprIndexName* indexName = function->name->as()) { + Checkpoint check1 = checkpoint(this); TypeId lvalueType = checkLValue(scope, indexName); + Checkpoint check2 = checkpoint(this); + + forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) { + excludeList.insert(c.get()); + }); + // TODO figure out how to populate the location field of the table Property. - addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); + + if (get(lvalueType)) + asMutable(lvalueType)->ty.emplace(generalizedType); + else + addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); } else if (AstExprError* err = function->name->as()) { @@ -735,8 +749,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); - forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { - c->dependencies.push_back(NotNull{constraint.get()}); + forEachConstraint(start, end, this, [&c, &excludeList](const ConstraintPtr& constraint) { + if (!excludeList.count(constraint.get())) + c->dependencies.push_back(NotNull{constraint.get()}); }); addConstraint(scope, std::move(c)); @@ -763,16 +778,31 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) visitBlockWithoutChildScope(innerScope, block); } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) +static void bindFreeType(TypeId a, TypeId b) { - TypePackId varPackId = checkLValues(scope, assign->vars); + FreeType* af = getMutable(a); + FreeType* bf = getMutable(b); - TypePack expectedPack = extendTypePack(*arena, builtinTypes, varPackId, assign->values.size); + LUAU_ASSERT(af || bf); + + if (!bf) + asMutable(a)->ty.emplace(b); + else if (!af) + asMutable(b)->ty.emplace(a); + else if (subsumes(bf->scope, af->scope)) + asMutable(a)->ty.emplace(b); + else if (subsumes(af->scope, bf->scope)) + asMutable(b)->ty.emplace(a); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) +{ + std::vector varTypes = checkLValues(scope, assign->vars); std::vector> expectedTypes; - expectedTypes.reserve(expectedPack.head.size()); + expectedTypes.reserve(varTypes.size()); - for (TypeId ty : expectedPack.head) + for (TypeId ty : varTypes) { ty = follow(ty); if (get(ty)) @@ -781,9 +811,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) expectedTypes.push_back(ty); } - TypePackId valuePack = checkPack(scope, assign->values, expectedTypes).tp; + TypePackId exprPack = checkPack(scope, assign->values, expectedTypes).tp; + TypePackId varPack = arena->addTypePack({varTypes}); - addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); + addConstraint(scope, assign->location, PackSubtypeConstraint{exprPack, varPack}); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) @@ -865,11 +896,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia asMutable(aliasTy)->ty.emplace(ty); std::vector typeParams; - for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true)) + for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true, /* addTypes */ false)) typeParams.push_back(tyParam.second.ty); std::vector typePackParams; - for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true)) + for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false)) typePackParams.push_back(tpParam.second.tp); addConstraint(scope, alias->type->location, @@ -1010,7 +1041,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction for (auto& [name, generic] : generics) { genericTys.push_back(generic.ty); - scope->privateTypeBindings[name] = TypeFun{generic.ty}; } std::vector genericTps; @@ -1018,7 +1048,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction for (auto& [name, generic] : genericPacks) { genericTps.push_back(generic.tp); - scope->privateTypePackBindings[name] = generic.tp; } ScopePtr funScope = scope; @@ -1161,7 +1190,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa TypePackId expectedArgPack = arena->freshTypePack(scope.get()); TypePackId expectedRetPack = arena->freshTypePack(scope.get()); - TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack}); + TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack, std::nullopt, call->self}); TypeId instantiatedFnType = arena->addType(BlockedType{}); addConstraint(scope, call->location, InstantiationConstraint{instantiatedFnType, fnType}); @@ -1264,7 +1293,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa // TODO: How do expectedTypes play into this? Do they? TypePackId rets = arena->addTypePack(BlockedTypePack{}); TypePackId argPack = arena->addTypePack(TypePack{args, argTail}); - FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets); + FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); NotNull fcc = addConstraint(scope, call->func->location, FunctionCallConstraint{ @@ -1457,7 +1486,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* gl Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr).ty; - TypeId result = freshType(scope); + TypeId result = arena->addType(BlockedType{}); std::optional def = dfg->getDef(indexName); if (def) @@ -1468,13 +1497,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* scope->dcrRefinements[*def] = result; } - TableType::Props props{{indexName->index.value, Property{result}}}; - const std::optional indexer; - TableType ttv{std::move(props), indexer, TypeLevel{}, scope.get(), TableState::Free}; - - TypeId expectedTableType = arena->addType(std::move(ttv)); - - addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); + addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value}); if (def) return Inference{result, refinementArena.proposition(*def, builtinTypes->truthyType)}; @@ -1589,6 +1612,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( else if (typeguard->type == "number") discriminantTy = builtinTypes->numberType; else if (typeguard->type == "boolean") + discriminantTy = builtinTypes->booleanType; + else if (typeguard->type == "thread") discriminantTy = builtinTypes->threadType; else if (typeguard->type == "table") discriminantTy = builtinTypes->tableType; @@ -1596,8 +1621,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( discriminantTy = builtinTypes->functionType; else if (typeguard->type == "userdata") { - // For now, we don't really care about being accurate with userdata if the typeguard was using typeof - discriminantTy = builtinTypes->neverType; // TODO: replace with top class type + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. + discriminantTy = builtinTypes->classType; } else if (!typeguard->isTypeof && typeguard->type == "vector") discriminantTy = builtinTypes->neverType; // TODO: figure out a way to deal with this quirky type @@ -1649,18 +1674,15 @@ std::tuple ConstraintGraphBuilder::checkBinary( } } -TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) +std::vector ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) { std::vector types; types.reserve(exprs.size); - for (size_t i = 0; i < exprs.size; ++i) - { - AstExpr* const expr = exprs.data[i]; + for (AstExpr* expr : exprs) types.push_back(checkLValue(scope, expr)); - } - return arena->addTypePack(std::move(types)); + return types; } /** @@ -1679,6 +1701,28 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'}; return checkLValue(scope, &synthetic); } + + // An indexer is only interesting in an lvalue-ey way if it is at the + // tail of an expression. + // + // If the indexer is not at the tail, then we are not interested in + // augmenting the lhs data structure with a new indexer. Constraint + // generation can treat it as an ordinary lvalue. + // + // eg + // + // a.b.c[1] = 44 -- lvalue + // a.b[4].c = 2 -- rvalue + + TypeId resultType = arena->addType(BlockedType{}); + TypeId subjectType = check(scope, indexExpr->expr).ty; + TypeId indexType = check(scope, indexExpr->index).ty; + TypeId propType = arena->addType(BlockedType{}); + addConstraint(scope, expr->location, SetIndexerConstraint{resultType, subjectType, indexType, propType}); + + module->astTypes[expr] = propType; + + return propType; } else if (!expr->is()) return check(scope, expr).ty; @@ -1718,7 +1762,8 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) auto lookupResult = scope->lookupEx(sym); if (!lookupResult) return check(scope, expr).ty; - const auto [subjectType, symbolScope] = std::move(*lookupResult); + const auto [subjectBinding, symbolScope] = std::move(*lookupResult); + TypeId subjectType = subjectBinding->typeId; TypeId propTy = freshType(scope); @@ -1739,14 +1784,17 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) module->astTypes[expr] = prevSegmentTy; module->astTypes[e] = updatedType; - symbolScope->bindings[sym].typeId = updatedType; - - std::optional def = dfg->getDef(sym); - if (def) + if (!subjectType->persistent) { - // This can fail if the user is erroneously trying to augment a builtin - // table like os or string. - symbolScope->dcrRefinements[*def] = updatedType; + symbolScope->bindings[sym].typeId = updatedType; + + std::optional def = dfg->getDef(sym); + if (def) + { + // This can fail if the user is erroneously trying to augment a builtin + // table like os or string. + symbolScope->dcrRefinements[*def] = updatedType; + } } return propTy; @@ -1904,13 +1952,11 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS for (const auto& [name, g] : genericDefinitions) { genericTypes.push_back(g.ty); - signatureScope->privateTypeBindings[name] = TypeFun{g.ty}; } for (const auto& [name, g] : genericPackDefinitions) { genericTypePacks.push_back(g.tp); - signatureScope->privateTypePackBindings[name] = g.tp; } // Local variable works around an odd gcc 11.3 warning: may be used uninitialized @@ -2023,15 +2069,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS actualFunction.generics = std::move(genericTypes); actualFunction.genericPacks = std::move(genericTypePacks); actualFunction.argNames = std::move(argNames); + actualFunction.hasSelf = fn->self != nullptr; TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); module->astTypes[fn] = actualFunctionType; if (expectedType && get(*expectedType)) - { - asMutable(*expectedType)->ty.emplace(actualFunctionType); - } + bindFreeType(*expectedType, actualFunctionType); return { /* signature */ actualFunctionType, @@ -2179,13 +2224,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (const auto& [name, g] : genericDefinitions) { genericTypes.push_back(g.ty); - signatureScope->privateTypeBindings[name] = TypeFun{g.ty}; } for (const auto& [name, g] : genericPackDefinitions) { genericTypePacks.push_back(g.tp); - signatureScope->privateTypePackBindings[name] = g.tp; } } else @@ -2330,7 +2373,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const } std::vector> ConstraintGraphBuilder::createGenerics( - const ScopePtr& scope, AstArray generics, bool useCache) + const ScopePtr& scope, AstArray generics, bool useCache, bool addTypes) { std::vector> result; for (const auto& generic : generics) @@ -2350,6 +2393,9 @@ std::vector> ConstraintGraphBuilder::crea if (generic.defaultValue) defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false); + if (addTypes) + scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy}; + result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); } @@ -2357,7 +2403,7 @@ std::vector> ConstraintGraphBuilder::crea } std::vector> ConstraintGraphBuilder::createGenericPacks( - const ScopePtr& scope, AstArray generics, bool useCache) + const ScopePtr& scope, AstArray generics, bool useCache, bool addTypes) { std::vector> result; for (const auto& generic : generics) @@ -2378,6 +2424,9 @@ std::vector> ConstraintGraphBuilder:: if (generic.defaultValue) defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false); + if (addTypes) + scope->privateTypePackBindings[generic.name.value] = genericTy; + result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); } @@ -2394,11 +2443,9 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo if (auto f = first(tp)) return Inference{*f, refinement}; - TypeId typeResult = freshType(scope); - TypePack onePack{{typeResult}, freshTypePack(scope)}; - TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); - - addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); + TypeId typeResult = arena->addType(BlockedType{}); + TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); + addConstraint(scope, location, UnpackConstraint{resultPack, tp}); return Inference{typeResult, refinement}; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 879dac39b..96673e3dc 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -22,6 +22,22 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { +size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const +{ + size_t result = 0; + + if (const TypeId* ty = get_if(&bci)) + result = std::hash()(*ty); + else if (const TypePackId* tp = get_if(&bci)) + result = std::hash()(*tp); + else if (Constraint const* const* c = get_if(&bci)) + result = std::hash()(*c); + else + LUAU_ASSERT(!"Should be unreachable"); + + return result; +} + [[maybe_unused]] static void dumpBindings(NotNull scope, ToStringOptions& opts) { for (const auto& [k, v] : scope->bindings) @@ -221,10 +237,12 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) + ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, + DcrLogger* logger) : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) + , reducer(reducer) , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) @@ -326,6 +344,27 @@ void ConstraintSolver::run() if (force) printf("Force "); printf("Dispatched\n\t%s\n", saveMe.c_str()); + + if (force) + { + printf("Blocked on:\n"); + + for (const auto& [bci, cv] : blocked) + { + if (end(cv) == std::find(begin(cv), end(cv), c)) + continue; + + if (auto bty = get_if(&bci)) + printf("\tType %s\n", toString(*bty, opts).c_str()); + else if (auto btp = get_if(&bci)) + printf("\tPack %s\n", toString(*btp, opts).c_str()); + else if (auto cc = get_if(&bci)) + printf("\tCons %s\n", toString(**cc, opts).c_str()); + else + LUAU_ASSERT(!"Unreachable??"); + } + } + dump(this, opts); } } @@ -411,8 +450,12 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*hpc, constraint); else if (auto spc = get(*constraint)) success = tryDispatch(*spc, constraint, force); + else if (auto spc = get(*constraint)) + success = tryDispatch(*spc, constraint, force); else if (auto sottc = get(*constraint)) success = tryDispatch(*sottc, constraint); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint); else LUAU_ASSERT(false); @@ -424,26 +467,46 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { - if (!recursiveBlock(c.subType, constraint)) - return false; - if (!recursiveBlock(c.superType, constraint)) - return false; - if (isBlocked(c.subType)) return block(c.subType, constraint); else if (isBlocked(c.superType)) return block(c.superType, constraint); - unify(c.subType, c.superType, constraint->scope); + Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(c.subType, c.superType); + + if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) + { + for (TypeId bt : u.blockedTypes) + block(bt, constraint); + for (TypePackId btp : u.blockedTypePacks) + block(btp, constraint); + return false; + } + + if (!u.errors.empty()) + { + TypeId errorType = errorRecoveryType(); + u.tryUnify(c.subType, errorType); + u.tryUnify(c.superType, errorType); + } + + const auto [changedTypes, changedPacks] = u.log.getChanges(); + + u.log.commit(); + + unblock(changedTypes); + unblock(changedPacks); + + // unify(c.subType, c.superType, constraint->scope); return true; } bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) { - if (!recursiveBlock(c.subPack, constraint) || !recursiveBlock(c.superPack, constraint)) - return false; - if (isBlocked(c.subPack)) return block(c.subPack, constraint); else if (isBlocked(c.superPack)) @@ -1183,8 +1246,26 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(BlockedType{}); TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); - auto ic = pushConstraint(constraint->scope, constraint->location, InstantiationConstraint{instantiatedTy, fn}); - auto sc = pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{instantiatedTy, inferredTy}); + auto pushConstraintGreedy = [this, constraint](ConstraintV cv) -> Constraint* { + std::unique_ptr c = std::make_unique(constraint->scope, constraint->location, std::move(cv)); + NotNull borrow{c.get()}; + + bool ok = tryDispatch(borrow, false); + if (ok) + return nullptr; + + solverConstraints.push_back(std::move(c)); + unsolvedConstraints.push_back(borrow); + + return borrow; + }; + + // HACK: We don't want other constraints to act on the free type pack + // created above until after these two constraints are solved, so we try to + // dispatch them directly. + + auto ic = pushConstraintGreedy(InstantiationConstraint{instantiatedTy, fn}); + auto sc = pushConstraintGreedy(SubtypeConstraint{instantiatedTy, inferredTy}); // Anything that is blocked on this constraint must also be blocked on our // synthesized constraints. @@ -1193,8 +1274,10 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) { - block(ic, blockedConstraint); - block(sc, blockedConstraint); + if (ic) + block(NotNull{ic}, blockedConstraint); + if (sc) + block(NotNull{sc}, blockedConstraint); } } @@ -1230,6 +1313,8 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullreduce(subjectType).value_or(subjectType); + std::optional resultType = lookupTableProp(subjectType, c.prop); if (!resultType) { @@ -1360,11 +1445,18 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullscope); + if (!isBlocked(c.propType)) + unify(c.propType, *existingPropType, constraint->scope); bind(c.resultType, c.subjectType); return true; } + if (get(subjectType) || get(subjectType) || get(subjectType)) + { + bind(c.resultType, subjectType); + return true; + } + if (get(subjectType)) { TypeId ty = arena->freshType(constraint->scope); @@ -1381,21 +1473,27 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) { if (ttv->state == TableState::Free) { + LUAU_ASSERT(!subjectType->persistent); + ttv->props[c.path[0]] = Property{c.propType}; bind(c.resultType, c.subjectType); return true; } else if (ttv->state == TableState::Unsealed) { + LUAU_ASSERT(!subjectType->persistent); + std::optional augmented = updateTheTableType(NotNull{arena}, subjectType, c.path, c.propType); bind(c.resultType, augmented.value_or(subjectType)); + bind(subjectType, c.resultType); return true; } else @@ -1411,13 +1509,59 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType) || get(subjectType) || get(subjectType)) + + LUAU_ASSERT(0); + return true; +} + +bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force) +{ + TypeId subjectType = follow(c.subjectType); + if (isBlocked(subjectType)) + return block(subjectType, constraint); + + if (auto ft = get(subjectType)) { - bind(c.resultType, subjectType); + Scope* scope = ft->scope; + TableType* tt = &asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, scope); + tt->indexer = TableIndexer{c.indexType, c.propType}; + + asMutable(c.resultType)->ty.emplace(subjectType); + asMutable(c.propType)->ty.emplace(scope); + unblock(c.propType); + unblock(c.resultType); + return true; } + else if (auto tt = get(subjectType)) + { + if (tt->indexer) + { + // TODO This probably has to be invariant. + unify(c.indexType, tt->indexer->indexType, constraint->scope); + asMutable(c.propType)->ty.emplace(tt->indexer->indexResultType); + asMutable(c.resultType)->ty.emplace(subjectType); + unblock(c.propType); + unblock(c.resultType); + return true; + } + else if (tt->state == TableState::Free || tt->state == TableState::Unsealed) + { + auto mtt = getMutable(subjectType); + mtt->indexer = TableIndexer{c.indexType, c.propType}; + asMutable(c.propType)->ty.emplace(tt->scope); + asMutable(c.resultType)->ty.emplace(subjectType); + unblock(c.propType); + unblock(c.resultType); + return true; + } + // Do not augment sealed or generic tables that lack indexers + } - LUAU_ASSERT(0); + asMutable(c.propType)->ty.emplace(builtinTypes->errorRecoveryType()); + asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); + unblock(c.propType); + unblock(c.resultType); return true; } @@ -1439,6 +1583,69 @@ bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNul return true; } +bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) +{ + TypePackId sourcePack = follow(c.sourcePack); + TypePackId resultPack = follow(c.resultPack); + + if (isBlocked(sourcePack)) + return block(sourcePack, constraint); + + if (isBlocked(resultPack)) + { + asMutable(resultPack)->ty.emplace(sourcePack); + unblock(resultPack); + return true; + } + + TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack)); + + auto destIter = begin(resultPack); + auto destEnd = end(resultPack); + + size_t i = 0; + while (destIter != destEnd) + { + if (i >= srcPack.head.size()) + break; + TypeId srcTy = follow(srcPack.head[i]); + + if (isBlocked(*destIter)) + { + if (follow(srcTy) == *destIter) + { + // Cyclic type dependency. (????) + asMutable(*destIter)->ty.emplace(constraint->scope); + } + else + asMutable(*destIter)->ty.emplace(srcTy); + unblock(*destIter); + } + else + unify(*destIter, srcTy, constraint->scope); + + ++destIter; + ++i; + } + + // We know that resultPack does not have a tail, but we don't know if + // sourcePack is long enough to fill every value. Replace every remaining + // result TypeId with the error recovery type. + + while (destIter != destEnd) + { + if (isBlocked(*destIter)) + { + asMutable(*destIter)->ty.emplace(builtinTypes->errorRecoveryType()); + unblock(*destIter); + } + + ++destIter; + } + + return true; +} + bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) { auto block_ = [&](auto&& t) { @@ -1628,10 +1835,20 @@ bool ConstraintSolver::tryDispatchIterableFunction( std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) { + std::unordered_set seen; + return lookupTableProp(subjectType, propName, seen); +} + +std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) +{ + if (!seen.insert(subjectType).second) + return std::nullopt; + auto collectParts = [&](auto&& unionOrIntersection) -> std::pair, std::vector> { std::optional blocked; std::vector parts; + std::vector freeParts; for (TypeId expectedPart : unionOrIntersection) { expectedPart = follow(expectedPart); @@ -1644,6 +1861,29 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons else if (ttv->indexer && maybeString(ttv->indexer->indexType)) parts.push_back(ttv->indexer->indexResultType); } + else if (get(expectedPart)) + { + freeParts.push_back(expectedPart); + } + } + + // If the only thing resembling a match is a single fresh type, we can + // confidently tablify it. If other types match or if there are more + // than one free type, we can't do anything. + if (parts.empty() && 1 == freeParts.size()) + { + TypeId freePart = freeParts.front(); + const FreeType* ft = get(freePart); + LUAU_ASSERT(ft); + Scope* scope = ft->scope; + + TableType* tt = &asMutable(freePart)->ty.emplace(); + tt->state = TableState::Free; + tt->scope = scope; + TypeId propType = arena->freshType(scope); + tt->props[propName] = Property{propType}; + + parts.push_back(propType); } return {blocked, parts}; @@ -1651,12 +1891,75 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons std::optional resultType; - if (auto ttv = get(subjectType)) + if (get(subjectType) || get(subjectType)) + { + return subjectType; + } + else if (auto ttv = getMutable(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) resultType = prop->second.type; else if (ttv->indexer && maybeString(ttv->indexer->indexType)) resultType = ttv->indexer->indexResultType; + else if (ttv->state == TableState::Free) + { + resultType = arena->addType(FreeType{ttv->scope}); + ttv->props[propName] = Property{*resultType}; + } + } + else if (auto mt = get(subjectType)) + { + if (auto p = lookupTableProp(mt->table, propName, seen)) + return p; + + TypeId mtt = follow(mt->metatable); + + if (get(mtt)) + return mtt; + else if (auto metatable = get(mtt)) + { + auto indexProp = metatable->props.find("__index"); + if (indexProp == metatable->props.end()) + return std::nullopt; + + // TODO: __index can be an overloaded function. + + TypeId indexType = follow(indexProp->second.type); + + if (auto ft = get(indexType)) + { + std::optional ret = first(ft->retTypes); + if (ret) + return *ret; + else + return std::nullopt; + } + + return lookupTableProp(indexType, propName, seen); + } + } + else if (auto ct = get(subjectType)) + { + while (ct) + { + if (auto prop = ct->props.find(propName); prop != ct->props.end()) + return prop->second.type; + else if (ct->parent) + ct = get(follow(*ct->parent)); + else + break; + } + } + else if (auto pt = get(subjectType); pt && pt->metatable) + { + const TableType* metatable = get(follow(*pt->metatable)); + LUAU_ASSERT(metatable); + + auto indexProp = metatable->props.find("__index"); + if (indexProp == metatable->props.end()) + return std::nullopt; + + return lookupTableProp(indexProp->second.type, propName, seen); } else if (auto utv = get(subjectType)) { @@ -1704,7 +2007,7 @@ void ConstraintSolver::block(NotNull target, NotNull constraint) @@ -1715,7 +2018,7 @@ bool ConstraintSolver::block(TypeId target, NotNull constraint if (FFlag::DebugLuauLogSolver) printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); - block_(target, constraint); + block_(follow(target), constraint); return false; } @@ -1802,7 +2105,7 @@ void ConstraintSolver::unblock(NotNull progressed) if (FFlag::DebugLuauLogSolverToJson) logger->popBlock(progressed); - return unblock_(progressed); + return unblock_(progressed.get()); } void ConstraintSolver::unblock(TypeId progressed) @@ -1810,7 +2113,10 @@ void ConstraintSolver::unblock(TypeId progressed) if (FFlag::DebugLuauLogSolverToJson) logger->popBlock(progressed); - return unblock_(progressed); + unblock_(progressed); + + if (auto bt = get(progressed)) + unblock(bt->boundTo); } void ConstraintSolver::unblock(TypePackId progressed) diff --git a/Analysis/src/DcrLogger.cpp b/Analysis/src/DcrLogger.cpp index a1ef650b8..9f66b022a 100644 --- a/Analysis/src/DcrLogger.cpp +++ b/Analysis/src/DcrLogger.cpp @@ -9,17 +9,39 @@ namespace Luau { +template +static std::string toPointerId(const T* ptr) +{ + return std::to_string(reinterpret_cast(ptr)); +} + +static std::string toPointerId(NotNull ptr) +{ + return std::to_string(reinterpret_cast(ptr.get())); +} + namespace Json { +template +void write(JsonEmitter& emitter, const T* ptr) +{ + write(emitter, toPointerId(ptr)); +} + +void write(JsonEmitter& emitter, NotNull ptr) +{ + write(emitter, toPointerId(ptr)); +} + void write(JsonEmitter& emitter, const Location& location) { - ObjectEmitter o = emitter.writeObject(); - o.writePair("beginLine", location.begin.line); - o.writePair("beginColumn", location.begin.column); - o.writePair("endLine", location.end.line); - o.writePair("endColumn", location.end.column); - o.finish(); + ArrayEmitter a = emitter.writeArray(); + a.writeValue(location.begin.line); + a.writeValue(location.begin.column); + a.writeValue(location.end.line); + a.writeValue(location.end.column); + a.finish(); } void write(JsonEmitter& emitter, const ErrorSnapshot& snapshot) @@ -47,24 +69,43 @@ void write(JsonEmitter& emitter, const TypeBindingSnapshot& snapshot) o.finish(); } -void write(JsonEmitter& emitter, const ConstraintGenerationLog& log) +template +void write(JsonEmitter& emitter, const DenseHashMap& map) { ObjectEmitter o = emitter.writeObject(); - o.writePair("source", log.source); + for (const auto& [k, v] : map) + o.writePair(toPointerId(k), v); + o.finish(); +} - emitter.writeComma(); - write(emitter, "constraintLocations"); - emitter.writeRaw(":"); +void write(JsonEmitter& emitter, const ExprTypesAtLocation& tys) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("location", tys.location); + o.writePair("ty", toPointerId(tys.ty)); - ObjectEmitter locationEmitter = emitter.writeObject(); + if (tys.expectedTy) + o.writePair("expectedTy", toPointerId(*tys.expectedTy)); - for (const auto& [id, location] : log.constraintLocations) - { - locationEmitter.writePair(id, location); - } + o.finish(); +} + +void write(JsonEmitter& emitter, const AnnotationTypesAtLocation& tys) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("location", tys.location); + o.writePair("resolvedTy", toPointerId(tys.resolvedTy)); + o.finish(); +} - locationEmitter.finish(); +void write(JsonEmitter& emitter, const ConstraintGenerationLog& log) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("source", log.source); o.writePair("errors", log.errors); + o.writePair("exprTypeLocations", log.exprTypeLocations); + o.writePair("annotationTypeLocations", log.annotationTypeLocations); + o.finish(); } @@ -78,26 +119,34 @@ void write(JsonEmitter& emitter, const ScopeSnapshot& snapshot) o.finish(); } -void write(JsonEmitter& emitter, const ConstraintBlockKind& kind) -{ - switch (kind) - { - case ConstraintBlockKind::TypeId: - return write(emitter, "type"); - case ConstraintBlockKind::TypePackId: - return write(emitter, "typePack"); - case ConstraintBlockKind::ConstraintId: - return write(emitter, "constraint"); - default: - LUAU_ASSERT(0); - } -} - void write(JsonEmitter& emitter, const ConstraintBlock& block) { ObjectEmitter o = emitter.writeObject(); - o.writePair("kind", block.kind); o.writePair("stringification", block.stringification); + + auto go = [&o](auto&& t) { + using T = std::decay_t; + + o.writePair("id", toPointerId(t)); + + if constexpr (std::is_same_v) + { + o.writePair("kind", "type"); + } + else if constexpr (std::is_same_v) + { + o.writePair("kind", "typePack"); + } + else if constexpr (std::is_same_v>) + { + o.writePair("kind", "constraint"); + } + else + static_assert(always_false_v, "non-exhaustive possibility switch"); + }; + + visit(go, block.target); + o.finish(); } @@ -114,7 +163,8 @@ void write(JsonEmitter& emitter, const BoundarySnapshot& snapshot) { ObjectEmitter o = emitter.writeObject(); o.writePair("rootScope", snapshot.rootScope); - o.writePair("constraints", snapshot.constraints); + o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints); + o.writePair("typeStrings", snapshot.typeStrings); o.finish(); } @@ -125,6 +175,7 @@ void write(JsonEmitter& emitter, const StepSnapshot& snapshot) o.writePair("forced", snapshot.forced); o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints); o.writePair("rootScope", snapshot.rootScope); + o.writePair("typeStrings", snapshot.typeStrings); o.finish(); } @@ -146,11 +197,6 @@ void write(JsonEmitter& emitter, const TypeCheckLog& log) } // namespace Json -static std::string toPointerId(NotNull ptr) -{ - return std::to_string(reinterpret_cast(ptr.get())); -} - static ScopeSnapshot snapshotScope(const Scope* scope, ToStringOptions& opts) { std::unordered_map bindings; @@ -230,6 +276,32 @@ void DcrLogger::captureSource(std::string source) generationLog.source = std::move(source); } +void DcrLogger::captureGenerationModule(const ModulePtr& module) +{ + generationLog.exprTypeLocations.reserve(module->astTypes.size()); + for (const auto& [expr, ty] : module->astTypes) + { + ExprTypesAtLocation tys; + tys.location = expr->location; + tys.ty = ty; + + if (auto expectedTy = module->astExpectedTypes.find(expr)) + tys.expectedTy = *expectedTy; + + generationLog.exprTypeLocations.push_back(tys); + } + + generationLog.annotationTypeLocations.reserve(module->astResolvedTypes.size()); + for (const auto& [annot, ty] : module->astResolvedTypes) + { + AnnotationTypesAtLocation tys; + tys.location = annot->location; + tys.resolvedTy = ty; + + generationLog.annotationTypeLocations.push_back(tys); + } +} + void DcrLogger::captureGenerationError(const TypeError& error) { std::string stringifiedError = toString(error); @@ -239,12 +311,6 @@ void DcrLogger::captureGenerationError(const TypeError& error) }); } -void DcrLogger::captureConstraintLocation(NotNull constraint, Location location) -{ - std::string id = toPointerId(constraint); - generationLog.constraintLocations[id] = location; -} - void DcrLogger::pushBlock(NotNull constraint, TypeId block) { constraintBlocks[constraint].push_back(block); @@ -284,44 +350,70 @@ void DcrLogger::popBlock(NotNull block) } } -void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) +static void snapshotTypeStrings(const std::vector& interestedExprs, + const std::vector& interestedAnnots, DenseHashMap& map, ToStringOptions& opts) +{ + for (const ExprTypesAtLocation& tys : interestedExprs) + { + map[tys.ty] = toString(tys.ty, opts); + + if (tys.expectedTy) + map[*tys.expectedTy] = toString(*tys.expectedTy, opts); + } + + for (const AnnotationTypesAtLocation& tys : interestedAnnots) + { + map[tys.resolvedTy] = toString(tys.resolvedTy, opts); + } +} + +void DcrLogger::captureBoundaryState( + BoundarySnapshot& target, const Scope* rootScope, const std::vector>& unsolvedConstraints) { - solveLog.initialState.rootScope = snapshotScope(rootScope, opts); - solveLog.initialState.constraints.clear(); + target.rootScope = snapshotScope(rootScope, opts); + target.unsolvedConstraints.clear(); for (NotNull c : unsolvedConstraints) { - std::string id = toPointerId(c); - solveLog.initialState.constraints[id] = { + target.unsolvedConstraints[c.get()] = { toString(*c.get(), opts), c->location, snapshotBlocks(c), }; } + + snapshotTypeStrings(generationLog.exprTypeLocations, generationLog.annotationTypeLocations, target.typeStrings, opts); +} + +void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) +{ + captureBoundaryState(solveLog.initialState, rootScope, unsolvedConstraints); } StepSnapshot DcrLogger::prepareStepSnapshot( const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints) { ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts); - std::string currentId = toPointerId(current); - std::unordered_map constraints; + DenseHashMap constraints{nullptr}; for (NotNull c : unsolvedConstraints) { - std::string id = toPointerId(c); - constraints[id] = { + constraints[c.get()] = { toString(*c.get(), opts), c->location, snapshotBlocks(c), }; } + DenseHashMap typeStrings{nullptr}; + snapshotTypeStrings(generationLog.exprTypeLocations, generationLog.annotationTypeLocations, typeStrings, opts); + return StepSnapshot{ - currentId, + current, force, - constraints, + std::move(constraints), scopeSnapshot, + std::move(typeStrings), }; } @@ -332,18 +424,7 @@ void DcrLogger::commitStepSnapshot(StepSnapshot snapshot) void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) { - solveLog.finalState.rootScope = snapshotScope(rootScope, opts); - solveLog.finalState.constraints.clear(); - - for (NotNull c : unsolvedConstraints) - { - std::string id = toPointerId(c); - solveLog.finalState.constraints[id] = { - toString(*c.get(), opts), - c->location, - snapshotBlocks(c), - }; - } + captureBoundaryState(solveLog.finalState, rootScope, unsolvedConstraints); } void DcrLogger::captureTypeCheckError(const TypeError& error) @@ -370,21 +451,21 @@ std::vector DcrLogger::snapshotBlocks(NotNull if (const TypeId* ty = get_if(&target)) { snapshot.push_back({ - ConstraintBlockKind::TypeId, + *ty, toString(*ty, opts), }); } else if (const TypePackId* tp = get_if(&target)) { snapshot.push_back({ - ConstraintBlockKind::TypePackId, + *tp, toString(*tp, opts), }); } else if (const NotNull* c = get_if>(&target)) { snapshot.push_back({ - ConstraintBlockKind::ConstraintId, + *c, toString(*(c->get()), opts), }); } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index fb61b4ab3..91c72e447 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -899,8 +899,8 @@ ModulePtr check( cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, moduleResolver, - requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, + NotNull{result->reduction.get()}, moduleResolver, requireCycles, logger.get()}; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index a7b2b7276..0b7608104 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1441,6 +1441,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor if (!unionNormals(here, *tn)) return false; } + else if (get(there)) + LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); else LUAU_ASSERT(!"Unreachable"); diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index aac7864a8..845ae3a36 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -183,7 +183,7 @@ struct PureQuantifier : Substitution else if (ttv->state == TableState::Generic) seenGenericType = true; - return ttv->state == TableState::Unsealed || (ttv->state == TableState::Free && subsumes(scope, ttv->scope)); + return (ttv->state == TableState::Unsealed || ttv->state == TableState::Free) && subsumes(scope, ttv->scope); } return false; diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 84925f790..cac72124e 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -31,12 +31,12 @@ std::optional Scope::lookup(Symbol sym) const { auto r = const_cast(this)->lookupEx(sym); if (r) - return r->first; + return r->first->typeId; else return std::nullopt; } -std::optional> Scope::lookupEx(Symbol sym) +std::optional> Scope::lookupEx(Symbol sym) { Scope* s = this; @@ -44,7 +44,7 @@ std::optional> Scope::lookupEx(Symbol sym) { auto it = s->bindings.find(sym); if (it != s->bindings.end()) - return std::pair{it->second.typeId, s}; + return std::pair{&it->second, s}; if (s->parent) s = s->parent.get(); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 1972177cf..d0c539845 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1533,6 +1533,10 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]"; return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType); } + else if constexpr (std::is_same_v) + { + return tos(c.resultType) + " ~ setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType); + } else if constexpr (std::is_same_v) { std::string result = tos(c.resultType); @@ -1543,6 +1547,8 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) else return result + " ~ if isSingleton D then D else unknown where D = " + discriminant; } + else if constexpr (std::is_same_v) + return tos(c.resultPack) + " ~ unpack " + tos(c.sourcePack); else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 4322a0daa..f23fad780 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -4,6 +4,7 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/DcrLogger.h" #include "Luau/Error.h" #include "Luau/Instantiation.h" @@ -329,11 +330,12 @@ struct TypeChecker2 for (size_t i = 0; i < count; ++i) { AstExpr* value = i < local->values.size ? local->values.data[i] : nullptr; + const bool isPack = value && (value->is() || value->is()); if (value) visit(value, RValue); - if (i != local->values.size - 1 || value) + if (i != local->values.size - 1 || !isPack) { AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; @@ -351,16 +353,19 @@ struct TypeChecker2 visit(var->annotation); } } - else + else if (value) { - LUAU_ASSERT(value); + TypePackId valuePack = lookupPack(value); + TypePack valueTypes; + if (i < local->vars.size) + valueTypes = extendTypePack(module->internalTypes, builtinTypes, valuePack, local->vars.size - i); - TypePackId valueTypes = lookupPack(value); - auto it = begin(valueTypes); + Location errorLocation; for (size_t j = i; j < local->vars.size; ++j) { - if (it == end(valueTypes)) + if (j - i >= valueTypes.head.size()) { + errorLocation = local->vars.data[j]->location; break; } @@ -368,14 +373,28 @@ struct TypeChecker2 if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - ErrorVec errors = tryUnify(stack.back(), value->location, *it, varType); + ErrorVec errors = tryUnify(stack.back(), value->location, valueTypes.head[j - i], varType); if (!errors.empty()) reportErrors(std::move(errors)); visit(var->annotation); } + } - ++it; + if (valueTypes.head.size() < local->vars.size - i) + { + reportError( + CountMismatch{ + // We subtract 1 here because the final AST + // expression is not worth one value. It is worth 0 + // or more depending on valueTypes.head + local->values.size - 1 + valueTypes.head.size(), + std::nullopt, + local->vars.size, + local->values.data[local->values.size - 1]->is() ? CountMismatch::FunctionResult + : CountMismatch::ExprListResult, + }, + errorLocation); } } } @@ -810,6 +829,95 @@ struct TypeChecker2 // TODO! } + ErrorVec visitOverload(AstExprCall* call, NotNull overloadFunctionType, const std::vector& argLocs, + TypePackId expectedArgTypes, TypePackId expectedRetType) + { + ErrorVec overloadErrors = + tryUnify(stack.back(), call->location, overloadFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult); + + size_t argIndex = 0; + auto inferredArgIt = begin(overloadFunctionType->argTypes); + auto expectedArgIt = begin(expectedArgTypes); + while (inferredArgIt != end(overloadFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) + { + Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; + ErrorVec argErrors = tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt); + for (TypeError e : argErrors) + overloadErrors.emplace_back(e); + + ++argIndex; + ++inferredArgIt; + ++expectedArgIt; + } + + // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad + ErrorVec argumentErrors = tryUnify(stack.back(), call->location, expectedArgTypes, overloadFunctionType->argTypes); + for (TypeError e : argumentErrors) + if (get(e) != nullptr) + overloadErrors.emplace_back(std::move(e)); + + return overloadErrors; + } + + void reportOverloadResolutionErrors(AstExprCall* call, std::vector overloads, TypePackId expectedArgTypes, + const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) + { + if (overloads.size() == 1) + { + reportErrors(std::get<0>(overloadsErrors.front())); + return; + } + + std::vector overloadTypes = overloadsThatMatchArgCount; + if (overloadsThatMatchArgCount.size() == 0) + { + reportError(GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location); + // If no overloads match argument count, just list all overloads. + overloadTypes = overloads; + } + else + { + // Report errors of the first argument-count-matching, but failing overload + TypeId overload = overloadsThatMatchArgCount[0]; + + // Remove the overload we are reporting errors about from the list of alternatives + overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end()); + + const FunctionType* ftv = get(overload); + LUAU_ASSERT(ftv); // overload must be a function type here + + auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [ftv](const std::pair& e) { + return ftv == std::get<1>(e); + }); + + LUAU_ASSERT(error != overloadsErrors.end()); + reportErrors(std::get<0>(*error)); + + // If only one overload matched, we don't need this error because we provided the previous errors. + if (overloadsThatMatchArgCount.size() == 1) + return; + } + + std::string s; + for (size_t i = 0; i < overloadTypes.size(); ++i) + { + TypeId overload = follow(overloadTypes[i]); + + if (i > 0) + s += "; "; + + if (i > 0 && i == overloadTypes.size() - 1) + s += "and "; + + s += toString(overload); + } + + if (overloadsThatMatchArgCount.size() == 0) + reportError(ExtraInformation{"Available overloads: " + s}, call->func->location); + else + reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location); + } + void visit(AstExprCall* call) { visit(call->func, RValue); @@ -865,6 +973,10 @@ struct TypeChecker2 return; } } + else if (auto itv = get(functionType)) + { + // We do nothing here because we'll flatten the intersection later, but we don't want to report it as a non-function. + } else if (auto utv = get(functionType)) { // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. @@ -930,48 +1042,105 @@ struct TypeChecker2 TypePackId expectedArgTypes = arena->addTypePack(args); - const FunctionType* inferredFunctionType = get(testFunctionType); - LUAU_ASSERT(inferredFunctionType); // testFunctionType should always be a FunctionType here + std::vector overloads = flattenIntersection(testFunctionType); + std::vector> overloadsErrors; + overloadsErrors.reserve(overloads.size()); - size_t argIndex = 0; - auto inferredArgIt = begin(inferredFunctionType->argTypes); - auto expectedArgIt = begin(expectedArgTypes); - while (inferredArgIt != end(inferredFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) + std::vector overloadsThatMatchArgCount; + + for (TypeId overload : overloads) { - Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; - reportErrors(tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt)); + overload = follow(overload); - ++argIndex; - ++inferredArgIt; - ++expectedArgIt; - } + const FunctionType* overloadFn = get(overload); + if (!overloadFn) + { + reportError(CannotCallNonFunction{overload}, call->func->location); + return; + } + else + { + // We may have to instantiate the overload in order for it to typecheck. + if (std::optional instantiatedFunctionType = instantiation.substitute(overload)) + { + overloadFn = get(*instantiatedFunctionType); + } + else + { + overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overloadFn); + return; + } + } - // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad - ErrorVec errors = tryUnify(stack.back(), call->location, expectedArgTypes, inferredFunctionType->argTypes); - for (TypeError e : errors) - if (get(e) != nullptr) - reportError(std::move(e)); + ErrorVec overloadErrors = visitOverload(call, NotNull{overloadFn}, argLocs, expectedArgTypes, expectedRetType); + if (overloadErrors.empty()) + return; + + bool argMismatch = false; + for (auto error : overloadErrors) + { + CountMismatch* cm = get(error); + if (!cm) + continue; + + if (cm->context == CountMismatch::Arg) + { + argMismatch = true; + break; + } + } - reportErrors(tryUnify(stack.back(), call->location, inferredFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult)); + if (!argMismatch) + overloadsThatMatchArgCount.push_back(overload); + + overloadsErrors.emplace_back(std::move(overloadErrors), overloadFn); + } + + reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); } - void visit(AstExprIndexName* indexName, ValueContext context) + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) { - visit(indexName->expr, RValue); + visit(expr, RValue); - TypeId leftType = lookupType(indexName->expr); + TypeId leftType = lookupType(expr); const NormalizedType* norm = normalizer.normalize(leftType); if (!norm) - reportError(NormalizationTooComplex{}, indexName->indexLocation); + reportError(NormalizationTooComplex{}, location); - checkIndexTypeFromType(leftType, *norm, indexName->index.value, indexName->location, context); + checkIndexTypeFromType(leftType, *norm, propName, location, context); + } + + void visit(AstExprIndexName* indexName, ValueContext context) + { + visitExprName(indexName->expr, indexName->location, indexName->index.value, context); } void visit(AstExprIndexExpr* indexExpr, ValueContext context) { + if (auto str = indexExpr->index->as()) + { + const std::string stringValue(str->value.data, str->value.size); + visitExprName(indexExpr->expr, indexExpr->location, stringValue, context); + return; + } + // TODO! visit(indexExpr->expr, LValue); visit(indexExpr->index, RValue); + + NotNull scope = stack.back(); + + TypeId exprType = lookupType(indexExpr->expr); + TypeId indexType = lookupType(indexExpr->index); + + if (auto tt = get(exprType)) + { + if (tt->indexer) + reportErrors(tryUnify(scope, indexExpr->index->location, indexType, tt->indexer->indexType)); + else + reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + } } void visit(AstExprFunction* fn) @@ -1879,8 +2048,17 @@ struct TypeChecker2 ty = *mtIndex; } - if (getTableType(ty)) - return bool(findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)); + if (auto tt = getTableType(ty)) + { + if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)) + return true; + + else if (tt->indexer && isPrim(tt->indexer->indexResultType, PrimitiveType::String)) + return tt->indexer->indexResultType; + + else + return false; + } else if (const ClassType* cls = get(ty)) return bool(lookupClassProp(cls, prop)); else if (const UnionType* utv = get(ty)) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e59c7e0ee..adca034c4 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1759,7 +1759,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate result; @@ -1767,23 +1767,23 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (auto a = expr.as()) result = checkExpr(scope, *a->expr, expectedType); else if (expr.is()) - result = {nilType}; + result = WithPredicate{nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) { if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) - result = {singletonType(bexpr->value)}; + result = WithPredicate{singletonType(bexpr->value)}; else - result = {booleanType}; + result = WithPredicate{booleanType}; } else if (const AstExprConstantString* sexpr = expr.as()) { if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) - result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; + result = WithPredicate{singletonType(std::string(sexpr->value.data, sexpr->value.size))}; else - result = {stringType}; + result = WithPredicate{stringType}; } else if (expr.is()) - result = {numberType}; + result = WithPredicate{numberType}; else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1837,7 +1837,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint // ice("AstExprLocal exists but no binding definition for it?", expr.location); reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) @@ -1849,7 +1849,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) @@ -1859,26 +1859,26 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (get(varargPack)) { if (std::optional ty = first(varargPack)) - return {*ty}; + return WithPredicate{*ty}; - return {nilType}; + return WithPredicate{nilType}; } else if (get(varargPack)) { TypeId head = freshType(scope); TypePackId tail = freshTypePack(scope); *asMutable(varargPack) = TypePack{{head}, tail}; - return {head}; + return WithPredicate{head}; } if (get(varargPack)) - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) - return {vtp->ty}; + return WithPredicate{vtp->ty}; else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } else ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); @@ -1929,9 +1929,9 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp lhsType = stripFromNilAndReport(lhsType, expr.expr->location); if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, /* addErrors= */ true)) - return {*ty}; + return WithPredicate{*ty}; - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors) @@ -2138,7 +2138,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (std::optional refiTy = resolveLValue(scope, *lvalue)) return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; - return {ty}; + return WithPredicate{ty}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) @@ -2147,7 +2147,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp checkFunctionBody(funScope, funTy, expr); - return {quantify(funScope, funTy, expr.location)}; + return WithPredicate{quantify(funScope, funTy, expr.location)}; } TypeId TypeChecker::checkExprTable( @@ -2252,7 +2252,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } std::vector> fieldTypes(expr.items.size); @@ -2339,7 +2339,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp expectedIndexResultType = fieldTypes[i].second; } - return {checkExprTable(scope, expr, fieldTypes, expectedType)}; + return WithPredicate{checkExprTable(scope, expr, fieldTypes, expectedType)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) @@ -2356,7 +2356,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp const bool operandIsAny = get(operandType) || get(operandType) || get(operandType); if (operandIsAny) - return {operandType}; + return WithPredicate{operandType}; if (typeCouldHaveMetatable(operandType)) { @@ -2377,16 +2377,16 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (!state.errors.empty()) retType = errorRecoveryType(retType); - return {retType}; + return WithPredicate{retType}; } reportError(expr.location, GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } reportErrors(tryUnify(operandType, numberType, scope, expr.location)); - return {numberType}; + return WithPredicate{numberType}; } case AstExprUnary::Len: { @@ -2396,7 +2396,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp // # operator is guaranteed to return number if (get(operandType) || get(operandType) || get(operandType)) - return {numberType}; + return WithPredicate{numberType}; DenseHashSet seen{nullptr}; @@ -2420,7 +2420,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (!hasLength(operandType, seen, &recursionCount)) reportError(TypeError{expr.location, NotATable{operandType}}); - return {numberType}; + return WithPredicate{numberType}; } default: ice("Unknown AstExprUnary " + std::to_string(int(expr.op))); @@ -3014,7 +3014,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp WithPredicate rhs = checkExpr(scope, *expr.right); // Intentionally discarding predicates with other operators. - return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; + return WithPredicate{checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; } } @@ -3045,7 +3045,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp // any type errors that may arise from it are going to be useless. currentModule->errors.resize(oldSize); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) @@ -3061,12 +3061,12 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp WithPredicate falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); if (falseType.type == trueType.type) - return {trueType.type}; + return WithPredicate{trueType.type}; std::vector types = reduceUnion({trueType.type, falseType.type}); if (types.empty()) - return {neverType}; - return {types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})}; + return WithPredicate{neverType}; + return WithPredicate{types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprInterpString& expr) @@ -3074,7 +3074,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp for (AstExpr* expr : expr.expressions) checkExpr(scope, *expr); - return {stringType}; + return WithPredicate{stringType}; } TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx) @@ -3704,7 +3704,7 @@ WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, cons { WithPredicate result = checkExprPackHelper(scope, expr); if (containsNever(result.type)) - return {uninhabitableTypePack}; + return WithPredicate{uninhabitableTypePack}; return result; } @@ -3715,14 +3715,14 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope else if (expr.is()) { if (!scope->varargPack) - return {errorRecoveryTypePack(scope)}; + return WithPredicate{errorRecoveryTypePack(scope)}; - return {*scope->varargPack}; + return WithPredicate{*scope->varargPack}; } else { TypeId type = checkExpr(scope, expr).type; - return {addTypePack({type})}; + return WithPredicate{addTypePack({type})}; } } @@ -3994,71 +3994,77 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope { retPack = freshTypePack(free->level); TypePackId freshArgPack = freshTypePack(free->level); - asMutable(actualFunctionType)->ty.emplace(free->level, freshArgPack, retPack); + emplaceType(asMutable(actualFunctionType), free->level, freshArgPack, retPack); } else retPack = freshTypePack(scope->level); - // checkExpr will log the pre-instantiated type of the function. - // That's not nearly as interesting as the instantiated type, which will include details about how - // generic functions are being instantiated for this particular callsite. - currentModule->astOriginalCallTypes[expr.func] = follow(functionType); - currentModule->astTypes[expr.func] = actualFunctionType; + // We break this function up into a lambda here to limit our stack footprint. + // The vectors used by this function aren't allocated until the lambda is actually called. + auto the_rest = [&]() -> WithPredicate { + // checkExpr will log the pre-instantiated type of the function. + // That's not nearly as interesting as the instantiated type, which will include details about how + // generic functions are being instantiated for this particular callsite. + currentModule->astOriginalCallTypes[expr.func] = follow(functionType); + currentModule->astTypes[expr.func] = actualFunctionType; - std::vector overloads = flattenIntersection(actualFunctionType); + std::vector overloads = flattenIntersection(actualFunctionType); - std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); + std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); - WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); - TypePackId argPack = argListResult.type; + WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); + TypePackId argPack = argListResult.type; - if (get(argPack)) - return {errorRecoveryTypePack(scope)}; + if (get(argPack)) + return WithPredicate{errorRecoveryTypePack(scope)}; - TypePack* args = nullptr; - if (expr.self) - { - argPack = addTypePack(TypePack{{selfType}, argPack}); - argListResult.type = argPack; - } - args = getMutable(argPack); - LUAU_ASSERT(args); + TypePack* args = nullptr; + if (expr.self) + { + argPack = addTypePack(TypePack{{selfType}, argPack}); + argListResult.type = argPack; + } + args = getMutable(argPack); + LUAU_ASSERT(args); - std::vector argLocations; - argLocations.reserve(expr.args.size + 1); - if (expr.self) - argLocations.push_back(expr.func->as()->expr->location); - for (AstExpr* arg : expr.args) - argLocations.push_back(arg->location); + std::vector argLocations; + argLocations.reserve(expr.args.size + 1); + if (expr.self) + argLocations.push_back(expr.func->as()->expr->location); + for (AstExpr* arg : expr.args) + argLocations.push_back(arg->location); - std::vector errors; // errors encountered for each overload + std::vector errors; // errors encountered for each overload - std::vector overloadsThatMatchArgCount; - std::vector overloadsThatDont; + std::vector overloadsThatMatchArgCount; + std::vector overloadsThatDont; - for (TypeId fn : overloads) - { - fn = follow(fn); + for (TypeId fn : overloads) + { + fn = follow(fn); - if (auto ret = checkCallOverload( - scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) - return *ret; - } + if (auto ret = checkCallOverload( + scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) + return *ret; + } - if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) - return {retPack}; + if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) + return WithPredicate{retPack}; - reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); - const FunctionType* overload = nullptr; - if (!overloadsThatMatchArgCount.empty()) - overload = get(overloadsThatMatchArgCount[0]); - if (!overload && !overloadsThatDont.empty()) - overload = get(overloadsThatDont[0]); - if (overload) - return {errorRecoveryTypePack(overload->retTypes)}; + const FunctionType* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return WithPredicate{errorRecoveryTypePack(overload->retTypes)}; - return {errorRecoveryTypePack(retPack)}; + return WithPredicate{errorRecoveryTypePack(retPack)}; + }; + + return the_rest(); } std::vector> TypeChecker::getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall) @@ -4119,8 +4125,13 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st return expectedTypes; } -std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, +/* + * Note: We return a std::unique_ptr here rather than an optional to manage our stack consumption. + * If this was an optional, callers would have to pay the stack cost for the result. This is problematic + * for functions that need to support recursion up to 600 levels deep. + */ +std::unique_ptr> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, + TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { LUAU_ASSERT(argLocations); @@ -4130,16 +4141,16 @@ std::optional> TypeChecker::checkCallOverload(const Sc if (get(fn)) { unify(anyTypePack, argPack, scope, expr.location); - return {{anyTypePack}}; + return std::make_unique>(anyTypePack); } if (get(fn)) { - return {{errorRecoveryTypePack(scope)}}; + return std::make_unique>(errorRecoveryTypePack(scope)); } if (get(fn)) - return {{uninhabitableTypePack}}; + return std::make_unique>(uninhabitableTypePack); if (auto ftv = get(fn)) { @@ -4152,7 +4163,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc options.isFunctionCall = true; unify(r, fn, scope, expr.location, options); - return {{retPack}}; + return std::make_unique>(retPack); } std::vector metaArgLocations; @@ -4191,7 +4202,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); unify(errorRecoveryTypePack(scope), retPack, scope, expr.func->location); - return {{errorRecoveryTypePack(retPack)}}; + return std::make_unique>(errorRecoveryTypePack(retPack)); } // When this function type has magic functions and did return something, we select that overload instead. @@ -4200,7 +4211,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) - return *ret; + return std::make_unique>(std::move(*ret)); } Unifier state = mkUnifier(scope, expr.location); @@ -4209,7 +4220,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { - return {}; + return nullptr; } checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, *argLocations); @@ -4244,10 +4255,10 @@ std::optional> TypeChecker::checkCallOverload(const Sc currentModule->astOverloadResolvedTypes[&expr] = fn; // We select this overload - return {{retPack}}; + return std::make_unique>(retPack); } - return {}; + return nullptr; } bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, @@ -4404,7 +4415,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons }; if (exprs.size == 0) - return {pack}; + return WithPredicate{pack}; TypePack* tp = getMutable(pack); @@ -4484,7 +4495,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons log.commit(); if (uninhabitable) - return {uninhabitableTypePack}; + return WithPredicate{uninhabitableTypePack}; return {pack, predicates}; } diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 94fb4ad3e..2393829df 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -16,10 +16,166 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) namespace Luau { -namespace +namespace detail +{ +bool TypeReductionMemoization::isIrreducible(TypeId ty) +{ + ty = follow(ty); + + // Only does shallow check, the TypeReducer itself already does deep traversal. + if (auto edge = types.find(ty); edge && edge->irreducible) + return true; + else if (get(ty) || get(ty) || get(ty)) + return false; + else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) + return false; + else + return true; +} + +bool TypeReductionMemoization::isIrreducible(TypePackId tp) +{ + tp = follow(tp); + + // Only does shallow check, the TypeReducer itself already does deep traversal. + if (auto edge = typePacks.find(tp); edge && edge->irreducible) + return true; + else if (get(tp) || get(tp)) + return false; + else if (auto vtp = get(tp)) + return isIrreducible(vtp->ty); + else + return true; +} + +TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy) +{ + ty = follow(ty); + reducedTy = follow(reducedTy); + + // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. + // We don't need to recurse much further than that, because we already record the irreducibility from + // the bottom up. + bool irreducible = isIrreducible(reducedTy); + if (auto it = get(reducedTy)) + { + for (TypeId part : it) + irreducible &= isIrreducible(part); + } + else if (auto ut = get(reducedTy)) + { + for (TypeId option : ut) + irreducible &= isIrreducible(option); + } + else if (auto tt = get(reducedTy)) + { + for (auto& [k, p] : tt->props) + irreducible &= isIrreducible(p.type); + + if (tt->indexer) + { + irreducible &= isIrreducible(tt->indexer->indexType); + irreducible &= isIrreducible(tt->indexer->indexResultType); + } + + for (auto ta : tt->instantiatedTypeParams) + irreducible &= isIrreducible(ta); + + for (auto tpa : tt->instantiatedTypePackParams) + irreducible &= isIrreducible(tpa); + } + else if (auto mt = get(reducedTy)) + { + irreducible &= isIrreducible(mt->table); + irreducible &= isIrreducible(mt->metatable); + } + else if (auto ft = get(reducedTy)) + { + irreducible &= isIrreducible(ft->argTypes); + irreducible &= isIrreducible(ft->retTypes); + } + else if (auto nt = get(reducedTy)) + irreducible &= isIrreducible(nt->ty); + + types[ty] = {reducedTy, irreducible}; + types[reducedTy] = {reducedTy, irreducible}; + return reducedTy; +} + +TypePackId TypeReductionMemoization::memoize(TypePackId tp, TypePackId reducedTp) +{ + tp = follow(tp); + reducedTp = follow(reducedTp); + + bool irreducible = isIrreducible(reducedTp); + TypePackIterator it = begin(tp); + while (it != end(tp)) + { + irreducible &= isIrreducible(*it); + ++it; + } + + if (it.tail()) + irreducible &= isIrreducible(*it.tail()); + + typePacks[tp] = {reducedTp, irreducible}; + typePacks[reducedTp] = {reducedTp, irreducible}; + return reducedTp; +} + +std::optional> TypeReductionMemoization::memoizedof(TypeId ty) const +{ + auto fetchContext = [this](TypeId ty) -> std::optional> { + if (auto edge = types.find(ty)) + return *edge; + else + return std::nullopt; + }; + + TypeId currentTy = ty; + std::optional> lastEdge; + while (auto edge = fetchContext(currentTy)) + { + lastEdge = edge; + if (edge->irreducible) + return edge; + else if (edge->type == currentTy) + return edge; + else + currentTy = edge->type; + } + + return lastEdge; +} + +std::optional> TypeReductionMemoization::memoizedof(TypePackId tp) const { + auto fetchContext = [this](TypePackId tp) -> std::optional> { + if (auto edge = typePacks.find(tp)) + return *edge; + else + return std::nullopt; + }; + + TypePackId currentTp = tp; + std::optional> lastEdge; + while (auto edge = fetchContext(currentTp)) + { + lastEdge = edge; + if (edge->irreducible) + return edge; + else if (edge->type == currentTp) + return edge; + else + currentTp = edge->type; + } -using detail::ReductionContext; + return lastEdge; +} +} // namespace detail + +namespace +{ template std::pair get2(const Thing& one, const Thing& two) @@ -34,9 +190,7 @@ struct TypeReducer NotNull arena; NotNull builtinTypes; NotNull handle; - - DenseHashMap>* memoizedTypes; - DenseHashMap>* memoizedTypePacks; + NotNull memoization; DenseHashSet* cyclics; int depth = 0; @@ -50,12 +204,6 @@ struct TypeReducer TypeId functionType(TypeId ty); TypeId negationType(TypeId ty); - bool isIrreducible(TypeId ty); - bool isIrreducible(TypePackId tp); - - TypeId memoize(TypeId ty, TypeId reducedTy); - TypePackId memoize(TypePackId tp, TypePackId reducedTp); - using BinaryFold = std::optional (TypeReducer::*)(TypeId, TypeId); using UnaryFold = TypeId (TypeReducer::*)(TypeId); @@ -64,12 +212,15 @@ struct TypeReducer { ty = follow(ty); - if (auto ctx = memoizedTypes->find(ty)) - return {ctx->type, getMutable(ctx->type)}; + if (auto edge = memoization->memoizedof(ty)) + return {edge->type, getMutable(edge->type)}; + // We specifically do not want to use [`detail::TypeReductionMemoization::memoize`] because that will + // potentially consider these copiedTy to be reducible, but we need this to resolve cyclic references + // without attempting to recursively reduce it, causing copies of copies of copies of... TypeId copiedTy = arena->addType(*t); - (*memoizedTypes)[ty] = {copiedTy, true}; - (*memoizedTypes)[copiedTy] = {copiedTy, true}; + memoization->types[ty] = {copiedTy, true}; + memoization->types[copiedTy] = {copiedTy, true}; return {copiedTy, getMutable(copiedTy)}; } @@ -175,8 +326,13 @@ TypeId TypeReducer::reduce(TypeId ty) { ty = follow(ty); - if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible) - return ctx->type; + if (auto edge = memoization->memoizedof(ty)) + { + if (edge->irreducible) + return edge->type; + else + ty = edge->type; + } else if (cyclics->contains(ty)) return ty; @@ -196,15 +352,20 @@ TypeId TypeReducer::reduce(TypeId ty) else result = ty; - return memoize(ty, result); + return memoization->memoize(ty, result); } TypePackId TypeReducer::reduce(TypePackId tp) { tp = follow(tp); - if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible) - return ctx->type; + if (auto edge = memoization->memoizedof(tp)) + { + if (edge->irreducible) + return edge->type; + else + tp = edge->type; + } else if (cyclics->contains(tp)) return tp; @@ -237,11 +398,11 @@ TypePackId TypeReducer::reduce(TypePackId tp) } if (!didReduce) - return memoize(tp, tp); + return memoization->memoize(tp, tp); else if (head.empty() && tail) - return memoize(tp, *tail); + return memoization->memoize(tp, *tail); else - return memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); + return memoization->memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); } std::optional TypeReducer::intersectionType(TypeId left, TypeId right) @@ -832,111 +993,6 @@ TypeId TypeReducer::negationType(TypeId ty) return ty; // for all T except the ones handled above, ~T ~ ~T } -bool TypeReducer::isIrreducible(TypeId ty) -{ - ty = follow(ty); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible) - return true; - else if (get(ty) || get(ty) || get(ty)) - return false; - else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) - return false; - else - return true; -} - -bool TypeReducer::isIrreducible(TypePackId tp) -{ - tp = follow(tp); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible) - return true; - else if (get(tp) || get(tp)) - return false; - else if (auto vtp = get(tp)) - return isIrreducible(vtp->ty); - else - return true; -} - -TypeId TypeReducer::memoize(TypeId ty, TypeId reducedTy) -{ - ty = follow(ty); - reducedTy = follow(reducedTy); - - // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. - // We don't need to recurse much further than that, because we already record the irreducibility from - // the bottom up. - bool irreducible = isIrreducible(reducedTy); - if (auto it = get(reducedTy)) - { - for (TypeId part : it) - irreducible &= isIrreducible(part); - } - else if (auto ut = get(reducedTy)) - { - for (TypeId option : ut) - irreducible &= isIrreducible(option); - } - else if (auto tt = get(reducedTy)) - { - for (auto& [k, p] : tt->props) - irreducible &= isIrreducible(p.type); - - if (tt->indexer) - { - irreducible &= isIrreducible(tt->indexer->indexType); - irreducible &= isIrreducible(tt->indexer->indexResultType); - } - - for (auto ta : tt->instantiatedTypeParams) - irreducible &= isIrreducible(ta); - - for (auto tpa : tt->instantiatedTypePackParams) - irreducible &= isIrreducible(tpa); - } - else if (auto mt = get(reducedTy)) - { - irreducible &= isIrreducible(mt->table); - irreducible &= isIrreducible(mt->metatable); - } - else if (auto ft = get(reducedTy)) - { - irreducible &= isIrreducible(ft->argTypes); - irreducible &= isIrreducible(ft->retTypes); - } - else if (auto nt = get(reducedTy)) - irreducible &= isIrreducible(nt->ty); - - (*memoizedTypes)[ty] = {reducedTy, irreducible}; - (*memoizedTypes)[reducedTy] = {reducedTy, irreducible}; - return reducedTy; -} - -TypePackId TypeReducer::memoize(TypePackId tp, TypePackId reducedTp) -{ - tp = follow(tp); - reducedTp = follow(reducedTp); - - bool irreducible = isIrreducible(reducedTp); - TypePackIterator it = begin(tp); - while (it != end(tp)) - { - irreducible &= isIrreducible(*it); - ++it; - } - - if (it.tail()) - irreducible &= isIrreducible(*it.tail()); - - (*memoizedTypePacks)[tp] = {reducedTp, irreducible}; - (*memoizedTypePacks)[reducedTp] = {reducedTp, irreducible}; - return reducedTp; -} - struct MarkCycles : TypeVisitor { DenseHashSet cyclics{nullptr}; @@ -961,7 +1017,6 @@ struct MarkCycles : TypeVisitor return !cyclics.find(follow(tp)); } }; - } // namespace TypeReduction::TypeReduction( @@ -981,8 +1036,13 @@ std::optional TypeReduction::reduce(TypeId ty) return ty; else if (!options.allowTypeReductionsFromOtherArenas && ty->owningArena != arena) return ty; - else if (auto memoized = memoizedof(ty)) - return *memoized; + else if (auto edge = memoization.memoizedof(ty)) + { + if (edge->irreducible) + return edge->type; + else + ty = edge->type; + } else if (hasExceededCartesianProductLimit(ty)) return std::nullopt; @@ -991,7 +1051,7 @@ std::optional TypeReduction::reduce(TypeId ty) MarkCycles finder; finder.traverse(ty); - TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclics}; + TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; return reducer.reduce(ty); } catch (const RecursionLimitException&) @@ -1008,8 +1068,13 @@ std::optional TypeReduction::reduce(TypePackId tp) return tp; else if (!options.allowTypeReductionsFromOtherArenas && tp->owningArena != arena) return tp; - else if (auto memoized = memoizedof(tp)) - return *memoized; + else if (auto edge = memoization.memoizedof(tp)) + { + if (edge->irreducible) + return edge->type; + else + tp = edge->type; + } else if (hasExceededCartesianProductLimit(tp)) return std::nullopt; @@ -1018,7 +1083,7 @@ std::optional TypeReduction::reduce(TypePackId tp) MarkCycles finder; finder.traverse(tp); - TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclics}; + TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; return reducer.reduce(tp); } catch (const RecursionLimitException&) @@ -1039,13 +1104,6 @@ std::optional TypeReduction::reduce(const TypeFun& fun) return std::nullopt; } -TypeReduction TypeReduction::fork(NotNull arena, const TypeReductionOptions& opts) const -{ - TypeReduction child{arena, builtinTypes, handle, opts}; - child.parent = this; - return child; -} - size_t TypeReduction::cartesianProductSize(TypeId ty) const { ty = follow(ty); @@ -1093,24 +1151,4 @@ bool TypeReduction::hasExceededCartesianProductLimit(TypePackId tp) const return false; } -std::optional TypeReduction::memoizedof(TypeId ty) const -{ - if (auto ctx = memoizedTypes.find(ty); ctx && ctx->irreducible) - return ctx->type; - else if (parent) - return parent->memoizedof(ty); - else - return std::nullopt; -} - -std::optional TypeReduction::memoizedof(TypePackId tp) const -{ - if (auto ctx = memoizedTypePacks.find(tp); ctx && ctx->irreducible) - return ctx->type; - else if (parent) - return parent->memoizedof(tp); - else - return std::nullopt; -} - } // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 7104f2e73..6364a5aa4 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -520,7 +520,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (const UnionType* subUnion = log.getMutable(subTy)) + if (log.getMutable(subTy) && log.getMutable(superTy)) + { + blockedTypes.push_back(subTy); + blockedTypes.push_back(superTy); + } + else if (const UnionType* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); } diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index ebbba6893..295534214 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -42,6 +42,7 @@ struct IrBuilder IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e); + IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f); IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence IrOp blockAtInst(uint32_t index); @@ -57,6 +58,8 @@ struct IrBuilder IrFunction function; + uint32_t activeBlockIdx = ~0u; + std::vector instIndexToBlock; // Block index at the bytecode instruction }; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 6a7094684..18d510cc9 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -5,6 +5,7 @@ #include "Luau/RegisterX64.h" #include "Luau/RegisterA64.h" +#include #include #include @@ -186,6 +187,16 @@ enum class IrCmd : uint8_t // A: int INT_TO_NUM, + // Adjust stack top (L->top) to point at 'B' TValues *after* the specified register + // This is used to return muliple values + // A: Rn + // B: int (offset) + ADJUST_STACK_TO_REG, + + // Restore stack top (L->top) to point to the function stack top (L->ci->top) + // This is used to recover after calling a variadic function + ADJUST_STACK_TO_TOP, + // Fallback functions // Perform an arithmetic operation on TValues of any type @@ -329,7 +340,7 @@ enum class IrCmd : uint8_t // Call specified function // A: unsigned int (bytecode instruction index) // B: Rn (function, followed by arguments) - // C: int (argument count or -1 to preserve all arguments up to stack top) + // C: int (argument count or -1 to use all arguments up to stack top) // D: int (result count or -1 to preserve all results and adjust stack top) // Note: return values are placed starting from Rn specified in 'B' LOP_CALL, @@ -337,13 +348,13 @@ enum class IrCmd : uint8_t // Return specified values from the function // A: unsigned int (bytecode instruction index) // B: Rn (value start) - // B: int (result count or -1 to return all values up to stack top) + // C: int (result count or -1 to return all values up to stack top) LOP_RETURN, // Perform a fast call of a built-in function // A: unsigned int (bytecode instruction index) // B: Rn (argument start) - // C: int (argument count or -1 preserve all arguments up to stack top) + // C: int (argument count or -1 use all arguments up to stack top) // D: block (fallback) // Note: return values are placed starting from Rn specified in 'B' LOP_FASTCALL, @@ -560,6 +571,7 @@ struct IrInst IrOp c; IrOp d; IrOp e; + IrOp f; uint32_t lastUse = 0; uint16_t useCount = 0; @@ -584,9 +596,10 @@ struct IrBlock uint16_t useCount = 0; - // Start points to an instruction index in a stream - // End is implicit + // 'start' and 'finish' define an inclusive range of instructions which belong to this block inside the function + // When block has been constructed, 'finish' always points to the first and only terminating instruction uint32_t start = ~0u; + uint32_t finish = ~0u; Label label; }; @@ -633,6 +646,19 @@ struct IrFunction return value.valueTag; } + std::optional asTagOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Tag) + return std::nullopt; + + return value.valueTag; + } + bool boolOp(IrOp op) { IrConst& value = constOp(op); @@ -641,6 +667,19 @@ struct IrFunction return value.valueBool; } + std::optional asBoolOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Bool) + return std::nullopt; + + return value.valueBool; + } + int intOp(IrOp op) { IrConst& value = constOp(op); @@ -649,6 +688,19 @@ struct IrFunction return value.valueInt; } + std::optional asIntOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Int) + return std::nullopt; + + return value.valueInt; + } + unsigned uintOp(IrOp op) { IrConst& value = constOp(op); @@ -657,6 +709,19 @@ struct IrFunction return value.valueUint; } + std::optional asUintOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Uint) + return std::nullopt; + + return value.valueUint; + } + double doubleOp(IrOp op) { IrConst& value = constOp(op); @@ -665,11 +730,31 @@ struct IrFunction return value.valueDouble; } + std::optional asDoubleOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Double) + return std::nullopt; + + return value.valueDouble; + } + IrCondition conditionOp(IrOp op) { LUAU_ASSERT(op.kind == IrOpKind::Condition); return IrCondition(op.index); } + + uint32_t getBlockIndex(const IrBlock& block) + { + // Can only be called with blocks from our vector + LUAU_ASSERT(&block >= blocks.data() && &block <= blocks.data() + blocks.size()); + return uint32_t(&block - blocks.data()); + } }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 3e95813bb..0a23b3f77 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -162,6 +162,8 @@ inline bool isPseudo(IrCmd cmd) return cmd == IrCmd::NOP || cmd == IrCmd::SUBSTITUTE; } +bool isGCO(uint8_t tag); + // Remove a single instruction void kill(IrFunction& function, IrInst& inst); @@ -179,7 +181,7 @@ void replace(IrFunction& function, IrOp& original, IrOp replacement); // Replace a single instruction // Target instruction index instead of reference is used to handle introduction of a new block terminator -void replace(IrFunction& function, uint32_t instIdx, IrInst replacement); +void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst replacement); // Replace instruction with a different value (using IrCmd::SUBSTITUTE) void substitute(IrFunction& function, IrInst& inst, IrOp replacement); @@ -188,10 +190,13 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement); void applySubstitutions(IrFunction& function, IrOp& op); void applySubstitutions(IrFunction& function, IrInst& inst); +// Compare numbers using IR condition value +bool compare(double a, double b, IrCondition cond); + // Perform constant folding on instruction at index // For most instructions, successful folding results in a IrCmd::SUBSTITUTE // But it can also be successful on conditional control-flow, replacing it with an unconditional IrCmd::JUMP -void foldConstants(IrBuilder& build, IrFunction& function, uint32_t instIdx); +void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t instIdx); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/OptimizeConstProp.h b/CodeGen/include/Luau/OptimizeConstProp.h new file mode 100644 index 000000000..3be044128 --- /dev/null +++ b/CodeGen/include/Luau/OptimizeConstProp.h @@ -0,0 +1,16 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrData.h" + +namespace Luau +{ +namespace CodeGen +{ + +struct IrBuilder; + +void constPropInBlockChains(IrBuilder& build); + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 78f001f17..5076cba2d 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -7,6 +7,7 @@ #include "Luau/CodeBlockUnwind.h" #include "Luau/IrAnalysis.h" #include "Luau/IrBuilder.h" +#include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeFinalX64.h" #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" @@ -31,7 +32,7 @@ #endif #endif -LUAU_FASTFLAGVARIABLE(DebugUseOldCodegen, false) +LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false) namespace Luau { @@ -40,12 +41,6 @@ namespace CodeGen constexpr uint32_t kFunctionAlignment = 32; -struct InstructionOutline -{ - int pcpos; - int length; -}; - static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) { if (build.logText) @@ -64,346 +59,6 @@ static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) emitContinueCallInVm(build); } -static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i, - Label* labelarr, Label& next, Label& fallback) -{ - int skip = 0; - - switch (op) - { - case LOP_NOP: - break; - case LOP_LOADNIL: - emitInstLoadNil(build, pc); - break; - case LOP_LOADB: - emitInstLoadB(build, pc, i, labelarr); - break; - case LOP_LOADN: - emitInstLoadN(build, pc); - break; - case LOP_LOADK: - emitInstLoadK(build, pc); - break; - case LOP_LOADKX: - emitInstLoadKX(build, pc); - break; - case LOP_MOVE: - emitInstMove(build, pc); - break; - case LOP_GETGLOBAL: - emitInstGetGlobal(build, pc, i, fallback); - break; - case LOP_SETGLOBAL: - emitInstSetGlobal(build, pc, i, next, fallback); - break; - case LOP_NAMECALL: - emitInstNameCall(build, pc, i, proto->k, next, fallback); - break; - case LOP_CALL: - emitInstCall(build, helpers, pc, i); - break; - case LOP_RETURN: - emitInstReturn(build, helpers, pc, i); - break; - case LOP_GETTABLE: - emitInstGetTable(build, pc, fallback); - break; - case LOP_SETTABLE: - emitInstSetTable(build, pc, next, fallback); - break; - case LOP_GETTABLEKS: - emitInstGetTableKS(build, pc, i, fallback); - break; - case LOP_SETTABLEKS: - emitInstSetTableKS(build, pc, i, next, fallback); - break; - case LOP_GETTABLEN: - emitInstGetTableN(build, pc, fallback); - break; - case LOP_SETTABLEN: - emitInstSetTableN(build, pc, next, fallback); - break; - case LOP_JUMP: - emitInstJump(build, pc, i, labelarr); - break; - case LOP_JUMPBACK: - emitInstJumpBack(build, pc, i, labelarr); - break; - case LOP_JUMPIF: - emitInstJumpIf(build, pc, i, labelarr, /* not_ */ false); - break; - case LOP_JUMPIFNOT: - emitInstJumpIf(build, pc, i, labelarr, /* not_ */ true); - break; - case LOP_JUMPIFEQ: - emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback); - break; - case LOP_JUMPIFLE: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::LessEqual, fallback); - break; - case LOP_JUMPIFLT: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::Less, fallback); - break; - case LOP_JUMPIFNOTEQ: - emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback); - break; - case LOP_JUMPIFNOTLE: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLessEqual, fallback); - break; - case LOP_JUMPIFNOTLT: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLess, fallback); - break; - case LOP_JUMPX: - emitInstJumpX(build, pc, i, labelarr); - break; - case LOP_JUMPXEQKNIL: - emitInstJumpxEqNil(build, pc, i, labelarr); - break; - case LOP_JUMPXEQKB: - emitInstJumpxEqB(build, pc, i, labelarr); - break; - case LOP_JUMPXEQKN: - emitInstJumpxEqN(build, pc, proto->k, i, labelarr); - break; - case LOP_JUMPXEQKS: - emitInstJumpxEqS(build, pc, i, labelarr); - break; - case LOP_ADD: - emitInstBinary(build, pc, TM_ADD, fallback); - break; - case LOP_SUB: - emitInstBinary(build, pc, TM_SUB, fallback); - break; - case LOP_MUL: - emitInstBinary(build, pc, TM_MUL, fallback); - break; - case LOP_DIV: - emitInstBinary(build, pc, TM_DIV, fallback); - break; - case LOP_MOD: - emitInstBinary(build, pc, TM_MOD, fallback); - break; - case LOP_POW: - emitInstBinary(build, pc, TM_POW, fallback); - break; - case LOP_ADDK: - emitInstBinaryK(build, pc, TM_ADD, fallback); - break; - case LOP_SUBK: - emitInstBinaryK(build, pc, TM_SUB, fallback); - break; - case LOP_MULK: - emitInstBinaryK(build, pc, TM_MUL, fallback); - break; - case LOP_DIVK: - emitInstBinaryK(build, pc, TM_DIV, fallback); - break; - case LOP_MODK: - emitInstBinaryK(build, pc, TM_MOD, fallback); - break; - case LOP_POWK: - emitInstPowK(build, pc, proto->k, fallback); - break; - case LOP_NOT: - emitInstNot(build, pc); - break; - case LOP_MINUS: - emitInstMinus(build, pc, fallback); - break; - case LOP_LENGTH: - emitInstLength(build, pc, fallback); - break; - case LOP_NEWTABLE: - emitInstNewTable(build, pc, i, next); - break; - case LOP_DUPTABLE: - emitInstDupTable(build, pc, i, next); - break; - case LOP_SETLIST: - emitInstSetList(build, pc, next); - break; - case LOP_GETUPVAL: - emitInstGetUpval(build, pc); - break; - case LOP_SETUPVAL: - emitInstSetUpval(build, pc, next); - break; - case LOP_CLOSEUPVALS: - emitInstCloseUpvals(build, pc, next); - break; - case LOP_FASTCALL: - // We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1 - skip = emitInstFastCall(build, pc, i, next) + 1; - break; - case LOP_FASTCALL1: - // We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1 - skip = emitInstFastCall1(build, pc, i, next) + 1; - break; - case LOP_FASTCALL2: - skip = emitInstFastCall2(build, pc, i, next); - break; - case LOP_FASTCALL2K: - skip = emitInstFastCall2K(build, pc, i, next); - break; - case LOP_FORNPREP: - emitInstForNPrep(build, pc, i, next, labelarr[i + 1 + LUAU_INSN_D(*pc)]); - break; - case LOP_FORNLOOP: - emitInstForNLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next); - break; - case LOP_FORGLOOP: - emitinstForGLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next, fallback); - break; - case LOP_FORGPREP_NEXT: - emitInstForGPrepNext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback); - break; - case LOP_FORGPREP_INEXT: - emitInstForGPrepInext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback); - break; - case LOP_AND: - emitInstAnd(build, pc); - break; - case LOP_ANDK: - emitInstAndK(build, pc); - break; - case LOP_OR: - emitInstOr(build, pc); - break; - case LOP_ORK: - emitInstOrK(build, pc); - break; - case LOP_GETIMPORT: - emitInstGetImport(build, pc, fallback); - break; - case LOP_CONCAT: - emitInstConcat(build, pc, i, next); - break; - case LOP_COVERAGE: - emitInstCoverage(build, i); - break; - default: - emitFallback(build, data, op, i); - break; - } - - return skip; -} - -static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauOpcode op, const Instruction* pc, int i, Label* labelarr) -{ - switch (op) - { - case LOP_GETIMPORT: - emitSetSavedPc(build, i + 1); - emitInstGetImportFallback(build, LUAU_INSN_A(*pc), pc[1]); - break; - case LOP_GETTABLE: - emitInstGetTableFallback(build, pc, i); - break; - case LOP_SETTABLE: - emitInstSetTableFallback(build, pc, i); - break; - case LOP_GETTABLEN: - emitInstGetTableNFallback(build, pc, i); - break; - case LOP_SETTABLEN: - emitInstSetTableNFallback(build, pc, i); - break; - case LOP_NAMECALL: - // TODO: fast-paths that we've handled can be removed from the fallback - emitFallback(build, data, op, i); - break; - case LOP_JUMPIFEQ: - emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false); - break; - case LOP_JUMPIFLE: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::LessEqual); - break; - case LOP_JUMPIFLT: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::Less); - break; - case LOP_JUMPIFNOTEQ: - emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true); - break; - case LOP_JUMPIFNOTLE: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLessEqual); - break; - case LOP_JUMPIFNOTLT: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLess); - break; - case LOP_ADD: - emitInstBinaryFallback(build, pc, i, TM_ADD); - break; - case LOP_SUB: - emitInstBinaryFallback(build, pc, i, TM_SUB); - break; - case LOP_MUL: - emitInstBinaryFallback(build, pc, i, TM_MUL); - break; - case LOP_DIV: - emitInstBinaryFallback(build, pc, i, TM_DIV); - break; - case LOP_MOD: - emitInstBinaryFallback(build, pc, i, TM_MOD); - break; - case LOP_POW: - emitInstBinaryFallback(build, pc, i, TM_POW); - break; - case LOP_ADDK: - emitInstBinaryKFallback(build, pc, i, TM_ADD); - break; - case LOP_SUBK: - emitInstBinaryKFallback(build, pc, i, TM_SUB); - break; - case LOP_MULK: - emitInstBinaryKFallback(build, pc, i, TM_MUL); - break; - case LOP_DIVK: - emitInstBinaryKFallback(build, pc, i, TM_DIV); - break; - case LOP_MODK: - emitInstBinaryKFallback(build, pc, i, TM_MOD); - break; - case LOP_POWK: - emitInstBinaryKFallback(build, pc, i, TM_POW); - break; - case LOP_MINUS: - emitInstMinusFallback(build, pc, i); - break; - case LOP_LENGTH: - emitInstLengthFallback(build, pc, i); - break; - case LOP_FORGLOOP: - emitinstForGLoopFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); - break; - case LOP_FORGPREP_NEXT: - case LOP_FORGPREP_INEXT: - emitInstForGPrepXnextFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); - break; - case LOP_GETGLOBAL: - // TODO: luaV_gettable + cachedslot update instead of full fallback - emitFallback(build, data, op, i); - break; - case LOP_SETGLOBAL: - // TODO: luaV_settable + cachedslot update instead of full fallback - emitFallback(build, data, op, i); - break; - case LOP_GETTABLEKS: - // Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access - // It is also required to perform cached slot update - // TODO: extra fast-paths could be lowered before the full fallback - emitFallback(build, data, op, i); - break; - case LOP_SETTABLEKS: - // TODO: luaV_settable + cachedslot update instead of full fallback - emitFallback(build, data, op, i); - break; - default: - LUAU_ASSERT(!"Expected fallback for instruction"); - } -} - static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { NativeProto* result = new NativeProto(); @@ -423,154 +78,33 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat build.logAppend("\n"); } - if (!FFlag::DebugUseOldCodegen) - { - build.align(kFunctionAlignment, AlignmentDataX64::Ud2); - - Label start = build.setLabel(); - - IrBuilder builder; - builder.buildFunctionIr(proto); - - optimizeMemoryOperandsX64(builder.function); - - IrLoweringX64 lowering(build, helpers, data, proto, builder.function); - - lowering.lower(options); - - result->instTargets = new uintptr_t[proto->sizecode]; - - for (int i = 0; i < proto->sizecode; i++) - { - auto [irLocation, asmLocation] = builder.function.bcMapping[i]; - - result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation - start.location; - } - - result->location = start.location; - - if (build.logText) - build.logAppend("\n"); - - return result; - } - - std::vector(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); } @@ -803,12 +803,16 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") end )"); + ToStringOptions opts; + opts.hideFunctionSelfArgument = true; + TypeId parentTy = requireType("foo"); auto ttv = get(follow(parentTy)); - auto ftv = get(ttv->props.at("method").type); + REQUIRE_MESSAGE(ttv, "Expected a table but got " << toString(parentTy, opts)); + TypeId methodTy = follow(ttv->props.at("method").type); + auto ftv = get(methodTy); + REQUIRE_MESSAGE(ftv, "Expected a function but got " << toString(methodTy, opts)); - ToStringOptions opts; - opts.hideFunctionSelfArgument = true; CHECK_EQ("foo:method(arg: string): ()", toStringNamedFunction("foo:method", *ftv, opts)); } diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index c70ef5226..a2fc0c75e 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -855,16 +855,8 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni type FutureIntersection = A & B )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // To be quite honest, I don't know exactly why DCR fixes this. - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - // TODO: shared self causes this test to break in bizarre ways. - LUAU_REQUIRE_ERRORS(result); - } + // TODO: shared self causes this test to break in bizarre ways. + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index a267419e0..7c2e451a6 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -660,7 +660,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") )"); LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); /* * mergesort takes two arguments: an array of some type T and a function that takes two Ts. @@ -1424,9 +1423,11 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite") { CheckResult result = check(R"( -function string.len(): number - return 1 -end + function string.len(): number + return 1 + end + + local s = string )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1434,11 +1435,11 @@ end // if 'string' library property was replaced with an internal module type, it will be freed and the next check will crash frontend.clear(); - result = check(R"( -print(string.len('hello')) + CheckResult result2 = check(R"( + print(string.len('hello')) )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result2); } TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite_2") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 50056290b..ba0f975ee 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1404,6 +1404,40 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") } } +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_boolean") +{ + CheckResult result = check(R"( + local function f(x: number | boolean) + if typeof(x) == "boolean" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_thread") +{ + CheckResult result = check(R"( + local function f(x: number | thread) + if typeof(x) == "thread" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("thread", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number", toString(requireTypeAtPosition({5, 28}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index fcebd1fed..27b43aa9e 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -347,8 +347,8 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_3") const TableType* arg0Table = get(follow(arg0)); REQUIRE(arg0Table != nullptr); - REQUIRE(arg0Table->props.find("bar") != arg0Table->props.end()); - REQUIRE(arg0Table->props.find("baz") != arg0Table->props.end()); + CHECK(arg0Table->props.count("bar")); + CHECK(arg0Table->props.count("baz")); } TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") @@ -2482,12 +2482,18 @@ TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") { - CheckResult result = check("local a = {} a[0] = 7 a[0] = 't'"); + CheckResult result = check(R"( + local a = {} + a[0] = 7 + a[0] = 't' + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.stringType, - }})); + CHECK((Location{Position{3, 15}, Position{3, 18}}) == result.errors[0].location); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK(tm->wantedType == typeChecker.numberType); + CHECK(tm->givenType == typeChecker.stringType); } TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") @@ -2673,7 +2679,10 @@ TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") )"); ModulePtr module = getMainModule(); - CHECK_GE(100, module->internalTypes.types.size()); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_GE(500, module->internalTypes.types.size()); + else + CHECK_GE(100, module->internalTypes.types.size()); } TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers") diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index a22149c71..47b140a14 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -1,5 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" #include "Luau/Scope.h" +#include "Luau/Symbol.h" #include "Luau/TypeInfer.h" #include "Luau/Type.h" @@ -9,6 +11,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + struct TryUnifyFixture : Fixture { TypeArena arena; @@ -254,7 +258,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unifica LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK_EQ(toString(result.errors[0]), "No overload for function accepts 0 arguments."); - CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(toString(result.errors[1]), "Available overloads: ({V}, V) -> (); and ({V}, number, V) -> ()"); + else + CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index 6bfb93b2a..f17ada20e 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -230,7 +230,7 @@ TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") +TEST_CASE_FIXTURE(Fixture, "for_loop_over_never") { CheckResult result = check(R"( for i, v in (5 :: never) do diff --git a/tools/faillist.txt b/tools/faillist.txt index 0a09f3f64..5c84d1687 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -5,23 +5,16 @@ AstQuery.last_argument_function_call_type AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop -AutocompleteTest.autocomplete_first_function_arg_expected_type AutocompleteTest.autocomplete_oop_implicit_self -AutocompleteTest.autocomplete_string_singleton_equality -AutocompleteTest.do_compatible_self_calls -AutocompleteTest.do_wrong_compatible_self_calls AutocompleteTest.type_correct_expected_return_type_suggestion AutocompleteTest.type_correct_suggestion_for_overloads BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types +BuiltinTests.assert_removes_falsy_types2 BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy BuiltinTests.bad_select_should_not_crash -BuiltinTests.coroutine_wrap_anything_goes -BuiltinTests.debug_info_is_crazy -BuiltinTests.debug_traceback_is_crazy BuiltinTests.dont_add_definitions_to_persistent_types -BuiltinTests.find_capture_types3 BuiltinTests.gmatch_definition BuiltinTests.match_capture_types BuiltinTests.match_capture_types2 @@ -34,16 +27,14 @@ BuiltinTests.sort_with_bad_predicate BuiltinTests.string_format_as_method BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions +BuiltinTests.string_format_tostring_specifier_type_constraint BuiltinTests.string_format_use_correct_argument2 -BuiltinTests.table_freeze_is_generic -BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload BuiltinTests.table_pack BuiltinTests.table_pack_reduce BuiltinTests.table_pack_variadic DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props -DefinitionTests.definition_file_classes FrontendTest.environments FrontendTest.nocheck_cycle_used_by_checked GenericsTests.apply_type_function_nested_generics2 @@ -52,6 +43,7 @@ GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_infer_generic_functions +GenericsTests.dont_unify_bound_types GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_functions_should_be_memory_safe @@ -62,16 +54,13 @@ GenericsTests.infer_generic_function_function_argument_3 GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_lib_function_function_argument GenericsTests.instantiated_function_argument_names -GenericsTests.instantiation_sharing_types GenericsTests.no_stack_overflow_from_quantifying GenericsTests.self_recursive_instantiated_param -IntersectionTypes.select_correct_union_fn -IntersectionTypes.should_still_pick_an_overload_whose_arguments_are_unions +IntersectionTypes.overload_is_not_a_function IntersectionTypes.table_intersection_write_sealed IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect ModuleTests.clone_self_property -ModuleTests.deepClone_cyclic_table NonstrictModeTests.for_in_iterator_variables_are_any NonstrictModeTests.function_parameters_are_any NonstrictModeTests.inconsistent_module_return_types_are_ok @@ -85,7 +74,6 @@ NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon NonstrictModeTests.parameters_having_type_any_are_optional NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_props_are_any -Normalize.cyclic_table_normalizes_sensibly ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.bail_early_if_unification_is_too_complicated ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack @@ -93,31 +81,28 @@ ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean ProvisionalTests.free_options_cannot_be_unified_together ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns -ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete -ProvisionalTests.weirditer_should_not_loop_forever RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string -RefinementTest.discriminate_tag +RefinementTest.discriminate_from_isa_of_x RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage +RefinementTest.refine_param_of_type_folder_or_part_without_using_typeof RefinementTest.refine_unknowns RefinementTest.type_guard_can_filter_for_intersection_of_tables -RefinementTest.type_narrow_for_all_the_userdata RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table +RefinementTest.x_is_not_instance_or_else_not_part RuntimeLimits.typescript_port_of_Result_type -TableTests.a_free_shape_can_turn_into_a_scalar_directly TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.accidentally_checked_prop_in_opposite_branch TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode -TableTests.call_method TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.checked_prop_too_early @@ -135,7 +120,6 @@ TableTests.explicitly_typed_table_with_indexer TableTests.found_like_key_in_table_function_call TableTests.found_like_key_in_table_property_access TableTests.found_multiple_like_keys -TableTests.function_calls_produces_sealed_table_given_unsealed_table TableTests.fuzz_table_unify_instantiated_table TableTests.generic_table_instantiation_potential_regression TableTests.give_up_after_one_metatable_index_look_up @@ -144,21 +128,16 @@ TableTests.indexing_from_a_table_should_prefer_properties_when_possible TableTests.inequality_operators_imply_exactly_matching_types TableTests.infer_array_2 TableTests.inferred_return_type_of_free_table -TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 -TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound -TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors TableTests.less_exponential_blowup_please TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys TableTests.nil_assign_doesnt_hit_indexer TableTests.nil_assign_doesnt_hit_no_indexer -TableTests.okay_to_add_property_to_unsealed_tables_by_function_call +TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr TableTests.only_ascribe_synthetic_names_at_module_scope -TableTests.oop_indexer_works TableTests.oop_polymorphic -TableTests.open_table_unification_2 TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table @@ -169,32 +148,21 @@ TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables TableTests.table_call_metamethod_basic -TableTests.table_indexing_error_location -TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict -TableTests.table_insert_should_cope_with_optional_properties_in_strict -TableTests.table_param_row_polymorphism_3 TableTests.table_simple_call -TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.table_unification_4 -TableTests.unifying_tables_shouldnt_uaf2 TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon -ToString.exhaustive_toString_of_cyclic_table +TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type ToString.named_metatable_toStringNamedFunction ToString.toStringDetailed2 ToString.toStringErrorPack ToString.toStringNamedFunction_generic_pack -ToString.toStringNamedFunction_hide_self_param -ToString.toStringNamedFunction_include_self_param ToString.toStringNamedFunction_map -TryUnifyTests.cli_41095_concat_log_in_sealed_table_unification TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails -TryUnifyTests.variadics_should_use_reversed_properly TypeAliases.cannot_create_cyclic_type_with_unknown_module -TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any TypeAliases.generic_param_remap TypeAliases.mismatched_generic_type_param TypeAliases.mutually_recursive_types_restriction_not_ok_1 @@ -218,11 +186,9 @@ TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict TypeInfer.no_stack_overflow_from_isoptional TypeInfer.no_stack_overflow_from_isoptional2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error -TypeInfer.tc_if_else_expressions_expected_type_3 TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 -TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.classes_without_overloaded_operators_cannot_be_added TypeInferClasses.higher_order_function_arguments_are_contravariant @@ -232,6 +198,7 @@ TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_propert TypeInferClasses.warn_when_prop_almost_matches TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.cannot_hoist_interior_defns_into_signature +TypeInferFunctions.check_function_before_lambda_that_uses_it TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict @@ -243,10 +210,7 @@ TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.infer_anonymous_function_arguments -TypeInferFunctions.infer_return_type_from_selected_overload TypeInferFunctions.infer_that_function_does_not_return_a_table -TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count -TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.no_lossy_function_type TypeInferFunctions.occurs_check_failure_in_function_return_type @@ -273,13 +237,11 @@ TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated TypeInferModules.type_error_of_unknown_qualified_type -TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted TypeInferOOP.object_constructor_can_refer_to_method_of_self TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.CallOrOfFunctions -TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators TypeInferOperators.cli_38355_recursive_union @@ -303,25 +265,21 @@ TypeInferUnknownNever.math_operators_and_never TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check TypePackTests.type_alias_backwards_compatible -TypePackTests.type_alias_default_export TypePackTests.type_alias_default_mixed_self -TypePackTests.type_alias_default_type_chained TypePackTests.type_alias_default_type_errors TypePackTests.type_alias_default_type_pack_self_chained_tp TypePackTests.type_alias_default_type_pack_self_tp -TypePackTests.type_alias_default_type_self TypePackTests.type_alias_defaults_confusing_types -TypePackTests.type_alias_defaults_recursive_type TypePackTests.type_alias_type_pack_multi TypePackTests.type_alias_type_pack_variadic TypePackTests.type_alias_type_packs_errors TypePackTests.type_alias_type_packs_nested TypePackTests.unify_variadic_tails_in_arguments -TypePackTests.unify_variadic_tails_in_arguments_free TypePackTests.variadic_packs TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch TypeSingletons.indexing_on_union_of_string_singletons +TypeSingletons.no_widening_from_callsites TypeSingletons.overloaded_function_call_with_singletons TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened From e58bb1b27ff356cb53bff94f9e45b80bac076dac Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 24 Feb 2023 10:29:24 -0800 Subject: [PATCH 38/66] GCC fix. --- CLI/Reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CLI/Reduce.cpp b/CLI/Reduce.cpp index d24c9874c..b7c780128 100644 --- a/CLI/Reduce.cpp +++ b/CLI/Reduce.cpp @@ -487,7 +487,7 @@ int main(int argc, char** argv) if (args.size() < 4) help(args); - for (int i = 1; i < args.size(); ++i) + for (size_t i = 1; i < args.size(); ++i) { if (args[i] == "--help") help(args); From 9a281f04923af6382dcebd65ebfcfede4c697358 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 3 Mar 2023 15:45:38 +0200 Subject: [PATCH 39/66] Sync to upstream/release/566 --- Analysis/include/Luau/Breadcrumb.h | 75 ++ Analysis/include/Luau/Constraint.h | 1 - .../include/Luau/ConstraintGraphBuilder.h | 13 +- Analysis/include/Luau/ConstraintSolver.h | 4 +- Analysis/include/Luau/DataFlowGraph.h | 120 ++-- Analysis/include/Luau/Def.h | 8 - Analysis/include/Luau/Frontend.h | 21 +- Analysis/include/Luau/Refinement.h | 8 +- Analysis/include/Luau/Scope.h | 1 + Analysis/include/Luau/Type.h | 3 +- Analysis/include/Luau/TypeReduction.h | 4 +- Analysis/include/Luau/TypeUtils.h | 3 +- Analysis/include/Luau/Unifier.h | 6 +- Analysis/src/Autocomplete.cpp | 16 +- Analysis/src/BuiltinDefinitions.cpp | 18 + Analysis/src/ConstraintGraphBuilder.cpp | 455 +++++++----- Analysis/src/ConstraintSolver.cpp | 237 +++---- Analysis/src/DataFlowGraph.cpp | 654 +++++++++++++----- Analysis/src/Def.cpp | 7 +- Analysis/src/Frontend.cpp | 59 +- Analysis/src/Linter.cpp | 43 +- Analysis/src/Normalize.cpp | 8 + Analysis/src/Refinement.cpp | 14 +- Analysis/src/Type.cpp | 1 + Analysis/src/TypeChecker2.cpp | 33 +- Analysis/src/TypeInfer.cpp | 15 +- Analysis/src/TypeReduction.cpp | 8 + Analysis/src/TypeUtils.cpp | 18 +- Analysis/src/Unifier.cpp | 19 +- Ast/src/Lexer.cpp | 5 +- Ast/src/Parser.cpp | 8 +- CLI/Reduce.cpp | 2 +- CMakeLists.txt | 16 +- CodeGen/include/Luau/AddressA64.h | 3 + CodeGen/include/Luau/AssemblyBuilderA64.h | 3 + CodeGen/include/Luau/AssemblyBuilderX64.h | 3 + CodeGen/include/Luau/ConditionA64.h | 3 + CodeGen/include/Luau/IrAnalysis.h | 10 + CodeGen/include/Luau/IrBuilder.h | 6 + CodeGen/include/Luau/IrData.h | 82 ++- CodeGen/include/Luau/IrUtils.h | 10 + CodeGen/include/Luau/OperandX64.h | 3 + CodeGen/include/Luau/RegisterA64.h | 5 + CodeGen/include/Luau/RegisterX64.h | 3 + CodeGen/include/Luau/UnwindBuilder.h | 6 +- CodeGen/include/Luau/UnwindBuilderDwarf2.h | 6 +- CodeGen/include/Luau/UnwindBuilderWin.h | 8 +- CodeGen/src/AssemblyBuilderA64.cpp | 5 +- CodeGen/src/AssemblyBuilderX64.cpp | 4 + CodeGen/src/CodeGen.cpp | 14 +- CodeGen/src/CodeGenX64.cpp | 4 +- CodeGen/src/CodeGenX64.h | 4 +- CodeGen/src/EmitBuiltinsX64.cpp | 568 ++++----------- CodeGen/src/EmitBuiltinsX64.h | 12 +- CodeGen/src/EmitCommon.h | 29 + CodeGen/src/EmitCommonX64.cpp | 28 +- CodeGen/src/EmitCommonX64.h | 28 +- CodeGen/src/EmitInstructionX64.cpp | 180 +---- CodeGen/src/EmitInstructionX64.h | 17 +- CodeGen/src/IrAnalysis.cpp | 43 ++ CodeGen/src/IrBuilder.cpp | 84 ++- CodeGen/src/IrDump.cpp | 20 +- CodeGen/src/IrLoweringX64.cpp | 190 +++-- CodeGen/src/IrLoweringX64.h | 6 +- CodeGen/src/IrRegAllocX64.cpp | 23 +- CodeGen/src/IrRegAllocX64.h | 7 + CodeGen/src/IrTranslateBuiltins.cpp | 223 ++++++ CodeGen/src/IrTranslation.cpp | 37 +- CodeGen/src/IrTranslation.h | 3 +- CodeGen/src/IrUtils.cpp | 18 + CodeGen/src/OptimizeConstProp.cpp | 221 +++++- CodeGen/src/OptimizeFinalX64.cpp | 2 + CodeGen/src/UnwindBuilderDwarf2.cpp | 6 +- CodeGen/src/UnwindBuilderWin.cpp | 6 +- Compiler/src/Builtins.cpp | 153 ++++ Compiler/src/Builtins.h | 8 + Compiler/src/BytecodeBuilder.cpp | 5 +- Compiler/src/Compiler.cpp | 17 +- Sources.cmake | 2 + VM/src/lstrlib.cpp | 2 +- tests/AssemblyBuilderA64.test.cpp | 1 + tests/AssemblyBuilderX64.test.cpp | 1 + tests/AstJsonEncoder.test.cpp | 2 +- tests/Autocomplete.test.cpp | 43 ++ tests/CodeAllocator.test.cpp | 28 +- tests/Compiler.test.cpp | 131 +++- tests/DataFlowGraph.test.cpp | 29 +- tests/Fixture.cpp | 13 +- tests/IrBuilder.test.cpp | 188 ++++- tests/Lexer.test.cpp | 30 + tests/Linter.test.cpp | 48 +- tests/Normalize.test.cpp | 14 + tests/Parser.test.cpp | 47 ++ tests/TypeInfer.oop.test.cpp | 18 + tests/TypeInfer.provisional.test.cpp | 14 - tests/TypeInfer.refinements.test.cpp | 186 ++++- tests/TypeInfer.tables.test.cpp | 54 +- tests/TypeInfer.test.cpp | 3 + tests/conformance/math.lua | 5 + tools/faillist.txt | 34 +- 100 files changed, 3357 insertions(+), 1555 deletions(-) create mode 100644 Analysis/include/Luau/Breadcrumb.h create mode 100644 CodeGen/src/EmitCommon.h diff --git a/Analysis/include/Luau/Breadcrumb.h b/Analysis/include/Luau/Breadcrumb.h new file mode 100644 index 000000000..59b293a0b --- /dev/null +++ b/Analysis/include/Luau/Breadcrumb.h @@ -0,0 +1,75 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Def.h" +#include "Luau/NotNull.h" +#include "Luau/Variant.h" + +#include +#include + +namespace Luau +{ + +using NullableBreadcrumbId = const struct Breadcrumb*; +using BreadcrumbId = NotNull; + +struct FieldMetadata +{ + std::string prop; +}; + +struct SubscriptMetadata +{ + BreadcrumbId key; +}; + +using Metadata = Variant; + +struct Breadcrumb +{ + NullableBreadcrumbId previous; + DefId def; + std::optional metadata; + std::vector children; +}; + +inline Breadcrumb* asMutable(NullableBreadcrumbId breadcrumb) +{ + LUAU_ASSERT(breadcrumb); + return const_cast(breadcrumb); +} + +template +const T* getMetadata(NullableBreadcrumbId breadcrumb) +{ + if (!breadcrumb || !breadcrumb->metadata) + return nullptr; + + return get_if(&*breadcrumb->metadata); +} + +struct BreadcrumbArena +{ + TypedAllocator allocator; + + template + BreadcrumbId add(NullableBreadcrumbId previous, DefId def, Args&&... args) + { + Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, std::forward(args)...}); + if (previous) + asMutable(previous)->children.push_back(NotNull{bc}); + return NotNull{bc}; + } + + template + BreadcrumbId emplace(NullableBreadcrumbId previous, DefId def, Args&&... args) + { + Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, Metadata{T{std::forward(args)...}}}); + if (previous) + asMutable(previous)->children.push_back(NotNull{bc}); + return NotNull{bc}; + } +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 1c41bbb7f..2223c29e0 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -2,7 +2,6 @@ #pragma once #include "Luau/Ast.h" // Used for some of the enumerations -#include "Luau/Def.h" #include "Luau/DenseHash.h" #include "Luau/NotNull.h" #include "Luau/Type.h" diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 7b2711f89..e79c4c91e 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -224,7 +224,7 @@ struct ConstraintGraphBuilder * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type of the AST annotation. **/ - TypeId resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments); + TypeId resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh = false); /** * Resolves a type pack from its AST annotation. @@ -233,7 +233,7 @@ struct ConstraintGraphBuilder * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type pack of the AST annotation. **/ - TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArguments); + TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArguments, bool replaceErrorWithFresh = false); /** * Resolves a type pack from its AST annotation. @@ -242,7 +242,7 @@ struct ConstraintGraphBuilder * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type pack of the AST annotation. **/ - TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments); + TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh = false); /** * Creates generic types given a list of AST definitions, resolving default @@ -282,10 +282,17 @@ struct ConstraintGraphBuilder * initial scan of the AST and note what globals are defined. */ void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); + + /** Given a function type annotation, return a vector describing the expected types of the calls to the function + * For example, calling a function with annotation ((number) -> string & ((string) -> number)) + * yields a vector of size 1, with value: [number | string] + */ + std::vector> getExpectedCallTypesForFunctionOverloads(const TypeId fnType); }; /** Borrow a vector of pointers from a vector of owning pointers to constraints. */ std::vector> borrowConstraints(const std::vector& constraints); + } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 62687ae47..4fd7d0d10 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -132,8 +132,8 @@ struct ConstraintSolver bool tryDispatchIterableFunction( TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); - std::optional lookupTableProp(TypeId subjectType, const std::string& propName); - std::optional lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); + std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName); + std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); void block(NotNull target, NotNull constraint); /** diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index bd096ea90..ce4ecb04c 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -3,6 +3,7 @@ // Do not include LValue. It should never be used here. #include "Luau/Ast.h" +#include "Luau/Breadcrumb.h" #include "Luau/DenseHash.h" #include "Luau/Def.h" #include "Luau/Symbol.h" @@ -17,16 +18,14 @@ struct DataFlowGraph DataFlowGraph(DataFlowGraph&&) = default; DataFlowGraph& operator=(DataFlowGraph&&) = default; - // TODO: AstExprLocal, AstExprGlobal, and AstLocal* are guaranteed never to return nullopt. - // We leave them to return an optional as we build it out, but the end state is for them to return a non-optional DefId. - std::optional getDef(const AstExpr* expr) const; - std::optional getDef(const AstLocal* local) const; + NullableBreadcrumbId getBreadcrumb(const AstExpr* expr) const; - /// Retrieve the Def that corresponds to the given Symbol. - /// - /// We do not perform dataflow analysis on globals, so this function always - /// yields nullopt when passed a global Symbol. - std::optional getDef(const Symbol& symbol) const; + BreadcrumbId getBreadcrumb(const AstLocal* local) const; + BreadcrumbId getBreadcrumb(const AstExprLocal* local) const; + BreadcrumbId getBreadcrumb(const AstExprGlobal* global) const; + + BreadcrumbId getBreadcrumb(const AstStatDeclareGlobal* global) const; + BreadcrumbId getBreadcrumb(const AstStatDeclareFunction* func) const; private: DataFlowGraph() = default; @@ -34,9 +33,17 @@ struct DataFlowGraph DataFlowGraph(const DataFlowGraph&) = delete; DataFlowGraph& operator=(const DataFlowGraph&) = delete; - DefArena arena; - DenseHashMap astDefs{nullptr}; - DenseHashMap localDefs{nullptr}; + DefArena defs; + BreadcrumbArena breadcrumbs; + + DenseHashMap astBreadcrumbs{nullptr}; + + // Sometimes we don't have the AstExprLocal* but we have AstLocal*, and sometimes we need to extract that DefId. + DenseHashMap localBreadcrumbs{nullptr}; + + // There's no AstStatDeclaration, and it feels useless to introduce it just to enforce an invariant in one place. + // All keys in this maps are really only statements that ambiently declares a symbol. + DenseHashMap declaredBreadcrumbs{nullptr}; friend struct DataFlowGraphBuilder; }; @@ -44,12 +51,11 @@ struct DataFlowGraph struct DfgScope { DfgScope* parent; - DenseHashMap bindings{Symbol{}}; -}; + DenseHashMap bindings{Symbol{}}; + DenseHashMap> props{nullptr}; -struct ExpressionFlowGraph -{ - std::optional def; + NullableBreadcrumbId lookup(Symbol symbol) const; + NullableBreadcrumbId lookup(DefId def, const std::string& key) const; }; // Currently unsound. We do not presently track the control flow of the program. @@ -65,23 +71,19 @@ struct DataFlowGraphBuilder DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; DataFlowGraph graph; - NotNull arena{&graph.arena}; - struct InternalErrorReporter* handle; - std::vector> scopes; + NotNull defs{&graph.defs}; + NotNull breadcrumbs{&graph.breadcrumbs}; - // Does not belong in DataFlowGraphBuilder, but the old solver allows properties to escape the scope they were defined in, - // so we will need to be able to emulate this same behavior here too. We can kill this once we have better flow sensitivity. - DenseHashMap> props{nullptr}; + struct InternalErrorReporter* handle = nullptr; + DfgScope* moduleScope = nullptr; - DfgScope* childScope(DfgScope* scope); + std::vector> scopes; - std::optional use(DfgScope* scope, Symbol symbol, AstExpr* e); - DefId use(DefId def, AstExprIndexName* e); + DfgScope* childScope(DfgScope* scope); void visit(DfgScope* scope, AstStatBlock* b); void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); - // TODO: visit type aliases void visit(DfgScope* scope, AstStat* s); void visit(DfgScope* scope, AstStatIf* i); void visit(DfgScope* scope, AstStatWhile* w); @@ -97,24 +99,52 @@ struct DataFlowGraphBuilder void visit(DfgScope* scope, AstStatCompoundAssign* c); void visit(DfgScope* scope, AstStatFunction* f); void visit(DfgScope* scope, AstStatLocalFunction* l); - - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExpr* e); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprLocal* l); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprGlobal* g); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprCall* c); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexName* i); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexExpr* i); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprFunction* f); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTable* t); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprUnary* u); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprBinary* b); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTypeAssertion* t); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIfElse* i); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprInterpString* i); - - // TODO: visitLValue - // TODO: visitTypes (because of typeof which has access to values namespace, needs unreachable scope) - // TODO: visitTypePacks (because of typeof which has access to values namespace, needs unreachable scope) + void visit(DfgScope* scope, AstStatTypeAlias* t); + void visit(DfgScope* scope, AstStatDeclareGlobal* d); + void visit(DfgScope* scope, AstStatDeclareFunction* d); + void visit(DfgScope* scope, AstStatDeclareClass* d); + void visit(DfgScope* scope, AstStatError* error); + + BreadcrumbId visitExpr(DfgScope* scope, AstExpr* e); + BreadcrumbId visitExpr(DfgScope* scope, AstExprLocal* l); + BreadcrumbId visitExpr(DfgScope* scope, AstExprGlobal* g); + BreadcrumbId visitExpr(DfgScope* scope, AstExprCall* c); + BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexName* i); + BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexExpr* i); + BreadcrumbId visitExpr(DfgScope* scope, AstExprFunction* f); + BreadcrumbId visitExpr(DfgScope* scope, AstExprTable* t); + BreadcrumbId visitExpr(DfgScope* scope, AstExprUnary* u); + BreadcrumbId visitExpr(DfgScope* scope, AstExprBinary* b); + BreadcrumbId visitExpr(DfgScope* scope, AstExprTypeAssertion* t); + BreadcrumbId visitExpr(DfgScope* scope, AstExprIfElse* i); + BreadcrumbId visitExpr(DfgScope* scope, AstExprInterpString* i); + BreadcrumbId visitExpr(DfgScope* scope, AstExprError* error); + + void visitLValue(DfgScope* scope, AstExpr* e); + void visitLValue(DfgScope* scope, AstExprLocal* l); + void visitLValue(DfgScope* scope, AstExprGlobal* g); + void visitLValue(DfgScope* scope, AstExprIndexName* i); + void visitLValue(DfgScope* scope, AstExprIndexExpr* i); + void visitLValue(DfgScope* scope, AstExprError* e); + + void visitType(DfgScope* scope, AstType* t); + void visitType(DfgScope* scope, AstTypeReference* r); + void visitType(DfgScope* scope, AstTypeTable* t); + void visitType(DfgScope* scope, AstTypeFunction* f); + void visitType(DfgScope* scope, AstTypeTypeof* t); + void visitType(DfgScope* scope, AstTypeUnion* u); + void visitType(DfgScope* scope, AstTypeIntersection* i); + void visitType(DfgScope* scope, AstTypeError* error); + + void visitTypePack(DfgScope* scope, AstTypePack* p); + void visitTypePack(DfgScope* scope, AstTypePackExplicit* e); + void visitTypePack(DfgScope* scope, AstTypePackVariadic* v); + void visitTypePack(DfgScope* scope, AstTypePackGeneric* g); + + void visitTypeList(DfgScope* scope, AstTypeList l); + + void visitGenerics(DfgScope* scope, AstArray g); + void visitGenericPacks(DfgScope* scope, AstArray g); }; } // namespace Luau diff --git a/Analysis/include/Luau/Def.h b/Analysis/include/Luau/Def.h index 1eef7dfdc..10d81367e 100644 --- a/Analysis/include/Luau/Def.h +++ b/Analysis/include/Luau/Def.h @@ -14,12 +14,6 @@ namespace Luau struct Def; using DefId = NotNull; -struct FieldMetadata -{ - DefId parent; - std::string propName; -}; - /** * A cell is a "single-object" value. * @@ -29,7 +23,6 @@ struct FieldMetadata */ struct Cell { - std::optional field; }; /** @@ -83,7 +76,6 @@ struct DefArena TypedAllocator allocator; DefId freshCell(); - DefId freshCell(DefId parent, const std::string& prop); // TODO: implement once we have cases where we need to merge in definitions // DefId phi(const std::vector& defs); }; diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 403551f67..7c5dc4a0d 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -162,8 +162,7 @@ struct Frontend ScopePtr getGlobalScope(); private: - ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, - bool forAutocomplete = false); + ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete = false, bool recordJsonLog = false); std::pair getSourceNode(const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); @@ -202,16 +201,12 @@ struct Frontend ScopePtr globalScope; }; -ModulePtr check( - const SourceModule& sourceModule, - const std::vector& requireCycles, - NotNull builtinTypes, - NotNull iceHandler, - NotNull moduleResolver, - NotNull fileResolver, - const ScopePtr& globalScope, - NotNull unifierState, - FrontendOptions options -); +ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, + NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, + const ScopePtr& globalScope, FrontendOptions options); + +ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, + NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, + const ScopePtr& globalScope, FrontendOptions options, bool recordJsonLog); } // namespace Luau diff --git a/Analysis/include/Luau/Refinement.h b/Analysis/include/Luau/Refinement.h index e7d3cf23b..fecf459ad 100644 --- a/Analysis/include/Luau/Refinement.h +++ b/Analysis/include/Luau/Refinement.h @@ -1,13 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Def.h" +#include "Luau/NotNull.h" #include "Luau/TypedAllocator.h" #include "Luau/Variant.h" namespace Luau { +using BreadcrumbId = NotNull; + struct Type; using TypeId = const Type*; @@ -50,7 +52,7 @@ struct Equivalence struct Proposition { - DefId def; + BreadcrumbId breadcrumb; TypeId discriminantTy; }; @@ -67,7 +69,7 @@ struct RefinementArena RefinementId conjunction(RefinementId lhs, RefinementId rhs); RefinementId disjunction(RefinementId lhs, RefinementId rhs); RefinementId equivalence(RefinementId lhs, RefinementId rhs); - RefinementId proposition(DefId def, TypeId discriminantTy); + RefinementId proposition(BreadcrumbId breadcrumb, TypeId discriminantTy); private: TypedAllocator allocator; diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 85a36fc90..0d3972672 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Def.h" #include "Luau/Location.h" #include "Luau/NotNull.h" #include "Luau/Type.h" diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index d009001b6..cf1f8dae4 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -4,9 +4,7 @@ #include "Luau/Ast.h" #include "Luau/Common.h" #include "Luau/Refinement.h" -#include "Luau/DataFlowGraph.h" #include "Luau/DenseHash.h" -#include "Luau/Def.h" #include "Luau/NotNull.h" #include "Luau/Predicate.h" #include "Luau/Unifiable.h" @@ -662,6 +660,7 @@ struct BuiltinTypes const TypeId functionType; const TypeId classType; const TypeId tableType; + const TypeId emptyTableType; const TypeId trueType; const TypeId falseType; const TypeId anyType; diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h index 80a7ac596..3f64870ab 100644 --- a/Analysis/include/Luau/TypeReduction.h +++ b/Analysis/include/Luau/TypeReduction.h @@ -54,8 +54,8 @@ struct TypeReductionOptions struct TypeReduction { - explicit TypeReduction( - NotNull arena, NotNull builtinTypes, NotNull handle, const TypeReductionOptions& opts = {}); + explicit TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle, + const TypeReductionOptions& opts = {}); TypeReduction(const TypeReduction&) = delete; TypeReduction& operator=(const TypeReduction&) = delete; diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 3f535a03f..42ba40522 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -27,7 +27,8 @@ std::pair> getParameterExtents(const TxnLog* log, // Extend the provided pack to at least `length` types. // Returns a temporary TypePack that contains those types plus a tail. -TypePack extendTypePack(TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length); +TypePack extendTypePack( + TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length, std::vector> overrides = {}); /** * Reduces a union by decomposing to the any/error type if it appears in the diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index ebfff4c29..50024e3fd 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -61,8 +61,9 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; - bool normalize; // Normalize unions and intersections if necessary - bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels + bool normalize = true; // Normalize unions and intersections if necessary + bool checkInhabited = true; // Normalize types to check if they are inhabited + bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; UnifierSharedState& sharedState; @@ -155,5 +156,6 @@ struct Unifier }; void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); +std::optional hasUnificationTooComplex(const ErrorVec& errors); } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 85e27168a..1e0949711 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInWhile, false); LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInFor, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompleteSkipNormalization, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -145,6 +146,13 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); + if (FFlag::LuauAutocompleteSkipNormalization) + { + // Cost of normalization can be too high for autocomplete response time requirements + unifier.normalize = false; + unifier.checkInhabited = false; + } + return unifier.canUnify(subTy, superTy).empty(); } @@ -314,7 +322,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul { autocompleteProps(module, typeArena, builtinTypes, rootTy, mt->table, indexType, nodes, result, seen); - if (auto mtable = get(mt->metatable)) + if (auto mtable = get(follow(mt->metatable))) fillMetatableProps(mtable); } else if (auto i = get(ty)) @@ -1528,9 +1536,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; } - else if (AstStatIf* statIf = extractStat(ancestry); - statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && - (statIf->condition && !statIf->condition->location.containsClosed(position))) + else if (AstStatIf* statIf = extractStat(ancestry); statIf && + (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && + (statIf->condition && !statIf->condition->location.containsClosed(position))) { AutocompleteEntryMap ret; ret["then"] = {AutocompleteEntryKind::Keyword}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 7bb57208c..b111c504a 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -15,6 +15,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauDeprecateTableGetnForeach, false) + /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk @@ -335,6 +337,14 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + if (FFlag::LuauDeprecateTableGetnForeach) + { + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + ttv->props["foreach"].deprecated = true; + ttv->props["foreachi"].deprecated = true; + } + attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); } @@ -428,6 +438,14 @@ void registerBuiltinGlobals(Frontend& frontend) ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + if (FFlag::LuauDeprecateTableGetnForeach) + { + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + ttv->props["foreach"].deprecated = true; + ttv->props["foreachi"].deprecated = true; + } + attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index fe412632c..9ee2b0882 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -2,11 +2,13 @@ #include "Luau/ConstraintGraphBuilder.h" #include "Luau/Ast.h" +#include "Luau/Breadcrumb.h" #include "Luau/Common.h" #include "Luau/Constraint.h" #include "Luau/DcrLogger.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" +#include "Luau/Refinement.h" #include "Luau/Scope.h" #include "Luau/TypeUtils.h" #include "Luau/Type.h" @@ -14,7 +16,6 @@ #include LUAU_FASTINT(LuauCheckRecursionLimit); -LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauNegatedClassTypes); LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration); @@ -145,9 +146,6 @@ ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, Mod , globalScope(globalScope) , logger(logger) { - if (FFlag::DebugLuauLogSolverToJson) - LUAU_ASSERT(logger); - LUAU_ASSERT(module); } @@ -186,29 +184,42 @@ NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, return NotNull{constraints.emplace_back(std::move(c)).get()}; } -static void unionRefinements(const std::unordered_map& lhs, const std::unordered_map& rhs, - std::unordered_map& dest, NotNull arena) +struct RefinementPartition +{ + // Types that we want to intersect against the type of the expression. + std::vector discriminantTypes; + + // Sometimes the type we're discriminating against is implicitly nil. + bool shouldAppendNilType = false; +}; + +using RefinementContext = std::unordered_map; + +static void unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, NotNull arena) { - for (auto [def, ty] : lhs) + for (auto& [def, partition] : lhs) { auto rhsIt = rhs.find(def); if (rhsIt == rhs.end()) continue; - std::vector discriminants{{ty, rhsIt->second}}; + LUAU_ASSERT(!partition.discriminantTypes.empty()); + LUAU_ASSERT(!rhsIt->second.discriminantTypes.empty()); - if (auto destIt = dest.find(def); destIt != dest.end()) - discriminants.push_back(destIt->second); + TypeId leftDiscriminantTy = + partition.discriminantTypes.size() == 1 ? partition.discriminantTypes[0] : arena->addType(IntersectionType{partition.discriminantTypes}); - dest[def] = arena->addType(UnionType{std::move(discriminants)}); + TypeId rightDiscriminantTy = rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] + : arena->addType(IntersectionType{rhsIt->second.discriminantTypes}); + + dest[def].discriminantTypes.push_back(arena->addType(UnionType{{leftDiscriminantTy, rightDiscriminantTy}})); + dest[def].shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; } } -static void computeRefinement(const ScopePtr& scope, RefinementId refinement, std::unordered_map* refis, bool sense, - NotNull arena, bool eq, std::vector* constraints) +static void computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, NotNull arena, bool eq, + std::vector* constraints) { - using RefinementMap = std::unordered_map; - if (!refinement) return; else if (auto variadic = get(refinement)) @@ -220,8 +231,8 @@ static void computeRefinement(const ScopePtr& scope, RefinementId refinement, st return computeRefinement(scope, negation->refinement, refis, !sense, arena, eq, constraints); else if (auto conjunction = get(refinement)) { - RefinementMap lhsRefis; - RefinementMap rhsRefis; + RefinementContext lhsRefis; + RefinementContext rhsRefis; computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); @@ -231,8 +242,8 @@ static void computeRefinement(const ScopePtr& scope, RefinementId refinement, st } else if (auto disjunction = get(refinement)) { - RefinementMap lhsRefis; - RefinementMap rhsRefis; + RefinementContext lhsRefis; + RefinementContext rhsRefis; computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); @@ -256,30 +267,37 @@ static void computeRefinement(const ScopePtr& scope, RefinementId refinement, st constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy, !sense}); } - if (auto it = refis->find(proposition->def); it != refis->end()) - (*refis)[proposition->def] = arena->addType(IntersectionType{{discriminantTy, it->second}}); - else - (*refis)[proposition->def] = discriminantTy; - } -} + RefinementContext uncommittedRefis; + uncommittedRefis[proposition->breadcrumb->def].discriminantTypes.push_back(discriminantTy); -static std::pair computeDiscriminantType(NotNull arena, const ScopePtr& scope, DefId def, TypeId discriminantTy) -{ - LUAU_ASSERT(get(def)); + // When the top-level expression is `t[x]`, we want to refine it into `nil`, not `never`. + if ((sense || !eq) && getMetadata(proposition->breadcrumb)) + uncommittedRefis[proposition->breadcrumb->def].shouldAppendNilType = true; - while (const Cell* current = get(def)) - { - if (!current->field) - break; + for (NullableBreadcrumbId current = proposition->breadcrumb; current && current->previous; current = current->previous) + { + LUAU_ASSERT(get(current->def)); - TableType::Props props{{current->field->propName, Property{discriminantTy}}}; - discriminantTy = arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); + // If this current breadcrumb has no metadata, it's no-op for the purpose of building a discriminant type. + if (!current->metadata) + continue; + else if (auto field = getMetadata(current)) + { + TableType::Props props{{field->prop, Property{discriminantTy}}}; + discriminantTy = arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); + uncommittedRefis[current->previous->def].discriminantTypes.push_back(discriminantTy); + } + } - def = current->field->parent; - current = get(def); - } + // And now it's time to commit it. + for (auto& [def, partition] : uncommittedRefis) + { + for (TypeId discriminantTy : partition.discriminantTypes) + (*refis)[def].discriminantTypes.push_back(discriminantTy); - return {def, discriminantTy}; + (*refis)[def].shouldAppendNilType |= partition.shouldAppendNilType; + } + } } void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement) @@ -287,19 +305,21 @@ void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location lo if (!refinement) return; - std::unordered_map refinements; + RefinementContext refinements; std::vector constraints; computeRefinement(scope, refinement, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); - for (auto [def, discriminantTy] : refinements) + for (auto& [def, partition] : refinements) { - auto [def2, discriminantTy2] = computeDiscriminantType(arena, scope, def, discriminantTy); - std::optional defTy = scope->lookup(def2); - if (!defTy) - ice->ice("Every DefId must map to a type!"); + if (std::optional defTy = scope->lookup(def)) + { + TypeId ty = *defTy; + if (partition.shouldAppendNilType) + ty = arena->addType(UnionType{{ty, builtinTypes->nilType}}); - TypeId resultTy = arena->addType(IntersectionType{{*defTy, discriminantTy2}}); - scope->dcrRefinements[def2] = resultTy; + partition.discriminantTypes.push_back(ty); + scope->dcrRefinements[def] = arena->addType(IntersectionType{std::move(partition.discriminantTypes)}); + } } for (auto& c : constraints) @@ -321,7 +341,7 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) visitBlockWithoutChildScope(scope, block); - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->captureGenerationModule(module); } @@ -543,8 +563,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) // HACK: In the greedy solver, we say the type state of a variable is the type annotation itself, but // the actual type state is the corresponding initializer expression (if it exists) or nil otherwise. - if (auto def = dfg->getDef(l)) - scope->dcrRefinements[*def] = varTypes[i]; + BreadcrumbId bc = dfg->getBreadcrumb(l); + scope->dcrRefinements[bc->def] = varTypes[i]; } if (local->values.size > 0) @@ -578,8 +598,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) { + TypeId annotationTy = builtinTypes->numberType; if (for_->var->annotation) - resolveType(scope, for_->var->annotation, /* inTypeArguments */ false); + annotationTy = resolveType(scope, for_->var->annotation, /* inTypeArguments */ false); auto inferNumber = [&](AstExpr* expr) { if (!expr) @@ -594,7 +615,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) inferNumber(for_->step); ScopePtr forScope = childScope(for_, scope); - forScope->bindings[for_->var] = Binding{builtinTypes->numberType, for_->var->location}; + forScope->bindings[for_->var] = Binding{annotationTy, for_->var->location}; + + BreadcrumbId bc = dfg->getBreadcrumb(for_->var); + forScope->dcrRefinements[bc->def] = annotationTy; visit(forScope, for_->body); } @@ -613,8 +637,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) loopScope->bindings[var] = Binding{ty, var->location}; variableTypes.push_back(ty); - if (auto def = dfg->getDef(var)) - loopScope->dcrRefinements[*def] = ty; + BreadcrumbId bc = dfg->getBreadcrumb(var); + loopScope->dcrRefinements[bc->def] = ty; } // It is always ok to provide too few variables, so we give this pack a free tail. @@ -638,10 +662,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) { ScopePtr repeatScope = childScope(repeat, scope); - visit(repeatScope, repeat->body); + visitBlockWithoutChildScope(repeatScope, repeat->body); - // The condition does indeed have access to bindings from within the body of - // the loop. check(repeatScope, repeat->condition); } @@ -662,6 +684,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* FunctionSignature sig = checkFunctionSignature(scope, function->func); sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; + BreadcrumbId bc = dfg->getBreadcrumb(function->name); + scope->dcrRefinements[bc->def] = functionType; + sig.bodyScope->dcrRefinements[bc->def] = sig.signature; + Checkpoint start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); Checkpoint end = checkpoint(this); @@ -697,10 +723,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct addConstraint(scope, function->name->location, SubtypeConstraint{generalizedType, *existingFunctionTy}); Symbol sym{localName->local}; - std::optional def = dfg->getDef(sym); - LUAU_ASSERT(def); scope->bindings[sym].typeId = generalizedType; - scope->dcrRefinements[*def] = generalizedType; } else scope->bindings[localName->local] = Binding{generalizedType, localName->location}; @@ -742,6 +765,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct if (generalizedType == nullptr) ice->ice("generalizedType == nullptr", function->location); + if (NullableBreadcrumbId bc = dfg->getBreadcrumb(function->name)) + scope->dcrRefinements[bc->def] = generalizedType; + checkFunctionBody(sig.bodyScope, function->func); Checkpoint end = checkpoint(this); @@ -821,19 +847,19 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* { // We need to tweak the BinaryConstraint that we emit, so we cannot use the // strategy of falsifying an AST fragment. - TypeId varId = checkLValue(scope, assign->var); - Inference valueInf = check(scope, assign->value); + TypeId varTy = checkLValue(scope, assign->var); + TypeId valueTy = check(scope, assign->value).ty; TypeId resultType = arena->addType(BlockedType{}); addConstraint(scope, assign->location, - BinaryConstraint{assign->op, varId, valueInf.ty, resultType, assign, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); - addConstraint(scope, assign->location, SubtypeConstraint{resultType, varId}); + BinaryConstraint{assign->op, varTy, valueTy, resultType, assign, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); + addConstraint(scope, assign->location, SubtypeConstraint{resultType, varTy}); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { ScopePtr condScope = childScope(ifStatement->condition, scope); - auto [_, refinement] = check(condScope, ifStatement->condition, std::nullopt); + RefinementId refinement = check(condScope, ifStatement->condition, std::nullopt).refinement; ScopePtr thenScope = childScope(ifStatement->thenbody, scope); applyRefinements(thenScope, ifStatement->condition->location, refinement); @@ -921,7 +947,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* Name globalName(global->name.value); module->declaredGlobals[globalName] = globalTy; - scope->bindings[global->name] = Binding{globalTy, global->location}; + rootScope->bindings[global->name] = Binding{globalTy, global->location}; + + BreadcrumbId bc = dfg->getBreadcrumb(global); + rootScope->dcrRefinements[bc->def] = globalTy; } static bool isMetamethod(const Name& name) @@ -1067,6 +1096,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction module->declaredGlobals[fnName] = fnType; scope->bindings[global->name] = Binding{fnType, global->location}; + + BreadcrumbId bc = dfg->getBreadcrumb(global); + rootScope->dcrRefinements[bc->def] = fnType; } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) @@ -1158,10 +1190,10 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa exprArgs.push_back(indexExpr->expr); - if (auto def = dfg->getDef(indexExpr->expr)) + if (auto bc = dfg->getBreadcrumb(indexExpr->expr)) { TypeId discriminantTy = arena->addType(BlockedType{}); - returnRefinements.push_back(refinementArena.proposition(*def, discriminantTy)); + returnRefinements.push_back(refinementArena.proposition(NotNull{bc}, discriminantTy)); discriminantTypes.push_back(discriminantTy); } else @@ -1172,10 +1204,10 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa { exprArgs.push_back(arg); - if (auto def = dfg->getDef(arg)) + if (auto bc = dfg->getBreadcrumb(arg)) { TypeId discriminantTy = arena->addType(BlockedType{}); - returnRefinements.push_back(refinementArena.proposition(*def, discriminantTy)); + returnRefinements.push_back(refinementArena.proposition(NotNull{bc}, discriminantTy)); discriminantTypes.push_back(discriminantTy); } else @@ -1186,6 +1218,8 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa TypeId fnType = check(scope, call->func).ty; Checkpoint fnEndCheckpoint = checkpoint(this); + std::vector> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType); + module->astOriginalCallTypes[call->func] = fnType; TypePackId expectedArgPack = arena->freshTypePack(scope.get()); @@ -1208,9 +1242,9 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa TypePack expectedArgs; if (!needTail) - expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size()); + expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size(), expectedTypesForCall); else - expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size() - 1); + expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size() - 1, expectedTypesForCall); std::vector args; std::optional argTail; @@ -1278,9 +1312,9 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa if (AstExprLocal* targetLocal = targetExpr->as()) { scope->bindings[targetLocal->local].typeId = resultTy; - auto def = dfg->getDef(targetLocal->local); - if (def) - scope->dcrRefinements[*def] = resultTy; // TODO: typestates: track this as an assignment + + BreadcrumbId bc = dfg->getBreadcrumb(targetLocal); + scope->dcrRefinements[bc->def] = resultTy; // TODO: typestates: track this as an assignment } return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; @@ -1451,36 +1485,35 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) { - std::optional resultTy; - auto def = dfg->getDef(local); - if (def) - resultTy = scope->lookup(*def); - - if (!resultTy) - { - if (auto ty = scope->lookup(local->local)) - resultTy = *ty; - } + BreadcrumbId bc = dfg->getBreadcrumb(local); - if (!resultTy) - return Inference{builtinTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. - - if (def) - return Inference{*resultTy, refinementArena.proposition(*def, builtinTypes->truthyType)}; + if (auto ty = scope->lookup(bc->def)) + return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; + else if (auto ty = scope->lookup(local->local)) + return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; else - return Inference{*resultTy}; + ice->ice("AstExprLocal came before its declaration?"); } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) { - if (std::optional ty = scope->lookup(global->name)) - return Inference{*ty}; + BreadcrumbId bc = dfg->getBreadcrumb(global); /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any * global that is not already in-scope is definitely an unknown symbol. */ - reportError(global->location, UnknownSymbol{global->name.value}); - return Inference{builtinTypes->errorRecoveryType()}; + if (auto ty = scope->lookup(bc->def)) + return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; + else if (auto ty = scope->lookup(global->name)) + { + rootScope->dcrRefinements[bc->def] = *ty; + return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; + } + else + { + reportError(global->location, UnknownSymbol{global->name.value}); + return Inference{builtinTypes->errorRecoveryType()}; + } } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) @@ -1488,19 +1521,19 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* TypeId obj = check(scope, indexName->expr).ty; TypeId result = arena->addType(BlockedType{}); - std::optional def = dfg->getDef(indexName); - if (def) + NullableBreadcrumbId bc = dfg->getBreadcrumb(indexName); + if (bc) { - if (auto ty = scope->lookup(*def)) - return Inference{*ty, refinementArena.proposition(*def, builtinTypes->truthyType)}; - else - scope->dcrRefinements[*def] = result; + if (auto ty = scope->lookup(bc->def)) + return Inference{*ty, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; + + scope->dcrRefinements[bc->def] = result; } addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value}); - if (def) - return Inference{result, refinementArena.proposition(*def, builtinTypes->truthyType)}; + if (bc) + return Inference{result, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; else return Inference{result}; } @@ -1509,15 +1542,26 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* { TypeId obj = check(scope, indexExpr->expr).ty; TypeId indexType = check(scope, indexExpr->index).ty; - TypeId result = freshType(scope); + NullableBreadcrumbId bc = dfg->getBreadcrumb(indexExpr); + if (bc) + { + if (auto ty = scope->lookup(bc->def)) + return Inference{*ty, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; + + scope->dcrRefinements[bc->def] = result; + } + TableIndexer indexer{indexType, result}; TypeId tableType = arena->addType(TableType{TableType::Props{}, TableIndexer{indexType, result}, TypeLevel{}, scope.get(), TableState::Free}); addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); - return Inference{result}; + if (bc) + return Inference{result, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; + else + return Inference{result}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) @@ -1545,7 +1589,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* bi Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) { ScopePtr condScope = childScope(ifElse->condition, scope); - auto [_, refinement] = check(scope, ifElse->condition); + RefinementId refinement = check(condScope, ifElse->condition).refinement; ScopePtr thenScope = childScope(ifElse->trueExpr, scope); applyRefinements(thenScope, ifElse->trueExpr->location, refinement); @@ -1600,8 +1644,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( TypeId leftType = check(scope, binary->left).ty; TypeId rightType = check(scope, binary->right).ty; - std::optional def = dfg->getDef(typeguard->target); - if (!def) + NullableBreadcrumbId bc = dfg->getBreadcrumb(typeguard->target); + if (!bc) return {leftType, rightType, nullptr}; TypeId discriminantTy = builtinTypes->neverType; @@ -1637,7 +1681,7 @@ std::tuple ConstraintGraphBuilder::checkBinary( discriminantTy = ty; } - RefinementId proposition = refinementArena.proposition(*def, discriminantTy); + RefinementId proposition = refinementArena.proposition(NotNull{bc}, discriminantTy); if (binary->op == AstExprBinary::CompareEq) return {leftType, rightType, proposition}; else if (binary->op == AstExprBinary::CompareNe) @@ -1651,12 +1695,12 @@ std::tuple ConstraintGraphBuilder::checkBinary( TypeId rightType = check(scope, binary->right, expectedType, true).ty; RefinementId leftRefinement = nullptr; - if (auto def = dfg->getDef(binary->left)) - leftRefinement = refinementArena.proposition(*def, rightType); + if (auto bc = dfg->getBreadcrumb(binary->left)) + leftRefinement = refinementArena.proposition(NotNull{bc}, rightType); RefinementId rightRefinement = nullptr; - if (auto def = dfg->getDef(binary->right)) - rightRefinement = refinementArena.proposition(*def, leftType); + if (auto bc = dfg->getBreadcrumb(binary->right)) + rightRefinement = refinementArena.proposition(NotNull{bc}, leftType); if (binary->op == AstExprBinary::CompareNe) { @@ -1685,6 +1729,21 @@ std::vector ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, return types; } +static bool isIndexNameEquivalent(AstExpr* expr) +{ + if (expr->is()) + return true; + + AstExprIndexExpr* e = expr->as(); + if (e == nullptr) + return false; + + if (!e->index->is()) + return false; + + return true; +} + /** * This function is mostly about identifying properties that are being inserted into unsealed tables. * @@ -1692,16 +1751,8 @@ std::vector ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, */ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) { - if (auto indexExpr = expr->as()) + if (auto indexExpr = expr->as(); indexExpr && !indexExpr->index->is()) { - if (auto constantString = indexExpr->index->as()) - { - AstName syntheticIndex{constantString->value.data}; - AstExprIndexName synthetic{ - indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'}; - return checkLValue(scope, &synthetic); - } - // An indexer is only interesting in an lvalue-ey way if it is at the // tail of an expression. // @@ -1724,7 +1775,7 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) return propType; } - else if (!expr->is()) + else if (!isIndexNameEquivalent(expr)) return check(scope, expr).ty; Symbol sym; @@ -1750,6 +1801,19 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) exprs.push_back(e); e = indexName->expr; } + else if (auto indexExpr = e->as()) + { + if (auto strIndex = indexExpr->index->as()) + { + segments.push_back(std::string(strIndex->value.data, strIndex->value.size)); + exprs.push_back(e); + e = indexExpr->expr; + } + else + { + return check(scope, expr).ty; + } + } else return check(scope, expr).ty; } @@ -1788,13 +1852,10 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) { symbolScope->bindings[sym].typeId = updatedType; - std::optional def = dfg->getDef(sym); - if (def) - { - // This can fail if the user is erroneously trying to augment a builtin - // table like os or string. - symbolScope->dcrRefinements[*def] = updatedType; - } + // This can fail if the user is erroneously trying to augment a builtin + // table like os or string. + if (auto bc = dfg->getBreadcrumb(e)) + symbolScope->dcrRefinements[bc->def] = updatedType; } return propTy; @@ -1984,36 +2045,32 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS argTypes.push_back(selfType); argNames.emplace_back(FunctionArgument{fn->self->name.value, fn->self->location}); signatureScope->bindings[fn->self] = Binding{selfType, fn->self->location}; + + BreadcrumbId bc = dfg->getBreadcrumb(fn->self); + signatureScope->dcrRefinements[bc->def] = selfType; } for (size_t i = 0; i < fn->args.size; ++i) { AstLocal* local = fn->args.data[i]; - TypeId t = freshType(signatureScope); - argTypes.push_back(t); - argNames.emplace_back(FunctionArgument{local->name.value, local->location}); - signatureScope->bindings[local] = Binding{t, local->location}; - - auto def = dfg->getDef(local); - LUAU_ASSERT(def); - signatureScope->dcrRefinements[*def] = t; - - TypeId annotationTy = t; - + TypeId argTy = nullptr; if (local->annotation) + argTy = resolveType(signatureScope, local->annotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); + else { - annotationTy = resolveType(signatureScope, local->annotation, /* inTypeArguments */ false); - // If we provide an annotation that is wrong, type inference should ignore the annotation - // and try to infer a fresh type, like in the old solver - if (get(follow(annotationTy))) - annotationTy = freshType(signatureScope); - addConstraint(signatureScope, local->annotation->location, SubtypeConstraint{t, annotationTy}); - } - else if (i < expectedArgPack.head.size()) - { - addConstraint(signatureScope, local->location, SubtypeConstraint{t, expectedArgPack.head[i]}); + argTy = freshType(signatureScope); + + if (i < expectedArgPack.head.size()) + addConstraint(signatureScope, local->location, SubtypeConstraint{argTy, expectedArgPack.head[i]}); } + + argTypes.push_back(argTy); + argNames.emplace_back(FunctionArgument{local->name.value, local->location}); + signatureScope->bindings[local] = Binding{argTy, local->location}; + + BreadcrumbId bc = dfg->getBreadcrumb(local); + signatureScope->dcrRefinements[bc->def] = argTy; } TypePackId varargPack = nullptr; @@ -2022,7 +2079,8 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS { if (fn->varargAnnotation) { - TypePackId annotationType = resolveTypePack(signatureScope, fn->varargAnnotation, /* inTypeArguments */ false); + TypePackId annotationType = + resolveTypePack(signatureScope, fn->varargAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh */ true); varargPack = annotationType; } else if (expectedArgPack.tail && get(*expectedArgPack.tail)) @@ -2049,8 +2107,8 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // Type checking will sort out any discrepancies later. if (fn->returnAnnotation) { - TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation, /* inTypeArguments */ false); - + TypePackId annotatedRetType = + resolveTypePack(signatureScope, *fn->returnAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); // We bind the annotated type directly here so that, when we need to // generate constraints for return types, we have a guarantee that we // know the annotated return type already, if one was provided. @@ -2098,7 +2156,7 @@ void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFun } } -TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments) +TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh) { TypeId result = nullptr; @@ -2176,6 +2234,8 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b else { result = builtinTypes->errorRecoveryType(); + if (replaceErrorWithFresh) + result = freshType(scope); } } else if (auto tab = ty->as()) @@ -2239,8 +2299,8 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b signatureScope = scope; } - TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments); - TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments); + TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments, replaceErrorWithFresh); + TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments, replaceErrorWithFresh); // TODO: FunctionType needs a pointer to the scope so that we know // how to quantify/instantiate it. @@ -2307,6 +2367,8 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b else if (ty->is()) { result = builtinTypes->errorRecoveryType(); + if (replaceErrorWithFresh) + result = freshType(scope); } else { @@ -2318,18 +2380,16 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument) +TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument, bool replaceErrorWithFresh) { TypePackId result; if (auto expl = tp->as()) { - result = resolveTypePack(scope, expl->typeList, inTypeArgument); + result = resolveTypePack(scope, expl->typeList, inTypeArgument, replaceErrorWithFresh); } else if (auto var = tp->as()) { - TypeId ty = resolveType(scope, var->variadicType, inTypeArgument); - if (get(follow(ty))) - ty = freshType(scope); + TypeId ty = resolveType(scope, var->variadicType, inTypeArgument, replaceErrorWithFresh); result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); } else if (auto gen = tp->as()) @@ -2354,19 +2414,19 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTyp return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments) +TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh) { std::vector head; for (AstType* headTy : list.types) { - head.push_back(resolveType(scope, headTy, inTypeArguments)); + head.push_back(resolveType(scope, headTy, inTypeArguments, replaceErrorWithFresh)); } std::optional tail = std::nullopt; if (list.tailType) { - tail = resolveTypePack(scope, list.tailType, inTypeArguments); + tail = resolveTypePack(scope, list.tailType, inTypeArguments, replaceErrorWithFresh); } return arena->addTypePack(TypePack{head, tail}); @@ -2454,7 +2514,7 @@ void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) { errors.push_back(TypeError{location, moduleName, std::move(err)}); - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->captureGenerationError(errors.back()); } @@ -2462,7 +2522,7 @@ void ConstraintGraphBuilder::reportCodeTooComplex(Location location) { errors.push_back(TypeError{location, moduleName, CodeTooComplex{}}); - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->captureGenerationError(errors.back()); } @@ -2493,6 +2553,69 @@ void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, program->visit(&gp); } +std::vector> ConstraintGraphBuilder::getExpectedCallTypesForFunctionOverloads(const TypeId fnType) +{ + std::vector funTys; + if (auto it = get(follow(fnType))) + { + for (TypeId intersectionComponent : it) + { + funTys.push_back(intersectionComponent); + } + } + + std::vector> expectedTypes; + // For a list of functions f_0 : e_0 -> r_0, ... f_n : e_n -> r_n, + // emit a list of arguments that the function could take at each position + // by unioning the arguments at each place + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { + if (index == expectedTypes.size()) + expectedTypes.push_back(ty); + else if (ty) + { + auto& el = expectedTypes[index]; + + if (!el) + el = ty; + else + { + std::vector result = reduceUnion({*el, ty}); + if (result.empty()) + el = builtinTypes->neverType; + else if (result.size() == 1) + el = result[0]; + else + el = module->internalTypes.addType(UnionType{std::move(result)}); + } + } + }; + + for (const TypeId overload : funTys) + { + if (const FunctionType* ftv = get(follow(overload))) + { + auto [argsHead, argsTail] = flatten(ftv->argTypes); + size_t start = ftv->hasSelf ? 1 : 0; + size_t index = 0; + for (size_t i = start; i < argsHead.size(); ++i) + assignOption(index++, argsHead[i]); + if (argsTail) + { + argsTail = follow(*argsTail); + if (const VariadicTypePack* vtp = get(*argsTail)) + { + while (index < funTys.size()) + assignOption(index++, vtp->ty); + } + } + } + } + + // TODO vvijay Feb 24, 2023 apparently we have to demote the types here? + + return expectedTypes; +} + std::vector> borrowConstraints(const std::vector& constraints) { std::vector> result; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 96673e3dc..3cb4e4e7e 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -17,7 +17,6 @@ #include "Luau/VisitType.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); -LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { @@ -261,9 +260,6 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNullcaptureInitialSolverState(rootScope, unsolvedConstraints); } @@ -320,7 +316,7 @@ void ConstraintSolver::run() std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; StepSnapshot snapshot; - if (FFlag::DebugLuauLogSolverToJson) + if (logger) { snapshot = logger->prepareStepSnapshot(rootScope, c, force, unsolvedConstraints); } @@ -334,7 +330,7 @@ void ConstraintSolver::run() unblock(c); unsolvedConstraints.erase(unsolvedConstraints.begin() + i); - if (FFlag::DebugLuauLogSolverToJson) + if (logger) { logger->commitStepSnapshot(snapshot); } @@ -393,7 +389,7 @@ void ConstraintSolver::run() dumpBindings(rootScope, opts); } - if (FFlag::DebugLuauLogSolverToJson) + if (logger) { logger->captureFinalSolverState(rootScope, unsolvedConstraints); } @@ -486,6 +482,9 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNullreduce(subjectType).value_or(subjectType); - std::optional resultType = lookupTableProp(subjectType, c.prop); - if (!resultType) + auto [blocked, result] = lookupTableProp(subjectType, c.prop); + if (!blocked.empty()) { - asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); - unblock(c.resultType); - return true; - } + for (TypeId blocked : blocked) + block(blocked, constraint); - if (isBlocked(*resultType)) - { - block(*resultType, constraint); return false; } - asMutable(c.resultType)->ty.emplace(*resultType); + asMutable(c.resultType)->ty.emplace(result.value_or(builtinTypes->errorRecoveryType())); + unblock(c.resultType); return true; } @@ -1426,17 +1421,18 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull existingPropType = subjectType; for (const std::string& segment : c.path) { - ErrorVec e; - std::optional propTy = lookupTableProp(*existingPropType, segment); - if (!propTy) - { - existingPropType = std::nullopt; + if (!existingPropType) break; + + auto [blocked, result] = lookupTableProp(*existingPropType, segment); + if (!blocked.empty()) + { + for (TypeId blocked : blocked) + block(blocked, constraint); + return false; } - else if (isBlocked(*propTy)) - return block(*propTy, constraint); - else - existingPropType = follow(*propTy); + + existingPropType = result; } auto bind = [](TypeId a, TypeId b) { @@ -1451,6 +1447,9 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) + subjectType = follow(mt->table); + if (get(subjectType) || get(subjectType) || get(subjectType)) { bind(c.resultType, subjectType); @@ -1504,8 +1503,8 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) { - // Classes never change shape as a result of property assignments. - // The result is always the subject. + // Classes and intersections never change shape as a result of property + // assignments. The result is always the subject. bind(c.resultType, subjectType); return true; } @@ -1833,122 +1832,68 @@ bool ConstraintSolver::tryDispatchIterableFunction( return true; } -std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) +std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) { std::unordered_set seen; return lookupTableProp(subjectType, propName, seen); } -std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) +std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) { if (!seen.insert(subjectType).second) - return std::nullopt; - - auto collectParts = [&](auto&& unionOrIntersection) -> std::pair, std::vector> { - std::optional blocked; - - std::vector parts; - std::vector freeParts; - for (TypeId expectedPart : unionOrIntersection) - { - expectedPart = follow(expectedPart); - if (isBlocked(expectedPart) || get(expectedPart)) - blocked = expectedPart; - else if (const TableType* ttv = get(follow(expectedPart))) - { - if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) - parts.push_back(prop->second.type); - else if (ttv->indexer && maybeString(ttv->indexer->indexType)) - parts.push_back(ttv->indexer->indexResultType); - } - else if (get(expectedPart)) - { - freeParts.push_back(expectedPart); - } - } - - // If the only thing resembling a match is a single fresh type, we can - // confidently tablify it. If other types match or if there are more - // than one free type, we can't do anything. - if (parts.empty() && 1 == freeParts.size()) - { - TypeId freePart = freeParts.front(); - const FreeType* ft = get(freePart); - LUAU_ASSERT(ft); - Scope* scope = ft->scope; - - TableType* tt = &asMutable(freePart)->ty.emplace(); - tt->state = TableState::Free; - tt->scope = scope; - TypeId propType = arena->freshType(scope); - tt->props[propName] = Property{propType}; - - parts.push_back(propType); - } - - return {blocked, parts}; - }; + return {}; - std::optional resultType; + subjectType = follow(subjectType); - if (get(subjectType) || get(subjectType)) + if (isBlocked(subjectType)) + return {{subjectType}, std::nullopt}; + else if (get(subjectType) || get(subjectType)) { - return subjectType; + return {{}, subjectType}; } else if (auto ttv = getMutable(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) - resultType = prop->second.type; + return {{}, prop->second.type}; else if (ttv->indexer && maybeString(ttv->indexer->indexType)) - resultType = ttv->indexer->indexResultType; + return {{}, ttv->indexer->indexResultType}; else if (ttv->state == TableState::Free) { - resultType = arena->addType(FreeType{ttv->scope}); - ttv->props[propName] = Property{*resultType}; + TypeId result = arena->freshType(ttv->scope); + ttv->props[propName] = Property{result}; + return {{}, result}; } } else if (auto mt = get(subjectType)) { - if (auto p = lookupTableProp(mt->table, propName, seen)) - return p; + auto [blocked, result] = lookupTableProp(mt->table, propName, seen); + if (!blocked.empty() || result) + return {blocked, result}; TypeId mtt = follow(mt->metatable); if (get(mtt)) - return mtt; + return {{mtt}, std::nullopt}; else if (auto metatable = get(mtt)) { auto indexProp = metatable->props.find("__index"); if (indexProp == metatable->props.end()) - return std::nullopt; + return {{}, result}; // TODO: __index can be an overloaded function. TypeId indexType = follow(indexProp->second.type); if (auto ft = get(indexType)) - { - std::optional ret = first(ft->retTypes); - if (ret) - return *ret; - else - return std::nullopt; - } - - return lookupTableProp(indexType, propName, seen); + return {{}, first(ft->retTypes)}; + else + return lookupTableProp(indexType, propName, seen); } } else if (auto ct = get(subjectType)) { - while (ct) - { - if (auto prop = ct->props.find(propName); prop != ct->props.end()) - return prop->second.type; - else if (ct->parent) - ct = get(follow(*ct->parent)); - else - break; - } + if (auto p = lookupClassProp(ct, propName)) + return {{}, p->type}; } else if (auto pt = get(subjectType); pt && pt->metatable) { @@ -1957,38 +1902,70 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons auto indexProp = metatable->props.find("__index"); if (indexProp == metatable->props.end()) - return std::nullopt; + return {{}, std::nullopt}; return lookupTableProp(indexProp->second.type, propName, seen); } + else if (auto ft = get(subjectType)) + { + Scope* scope = ft->scope; + + TableType* tt = &asMutable(subjectType)->ty.emplace(); + tt->state = TableState::Free; + tt->scope = scope; + TypeId propType = arena->freshType(scope); + tt->props[propName] = Property{propType}; + + return {{}, propType}; + } else if (auto utv = get(subjectType)) { - auto [blocked, parts] = collectParts(utv); + std::vector blocked; + std::vector options; - if (blocked) - resultType = *blocked; - else if (parts.size() == 1) - resultType = parts[0]; - else if (parts.size() > 1) - resultType = arena->addType(UnionType{std::move(parts)}); + for (TypeId ty : utv) + { + auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); + if (innerResult) + options.push_back(*innerResult); + } + + if (!blocked.empty()) + return {blocked, std::nullopt}; - // otherwise, nothing: no matching property + if (options.empty()) + return {{}, std::nullopt}; + else if (options.size() == 1) + return {{}, options[0]}; + else + return {{}, arena->addType(UnionType{std::move(options)})}; } else if (auto itv = get(subjectType)) { - auto [blocked, parts] = collectParts(itv); + std::vector blocked; + std::vector options; - if (blocked) - resultType = *blocked; - else if (parts.size() == 1) - resultType = parts[0]; - else if (parts.size() > 1) - resultType = arena->addType(IntersectionType{std::move(parts)}); + for (TypeId ty : itv) + { + auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); + if (innerResult) + options.push_back(*innerResult); + } + + if (!blocked.empty()) + return {blocked, std::nullopt}; - // otherwise, nothing: no matching property + if (options.empty()) + return {{}, std::nullopt}; + else if (options.size() == 1) + return {{}, options[0]}; + else + return {{}, arena->addType(IntersectionType{std::move(options)})}; } - return resultType; + return {{}, std::nullopt}; } void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) @@ -2001,7 +1978,7 @@ void ConstraintSolver::block_(BlockedConstraintId target, NotNull target, NotNull constraint) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->pushBlock(constraint, target); if (FFlag::DebugLuauLogSolver) @@ -2012,7 +1989,7 @@ void ConstraintSolver::block(NotNull target, NotNull constraint) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->pushBlock(constraint, target); if (FFlag::DebugLuauLogSolver) @@ -2024,7 +2001,7 @@ bool ConstraintSolver::block(TypeId target, NotNull constraint bool ConstraintSolver::block(TypePackId target, NotNull constraint) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->pushBlock(constraint, target); if (FFlag::DebugLuauLogSolver) @@ -2102,7 +2079,7 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) void ConstraintSolver::unblock(NotNull progressed) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->popBlock(progressed); return unblock_(progressed.get()); @@ -2110,7 +2087,7 @@ void ConstraintSolver::unblock(NotNull progressed) void ConstraintSolver::unblock(TypeId progressed) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->popBlock(progressed); unblock_(progressed); @@ -2121,7 +2098,7 @@ void ConstraintSolver::unblock(TypeId progressed) void ConstraintSolver::unblock(TypePackId progressed) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->popBlock(progressed); return unblock_(progressed); diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 7e7166037..e73c7e8c9 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -1,7 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/DataFlowGraph.h" +#include "Luau/Breadcrumb.h" #include "Luau/Error.h" +#include "Luau/Refinement.h" LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -9,69 +11,97 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau { -std::optional DataFlowGraph::getDef(const AstExpr* expr) const +NullableBreadcrumbId DataFlowGraph::getBreadcrumb(const AstExpr* expr) const { // We need to skip through AstExprGroup because DFG doesn't try its best to transitively while (auto group = expr->as()) expr = group->expr; - if (auto def = astDefs.find(expr)) - return NotNull{*def}; - return std::nullopt; + if (auto bc = astBreadcrumbs.find(expr)) + return *bc; + return nullptr; } -std::optional DataFlowGraph::getDef(const AstLocal* local) const +BreadcrumbId DataFlowGraph::getBreadcrumb(const AstLocal* local) const { - if (auto def = localDefs.find(local)) - return NotNull{*def}; - return std::nullopt; + auto bc = localBreadcrumbs.find(local); + LUAU_ASSERT(bc); + return NotNull{*bc}; } -std::optional DataFlowGraph::getDef(const Symbol& symbol) const +BreadcrumbId DataFlowGraph::getBreadcrumb(const AstExprLocal* local) const { - if (symbol.local) - return getDef(symbol.local); - else - return std::nullopt; + auto bc = astBreadcrumbs.find(local); + LUAU_ASSERT(bc); + return NotNull{*bc}; } -DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) +BreadcrumbId DataFlowGraph::getBreadcrumb(const AstExprGlobal* global) const { - LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + auto bc = astBreadcrumbs.find(global); + LUAU_ASSERT(bc); + return NotNull{*bc}; +} - DataFlowGraphBuilder builder; - builder.handle = handle; - builder.visit(nullptr, block); // nullptr is the root DFG scope. - if (FFlag::DebugLuauFreezeArena) - builder.arena->allocator.freeze(); - return std::move(builder.graph); +BreadcrumbId DataFlowGraph::getBreadcrumb(const AstStatDeclareGlobal* global) const +{ + auto bc = declaredBreadcrumbs.find(global); + LUAU_ASSERT(bc); + return NotNull{*bc}; } -DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope) +BreadcrumbId DataFlowGraph::getBreadcrumb(const AstStatDeclareFunction* func) const { - return scopes.emplace_back(new DfgScope{scope}).get(); + auto bc = declaredBreadcrumbs.find(func); + LUAU_ASSERT(bc); + return NotNull{*bc}; } -std::optional DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, AstExpr* e) +NullableBreadcrumbId DfgScope::lookup(Symbol symbol) const { - for (DfgScope* current = scope; current; current = current->parent) + for (const DfgScope* current = this; current; current = current->parent) { - if (auto def = current->bindings.find(symbol)) + if (auto breadcrumb = current->bindings.find(symbol)) + return *breadcrumb; + } + + return nullptr; +} + +NullableBreadcrumbId DfgScope::lookup(DefId def, const std::string& key) const +{ + for (const DfgScope* current = this; current; current = current->parent) + { + if (auto map = props.find(def)) { - graph.astDefs[e] = *def; - return NotNull{*def}; + if (auto it = map->find(key); it != map->end()) + return it->second; } } - return std::nullopt; + return nullptr; +} + +DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + + DataFlowGraphBuilder builder; + builder.handle = handle; + builder.moduleScope = builder.childScope(nullptr); // nullptr is the root DFG scope. + builder.visitBlockWithoutChildScope(builder.moduleScope, block); + + if (FFlag::DebugLuauFreezeArena) + { + builder.defs->allocator.freeze(); + builder.breadcrumbs->allocator.freeze(); + } + + return std::move(builder.graph); } -DefId DataFlowGraphBuilder::use(DefId def, AstExprIndexName* e) +DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope) { - auto& propertyDef = props[def][e->index.value]; - if (!propertyDef) - propertyDef = arena->freshCell(def, e->index.value); - graph.astDefs[e] = propertyDef; - return NotNull{propertyDef}; + return scopes.emplace_back(new DfgScope{scope}).get(); } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) @@ -119,27 +149,24 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) else if (auto l = s->as()) return visit(scope, l); else if (auto t = s->as()) - return; // ok - else if (auto d = s->as()) - return; // ok + return visit(scope, t); else if (auto d = s->as()) - return; // ok + return visit(scope, d); else if (auto d = s->as()) - return; // ok + return visit(scope, d); else if (auto d = s->as()) - return; // ok - else if (auto _ = s->as()) - return; // ok + return visit(scope, d); + else if (auto error = s->as()) + return visit(scope, error); else - handle->ice("Unknown AstStat in DataFlowGraphBuilder"); + handle->ice("Unknown AstStat in DataFlowGraphBuilder::visit"); } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) { - DfgScope* condScope = childScope(scope); - visitExpr(condScope, i->condition); - visit(condScope, i->thenbody); - + // TODO: type states and control flow analysis + visitExpr(scope, i->condition); + visit(scope, i->thenbody); if (i->elsebody) visit(scope, i->elsebody); } @@ -186,24 +213,41 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) { - // TODO: alias tracking + // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) + std::vector bcs; + bcs.reserve(l->values.size); for (AstExpr* e : l->values) - visitExpr(scope, e); + bcs.push_back(visitExpr(scope, e)); - for (AstLocal* local : l->vars) + for (size_t i = 0; i < l->vars.size; ++i) { - DefId def = arena->freshCell(); - graph.localDefs[local] = def; - scope->bindings[local] = def; + AstLocal* local = l->vars.data[i]; + if (local->annotation) + visitType(scope, local->annotation); + + // We need to create a new breadcrumb with new defs to intentionally avoid alias tracking. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell(), i < bcs.size() ? bcs[i]->metadata : std::nullopt); + graph.localBreadcrumbs[local] = bc; + scope->bindings[local] = bc; } } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) { DfgScope* forScope = childScope(scope); // TODO: loop scope. - DefId def = arena->freshCell(); - graph.localDefs[f->var] = def; - scope->bindings[f->var] = def; + + visitExpr(scope, f->from); + visitExpr(scope, f->to); + if (f->step) + visitExpr(scope, f->step); + + if (f->var->annotation) + visitType(forScope, f->var->annotation); + + // TODO: RangeMetadata. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.localBreadcrumbs[f->var] = bc; + scope->bindings[f->var] = bc; // TODO(controlflow): entry point has a back edge from exit point visit(forScope, f->body); @@ -215,12 +259,17 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) for (AstLocal* local : f->vars) { - DefId def = arena->freshCell(); - graph.localDefs[local] = def; - forScope->bindings[local] = def; + if (local->annotation) + visitType(forScope, local->annotation); + + // TODO: IterMetadata (different from RangeMetadata) + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.localBreadcrumbs[local] = bc; + forScope->bindings[local] = bc; } // TODO(controlflow): entry point has a back edge from exit point + // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) for (AstExpr* e : f->values) visitExpr(forScope, e); @@ -233,87 +282,117 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) visitExpr(scope, r); for (AstExpr* l : a->vars) - { - AstExpr* root = l; - - bool isUpdatable = true; - while (true) - { - if (root->is() || root->is()) - break; - - AstExprIndexName* indexName = root->as(); - if (!indexName) - { - isUpdatable = false; - break; - } - - root = indexName->expr; - } - - if (isUpdatable) - { - // TODO global? - if (auto exprLocal = root->as()) - { - DefId def = arena->freshCell(); - graph.astDefs[exprLocal] = def; - - // Update the def in the scope that introduced the local. Not - // the current scope. - AstLocal* local = exprLocal->local; - DfgScope* s = scope; - while (s && !s->bindings.find(local)) - s = s->parent; - LUAU_ASSERT(s && s->bindings.find(local)); - s->bindings[local] = def; - } - } - - visitExpr(scope, l); // TODO: they point to a new def!! - } + visitLValue(scope, l); } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) { - // TODO(typestates): The lhs is being read and written to. This might or might not be annoying. + // TODO: This needs revisiting because this is incorrect. The `c->var` part is both being read and written to, + // but the `c->var` only has one pointer address, so we need to come up with a way to store both. + // For now, it's not important because we don't have type states, but it is going to be important, e.g. + // + // local a = 5 -- a[1] + // a += 5 -- a[2] = a[1] + 5 + // + // We can't just visit `c->var` as a rvalue and then separately traverse `c->var` as an lvalue, since that's O(n^2). + visitLValue(scope, c->var); visitExpr(scope, c->value); } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) { - visitExpr(scope, f->name); + // In the old solver, we assumed that the name of the function is always a function in the body + // but this isn't true, e.g. the following example will print `5`, not a function address. + // + // local function f() print(f) end + // local g = f + // f = 5 + // g() --> 5 + // + // which is evidence that references to variables must be a phi node of all possible definitions, + // but for bug compatibility, we'll assume the same thing here. + visitLValue(scope, f->name); visitExpr(scope, f->func); } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) { - DefId def = arena->freshCell(); - graph.localDefs[l->name] = def; - scope->bindings[l->name] = def; + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.localBreadcrumbs[l->name] = bc; + scope->bindings[l->name] = bc; visitExpr(scope, l->func); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t) +{ + DfgScope* unreachable = childScope(scope); + visitGenerics(unreachable, t->generics); + visitGenericPacks(unreachable, t->genericPacks); + visitType(unreachable, t->type); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d) +{ + // TODO: AmbientDeclarationMetadata. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.declaredBreadcrumbs[d] = bc; + scope->bindings[d->name] = bc; + + visitType(scope, d->type); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d) +{ + // TODO: AmbientDeclarationMetadata. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.declaredBreadcrumbs[d] = bc; + scope->bindings[d->name] = bc; + + DfgScope* unreachable = childScope(scope); + visitGenerics(unreachable, d->generics); + visitGenericPacks(unreachable, d->genericPacks); + visitTypeList(unreachable, d->params); + visitTypeList(unreachable, d->retTypes); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d) +{ + // This declaration does not "introduce" any bindings in value namespace, + // so there's no symbolic value to begin with. We'll traverse the properties + // because their type annotations may depend on something in the value namespace. + DfgScope* unreachable = childScope(scope); + for (AstDeclaredClassProp prop : d->props) + visitType(unreachable, prop.ty); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatError* error) +{ + DfgScope* unreachable = childScope(scope); + for (AstStat* s : error->statements) + visit(unreachable, s); + for (AstExpr* e : error->expressions) + visitExpr(unreachable, e); +} + +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) { if (auto g = e->as()) return visitExpr(scope, g->expr); else if (auto c = e->as()) - return {}; // ok + return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as()) - return {}; // ok + return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as()) - return {}; // ok + return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as()) - return {}; // ok + return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto l = e->as()) return visitExpr(scope, l); else if (auto g = e->as()) return visitExpr(scope, g); else if (auto v = e->as()) - return {}; // ok + return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as()) return visitExpr(scope, c); else if (auto i = e->as()) @@ -334,76 +413,123 @@ ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) return visitExpr(scope, i); else if (auto i = e->as()) return visitExpr(scope, i); - else if (auto _ = e->as()) - return {}; // ok + else if (auto error = e->as()) + return visitExpr(scope, error); else - handle->ice("Unknown AstExpr in DataFlowGraphBuilder"); + handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitExpr"); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) { - return {use(scope, l->local, l)}; + NullableBreadcrumbId breadcrumb = scope->lookup(l->local); + if (!breadcrumb) + handle->ice("AstExprLocal came before its declaration?"); + + graph.astBreadcrumbs[l] = breadcrumb; + return NotNull{breadcrumb}; } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) { - return {use(scope, g->name, g)}; + NullableBreadcrumbId bc = scope->lookup(g->name); + if (!bc) + { + bc = breadcrumbs->add(nullptr, defs->freshCell()); + moduleScope->bindings[g->name] = bc; + } + + graph.astBreadcrumbs[g] = bc; + return NotNull{bc}; } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) { visitExpr(scope, c->func); for (AstExpr* arg : c->args) visitExpr(scope, arg); - return {}; + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) { - std::optional def = visitExpr(scope, i->expr).def; - if (!def) - return {}; + BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); - return {use(*def, i)}; + std::string key = i->index.value; + NullableBreadcrumbId& propBreadcrumb = moduleScope->props[parentBreadcrumb->def][key]; + if (!propBreadcrumb) + propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); + + graph.astBreadcrumbs[i] = propBreadcrumb; + return NotNull{propBreadcrumb}; } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) { - visitExpr(scope, i->expr); - visitExpr(scope, i->index); + BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + BreadcrumbId key = visitExpr(scope, i->index); - if (i->index->as()) + if (auto string = i->index->as()) { - // TODO: properties for the def + std::string key{string->value.data, string->value.size}; + NullableBreadcrumbId& propBreadcrumb = moduleScope->props[parentBreadcrumb->def][key]; + if (!propBreadcrumb) + propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); + + graph.astBreadcrumbs[i] = NotNull{propBreadcrumb}; + return NotNull{propBreadcrumb}; } - return {}; + return breadcrumbs->emplace(nullptr, defs->freshCell(), key); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) { + DfgScope* signatureScope = childScope(scope); + if (AstLocal* self = f->self) { - DefId def = arena->freshCell(); - graph.localDefs[self] = def; - scope->bindings[self] = def; + // There's no syntax for `self` to have an annotation if using `function t:m()` + LUAU_ASSERT(!self->annotation); + + // TODO: ParameterMetadata. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.localBreadcrumbs[self] = bc; + signatureScope->bindings[self] = bc; } for (AstLocal* param : f->args) { - DefId def = arena->freshCell(); - graph.localDefs[param] = def; - scope->bindings[param] = def; + if (param->annotation) + visitType(signatureScope, param->annotation); + + // TODO: ParameterMetadata. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.localBreadcrumbs[param] = bc; + signatureScope->bindings[param] = bc; } - visit(scope, f->body); + if (f->varargAnnotation) + visitTypePack(scope, f->varargAnnotation); - return {}; + if (f->returnAnnotation) + visitTypeList(signatureScope, *f->returnAnnotation); + + // TODO: function body can be re-entrant, as in mutations that occurs at the end of the function can also be + // visible to the beginning of the function, so statically speaking, the body of the function has an exit point + // that points back to itself, e.g. + // + // local function f() print(f) f = 5 end + // local g = f + // g() --> function: address + // g() --> 5 + visit(signatureScope, f->body); + + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) { for (AstExprTable::Item item : t->items) { @@ -412,47 +538,259 @@ ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTabl visitExpr(scope, item.value); } - return {}; + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) { visitExpr(scope, u->expr); - return {}; + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) { visitExpr(scope, b->left); visitExpr(scope, b->right); - return {}; + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) { - ExpressionFlowGraph result = visitExpr(scope, t->expr); - // TODO: visit type - return result; + // TODO: TypeAssertionMetadata? + BreadcrumbId bc = visitExpr(scope, t->expr); + visitType(scope, t->annotation); + + return bc; } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) { - DfgScope* condScope = childScope(scope); - visitExpr(condScope, i->condition); - visitExpr(condScope, i->trueExpr); - + visitExpr(scope, i->condition); + visitExpr(scope, i->trueExpr); visitExpr(scope, i->falseExpr); - return {}; + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) { for (AstExpr* e : i->expressions) visitExpr(scope, e); - return {}; + + return breadcrumbs->add(nullptr, defs->freshCell()); +} + +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* error) +{ + DfgScope* unreachable = childScope(scope); + for (AstExpr* e : error->expressions) + visitExpr(unreachable, e); + + return breadcrumbs->add(nullptr, defs->freshCell()); +} + +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e) +{ + if (auto l = e->as()) + return visitLValue(scope, l); + else if (auto g = e->as()) + return visitLValue(scope, g); + else if (auto i = e->as()) + return visitLValue(scope, i); + else if (auto i = e->as()) + return visitLValue(scope, i); + else if (auto error = e->as()) + { + visitExpr(scope, error); // TODO: is this right? + return; + } + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); +} + +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l) +{ + // Bug compatibility: we don't support type states yet, so we need to do this. + NullableBreadcrumbId bc = scope->lookup(l->local); + LUAU_ASSERT(bc); + + graph.astBreadcrumbs[l] = bc; + scope->bindings[l->local] = bc; +} + +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g) +{ + // Bug compatibility: we don't support type states yet, so we need to do this. + NullableBreadcrumbId bc = scope->lookup(g->name); + if (!bc) + bc = breadcrumbs->add(nullptr, defs->freshCell()); + + graph.astBreadcrumbs[g] = bc; + scope->bindings[g->name] = bc; +} + +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i) +{ + // Bug compatibility: we don't support type states yet, so we need to do this. + BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + + std::string key = i->index.value; + NullableBreadcrumbId propBreadcrumb = scope->lookup(parentBreadcrumb->def, key); + if (!propBreadcrumb) + { + propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); + moduleScope->props[parentBreadcrumb->def][key] = propBreadcrumb; + } + + graph.astBreadcrumbs[i] = propBreadcrumb; +} + +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i) +{ + BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + visitExpr(scope, i->index); + + if (auto string = i->index->as()) + { + std::string key{string->value.data, string->value.size}; + NullableBreadcrumbId propBreadcrumb = scope->lookup(parentBreadcrumb->def, key); + if (!propBreadcrumb) + { + propBreadcrumb = breadcrumbs->add(parentBreadcrumb, parentBreadcrumb->def); + moduleScope->props[parentBreadcrumb->def][key] = propBreadcrumb; + } + + graph.astBreadcrumbs[i] = propBreadcrumb; + } +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t) +{ + if (auto r = t->as()) + return visitType(scope, r); + else if (auto table = t->as()) + return visitType(scope, table); + else if (auto f = t->as()) + return visitType(scope, f); + else if (auto tyof = t->as()) + return visitType(scope, tyof); + else if (auto u = t->as()) + return visitType(scope, u); + else if (auto i = t->as()) + return visitType(scope, i); + else if (auto e = t->as()) + return visitType(scope, e); + else if (auto s = t->as()) + return; // ok + else if (auto s = t->as()) + return; // ok + else + handle->ice("Unknown AstType in DataFlowGraphBuilder::visitType"); +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeReference* r) +{ + for (AstTypeOrPack param : r->parameters) + { + if (param.type) + visitType(scope, param.type); + else + visitTypePack(scope, param.typePack); + } +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeTable* t) +{ + for (AstTableProp p : t->props) + visitType(scope, p.type); + + if (t->indexer) + { + visitType(scope, t->indexer->indexType); + visitType(scope, t->indexer->resultType); + } +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeFunction* f) +{ + visitGenerics(scope, f->generics); + visitGenericPacks(scope, f->genericPacks); + visitTypeList(scope, f->argTypes); + visitTypeList(scope, f->returnTypes); +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeTypeof* t) +{ + visitExpr(scope, t->expr); +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeUnion* u) +{ + for (AstType* t : u->types) + visitType(scope, t); +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeIntersection* i) +{ + for (AstType* t : i->types) + visitType(scope, t); +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeError* error) +{ + for (AstType* t : error->types) + visitType(scope, t); +} + +void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePack* p) +{ + if (auto e = p->as()) + return visitTypePack(scope, e); + else if (auto v = p->as()) + return visitTypePack(scope, v); + else if (auto g = p->as()) + return; // ok + else + handle->ice("Unknown AstTypePack in DataFlowGraphBuilder::visitTypePack"); +} + +void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePackExplicit* e) +{ + visitTypeList(scope, e->typeList); +} + +void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePackVariadic* v) +{ + visitType(scope, v->variadicType); +} + +void DataFlowGraphBuilder::visitTypeList(DfgScope* scope, AstTypeList l) +{ + for (AstType* t : l.types) + visitType(scope, t); + + if (l.tailType) + visitTypePack(scope, l.tailType); +} + +void DataFlowGraphBuilder::visitGenerics(DfgScope* scope, AstArray g) +{ + for (AstGenericType generic : g) + { + if (generic.defaultValue) + visitType(scope, generic.defaultValue); + } +} + +void DataFlowGraphBuilder::visitGenericPacks(DfgScope* scope, AstArray g) +{ + for (AstGenericTypePack generic : g) + { + if (generic.defaultValue) + visitTypePack(scope, generic.defaultValue); + } } } // namespace Luau diff --git a/Analysis/src/Def.cpp b/Analysis/src/Def.cpp index 8ce1129c6..7be075c25 100644 --- a/Analysis/src/Def.cpp +++ b/Analysis/src/Def.cpp @@ -6,12 +6,7 @@ namespace Luau DefId DefArena::freshCell() { - return NotNull{allocator.allocate(Def{Cell{std::nullopt}})}; -} - -DefId DefArena::freshCell(DefId parent, const std::string& prop) -{ - return NotNull{allocator.allocate(Def{Cell{FieldMetadata{parent, prop}}})}; + return NotNull{allocator.allocate(Def{Cell{}})}; } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 91c72e447..b3e453db0 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -25,12 +25,13 @@ #include LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) -LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { @@ -517,7 +518,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& requireCycles, - NotNull builtinTypes, - NotNull iceHandler, - NotNull moduleResolver, - NotNull fileResolver, - const ScopePtr& globalScope, - NotNull unifierState, - FrontendOptions options -) { +ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, + NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, + const ScopePtr& globalScope, FrontendOptions options) +{ + const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson; + return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, globalScope, options, recordJsonLog); +} + +ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, + NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, + const ScopePtr& globalScope, FrontendOptions options, bool recordJsonLog) +{ ModulePtr result = std::make_shared(); result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, iceHandler); std::unique_ptr logger; - if (FFlag::DebugLuauLogSolverToJson) + if (recordJsonLog) { logger = std::make_unique(); std::optional source = fileResolver->readSource(sourceModule.name); @@ -882,7 +886,11 @@ ModulePtr check( DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); - Normalizer normalizer{&result->internalTypes, builtinTypes, unifierState}; + UnifierSharedState unifierState{iceHandler}; + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + + Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; ConstraintGraphBuilder cgb{ sourceModule.name, @@ -925,7 +933,7 @@ ModulePtr check( freeze(result->internalTypes); freeze(result->interfaceTypes); - if (FFlag::DebugLuauLogSolverToJson) + if (recordJsonLog) { std::string output = logger->compileOutput(); printf("%s\n", output.c_str()); @@ -934,20 +942,11 @@ ModulePtr check( return result; } -ModulePtr Frontend::check( - const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete) +ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete, bool recordJsonLog) { - return Luau::check( - sourceModule, - requireCycles, - builtinTypes, - NotNull{&iceHandler}, - NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, - NotNull{fileResolver}, - forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope, - NotNull{&typeChecker.unifierState}, - options - ); + return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, + NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, + forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope, options, recordJsonLog); } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 65ad8a825..f850bd3d1 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,6 +14,8 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) +LUAU_FASTFLAGVARIABLE(LuauImproveDeprecatedApiLint, false) + namespace Luau { @@ -2100,7 +2102,7 @@ class LintDeprecatedApi : AstVisitor public: LUAU_NOINLINE static void process(LintContext& context) { - if (!context.module) + if (!FFlag::LuauImproveDeprecatedApiLint && !context.module) return; LintDeprecatedApi pass{&context}; @@ -2117,26 +2119,51 @@ class LintDeprecatedApi : AstVisitor bool visit(AstExprIndexName* node) override { - std::optional ty = context->getType(node->expr); - if (!ty) - return true; + if (std::optional ty = context->getType(node->expr)) + check(node, follow(*ty)); + else if (AstExprGlobal* global = node->expr->as()) + if (FFlag::LuauImproveDeprecatedApiLint) + check(node->location, global->name, node->index); - if (const ClassType* cty = get(follow(*ty))) + return true; + } + + void check(AstExprIndexName* node, TypeId ty) + { + if (const ClassType* cty = get(ty)) { const Property* prop = lookupClassProp(cty, node->index.value); if (prop && prop->deprecated) report(node->location, *prop, cty->name.c_str(), node->index.value); } - else if (const TableType* tty = get(follow(*ty))) + else if (const TableType* tty = get(ty)) { auto prop = tty->props.find(node->index.value); if (prop != tty->props.end() && prop->second.deprecated) - report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); + { + // strip synthetic typeof() for builtin tables + if (FFlag::LuauImproveDeprecatedApiLint && tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')') + report(node->location, prop->second, tty->name->substr(7, tty->name->length() - 8).c_str(), node->index.value); + else + report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); + } } + } - return true; + void check(const Location& location, AstName global, AstName index) + { + if (const LintContext::Global* gv = context->builtinGlobals.find(global)) + { + if (const TableType* tty = get(gv->type)) + { + auto prop = tty->props.find(index.value); + + if (prop != tty->props.end() && prop->second.deprecated) + report(location, prop->second, global.value, index.value); + } + } } void report(const Location& location, const Property& prop, const char* container, const char* field) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 0b7608104..0552bec03 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -2653,6 +2653,14 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) intersectNormals(here, *negated); } } + else if (get(t)) + { + // HACK: Refinements sometimes intersect with ~any under the + // assumption that it is the same as any. + return true; + } + else if (auto nt = get(t)) + return intersectNormalWithTy(here, nt->ty); else { // TODO negated unions, intersections, table, and function. diff --git a/Analysis/src/Refinement.cpp b/Analysis/src/Refinement.cpp index 459379ad9..a81063c7b 100644 --- a/Analysis/src/Refinement.cpp +++ b/Analysis/src/Refinement.cpp @@ -4,6 +4,11 @@ namespace Luau { +RefinementId RefinementArena::variadic(const std::vector& refis) +{ + return NotNull{allocator.allocate(Variadic{refis})}; +} + RefinementId RefinementArena::negation(RefinementId refinement) { return NotNull{allocator.allocate(Negation{refinement})}; @@ -24,14 +29,9 @@ RefinementId RefinementArena::equivalence(RefinementId lhs, RefinementId rhs) return NotNull{allocator.allocate(Equivalence{lhs, rhs})}; } -RefinementId RefinementArena::proposition(DefId def, TypeId discriminantTy) +RefinementId RefinementArena::proposition(BreadcrumbId breadcrumb, TypeId discriminantTy) { - return NotNull{allocator.allocate(Proposition{def, discriminantTy})}; -} - -RefinementId RefinementArena::variadic(const std::vector& refis) -{ - return NotNull{allocator.allocate(Variadic{refis})}; + return NotNull{allocator.allocate(Proposition{breadcrumb, discriminantTy})}; } } // namespace Luau diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 2f69f6980..f15f8c4cf 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -745,6 +745,7 @@ BuiltinTypes::BuiltinTypes() , functionType(arena->addType(Type{PrimitiveType{PrimitiveType::Function}, /*persistent*/ true})) , classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}}, /*persistent*/ true})) , tableType(arena->addType(Type{PrimitiveType{PrimitiveType::Table}, /*persistent*/ true})) + , emptyTableType(arena->addType(Type{TableType{TableState::Sealed, TypeLevel{}, nullptr}, /*persistent*/ true})) , trueType(arena->addType(Type{SingletonType{BooleanSingleton{true}}, /*persistent*/ true})) , falseType(arena->addType(Type{SingletonType{BooleanSingleton{false}}, /*persistent*/ true})) , anyType(arena->addType(Type{AnyType{}, /*persistent*/ true})) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index f23fad780..aacfd7295 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -19,7 +19,6 @@ #include -LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauDontReduceTypes) @@ -105,8 +104,6 @@ struct TypeChecker2 , sourceModule(sourceModule) , module(module) { - if (FFlag::DebugLuauLogSolverToJson) - LUAU_ASSERT(logger); } std::optional pushStack(AstNode* node) @@ -918,13 +915,9 @@ struct TypeChecker2 reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location); } - void visit(AstExprCall* call) + // Note: this is intentionally separated from `visit(AstExprCall*)` for stack allocation purposes. + void visitCall(AstExprCall* call) { - visit(call->func, RValue); - - for (AstExpr* arg : call->args) - visit(arg, RValue); - TypeArena* arena = &testArena; Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; @@ -1099,6 +1092,16 @@ struct TypeChecker2 reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); } + void visit(AstExprCall* call) + { + visit(call->func, RValue); + + for (AstExpr* arg : call->args) + visit(arg, RValue); + + visitCall(call); + } + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) { visit(expr, RValue); @@ -1169,9 +1172,9 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back())) + if (!isSubtype(inferredArgTy, annotatedArgTy, stack.back())) { - reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); + reportError(TypeMismatch{inferredArgTy, annotatedArgTy}, arg->location); } } @@ -1726,7 +1729,7 @@ struct TypeChecker2 } } - for (size_t i = packsProvided; i < packsProvided; ++i) + for (size_t i = packsProvided; i < packsRequired; ++i) { if (alias->typePackParams[i].defaultValue) { @@ -1948,7 +1951,7 @@ struct TypeChecker2 { module->errors.emplace_back(location, sourceModule->name, std::move(data)); - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->captureTypeCheckError(module->errors.back()); } @@ -2053,8 +2056,8 @@ struct TypeChecker2 if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)) return true; - else if (tt->indexer && isPrim(tt->indexer->indexResultType, PrimitiveType::String)) - return tt->indexer->indexResultType; + else if (tt->indexer && isPrim(tt->indexer->indexType, PrimitiveType::String)) + return true; else return false; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index adca034c4..6aa8e6cac 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -937,9 +937,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) TypeId right = nullptr; - Location loc = 0 == assign.values.size - ? assign.location - : i < assign.values.size ? assign.values.data[i]->location : assign.values.data[assign.values.size - 1]->location; + Location loc = 0 == assign.values.size ? assign.location + : i < assign.values.size ? assign.values.data[i]->location + : assign.values.data[assign.values.size - 1]->location; if (valueIter != valueEnd) { @@ -3170,7 +3170,8 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex property.location = expr.indexLocation; return theType; } - else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free)) + else if (FFlag::LuauDontExtendUnsealedRValueTables && + ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free)) { TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; @@ -3299,7 +3300,8 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex property.location = expr.index->location; return resultType; } - else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + else if (FFlag::LuauDontExtendUnsealedRValueTables && + ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) { TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; @@ -3321,7 +3323,8 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return resultType; } - else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + else if (FFlag::LuauDontExtendUnsealedRValueTables && + ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) { TypeId indexerType = freshType(exprTable->level); unify(indexType, indexerType, scope, expr.location); diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 2393829df..abafa9fbc 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -434,6 +434,14 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) return std::nullopt; // error & T ~ error & T else if (get(right)) return std::nullopt; // T & error ~ T & error + else if (get(left)) + return std::nullopt; // *blocked* & T ~ *blocked* & T + else if (get(right)) + return std::nullopt; // T & *blocked* ~ T & *blocked* + else if (get(left)) + return std::nullopt; // *pending* & T ~ *pending* & T + else if (get(right)) + return std::nullopt; // T & *pending* ~ T & *pending* else if (auto ut = get(left)) return reduce(distribute(begin(ut), end(ut), &TypeReducer::intersectionType, right)); // (A | B) & T ~ (A & T) | (B & T) else if (get(right)) diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index f8f51bcf1..e5029e587 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -117,7 +117,8 @@ std::pair> getParameterExtents(const TxnLog* log, return {minCount, minCount + optionalCount}; } -TypePack extendTypePack(TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length) +TypePack extendTypePack( + TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length, std::vector> overrides) { TypePack result; @@ -179,11 +180,22 @@ TypePack extendTypePack(TypeArena& arena, NotNull builtinTypes, Ty TypePack newPack; newPack.tail = arena.freshTypePack(ftp->scope); - + size_t overridesIndex = 0; while (result.head.size() < length) { - newPack.head.push_back(arena.freshType(ftp->scope)); + TypeId t; + if (overridesIndex < overrides.size() && overrides[overridesIndex]) + { + t = *overrides[overridesIndex]; + } + else + { + t = arena.freshType(ftp->scope); + } + + newPack.head.push_back(t); result.head.push_back(newPack.head.back()); + overridesIndex++; } asMutable(pack)->ty.emplace(std::move(newPack)); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 6364a5aa4..aba642714 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -312,7 +312,7 @@ TypePackId Widen::operator()(TypePackId tp) return substitute(tp).value_or(tp); } -static std::optional hasUnificationTooComplex(const ErrorVec& errors) +std::optional hasUnificationTooComplex(const ErrorVec& errors) { auto isUnificationTooComplex = [](const TypeError& te) { return nullptr != get(te); @@ -375,7 +375,6 @@ Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope , variance(variance) , sharedState(*normalizer->sharedState) { - normalize = true; LUAU_ASSERT(sharedState.iceHandler); } @@ -561,6 +560,11 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyFunctions(subTy, superTy, isFunctionCall); + else if (auto table = log.get(superTy); table && table->type == PrimitiveType::Table) + tryUnify(subTy, builtinTypes->emptyTableType, isFunctionCall, isIntersection); + else if (auto table = log.get(subTy); table && table->type == PrimitiveType::Table) + tryUnify(builtinTypes->emptyTableType, superTy, isFunctionCall, isIntersection); + else if (log.getMutable(superTy) && log.getMutable(subTy)) { tryUnifyTables(subTy, superTy, isIntersection); @@ -591,7 +595,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.get(superTy) || log.get(subTy)) tryUnifyNegations(subTy, superTy); - else if (FFlag::LuauUninhabitedSubAnything2 && !normalizer->isInhabited(subTy)) + else if (FFlag::LuauUninhabitedSubAnything2 && checkInhabited && !normalizer->isInhabited(subTy)) { } @@ -1769,6 +1773,12 @@ struct Resetter void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { + if (isPrim(log.follow(subTy), PrimitiveType::Table)) + subTy = builtinTypes->emptyTableType; + + if (isPrim(log.follow(superTy), PrimitiveType::Table)) + superTy = builtinTypes->emptyTableType; + TypeId activeSubTy = subTy; TableType* superTable = log.getMutable(superTy); TableType* subTable = log.getMutable(subTy); @@ -2092,7 +2102,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) TypeId osubTy = subTy; TypeId osuperTy = superTy; - if (FFlag::LuauUninhabitedSubAnything2 && !normalizer->isInhabited(subTy)) + if (FFlag::LuauUninhabitedSubAnything2 && checkInhabited && !normalizer->isInhabited(subTy)) return; if (reversed) @@ -2682,6 +2692,7 @@ Unifier Unifier::makeChildUnifier() { Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; u.normalize = normalize; + u.checkInhabited = checkInhabited; u.useScopes = useScopes; return u; } diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 118b06798..dac3b95b6 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,6 +6,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauFixInterpStringMid, false) + namespace Luau { @@ -640,7 +642,8 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT } consume(); - Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset - 1); + Lexeme lexemeOutput(Location(start, position()), FFlag::LuauFixInterpStringMid ? formatType : Lexeme::InterpStringBegin, + &buffer[startOffset], offset - startOffset - 1); return lexemeOutput; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 4d61914f7..4c347712f 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1248,7 +1248,11 @@ std::pair Parser::parseReturnTypeAnnotation() { AstType* returnType = parseTypeAnnotation(result, innerBegin); - return {Location{location, returnType->location}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + // If parseTypeAnnotation parses nothing, then returnType->location.end only points at the last non-type-pack + // type to successfully parse. We need the span of the whole annotation. + Position endPos = result.size() == 1 ? location.end : returnType->location.end; + + return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; } return {location, AstTypeList{copy(result), varargAnnotation}}; @@ -2623,8 +2627,6 @@ AstExpr* Parser::parseInterpString() endLocation = currentLexeme.location; - Location startOfBrace = Location(endLocation.end, 1); - scratchData.assign(currentLexeme.data, currentLexeme.length); if (!Lexer::fixupQuotedString(scratchData)) diff --git a/CLI/Reduce.cpp b/CLI/Reduce.cpp index d24c9874c..b7c780128 100644 --- a/CLI/Reduce.cpp +++ b/CLI/Reduce.cpp @@ -487,7 +487,7 @@ int main(int argc, char** argv) if (args.size() < 4) help(args); - for (int i = 1; i < args.size(); ++i) + for (size_t i = 1; i < args.size(); ++i) { if (args[i] == "--help") help(args); diff --git a/CMakeLists.txt b/CMakeLists.txt index 4255c7c25..6e15e5f88 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,9 +14,11 @@ option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) option(LUAU_NATIVE "Enable support for native code generation" OFF) +cmake_policy(SET CMP0054 NEW) +cmake_policy(SET CMP0091 NEW) + if(LUAU_STATIC_CRT) cmake_minimum_required(VERSION 3.15) - cmake_policy(SET CMP0091 NEW) set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") endif() @@ -88,9 +90,15 @@ set(LUAU_OPTIONS) if(MSVC) list(APPEND LUAU_OPTIONS /D_CRT_SECURE_NO_WARNINGS) # We need to use the portable CRT functions. - list(APPEND LUAU_OPTIONS /MP) # Distribute single project compilation across multiple cores + list(APPEND LUAU_OPTIONS "/we4018") # Signed/unsigned mismatch + list(APPEND LUAU_OPTIONS "/we4388") # Also signed/unsigned mismatch else() list(APPEND LUAU_OPTIONS -Wall) # All warnings + list(APPEND LUAU_OPTIONS -Wsign-compare) # This looks to be included in -Wall for GCC but not clang +endif() + +if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + list(APPEND LUAU_OPTIONS /MP) # Distribute single project compilation across multiple cores endif() if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") @@ -115,7 +123,7 @@ endif() set(ISOCLINE_OPTIONS) -if (NOT MSVC) +if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") list(APPEND ISOCLINE_OPTIONS -Wno-unused-function) endif() @@ -137,7 +145,7 @@ if(LUAU_NATIVE) target_compile_definitions(Luau.VM PUBLIC LUA_CUSTOM_EXECUTION=1) endif() -if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) +if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" AND MSVC_VERSION GREATER_EQUAL 1924) # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: # https://developercommunity.visualstudio.com/t/performance-regression-on-a-complex-interpreter-lo/1631863 set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h index 53efd3c37..2c852046c 100644 --- a/CodeGen/include/Luau/AddressA64.h +++ b/CodeGen/include/Luau/AddressA64.h @@ -7,6 +7,8 @@ namespace Luau { namespace CodeGen { +namespace A64 +{ enum class AddressKindA64 : uint8_t { @@ -49,5 +51,6 @@ struct AddressA64 using mem = AddressA64; +} // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 9e12168a0..94d8f8114 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -13,6 +13,8 @@ namespace Luau { namespace CodeGen { +namespace A64 +{ class AssemblyBuilderA64 { @@ -157,5 +159,6 @@ class AssemblyBuilderA64 uint32_t* codeEnd = nullptr; }; +} // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 235f1a84e..597f2b2c3 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -14,6 +14,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ enum class RoundingModeX64 { @@ -242,5 +244,6 @@ class AssemblyBuilderX64 uint8_t* codeEnd = nullptr; }; +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/ConditionA64.h b/CodeGen/include/Luau/ConditionA64.h index e208d8cb0..0beadad52 100644 --- a/CodeGen/include/Luau/ConditionA64.h +++ b/CodeGen/include/Luau/ConditionA64.h @@ -5,6 +5,8 @@ namespace Luau { namespace CodeGen { +namespace A64 +{ enum class ConditionA64 { @@ -33,5 +35,6 @@ enum class ConditionA64 Count }; +} // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 0941d475d..d3e1a9334 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -1,16 +1,26 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include + +#include + namespace Luau { namespace CodeGen { +struct IrBlock; struct IrFunction; void updateUseCounts(IrFunction& function); void updateLastUseLocations(IrFunction& function); +// Returns how many values are coming into the block (live in) and how many are coming out of the block (live out) +std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block); +uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block); +uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 295534214..916c6eeda 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -27,6 +27,12 @@ struct IrBuilder bool isInternalBlock(IrOp block); void beginBlock(IrOp block); + void loadAndCheckTag(IrOp loc, uint8_t tag, IrOp fallback); + + // Clones all instructions into the current block + // Source block that is cloned cannot use values coming in from a predecessor + void clone(const IrBlock& source, bool removeCurrentTerminator); + IrOp constBool(bool value); IrOp constInt(int value); IrOp constUint(unsigned value); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 18d510cc9..049d700af 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -108,6 +108,13 @@ enum class IrCmd : uint8_t MOD_NUM, POW_NUM, + // Get the minimum/maximum of two numbers + // If one of the values is NaN, 'B' is returned as the result + // A, B: double + // In final x64 lowering, B can also be Rn or Kn + MIN_NUM, + MAX_NUM, + // Negate a double number // A: double UNM_NUM, @@ -197,6 +204,29 @@ enum class IrCmd : uint8_t // This is used to recover after calling a variadic function ADJUST_STACK_TO_TOP, + // Execute fastcall builtin function in-place + // A: builtin + // B: Rn (result start) + // C: Rn (argument start) + // D: Rn or Kn or a boolean that's false (optional second argument) + // E: int (argument count or -1 to use all arguments up to stack top) + // F: int (result count or -1 to preserve all results and adjust stack top) + FASTCALL, + + // Call the fastcall builtin function + // A: builtin + // B: Rn (result start) + // C: Rn (argument start) + // D: Rn or Kn or a boolean that's false (optional second argument) + // E: int (argument count or -1 to use all arguments up to stack top) + // F: int (result count or -1 to preserve all results and adjust stack top) + INVOKE_FASTCALL, + + // Check that fastcall builtin function invocation was successful (negative result count jumps to fallback) + // A: int (result count) + // B: block (fallback) + CHECK_FASTCALL_RES, + // Fallback functions // Perform an arithmetic operation on TValues of any type @@ -351,39 +381,26 @@ enum class IrCmd : uint8_t // C: int (result count or -1 to return all values up to stack top) LOP_RETURN, - // Perform a fast call of a built-in function - // A: unsigned int (bytecode instruction index) - // B: Rn (argument start) - // C: int (argument count or -1 use all arguments up to stack top) - // D: block (fallback) - // Note: return values are placed starting from Rn specified in 'B' - LOP_FASTCALL, - - // Perform a fast call of a built-in function using 1 register argument - // A: unsigned int (bytecode instruction index) - // B: Rn (result start) - // C: Rn (arg1) - // D: block (fallback) - LOP_FASTCALL1, + // Adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue + // A: Rn (loop variable start, updates Rn+2 Rn+3 Rn+4) + // B: int (loop variable count, is more than 2, additional registers are set to nil) + // C: block (repeat) + // D: block (exit) + LOP_FORGLOOP, - // Perform a fast call of a built-in function using 2 register arguments + // Handle LOP_FORGLOOP fallback when variable being iterated is not a table // A: unsigned int (bytecode instruction index) - // B: Rn (result start) - // C: Rn (arg1) - // D: Rn (arg2) - // E: block (fallback) - LOP_FASTCALL2, + // B: Rn (loop state start, updates Rn+2 Rn+3 Rn+4 Rn+5) + // C: int (extra variable count or -1 for ipairs-style iteration) + // D: block (repeat) + // E: block (exit) + LOP_FORGLOOP_FALLBACK, - // Perform a fast call of a built-in function using 1 register argument and 1 constant argument + // Fallback for generic for loop preparation when iterating over builtin pairs/ipairs + // It raises an error if 'B' register is not a function // A: unsigned int (bytecode instruction index) - // B: Rn (result start) - // C: Rn (arg1) - // D: Kn (arg2) - // E: block (fallback) - LOP_FASTCALL2K, - - LOP_FORGLOOP, - LOP_FORGLOOP_FALLBACK, + // B: Rn + // C: block (forgloop location) LOP_FORGPREP_XNEXT_FALLBACK, // Perform `and` or `or` operation (selecting lhs or rhs based on whether the lhs is truthy) and put the result into target register @@ -462,7 +479,7 @@ enum class IrCmd : uint8_t // Prepare loop variables for a generic for loop, jump to the loop backedge unconditionally // A: unsigned int (bytecode instruction index) - // B: Rn (loop state, updates Rn Rn+1 Rn+2) + // B: Rn (loop state start, updates Rn Rn+1 Rn+2) // C: block FALLBACK_FORGPREP, @@ -577,8 +594,8 @@ struct IrInst uint16_t useCount = 0; // Location of the result (optional) - RegisterX64 regX64 = noreg; - RegisterA64 regA64{KindA64::none, 0}; + X64::RegisterX64 regX64 = X64::noreg; + A64::RegisterA64 regA64 = A64::noreg; bool reusedReg = false; }; @@ -587,6 +604,7 @@ enum class IrBlockKind : uint8_t Bytecode, Fallback, Internal, + Linearized, Dead, }; diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 0a23b3f77..153cf7ade 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -133,6 +133,8 @@ inline bool hasResult(IrCmd cmd) case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: case IrCmd::NOT_ANY: case IrCmd::TABLE_LEN: @@ -141,6 +143,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::NUM_TO_INDEX: case IrCmd::INT_TO_NUM: case IrCmd::SUBSTITUTE: + case IrCmd::INVOKE_FASTCALL: return true; default: break; @@ -151,6 +154,9 @@ inline bool hasResult(IrCmd cmd) inline bool hasSideEffects(IrCmd cmd) { + if (cmd == IrCmd::INVOKE_FASTCALL) + return true; + // Instructions that don't produce a result most likely have other side-effects to make them useful // Right now, a full switch would mirror the 'hasResult' function, so we use this simple condition return !hasResult(cmd); @@ -164,6 +170,10 @@ inline bool isPseudo(IrCmd cmd) bool isGCO(uint8_t tag); +// Manually add or remove use of an operand +void addUse(IrFunction& function, IrOp op); +void removeUse(IrFunction& function, IrOp op); + // Remove a single instruction void kill(IrFunction& function, IrInst& inst); diff --git a/CodeGen/include/Luau/OperandX64.h b/CodeGen/include/Luau/OperandX64.h index 5ad38e907..b9aa8f54b 100644 --- a/CodeGen/include/Luau/OperandX64.h +++ b/CodeGen/include/Luau/OperandX64.h @@ -10,6 +10,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ enum class CategoryX64 : uint8_t { @@ -138,5 +140,6 @@ constexpr OperandX64 operator+(RegisterX64 base, OperandX64 op) return op; } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/RegisterA64.h b/CodeGen/include/Luau/RegisterA64.h index 2d56276f7..519e83fcf 100644 --- a/CodeGen/include/Luau/RegisterA64.h +++ b/CodeGen/include/Luau/RegisterA64.h @@ -9,6 +9,8 @@ namespace Luau { namespace CodeGen { +namespace A64 +{ enum class KindA64 : uint8_t { @@ -33,6 +35,8 @@ struct RegisterA64 } }; +constexpr RegisterA64 noreg{KindA64::none, 0}; + constexpr RegisterA64 w0{KindA64::w, 0}; constexpr RegisterA64 w1{KindA64::w, 1}; constexpr RegisterA64 w2{KindA64::w, 2}; @@ -101,5 +105,6 @@ constexpr RegisterA64 xzr{KindA64::x, 31}; constexpr RegisterA64 sp{KindA64::none, 31}; +} // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/RegisterX64.h b/CodeGen/include/Luau/RegisterX64.h index adc2db0cc..9d76b1169 100644 --- a/CodeGen/include/Luau/RegisterX64.h +++ b/CodeGen/include/Luau/RegisterX64.h @@ -9,6 +9,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ enum class SizeX64 : uint8_t { @@ -133,5 +135,6 @@ constexpr RegisterX64 qwordReg(RegisterX64 reg) return RegisterX64{SizeX64::qword, reg.index}; } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h index b7237318a..98e604982 100644 --- a/CodeGen/include/Luau/UnwindBuilder.h +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -21,10 +21,10 @@ class UnwindBuilder virtual void start() = 0; - virtual void spill(int espOffset, RegisterX64 reg) = 0; - virtual void save(RegisterX64 reg) = 0; + virtual void spill(int espOffset, X64::RegisterX64 reg) = 0; + virtual void save(X64::RegisterX64 reg) = 0; virtual void allocStack(int size) = 0; - virtual void setupFrameReg(RegisterX64 reg, int espOffset) = 0; + virtual void setupFrameReg(X64::RegisterX64 reg, int espOffset) = 0; virtual void finish() = 0; diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index dab6e9573..972f7423b 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -17,10 +17,10 @@ class UnwindBuilderDwarf2 : public UnwindBuilder void start() override; - void spill(int espOffset, RegisterX64 reg) override; - void save(RegisterX64 reg) override; + void spill(int espOffset, X64::RegisterX64 reg) override; + void save(X64::RegisterX64 reg) override; void allocStack(int size) override; - void setupFrameReg(RegisterX64 reg, int espOffset) override; + void setupFrameReg(X64::RegisterX64 reg, int espOffset) override; void finish() override; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index 005137712..1cd750a1d 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -27,10 +27,10 @@ class UnwindBuilderWin : public UnwindBuilder void start() override; - void spill(int espOffset, RegisterX64 reg) override; - void save(RegisterX64 reg) override; + void spill(int espOffset, X64::RegisterX64 reg) override; + void save(X64::RegisterX64 reg) override; void allocStack(int size) override; - void setupFrameReg(RegisterX64 reg, int espOffset) override; + void setupFrameReg(X64::RegisterX64 reg, int espOffset) override; void finish() override; @@ -45,7 +45,7 @@ class UnwindBuilderWin : public UnwindBuilder std::vector unwindCodes; uint8_t prologSize = 0; - RegisterX64 frameReg = rax; // rax means that frame register is not used + X64::RegisterX64 frameReg = X64::rax; // rax means that frame register is not used uint8_t frameRegOffset = 0; uint32_t stackOffset = 0; diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index 286800d6d..308747d26 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -9,6 +9,8 @@ namespace Luau { namespace CodeGen { +namespace A64 +{ static const uint8_t codeForCondition[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}; static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered"); @@ -719,5 +721,6 @@ void AssemblyBuilderA64::log(AddressA64 addr) text.append("]"); } +} // namespace A64 } // namespace CodeGen -} // namespace Luau \ No newline at end of file +} // namespace Luau diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 71bfaec11..bf7889b89 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -11,6 +11,9 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ + // TODO: more assertions on operand sizes static const uint8_t codeForCondition[] = { @@ -1475,5 +1478,6 @@ const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg) const return names[size_t(reg.size)][reg.index]; } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 5076cba2d..51bf17461 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -41,7 +41,7 @@ namespace CodeGen constexpr uint32_t kFunctionAlignment = 32; -static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) +static void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) { if (build.logText) build.logAppend("; exitContinueVm\n"); @@ -59,7 +59,7 @@ static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) emitContinueCallInVm(build); } -static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +static NativeProto* assembleFunction(X64::AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { NativeProto* result = new NativeProto(); @@ -78,7 +78,7 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat build.logAppend("\n"); } - build.align(kFunctionAlignment, AlignmentDataX64::Ud2); + build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); Label start = build.setLabel(); @@ -92,7 +92,7 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat optimizeMemoryOperandsX64(builder.function); - IrLoweringX64 lowering(build, helpers, data, proto, builder.function); + X64::IrLoweringX64 lowering(build, helpers, data, proto, builder.function); lowering.lower(options); @@ -213,7 +213,7 @@ void create(lua_State* L) initFallbackTable(data); initHelperFunctions(data); - if (!x64::initEntryFunction(data)) + if (!X64::initEntryFunction(data)) { destroyNativeState(L); return; @@ -251,7 +251,7 @@ void compile(lua_State* L, int idx) if (!getNativeState(L)) return; - AssemblyBuilderX64 build(/* logText= */ false); + X64::AssemblyBuilderX64 build(/* logText= */ false); NativeState* data = getNativeState(L); std::vector protos; @@ -302,7 +302,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) LUAU_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); - AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); + X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); NativeState data; initFallbackTable(data); diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index b23d2b38c..ac6c9416c 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -38,7 +38,7 @@ namespace Luau { namespace CodeGen { -namespace x64 +namespace X64 { bool initEntryFunction(NativeState& data) @@ -143,6 +143,6 @@ bool initEntryFunction(NativeState& data) return true; } -} // namespace x64 +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenX64.h b/CodeGen/src/CodeGenX64.h index 6791f7f3b..b82266af7 100644 --- a/CodeGen/src/CodeGenX64.h +++ b/CodeGen/src/CodeGenX64.h @@ -8,11 +8,11 @@ namespace CodeGen struct NativeState; -namespace x64 +namespace X64 { bool initEntryFunction(NativeState& data); -} // namespace x64 +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 0a3b3609d..05b63551b 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -5,7 +5,7 @@ #include "Luau/Bytecode.h" #include "EmitCommonX64.h" -#include "IrTranslateBuiltins.h" // Used temporarily for shared definition of BuiltinImplResult +#include "IrRegAllocX64.h" #include "NativeState.h" #include "lstate.h" @@ -16,343 +16,135 @@ namespace Luau { namespace CodeGen { - -BuiltinImplResult emitBuiltinMathFloor(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +namespace X64 { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_FLOOR\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vroundsd(xmm0, xmm0, luauRegValue(arg), RoundingModeX64::RoundToNegativeInfinity); - build.vmovsd(luauRegValue(ra), xmm0); - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; +void emitBuiltinMathFloor(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +{ + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + build.vroundsd(tmp.reg, tmp.reg, luauRegValue(arg), RoundingModeX64::RoundToNegativeInfinity); + build.vmovsd(luauRegValue(ra), tmp.reg); } -BuiltinImplResult emitBuiltinMathCeil(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathCeil(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_CEIL\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vroundsd(xmm0, xmm0, luauRegValue(arg), RoundingModeX64::RoundToPositiveInfinity); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + build.vroundsd(tmp.reg, tmp.reg, luauRegValue(arg), RoundingModeX64::RoundToPositiveInfinity); + build.vmovsd(luauRegValue(ra), tmp.reg); } -BuiltinImplResult emitBuiltinMathSqrt(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathSqrt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_SQRT\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vsqrtsd(xmm0, xmm0, luauRegValue(arg)); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + build.vsqrtsd(tmp.reg, tmp.reg, luauRegValue(arg)); + build.vmovsd(luauRegValue(ra), tmp.reg); } -BuiltinImplResult emitBuiltinMathAbs(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathAbs(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_ABS\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vmovsd(xmm0, luauRegValue(arg)); - build.vandpd(xmm0, xmm0, build.i64(~(1LL << 63))); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + build.vmovsd(tmp.reg, luauRegValue(arg)); + build.vandpd(tmp.reg, tmp.reg, build.i64(~(1LL << 63))); + build.vmovsd(luauRegValue(ra), tmp.reg); } -static BuiltinImplResult emitBuiltinMathSingleArgFunc( - AssemblyBuilderX64& build, int nparams, int ra, int arg, int nresults, Label& fallback, const char* name, int32_t offset) +static void emitBuiltinMathSingleArgFunc(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int32_t offset) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined %s\n", name); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); build.call(qword[rNativeContext + offset]); build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; -} - -BuiltinImplResult emitBuiltinMathExp(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) -{ - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_EXP", offsetof(NativeContext, libm_exp)); -} - -BuiltinImplResult emitBuiltinMathDeg(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) -{ - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_DEG\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - const double rpd = (3.14159265358979323846 / 180.0); - - build.vmovsd(xmm0, luauRegValue(arg)); - build.vdivsd(xmm0, xmm0, build.f64(rpd)); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult emitBuiltinMathRad(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathExp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_RAD\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - const double rpd = (3.14159265358979323846 / 180.0); - - build.vmovsd(xmm0, luauRegValue(arg)); - build.vmulsd(xmm0, xmm0, build.f64(rpd)); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_exp)); } -BuiltinImplResult emitBuiltinMathFmod(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathFmod(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_FMOD\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); build.call(qword[rNativeContext + offsetof(NativeContext, libm_fmod)]); build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult emitBuiltinMathPow(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathPow(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_POW\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult emitBuiltinMathMin(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathAsin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams != 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_MIN\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - - build.vmovsd(xmm0, qword[args + offsetof(TValue, value)]); - build.vminsd(xmm0, xmm0, luauRegValue(arg)); - - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_asin)); } -BuiltinImplResult emitBuiltinMathMax(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathSin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams != 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_MAX\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - - build.vmovsd(xmm0, qword[args + offsetof(TValue, value)]); - build.vmaxsd(xmm0, xmm0, luauRegValue(arg)); - - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_sin)); } -BuiltinImplResult emitBuiltinMathAsin(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathSinh(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_ASIN", offsetof(NativeContext, libm_asin)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_sinh)); } -BuiltinImplResult emitBuiltinMathSin(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathAcos(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_SIN", offsetof(NativeContext, libm_sin)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_acos)); } -BuiltinImplResult emitBuiltinMathSinh(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathCos(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_SINH", offsetof(NativeContext, libm_sinh)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_cos)); } -BuiltinImplResult emitBuiltinMathAcos(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathCosh(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_ACOS", offsetof(NativeContext, libm_acos)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_cosh)); } -BuiltinImplResult emitBuiltinMathCos(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathAtan(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_COS", offsetof(NativeContext, libm_cos)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_atan)); } -BuiltinImplResult emitBuiltinMathCosh(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathTan(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_COSH", offsetof(NativeContext, libm_cosh)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_tan)); } -BuiltinImplResult emitBuiltinMathAtan(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathTanh(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_ATAN", offsetof(NativeContext, libm_atan)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_tanh)); } -BuiltinImplResult emitBuiltinMathTan(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathAtan2(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_TAN", offsetof(NativeContext, libm_tan)); -} - -BuiltinImplResult emitBuiltinMathTanh(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) -{ - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_TANH", offsetof(NativeContext, libm_tanh)); -} - -BuiltinImplResult emitBuiltinMathAtan2(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) -{ - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_ATAN2\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); build.call(qword[rNativeContext + offsetof(NativeContext, libm_atan2)]); build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult emitBuiltinMathLog10(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathLog10(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_LOG10", offsetof(NativeContext, libm_log10)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_log10)); } -BuiltinImplResult emitBuiltinMathLog(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathLog(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_LOG\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); if (nparams == 1) @@ -367,19 +159,15 @@ BuiltinImplResult emitBuiltinMathLog(AssemblyBuilderX64& build, int nparams, int RegisterX64 tmp = rbx; OperandX64 arg2value = qword[args + offsetof(TValue, value)]; - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - build.vmovsd(xmm1, arg2value); - jumpOnNumberCmp(build, noreg, build.f64(2.0), xmm1, ConditionX64::NotEqual, log10check); + jumpOnNumberCmp(build, noreg, build.f64(2.0), xmm1, IrCondition::NotEqual, log10check); build.call(qword[rNativeContext + offsetof(NativeContext, libm_log2)]); build.jmp(exit); build.setLabel(log10check); - jumpOnNumberCmp(build, noreg, build.f64(10.0), xmm1, ConditionX64::NotEqual, logdivlog); + jumpOnNumberCmp(build, noreg, build.f64(10.0), xmm1, IrCondition::NotEqual, logdivlog); build.call(qword[rNativeContext + offsetof(NativeContext, libm_log10)]); build.jmp(exit); @@ -402,28 +190,11 @@ BuiltinImplResult emitBuiltinMathLog(AssemblyBuilderX64& build, int nparams, int } build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } - -BuiltinImplResult emitBuiltinMathLdexp(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_LDEXP\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); if (build.abi == ABIX64::Windows) @@ -434,48 +205,27 @@ BuiltinImplResult emitBuiltinMathLdexp(AssemblyBuilderX64& build, int nparams, i build.call(qword[rNativeContext + offsetof(NativeContext, libm_ldexp)]); build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult emitBuiltinMathRound(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathRound(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_ROUND\n"); + ScopedRegX64 tmp0{regs, SizeX64::xmmword}; + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); + build.vmovsd(tmp0.reg, luauRegValue(arg)); + build.vandpd(tmp1.reg, tmp0.reg, build.f64x2(-0.0, -0.0)); + build.vmovsd(tmp2.reg, build.i64(0x3fdfffffffffffff)); // 0.49999999999999994 + build.vorpd(tmp1.reg, tmp1.reg, tmp2.reg); + build.vaddsd(tmp0.reg, tmp0.reg, tmp1.reg); + build.vroundsd(tmp0.reg, tmp0.reg, tmp0.reg, RoundingModeX64::RoundToZero); - build.vmovsd(xmm0, luauRegValue(arg)); - build.vandpd(xmm1, xmm0, build.f64x2(-0.0, -0.0)); - build.vmovsd(xmm2, build.i64(0x3fdfffffffffffff)); // 0.49999999999999994 - build.vorpd(xmm1, xmm1, xmm2); - build.vaddsd(xmm0, xmm0, xmm1); - build.vroundsd(xmm0, xmm0, xmm0, RoundingModeX64::RoundToZero); - - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + build.vmovsd(luauRegValue(ra), tmp0.reg); } -BuiltinImplResult emitBuiltinMathFrexp(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 2) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_FREXP\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); if (build.abi == ABIX64::Windows) @@ -487,26 +237,13 @@ BuiltinImplResult emitBuiltinMathFrexp(AssemblyBuilderX64& build, int nparams, i build.vmovsd(luauRegValue(ra), xmm0); - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra + 1), xmm0); - build.mov(luauRegTag(ra + 1), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 2}; } -BuiltinImplResult emitBuiltinMathModf(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 2) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_MODF\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); if (build.abi == ABIX64::Windows) @@ -519,156 +256,109 @@ BuiltinImplResult emitBuiltinMathModf(AssemblyBuilderX64& build, int nparams, in build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra), xmm1); - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - build.vmovsd(luauRegValue(ra + 1), xmm0); - build.mov(luauRegTag(ra + 1), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 2}; } -BuiltinImplResult emitBuiltinMathSign(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; + ScopedRegX64 tmp0{regs, SizeX64::xmmword}; + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; + ScopedRegX64 tmp3{regs, SizeX64::xmmword}; - if (build.logText) - build.logAppend("; inlined LBF_MATH_SIGN\n"); + build.vmovsd(tmp0.reg, luauRegValue(arg)); + build.vxorpd(tmp1.reg, tmp1.reg, tmp1.reg); - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vmovsd(xmm0, luauRegValue(arg)); - build.vxorpd(xmm1, xmm1, xmm1); - - // Set xmm2 to -1 if arg < 0, else 0 - build.vcmpltsd(xmm2, xmm0, xmm1); - build.vmovsd(xmm3, build.f64(-1)); - build.vandpd(xmm2, xmm2, xmm3); + // Set tmp2 to -1 if arg < 0, else 0 + build.vcmpltsd(tmp2.reg, tmp0.reg, tmp1.reg); + build.vmovsd(tmp3.reg, build.f64(-1)); + build.vandpd(tmp2.reg, tmp2.reg, tmp3.reg); // Set mask bit to 1 if 0 < arg, else 0 - build.vcmpltsd(xmm0, xmm1, xmm0); - - // Result = (mask-bit == 1) ? 1.0 : xmm2 - // If arg < 0 then xmm2 is -1 and mask-bit is 0, result is -1 - // If arg == 0 then xmm2 is 0 and mask-bit is 0, result is 0 - // If arg > 0 then xmm2 is 0 and mask-bit is 1, result is 1 - build.vblendvpd(xmm0, xmm2, build.f64x2(1, 1), xmm0); + build.vcmpltsd(tmp0.reg, tmp1.reg, tmp0.reg); - build.vmovsd(luauRegValue(ra), xmm0); + // Result = (mask-bit == 1) ? 1.0 : tmp2 + // If arg < 0 then tmp2 is -1 and mask-bit is 0, result is -1 + // If arg == 0 then tmp2 is 0 and mask-bit is 0, result is 0 + // If arg > 0 then tmp2 is 0 and mask-bit is 1, result is 1 + build.vblendvpd(tmp0.reg, tmp2.reg, build.f64x2(1, 1), tmp0.reg); - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + build.vmovsd(luauRegValue(ra), tmp0.reg); } -BuiltinImplResult emitBuiltinMathClamp(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults) { - if (nparams < 3 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_CLAMP\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + sizeof(TValue) + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - - RegisterX64 min = xmm1; - RegisterX64 max = xmm2; - build.vmovsd(min, qword[args + offsetof(TValue, value)]); - build.vmovsd(max, qword[args + sizeof(TValue) + offsetof(TValue, value)]); + OperandX64 argsOp = 0; - jumpOnNumberCmp(build, noreg, min, max, ConditionX64::NotLessEqual, fallback); + if (args.kind == IrOpKind::VmReg) + argsOp = luauRegAddress(args.index); + else if (args.kind == IrOpKind::VmConst) + argsOp = luauConstantAddress(args.index); - build.vmaxsd(xmm0, min, luauRegValue(arg)); - build.vminsd(xmm0, max, xmm0); - - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; -} - - -BuiltinImplResult emitBuiltin(AssemblyBuilderX64& build, int bfid, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) -{ switch (bfid) { case LBF_ASSERT: - // This builtin fast-path was already translated to IR - return {BuiltinImplType::None, -1}; + case LBF_MATH_DEG: + case LBF_MATH_RAD: + case LBF_MATH_MIN: + case LBF_MATH_MAX: + case LBF_MATH_CLAMP: + // These instructions are fully translated to IR + break; case LBF_MATH_FLOOR: - return emitBuiltinMathFloor(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathFloor(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_CEIL: - return emitBuiltinMathCeil(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathCeil(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SQRT: - return emitBuiltinMathSqrt(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathSqrt(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ABS: - return emitBuiltinMathAbs(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathAbs(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_EXP: - return emitBuiltinMathExp(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_DEG: - return emitBuiltinMathDeg(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_RAD: - return emitBuiltinMathRad(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathExp(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_FMOD: - return emitBuiltinMathFmod(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathFmod(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_POW: - return emitBuiltinMathPow(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_MIN: - return emitBuiltinMathMin(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_MAX: - return emitBuiltinMathMax(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathPow(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ASIN: - return emitBuiltinMathAsin(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathAsin(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SIN: - return emitBuiltinMathSin(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathSin(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SINH: - return emitBuiltinMathSinh(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathSinh(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ACOS: - return emitBuiltinMathAcos(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathAcos(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_COS: - return emitBuiltinMathCos(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathCos(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_COSH: - return emitBuiltinMathCosh(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathCosh(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ATAN: - return emitBuiltinMathAtan(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathAtan(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_TAN: - return emitBuiltinMathTan(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathTan(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_TANH: - return emitBuiltinMathTanh(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathTanh(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ATAN2: - return emitBuiltinMathAtan2(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathAtan2(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_LOG10: - return emitBuiltinMathLog10(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathLog10(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_LOG: - return emitBuiltinMathLog(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathLog(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_LDEXP: - return emitBuiltinMathLdexp(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathLdexp(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ROUND: - return emitBuiltinMathRound(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathRound(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_FREXP: - return emitBuiltinMathFrexp(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathFrexp(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_MODF: - return emitBuiltinMathModf(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathModf(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SIGN: - return emitBuiltinMathSign(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_CLAMP: - return emitBuiltinMathClamp(build, nparams, ra, arg, args, nresults, fallback); + return emitBuiltinMathSign(regs, build, nparams, ra, arg, argsOp, nresults); default: - return {BuiltinImplType::None, -1}; + LUAU_ASSERT(!"missing x64 lowering"); + break; } } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitBuiltinsX64.h b/CodeGen/src/EmitBuiltinsX64.h index 1ee04e92e..5925a2b3d 100644 --- a/CodeGen/src/EmitBuiltinsX64.h +++ b/CodeGen/src/EmitBuiltinsX64.h @@ -6,12 +6,18 @@ namespace Luau namespace CodeGen { -class AssemblyBuilderX64; struct Label; +struct IrOp; + +namespace X64 +{ + +class AssemblyBuilderX64; struct OperandX64; -struct BuiltinImplResult; +struct IrRegAllocX64; -BuiltinImplResult emitBuiltin(AssemblyBuilderX64& build, int bfid, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback); +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults); +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h new file mode 100644 index 000000000..3c41c271d --- /dev/null +++ b/CodeGen/src/EmitCommon.h @@ -0,0 +1,29 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Label.h" + +namespace Luau +{ +namespace CodeGen +{ + +constexpr unsigned kTValueSizeLog2 = 4; +constexpr unsigned kLuaNodeSizeLog2 = 5; +constexpr unsigned kLuaNodeTagMask = 0xf; +constexpr unsigned kNextBitOffset = 4; + +constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field +constexpr unsigned kOffsetOfLuaNodeNext = 12; // offsetof cannot be used on a bit field +constexpr unsigned kOffsetOfInstructionC = 3; + +// Leaf functions that are placed in every module to perform common instruction sequences +struct ModuleHelpers +{ + Label exitContinueVm; + Label exitNoContinueVm; + Label continueCallInVm; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 7d36e17de..e9cfdc486 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -2,6 +2,7 @@ #include "EmitCommonX64.h" #include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrData.h" #include "CustomExecUtils.h" #include "NativeState.h" @@ -13,8 +14,10 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ -void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label) +void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label) { // Refresher on comi/ucomi EFLAGS: // CF only: less @@ -35,23 +38,23 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, // And because of NaN, integer check interchangeability like 'not less or equal' <-> 'greater' does not hold switch (cond) { - case ConditionX64::NotLessEqual: + case IrCondition::NotLessEqual: // (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN build.jcc(ConditionX64::NotAboveEqual, label); break; - case ConditionX64::LessEqual: + case IrCondition::LessEqual: // (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN build.jcc(ConditionX64::AboveEqual, label); break; - case ConditionX64::NotLess: + case IrCondition::NotLess: // (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN build.jcc(ConditionX64::NotAbove, label); break; - case ConditionX64::Less: + case IrCondition::Less: // (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN build.jcc(ConditionX64::Above, label); break; - case ConditionX64::NotEqual: + case IrCondition::NotEqual: // ZF=0 or PF=1 means != or NaN build.jcc(ConditionX64::NotZero, label); build.jcc(ConditionX64::Parity, label); @@ -61,25 +64,25 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, } } -void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label) +void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label) { build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(ra)); build.lea(rArg3, luauRegAddress(rb)); - if (cond == ConditionX64::NotLessEqual || cond == ConditionX64::LessEqual) + if (cond == IrCondition::NotLessEqual || cond == IrCondition::LessEqual) build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]); - else if (cond == ConditionX64::NotLess || cond == ConditionX64::Less) + else if (cond == IrCondition::NotLess || cond == IrCondition::Less) build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]); - else if (cond == ConditionX64::NotEqual || cond == ConditionX64::Equal) + else if (cond == IrCondition::NotEqual || cond == IrCondition::Equal) build.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]); else LUAU_ASSERT(!"Unsupported condition"); emitUpdateBase(build); build.test(eax, eax); - build.jcc(cond == ConditionX64::NotLessEqual || cond == ConditionX64::NotLess || cond == ConditionX64::NotEqual ? ConditionX64::Zero - : ConditionX64::NotZero, + build.jcc(cond == IrCondition::NotLessEqual || cond == IrCondition::NotLess || cond == IrCondition::NotEqual ? ConditionX64::Zero + : ConditionX64::NotZero, label); } @@ -377,5 +380,6 @@ void emitContinueCallInVm(AssemblyBuilderX64& build) emitExit(build, /* continueInVm */ true); } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 8d6e36d6e..6b6762550 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -3,6 +3,8 @@ #include "Luau/AssemblyBuilderX64.h" +#include "EmitCommon.h" + #include "lobject.h" #include "ltm.h" @@ -23,8 +25,12 @@ namespace Luau namespace CodeGen { +enum class IrCondition : uint8_t; struct NativeState; +namespace X64 +{ + // Data that is very common to access is placed in non-volatile registers constexpr RegisterX64 rState = r15; // lua_State* L constexpr RegisterX64 rBase = r14; // StkId base @@ -65,23 +71,6 @@ constexpr OperandX64 sArg6 = noreg; #endif -constexpr unsigned kTValueSizeLog2 = 4; -constexpr unsigned kLuaNodeSizeLog2 = 5; -constexpr unsigned kLuaNodeTagMask = 0xf; -constexpr unsigned kNextBitOffset = 4; - -constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field -constexpr unsigned kOffsetOfLuaNodeNext = 12; // offsetof cannot be used on a bit field -constexpr unsigned kOffsetOfInstructionC = 3; - -// Leaf functions that are placed in every module to perform common instruction sequences -struct ModuleHelpers -{ - Label exitContinueVm; - Label exitNoContinueVm; - Label continueCallInVm; -}; - inline OperandX64 luauReg(int ri) { return xmmword[rBase + ri * sizeof(TValue)]; @@ -243,8 +232,8 @@ inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX6 jumpIfNodeValueTagIs(build, node, LUA_TNIL, label); } -void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label); -void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label); +void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label); +void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label); void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); @@ -268,5 +257,6 @@ void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpo void emitContinueCallInVm(AssemblyBuilderX64& build); +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 97ff9f59d..3b0aa258b 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -7,7 +7,6 @@ #include "EmitBuiltinsX64.h" #include "EmitCommonX64.h" #include "NativeState.h" -#include "IrTranslateBuiltins.h" // Used temporarily until emitInstFastCallN is removed #include "lobject.h" #include "ltm.h" @@ -16,6 +15,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback) { @@ -481,137 +482,12 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne callBarrierTableFast(build, table, next); } -static void emitInstFastCallN( - AssemblyBuilderX64& build, const Instruction* pc, bool customParams, int customParamCount, OperandX64 customArgs, int pcpos, Label& fallback) -{ - int bfid = LUAU_INSN_A(*pc); - int skip = LUAU_INSN_C(*pc); - - Instruction call = pc[skip + 1]; - LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); - int ra = LUAU_INSN_A(call); - - int nparams = customParams ? customParamCount : LUAU_INSN_B(call) - 1; - int nresults = LUAU_INSN_C(call) - 1; - int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1; - OperandX64 args = customParams ? customArgs : luauRegAddress(ra + 2); - - BuiltinImplResult br = emitBuiltin(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); - - if (br.type == BuiltinImplType::UsesFallback) - { - if (nresults == LUA_MULTRET) - { - // L->top = ra + n; - build.lea(rax, addr[rBase + (ra + br.actualResultCount) * sizeof(TValue)]); - build.mov(qword[rState + offsetof(lua_State, top)], rax); - } - else if (nparams == LUA_MULTRET) - { - // L->top = L->ci->top; - build.mov(rax, qword[rState + offsetof(lua_State, ci)]); - build.mov(rax, qword[rax + offsetof(CallInfo, top)]); - build.mov(qword[rState + offsetof(lua_State, top)], rax); - } - - return; - } - - // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline - emitSetSavedPc(build, pcpos + 1); // uses rax/rdx - - build.mov(rax, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); - - // 5th parameter (args) is left unset for LOP_FASTCALL1 - if (args.cat == CategoryX64::mem) - { - if (build.abi == ABIX64::Windows) - { - build.lea(rcx, args); - build.mov(sArg5, rcx); - } - else - { - build.lea(rArg5, args); - } - } - - if (nparams == LUA_MULTRET) - { - // L->top - (ra + 1) - RegisterX64 reg = (build.abi == ABIX64::Windows) ? rcx : rArg6; - build.mov(reg, qword[rState + offsetof(lua_State, top)]); - build.lea(rdx, addr[rBase + (ra + 1) * sizeof(TValue)]); - build.sub(reg, rdx); - build.shr(reg, kTValueSizeLog2); - - if (build.abi == ABIX64::Windows) - build.mov(sArg6, reg); - } - else - { - if (build.abi == ABIX64::Windows) - build.mov(sArg6, nparams); - else - build.mov(rArg6, nparams); - } - - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(arg)); - build.mov(dwordReg(rArg4), nresults); - - build.call(rax); - - build.test(eax, eax); // test here will set SF=1 for a negative number and it always sets OF to 0 - build.jcc(ConditionX64::Less, fallback); // jl jumps if SF != OF - - if (nresults == LUA_MULTRET) - { - // L->top = ra + n; - build.shl(rax, kTValueSizeLog2); - build.lea(rax, addr[rBase + rax + ra * sizeof(TValue)]); - build.mov(qword[rState + offsetof(lua_State, top)], rax); - } - else if (nparams == LUA_MULTRET) - { - // L->top = L->ci->top; - build.mov(rax, qword[rState + offsetof(lua_State, ci)]); - build.mov(rax, qword[rax + offsetof(CallInfo, top)]); - build.mov(qword[rState + offsetof(lua_State, top)], rax); - } -} - -void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) -{ - return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 1, /* customArgs */ 0, pcpos, fallback); -} - -void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) -{ - return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegAddress(pc[1]), pcpos, fallback); -} - -void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) -{ - return emitInstFastCallN( - build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantAddress(pc[1]), pcpos, fallback); -} - -void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) +void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit) { - return emitInstFastCallN(build, pc, /* customParams */ false, /* customParamCount */ 0, /* customArgs */ 0, pcpos, fallback); -} + // ipairs-style traversal is handled in IR + LUAU_ASSERT(aux >= 0); -void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback) -{ - int ra = LUAU_INSN_A(*pc); - int aux = pc[1]; - - emitInterrupt(build, pcpos); - - // fast-path: builtin table iteration - jumpIfTagIsNot(build, ra, LUA_TNIL, fallback); + // This is a fast-path for builtin table iteration, tag check for 'ra' has to be performed before emitting this instruction // Registers are chosen in this way to simplify fallback code for the node part RegisterX64 table = rArg2; @@ -630,22 +506,19 @@ void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo for (int i = 2; i < aux; ++i) build.mov(luauRegTag(ra + 3 + i), LUA_TNIL); - // ipairs-style traversal is terminated early when array part ends of nil array element is encountered - bool isIpairsIter = aux < 0; - Label skipArray, skipArrayNil; // First we advance index through the array portion // while (unsigned(index) < unsigned(sizearray)) Label arrayLoop = build.setLabel(); build.cmp(dwordReg(index), dword[table + offsetof(Table, sizearray)]); - build.jcc(ConditionX64::NotBelow, isIpairsIter ? loopExit : skipArray); + build.jcc(ConditionX64::NotBelow, skipArray); // If element is nil, we increment the index; if it's not, we still need 'index + 1' inside build.inc(index); build.cmp(dword[elemPtr + offsetof(TValue, tt)], LUA_TNIL); - build.jcc(ConditionX64::Equal, isIpairsIter ? loopExit : skipArrayNil); + build.jcc(ConditionX64::Equal, skipArrayNil); // setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); build.mov(luauRegValue(ra + 2), index); @@ -661,31 +534,25 @@ void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo build.jmp(loopRepeat); - if (!isIpairsIter) - { - build.setLabel(skipArrayNil); + build.setLabel(skipArrayNil); - // Index already incremented, advance to next array element - build.add(elemPtr, sizeof(TValue)); - build.jmp(arrayLoop); + // Index already incremented, advance to next array element + build.add(elemPtr, sizeof(TValue)); + build.jmp(arrayLoop); - build.setLabel(skipArray); + build.setLabel(skipArray); - // Call helper to assign next node value or to signal loop exit - build.mov(rArg1, rState); - // rArg2 and rArg3 are already set - build.lea(rArg4, luauRegAddress(ra)); - build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNodeIter)]); - build.test(al, al); - build.jcc(ConditionX64::NotZero, loopRepeat); - } + // Call helper to assign next node value or to signal loop exit + build.mov(rArg1, rState); + // rArg2 and rArg3 are already set + build.lea(rArg4, luauRegAddress(ra)); + build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNodeIter)]); + build.test(al, al); + build.jcc(ConditionX64::NotZero, loopRepeat); } -void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat) +void emitinstForGLoopFallback(AssemblyBuilderX64& build, int pcpos, int ra, int aux, Label& loopRepeat) { - int ra = LUAU_INSN_A(*pc); - int aux = pc[1]; - emitSetSavedPc(build, pcpos + 1); build.mov(rArg1, rState); @@ -697,10 +564,8 @@ void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, build.jcc(ConditionX64::NotZero, loopRepeat); } -void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& target) +void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target) { - int ra = LUAU_INSN_A(*pc); - build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(ra)); build.mov(dwordReg(rArg3), pcpos + 1); @@ -836,5 +701,6 @@ void emitInstCoverage(AssemblyBuilderX64& build, int pcpos) build.mov(dword[rcx], eax); } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index 5fbfb56d6..dcca52ab6 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -13,21 +13,21 @@ namespace Luau namespace CodeGen { -class AssemblyBuilderX64; struct Label; struct ModuleHelpers; +namespace X64 +{ + +class AssemblyBuilderX64; + void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback); void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next); -void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); -void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); -void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); -void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); -void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback); -void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat); -void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& target); +void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit); +void emitinstForGLoopFallback(AssemblyBuilderX64& build, int pcpos, int ra, int aux, Label& loopRepeat); +void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target); void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc); void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc); @@ -35,5 +35,6 @@ void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux); void emitInstCoverage(AssemblyBuilderX64& build, int pcpos); +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index b494f2afc..aa3e19f7e 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrAnalysis.h" +#include "Luau/DenseHash.h" #include "Luau/IrData.h" #include "Luau/IrUtils.h" @@ -73,5 +74,47 @@ void updateLastUseLocations(IrFunction& function) } } +std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block) +{ + uint32_t liveIns = 0; + uint32_t liveOuts = 0; + + auto checkOp = [&](IrOp op) { + if (op.kind == IrOpKind::Inst) + { + if (op.index >= block.start && op.index <= block.finish) + liveOuts--; + else + liveIns++; + } + }; + + for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) + { + IrInst& inst = function.instructions[instIdx]; + + liveOuts += inst.useCount; + + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + checkOp(inst.f); + } + + return std::make_pair(liveIns, liveOuts); +} + +uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block) +{ + return getLiveInOutValueCount(function, block).first; +} + +uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block) +{ + return getLiveInOutValueCount(function, block).second; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 2e7c75d19..056ea6007 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -2,6 +2,7 @@ #include "Luau/IrBuilder.h" #include "Luau/Common.h" +#include "Luau/DenseHash.h" #include "Luau/IrAnalysis.h" #include "Luau/IrUtils.h" @@ -271,7 +272,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) int skip = LUAU_INSN_C(*pc); IrOp next = blockAtInst(i + skip + 2); - translateFastCallN(*this, pc, i, false, 0, {}, next, IrCmd::LOP_FASTCALL); + translateFastCallN(*this, pc, i, false, 0, {}, next); activeFastcallFallback = true; fastcallFallbackReturn = next; @@ -282,7 +283,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) int skip = LUAU_INSN_C(*pc); IrOp next = blockAtInst(i + skip + 2); - translateFastCallN(*this, pc, i, true, 1, constBool(false), next, IrCmd::LOP_FASTCALL1); + translateFastCallN(*this, pc, i, true, 1, constBool(false), next); activeFastcallFallback = true; fastcallFallbackReturn = next; @@ -293,7 +294,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) int skip = LUAU_INSN_C(*pc); IrOp next = blockAtInst(i + skip + 2); - translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), next, IrCmd::LOP_FASTCALL2); + translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), next); activeFastcallFallback = true; fastcallFallbackReturn = next; @@ -304,7 +305,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) int skip = LUAU_INSN_C(*pc); IrOp next = blockAtInst(i + skip + 2); - translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), next, IrCmd::LOP_FASTCALL2K); + translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), next); activeFastcallFallback = true; fastcallFallbackReturn = next; @@ -318,21 +319,28 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) break; case LOP_FORGLOOP: { + int aux = int(pc[1]); + // We have a translation for ipairs-style traversal, general loop iteration is still too complex - if (int(pc[1]) < 0) + if (aux < 0) { translateInstForGLoopIpairs(*this, pc, i); } else { + int ra = LUAU_INSN_A(*pc); + IrOp loopRepeat = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); IrOp loopExit = blockAtInst(i + getOpLength(LOP_FORGLOOP)); IrOp fallback = block(IrBlockKind::Fallback); - inst(IrCmd::LOP_FORGLOOP, constUint(i), loopRepeat, loopExit, fallback); + inst(IrCmd::INTERRUPT, constUint(i)); + loadAndCheckTag(vmReg(ra), LUA_TNIL, fallback); + + inst(IrCmd::LOP_FORGLOOP, vmReg(ra), constInt(aux), loopRepeat, loopExit); beginBlock(fallback); - inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), loopRepeat, loopExit); + inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), vmReg(ra), constInt(aux), loopRepeat, loopExit); beginBlock(loopExit); } @@ -426,6 +434,68 @@ void IrBuilder::beginBlock(IrOp block) inTerminatedBlock = false; } +void IrBuilder::loadAndCheckTag(IrOp loc, uint8_t tag, IrOp fallback) +{ + inst(IrCmd::CHECK_TAG, inst(IrCmd::LOAD_TAG, loc), constTag(tag), fallback); +} + +void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) +{ + DenseHashMap instRedir{~0u}; + + auto redirect = [&instRedir](IrOp& op) { + if (op.kind == IrOpKind::Inst) + { + if (const uint32_t* newIndex = instRedir.find(op.index)) + op.index = *newIndex; + else + LUAU_ASSERT(!"values can only be used if they are defined in the same block"); + } + }; + + if (removeCurrentTerminator && inTerminatedBlock) + { + IrBlock& active = function.blocks[activeBlockIdx]; + IrInst& term = function.instructions[active.finish]; + + kill(function, term); + inTerminatedBlock = false; + } + + for (uint32_t index = source.start; index <= source.finish; index++) + { + LUAU_ASSERT(index < function.instructions.size()); + IrInst clone = function.instructions[index]; + + // Skip pseudo instructions to make clone more compact, but validate that they have no users + if (isPseudo(clone.cmd)) + { + LUAU_ASSERT(clone.useCount == 0); + continue; + } + + redirect(clone.a); + redirect(clone.b); + redirect(clone.c); + redirect(clone.d); + redirect(clone.e); + redirect(clone.f); + + addUse(function, clone.a); + addUse(function, clone.b); + addUse(function, clone.c); + addUse(function, clone.d); + addUse(function, clone.e); + addUse(function, clone.f); + + // Instructions that referenced the original will have to be adjusted to use the clone + instRedir[index] = uint32_t(function.instructions.size()); + + // Reconstruct the fresh clone + inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f); + } +} + IrOp IrBuilder::constBool(bool value) { IrConst constant; diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 681de2867..cb203f7a7 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -118,6 +118,10 @@ const char* getCmdName(IrCmd cmd) return "MOD_NUM"; case IrCmd::POW_NUM: return "POW_NUM"; + case IrCmd::MIN_NUM: + return "MIN_NUM"; + case IrCmd::MAX_NUM: + return "MAX_NUM"; case IrCmd::UNM_NUM: return "UNM_NUM"; case IrCmd::NOT_ANY: @@ -152,6 +156,12 @@ const char* getCmdName(IrCmd cmd) return "ADJUST_STACK_TO_REG"; case IrCmd::ADJUST_STACK_TO_TOP: return "ADJUST_STACK_TO_TOP"; + case IrCmd::FASTCALL: + return "FASTCALL"; + case IrCmd::INVOKE_FASTCALL: + return "INVOKE_FASTCALL"; + case IrCmd::CHECK_FASTCALL_RES: + return "CHECK_FASTCALL_RES"; case IrCmd::DO_ARITH: return "DO_ARITH"; case IrCmd::DO_LEN: @@ -206,14 +216,6 @@ const char* getCmdName(IrCmd cmd) return "LOP_CALL"; case IrCmd::LOP_RETURN: return "LOP_RETURN"; - case IrCmd::LOP_FASTCALL: - return "LOP_FASTCALL"; - case IrCmd::LOP_FASTCALL1: - return "LOP_FASTCALL1"; - case IrCmd::LOP_FASTCALL2: - return "LOP_FASTCALL2"; - case IrCmd::LOP_FASTCALL2K: - return "LOP_FASTCALL2K"; case IrCmd::LOP_FORGLOOP: return "LOP_FORGLOOP"; case IrCmd::LOP_FORGLOOP_FALLBACK: @@ -267,6 +269,8 @@ const char* getBlockKindName(IrBlockKind kind) return "bb_fallback"; case IrBlockKind::Internal: return "bb"; + case IrBlockKind::Linearized: + return "bb_linear"; case IrBlockKind::Dead: return "dead"; } diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index d81240ffd..383375753 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -20,6 +20,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) : build(build) @@ -517,6 +519,36 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } + case IrCmd::MIN_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + + if (inst.a.kind == IrOpKind::Constant) + { + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + build.vmovsd(tmp.reg, memRegDoubleOp(inst.a)); + build.vminsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b)); + } + else + { + build.vminsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b)); + } + break; + case IrCmd::MAX_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + + if (inst.a.kind == IrOpKind::Constant) + { + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + build.vmovsd(tmp.reg, memRegDoubleOp(inst.a)); + build.vmaxsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b)); + } + else + { + build.vmaxsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b)); + } + break; case IrCmd::UNM_NUM: { inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); @@ -624,7 +656,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) ScopedRegX64 tmp{regs, SizeX64::xmmword}; // TODO: jumpOnNumberCmp should work on IrCondition directly - jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), getX64Condition(cond), labelOp(inst.d)); + jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), cond, labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; } @@ -636,7 +668,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) IrCondition cond = IrCondition(inst.c.index); - jumpOnAnyCmpFallback(build, inst.a.index, inst.b.index, getX64Condition(cond), labelOp(inst.d)); + jumpOnAnyCmpFallback(build, inst.a.index, inst.b.index, cond, labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; } @@ -716,6 +748,89 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg); break; } + + case IrCmd::FASTCALL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + emitBuiltin(regs, build, uintOp(inst.a), inst.b.index, inst.c.index, inst.d, intOp(inst.e), intOp(inst.f)); + break; + case IrCmd::INVOKE_FASTCALL: + { + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + + unsigned bfid = uintOp(inst.a); + + OperandX64 args = 0; + + if (inst.d.kind == IrOpKind::VmReg) + args = luauRegAddress(inst.d.index); + else if (inst.d.kind == IrOpKind::VmConst) + args = luauConstantAddress(inst.d.index); + else + LUAU_ASSERT(boolOp(inst.d) == false); + + int ra = inst.b.index; + int arg = inst.c.index; + int nparams = intOp(inst.e); + int nresults = intOp(inst.f); + + regs.assertAllFree(); + + build.mov(rax, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); + + // 5th parameter (args) is left unset for LOP_FASTCALL1 + if (args.cat == CategoryX64::mem) + { + if (build.abi == ABIX64::Windows) + { + build.lea(rcx, args); + build.mov(sArg5, rcx); + } + else + { + build.lea(rArg5, args); + } + } + + if (nparams == LUA_MULTRET) + { + // L->top - (ra + 1) + RegisterX64 reg = (build.abi == ABIX64::Windows) ? rcx : rArg6; + build.mov(reg, qword[rState + offsetof(lua_State, top)]); + build.lea(rdx, addr[rBase + (ra + 1) * sizeof(TValue)]); + build.sub(reg, rdx); + build.shr(reg, kTValueSizeLog2); + + if (build.abi == ABIX64::Windows) + build.mov(sArg6, reg); + } + else + { + if (build.abi == ABIX64::Windows) + build.mov(sArg6, nparams); + else + build.mov(rArg6, nparams); + } + + build.mov(rArg1, rState); + build.lea(rArg2, luauRegAddress(ra)); + build.lea(rArg3, luauRegAddress(arg)); + build.mov(dwordReg(rArg4), nresults); + + build.call(rax); + + inst.regX64 = regs.takeGprReg(eax); // Result of a builtin call is returned in eax + break; + } + case IrCmd::CHECK_FASTCALL_RES: + { + RegisterX64 res = regOp(inst.a); + + build.test(res, res); // test here will set SF=1 for a negative number and it always sets OF to 0 + build.jcc(ConditionX64::Less, labelOp(inst.b)); // jl jumps if SF != OF + break; + } case IrCmd::DO_ARITH: LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); @@ -1014,41 +1129,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) emitInstReturn(build, helpers, pc, uintOp(inst.a)); break; } - case IrCmd::LOP_FASTCALL: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - - emitInstFastCall(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.d)); - break; - case IrCmd::LOP_FASTCALL1: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - - emitInstFastCall1(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.d)); - break; - case IrCmd::LOP_FASTCALL2: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg); - - emitInstFastCall2(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.e)); - break; - case IrCmd::LOP_FASTCALL2K: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - - emitInstFastCall2K(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.e)); - break; case IrCmd::LOP_FORGLOOP: - emitinstForGLoop(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b), labelOp(inst.c), labelOp(inst.d)); + LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); + emitinstForGLoop(build, inst.a.index, intOp(inst.b), labelOp(inst.c), labelOp(inst.d)); break; case IrCmd::LOP_FORGLOOP_FALLBACK: - emitinstForGLoopFallback(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); - build.jmp(labelOp(inst.c)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + emitinstForGLoopFallback(build, uintOp(inst.a), inst.b.index, intOp(inst.c), labelOp(inst.d)); + build.jmp(labelOp(inst.e)); break; case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: - emitInstForGPrepXnextFallback(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + emitInstForGPrepXnextFallback(build, uintOp(inst.a), inst.b.index, labelOp(inst.c)); break; case IrCmd::LOP_AND: emitInstAnd(build, proto->code + uintOp(inst.a)); @@ -1224,38 +1316,6 @@ Label& IrLoweringX64::labelOp(IrOp op) const return blockOp(op).label; } -ConditionX64 IrLoweringX64::getX64Condition(IrCondition cond) const -{ - // TODO: this function will not be required when jumpOnNumberCmp starts accepting an IrCondition - switch (cond) - { - case IrCondition::Equal: - return ConditionX64::Equal; - case IrCondition::NotEqual: - return ConditionX64::NotEqual; - case IrCondition::Less: - return ConditionX64::Less; - case IrCondition::NotLess: - return ConditionX64::NotLess; - case IrCondition::LessEqual: - return ConditionX64::LessEqual; - case IrCondition::NotLessEqual: - return ConditionX64::NotLessEqual; - case IrCondition::Greater: - return ConditionX64::Greater; - case IrCondition::NotGreater: - return ConditionX64::NotGreater; - case IrCondition::GreaterEqual: - return ConditionX64::GreaterEqual; - case IrCondition::NotGreaterEqual: - return ConditionX64::NotGreaterEqual; - default: - LUAU_ASSERT(!"unsupported condition"); - break; - } - - return ConditionX64::Count; -} - +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index e47c39783..a0ad3eabd 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -19,6 +19,9 @@ struct ModuleHelpers; struct NativeState; struct AssemblyOptions; +namespace X64 +{ + struct IrLoweringX64 { // Some of these arguments are only required while we re-use old direct bytecode to x64 lowering @@ -46,8 +49,6 @@ struct IrLoweringX64 IrBlock& blockOp(IrOp op) const; Label& labelOp(IrOp op) const; - ConditionX64 getX64Condition(IrCondition cond) const; - AssemblyBuilderX64& build; ModuleHelpers& helpers; NativeState& data; @@ -58,5 +59,6 @@ struct IrLoweringX64 IrRegAllocX64 regs; }; +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index de60159fa..91867806a 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -19,6 +19,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; @@ -106,6 +108,16 @@ RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_l return allocXmmReg(); } +RegisterX64 IrRegAllocX64::takeGprReg(RegisterX64 reg) +{ + // In a more advanced register allocator, this would require a spill for the current register user + // But at the current stage we don't have register live ranges intersecting forced register uses + LUAU_ASSERT(freeGprMap[reg.index]); + + freeGprMap[reg.index] = false; + return reg; +} + void IrRegAllocX64::freeReg(RegisterX64 reg) { if (reg.size == SizeX64::xmmword) @@ -148,6 +160,15 @@ void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index) checkOp(inst.f); } +void IrRegAllocX64::assertAllFree() const +{ + for (RegisterX64 reg : kGprAllocOrder) + LUAU_ASSERT(freeGprMap[reg.index]); + + for (bool free : freeXmmMap) + LUAU_ASSERT(free); +} + ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, SizeX64 size) : owner(owner) { @@ -157,7 +178,6 @@ ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, SizeX64 size) reg = owner.allocGprReg(size); } - ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg) : owner(owner) , reg(reg) @@ -177,5 +197,6 @@ void ScopedRegX64::free() reg = noreg; } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrRegAllocX64.h b/CodeGen/src/IrRegAllocX64.h index a532c3b3a..ac072a32f 100644 --- a/CodeGen/src/IrRegAllocX64.h +++ b/CodeGen/src/IrRegAllocX64.h @@ -11,6 +11,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ struct IrRegAllocX64 { @@ -22,10 +24,14 @@ struct IrRegAllocX64 RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list oprefs); RegisterX64 allocXmmRegOrReuse(uint32_t index, std::initializer_list oprefs); + RegisterX64 takeGprReg(RegisterX64 reg); + void freeReg(RegisterX64 reg); void freeLastUseReg(IrInst& target, uint32_t index); void freeLastUseRegs(const IrInst& inst, uint32_t index); + void assertAllFree() const; + IrFunction& function; std::array freeGprMap; @@ -47,5 +53,6 @@ struct ScopedRegX64 RegisterX64 reg; }; +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index ccd743ed1..bc909105d 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -6,11 +6,66 @@ #include "lstate.h" +// TODO: should be possible to handle fastcalls in contexts where nresults is -1 by adding the adjustment instruction +// TODO: when nresults is less than our actual result count, we can skip computing/writing unused results + namespace Luau { namespace CodeGen { +// Wrapper code for all builtins with a fixed signature and manual assembly lowering of the body + +// (number, ...) -> number +BuiltinImplResult translateBuiltinNumberToNumber( + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +// (number, number, ...) -> number +BuiltinImplResult translateBuiltin2NumberToNumber( + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO:tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +// (number, ...) -> (number, number) +BuiltinImplResult translateBuiltinNumberTo2Number( + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 2) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO: some tag updates might not be required, we place them here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 2}; +} + BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults != 0) @@ -25,12 +80,180 @@ BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 0}; } +BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + + const double rpd = (3.14159265358979323846 / 180.0); + + IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp value = build.inst(IrCmd::DIV_NUM, varg, build.constDouble(rpd)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + + const double rpd = (3.14159265358979323846 / 180.0); + + IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp value = build.inst(IrCmd::MUL_NUM, varg, build.constDouble(rpd)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathLog( + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + + if (nparams != 1) + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + // TODO: this can be extended for other number of arguments + if (nparams != 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + + IrOp varg1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp varg2 = build.inst(IrCmd::LOAD_DOUBLE, args); + + IrOp res = build.inst(IrCmd::MIN_NUM, varg2, varg1); // Swapped arguments are required for consistency with VM builtins + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + // TODO: this can be extended for other number of arguments + if (nparams != 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + + IrOp varg1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp varg2 = build.inst(IrCmd::LOAD_DOUBLE, args); + + IrOp res = build.inst(IrCmd::MAX_NUM, varg2, varg1); // Swapped arguments are required for consistency with VM builtins + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 3 || nresults > 1) + return {BuiltinImplType::None, -1}; + + IrOp block = build.block(IrBlockKind::Internal); + + LUAU_ASSERT(args.kind == IrOpKind::VmReg); + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + build.loadAndCheckTag(build.vmReg(args.index + 1), LUA_TNUMBER, fallback); + + IrOp min = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp max = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 1)); + + build.inst(IrCmd::JUMP_CMP_NUM, min, max, build.cond(IrCondition::NotLessEqual), fallback, block); + build.beginBlock(block); + + IrOp v = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp r = build.inst(IrCmd::MAX_NUM, min, v); + IrOp clamped = build.inst(IrCmd::MIN_NUM, max, r); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), clamped); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback) { switch (bfid) { case LBF_ASSERT: return translateBuiltinAssert(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_DEG: + return translateBuiltinMathDeg(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_RAD: + return translateBuiltinMathRad(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_LOG: + return translateBuiltinMathLog(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_MIN: + return translateBuiltinMathMin(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_MAX: + return translateBuiltinMathMax(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_CLAMP: + return translateBuiltinMathClamp(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_FLOOR: + case LBF_MATH_CEIL: + case LBF_MATH_SQRT: + case LBF_MATH_ABS: + case LBF_MATH_EXP: + case LBF_MATH_ASIN: + case LBF_MATH_SIN: + case LBF_MATH_SINH: + case LBF_MATH_ACOS: + case LBF_MATH_COS: + case LBF_MATH_COSH: + case LBF_MATH_ATAN: + case LBF_MATH_TAN: + case LBF_MATH_TANH: + case LBF_MATH_LOG10: + case LBF_MATH_ROUND: + case LBF_MATH_SIGN: + return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_FMOD: + case LBF_MATH_POW: + case LBF_MATH_ATAN2: + case LBF_MATH_LDEXP: + return translateBuiltin2NumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_FREXP: + case LBF_MATH_MODF: + return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 68c6c402c..48ca3975b 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -479,8 +479,7 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc) build.inst(IrCmd::CLOSE_UPVALS, build.vmReg(ra)); } -void translateFastCallN( - IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next, IrCmd fallbackCmd) +void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next) { int bfid = LUAU_INSN_A(*pc); int skip = LUAU_INSN_C(*pc); @@ -509,23 +508,17 @@ void translateFastCallN( } else { - switch (fallbackCmd) - { - case IrCmd::LOP_FASTCALL: - build.inst(IrCmd::LOP_FASTCALL, build.constUint(pcpos), build.vmReg(ra), build.constInt(nparams), fallback); - break; - case IrCmd::LOP_FASTCALL1: - build.inst(IrCmd::LOP_FASTCALL1, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), fallback); - break; - case IrCmd::LOP_FASTCALL2: - build.inst(IrCmd::LOP_FASTCALL2, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), build.vmReg(pc[1]), fallback); - break; - case IrCmd::LOP_FASTCALL2K: - build.inst(IrCmd::LOP_FASTCALL2K, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), build.vmConst(pc[1]), fallback); - break; - default: - LUAU_ASSERT(!"unexpected command"); - } + // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + + IrOp res = build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), + build.constInt(nresults)); + build.inst(IrCmd::CHECK_FASTCALL_RES, res, fallback); + + if (nresults == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(ra), res); + else if (nparams == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_TOP); } build.inst(IrCmd::JUMP, next); @@ -645,7 +638,7 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo build.inst(IrCmd::JUMP, target); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target); + build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); } void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcpos) @@ -677,7 +670,7 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp build.inst(IrCmd::JUMP, target); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target); + build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); } void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pcpos) @@ -728,7 +721,7 @@ void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pc build.inst(IrCmd::JUMP, loopRepeat); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGLOOP_FALLBACK, build.constUint(pcpos), loopRepeat, loopExit); + build.inst(IrCmd::LOP_FORGLOOP_FALLBACK, build.constUint(pcpos), build.vmReg(ra), build.constInt(int(pc[1])), loopRepeat, loopExit); // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(loopExit)) diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 5b3f78f28..0d4a5096c 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -43,8 +43,7 @@ void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc); -void translateFastCallN( - IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next, IrCmd fallbackCmd); +void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next); void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 6ccbc8ce0..0808ad076 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -286,6 +286,24 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(pow(function.doubleOp(inst.a), function.doubleOp(inst.b)))); break; + case IrCmd::MIN_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + double a1 = function.doubleOp(inst.a); + double a2 = function.doubleOp(inst.b); + + substitute(function, inst, build.constDouble((a2 < a1) ? a2 : a1)); + } + break; + case IrCmd::MAX_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + double a1 = function.doubleOp(inst.a); + double a2 = function.doubleOp(inst.b); + + substitute(function, inst, build.constDouble((a2 > a1) ? a2 : a1)); + } + break; case IrCmd::UNM_NUM: if (inst.a.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(-function.doubleOp(inst.a))); diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index c9c7f6c4a..956c96d63 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -2,11 +2,16 @@ #include "Luau/OptimizeConstProp.h" #include "Luau/DenseHash.h" +#include "Luau/IrAnalysis.h" #include "Luau/IrBuilder.h" #include "Luau/IrUtils.h" #include "lua.h" +#include + +LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) + namespace Luau { namespace CodeGen @@ -181,6 +186,82 @@ struct ConstPropState DenseHashMap instLink{~0u}; }; +static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) +{ + // Switch over all values is used to force new items to be handled + switch (bfid) + { + case LBF_NONE: + case LBF_ASSERT: + case LBF_MATH_ABS: + case LBF_MATH_ACOS: + case LBF_MATH_ASIN: + case LBF_MATH_ATAN2: + case LBF_MATH_ATAN: + case LBF_MATH_CEIL: + case LBF_MATH_COSH: + case LBF_MATH_COS: + case LBF_MATH_DEG: + case LBF_MATH_EXP: + case LBF_MATH_FLOOR: + case LBF_MATH_FMOD: + case LBF_MATH_FREXP: + case LBF_MATH_LDEXP: + case LBF_MATH_LOG10: + case LBF_MATH_LOG: + case LBF_MATH_MAX: + case LBF_MATH_MIN: + case LBF_MATH_MODF: + case LBF_MATH_POW: + case LBF_MATH_RAD: + case LBF_MATH_SINH: + case LBF_MATH_SIN: + case LBF_MATH_SQRT: + case LBF_MATH_TANH: + case LBF_MATH_TAN: + case LBF_BIT32_ARSHIFT: + case LBF_BIT32_BAND: + case LBF_BIT32_BNOT: + case LBF_BIT32_BOR: + case LBF_BIT32_BXOR: + case LBF_BIT32_BTEST: + case LBF_BIT32_EXTRACT: + case LBF_BIT32_LROTATE: + case LBF_BIT32_LSHIFT: + case LBF_BIT32_REPLACE: + case LBF_BIT32_RROTATE: + case LBF_BIT32_RSHIFT: + case LBF_TYPE: + case LBF_STRING_BYTE: + case LBF_STRING_CHAR: + case LBF_STRING_LEN: + case LBF_TYPEOF: + case LBF_STRING_SUB: + case LBF_MATH_CLAMP: + case LBF_MATH_SIGN: + case LBF_MATH_ROUND: + case LBF_RAWSET: + case LBF_RAWGET: + case LBF_RAWEQUAL: + case LBF_TABLE_INSERT: + case LBF_TABLE_UNPACK: + case LBF_VECTOR: + case LBF_BIT32_COUNTLZ: + case LBF_BIT32_COUNTRZ: + case LBF_SELECT_VARARG: + case LBF_RAWLEN: + case LBF_BIT32_EXTRACTK: + case LBF_GETMETATABLE: + break; + case LBF_SETMETATABLE: + state.invalidateHeap(); // TODO: only knownNoMetatable is affected and we might know which one + break; + } + + // TODO: classify further using switch above, some fastcalls only modify the value, not the tag + state.invalidateRegistersFrom(firstReturnReg); +} + static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& function, IrBlock& block, IrInst& inst, uint32_t index) { switch (inst.cmd) @@ -406,20 +487,16 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } break; - case IrCmd::LOP_FASTCALL: - case IrCmd::LOP_FASTCALL1: - case IrCmd::LOP_FASTCALL2: - case IrCmd::LOP_FASTCALL2K: - // TODO: classify fast call behaviors to avoid heap invalidation - state.invalidateHeap(); // Even a builtin method can change table properties - state.invalidateRegistersFrom(inst.b.index); - break; case IrCmd::LOP_AND: case IrCmd::LOP_ANDK: case IrCmd::LOP_OR: case IrCmd::LOP_ORK: state.invalidate(inst.b); break; + case IrCmd::FASTCALL: + case IrCmd::INVOKE_FASTCALL: + handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), inst.b.index, function.intOp(inst.f)); + break; // These instructions don't have an effect on register/memory state we are tracking case IrCmd::NOP: @@ -436,6 +513,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: case IrCmd::NOT_ANY: case IrCmd::JUMP: @@ -458,6 +537,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SUBSTITUTE: case IrCmd::ADJUST_STACK_TO_REG: // Changes stack top, but not the values case IrCmd::ADJUST_STACK_TO_TOP: // Changes stack top, but not the values + case IrCmd::CHECK_FASTCALL_RES: // Changes stack top, but not the values break; // We don't model the following instructions, so we just clear all the knowledge we have built up @@ -534,8 +614,9 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite if (termInst.cmd == IrCmd::JUMP) { IrBlock& target = function.blockOp(termInst.a); + uint32_t targetIdx = function.getBlockIndex(target); - if (target.useCount == 1 && !visited[function.getBlockIndex(target)] && target.kind != IrBlockKind::Fallback) + if (target.useCount == 1 && !visited[targetIdx] && target.kind != IrBlockKind::Fallback) nextBlock = ⌖ } @@ -543,12 +624,114 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite } } +// Note that blocks in the collected path are marked as visited +static std::vector collectDirectBlockJumpPath(IrFunction& function, std::vector& visited, IrBlock* block) +{ + // Path shouldn't be built starting with a block that has 'live out' values. + // One theoretical way to get it is if we had 'block' jumping unconditionally into a successor that uses values from 'block' + // * if the successor has only one use, the starting conditions of 'tryCreateLinearBlock' would prevent this + // * if the successor has multiple uses, it can't have such 'live in' values without phi nodes that we don't have yet + // Another possibility is to have two paths from 'block' into the target through two intermediate blocks + // Usually that would mean that we would have a conditional jump at the end of 'block' + // But using check guards and fallback clocks it becomes a possible setup + // We avoid this by making sure fallbacks rejoin the other immediate successor of 'block' + LUAU_ASSERT(getLiveOutValueCount(function, *block) == 0); + + std::vector path; + + while (block) + { + IrInst& termInst = function.instructions[block->finish]; + IrBlock* nextBlock = nullptr; + + // A chain is made from internal blocks that were not a part of bytecode CFG + if (termInst.cmd == IrCmd::JUMP) + { + IrBlock& target = function.blockOp(termInst.a); + uint32_t targetIdx = function.getBlockIndex(target); + + if (!visited[targetIdx] && target.kind == IrBlockKind::Internal) + { + // Additional restriction is that to join a block, it cannot produce values that are used in other blocks + // And it also can't use values produced in other blocks + auto [liveIns, liveOuts] = getLiveInOutValueCount(function, target); + + if (liveIns == 0 && liveOuts == 0) + { + visited[targetIdx] = true; + path.push_back(targetIdx); + + nextBlock = ⌖ + } + } + } + + block = nextBlock; + } + + return path; +} + +static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited, IrBlock& startingBlock) +{ + IrFunction& function = build.function; + + uint32_t blockIdx = function.getBlockIndex(startingBlock); + LUAU_ASSERT(!visited[blockIdx]); + visited[blockIdx] = true; + + IrInst& termInst = function.instructions[startingBlock.finish]; + + // Block has to end with an unconditional jump + if (termInst.cmd != IrCmd::JUMP) + return; + + // And it has to jump to a block with more than one user + // If there's only one use, it should already be optimized by constPropInBlockChain + if (function.blockOp(termInst.a).useCount == 1) + return; + + uint32_t targetBlockIdx = termInst.a.index; + + // Check if this path is worth it (and it will also mark path blocks as visited) + std::vector path = collectDirectBlockJumpPath(function, visited, &startingBlock); + + // If path is too small, we are not going to linearize it + if (int(path.size()) < FInt::LuauCodeGenMinLinearBlockPath) + return; + + // Initialize state with the knowledge of our current block + ConstPropState state; + constPropInBlock(build, startingBlock, state); + + // Veryfy that target hasn't changed + LUAU_ASSERT(function.instructions[startingBlock.finish].a.index == targetBlockIdx); + + // Create new linearized block into which we are going to redirect starting block jump + IrOp newBlock = build.block(IrBlockKind::Linearized); + visited.push_back(false); + + // TODO: placement of linear blocks in final lowering is sub-optimal, it should follow our predecessor + build.beginBlock(newBlock); + + replace(function, termInst.a, newBlock); + + // Clone the collected path int our fresh block + for (uint32_t pathBlockIdx : path) + build.clone(function.blocks[pathBlockIdx], /* removeCurrentTerminator */ true); + + // Optimize our linear block + IrBlock& linearBlock = function.blockOp(newBlock); + constPropInBlock(build, linearBlock, state); +} + void constPropInBlockChains(IrBuilder& build) { IrFunction& function = build.function; std::vector visited(function.blocks.size(), false); + // First pass: go over existing blocks once and propagate constants for (IrBlock& block : function.blocks) { if (block.kind == IrBlockKind::Fallback || block.kind == IrBlockKind::Dead) @@ -559,6 +742,26 @@ void constPropInBlockChains(IrBuilder& build) constPropInBlockChain(build, visited, &block); } + + // Second pass: go through internal block chains and outline them into a single new block + // Outlining will be able to linearize the execution, even if there was a jump to a block with multiple users, + // new 'block' will only be reachable from a single one and all gathered information can be preserved. + std::fill(visited.begin(), visited.end(), false); + + // This next loop can create new 'linear' blocks, so index-based loop has to be used (and it intentionally won't reach those new blocks) + size_t originalBlockCount = function.blocks.size(); + for (size_t i = 0; i < originalBlockCount; i++) + { + IrBlock& block = function.blocks[i]; + + if (block.kind == IrBlockKind::Fallback || block.kind == IrBlockKind::Dead) + continue; + + if (visited[function.getBlockIndex(block)]) + continue; + + tryCreateLinearBlock(build, visited, block); + } } } // namespace CodeGen diff --git a/CodeGen/src/OptimizeFinalX64.cpp b/CodeGen/src/OptimizeFinalX64.cpp index 2b7c96527..dd31fcc4f 100644 --- a/CodeGen/src/OptimizeFinalX64.cpp +++ b/CodeGen/src/OptimizeFinalX64.cpp @@ -41,6 +41,8 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block) case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: { if (inst.b.kind == IrOpKind::Inst) { diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index 7dc86d3ec..a95ed0941 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -168,12 +168,12 @@ void UnwindBuilderDwarf2::start() // Function call frame instructions to follow } -void UnwindBuilderDwarf2::spill(int espOffset, RegisterX64 reg) +void UnwindBuilderDwarf2::spill(int espOffset, X64::RegisterX64 reg) { pos = advanceLocation(pos, 5); // REX.W mov [rsp + imm8], reg } -void UnwindBuilderDwarf2::save(RegisterX64 reg) +void UnwindBuilderDwarf2::save(X64::RegisterX64 reg) { stackOffset += 8; pos = advanceLocation(pos, 2); // REX.W push reg @@ -188,7 +188,7 @@ void UnwindBuilderDwarf2::allocStack(int size) pos = defineCfaExpressionOffset(pos, stackOffset); } -void UnwindBuilderDwarf2::setupFrameReg(RegisterX64 reg, int espOffset) +void UnwindBuilderDwarf2::setupFrameReg(X64::RegisterX64 reg, int espOffset) { if (espOffset != 0) pos = advanceLocation(pos, 5); // REX.W lea rbp, [rsp + imm8] diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 13e92ab0a..217330013 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -49,12 +49,12 @@ void UnwindBuilderWin::start() unwindCodes.reserve(16); } -void UnwindBuilderWin::spill(int espOffset, RegisterX64 reg) +void UnwindBuilderWin::spill(int espOffset, X64::RegisterX64 reg) { prologSize += 5; // REX.W mov [rsp + imm8], reg } -void UnwindBuilderWin::save(RegisterX64 reg) +void UnwindBuilderWin::save(X64::RegisterX64 reg) { prologSize += 2; // REX.W push reg stackOffset += 8; @@ -70,7 +70,7 @@ void UnwindBuilderWin::allocStack(int size) unwindCodes.push_back({prologSize, UWOP_ALLOC_SMALL, uint8_t((size - 8) / 8)}); } -void UnwindBuilderWin::setupFrameReg(RegisterX64 reg, int espOffset) +void UnwindBuilderWin::setupFrameReg(X64::RegisterX64 reg, int espOffset) { LUAU_ASSERT(espOffset < 256 && espOffset % 16 == 0); diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index ae9633708..073bb1c79 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -244,5 +244,158 @@ void analyzeBuiltins(DenseHashMap& result, const DenseHashMap root->visit(&visitor); } +BuiltinInfo getBuiltinInfo(int bfid) +{ + switch (LuauBuiltinFunction(bfid)) + { + case LBF_NONE: + return {-1, -1}; + + case LBF_ASSERT: + return {-1, -1}; + ; // assert() returns all values when first value is truthy + + case LBF_MATH_ABS: + case LBF_MATH_ACOS: + case LBF_MATH_ASIN: + return {1, 1}; + + case LBF_MATH_ATAN2: + return {2, 1}; + + case LBF_MATH_ATAN: + case LBF_MATH_CEIL: + case LBF_MATH_COSH: + case LBF_MATH_COS: + case LBF_MATH_DEG: + case LBF_MATH_EXP: + case LBF_MATH_FLOOR: + return {1, 1}; + + case LBF_MATH_FMOD: + return {2, 1}; + + case LBF_MATH_FREXP: + return {1, 2}; + + case LBF_MATH_LDEXP: + return {2, 1}; + + case LBF_MATH_LOG10: + return {1, 1}; + + case LBF_MATH_LOG: + return {-1, 1}; // 1 or 2 parameters + + case LBF_MATH_MAX: + case LBF_MATH_MIN: + return {-1, 1}; // variadic + + case LBF_MATH_MODF: + return {1, 2}; + + case LBF_MATH_POW: + return {2, 1}; + + case LBF_MATH_RAD: + case LBF_MATH_SINH: + case LBF_MATH_SIN: + case LBF_MATH_SQRT: + case LBF_MATH_TANH: + case LBF_MATH_TAN: + return {1, 1}; + + case LBF_BIT32_ARSHIFT: + return {2, 1}; + + case LBF_BIT32_BAND: + return {-1, 1}; // variadic + + case LBF_BIT32_BNOT: + return {1, 1}; + + case LBF_BIT32_BOR: + case LBF_BIT32_BXOR: + case LBF_BIT32_BTEST: + return {-1, 1}; // variadic + + case LBF_BIT32_EXTRACT: + return {-1, 1}; // 2 or 3 parameters + + case LBF_BIT32_LROTATE: + case LBF_BIT32_LSHIFT: + return {2, 1}; + + case LBF_BIT32_REPLACE: + return {-1, 1}; // 3 or 4 parameters + + case LBF_BIT32_RROTATE: + case LBF_BIT32_RSHIFT: + return {2, 1}; + + case LBF_TYPE: + return {1, 1}; + + case LBF_STRING_BYTE: + return {-1, -1}; // 1, 2 or 3 parameters + + case LBF_STRING_CHAR: + return {-1, 1}; // variadic + + case LBF_STRING_LEN: + return {1, 1}; + + case LBF_TYPEOF: + return {1, 1}; + + case LBF_STRING_SUB: + return {-1, 1}; // 2 or 3 parameters + + case LBF_MATH_CLAMP: + return {3, 1}; + + case LBF_MATH_SIGN: + case LBF_MATH_ROUND: + return {1, 1}; + + case LBF_RAWSET: + return {3, 1}; + + case LBF_RAWGET: + case LBF_RAWEQUAL: + return {2, 1}; + + case LBF_TABLE_INSERT: + return {-1, 0}; // 2 or 3 parameters + + case LBF_TABLE_UNPACK: + return {-1, -1}; // 1, 2 or 3 parameters + + case LBF_VECTOR: + return {-1, 1}; // 3 or 4 parameters in some configurations + + case LBF_BIT32_COUNTLZ: + case LBF_BIT32_COUNTRZ: + return {1, 1}; + + case LBF_SELECT_VARARG: + return {-1, -1}; // variadic + + case LBF_RAWLEN: + return {1, 1}; + + case LBF_BIT32_EXTRACTK: + return {3, 1}; + + case LBF_GETMETATABLE: + return {1, 1}; + + case LBF_SETMETATABLE: + return {2, 1}; + }; + + LUAU_UNREACHABLE(); +} + } // namespace Compile } // namespace Luau diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h index 4399c5321..e179218aa 100644 --- a/Compiler/src/Builtins.h +++ b/Compiler/src/Builtins.h @@ -39,5 +39,13 @@ Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, void analyzeBuiltins(DenseHashMap& result, const DenseHashMap& globals, const DenseHashMap& variables, const CompileOptions& options, AstNode* root); +struct BuiltinInfo +{ + int params; + int results; +}; + +BuiltinInfo getBuiltinInfo(int bfid); + } // namespace Compile } // namespace Luau diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 11bf24297..8e450f4f8 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -2038,7 +2038,10 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, case LOP_CAPTURE: formatAppend(result, "CAPTURE %s %c%d\n", - LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL" : LUAU_INSN_A(insn) == LCT_REF ? "REF" : LUAU_INSN_A(insn) == LCT_VAL ? "VAL" : "", + LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL" + : LUAU_INSN_A(insn) == LCT_REF ? "REF" + : LUAU_INSN_A(insn) == LCT_VAL ? "VAL" + : "", LUAU_INSN_A(insn) == LCT_UPVAL ? 'U' : 'R', LUAU_INSN_B(insn)); break; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 8a017f488..78896d311 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,6 +26,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAGVARIABLE(LuauCompileTerminateBC, false) +LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinArity, false) namespace Luau { @@ -293,6 +294,12 @@ struct Compiler if (isConstant(expr)) return false; + // handles builtin calls that can't be constant-folded but are known to return one value + // note: optimizationLevel check is technically redundant but it's important that we never optimize based on builtins in O1 + if (FFlag::LuauCompileBuiltinArity && options.optimizationLevel >= 2) + if (int* bfid = builtins.find(expr)) + return getBuiltinInfo(*bfid).results != 1; + // handles local function calls where we know only one argument is returned AstExprFunction* func = getFunctionExpr(expr->func); Function* fi = func ? functions.find(func) : nullptr; @@ -506,6 +513,7 @@ struct Compiler // we can't inline multret functions because the caller expects L->top to be adjusted: // - inlined return compiles to a JUMP, and we don't have an instruction that adjusts L->top arbitrarily // - even if we did, right now all L->top adjustments are immediately consumed by the next instruction, and for now we want to preserve that + // - additionally, we can't easily compile multret expressions into designated target as computed call arguments will get clobbered if (multRet) { bytecode.addDebugRemark("inlining failed: can't convert fixed returns to multret"); @@ -755,8 +763,13 @@ struct Compiler } // Optimization: for 1/2 argument fast calls use specialized opcodes - if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2 && !isExprMultRet(expr->args.data[expr->args.size - 1])) - return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) + { + if (!isExprMultRet(expr->args.data[expr->args.size - 1])) + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + else if (FFlag::LuauCompileBuiltinArity && options.optimizationLevel >= 2 && int(expr->args.size) == getBuiltinInfo(bfid).params) + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + } if (expr->self) { diff --git a/Sources.cmake b/Sources.cmake index 22197e0e8..88c6e9b63 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -108,6 +108,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/CodeGenUtils.h CodeGen/src/CodeGenX64.h CodeGen/src/EmitBuiltinsX64.h + CodeGen/src/EmitCommon.h CodeGen/src/EmitCommonX64.h CodeGen/src/EmitInstructionX64.h CodeGen/src/Fallbacks.h @@ -126,6 +127,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstQuery.h Analysis/include/Luau/Autocomplete.h + Analysis/include/Luau/Breadcrumb.h Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Clone.h Analysis/include/Luau/Config.h diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index cf7381ae1..875a479a9 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -1445,7 +1445,7 @@ static int str_pack(lua_State* L) const char* s = luaL_checklstring(L, arg, &len); luaL_argcheck(L, len <= (size_t)size, arg, "string longer than given size"); luaL_addlstring(&b, s, len, -1); // add string - while (len++ < (size_t)size) // pad extra space + while (len++ < (size_t)size) // pad extra space luaL_addchar(&b, LUAL_PACKPADBYTE); break; } diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index d808ac491..e23b965bc 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -7,6 +7,7 @@ #include using namespace Luau::CodeGen; +using namespace Luau::CodeGen::A64; static std::string bytecodeAsArray(const std::vector& bytecode) { diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index c4d2a1c70..6aa7aa561 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -7,6 +7,7 @@ #include using namespace Luau::CodeGen; +using namespace Luau::CodeGen::X64; static std::string bytecodeAsArray(const std::vector& bytecode) { diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index 44e9e5e49..a0127eef7 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -445,7 +445,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,35","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,36","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; CHECK(toJson(statement) == expected); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index d238e9eca..85bd55077 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3198,6 +3198,20 @@ t.@1 } } +TEST_CASE_FIXTURE(ACFixture, "simple") +{ + check(R"( +local t = {} +function t:m() end +t:m() + )"); + + // auto ac = autocomplete('1'); + + // REQUIRE(ac.entryMap.count("m")); + // CHECK(!ac.entryMap["m"].wrongIndexType); +} + TEST_CASE_FIXTURE(ACFixture, "do_compatible_self_calls") { check(R"( @@ -3466,4 +3480,33 @@ TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") CHECK(isCorrect); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_response_perf1" * doctest::timeout(0.5)) +{ + ScopedFastFlag luauAutocompleteSkipNormalization{"LuauAutocompleteSkipNormalization", true}; + + // Build a function type with a large overload set + const int parts = 100; + std::string source; + + for (int i = 0; i < parts; i++) + formatAppend(source, "type T%d = { f%d: number }\n", i, i); + + source += "type Instance = { new: (('s0', extra: Instance?) -> T0)"; + + for (int i = 1; i < parts; i++) + formatAppend(source, " & (('s%d', extra: Instance?) -> T%d)", i, i); + + source += " }\n"; + + source += "local Instance: Instance = {} :: any\n"; + source += "local function c(): boolean return t@1 end\n"; + + check(source); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("Instance")); +} + TEST_SUITE_END(); diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 65b485a7f..a6ed96f02 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -131,6 +131,8 @@ TEST_CASE("CodeAllocationWithUnwindCallbacks") #if !defined(LUAU_BIG_ENDIAN) TEST_CASE("WindowsUnwindCodesX64") { + using namespace X64; + UnwindBuilderWin unwind; unwind.start(); @@ -162,6 +164,8 @@ TEST_CASE("WindowsUnwindCodesX64") TEST_CASE("Dwarf2UnwindCodesX64") { + using namespace X64; + UnwindBuilderDwarf2 unwind; unwind.start(); @@ -195,21 +199,23 @@ TEST_CASE("Dwarf2UnwindCodesX64") #if defined(_WIN32) // Windows x64 ABI -constexpr RegisterX64 rArg1 = rcx; -constexpr RegisterX64 rArg2 = rdx; -constexpr RegisterX64 rArg3 = r8; +constexpr X64::RegisterX64 rArg1 = X64::rcx; +constexpr X64::RegisterX64 rArg2 = X64::rdx; +constexpr X64::RegisterX64 rArg3 = X64::r8; #else // System V AMD64 ABI -constexpr RegisterX64 rArg1 = rdi; -constexpr RegisterX64 rArg2 = rsi; -constexpr RegisterX64 rArg3 = rdx; +constexpr X64::RegisterX64 rArg1 = X64::rdi; +constexpr X64::RegisterX64 rArg2 = X64::rsi; +constexpr X64::RegisterX64 rArg3 = X64::rdx; #endif -constexpr RegisterX64 rNonVol1 = r12; -constexpr RegisterX64 rNonVol2 = rbx; +constexpr X64::RegisterX64 rNonVol1 = X64::r12; +constexpr X64::RegisterX64 rNonVol2 = X64::rbx; TEST_CASE("GeneratedCodeExecutionX64") { + using namespace X64; + AssemblyBuilderX64 build(/* logText= */ false); build.mov(rax, rArg1); @@ -244,6 +250,8 @@ void throwing(int64_t arg) TEST_CASE("GeneratedCodeExecutionWithThrowX64") { + using namespace X64; + AssemblyBuilderX64 build(/* logText= */ false); #if defined(_WIN32) @@ -320,6 +328,8 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") { + using namespace X64; + AssemblyBuilderX64 build(/* logText= */ false); #if defined(_WIN32) @@ -437,6 +447,8 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") TEST_CASE("GeneratedCodeExecutionA64") { + using namespace A64; + AssemblyBuilderA64 build(/* logText= */ false); Label skip; diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 87a782adb..135a555ab 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4655,6 +4655,8 @@ RETURN R0 0 TEST_CASE("LoopUnrollCost") { + ScopedFastFlag sff("LuauCompileBuiltinArity", true); + ScopedFastInt sfis[] = { {"LuauCompileLoopUnrollThreshold", 25}, {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, @@ -4796,10 +4798,10 @@ FORNPREP R1 L3 L0: FASTCALL1 24 R3 L1 MOVE R6 R3 GETIMPORT R5 2 [math.sin] -CALL R5 1 -1 -L1: FASTCALL 2 L2 +CALL R5 1 1 +L1: FASTCALL1 2 R5 L2 GETIMPORT R4 4 [math.abs] -CALL R4 -1 1 +CALL R4 1 1 L2: SETTABLE R4 R0 R3 FORNLOOP R1 L0 L3: RETURN R0 1 @@ -5924,6 +5926,8 @@ RETURN R2 1 TEST_CASE("InlineMultret") { + ScopedFastFlag sff("LuauCompileBuiltinArity", true); + // inlining a function in multret context is prohibited since we can't adjust L->top outside of CALL/GETVARARGS CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -5994,7 +5998,7 @@ CALL R1 1 -1 RETURN R1 -1 )"); - // and unfortunately we can't do this analysis for builtins or method calls due to getfenv + // we do this for builtins though as we assume getfenv is not used or is not changing arity CHECK_EQ("\n" + compileFunction(R"( local function foo(a) return math.abs(a) @@ -6005,10 +6009,8 @@ return foo(42) 1, 2), R"( DUPCLOSURE R0 K0 ['foo'] -MOVE R1 R0 -LOADN R2 42 -CALL R1 1 -1 -RETURN R1 -1 +LOADN R1 42 +RETURN R1 1 )"); } @@ -6263,6 +6265,8 @@ RETURN R0 52 TEST_CASE("BuiltinFoldingProhibited") { + ScopedFastFlag sff("LuauCompileBuiltinArity", true); + CHECK_EQ("\n" + compileFunction(R"( return math.abs(), @@ -6326,8 +6330,8 @@ L8: LOADN R10 1 FASTCALL2K 19 R10 K3 L9 [true] LOADK R11 K3 [true] GETIMPORT R9 26 [math.min] -CALL R9 2 -1 -L9: RETURN R0 -1 +CALL R9 2 1 +L9: RETURN R0 10 )"); } @@ -6865,4 +6869,111 @@ L3: RETURN R0 0 )"); } +TEST_CASE("BuiltinArity") +{ + ScopedFastFlag sff("LuauCompileBuiltinArity", true); + + // by default we can't assume that we know parameter/result count for builtins as they can be overridden at runtime + CHECK_EQ("\n" + compileFunction(R"( +return math.abs(unknown()) +)", + 0, 1), + R"( +GETIMPORT R1 1 [unknown] +CALL R1 0 -1 +FASTCALL 2 L0 +GETIMPORT R0 4 [math.abs] +CALL R0 -1 -1 +L0: RETURN R0 -1 +)"); + + // however, when using optimization level 2, we assume compile time knowledge about builtin behavior even if we can't deoptimize that with fenv + // in the test case below, this allows us to synthesize a more efficient FASTCALL1 (and use a fixed-return call to unknown) + CHECK_EQ("\n" + compileFunction(R"( +return math.abs(unknown()) +)", + 0, 2), + R"( +GETIMPORT R1 1 [unknown] +CALL R1 0 1 +FASTCALL1 2 R1 L0 +GETIMPORT R0 4 [math.abs] +CALL R0 1 1 +L0: RETURN R0 1 +)"); + + // some builtins are variadic, and as such they can't use fixed-length fastcall variants + CHECK_EQ("\n" + compileFunction(R"( +return math.max(0, unknown()) +)", + 0, 2), + R"( +LOADN R1 0 +GETIMPORT R2 1 [unknown] +CALL R2 0 -1 +FASTCALL 18 L0 +GETIMPORT R0 4 [math.max] +CALL R0 -1 1 +L0: RETURN R0 1 +)"); + + // some builtins are not variadic but don't have a fixed number of arguments; we currently don't optimize this although we might start to in the + // future + CHECK_EQ("\n" + compileFunction(R"( +return bit32.extract(0, 1, unknown()) +)", + 0, 2), + R"( +LOADN R1 0 +LOADN R2 1 +GETIMPORT R3 1 [unknown] +CALL R3 0 -1 +FASTCALL 34 L0 +GETIMPORT R0 4 [bit32.extract] +CALL R0 -1 1 +L0: RETURN R0 1 +)"); + + // importantly, this optimization also helps us get around the multret inlining restriction for builtin wrappers + CHECK_EQ("\n" + compileFunction(R"( +local function new() + return setmetatable({}, MT) +end + +return new() +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 ['new'] +NEWTABLE R2 0 0 +GETIMPORT R3 2 [MT] +FASTCALL2 61 R2 R3 L0 +GETIMPORT R1 4 [setmetatable] +CALL R1 2 1 +L0: RETURN R1 1 +)"); + + // note that the results of this optimization are benign in fixed-arg contexts which dampens the effect of fenv substitutions on correctness in + // practice + CHECK_EQ("\n" + compileFunction(R"( +local x = ... +local y, z = type(x) +return type(y, z) +)", + 0, 2), + R"( +GETVARARGS R0 1 +FASTCALL1 40 R0 L0 +MOVE R2 R0 +GETIMPORT R1 1 [type] +CALL R1 1 2 +L0: FASTCALL2 40 R1 R2 L1 +MOVE R4 R1 +MOVE R5 R2 +GETIMPORT R3 1 [type] +CALL R3 2 1 +L1: RETURN R3 1 +)"); +} + TEST_SUITE_END(); diff --git a/tests/DataFlowGraph.test.cpp b/tests/DataFlowGraph.test.cpp index d8230700a..bd5fe5628 100644 --- a/tests/DataFlowGraph.test.cpp +++ b/tests/DataFlowGraph.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; -class DataFlowGraphFixture +struct DataFlowGraphFixture { // Only needed to fix the operator== reflexivity of an empty Symbol. ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; @@ -23,7 +23,6 @@ class DataFlowGraphFixture std::optional graph; -public: void dfg(const std::string& code) { ParseResult parseResult = Parser::parse(code.c_str(), code.size(), names, allocator); @@ -34,19 +33,19 @@ class DataFlowGraphFixture } template - std::optional getDef(const std::vector& nths = {nth(N)}) + NullableBreadcrumbId getBreadcrumb(const std::vector& nths = {nth(N)}) { T* node = query(module, nths); REQUIRE(node); - return graph->getDef(node); + return graph->getBreadcrumb(node); } template - DefId requireDef(const std::vector& nths = {nth(N)}) + BreadcrumbId requireBreadcrumb(const std::vector& nths = {nth(N)}) { - auto loc = getDef(nths); - REQUIRE(loc); - return NotNull{*loc}; + auto bc = getBreadcrumb(nths); + REQUIRE(bc); + return NotNull{bc}; } }; @@ -59,7 +58,7 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_locals_in_local_stat") local y = x )"); - REQUIRE(getDef()); + REQUIRE(getBreadcrumb()); } TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_parameters_in_functions") @@ -70,7 +69,7 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_parameters_in_functions") end )"); - REQUIRE(getDef()); + REQUIRE(getBreadcrumb()); } TEST_CASE_FIXTURE(DataFlowGraphFixture, "find_aliases") @@ -81,9 +80,9 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "find_aliases") local z = y )"); - DefId x = requireDef(); - DefId y = requireDef(); - REQUIRE(x != y); // TODO: they should be equal but it's not just locals that can alias, so we'll support this later. + BreadcrumbId x = requireBreadcrumb(); + BreadcrumbId y = requireBreadcrumb(); + REQUIRE(x != y); } TEST_CASE_FIXTURE(DataFlowGraphFixture, "independent_locals") @@ -96,8 +95,8 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "independent_locals") local b = y )"); - DefId x = requireDef(); - DefId y = requireDef(); + BreadcrumbId x = requireBreadcrumb(); + BreadcrumbId y = requireBreadcrumb(); REQUIRE(x != y); } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index f245ca933..cbceabbdc 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -178,17 +178,8 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars if (FFlag::DebugLuauDeferredConstraintResolution) { - Luau::check( - *sourceModule, - {}, - frontend.builtinTypes, - NotNull{&ice}, - NotNull{&moduleResolver}, - NotNull{&fileResolver}, - typeChecker.globalScope, - NotNull{&typeChecker.unifierState}, - frontend.options - ); + Luau::check(*sourceModule, {}, frontend.builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, + typeChecker.globalScope, frontend.options); } else typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 9da34367b..0896517f9 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -311,6 +311,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::DIV_NUM, build.constDouble(2), build.constDouble(5))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MOD_NUM, build.constDouble(5), build.constDouble(2))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::POW_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, build.constDouble(5), build.constDouble(2))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::UNM_NUM, build.constDouble(5))); @@ -338,6 +340,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") STORE_DOUBLE R0, 0.40000000000000002 STORE_DOUBLE R0, 1 STORE_DOUBLE R0, 25 + STORE_DOUBLE R0, 2 + STORE_DOUBLE R0, 5 STORE_DOUBLE R0, -5 STORE_INT R0, 1i STORE_INT R0, 0i @@ -809,7 +813,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); - build.inst(IrCmd::LOP_FASTCALL1, build.constUint(0), build.vmReg(1), build.vmReg(2), fallback); + build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(LBF_SETMETATABLE), build.vmReg(1), build.vmReg(2), build.vmReg(3), build.constInt(3), + build.constInt(1)); build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); @@ -830,7 +835,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") %1 = LOAD_POINTER R0 CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 - LOP_FASTCALL1 0u, R1, R2, bb_fallback_1 + %4 = INVOKE_FASTCALL 61u, R1, R2, R3, 3i, 1i CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 STORE_DOUBLE R1, 0.5 @@ -1195,3 +1200,182 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") +{ + IrOp block1 = build.block(IrBlockKind::Internal); + IrOp fallback1 = build.block(IrBlockKind::Fallback); + IrOp block2 = build.block(IrBlockKind::Internal); + IrOp fallback2 = build.block(IrBlockKind::Fallback); + IrOp block3 = build.block(IrBlockKind::Internal); + IrOp block4 = build.block(IrBlockKind::Internal); + + build.beginBlock(block1); + + IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(tnumber), fallback1); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(fallback1); + build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(block2); + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(tnumber), fallback2); + build.inst(IrCmd::JUMP, block3); + + build.beginBlock(fallback2); + build.inst(IrCmd::DO_LEN, build.vmReg(0), build.vmReg(2)); + build.inst(IrCmd::JUMP, block3); + + build.beginBlock(block3); + build.inst(IrCmd::JUMP, block4); + + build.beginBlock(block4); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %0 = LOAD_TAG R2 + CHECK_TAG %0, tnumber, bb_fallback_1 + JUMP bb_linear_6 + +bb_fallback_1: + DO_LEN R1, R2 + JUMP bb_2 + +bb_2: + %5 = LOAD_TAG R2 + CHECK_TAG %5, tnumber, bb_fallback_3 + JUMP bb_4 + +bb_fallback_3: + DO_LEN R0, R2 + JUMP bb_4 + +bb_4: + JUMP bb_5 + +bb_5: + LOP_RETURN 0u, R0, 0i + +bb_linear_6: + LOP_RETURN 0u, R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues") +{ + IrOp block1 = build.block(IrBlockKind::Internal); + IrOp fallback1 = build.block(IrBlockKind::Fallback); + IrOp block2 = build.block(IrBlockKind::Internal); + IrOp fallback2 = build.block(IrBlockKind::Fallback); + IrOp block3 = build.block(IrBlockKind::Internal); + IrOp block4a = build.block(IrBlockKind::Internal); + IrOp block4b = build.block(IrBlockKind::Internal); + + build.beginBlock(block1); + + IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(tnumber), fallback1); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(fallback1); + build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(block2); + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(tnumber), fallback2); + build.inst(IrCmd::JUMP, block3); + + build.beginBlock(fallback2); + build.inst(IrCmd::DO_LEN, build.vmReg(0), build.vmReg(2)); + build.inst(IrCmd::JUMP, block3); + + build.beginBlock(block3); + IrOp tag3a = build.inst(IrCmd::LOAD_TAG, build.vmReg(3)); + build.inst(IrCmd::JUMP_EQ_TAG, tag3a, build.constTag(tnil), block4a, block4b); + + build.beginBlock(block4a); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(block4b); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %0 = LOAD_TAG R2 + CHECK_TAG %0, tnumber, bb_fallback_1 + JUMP bb_2 + +bb_fallback_1: + DO_LEN R1, R2 + JUMP bb_2 + +bb_2: + %5 = LOAD_TAG R2 + CHECK_TAG %5, tnumber, bb_fallback_3 + JUMP bb_4 + +bb_fallback_3: + DO_LEN R0, R2 + JUMP bb_4 + +bb_4: + %10 = LOAD_TAG R3 + JUMP_EQ_TAG %10, tnil, bb_5, bb_6 + +bb_5: + STORE_TAG R0, %10 + LOP_RETURN 0u, R0, 0i + +bb_6: + STORE_TAG R0, %10 + LOP_RETURN 0u, R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "InfiniteLoopInPathAnalysis") +{ + IrOp block1 = build.block(IrBlockKind::Internal); + IrOp block2 = build.block(IrBlockKind::Internal); + + build.beginBlock(block1); + + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(block2); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tboolean)); + build.inst(IrCmd::JUMP, block2); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + STORE_TAG R0, tnumber + JUMP bb_1 + +bb_1: + STORE_TAG R1, tboolean + JUMP bb_1 + +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 784fadad2..7fcc1e542 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -155,6 +155,36 @@ TEST_CASE("string_interpolation_basic") CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); } +TEST_CASE("string_interpolation_full") +{ + ScopedFastFlag sff("LuauFixInterpStringMid", true); + + const std::string testInput = R"(`foo {"bar"} {"baz"} end`)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme interpBegin = lexer.next(); + CHECK_EQ(interpBegin.type, Lexeme::InterpStringBegin); + CHECK_EQ(interpBegin.toString(), "`foo {"); + + Lexeme quote1 = lexer.next(); + CHECK_EQ(quote1.type, Lexeme::QuotedString); + CHECK_EQ(quote1.toString(), "\"bar\""); + + Lexeme interpMid = lexer.next(); + CHECK_EQ(interpMid.type, Lexeme::InterpStringMid); + CHECK_EQ(interpMid.toString(), "} {"); + + Lexeme quote2 = lexer.next(); + CHECK_EQ(quote2.type, Lexeme::QuotedString); + CHECK_EQ(quote2.toString(), "\"baz\""); + + Lexeme interpEnd = lexer.next(); + CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); + CHECK_EQ(interpEnd.toString(), "} end`"); +} + TEST_CASE("string_interpolation_double_brace") { const std::string testInput = R"(`foo{{bad}}bar`)"; diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index afd5a4e43..c716982ee 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -37,7 +37,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") // Normally this would be defined externally, so hack it in for testing addGlobalBinding(frontend, "Wait", Binding{typeChecker.anyType, {}, true, "wait", "@test/global/Wait"}); - LintResult result = lintTyped("Wait(5)"); + LintResult result = lint("Wait(5)"); REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'Wait' is deprecated, use 'wait' instead"); @@ -49,7 +49,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobalNoReplacement") const char* deprecationReplacementString = ""; addGlobalBinding(frontend, "Version", Binding{typeChecker.anyType, {}, true, deprecationReplacementString}); - LintResult result = lintTyped("Version()"); + LintResult result = lint("Version()"); REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'Version' is deprecated"); @@ -1440,8 +1440,10 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") REQUIRE(12 == result.warnings.size()); } -TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") +TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiTyped") { + ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); + unfreeze(typeChecker.globalTypes); TypeId instanceType = typeChecker.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); persist(instanceType); @@ -1459,6 +1461,13 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") addGlobalBinding(frontend, "Color3", Binding{colorType, {}}); + if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + { + ttv->props["foreach"].deprecated = true; + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + } + freeze(typeChecker.globalTypes); LintResult result = lintTyped(R"( @@ -1467,14 +1476,43 @@ return function (i: Instance) print(i.Name) print(Color3.toHSV()) print(Color3.doesntexist, i.doesntexist) -- type error, but this verifies we correctly handle non-existent members + print(table.getn({})) + table.foreach({}, function() end) + print(table.nogetn()) -- verify that we correctly handle non-existent members return i.DataCost end )"); - REQUIRE(3 == result.warnings.size()); + REQUIRE(5 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Member 'Instance.Wait' is deprecated"); CHECK_EQ(result.warnings[1].text, "Member 'toHSV' is deprecated, use 'Color3:ToHSV' instead"); - CHECK_EQ(result.warnings[2].text, "Member 'Instance.DataCost' is deprecated"); + CHECK_EQ(result.warnings[2].text, "Member 'table.getn' is deprecated, use '#' instead"); + CHECK_EQ(result.warnings[3].text, "Member 'table.foreach' is deprecated"); + CHECK_EQ(result.warnings[4].text, "Member 'Instance.DataCost' is deprecated"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiUntyped") +{ + ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); + + if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + { + ttv->props["foreach"].deprecated = true; + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + } + + LintResult result = lint(R"( +return function () + print(table.getn({})) + table.foreach({}, function() end) + print(table.nogetn()) -- verify that we correctly handle non-existent members +end +)"); + + REQUIRE(2 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Member 'table.getn' is deprecated, use '#' instead"); + CHECK_EQ(result.warnings[1].text, "Member 'table.foreach' is deprecated"); } TEST_CASE_FIXTURE(BuiltinsFixture, "TableOperations") diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index ca06046a4..c45932c6f 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -488,6 +488,20 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean_2") )"))); } +TEST_CASE_FIXTURE(NormalizeFixture, "double_negation") +{ + CHECK("number" == toString(normal(R"( + number & Not> + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negate_any") +{ + CHECK("number" == toString(normal(R"( + number & Not + )"))); +} + TEST_CASE_FIXTURE(NormalizeFixture, "intersect_function_and_top_function") { CHECK("() -> ()" == toString(normal(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index c72cbcce6..9ff16d16b 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -458,6 +458,24 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_should_work_when_name_is_also_local") REQUIRE(block->body.data[1]->is()); } +TEST_CASE_FIXTURE(Fixture, "type_alias_span_is_correct") +{ + AstStatBlock* block = parse(R"( + type Packed1 = (T...) -> (T...) + type Packed2 = (Packed1, T...) -> (Packed1, T...) + )"); + + REQUIRE(block != nullptr); + REQUIRE(2 == block->body.size); + AstStatTypeAlias* t1 = block->body.data[0]->as(); + REQUIRE(t1); + REQUIRE(Location{Position{1, 8}, Position{1, 45}} == t1->location); + + AstStatTypeAlias* t2 = block->body.data[1]->as(); + REQUIRE(t2); + REQUIRE(Location{Position{2, 8}, Position{2, 75}} == t2->location); +} + TEST_CASE_FIXTURE(Fixture, "parse_error_messages") { CHECK_EQ(getParseError(R"( @@ -1020,6 +1038,35 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_call_without_parens") } } +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_expression") +{ + ScopedFastFlag sff("LuauFixInterpStringMid", true); + + try + { + parse(R"( + print(`{}`) + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Malformed interpolated string, expected expression inside '{}'", e.getErrors().front().getMessage()); + } + + try + { + parse(R"( + print(`{}{1}`) + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Malformed interpolated string, expected expression inside '{}'", e.getErrors().front().getMessage()); + } +} + TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection") { try diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 088b4d56e..cf27518a6 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -272,4 +272,22 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "set_prop_of_intersection_containing_metatable") +{ + CheckResult result = check(R"( + export type Set = typeof(setmetatable( + {} :: { + add: (self: Set, T) -> Set, + }, + {} + )) + + local Set = {} :: Set & {} + + function Set:add(t) + return self + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index c389f325f..0aacb8aec 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -814,18 +814,4 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") } } -TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_clone_it") -{ - CheckResult result = check(R"( - local function f(x: unknown) - if typeof(x) == "table" then - local cloned: {} = table.clone(x) - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - // LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index ba0f975ee..570cf278e 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -535,13 +535,20 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b + } } TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") { - CheckResult result = check(R"( local t local u: {x: number?} = {x = nil} @@ -804,7 +811,10 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("{| x: true |}?", toString(requireTypeAtPosition({3, 28}))); + else + CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") @@ -1523,7 +1533,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length if (FFlag::DebugLuauDeferredConstraintResolution) { LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("a & table", toString(requireTypeAtPosition({3, 29}))); + CHECK_EQ("table", toString(requireTypeAtPosition({3, 29}))); } else { @@ -1532,6 +1542,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length } } +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_clone_it") +{ + CheckResult result = check(R"( + local function f(x: unknown) + if typeof(x) == "table" then + local cloned: {} = table.clone(x) + end + end + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + } +} + TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_during_constraint_solving_stage") { CheckResult result = check(R"( @@ -1573,4 +1603,150 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_duri CHECK_EQ("Instance", toString(requireTypeAtPosition({7, 28}))); } +TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + foo = { bar = 5 :: number? } + + if foo.bar then + local bar = foo.bar + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + CHECK_EQ("*error-type*", toString(requireTypeAtPosition({4, 30}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never") +{ + CheckResult result = check(R"( + local function f(t: {string}, s: string) + local v1 = t[5] + local v2 = v1 + + if typeof(v1) == "nil" then + local foo = v1 + else + local foo = v1 + end + + if typeof(v2) == "nil" then + local foo = v2 + else + local foo = v2 + end + + if typeof(s) == "nil" then + local foo = s + else + local foo = s + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 28}))); + + CHECK_EQ("nil", toString(requireTypeAtPosition({12, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({14, 28}))); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("never", toString(requireTypeAtPosition({18, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({20, 28}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({18, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({20, 28}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "cat_or_dog_through_a_local") +{ + CheckResult result = check(R"( + type Cat = { tag: "cat", catfood: string } + type Dog = { tag: "dog", dogfood: string } + type Animal = Cat | Dog + + local function f(animal: Animal) + local tag = animal.tag + if tag == "dog" then + local dog = animal + elseif tag == "cat" then + local cat = animal + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Cat | Dog", toString(requireTypeAtPosition({8, 28}))); + CHECK_EQ("Cat | Dog", toString(requireTypeAtPosition({10, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "prove_that_dataflow_analysis_isnt_doing_alias_tracking_yet") +{ + CheckResult result = check(R"( + local function f(tag: "cat" | "dog") + local tag2 = tag + + if tag2 == "cat" then + local foo = tag + else + local foo = tag + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("cat" | "dog")", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"("cat" | "dog")", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "fail_to_refine_a_property_of_subscript_expression") +{ + CheckResult result = check(R"( + type Foo = { foo: number? } + local function f(t: {Foo}) + if t[1].foo then + local foo = t[1].foo + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 34}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_annotations_arent_relevant_when_doing_dataflow_analysis") +{ + CheckResult result = check(R"( + local function s() return "hello" end + + local function f(t: {string}) + local s1: string = t[5] + local s2: string = s() + + if typeof(s1) == "nil" and typeof(s2) == "nil" then + local foo = s1 + local bar = s2 + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({8, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("never", toString(requireTypeAtPosition({9, 28}))); + else + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 27b43aa9e..2a87f0e3b 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -519,10 +519,16 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_ local x = get(t) )"); - // Currently this errors but it shouldn't, since set only needs write access - // TODO: file a JIRA for this - LUAU_REQUIRE_ERRORS(result); - // CHECK_EQ("number?", toString(requireType("x"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("x"))); + } + else + { + LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ("number?", toString(requireType("x"))); + } } TEST_CASE_FIXTURE(Fixture, "width_subtyping") @@ -2646,7 +2652,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_sc const MetatableType* newRet = get(follow(*newRetType)); REQUIRE(newRet); - const TableType* newRetMeta = get(newRet->metatable); + const TableType* newRetMeta = get(follow(newRet->metatable)); REQUIRE(newRetMeta); CHECK(newRetMeta->props.count("incr")); @@ -3601,4 +3607,42 @@ TEST_CASE_FIXTURE(Fixture, "dont_extend_unsealed_tables_in_rvalue_position") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "extend_unsealed_table_with_metatable") +{ + CheckResult result = check(R"( + local T = setmetatable({}, { + __call = function(_, name: string?) + end, + }) + + T.for_ = "for_" + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "top_table_type_is_isomorphic_to_empty_sealed_table_type") +{ + CheckResult result = check(R"( + local None = newproxy(true) + local mt = getmetatable(None) + mt.__tostring = function() + return "Object.None" + end + + function assign(...) + for index = 1, select("#", ...) do + local rest = select(index, ...) + + if rest ~= nil and typeof(rest) == "table" then + for key, value in pairs(rest) do + end + end + end + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 16797ee4d..3865e83a8 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -439,7 +439,10 @@ end _ += not _ do end )"); +} +TEST_CASE_FIXTURE(Fixture, "cyclic_follow_2") +{ check(R"( --!nonstrict n13,_,table,_,l0,_,_ = ... diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 972c399b2..e2f68e654 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -261,6 +261,11 @@ assert(math.sign(inf) == 1) assert(math.sign(-inf) == -1) assert(math.sign(nan) == 0) +assert(math.min(nan, 2) ~= math.min(nan, 2)) +assert(math.min(1, nan) == 1) +assert(math.max(nan, 2) ~= math.max(nan, 2)) +assert(math.max(1, nan) == 1) + -- clamp assert(math.clamp(-1, 0, 1) == 0) assert(math.clamp(0.5, 0, 1) == 0.5) diff --git a/tools/faillist.txt b/tools/faillist.txt index 5c84d1687..c68312985 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,13 +1,9 @@ -AnnotationTests.instantiate_type_fun_should_not_trip_rbxassert AnnotationTests.too_many_type_params -AnnotationTests.two_type_params AstQuery.last_argument_function_call_type AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop -AutocompleteTest.autocomplete_oop_implicit_self -AutocompleteTest.type_correct_expected_return_type_suggestion -AutocompleteTest.type_correct_suggestion_for_overloads +AutocompleteTest.autocomplete_response_perf1 BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 @@ -43,7 +39,6 @@ GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_infer_generic_functions -GenericsTests.dont_unify_bound_types GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_functions_should_be_memory_safe @@ -56,7 +51,6 @@ GenericsTests.infer_generic_lib_function_function_argument GenericsTests.instantiated_function_argument_names GenericsTests.no_stack_overflow_from_quantifying GenericsTests.self_recursive_instantiated_param -IntersectionTypes.overload_is_not_a_function IntersectionTypes.table_intersection_write_sealed IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect @@ -72,7 +66,6 @@ NonstrictModeTests.local_tables_are_not_any NonstrictModeTests.locals_are_any_by_default NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon NonstrictModeTests.parameters_having_type_any_are_optional -NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_props_are_any ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.bail_early_if_unification_is_too_complicated @@ -85,20 +78,11 @@ ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete -RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string -RefinementTest.discriminate_from_isa_of_x -RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil -RefinementTest.narrow_property_of_a_bounded_variable -RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true -RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage -RefinementTest.refine_param_of_type_folder_or_part_without_using_typeof -RefinementTest.refine_unknowns RefinementTest.type_guard_can_filter_for_intersection_of_tables RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table -RefinementTest.x_is_not_instance_or_else_not_part RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.accidentally_checked_prop_in_opposite_branch @@ -109,7 +93,6 @@ TableTests.checked_prop_too_early TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index -TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys TableTests.error_detailed_metatable_prop TableTests.expected_indexer_from_table_union @@ -134,7 +117,6 @@ TableTests.less_exponential_blowup_please TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys TableTests.nil_assign_doesnt_hit_indexer -TableTests.nil_assign_doesnt_hit_no_indexer TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.oop_polymorphic @@ -153,7 +135,6 @@ TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.table_unification_4 TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon -TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type ToString.named_metatable_toStringNamedFunction ToString.toStringDetailed2 ToString.toStringErrorPack @@ -187,11 +168,9 @@ TypeInfer.no_stack_overflow_from_isoptional TypeInfer.no_stack_overflow_from_isoptional2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice -TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.classes_without_overloaded_operators_cannot_be_added -TypeInferClasses.higher_order_function_arguments_are_contravariant TypeInferClasses.index_instance_property TypeInferClasses.optional_class_field_access_error TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties @@ -239,7 +218,6 @@ TypeInferModules.module_type_conflict_instantiated TypeInferModules.type_error_of_unknown_qualified_type TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted -TypeInferOOP.object_constructor_can_refer_to_method_of_self TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.CallOrOfFunctions TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable @@ -265,25 +243,15 @@ TypeInferUnknownNever.math_operators_and_never TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check TypePackTests.type_alias_backwards_compatible -TypePackTests.type_alias_default_mixed_self TypePackTests.type_alias_default_type_errors -TypePackTests.type_alias_default_type_pack_self_chained_tp -TypePackTests.type_alias_default_type_pack_self_tp -TypePackTests.type_alias_defaults_confusing_types -TypePackTests.type_alias_type_pack_multi -TypePackTests.type_alias_type_pack_variadic TypePackTests.type_alias_type_packs_errors -TypePackTests.type_alias_type_packs_nested TypePackTests.unify_variadic_tails_in_arguments TypePackTests.variadic_packs TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch TypeSingletons.indexing_on_union_of_string_singletons TypeSingletons.no_widening_from_callsites -TypeSingletons.overloaded_function_call_with_singletons -TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened -TypeSingletons.table_properties_singleton_strings_mismatch TypeSingletons.table_properties_type_error_escapes TypeSingletons.taking_the_length_of_union_of_string_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton From 4653484913613fd4d84cab1595256a42c5e6bb75 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 10 Mar 2023 11:20:04 -0800 Subject: [PATCH 40/66] Sync to upstream/release/567 --- Analysis/include/Luau/BuiltinDefinitions.h | 27 +- Analysis/include/Luau/Error.h | 4 +- Analysis/include/Luau/Frontend.h | 24 +- Analysis/include/Luau/Quantify.h | 2 +- Analysis/include/Luau/Type.h | 8 +- Analysis/include/Luau/TypeInfer.h | 18 +- Analysis/include/Luau/Unifier.h | 5 + Analysis/src/AstQuery.cpp | 14 +- Analysis/src/Autocomplete.cpp | 49 +- Analysis/src/BuiltinDefinitions.cpp | 197 ++++---- Analysis/src/ConstraintGraphBuilder.cpp | 4 +- Analysis/src/ConstraintSolver.cpp | 34 +- Analysis/src/Error.cpp | 4 +- Analysis/src/Frontend.cpp | 146 ++++-- Analysis/src/Quantify.cpp | 5 +- Analysis/src/TxnLog.cpp | 1 + Analysis/src/Type.cpp | 8 +- Analysis/src/TypeChecker2.cpp | 73 ++- Analysis/src/TypeInfer.cpp | 51 +- Analysis/src/TypeReduction.cpp | 44 +- Analysis/src/Unifier.cpp | 158 ++++--- CLI/Analyze.cpp | 5 +- CodeGen/include/Luau/IrAnalysis.h | 56 +++ CodeGen/include/Luau/IrData.h | 12 +- CodeGen/include/Luau/IrDump.h | 9 +- CodeGen/include/Luau/IrUtils.h | 3 - CodeGen/src/CodeGen.cpp | 23 +- CodeGen/src/IrAnalysis.cpp | 517 +++++++++++++++++++++ CodeGen/src/IrBuilder.cpp | 2 +- CodeGen/src/IrDump.cpp | 163 +++++-- CodeGen/src/IrLoweringX64.cpp | 30 +- CodeGen/src/IrRegAllocX64.cpp | 22 +- CodeGen/src/IrRegAllocX64.h | 2 + CodeGen/src/IrUtils.cpp | 52 ++- Common/include/Luau/ExperimentalFlags.h | 1 - VM/include/luaconf.h | 12 + VM/src/lbuiltins.cpp | 105 +++++ VM/src/ldebug.cpp | 17 +- VM/src/lmathlib.cpp | 11 +- VM/src/ltablib.cpp | 154 +++++- VM/src/lvmexecute.cpp | 9 +- VM/src/lvmutils.cpp | 20 +- fuzz/proto.cpp | 34 +- tests/AstQuery.test.cpp | 13 + tests/Autocomplete.test.cpp | 128 ++--- tests/BuiltinDefinitions.test.cpp | 4 +- tests/ClassFixture.cpp | 41 +- tests/Conformance.test.cpp | 9 +- tests/ConstraintGraphBuilderFixture.cpp | 2 +- tests/Fixture.cpp | 61 +-- tests/Fixture.h | 1 - tests/Frontend.test.cpp | 12 +- tests/IrBuilder.test.cpp | 410 +++++++++++++--- tests/Linter.test.cpp | 49 +- tests/Module.test.cpp | 33 +- tests/NonstrictMode.test.cpp | 10 +- tests/Normalize.test.cpp | 6 +- tests/ToDot.test.cpp | 14 +- tests/ToString.test.cpp | 32 +- tests/TypeInfer.aliases.test.cpp | 33 +- tests/TypeInfer.annotations.test.cpp | 19 +- tests/TypeInfer.anyerror.test.cpp | 4 +- tests/TypeInfer.builtins.test.cpp | 32 +- tests/TypeInfer.definitions.test.cpp | 76 +-- tests/TypeInfer.functions.test.cpp | 55 ++- tests/TypeInfer.generics.test.cpp | 4 +- tests/TypeInfer.intersectionTypes.test.cpp | 14 +- tests/TypeInfer.loops.test.cpp | 26 +- tests/TypeInfer.modules.test.cpp | 39 ++ tests/TypeInfer.oop.test.cpp | 19 + tests/TypeInfer.operators.test.cpp | 44 +- tests/TypeInfer.primitives.test.cpp | 4 +- tests/TypeInfer.refinements.test.cpp | 39 +- tests/TypeInfer.tables.test.cpp | 78 ++-- tests/TypeInfer.test.cpp | 41 +- tests/TypeInfer.tryUnify.test.cpp | 43 +- tests/TypeInfer.typePacks.cpp | 32 +- tests/TypeInfer.unionTypes.test.cpp | 10 +- tests/TypeReduction.test.cpp | 18 + tests/TypeVar.test.cpp | 66 +-- tests/conformance/basic.lua | 2 + tests/conformance/math.lua | 2 + tests/conformance/sort.lua | 45 +- tools/faillist.txt | 13 +- 84 files changed, 2638 insertions(+), 1080 deletions(-) diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 0604b40e2..162139581 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -10,12 +10,13 @@ namespace Luau { struct Frontend; +struct GlobalTypes; struct TypeChecker; struct TypeArena; -void registerBuiltinTypes(Frontend& frontend); +void registerBuiltinTypes(GlobalTypes& globals); -void registerBuiltinGlobals(TypeChecker& typeChecker); +void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals); void registerBuiltinGlobals(Frontend& frontend); TypeId makeUnion(TypeArena& arena, std::vector&& types); @@ -23,8 +24,7 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types); /** Build an optional 't' */ -TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t); -TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t); +TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t); /** Small utility function for building up type definitions from C++. */ @@ -52,17 +52,12 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string& std::string getBuiltinDefinitionSource(); -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding); -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding); -void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding); -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding); -std::optional tryGetGlobalBinding(Frontend& frontend, const std::string& name); -Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name); -TypeId getGlobalBinding(Frontend& frontend, const std::string& name); -TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name); +void addGlobalBinding(GlobalTypes& globals, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(GlobalTypes& globals, const std::string& name, Binding binding); +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, Binding binding); +std::optional tryGetGlobalBinding(GlobalTypes& globals, const std::string& name); +Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name); +TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name); } // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 69d4cca3c..8571430bf 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -411,8 +411,8 @@ struct InternalErrorReporter std::function onInternalError; std::string moduleName; - [[noreturn]] void ice(const std::string& message, const Location& location); - [[noreturn]] void ice(const std::string& message); + [[noreturn]] void ice(const std::string& message, const Location& location) const; + [[noreturn]] void ice(const std::string& message) const; }; class InternalCompilerError : public std::exception diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 7c5dc4a0d..9c0366a6c 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -21,6 +21,7 @@ class ParseError; struct Frontend; struct TypeError; struct LintWarning; +struct GlobalTypes; struct TypeChecker; struct FileResolver; struct ModuleResolver; @@ -31,11 +32,12 @@ struct LoadDefinitionFileResult { bool success; ParseResult parseResult; + SourceModule sourceModule; ModulePtr module; }; -LoadDefinitionFileResult loadDefinitionFile( - TypeChecker& typeChecker, ScopePtr targetScope, std::string_view definition, const std::string& packageName); +LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition, + const std::string& packageName, bool captureComments); std::optional parseMode(const std::vector& hotcomments); @@ -152,14 +154,12 @@ struct Frontend void clear(); ScopePtr addEnvironment(const std::string& environmentName); - ScopePtr getEnvironmentScope(const std::string& environmentName); + ScopePtr getEnvironmentScope(const std::string& environmentName) const; - void registerBuiltinDefinition(const std::string& name, std::function); + void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); - LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName); - - ScopePtr getGlobalScope(); + LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments); private: ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete = false, bool recordJsonLog = false); @@ -171,10 +171,10 @@ struct Frontend static LintResult classifyLints(const std::vector& warnings, const Config& config); - ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete); + ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const; std::unordered_map environments; - std::unordered_map> builtinDefinitions; + std::unordered_map> builtinDefinitions; BuiltinTypes builtinTypes_; @@ -184,21 +184,19 @@ struct Frontend FileResolver* fileResolver; FrontendModuleResolver moduleResolver; FrontendModuleResolver moduleResolverForAutocomplete; + GlobalTypes globals; + GlobalTypes globalsForAutocomplete; TypeChecker typeChecker; TypeChecker typeCheckerForAutocomplete; ConfigResolver* configResolver; FrontendOptions options; InternalErrorReporter iceHandler; - TypeArena globalTypes; std::unordered_map sourceNodes; std::unordered_map sourceModules; std::unordered_map requireTrace; Stats stats = {}; - -private: - ScopePtr globalScope; }; ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index b350fab52..c86512f1f 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -10,6 +10,6 @@ struct TypeArena; struct Scope; void quantify(TypeId ty, TypeLevel level); -TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope); +std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope); } // namespace Luau diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index cf1f8dae4..ef2d4c6a4 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -640,10 +640,10 @@ struct BuiltinTypes BuiltinTypes(const BuiltinTypes&) = delete; void operator=(const BuiltinTypes&) = delete; - TypeId errorRecoveryType(TypeId guess); - TypePackId errorRecoveryTypePack(TypePackId guess); - TypeId errorRecoveryType(); - TypePackId errorRecoveryTypePack(); + TypeId errorRecoveryType(TypeId guess) const; + TypePackId errorRecoveryTypePack(TypePackId guess) const; + TypeId errorRecoveryType() const; + TypePackId errorRecoveryTypePack() const; private: std::unique_ptr arena; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 678bd419d..21cb26371 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -63,11 +63,22 @@ enum class ValueContext RValue }; +struct GlobalTypes +{ + GlobalTypes(NotNull builtinTypes); + + NotNull builtinTypes; // Global types are based on builtin types + + TypeArena globalTypes; + SourceModule globalNames; // names for symbols entered into globalScope + ScopePtr globalScope; // shared by all modules +}; + // All Types are retained via Environment::types. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker { - explicit TypeChecker(ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); + explicit TypeChecker(const GlobalTypes& globals, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); TypeChecker(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete; @@ -355,11 +366,10 @@ struct TypeChecker */ std::vector unTypePack(const ScopePtr& scope, TypePackId pack, size_t expectedLength, const Location& location); - TypeArena globalTypes; + // TODO: only const version of global scope should be available to make sure nothing else is modified inside of from users of TypeChecker + const GlobalTypes& globals; ModuleResolver* resolver; - SourceModule globalNames; // names for symbols entered into globalScope - ScopePtr globalScope; // shared by all modules ModulePtr currentModule; ModuleName currentModuleName; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 50024e3fd..fc886ac0c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -90,6 +90,11 @@ struct Unifier private: void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy); + + // Traverse the two types provided and block on any BlockedTypes we find. + // Returns true if any types were blocked on. + bool blockOnBlockedTypes(TypeId subTy, TypeId superTy); + void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv); void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index e95b0017f..b0c3750b1 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -12,7 +12,6 @@ #include LUAU_FASTFLAG(LuauCompleteTableKeysBetter); -LUAU_FASTFLAGVARIABLE(SupportTypeAliasGoToDeclaration, false); namespace Luau { @@ -195,17 +194,10 @@ struct FindFullAncestry final : public AstVisitor bool visit(AstType* type) override { - if (FFlag::SupportTypeAliasGoToDeclaration) - { - if (includeTypes) - return visit(static_cast(type)); - else - return false; - } + if (includeTypes) + return visit(static_cast(type)); else - { - return AstVisitor::visit(type); - } + return false; } bool visit(AstNode* node) override diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1e0949711..1df4d3d75 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,8 +14,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); -LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInWhile, false); -LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInFor, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteSkipNormalization, false); static const std::unordered_set kStatementStartingKeywords = { @@ -1425,24 +1423,12 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (!statFor->hasDo || position < statFor->doLocation.begin) { - if (FFlag::LuauFixAutocompleteInFor) - { - if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || - (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } - else - { - if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || + (statFor->step && statFor->step->location.containsClosed(position))) + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || - (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - } + if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; return {}; } @@ -1493,14 +1479,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) { - if (FFlag::LuauFixAutocompleteInWhile) - { - return autocompleteWhileLoopKeywords(ancestry); - } - else - { - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } + return autocompleteWhileLoopKeywords(ancestry); } if (!statWhile->hasDo || position < statWhile->doLocation.begin) @@ -1511,18 +1490,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } else if (AstStatWhile* statWhile = extractStat(ancestry); - FFlag::LuauFixAutocompleteInWhile ? (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && - statWhile->condition && !statWhile->condition->location.containsClosed(position)) - : (statWhile && !statWhile->hasDo)) + (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition && + !statWhile->condition->location.containsClosed(position))) { - if (FFlag::LuauFixAutocompleteInWhile) - { - return autocompleteWhileLoopKeywords(ancestry); - } - else - { - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } + return autocompleteWhileLoopKeywords(ancestry); } else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { @@ -1672,7 +1643,7 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName return {}; NotNull builtinTypes = frontend.builtinTypes; - Scope* globalScope = frontend.typeCheckerForAutocomplete.globalScope.get(); + Scope* globalScope = frontend.globalsForAutocomplete.globalScope.get(); TypeArena typeArena; return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index b111c504a..d2ace49b9 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -52,14 +52,9 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types) return arena.addType(IntersectionType{std::move(types)}); } -TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t) +TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t) { - return makeUnion(arena, {frontend.typeChecker.nilType, t}); -} - -TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t) -{ - return makeUnion(arena, {typeChecker.nilType, t}); + return makeUnion(arena, {builtinTypes->nilType, t}); } TypeId makeFunction( @@ -148,85 +143,52 @@ Property makeProperty(TypeId ty, std::optional documentationSymbol) }; } -void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName) -{ - addGlobalBinding(frontend, frontend.getGlobalScope(), name, ty, packageName); -} - -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); - -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName) +void addGlobalBinding(GlobalTypes& globals, const std::string& name, TypeId ty, const std::string& packageName) { - addGlobalBinding(typeChecker, typeChecker.globalScope, name, ty, packageName); + addGlobalBinding(globals, globals.globalScope, name, ty, packageName); } -void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding) +void addGlobalBinding(GlobalTypes& globals, const std::string& name, Binding binding) { - addGlobalBinding(frontend, frontend.getGlobalScope(), name, binding); + addGlobalBinding(globals, globals.globalScope, name, binding); } -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding) -{ - addGlobalBinding(typeChecker, typeChecker.globalScope, name, binding); -} - -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) { std::string documentationSymbol = packageName + "/global/" + name; - addGlobalBinding(frontend, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); + addGlobalBinding(globals, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); } -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, Binding binding) { - std::string documentationSymbol = packageName + "/global/" + name; - addGlobalBinding(typeChecker, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); + scope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = binding; } -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding) +std::optional tryGetGlobalBinding(GlobalTypes& globals, const std::string& name) { - addGlobalBinding(frontend.typeChecker, scope, name, binding); -} - -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding) -{ - scope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = binding; -} - -std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std::string& name) -{ - AstName astName = typeChecker.globalNames.names->getOrAdd(name.c_str()); - auto it = typeChecker.globalScope->bindings.find(astName); - if (it != typeChecker.globalScope->bindings.end()) + AstName astName = globals.globalNames.names->getOrAdd(name.c_str()); + auto it = globals.globalScope->bindings.find(astName); + if (it != globals.globalScope->bindings.end()) return it->second; return std::nullopt; } -TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name) +TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name) { - auto t = tryGetGlobalBinding(typeChecker, name); + auto t = tryGetGlobalBinding(globals, name); LUAU_ASSERT(t.has_value()); return t->typeId; } -TypeId getGlobalBinding(Frontend& frontend, const std::string& name) -{ - return getGlobalBinding(frontend.typeChecker, name); -} - -std::optional tryGetGlobalBinding(Frontend& frontend, const std::string& name) -{ - return tryGetGlobalBinding(frontend.typeChecker, name); -} - -Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name) +Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name) { - AstName astName = typeChecker.globalNames.names->get(name.c_str()); + AstName astName = globals.globalNames.names->get(name.c_str()); if (astName == AstName()) return nullptr; - auto it = typeChecker.globalScope->bindings.find(astName); - if (it != typeChecker.globalScope->bindings.end()) + auto it = globals.globalScope->bindings.find(astName); + if (it != globals.globalScope->bindings.end()) return &it->second; return nullptr; @@ -240,34 +202,33 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string& } } -void registerBuiltinTypes(Frontend& frontend) +void registerBuiltinTypes(GlobalTypes& globals) { - frontend.getGlobalScope()->addBuiltinTypeBinding("any", TypeFun{{}, frontend.builtinTypes->anyType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("nil", TypeFun{{}, frontend.builtinTypes->nilType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("number", TypeFun{{}, frontend.builtinTypes->numberType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("string", TypeFun{{}, frontend.builtinTypes->stringType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("boolean", TypeFun{{}, frontend.builtinTypes->booleanType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("thread", TypeFun{{}, frontend.builtinTypes->threadType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("unknown", TypeFun{{}, frontend.builtinTypes->unknownType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("never", TypeFun{{}, frontend.builtinTypes->neverType}); + globals.globalScope->addBuiltinTypeBinding("any", TypeFun{{}, globals.builtinTypes->anyType}); + globals.globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, globals.builtinTypes->nilType}); + globals.globalScope->addBuiltinTypeBinding("number", TypeFun{{}, globals.builtinTypes->numberType}); + globals.globalScope->addBuiltinTypeBinding("string", TypeFun{{}, globals.builtinTypes->stringType}); + globals.globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, globals.builtinTypes->booleanType}); + globals.globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, globals.builtinTypes->threadType}); + globals.globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, globals.builtinTypes->unknownType}); + globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType}); } -void registerBuiltinGlobals(TypeChecker& typeChecker) +void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) { - LUAU_ASSERT(!typeChecker.globalTypes.types.isFrozen()); - LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); - - TypeId nilType = typeChecker.nilType; + LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); + LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); - TypeArena& arena = typeChecker.globalTypes; - NotNull builtinTypes = typeChecker.builtinTypes; + TypeArena& arena = globals.globalTypes; + NotNull builtinTypes = globals.builtinTypes; - LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); + LoadDefinitionFileResult loadResult = + Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false); LUAU_ASSERT(loadResult.success); TypeId genericK = arena.addType(GenericType{"K"}); TypeId genericV = arena.addType(GenericType{"V"}); - TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); + TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), globals.globalScope->level, TableState::Generic}); std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); LUAU_ASSERT(stringMetatableTy); @@ -277,33 +238,33 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) auto it = stringMetatableTable->props.find("__index"); LUAU_ASSERT(it != stringMetatableTable->props.end()); - addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); + addGlobalBinding(globals, "string", it->second.type, "@luau"); // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(typeChecker, arena, genericK), genericV}}); - addGlobalBinding(typeChecker, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); + addGlobalBinding(globals, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, builtinTypes->nilType}}); // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + addGlobalBinding(globals, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericType{"MT"}); - TableType tab{TableState::Generic, typeChecker.globalScope->level}; + TableType tab{TableState::Generic, globals.globalScope->level}; TypeId tabTy = arena.addType(tab); TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); - addGlobalBinding(typeChecker, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); + addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); // clang-format off // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(typeChecker, "setmetatable", + addGlobalBinding(globals, "setmetatable", arena.addType( FunctionType{ {genericMT}, @@ -315,7 +276,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) ); // clang-format on - for (const auto& pair : typeChecker.globalScope->bindings) + for (const auto& pair : globals.globalScope->bindings) { persist(pair.second.typeId); @@ -326,12 +287,12 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) } } - attachMagicFunction(getGlobalBinding(typeChecker, "assert"), magicFunctionAssert); - attachMagicFunction(getGlobalBinding(typeChecker, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(typeChecker, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(typeChecker, "select"), dcrMagicFunctionSelect); + attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); + attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); + attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); + attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); - if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) { // tabTy is a generic table type which we can't express via declaration syntax yet ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); @@ -349,26 +310,28 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); } - attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(typeChecker, "require"), dcrMagicFunctionRequire); + attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); + attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); } void registerBuiltinGlobals(Frontend& frontend) { - LUAU_ASSERT(!frontend.globalTypes.types.isFrozen()); - LUAU_ASSERT(!frontend.globalTypes.typePacks.isFrozen()); + GlobalTypes& globals = frontend.globals; + + LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); + LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); - registerBuiltinTypes(frontend); + registerBuiltinTypes(globals); - TypeArena& arena = frontend.globalTypes; - NotNull builtinTypes = frontend.builtinTypes; + TypeArena& arena = globals.globalTypes; + NotNull builtinTypes = globals.builtinTypes; - LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau"); + LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau", /* captureComments */ false); LUAU_ASSERT(loadResult.success); TypeId genericK = arena.addType(GenericType{"K"}); TypeId genericV = arena.addType(GenericType{"V"}); - TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), frontend.getGlobalScope()->level, TableState::Generic}); + TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), globals.globalScope->level, TableState::Generic}); std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); LUAU_ASSERT(stringMetatableTy); @@ -378,33 +341,33 @@ void registerBuiltinGlobals(Frontend& frontend) auto it = stringMetatableTable->props.find("__index"); LUAU_ASSERT(it != stringMetatableTable->props.end()); - addGlobalBinding(frontend, "string", it->second.type, "@luau"); + addGlobalBinding(globals, "string", it->second.type, "@luau"); // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(frontend, arena, genericK), genericV}}); - addGlobalBinding(frontend, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); + addGlobalBinding(globals, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.builtinTypes->nilType}}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, builtinTypes->nilType}}); // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) - addGlobalBinding(frontend, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + addGlobalBinding(globals, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericType{"MT"}); - TableType tab{TableState::Generic, frontend.getGlobalScope()->level}; + TableType tab{TableState::Generic, globals.globalScope->level}; TypeId tabTy = arena.addType(tab); TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); - addGlobalBinding(frontend, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); + addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); // clang-format off // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(frontend, "setmetatable", + addGlobalBinding(globals, "setmetatable", arena.addType( FunctionType{ {genericMT}, @@ -416,7 +379,7 @@ void registerBuiltinGlobals(Frontend& frontend) ); // clang-format on - for (const auto& pair : frontend.getGlobalScope()->bindings) + for (const auto& pair : globals.globalScope->bindings) { persist(pair.second.typeId); @@ -427,12 +390,12 @@ void registerBuiltinGlobals(Frontend& frontend) } } - attachMagicFunction(getGlobalBinding(frontend, "assert"), magicFunctionAssert); - attachMagicFunction(getGlobalBinding(frontend, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(frontend, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(frontend, "select"), dcrMagicFunctionSelect); + attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); + attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); + attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); + attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); - if (TableType* ttv = getMutable(getGlobalBinding(frontend, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) { // tabTy is a generic table type which we can't express via declaration syntax yet ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); @@ -449,8 +412,8 @@ void registerBuiltinGlobals(Frontend& frontend) attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); } - attachMagicFunction(getGlobalBinding(frontend, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(frontend, "require"), dcrMagicFunctionRequire); + attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); + attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); } static std::optional> magicFunctionSelect( diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 9ee2b0882..711d357f0 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -18,7 +18,6 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauNegatedClassTypes); -LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration); namespace Luau { @@ -587,8 +586,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (ModulePtr module = moduleResolver->getModule(moduleInfo->name)) { scope->importedTypeBindings[name] = module->exportedTypeBindings; - if (FFlag::SupportTypeAliasGoToDeclaration) - scope->importedModules[name] = moduleName; + scope->importedModules[name] = moduleInfo->name; } } } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 3cb4e4e7e..3c306b40e 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -521,12 +521,19 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullscope); - - if (isBlocked(c.generalizedType)) - asMutable(c.generalizedType)->ty.emplace(generalized); + std::optional generalized = quantify(arena, c.sourceType, constraint->scope); + if (generalized) + { + if (isBlocked(c.generalizedType)) + asMutable(c.generalizedType)->ty.emplace(*generalized); + else + unify(c.generalizedType, *generalized, constraint->scope); + } else - unify(c.generalizedType, generalized, constraint->scope); + { + reportError(CodeTooComplex{}, constraint->location); + asMutable(c.generalizedType)->ty.emplace(builtinTypes->errorRecoveryType()); + } unblock(c.generalizedType); unblock(c.sourceType); @@ -1365,7 +1372,7 @@ static std::optional updateTheTableType(NotNull arena, TypeId if (it == tbl->props.end()) return std::nullopt; - t = it->second.type; + t = follow(it->second.type); } // The last path segment should not be a property of the table at all. @@ -1450,12 +1457,6 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) subjectType = follow(mt->table); - if (get(subjectType) || get(subjectType) || get(subjectType)) - { - bind(c.resultType, subjectType); - return true; - } - if (get(subjectType)) { TypeId ty = arena->freshType(constraint->scope); @@ -1501,16 +1502,13 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) + else { - // Classes and intersections never change shape as a result of property - // assignments. The result is always the subject. + // Other kinds of types don't change shape when properties are assigned + // to them. (if they allow properties at all!) bind(c.resultType, subjectType); return true; } - - LUAU_ASSERT(0); - return true; } bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force) diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index a527b2440..84b9cb37d 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -959,7 +959,7 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) visit(visitErrorData, error.data); } -void InternalErrorReporter::ice(const std::string& message, const Location& location) +void InternalErrorReporter::ice(const std::string& message, const Location& location) const { InternalCompilerError error(message, moduleName, location); @@ -969,7 +969,7 @@ void InternalErrorReporter::ice(const std::string& message, const Location& loca throw error; } -void InternalErrorReporter::ice(const std::string& message) +void InternalErrorReporter::ice(const std::string& message) const { InternalCompilerError error(message, moduleName); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b3e453db0..722f1a2c8 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -31,6 +31,7 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAGVARIABLE(LuauDefinitionFileSourceModule, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau @@ -83,32 +84,31 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) } } -LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName) +LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments) { if (!FFlag::DebugLuauDeferredConstraintResolution) - return Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, source, packageName); + return Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, source, packageName, captureComments); LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); + Luau::SourceModule sourceModule; ParseOptions options; options.allowDeclarationSyntax = true; + options.captureComments = captureComments; - Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); + Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, nullptr}; + return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - Luau::SourceModule module; - module.root = parseResult.root; - module.mode = Mode::Definition; + sourceModule.root = parseResult.root; + sourceModule.mode = Mode::Definition; - ModulePtr checkedModule = check(module, Mode::Definition, {}); + ModulePtr checkedModule = check(sourceModule, Mode::Definition, {}); if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, checkedModule}; + return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; CloneState cloneState; @@ -117,20 +117,20 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, globalTypes, cloneState); + TypeId globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); - globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + globals.globalScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; typesToPersist.push_back(globalTy); } for (const auto& [name, ty] : checkedModule->exportedTypeBindings) { - TypeFun globalTy = clone(ty, globalTypes, cloneState); + TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); - globalScope->exportedTypeBindings[name] = globalTy; + globals.globalScope->exportedTypeBindings[name] = globalTy; typesToPersist.push_back(globalTy.type); } @@ -140,10 +140,11 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c persist(ty); } - return LoadDefinitionFileResult{true, parseResult, checkedModule}; + return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } -LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) +LoadDefinitionFileResult loadDefinitionFile_DEPRECATED( + TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName) { LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); @@ -156,7 +157,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, nullptr}; + return LoadDefinitionFileResult{false, parseResult, {}, nullptr}; Luau::SourceModule module; module.root = parseResult.root; @@ -165,7 +166,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t ModulePtr checkedModule = typeChecker.check(module, Mode::Definition); if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, checkedModule}; + return LoadDefinitionFileResult{false, parseResult, {}, checkedModule}; CloneState cloneState; @@ -174,17 +175,17 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); + TypeId globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; typesToPersist.push_back(globalTy); } for (const auto& [name, ty] : checkedModule->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); + TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; @@ -197,7 +198,67 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t persist(ty); } - return LoadDefinitionFileResult{true, parseResult, checkedModule}; + return LoadDefinitionFileResult{true, parseResult, {}, checkedModule}; +} + +LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, + const std::string& packageName, bool captureComments) +{ + if (!FFlag::LuauDefinitionFileSourceModule) + return loadDefinitionFile_DEPRECATED(typeChecker, globals, targetScope, source, packageName); + + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); + + Luau::SourceModule sourceModule; + + ParseOptions options; + options.allowDeclarationSyntax = true; + options.captureComments = captureComments; + + Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); + + if (parseResult.errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; + + sourceModule.root = parseResult.root; + sourceModule.mode = Mode::Definition; + + ModulePtr checkedModule = typeChecker.check(sourceModule, Mode::Definition); + + if (checkedModule->errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; + + CloneState cloneState; + + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); + + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globals.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + typesToPersist.push_back(globalTy); + } + + for (const auto& [name, ty] : checkedModule->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; + + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); + } + + return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } std::vector parsePathExpr(const AstExpr& pathExpr) @@ -414,11 +475,12 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c , fileResolver(fileResolver) , moduleResolver(this) , moduleResolverForAutocomplete(this) - , typeChecker(&moduleResolver, builtinTypes, &iceHandler) - , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, builtinTypes, &iceHandler) + , globals(builtinTypes) + , globalsForAutocomplete(builtinTypes) + , typeChecker(globals, &moduleResolver, builtinTypes, &iceHandler) + , typeCheckerForAutocomplete(globalsForAutocomplete, &moduleResolverForAutocomplete, builtinTypes, &iceHandler) , configResolver(configResolver) , options(options) - , globalScope(typeChecker.globalScope) { } @@ -704,13 +766,13 @@ bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& return cyclic; } -ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) +ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const { ScopePtr result; if (forAutocomplete) - result = typeCheckerForAutocomplete.globalScope; + result = globalsForAutocomplete.globalScope; else - result = typeChecker.globalScope; + result = globals.globalScope; if (module.environmentName) result = getEnvironmentScope(*module.environmentName); @@ -848,16 +910,6 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons return const_cast(this)->getSourceModule(moduleName); } -ScopePtr Frontend::getGlobalScope() -{ - if (!globalScope) - { - globalScope = typeChecker.globalScope; - } - - return globalScope; -} - ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, const ScopePtr& globalScope, FrontendOptions options) @@ -946,7 +998,7 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect { return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, - forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope, options, recordJsonLog); + forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope, options, recordJsonLog); } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. @@ -1115,7 +1167,7 @@ ScopePtr Frontend::addEnvironment(const std::string& environmentName) if (environments.count(environmentName) == 0) { - ScopePtr scope = std::make_shared(typeChecker.globalScope); + ScopePtr scope = std::make_shared(globals.globalScope); environments[environmentName] = scope; return scope; } @@ -1123,14 +1175,16 @@ ScopePtr Frontend::addEnvironment(const std::string& environmentName) return environments[environmentName]; } -ScopePtr Frontend::getEnvironmentScope(const std::string& environmentName) +ScopePtr Frontend::getEnvironmentScope(const std::string& environmentName) const { - LUAU_ASSERT(environments.count(environmentName) > 0); + if (auto it = environments.find(environmentName); it != environments.end()) + return it->second; - return environments[environmentName]; + LUAU_ASSERT(!"environment doesn't exist"); + return {}; } -void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) +void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) { LUAU_ASSERT(builtinDefinitions.count(name) == 0); @@ -1143,7 +1197,7 @@ void Frontend::applyBuiltinDefinitionToEnvironment(const std::string& environmen LUAU_ASSERT(builtinDefinitions.count(definitionName) > 0); if (builtinDefinitions.count(definitionName) > 0) - builtinDefinitions[definitionName](typeChecker, getEnvironmentScope(environmentName)); + builtinDefinitions[definitionName](typeChecker, globals, getEnvironmentScope(environmentName)); } LintResult Frontend::classifyLints(const std::vector& warnings, const Config& config) diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 845ae3a36..9da43ed2d 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -253,11 +253,12 @@ struct PureQuantifier : Substitution } }; -TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope) +std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope) { PureQuantifier quantifier{arena, scope}; std::optional result = quantifier.substitute(ty); - LUAU_ASSERT(result); + if (!result) + return std::nullopt; FunctionType* ftv = getMutable(*result); LUAU_ASSERT(ftv); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 5040952e8..26618313b 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -276,6 +276,7 @@ PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) { LUAU_ASSERT(get(ty)); + LUAU_ASSERT(ty != newBoundTo); PendingType* newTy = queue(ty); if (TableType* ttv = Luau::getMutable(newTy)) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index f15f8c4cf..4bc1223de 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -856,22 +856,22 @@ TypeId BuiltinTypes::makeStringMetatable() return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } -TypeId BuiltinTypes::errorRecoveryType() +TypeId BuiltinTypes::errorRecoveryType() const { return errorType; } -TypePackId BuiltinTypes::errorRecoveryTypePack() +TypePackId BuiltinTypes::errorRecoveryTypePack() const { return errorTypePack; } -TypeId BuiltinTypes::errorRecoveryType(TypeId guess) +TypeId BuiltinTypes::errorRecoveryType(TypeId guess) const { return guess; } -TypePackId BuiltinTypes::errorRecoveryTypePack(TypePackId guess) +TypePackId BuiltinTypes::errorRecoveryTypePack(TypePackId guess) const { return guess; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index aacfd7295..a160a1d26 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -857,7 +857,7 @@ struct TypeChecker2 } void reportOverloadResolutionErrors(AstExprCall* call, std::vector overloads, TypePackId expectedArgTypes, - const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) + const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) { if (overloads.size() == 1) { @@ -883,8 +883,8 @@ struct TypeChecker2 const FunctionType* ftv = get(overload); LUAU_ASSERT(ftv); // overload must be a function type here - auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [ftv](const std::pair& e) { - return ftv == std::get<1>(e); + auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [overload](const std::pair& e) { + return overload == e.second; }); LUAU_ASSERT(error != overloadsErrors.end()); @@ -1036,7 +1036,7 @@ struct TypeChecker2 TypePackId expectedArgTypes = arena->addTypePack(args); std::vector overloads = flattenIntersection(testFunctionType); - std::vector> overloadsErrors; + std::vector> overloadsErrors; overloadsErrors.reserve(overloads.size()); std::vector overloadsThatMatchArgCount; @@ -1060,7 +1060,7 @@ struct TypeChecker2 } else { - overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overloadFn); + overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overload); return; } } @@ -1086,7 +1086,7 @@ struct TypeChecker2 if (!argMismatch) overloadsThatMatchArgCount.push_back(overload); - overloadsErrors.emplace_back(std::move(overloadErrors), overloadFn); + overloadsErrors.emplace_back(std::move(overloadErrors), overload); } reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); @@ -1102,11 +1102,54 @@ struct TypeChecker2 visitCall(call); } + std::optional tryStripUnionFromNil(TypeId ty) + { + if (const UnionType* utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + + std::vector result; + + for (TypeId option : utv) + { + if (!isNil(option)) + result.push_back(option); + } + + if (result.empty()) + return std::nullopt; + + return result.size() == 1 ? result[0] : module->internalTypes.addType(UnionType{std::move(result)}); + } + + return std::nullopt; + } + + TypeId stripFromNilAndReport(TypeId ty, const Location& location) + { + ty = follow(ty); + + if (auto utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + } + + if (std::optional strippedUnion = tryStripUnionFromNil(ty)) + { + reportError(OptionalValueAccess{ty}, location); + return follow(*strippedUnion); + } + + return ty; + } + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) { visit(expr, RValue); - TypeId leftType = lookupType(expr); + TypeId leftType = stripFromNilAndReport(lookupType(expr), location); const NormalizedType* norm = normalizer.normalize(leftType); if (!norm) reportError(NormalizationTooComplex{}, location); @@ -1766,7 +1809,15 @@ struct TypeChecker2 } else { - reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); + std::string symbol = ""; + if (ty->prefix) + { + symbol += (*(ty->prefix)).value; + symbol += "."; + } + symbol += ty->name.value; + + reportError(UnknownSymbol{symbol, UnknownSymbol::Context::Type}, ty->location); } } } @@ -2032,7 +2083,11 @@ struct TypeChecker2 { if (foundOneProp) reportError(MissingUnionProperty{tableTy, typesMissingTheProp, prop}, location); - else if (context == LValue) + // For class LValues, we don't want to report an extension error, + // because classes come into being with full knowledge of their + // shape. We instead want to report the unknown property error of + // the `else` branch. + else if (context == LValue && !get(tableTy)) reportError(CannotExtendTable{tableTy, CannotExtendTable::Property, prop}, location); else reportError(UnknownProperty{tableTy, prop}, location); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 6aa8e6cac..87d5686fa 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -42,7 +42,6 @@ LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) -LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) namespace Luau @@ -212,8 +211,24 @@ size_t HashBoolNamePair::operator()(const std::pair& pair) const return std::hash()(pair.first) ^ std::hash()(pair.second); } -TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) - : resolver(resolver) +GlobalTypes::GlobalTypes(NotNull builtinTypes) + : builtinTypes(builtinTypes) +{ + globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); + + globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType}); + globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType}); + globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType}); + globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType}); + globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType}); + globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType}); + globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType}); + globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); +} + +TypeChecker::TypeChecker(const GlobalTypes& globals, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) + : globals(globals) + , resolver(resolver) , builtinTypes(builtinTypes) , iceHandler(iceHandler) , unifierState(iceHandler) @@ -231,16 +246,6 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull builtin , uninhabitableTypePack(builtinTypes->uninhabitableTypePack) , duplicateTypeAliases{{false, {}}} { - globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); - - globalScope->addBuiltinTypeBinding("any", TypeFun{{}, anyType}); - globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, nilType}); - globalScope->addBuiltinTypeBinding("number", TypeFun{{}, numberType}); - globalScope->addBuiltinTypeBinding("string", TypeFun{{}, stringType}); - globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, booleanType}); - globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, threadType}); - globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, unknownType}); - globalScope->addBuiltinTypeBinding("never", TypeFun{{}, neverType}); } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -273,7 +278,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; - ScopePtr parentScope = environmentScope.value_or(globalScope); + ScopePtr parentScope = environmentScope.value_or(globals.globalScope); ScopePtr moduleScope = std::make_shared(parentScope); if (module.cyclic) @@ -1121,8 +1126,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) if (ModulePtr module = resolver->getModule(moduleInfo->name)) { scope->importedTypeBindings[name] = module->exportedTypeBindings; - if (FFlag::SupportTypeAliasGoToDeclaration) - scope->importedModules[name] = moduleInfo->name; + scope->importedModules[name] = moduleInfo->name; } // In non-strict mode we force the module type on the variable, in strict mode it is already unified @@ -1580,7 +1584,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea } else { - if (globalScope->builtinTypeNames.contains(name)) + if (globals.globalScope->builtinTypeNames.contains(name)) { reportError(typealias.location, DuplicateTypeDefinition{name}); duplicateTypeAliases.insert({typealias.exported, name}); @@ -1601,8 +1605,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; scope->typeAliasLocations[name] = typealias.location; - if (FFlag::SupportTypeAliasGoToDeclaration) - scope->typeAliasNameLocations[name] = typealias.nameLocation; + scope->typeAliasNameLocations[name] = typealias.nameLocation; } } } @@ -3360,19 +3363,19 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T if (auto globalName = funName.as()) { - const ScopePtr& globalScope = currentModule->getModuleScope(); + const ScopePtr& moduleScope = currentModule->getModuleScope(); Symbol name = globalName->name; - if (globalScope->bindings.count(name)) + if (moduleScope->bindings.count(name)) { if (isNonstrictMode()) - return globalScope->bindings[name].typeId; + return moduleScope->bindings[name].typeId; return errorRecoveryType(scope); } else { TypeId ty = freshTy(); - globalScope->bindings[name] = {ty, funName.location}; + moduleScope->bindings[name] = {ty, funName.location}; return ty; } } @@ -5898,7 +5901,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r if (!typeguardP.isTypeof) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - auto typeFun = globalScope->lookupType(typeguardP.kind); + auto typeFun = globals.globalScope->lookupType(typeguardP.kind); if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index abafa9fbc..6a9fadfad 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -407,9 +407,6 @@ TypePackId TypeReducer::reduce(TypePackId tp) std::optional TypeReducer::intersectionType(TypeId left, TypeId right) { - LUAU_ASSERT(!get(left)); - LUAU_ASSERT(!get(right)); - if (get(left)) return left; // never & T ~ never else if (get(right)) @@ -442,6 +439,17 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) return std::nullopt; // *pending* & T ~ *pending* & T else if (get(right)) return std::nullopt; // T & *pending* ~ T & *pending* + else if (auto [utl, utr] = get2(left, right); utl && utr) + { + std::vector parts; + for (TypeId optionl : utl) + { + for (TypeId optionr : utr) + parts.push_back(apply(&TypeReducer::intersectionType, optionl, optionr)); + } + + return reduce(flatten(std::move(parts))); // (T | U) & (A | B) ~ (T & A) | (T & B) | (U & A) | (U & B) + } else if (auto ut = get(left)) return reduce(distribute(begin(ut), end(ut), &TypeReducer::intersectionType, right)); // (A | B) & T ~ (A & T) | (B & T) else if (get(right)) @@ -789,6 +797,36 @@ std::optional TypeReducer::unionType(TypeId left, TypeId right) return reduce(distribute(begin(it), end(it), &TypeReducer::unionType, left)); // ~T | (A & B) ~ (~T | A) & (~T | B) else if (auto [it, nt] = get2(left, right); it && nt) return unionType(right, left); // (A & B) | ~T ~ ~T | (A & B) + else if (auto it = get(left)) + { + bool didReduce = false; + std::vector parts; + for (TypeId part : it) + { + auto nt = get(part); + if (!nt) + { + parts.push_back(part); + continue; + } + + auto redex = unionType(part, right); + if (redex && get(*redex)) + { + didReduce = true; + continue; + } + + parts.push_back(part); + } + + if (didReduce) + return flatten(std::move(parts)); // (T & ~nil) | nil ~ T + else + return std::nullopt; // (T & ~nil) | U + } + else if (get(right)) + return unionType(right, left); // A | (T & U) ~ (T & U) | A else if (auto [nl, nr] = get2(left, right); nl && nr) { // These should've been reduced already. diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index aba642714..b53401dce 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,10 +18,9 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner2, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) -LUAU_FASTFLAGVARIABLE(LuauTableUnifyInstantiationFix, false) +LUAU_FASTFLAGVARIABLE(LuauTinyUnifyNormalsFix, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNegatedFunctionTypes) @@ -108,7 +107,7 @@ struct PromoteTypeLevels final : TypeOnceVisitor // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + if (!log.is(ty)) return true; promote(ty, log.getMutable(ty)); @@ -126,7 +125,7 @@ struct PromoteTypeLevels final : TypeOnceVisitor // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + if (!log.is(ty)) return true; promote(ty, log.getMutable(ty)); @@ -690,6 +689,31 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ } } +struct BlockedTypeFinder : TypeOnceVisitor +{ + std::unordered_set blockedTypes; + + bool visit(TypeId ty, const BlockedType&) override + { + blockedTypes.insert(ty); + return true; + } +}; + +bool Unifier::blockOnBlockedTypes(TypeId subTy, TypeId superTy) +{ + BlockedTypeFinder blockedTypeFinder; + blockedTypeFinder.traverse(subTy); + blockedTypeFinder.traverse(superTy); + if (!blockedTypeFinder.blockedTypes.empty()) + { + blockedTypes.insert(end(blockedTypes), begin(blockedTypeFinder.blockedTypes), end(blockedTypeFinder.blockedTypes)); + return true; + } + + return false; +} + void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall) { // T <: A | B if T <: A or T <: B @@ -788,6 +812,11 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } else if (!found && normalize) { + // We cannot normalize a type that contains blocked types. We have to + // stop for now if we find any. + if (blockOnBlockedTypes(subTy, superTy)) + return; + // It is possible that T <: A | B even though T (superTable)) - innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); - else if (get(subTable)) - innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); + + if (FFlag::LuauTinyUnifyNormalsFix) + innerState.tryUnify(subTable, superTable); else - innerState.tryUnifyTables(subTable, superTable); + { + if (get(superTable)) + innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); + else if (get(subTable)) + innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); + else + innerState.tryUnifyTables(subTable, superTable); + } + if (innerState.errors.empty()) { found = true; @@ -1782,7 +1828,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TypeId activeSubTy = subTy; TableType* superTable = log.getMutable(superTy); TableType* subTable = log.getMutable(subTy); - TableType* instantiatedSubTable = subTable; // TODO: remove with FFlagLuauTableUnifyInstantiationFix if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1799,16 +1844,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) std::optional instantiated = instantiation.substitute(subTy); if (instantiated.has_value()) { - if (FFlag::LuauTableUnifyInstantiationFix) - { - activeSubTy = *instantiated; - subTable = log.getMutable(activeSubTy); - } - else - { - subTable = log.getMutable(*instantiated); - instantiatedSubTable = subTable; - } + activeSubTy = *instantiated; + subTable = log.getMutable(activeSubTy); if (!subTable) ice("instantiation made a table type into a non-table type in tryUnifyTables"); @@ -1910,21 +1947,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated // txn log. - TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(superTy) : superTy; - TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(activeSubTy) : activeSubTy; + TypeId superTyNew = log.follow(superTy); + TypeId subTyNew = log.follow(activeSubTy); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - // If one of the types stopped being a table altogether, we need to restart from the top - if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) - return tryUnify(subTy, superTy, false, isIntersection); - } + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) + return tryUnify(subTy, superTy, false, isIntersection); // Otherwise, restart only the table unification TableType* newSuperTable = log.getMutable(superTyNew); TableType* newSubTable = log.getMutable(subTyNew); - if (superTable != newSuperTable || (subTable != newSubTable && (FFlag::LuauTableUnifyInstantiationFix || subTable != instantiatedSubTable))) + if (superTable != newSuperTable || subTable != newSubTable) { if (errors.empty()) return tryUnifyTables(subTy, superTy, isIntersection); @@ -1981,15 +2015,12 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else extraProperties.push_back(name); - TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(superTy) : superTy; - TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(activeSubTy) : activeSubTy; + TypeId superTyNew = log.follow(superTy); + TypeId subTyNew = log.follow(activeSubTy); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - // If one of the types stopped being a table altogether, we need to restart from the top - if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) - return tryUnify(subTy, superTy, false, isIntersection); - } + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) + return tryUnify(subTy, superTy, false, isIntersection); // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated @@ -1997,7 +2028,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TableType* newSuperTable = log.getMutable(superTyNew); TableType* newSubTable = log.getMutable(subTyNew); - if (superTable != newSuperTable || (subTable != newSubTable && (FFlag::LuauTableUnifyInstantiationFix || subTable != instantiatedSubTable))) + if (superTable != newSuperTable || subTable != newSubTable) { if (errors.empty()) return tryUnifyTables(subTy, superTy, isIntersection); @@ -2050,19 +2081,11 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // Changing the indexer can invalidate the table pointers. - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - superTable = log.getMutable(log.follow(superTy)); - subTable = log.getMutable(log.follow(activeSubTy)); + superTable = log.getMutable(log.follow(superTy)); + subTable = log.getMutable(log.follow(activeSubTy)); - if (!superTable || !subTable) - return; - } - else - { - superTable = log.getMutable(superTy); - subTable = log.getMutable(activeSubTy); - } + if (!superTable || !subTable) + return; if (!missingProperties.empty()) { @@ -2135,18 +2158,15 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) Unifier child = makeChildUnifier(); child.tryUnify_(ty, superTy); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table - // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed - // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check - TypeId newSuperTy = child.log.follow(superTy); + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed + // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check + TypeId newSuperTy = child.log.follow(superTy); - if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) - { - log.replace(superTy, BoundType{subTy}); - return; - } + if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) + { + log.replace(superTy, BoundType{subTy}); + return; } if (auto e = hasUnificationTooComplex(child.errors)) @@ -2156,13 +2176,10 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) log.concat(std::move(child.log)); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table - // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype - if (child.errors.empty()) - log.replace(superTy, BoundType{subTy}); - } + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype + if (child.errors.empty()) + log.replace(superTy, BoundType{subTy}); return; } @@ -2379,6 +2396,11 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) if (!log.get(subTy) && !log.get(superTy)) ice("tryUnifyNegations superTy or subTy must be a negation type"); + // We cannot normalize a type that contains blocked types. We have to + // stop for now if we find any. + if (blockOnBlockedTypes(subTy, superTy)) + return; + const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 6257e2f3a..d6f1822d4 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -121,6 +121,7 @@ static void displayHelp(const char* argv0) static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + fflush(stdout); return 1; } @@ -267,8 +268,8 @@ int main(int argc, char** argv) CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); - Luau::registerBuiltinGlobals(frontend.typeChecker); - Luau::freeze(frontend.typeChecker.globalTypes); + Luau::registerBuiltinGlobals(frontend.typeChecker, frontend.globals); + Luau::freeze(frontend.globals.globalTypes); #ifdef CALLGRIND CALLGRIND_ZERO_STATS; diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index d3e1a9334..21fa755ca 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -1,7 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include #include +#include #include @@ -22,5 +24,59 @@ std::pair getLiveInOutValueCount(IrFunction& function, IrBlo uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block); uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block); +struct RegisterSet +{ + std::bitset<256> regs; + + // If variadic sequence is active, we track register from which it starts + bool varargSeq = false; + uint8_t varargStart = 0; +}; + +struct CfgInfo +{ + std::vector predecessors; + std::vector predecessorsOffsets; + + std::vector successors; + std::vector successorsOffsets; + + std::vector in; + std::vector out; + + RegisterSet captured; +}; + +void computeCfgInfo(IrFunction& function); + +struct BlockIteratorWrapper +{ + uint32_t* itBegin = nullptr; + uint32_t* itEnd = nullptr; + + bool empty() const + { + return itBegin == itEnd; + } + + size_t size() const + { + return size_t(itEnd - itBegin); + } + + uint32_t* begin() const + { + return itBegin; + } + + uint32_t* end() const + { + return itEnd; + } +}; + +BlockIteratorWrapper predecessors(CfgInfo& cfg, uint32_t blockIdx); +BlockIteratorWrapper successors(CfgInfo& cfg, uint32_t blockIdx); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 049d700af..439abb9bd 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/IrAnalysis.h" #include "Luau/Label.h" #include "Luau/RegisterX64.h" #include "Luau/RegisterA64.h" @@ -261,6 +262,7 @@ enum class IrCmd : uint8_t // A: Rn (value start) // B: unsigned int (number of registers to go over) // Note: result is stored in the register specified in 'A' + // Note: all referenced registers might be modified in the operation CONCAT, // Load function upvalue into stack slot @@ -382,16 +384,16 @@ enum class IrCmd : uint8_t LOP_RETURN, // Adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue - // A: Rn (loop variable start, updates Rn+2 Rn+3 Rn+4) - // B: int (loop variable count, is more than 2, additional registers are set to nil) + // A: Rn (loop variable start, updates Rn+2 and 'B' number of registers starting from Rn+3) + // B: int (loop variable count, if more than 2, registers starting from Rn+5 are set to nil) // C: block (repeat) // D: block (exit) LOP_FORGLOOP, // Handle LOP_FORGLOOP fallback when variable being iterated is not a table // A: unsigned int (bytecode instruction index) - // B: Rn (loop state start, updates Rn+2 Rn+3 Rn+4 Rn+5) - // C: int (extra variable count or -1 for ipairs-style iteration) + // B: Rn (loop state start, updates Rn+2 and 'C' number of registers starting from Rn+3) + // C: int (loop variable count and a MSB set when it's an ipairs-like iteration loop) // D: block (repeat) // E: block (exit) LOP_FORGLOOP_FALLBACK, @@ -638,6 +640,8 @@ struct IrFunction Proto* proto = nullptr; + CfgInfo cfg; + IrBlock& blockOp(IrOp op) { LUAU_ASSERT(op.kind == IrOpKind::Block); diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index 47a5f9e92..a6329ecf5 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -11,6 +11,8 @@ namespace Luau namespace CodeGen { +struct CfgInfo; + const char* getCmdName(IrCmd cmd); const char* getBlockKindName(IrBlockKind kind); @@ -19,6 +21,7 @@ struct IrToStringContext std::string& result; std::vector& blocks; std::vector& constants; + CfgInfo& cfg; }; void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index); @@ -27,10 +30,10 @@ void toString(IrToStringContext& ctx, IrOp op); void toString(std::string& result, IrConst constant); -void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index); -void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index); // Block title +void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo); +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo); // Block title -std::string toString(IrFunction& function, bool includeDetails); +std::string toString(IrFunction& function, bool includeUseInfo); std::string dump(IrFunction& function); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 153cf7ade..3b14a8c80 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -183,9 +183,6 @@ void kill(IrFunction& function, uint32_t start, uint32_t end); // Remove a block, including all instructions inside void kill(IrFunction& function, IrBlock& block); -void removeUse(IrFunction& function, IrInst& inst); -void removeUse(IrFunction& function, IrBlock& block); - // Replace a single operand and update use counts (can cause chain removal of dead code) void replace(IrFunction& function, IrOp& original, IrOp replacement); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 51bf17461..c794972d0 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -68,9 +68,24 @@ static NativeProto* assembleFunction(X64::AssemblyBuilderX64& build, NativeState if (options.includeAssembly || options.includeIr) { if (proto->debugname) - build.logAppend("; function %s()", getstr(proto->debugname)); + build.logAppend("; function %s(", getstr(proto->debugname)); else - build.logAppend("; function()"); + build.logAppend("; function("); + + for (int i = 0; i < proto->numparams; i++) + { + LocVar* var = proto->locvars ? &proto->locvars[proto->sizelocvars - proto->numparams + i] : nullptr; + + if (var && var->varname) + build.logAppend("%s%s", i == 0 ? "" : ", ", getstr(var->varname)); + else + build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); + } + + if (proto->numparams != 0 && proto->is_vararg) + build.logAppend(", ...)"); + else + build.logAppend(")"); if (proto->linedefined >= 0) build.logAppend(" line %d\n", proto->linedefined); @@ -90,6 +105,10 @@ static NativeProto* assembleFunction(X64::AssemblyBuilderX64& build, NativeState constPropInBlockChains(builder); } + // TODO: cfg info has to be computed earlier to use in optimizations + // It's done here to appear in text output and to measure performance impact on code generation + computeCfgInfo(builder.function); + optimizeMemoryOperandsX64(builder.function); X64::IrLoweringX64 lowering(build, helpers, data, proto, builder.function); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index aa3e19f7e..dc7d771ec 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -5,6 +5,10 @@ #include "Luau/IrData.h" #include "Luau/IrUtils.h" +#include "lobject.h" + +#include + #include namespace Luau @@ -116,5 +120,518 @@ uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block) return getLiveInOutValueCount(function, block).second; } +static void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart) +{ + if (!defRs.varargSeq) + { + LUAU_ASSERT(!sourceRs.varargSeq || sourceRs.varargStart == varargStart); + + sourceRs.varargSeq = true; + sourceRs.varargStart = varargStart; + } + else + { + // Variadic use sequence might include registers before def sequence + for (int i = varargStart; i < defRs.varargStart; i++) + { + if (!defRs.regs.test(i)) + sourceRs.regs.set(i); + } + } +} + +static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& block, RegisterSet& defRs, std::bitset<256>& capturedRegs) +{ + RegisterSet inRs; + + auto def = [&](IrOp op, int offset = 0) { + LUAU_ASSERT(op.kind == IrOpKind::VmReg); + defRs.regs.set(op.index + offset, true); + }; + + auto use = [&](IrOp op, int offset = 0) { + LUAU_ASSERT(op.kind == IrOpKind::VmReg); + if (!defRs.regs.test(op.index + offset)) + inRs.regs.set(op.index + offset, true); + }; + + auto maybeDef = [&](IrOp op) { + if (op.kind == IrOpKind::VmReg) + defRs.regs.set(op.index, true); + }; + + auto maybeUse = [&](IrOp op) { + if (op.kind == IrOpKind::VmReg) + { + if (!defRs.regs.test(op.index)) + inRs.regs.set(op.index, true); + } + }; + + auto defVarargs = [&](uint8_t varargStart) { + defRs.varargSeq = true; + defRs.varargStart = varargStart; + }; + + auto useVarargs = [&](uint8_t varargStart) { + requireVariadicSequence(inRs, defRs, varargStart); + + // Variadic sequence has been consumed + defRs.varargSeq = false; + defRs.varargStart = 0; + }; + + auto defRange = [&](int start, int count) { + if (count == -1) + { + defVarargs(start); + } + else + { + for (int i = start; i < start + count; i++) + defRs.regs.set(i, true); + } + }; + + auto useRange = [&](int start, int count) { + if (count == -1) + { + useVarargs(start); + } + else + { + for (int i = start; i < start + count; i++) + { + if (!defRs.regs.test(i)) + inRs.regs.set(i, true); + } + } + }; + + for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) + { + const IrInst& inst = function.instructions[instIdx]; + + // For correct analysis, all instruction uses must be handled before handling the definitions + switch (inst.cmd) + { + case IrCmd::LOAD_TAG: + case IrCmd::LOAD_POINTER: + case IrCmd::LOAD_DOUBLE: + case IrCmd::LOAD_INT: + case IrCmd::LOAD_TVALUE: + maybeUse(inst.a); // Argument can also be a VmConst + break; + case IrCmd::STORE_TAG: + case IrCmd::STORE_POINTER: + case IrCmd::STORE_DOUBLE: + case IrCmd::STORE_INT: + case IrCmd::STORE_TVALUE: + maybeDef(inst.a); // Argument can also be a pointer value + break; + case IrCmd::JUMP_IF_TRUTHY: + case IrCmd::JUMP_IF_FALSY: + use(inst.a); + break; + case IrCmd::JUMP_CMP_ANY: + use(inst.a); + use(inst.b); + break; + // A <- B, C + case IrCmd::DO_ARITH: + case IrCmd::GET_TABLE: + case IrCmd::SET_TABLE: + use(inst.b); + maybeUse(inst.c); // Argument can also be a VmConst + + def(inst.a); + break; + // A <- B + case IrCmd::DO_LEN: + use(inst.b); + + def(inst.a); + break; + case IrCmd::GET_IMPORT: + def(inst.a); + break; + case IrCmd::CONCAT: + useRange(inst.a.index, function.uintOp(inst.b)); + + defRange(inst.a.index, function.uintOp(inst.b)); + break; + case IrCmd::GET_UPVALUE: + def(inst.a); + break; + case IrCmd::SET_UPVALUE: + use(inst.b); + break; + case IrCmd::PREPARE_FORN: + use(inst.a); + use(inst.b); + use(inst.c); + + def(inst.a); + def(inst.b); + def(inst.c); + break; + case IrCmd::INTERRUPT: + break; + case IrCmd::BARRIER_OBJ: + case IrCmd::BARRIER_TABLE_FORWARD: + use(inst.b); + break; + case IrCmd::CLOSE_UPVALS: + // Closing an upvalue should be counted as a register use (it copies the fresh register value) + // But we lack the required information about the specific set of registers that are affected + // Because we don't plan to optimize captured registers atm, we skip full dataflow analysis for them right now + break; + case IrCmd::CAPTURE: + maybeUse(inst.a); + + if (function.boolOp(inst.b)) + capturedRegs.set(inst.a.index, true); + break; + case IrCmd::LOP_SETLIST: + use(inst.b); + useRange(inst.c.index, function.intOp(inst.d)); + break; + case IrCmd::LOP_NAMECALL: + use(inst.c); + + defRange(inst.b.index, 2); + break; + case IrCmd::LOP_CALL: + use(inst.b); + useRange(inst.b.index + 1, function.intOp(inst.c)); + + defRange(inst.b.index, function.intOp(inst.d)); + break; + case IrCmd::LOP_RETURN: + useRange(inst.b.index, function.intOp(inst.c)); + break; + case IrCmd::FASTCALL: + case IrCmd::INVOKE_FASTCALL: + if (int count = function.intOp(inst.e); count != -1) + { + if (count >= 3) + { + LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && inst.d.index == inst.c.index + 1); + + useRange(inst.c.index, count); + } + else + { + if (count >= 1) + use(inst.c); + + if (count >= 2) + maybeUse(inst.d); // Argument can also be a VmConst + } + } + else + { + useVarargs(inst.c.index); + } + + defRange(inst.b.index, function.intOp(inst.f)); + break; + case IrCmd::LOP_FORGLOOP: + // First register is not used by instruction, we check that it's still 'nil' with CHECK_TAG + use(inst.a, 1); + use(inst.a, 2); + + def(inst.a, 2); + defRange(inst.a.index + 3, function.intOp(inst.b)); + break; + case IrCmd::LOP_FORGLOOP_FALLBACK: + useRange(inst.b.index, 3); + + def(inst.b, 2); + defRange(inst.b.index + 3, uint8_t(function.intOp(inst.c))); // ignore most significant bit + break; + case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + use(inst.b); + break; + // B <- C, D + case IrCmd::LOP_AND: + case IrCmd::LOP_OR: + use(inst.c); + use(inst.d); + + def(inst.b); + break; + // B <- C + case IrCmd::LOP_ANDK: + case IrCmd::LOP_ORK: + use(inst.c); + + def(inst.b); + break; + case IrCmd::FALLBACK_GETGLOBAL: + def(inst.b); + break; + case IrCmd::FALLBACK_SETGLOBAL: + use(inst.b); + break; + case IrCmd::FALLBACK_GETTABLEKS: + use(inst.c); + + def(inst.b); + break; + case IrCmd::FALLBACK_SETTABLEKS: + use(inst.b); + use(inst.c); + break; + case IrCmd::FALLBACK_NAMECALL: + use(inst.c); + + defRange(inst.b.index, 2); + break; + case IrCmd::FALLBACK_PREPVARARGS: + // No effect on explicitly referenced registers + break; + case IrCmd::FALLBACK_GETVARARGS: + defRange(inst.b.index, function.intOp(inst.c)); + break; + case IrCmd::FALLBACK_NEWCLOSURE: + def(inst.b); + break; + case IrCmd::FALLBACK_DUPCLOSURE: + def(inst.b); + break; + case IrCmd::FALLBACK_FORGPREP: + use(inst.b); + + defRange(inst.b.index, 3); + break; + case IrCmd::ADJUST_STACK_TO_REG: + case IrCmd::ADJUST_STACK_TO_TOP: + // While these can be considered as vararg producers and consumers, it is already handled in fastcall instruction + break; + + default: + break; + } + } + + return inRs; +} + +// The algorithm used here is commonly known as backwards data-flow analysis. +// For each block, we track 'upward-exposed' (live-in) uses of registers - a use of a register that hasn't been defined in the block yet. +// We also track the set of registers that were defined in the block. +// When initial live-in sets of registers are computed, propagation of those uses upwards through predecessors is performed. +// If predecessor doesn't define the register, we have to add it to the live-in set. +// Extending the set of live-in registers of a block requires re-checking of that block. +// Propagation runs iteratively, using a worklist of blocks to visit until a fixed point is reached. +// This algorithm can be easily extended to cover phi instructions, but we don't use those yet. +static void computeCfgLiveInOutRegSets(IrFunction& function) +{ + CfgInfo& info = function.cfg; + + // Try to compute Luau VM register use-def info + info.in.resize(function.blocks.size()); + info.out.resize(function.blocks.size()); + + // Captured registers are tracked for the whole function + // It should be possible to have a more precise analysis for them in the future + std::bitset<256> capturedRegs; + + std::vector defRss; + defRss.resize(function.blocks.size()); + + // First we compute live-in set of each block + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + const IrBlock& block = function.blocks[blockIdx]; + + if (block.kind == IrBlockKind::Dead) + continue; + + info.in[blockIdx] = computeBlockLiveInRegSet(function, block, defRss[blockIdx], capturedRegs); + } + + info.captured.regs = capturedRegs; + + // With live-in sets ready, we can arrive at a fixed point for both in/out registers by requesting required registers from predecessors + std::vector worklist; + + std::vector inWorklist; + inWorklist.resize(function.blocks.size(), false); + + // We will have to visit each block at least once, so we add all of them to the worklist immediately + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + const IrBlock& block = function.blocks[blockIdx]; + + if (block.kind == IrBlockKind::Dead) + continue; + + worklist.push_back(uint32_t(blockIdx)); + inWorklist[blockIdx] = true; + } + + while (!worklist.empty()) + { + uint32_t blockIdx = worklist.back(); + worklist.pop_back(); + inWorklist[blockIdx] = false; + + IrBlock& curr = function.blocks[blockIdx]; + RegisterSet& inRs = info.in[blockIdx]; + RegisterSet& outRs = info.out[blockIdx]; + RegisterSet& defRs = defRss[blockIdx]; + + // Current block has to provide all registers in successor blocks + for (uint32_t succIdx : successors(info, blockIdx)) + { + IrBlock& succ = function.blocks[succIdx]; + + // This is a step away from the usual definition of live range flow through CFG + // Exit from a regular block to a fallback block is not considered a block terminator + // This is because fallback blocks define an alternative implementation of the same operations + // This can cause the current block to define more registers that actually were available at fallback entry + if (curr.kind != IrBlockKind::Fallback && succ.kind == IrBlockKind::Fallback) + continue; + + const RegisterSet& succRs = info.in[succIdx]; + + outRs.regs |= succRs.regs; + + if (succRs.varargSeq) + { + LUAU_ASSERT(!outRs.varargSeq || outRs.varargStart == succRs.varargStart); + + outRs.varargSeq = true; + outRs.varargStart = succRs.varargStart; + } + } + + RegisterSet oldInRs = inRs; + + // If current block didn't define a live-out, it has to be live-in + inRs.regs |= outRs.regs & ~defRs.regs; + + if (outRs.varargSeq) + requireVariadicSequence(inRs, defRs, outRs.varargStart); + + // If we have new live-ins, we have to notify all predecessors + // We don't allow changes to the start of the variadic sequence, so we skip checking that member + if (inRs.regs != oldInRs.regs || inRs.varargSeq != oldInRs.varargSeq) + { + for (uint32_t predIdx : predecessors(info, blockIdx)) + { + if (!inWorklist[predIdx]) + { + worklist.push_back(predIdx); + inWorklist[predIdx] = true; + } + } + } + } + + // If Proto data is available, validate that entry block arguments match required registers + if (function.proto) + { + RegisterSet& entryIn = info.in[0]; + + LUAU_ASSERT(!entryIn.varargSeq); + + for (size_t i = 0; i < entryIn.regs.size(); i++) + LUAU_ASSERT(!entryIn.regs.test(i) || i < function.proto->numparams); + } +} + +static void computeCfgBlockEdges(IrFunction& function) +{ + CfgInfo& info = function.cfg; + + // Compute predecessors block edges + info.predecessorsOffsets.reserve(function.blocks.size()); + info.successorsOffsets.reserve(function.blocks.size()); + + int edgeCount = 0; + + for (const IrBlock& block : function.blocks) + { + info.predecessorsOffsets.push_back(edgeCount); + edgeCount += block.useCount; + } + + info.predecessors.resize(edgeCount); + info.successors.resize(edgeCount); + + edgeCount = 0; + + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + const IrBlock& block = function.blocks[blockIdx]; + + info.successorsOffsets.push_back(edgeCount); + + if (block.kind == IrBlockKind::Dead) + continue; + + for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) + { + const IrInst& inst = function.instructions[instIdx]; + + auto checkOp = [&](IrOp op) { + if (op.kind == IrOpKind::Block) + { + // We use a trick here, where we use the starting offset of the predecessor list as the position where to write next predecessor + // The values will be adjusted back in a separate loop later + info.predecessors[info.predecessorsOffsets[op.index]++] = uint32_t(blockIdx); + + info.successors[edgeCount++] = op.index; + } + }; + + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + checkOp(inst.f); + } + } + + // Offsets into the predecessor list were used as iterators in the previous loop + // To adjust them back, block use count is subtracted (predecessor count is equal to how many uses block has) + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + const IrBlock& block = function.blocks[blockIdx]; + + info.predecessorsOffsets[blockIdx] -= block.useCount; + } +} + +void computeCfgInfo(IrFunction& function) +{ + computeCfgBlockEdges(function); + computeCfgLiveInOutRegSets(function); +} + +BlockIteratorWrapper predecessors(CfgInfo& cfg, uint32_t blockIdx) +{ + LUAU_ASSERT(blockIdx < cfg.predecessorsOffsets.size()); + + uint32_t start = cfg.predecessorsOffsets[blockIdx]; + uint32_t end = blockIdx + 1 < cfg.predecessorsOffsets.size() ? cfg.predecessorsOffsets[blockIdx + 1] : uint32_t(cfg.predecessors.size()); + + return BlockIteratorWrapper{cfg.predecessors.data() + start, cfg.predecessors.data() + end}; +} + +BlockIteratorWrapper successors(CfgInfo& cfg, uint32_t blockIdx) +{ + LUAU_ASSERT(blockIdx < cfg.successorsOffsets.size()); + + uint32_t start = cfg.successorsOffsets[blockIdx]; + uint32_t end = blockIdx + 1 < cfg.successorsOffsets.size() ? cfg.successorsOffsets[blockIdx + 1] : uint32_t(cfg.successors.size()); + + return BlockIteratorWrapper{cfg.successors.data() + start, cfg.successors.data() + end}; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 056ea6007..0a700dba6 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -256,7 +256,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstDupTable(*this, pc, i); break; case LOP_SETLIST: - inst(IrCmd::LOP_SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); + inst(IrCmd::LOP_SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); break; case LOP_GETUPVAL: translateInstGetUpval(*this, pc, i); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index cb203f7a7..2787fb11f 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -306,7 +306,7 @@ void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index) void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index) { - append(ctx.result, "%s_%u:", getBlockKindName(block.kind), index); + append(ctx.result, "%s_%u", getBlockKindName(block.kind), index); } void toString(IrToStringContext& ctx, IrOp op) @@ -362,33 +362,151 @@ void toString(std::string& result, IrConst constant) } } -void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index) +void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo) { size_t start = ctx.result.size(); toString(ctx, inst, index); - padToDetailColumn(ctx.result, start); - if (inst.useCount == 0 && hasSideEffects(inst.cmd)) - append(ctx.result, "; %%%u, has side-effects\n", index); + if (includeUseInfo) + { + padToDetailColumn(ctx.result, start); + + if (inst.useCount == 0 && hasSideEffects(inst.cmd)) + append(ctx.result, "; %%%u, has side-effects\n", index); + else + append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); + } else - append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); + { + ctx.result.append("\n"); + } +} + +static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks) +{ + bool comma = false; + + for (uint32_t target : blocks) + { + if (comma) + append(ctx.result, ", "); + comma = true; + + toString(ctx, ctx.blocks[target], target); + } +} + +static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs) +{ + bool comma = false; + + for (size_t i = 0; i < rs.regs.size(); i++) + { + if (rs.regs.test(i)) + { + if (comma) + append(ctx.result, ", "); + comma = true; + + append(ctx.result, "R%d", int(i)); + } + } + + if (rs.varargSeq) + { + if (comma) + append(ctx.result, ", "); + + append(ctx.result, "R%d...", rs.varargStart); + } } -void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index) +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo) { + // Report captured registers for entry block + if (block.useCount == 0 && block.kind != IrBlockKind::Dead && ctx.cfg.captured.regs.any()) + { + append(ctx.result, "; captured regs: "); + appendRegisterSet(ctx, ctx.cfg.captured); + append(ctx.result, "\n\n"); + } + size_t start = ctx.result.size(); toString(ctx, block, index); - padToDetailColumn(ctx.result, start); + append(ctx.result, ":"); + + if (includeUseInfo) + { + padToDetailColumn(ctx.result, start); - append(ctx.result, "; useCount: %d\n", block.useCount); + append(ctx.result, "; useCount: %d\n", block.useCount); + } + else + { + ctx.result.append("\n"); + } + + // Predecessor list + if (!ctx.cfg.predecessors.empty()) + { + BlockIteratorWrapper pred = predecessors(ctx.cfg, index); + + if (!pred.empty()) + { + append(ctx.result, "; predecessors: "); + + appendBlockSet(ctx, pred); + append(ctx.result, "\n"); + } + } + + // Successor list + if (!ctx.cfg.successors.empty()) + { + BlockIteratorWrapper succ = successors(ctx.cfg, index); + + if (!succ.empty()) + { + append(ctx.result, "; successors: "); + + appendBlockSet(ctx, succ); + append(ctx.result, "\n"); + } + } + + // Live-in VM regs + if (index < ctx.cfg.in.size()) + { + const RegisterSet& in = ctx.cfg.in[index]; + + if (in.regs.any() || in.varargSeq) + { + append(ctx.result, "; in regs: "); + appendRegisterSet(ctx, in); + append(ctx.result, "\n"); + } + } + + // Live-out VM regs + if (index < ctx.cfg.out.size()) + { + const RegisterSet& out = ctx.cfg.out[index]; + + if (out.regs.any() || out.varargSeq) + { + append(ctx.result, "; out regs: "); + appendRegisterSet(ctx, out); + append(ctx.result, "\n"); + } + } } -std::string toString(IrFunction& function, bool includeDetails) +std::string toString(IrFunction& function, bool includeUseInfo) { std::string result; - IrToStringContext ctx{result, function.blocks, function.constants}; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; for (size_t i = 0; i < function.blocks.size(); i++) { @@ -397,15 +515,7 @@ std::string toString(IrFunction& function, bool includeDetails) if (block.kind == IrBlockKind::Dead) continue; - if (includeDetails) - { - toStringDetailed(ctx, block, uint32_t(i)); - } - else - { - toString(ctx, block, uint32_t(i)); - ctx.result.append("\n"); - } + toStringDetailed(ctx, block, uint32_t(i), includeUseInfo); if (block.start == ~0u) { @@ -423,16 +533,7 @@ std::string toString(IrFunction& function, bool includeDetails) continue; append(ctx.result, " "); - - if (includeDetails) - { - toStringDetailed(ctx, inst, index); - } - else - { - toString(ctx, inst, index); - ctx.result.append("\n"); - } + toStringDetailed(ctx, inst, index, includeUseInfo); } append(ctx.result, "\n"); @@ -443,7 +544,7 @@ std::string toString(IrFunction& function, bool includeDetails) std::string dump(IrFunction& function) { - std::string result = toString(function, /* includeDetails */ true); + std::string result = toString(function, /* includeUseInfo */ true); printf("%s\n", result.c_str()); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 383375753..3b27d09fc 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -79,7 +79,7 @@ void IrLoweringX64::lower(AssemblyOptions options) } } - IrToStringContext ctx{build.text, function.blocks, function.constants}; + IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; // We use this to skip outlined fallback blocks from IR/asm text output size_t textSize = build.text.length(); @@ -112,7 +112,7 @@ void IrLoweringX64::lower(AssemblyOptions options) if (options.includeIr) { build.logAppend("# "); - toStringDetailed(ctx, block, blockIndex); + toStringDetailed(ctx, block, blockIndex, /* includeUseInfo */ true); } build.setLabel(block.label); @@ -145,7 +145,7 @@ void IrLoweringX64::lower(AssemblyOptions options) if (options.includeIr) { build.logAppend("# "); - toStringDetailed(ctx, inst, index); + toStringDetailed(ctx, inst, index, /* includeUseInfo */ true); } IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; @@ -416,7 +416,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); - RegisterX64 lhs = regOp(inst.a); + ScopedRegX64 optLhsTmp{regs}; + RegisterX64 lhs; + + if (inst.a.kind == IrOpKind::Constant) + { + optLhsTmp.alloc(SizeX64::xmmword); + + build.vmovsd(optLhsTmp.reg, memRegDoubleOp(inst.a)); + lhs = optLhsTmp.reg; + } + else + { + lhs = regOp(inst.a); + } if (inst.b.kind == IrOpKind::Inst) { @@ -444,14 +457,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - + ScopedRegX64 optLhsTmp{regs}; RegisterX64 lhs; if (inst.a.kind == IrOpKind::Constant) { - build.vmovsd(tmp.reg, memRegDoubleOp(inst.a)); - lhs = tmp.reg; + optLhsTmp.alloc(SizeX64::xmmword); + + build.vmovsd(optLhsTmp.reg, memRegDoubleOp(inst.a)); + lhs = optLhsTmp.reg; } else { diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index 91867806a..c527d033f 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -169,13 +169,17 @@ void IrRegAllocX64::assertAllFree() const LUAU_ASSERT(free); } +ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner) + : owner(owner) + , reg(noreg) +{ +} + ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, SizeX64 size) : owner(owner) + , reg(noreg) { - if (size == SizeX64::xmmword) - reg = owner.allocXmmReg(); - else - reg = owner.allocGprReg(size); + alloc(size); } ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg) @@ -190,6 +194,16 @@ ScopedRegX64::~ScopedRegX64() owner.freeReg(reg); } +void ScopedRegX64::alloc(SizeX64 size) +{ + LUAU_ASSERT(reg == noreg); + + if (size == SizeX64::xmmword) + reg = owner.allocXmmReg(); + else + reg = owner.allocGprReg(size); +} + void ScopedRegX64::free() { LUAU_ASSERT(reg != noreg); diff --git a/CodeGen/src/IrRegAllocX64.h b/CodeGen/src/IrRegAllocX64.h index ac072a32f..497bb035c 100644 --- a/CodeGen/src/IrRegAllocX64.h +++ b/CodeGen/src/IrRegAllocX64.h @@ -40,6 +40,7 @@ struct IrRegAllocX64 struct ScopedRegX64 { + explicit ScopedRegX64(IrRegAllocX64& owner); ScopedRegX64(IrRegAllocX64& owner, SizeX64 size); ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg); ~ScopedRegX64(); @@ -47,6 +48,7 @@ struct ScopedRegX64 ScopedRegX64(const ScopedRegX64&) = delete; ScopedRegX64& operator=(const ScopedRegX64&) = delete; + void alloc(SizeX64 size); void free(); IrRegAllocX64& owner; diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 0808ad076..d8115be9f 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -14,6 +14,29 @@ namespace Luau namespace CodeGen { +static void removeInstUse(IrFunction& function, uint32_t instIdx) +{ + IrInst& inst = function.instructions[instIdx]; + + LUAU_ASSERT(inst.useCount); + inst.useCount--; + + if (inst.useCount == 0) + kill(function, inst); +} + +static void removeBlockUse(IrFunction& function, uint32_t blockIdx) +{ + IrBlock& block = function.blocks[blockIdx]; + + LUAU_ASSERT(block.useCount); + block.useCount--; + + // Entry block is never removed because is has an implicit use + if (block.useCount == 0 && blockIdx != 0) + kill(function, block); +} + void addUse(IrFunction& function, IrOp op) { if (op.kind == IrOpKind::Inst) @@ -25,9 +48,9 @@ void addUse(IrFunction& function, IrOp op) void removeUse(IrFunction& function, IrOp op) { if (op.kind == IrOpKind::Inst) - removeUse(function, function.instructions[op.index]); + removeInstUse(function, op.index); else if (op.kind == IrOpKind::Block) - removeUse(function, function.blocks[op.index]); + removeBlockUse(function, op.index); } bool isGCO(uint8_t tag) @@ -83,24 +106,6 @@ void kill(IrFunction& function, IrBlock& block) block.finish = ~0u; } -void removeUse(IrFunction& function, IrInst& inst) -{ - LUAU_ASSERT(inst.useCount); - inst.useCount--; - - if (inst.useCount == 0) - kill(function, inst); -} - -void removeUse(IrFunction& function, IrBlock& block) -{ - LUAU_ASSERT(block.useCount); - block.useCount--; - - if (block.useCount == 0) - kill(function, block); -} - void replace(IrFunction& function, IrOp& original, IrOp replacement) { // Add use before removing new one if that's the last one keeping target operand alive @@ -122,6 +127,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl addUse(function, replacement.e); addUse(function, replacement.f); + // An extra reference is added so block will not remove itself + block.useCount++; + // If we introduced an earlier terminating instruction, all following instructions become dead if (!isBlockTerminator(inst.cmd) && isBlockTerminator(replacement.cmd)) { @@ -142,6 +150,10 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl removeUse(function, inst.f); inst = replacement; + + // Removing the earlier extra reference, this might leave the block without users without marking it as dead + // This will have to be handled by separate dead code elimination + block.useCount--; } void substitute(IrFunction& function, IrInst& inst, IrOp replacement) diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index afd364018..35e11ca50 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -12,7 +12,6 @@ inline bool isFlagExperimental(const char* flag) // or critical bugs that are found after the code has been submitted. static const char* const kList[] = { "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code - "LuauTryhardAnd", // waiting for a fix in graphql-lua -> apollo-client-lia -> lua-apps "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) // makes sure we always have at least one entry nullptr, diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index dcf569230..5d6b760eb 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -20,6 +20,18 @@ #define LUAU_FASTMATH_END #endif +// Some functions like floor/ceil have SSE4.1 equivalents but we currently support systems without SSE4.1 +// Note that we only need to do this when SSE4.1 support is not guaranteed by compiler settings, as otherwise compiler will optimize these for us. +#if (defined(__x86_64__) || defined(_M_X64)) && !defined(__SSE4_1__) && !defined(__AVX__) +#if defined(_MSC_VER) && !defined(__clang__) +#define LUAU_TARGET_SSE41 +#elif defined(__GNUC__) && defined(__has_attribute) +#if __has_attribute(target) +#define LUAU_TARGET_SSE41 __attribute__((target("sse4.1"))) +#endif +#endif +#endif + // Used on functions that have a printf-like interface to validate them statically #if defined(__GNUC__) #define LUA_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg))) diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 71869b118..3c669bff9 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -15,6 +15,16 @@ #include #endif +#ifdef LUAU_TARGET_SSE41 +#include + +#ifndef _MSC_VER +#include // on MSVC this comes from intrin.h +#endif +#endif + +LUAU_FASTFLAGVARIABLE(LuauBuiltinSSE41, false) + // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -95,7 +105,9 @@ static int luauF_atan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } +// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN +LUAU_NOINLINE static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -158,7 +170,9 @@ static int luauF_exp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } +// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN +LUAU_NOINLINE static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -935,7 +949,9 @@ static int luauF_sign(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } +// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN +LUAU_NOINLINE static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -1244,6 +1260,78 @@ static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, St return -1; } +#ifdef LUAU_TARGET_SSE41 +template +LUAU_TARGET_SSE41 inline double roundsd_sse41(double v) +{ + __m128d av = _mm_set_sd(v); + __m128d rv = _mm_round_sd(av, av, Rounding | _MM_FROUND_NO_EXC); + return _mm_cvtsd_f64(rv); +} + +LUAU_TARGET_SSE41 static int luauF_floor_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (!FFlag::LuauBuiltinSSE41) + return luauF_floor(L, res, arg0, nresults, args, nparams); + + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, roundsd_sse41<_MM_FROUND_TO_NEG_INF>(a1)); + return 1; + } + + return -1; +} + +LUAU_TARGET_SSE41 static int luauF_ceil_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (!FFlag::LuauBuiltinSSE41) + return luauF_ceil(L, res, arg0, nresults, args, nparams); + + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, roundsd_sse41<_MM_FROUND_TO_POS_INF>(a1)); + return 1; + } + + return -1; +} + +LUAU_TARGET_SSE41 static int luauF_round_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (!FFlag::LuauBuiltinSSE41) + return luauF_round(L, res, arg0, nresults, args, nparams); + + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + // roundsd only supports bankers rounding natively, so we need to emulate rounding by using truncation + // offset is prevfloat(0.5), which is important so that we round prevfloat(0.5) to 0. + const double offset = 0.49999999999999994; + setnvalue(res, roundsd_sse41<_MM_FROUND_TO_ZERO>(a1 + (a1 < 0 ? -offset : offset))); + return 1; + } + + return -1; +} + +static bool luau_hassse41() +{ + int cpuinfo[4] = {}; +#ifdef _MSC_VER + __cpuid(cpuinfo, 1); +#else + __cpuid(1, cpuinfo[0], cpuinfo[1], cpuinfo[2], cpuinfo[3]); +#endif + + // We requre SSE4.1 support for ROUNDSD + // https://en.wikipedia.org/wiki/CPUID#EAX=1:_Processor_Info_and_Feature_Bits + return (cpuinfo[2] & (1 << 19)) != 0; +} +#endif + const luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1253,12 +1341,24 @@ const luau_FastFunction luauF_table[256] = { luauF_asin, luauF_atan2, luauF_atan, + +#ifdef LUAU_TARGET_SSE41 + luau_hassse41() ? luauF_ceil_sse41 : luauF_ceil, +#else luauF_ceil, +#endif + luauF_cosh, luauF_cos, luauF_deg, luauF_exp, + +#ifdef LUAU_TARGET_SSE41 + luau_hassse41() ? luauF_floor_sse41 : luauF_floor, +#else luauF_floor, +#endif + luauF_fmod, luauF_frexp, luauF_ldexp, @@ -1300,7 +1400,12 @@ const luau_FastFunction luauF_table[256] = { luauF_clamp, luauF_sign, + +#ifdef LUAU_TARGET_SSE41 + luau_hassse41() ? luauF_round_sse41 : luauF_round, +#else luauF_round, +#endif luauF_rawset, luauF_rawget, diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 779ab4cdb..9d086f562 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCheckGetInfoIndex, false) - static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -176,18 +174,9 @@ int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar) CallInfo* ci = NULL; if (level < 0) { - if (FFlag::LuauCheckGetInfoIndex) - { - const TValue* func = luaA_toobject(L, level); - api_check(L, ttisfunction(func)); - f = clvalue(func); - } - else - { - StkId func = L->top + level; - api_check(L, ttisfunction(func)); - f = clvalue(func); - } + const TValue* func = luaA_toobject(L, level); + api_check(L, ttisfunction(func)); + f = clvalue(func); } else if (unsigned(level) < unsigned(L->ci - L->base_ci)) { diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 0693b846f..2d4e3277a 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -300,7 +300,7 @@ static float fade(float t) return t * t * t * (t * (t * 6 - 15) + 10); } -static float lerp(float t, float a, float b) +static float math_lerp(float t, float a, float b) { return a + t * (b - a); } @@ -342,10 +342,11 @@ static float perlin(float x, float y, float z) int ba = p[b] + zi; int bb = p[b + 1] + zi; - return lerp(w, - lerp(v, lerp(u, grad(p[aa], xf, yf, zf), grad(p[ba], xf - 1, yf, zf)), lerp(u, grad(p[ab], xf, yf - 1, zf), grad(p[bb], xf - 1, yf - 1, zf))), - lerp(v, lerp(u, grad(p[aa + 1], xf, yf, zf - 1), grad(p[ba + 1], xf - 1, yf, zf - 1)), - lerp(u, grad(p[ab + 1], xf, yf - 1, zf - 1), grad(p[bb + 1], xf - 1, yf - 1, zf - 1)))); + return math_lerp(w, + math_lerp(v, math_lerp(u, grad(p[aa], xf, yf, zf), grad(p[ba], xf - 1, yf, zf)), + math_lerp(u, grad(p[ab], xf, yf - 1, zf), grad(p[bb], xf - 1, yf - 1, zf))), + math_lerp(v, math_lerp(u, grad(p[aa + 1], xf, yf, zf - 1), grad(p[ba + 1], xf - 1, yf, zf - 1)), + math_lerp(u, grad(p[ab + 1], xf, yf - 1, zf - 1), grad(p[bb + 1], xf - 1, yf - 1, zf - 1)))); } static int math_noise(lua_State* L) diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 0efa9ee04..ddee3a71e 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,6 +10,8 @@ #include "ldebug.h" #include "lvm.h" +LUAU_FASTFLAGVARIABLE(LuauOptimizedSort, false) + static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -305,12 +307,14 @@ static int tunpack(lua_State* L) static void set2(lua_State* L, int i, int j) { + LUAU_ASSERT(!FFlag::LuauOptimizedSort); lua_rawseti(L, 1, i); lua_rawseti(L, 1, j); } static int sort_comp(lua_State* L, int a, int b) { + LUAU_ASSERT(!FFlag::LuauOptimizedSort); if (!lua_isnil(L, 2)) { // function? int res; @@ -328,6 +332,7 @@ static int sort_comp(lua_State* L, int a, int b) static void auxsort(lua_State* L, int l, int u) { + LUAU_ASSERT(!FFlag::LuauOptimizedSort); while (l < u) { // for tail recursion int i, j; @@ -407,16 +412,145 @@ static void auxsort(lua_State* L, int l, int u) } // repeat the routine for the larger one } -static int sort(lua_State* L) +typedef int (*SortPredicate)(lua_State* L, const TValue* l, const TValue* r); + +static int sort_func(lua_State* L, const TValue* l, const TValue* r) { - luaL_checktype(L, 1, LUA_TTABLE); - int n = lua_objlen(L, 1); - luaL_checkstack(L, 40, ""); // assume array is smaller than 2^40 - if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? - luaL_checktype(L, 2, LUA_TFUNCTION); - lua_settop(L, 2); // make sure there is two arguments - auxsort(L, 1, n); - return 0; + LUAU_ASSERT(L->top == L->base + 2); // table, function + + setobj2s(L, L->top, &L->base[1]); + setobj2s(L, L->top + 1, l); + setobj2s(L, L->top + 2, r); + L->top += 3; // safe because of LUA_MINSTACK guarantee + luaD_call(L, L->top - 3, 1); + L->top -= 1; // maintain stack depth + + return !l_isfalse(L->top); +} + +inline void sort_swap(lua_State* L, Table* t, int i, int j) +{ + TValue* arr = t->array; + int n = t->sizearray; + LUAU_ASSERT(unsigned(i) < unsigned(n) && unsigned(j) < unsigned(n)); // contract maintained in sort_less after predicate call + + // no barrier required because both elements are in the array before and after the swap + TValue temp; + setobj2s(L, &temp, &arr[i]); + setobj2t(L, &arr[i], &arr[j]); + setobj2t(L, &arr[j], &temp); +} + +inline int sort_less(lua_State* L, Table* t, int i, int j, SortPredicate pred) +{ + TValue* arr = t->array; + int n = t->sizearray; + LUAU_ASSERT(unsigned(i) < unsigned(n) && unsigned(j) < unsigned(n)); // contract maintained in sort_less after predicate call + + int res = pred(L, &arr[i], &arr[j]); + + // predicate call may resize the table, which is invalid + if (t->sizearray != n) + luaL_error(L, "table modified during sorting"); + + return res; +} + +static void sort_rec(lua_State* L, Table* t, int l, int u, SortPredicate pred) +{ + // sort range [l..u] (inclusive, 0-based) + while (l < u) + { + int i, j; + // sort elements a[l], a[(l+u)/2] and a[u] + if (sort_less(L, t, u, l, pred)) // a[u] < a[l]? + sort_swap(L, t, u, l); // swap a[l] - a[u] + if (u - l == 1) + break; // only 2 elements + i = l + ((u - l) >> 1); // midpoint + if (sort_less(L, t, i, l, pred)) // a[i]= P + while (sort_less(L, t, ++i, p, pred)) + { + if (i >= u) + luaL_error(L, "invalid order function for sorting"); + } + // repeat --j until a[j] <= P + while (sort_less(L, t, p, --j, pred)) + { + if (j <= l) + luaL_error(L, "invalid order function for sorting"); + } + if (j < i) + break; + sort_swap(L, t, i, j); + } + // swap pivot (a[u-1]) with a[i], which is the new midpoint + sort_swap(L, t, u - 1, i); + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // adjust so that smaller half is in [j..i] and larger one in [l..u] + if (i - l < u - i) + { + j = l; + i = i - 1; + l = i + 2; + } + else + { + j = i + 1; + i = u; + u = j - 2; + } + sort_rec(L, t, j, i, pred); // call recursively the smaller one + } // repeat the routine for the larger one +} + +static int tsort(lua_State* L) +{ + if (FFlag::LuauOptimizedSort) + { + luaL_checktype(L, 1, LUA_TTABLE); + Table* t = hvalue(L->base); + int n = luaH_getn(t); + if (t->readonly) + luaG_readonlyerror(L); + + SortPredicate pred = luaV_lessthan; + if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? + { + luaL_checktype(L, 2, LUA_TFUNCTION); + pred = sort_func; + } + lua_settop(L, 2); // make sure there are two arguments + + if (n > 0) + sort_rec(L, t, 0, n - 1, pred); + return 0; + } + else + { + luaL_checktype(L, 1, LUA_TTABLE); + int n = lua_objlen(L, 1); + luaL_checkstack(L, 40, ""); // assume array is smaller than 2^40 + if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? + luaL_checktype(L, 2, LUA_TFUNCTION); + lua_settop(L, 2); // make sure there is two arguments + auxsort(L, 1, n); + return 0; + } } // }====================================================== @@ -530,7 +664,7 @@ static const luaL_Reg tab_funcs[] = { {"maxn", maxn}, {"insert", tinsert}, {"remove", tremove}, - {"sort", sort}, + {"sort", tsort}, {"pack", tpack}, {"unpack", tunpack}, {"move", tmove}, diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 6f600e97d..c8a184a17 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -145,15 +145,16 @@ LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) L->base = L->ci->base; } - luaD_checkstack(L, LUA_MINSTACK); // ensure minimum stack size - L->ci->top = L->top + LUA_MINSTACK; - LUAU_ASSERT(L->ci->top <= L->stack_last); - // note: the pc expectations of the hook are matching the general "pc points to next instruction" // however, for the hook to be able to continue execution from the same point, this is called with savedpc at the *current* instruction + // this needs to be called before luaD_checkstack in case it fails to reallocate stack if (L->ci->savedpc) L->ci->savedpc++; + luaD_checkstack(L, LUA_MINSTACK); // ensure minimum stack size + L->ci->top = L->top + LUA_MINSTACK; + LUAU_ASSERT(L->ci->top <= L->stack_last); + Closure* cl = clvalue(L->ci->func); lua_Debug ar; diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 05d397540..b77207dae 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -201,15 +201,23 @@ static const TValue* get_compTM(lua_State* L, Table* mt1, Table* mt2, TMS event) return NULL; } -static int call_orderTM(lua_State* L, const TValue* p1, const TValue* p2, TMS event) +static int call_orderTM(lua_State* L, const TValue* p1, const TValue* p2, TMS event, bool error = false) { const TValue* tm1 = luaT_gettmbyobj(L, p1, event); const TValue* tm2; if (ttisnil(tm1)) + { + if (error) + luaG_ordererror(L, p1, p2, event); return -1; // no metamethod? + } tm2 = luaT_gettmbyobj(L, p2, event); if (!luaO_rawequalObj(tm1, tm2)) // different metamethods? + { + if (error) + luaG_ordererror(L, p1, p2, event); return -1; + } callTMres(L, L->top, tm1, p1, p2); return !l_isfalse(L->top); } @@ -239,16 +247,14 @@ int luaV_strcmp(const TString* ls, const TString* rs) int luaV_lessthan(lua_State* L, const TValue* l, const TValue* r) { - int res; - if (ttype(l) != ttype(r)) + if (LUAU_UNLIKELY(ttype(l) != ttype(r))) luaG_ordererror(L, l, r, TM_LT); - else if (ttisnumber(l)) + else if (LUAU_LIKELY(ttisnumber(l))) return luai_numlt(nvalue(l), nvalue(r)); else if (ttisstring(l)) return luaV_strcmp(tsvalue(l), tsvalue(r)) < 0; - else if ((res = call_orderTM(L, l, r, TM_LT)) == -1) - luaG_ordererror(L, l, r, TM_LT); - return res; + else + return call_orderTM(L, l, r, TM_LT, /* error= */ true); } int luaV_lessequal(lua_State* L, const TValue* l, const TValue* r) diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 5ca165765..c94f0889b 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -97,38 +97,39 @@ lua_State* createGlobalState() return L; } -int registerTypes(Luau::TypeChecker& env) +int registerTypes(Luau::TypeChecker& typeChecker, Luau::GlobalTypes& globals) { using namespace Luau; using std::nullopt; - Luau::registerBuiltinGlobals(env); + Luau::registerBuiltinGlobals(typeChecker, globals); - TypeArena& arena = env.globalTypes; + TypeArena& arena = globals.globalTypes; + BuiltinTypes& builtinTypes = *globals.builtinTypes; // Vector3 stub TypeId vector3MetaType = arena.addType(TableType{}); TypeId vector3InstanceType = arena.addType(ClassType{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test"}); getMutable(vector3InstanceType)->props = { - {"X", {env.numberType}}, - {"Y", {env.numberType}}, - {"Z", {env.numberType}}, + {"X", {builtinTypes.numberType}}, + {"Y", {builtinTypes.numberType}}, + {"Z", {builtinTypes.numberType}}, }; getMutable(vector3MetaType)->props = { {"__add", {makeFunction(arena, nullopt, {vector3InstanceType, vector3InstanceType}, {vector3InstanceType})}}, }; - env.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; + globals.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; // Instance stub TypeId instanceType = arena.addType(ClassType{"Instance", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(instanceType)->props = { - {"Name", {env.stringType}}, + {"Name", {builtinTypes.stringType}}, }; - env.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; + globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; // Part stub TypeId partType = arena.addType(ClassType{"Part", {}, instanceType, nullopt, {}, {}, "Test"}); @@ -136,9 +137,9 @@ int registerTypes(Luau::TypeChecker& env) {"Position", {vector3InstanceType}}, }; - env.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, partType}; + globals.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, partType}; - for (const auto& [_, fun] : env.globalScope->exportedTypeBindings) + for (const auto& [_, fun] : globals.globalScope->exportedTypeBindings) persist(fun.type); return 0; @@ -146,11 +147,11 @@ int registerTypes(Luau::TypeChecker& env) static void setupFrontend(Luau::Frontend& frontend) { - registerTypes(frontend.typeChecker); - Luau::freeze(frontend.typeChecker.globalTypes); + registerTypes(frontend.typeChecker, frontend.globals); + Luau::freeze(frontend.globals.globalTypes); - registerTypes(frontend.typeCheckerForAutocomplete); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + registerTypes(frontend.typeCheckerForAutocomplete, frontend.globalsForAutocomplete); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); frontend.iceHandler.onInternalError = [](const char* error) { printf("ICE: %s\n", error); @@ -264,6 +265,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) static Luau::Frontend frontend(&fileResolver, &configResolver, options); static int once = (setupFrontend(frontend), 0); + (void)once; // restart frontend.clear(); @@ -302,7 +304,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down // note: it's important for typeck to be destroyed at this point! - for (auto& p : frontend.typeChecker.globalScope->bindings) + for (auto& p : frontend.globals.globalScope->bindings) { Luau::ToStringOptions opts; opts.exhaustive = true; diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index a642334af..25521e35b 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -282,4 +282,17 @@ TEST_CASE_FIXTURE(Fixture, "Luau_selectively_query_for_a_different_boolean_2") REQUIRE(snd->value == true); } +TEST_CASE_FIXTURE(Fixture, "include_types_ancestry") +{ + check("local x: number = 4;"); + const Position pos(0, 10); + + std::vector ancestryNoTypes = findAstAncestryOfPosition(*getMainSourceModule(), pos); + std::vector ancestryTypes = findAstAncestryOfPosition(*getMainSourceModule(), pos, true); + + CHECK(ancestryTypes.size() > ancestryNoTypes.size()); + CHECK(!ancestryNoTypes.back()->asType()); + CHECK(ancestryTypes.back()->asType()); +} + TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 85bd55077..aedb50ab6 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,8 +15,6 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauFixAutocompleteInWhile) -LUAU_FASTFLAG(LuauFixAutocompleteInFor) using namespace Luau; @@ -85,10 +83,11 @@ struct ACFixtureImpl : BaseType LoadDefinitionFileResult loadDefinition(const std::string& source) { - TypeChecker& typeChecker = this->frontend.typeCheckerForAutocomplete; - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, source, "@test"); - freeze(typeChecker.globalTypes); + GlobalTypes& globals = this->frontend.globalsForAutocomplete; + unfreeze(globals.globalTypes); + LoadDefinitionFileResult result = + loadDefinitionFile(this->frontend.typeChecker, globals, globals.globalScope, source, "@test", /* captureComments */ false); + freeze(globals.globalTypes); REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); return result; @@ -110,10 +109,10 @@ struct ACFixture : ACFixtureImpl ACFixture() : ACFixtureImpl() { - addGlobalBinding(frontend, "table", Binding{typeChecker.anyType}); - addGlobalBinding(frontend, "math", Binding{typeChecker.anyType}); - addGlobalBinding(frontend.typeCheckerForAutocomplete, "table", Binding{typeChecker.anyType}); - addGlobalBinding(frontend.typeCheckerForAutocomplete, "math", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.globals, "table", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globals, "math", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "table", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "math", Binding{builtinTypes->anyType}); } }; @@ -630,19 +629,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") )"); auto ac5 = autocomplete('1'); - if (FFlag::LuauFixAutocompleteInFor) - { - CHECK_EQ(ac5.entryMap.count("math"), 1); - CHECK_EQ(ac5.entryMap.count("do"), 0); - CHECK_EQ(ac5.entryMap.count("end"), 0); - CHECK_EQ(ac5.context, AutocompleteContext::Expression); - } - else - { - CHECK_EQ(ac5.entryMap.count("do"), 1); - CHECK_EQ(ac5.entryMap.count("end"), 0); - CHECK_EQ(ac5.context, AutocompleteContext::Keyword); - } + CHECK_EQ(ac5.entryMap.count("math"), 1); + CHECK_EQ(ac5.entryMap.count("do"), 0); + CHECK_EQ(ac5.entryMap.count("end"), 0); + CHECK_EQ(ac5.context, AutocompleteContext::Expression); check(R"( for x = 1, 2, 5 f@1 @@ -661,29 +651,26 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") CHECK_EQ(ac7.entryMap.count("end"), 1); CHECK_EQ(ac7.context, AutocompleteContext::Statement); - if (FFlag::LuauFixAutocompleteInFor) - { - check(R"(local Foo = 1 - for x = @11, @22, @35 - )"); + check(R"(local Foo = 1 + for x = @11, @22, @35 + )"); - for (int i = 0; i < 3; ++i) - { - auto ac8 = autocomplete('1' + i); - CHECK_EQ(ac8.entryMap.count("Foo"), 1); - CHECK_EQ(ac8.entryMap.count("do"), 0); - } + for (int i = 0; i < 3; ++i) + { + auto ac8 = autocomplete('1' + i); + CHECK_EQ(ac8.entryMap.count("Foo"), 1); + CHECK_EQ(ac8.entryMap.count("do"), 0); + } - check(R"(local Foo = 1 - for x = @11, @22 - )"); + check(R"(local Foo = 1 + for x = @11, @22 + )"); - for (int i = 0; i < 2; ++i) - { - auto ac9 = autocomplete('1' + i); - CHECK_EQ(ac9.entryMap.count("Foo"), 1); - CHECK_EQ(ac9.entryMap.count("do"), 0); - } + for (int i = 0; i < 2; ++i) + { + auto ac9 = autocomplete('1' + i); + CHECK_EQ(ac9.entryMap.count("Foo"), 1); + CHECK_EQ(ac9.entryMap.count("do"), 0); } } @@ -776,18 +763,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") )"); auto ac2 = autocomplete('1'); - if (FFlag::LuauFixAutocompleteInWhile) - { - CHECK_EQ(3, ac2.entryMap.size()); - CHECK_EQ(ac2.entryMap.count("do"), 1); - CHECK_EQ(ac2.entryMap.count("and"), 1); - CHECK_EQ(ac2.entryMap.count("or"), 1); - } - else - { - CHECK_EQ(1, ac2.entryMap.size()); - CHECK_EQ(ac2.entryMap.count("do"), 1); - } + CHECK_EQ(3, ac2.entryMap.size()); + CHECK_EQ(ac2.entryMap.count("do"), 1); + CHECK_EQ(ac2.entryMap.count("and"), 1); + CHECK_EQ(ac2.entryMap.count("or"), 1); CHECK_EQ(ac2.context, AutocompleteContext::Keyword); check(R"( @@ -803,31 +782,20 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") )"); auto ac4 = autocomplete('1'); - if (FFlag::LuauFixAutocompleteInWhile) - { - CHECK_EQ(3, ac4.entryMap.size()); - CHECK_EQ(ac4.entryMap.count("do"), 1); - CHECK_EQ(ac4.entryMap.count("and"), 1); - CHECK_EQ(ac4.entryMap.count("or"), 1); - } - else - { - CHECK_EQ(1, ac4.entryMap.size()); - CHECK_EQ(ac4.entryMap.count("do"), 1); - } + CHECK_EQ(3, ac4.entryMap.size()); + CHECK_EQ(ac4.entryMap.count("do"), 1); + CHECK_EQ(ac4.entryMap.count("and"), 1); + CHECK_EQ(ac4.entryMap.count("or"), 1); CHECK_EQ(ac4.context, AutocompleteContext::Keyword); - if (FFlag::LuauFixAutocompleteInWhile) - { - check(R"( - while t@1 - )"); + check(R"( + while t@1 + )"); - auto ac5 = autocomplete('1'); - CHECK_EQ(ac5.entryMap.count("do"), 0); - CHECK_EQ(ac5.entryMap.count("true"), 1); - CHECK_EQ(ac5.entryMap.count("false"), 1); - } + auto ac5 = autocomplete('1'); + CHECK_EQ(ac5.entryMap.count("do"), 0); + CHECK_EQ(ac5.entryMap.count("true"), 1); + CHECK_EQ(ac5.entryMap.count("false"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") @@ -3460,11 +3428,11 @@ TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") declare function require(path: string): any )"); - std::optional require = frontend.typeCheckerForAutocomplete.globalScope->linearSearchForBinding("require"); + std::optional require = frontend.globalsForAutocomplete.globalScope->linearSearchForBinding("require"); REQUIRE(require); - Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); attachTag(require->typeId, "RequireCall"); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); check(R"( local x = require("testing/@1") diff --git a/tests/BuiltinDefinitions.test.cpp b/tests/BuiltinDefinitions.test.cpp index 188f2190f..75d054bd2 100644 --- a/tests/BuiltinDefinitions.test.cpp +++ b/tests/BuiltinDefinitions.test.cpp @@ -12,9 +12,9 @@ TEST_SUITE_BEGIN("BuiltinDefinitionsTest"); TEST_CASE_FIXTURE(BuiltinsFixture, "lib_documentation_symbols") { - CHECK(!typeChecker.globalScope->bindings.empty()); + CHECK(!frontend.globals.globalScope->bindings.empty()); - for (const auto& [name, binding] : typeChecker.globalScope->bindings) + for (const auto& [name, binding] : frontend.globals.globalScope->bindings) { std::string nameString(name.c_str()); std::string expectedRootSymbol = "@luau/global/" + nameString; diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index 087b88d53..caf773b82 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -11,8 +11,9 @@ namespace Luau ClassFixture::ClassFixture() { - TypeArena& arena = typeChecker.globalTypes; - TypeId numberType = typeChecker.numberType; + GlobalTypes& globals = frontend.globals; + TypeArena& arena = globals.globalTypes; + TypeId numberType = builtinTypes->numberType; unfreeze(arena); @@ -28,47 +29,47 @@ ClassFixture::ClassFixture() {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, {"New", {makeFunction(arena, nullopt, {}, {baseClassInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - addGlobalBinding(frontend, "BaseClass", baseClassType, "@test"); + globals.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + addGlobalBinding(globals, "BaseClass", baseClassType, "@test"); TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { - {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, + {"Method", {makeFunction(arena, childClassInstanceType, {}, {builtinTypes->stringType})}}, }; TypeId childClassType = arena.addType(ClassType{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(childClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; - addGlobalBinding(frontend, "ChildClass", childClassType, "@test"); + globals.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + addGlobalBinding(globals, "ChildClass", childClassType, "@test"); TypeId grandChildInstanceType = arena.addType(ClassType{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(grandChildInstanceType)->props = { - {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, + {"Method", {makeFunction(arena, grandChildInstanceType, {}, {builtinTypes->stringType})}}, }; TypeId grandChildType = arena.addType(ClassType{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(grandChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; - addGlobalBinding(frontend, "GrandChild", childClassType, "@test"); + globals.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; + addGlobalBinding(globals, "GrandChild", childClassType, "@test"); TypeId anotherChildInstanceType = arena.addType(ClassType{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(anotherChildInstanceType)->props = { - {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, + {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {builtinTypes->stringType})}}, }; TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(anotherChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; - addGlobalBinding(frontend, "AnotherChild", childClassType, "@test"); + globals.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; + addGlobalBinding(globals, "AnotherChild", childClassType, "@test"); TypeId unrelatedClassInstanceType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); @@ -76,8 +77,8 @@ ClassFixture::ClassFixture() getMutable(unrelatedClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {unrelatedClassInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["UnrelatedClass"] = TypeFun{{}, unrelatedClassInstanceType}; - addGlobalBinding(frontend, "UnrelatedClass", unrelatedClassType, "@test"); + globals.globalScope->exportedTypeBindings["UnrelatedClass"] = TypeFun{{}, unrelatedClassInstanceType}; + addGlobalBinding(globals, "UnrelatedClass", unrelatedClassType, "@test"); TypeId vector2MetaType = arena.addType(TableType{}); @@ -94,17 +95,17 @@ ClassFixture::ClassFixture() getMutable(vector2MetaType)->props = { {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; - addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); + globals.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; + addGlobalBinding(globals, "Vector2", vector2Type, "@test"); TypeId callableClassMetaType = arena.addType(TableType{}); TypeId callableClassType = arena.addType(ClassType{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); getMutable(callableClassMetaType)->props = { - {"__call", {makeFunction(arena, nullopt, {callableClassType, typeChecker.stringType}, {typeChecker.numberType})}}, + {"__call", {makeFunction(arena, nullopt, {callableClassType, builtinTypes->stringType}, {builtinTypes->numberType})}}, }; - typeChecker.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; + globals.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; - for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) + for (const auto& [name, tf] : globals.globalScope->exportedTypeBindings) persist(tf.type); freeze(arena); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 077310ac8..7d5f41a15 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -503,14 +503,15 @@ TEST_CASE("Types") Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; Luau::BuiltinTypes builtinTypes; - Luau::TypeChecker env(&moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); + Luau::GlobalTypes globals{Luau::NotNull{&builtinTypes}}; + Luau::TypeChecker env(globals, &moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); - Luau::registerBuiltinGlobals(env); - Luau::freeze(env.globalTypes); + Luau::registerBuiltinGlobals(env, globals); + Luau::freeze(globals.globalTypes); lua_newtable(L); - for (const auto& [name, binding] : env.globalScope->bindings) + for (const auto& [name, binding] : globals.globalScope->bindings) { populateRTTI(L, binding.typeId); lua_setfield(L, -2, toString(name).c_str()); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index a9a43f0b6..81e5c41b7 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -22,7 +22,7 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) AstStatBlock* root = parse(code); dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); cgb = std::make_unique("MainModule", mainModule, &arena, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), - frontend.getGlobalScope(), &logger, NotNull{dfg.get()}); + frontend.globals.globalScope, &logger, NotNull{dfg.get()}); cgb->visit(root); rootScope = cgb->rootScope; constraints = Luau::borrowConstraints(cgb->constraints); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index cbceabbdc..a9c94eefc 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -138,17 +138,16 @@ Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* randomConstraintResolutionSeed */ randomSeed}) - , typeChecker(frontend.typeChecker) , builtinTypes(frontend.builtinTypes) { configResolver.defaultConfig.mode = Mode::Strict; configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.parseOptions.captureComments = true; - registerBuiltinTypes(frontend); + registerBuiltinTypes(frontend.globals); - Luau::freeze(frontend.typeChecker.globalTypes); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::freeze(frontend.globals.globalTypes); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); Luau::setPrintLine([](auto s) {}); } @@ -178,11 +177,11 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars if (FFlag::DebugLuauDeferredConstraintResolution) { - Luau::check(*sourceModule, {}, frontend.builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, - typeChecker.globalScope, frontend.options); + Luau::check(*sourceModule, {}, builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, + frontend.globals.globalScope, frontend.options); } else - typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + frontend.typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); } throw ParseErrors(result.errors); @@ -447,9 +446,9 @@ void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) void Fixture::registerTestTypes() { - addGlobalBinding(frontend, "game", typeChecker.anyType, "@luau"); - addGlobalBinding(frontend, "workspace", typeChecker.anyType, "@luau"); - addGlobalBinding(frontend, "script", typeChecker.anyType, "@luau"); + addGlobalBinding(frontend.globals, "game", builtinTypes->anyType, "@luau"); + addGlobalBinding(frontend.globals, "workspace", builtinTypes->anyType, "@luau"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@luau"); } void Fixture::dumpErrors(const CheckResult& cr) @@ -499,9 +498,9 @@ void Fixture::validateErrors(const std::vector& errors) LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test"); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); if (result.module) dumpErrors(result.module); @@ -512,16 +511,16 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) : Fixture(freeze, prepareAutocomplete) { - Luau::unfreeze(frontend.typeChecker.globalTypes); - Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::unfreeze(frontend.globals.globalTypes); + Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); registerBuiltinGlobals(frontend); if (prepareAutocomplete) - registerBuiltinGlobals(frontend.typeCheckerForAutocomplete); + registerBuiltinGlobals(frontend.typeCheckerForAutocomplete, frontend.globalsForAutocomplete); registerTestTypes(); - Luau::freeze(frontend.typeChecker.globalTypes); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::freeze(frontend.globals.globalTypes); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); } ModuleName fromString(std::string_view name) @@ -581,23 +580,31 @@ std::optional linearSearchForBinding(Scope* scope, const char* name) void registerHiddenTypes(Frontend* frontend) { - TypeId t = frontend->globalTypes.addType(GenericType{"T"}); + GlobalTypes& globals = frontend->globals; + + unfreeze(globals.globalTypes); + + TypeId t = globals.globalTypes.addType(GenericType{"T"}); GenericTypeDefinition genericT{t}; - ScopePtr globalScope = frontend->getGlobalScope(); - globalScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, frontend->globalTypes.addType(NegationType{t})}; + ScopePtr globalScope = globals.globalScope; + globalScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, globals.globalTypes.addType(NegationType{t})}; globalScope->exportedTypeBindings["fun"] = TypeFun{{}, frontend->builtinTypes->functionType}; globalScope->exportedTypeBindings["cls"] = TypeFun{{}, frontend->builtinTypes->classType}; globalScope->exportedTypeBindings["err"] = TypeFun{{}, frontend->builtinTypes->errorType}; globalScope->exportedTypeBindings["tbl"] = TypeFun{{}, frontend->builtinTypes->tableType}; + + freeze(globals.globalTypes); } void createSomeClasses(Frontend* frontend) { - TypeArena& arena = frontend->globalTypes; + GlobalTypes& globals = frontend->globals; + + TypeArena& arena = globals.globalTypes; unfreeze(arena); - ScopePtr moduleScope = frontend->getGlobalScope(); + ScopePtr moduleScope = globals.globalScope; TypeId parentType = arena.addType(ClassType{"Parent", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); @@ -606,22 +613,22 @@ void createSomeClasses(Frontend* frontend) parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; - addGlobalBinding(*frontend, "Parent", {parentType}); + addGlobalBinding(globals, "Parent", {parentType}); moduleScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; TypeId childType = arena.addType(ClassType{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); - addGlobalBinding(*frontend, "Child", {childType}); + addGlobalBinding(globals, "Child", {childType}); moduleScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, parentType, std::nullopt, {}, nullptr, "Test"}); - addGlobalBinding(*frontend, "AnotherChild", {anotherChildType}); + addGlobalBinding(globals, "AnotherChild", {anotherChildType}); moduleScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildType}; TypeId unrelatedType = arena.addType(ClassType{"Unrelated", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); - addGlobalBinding(*frontend, "Unrelated", {unrelatedType}); + addGlobalBinding(globals, "Unrelated", {unrelatedType}); moduleScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; for (const auto& [name, ty] : moduleScope->exportedTypeBindings) diff --git a/tests/Fixture.h b/tests/Fixture.h index a81a5e783..5db6ed165 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -101,7 +101,6 @@ struct Fixture std::unique_ptr sourceModule; Frontend frontend; InternalErrorReporter ice; - TypeChecker& typeChecker; NotNull builtinTypes; std::string decorateWithTypes(const std::string& code); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 1d31b2813..e09990fb8 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -81,8 +81,8 @@ struct FrontendFixture : BuiltinsFixture { FrontendFixture() { - addGlobalBinding(frontend, "game", frontend.typeChecker.anyType, "@test"); - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "game", builtinTypes->anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); } }; @@ -852,12 +852,12 @@ TEST_CASE_FIXTURE(FrontendFixture, "environments") { ScopePtr testScope = frontend.addEnvironment("test"); - unfreeze(typeChecker.globalTypes); - loadDefinitionFile(typeChecker, testScope, R"( + unfreeze(frontend.globals.globalTypes); + loadDefinitionFile(frontend.typeChecker, frontend.globals, testScope, R"( export type Foo = number | string )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); fileResolver.source["A"] = R"( --!nonstrict diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 0896517f9..41146d77a 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag") optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into CHECK_TAG - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: CHECK_TAG R2, tnil, bb_fallback_1 CHECK_TAG K5, tnil, bb_fallback_1 @@ -135,7 +135,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptBinaryArith") optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into second argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_DOUBLE R1 %2 = ADD_NUM %0, R2 @@ -165,7 +165,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag1") optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into first argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %1 = LOAD_TAG R2 JUMP_EQ_TAG R1, %1, bb_1, bb_2 @@ -202,7 +202,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag2") // Load from memory is 'inlined' into second argument is it can't be done for the first one // We also swap first and second argument to generate memory access on the LHS - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R1 STORE_TAG R6, %0 @@ -239,7 +239,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into first argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_POINTER R1 %1 = GET_ARR_ADDR %0, 0i @@ -276,7 +276,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum") optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into first argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %1 = LOAD_DOUBLE R2 JUMP_CMP_NUM R1, %1, bb_1, bb_2 @@ -328,7 +328,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R0, 30i STORE_INT R0, -2147483648i @@ -374,7 +374,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowEq") updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: JUMP bb_1 @@ -423,7 +423,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumToIndex") updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R0, 4i LOP_RETURN 0u @@ -458,7 +458,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Guards") updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: LOP_RETURN 0u @@ -579,7 +579,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber STORE_INT R1, 10i @@ -625,7 +625,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber STORE_DOUBLE R0, 0.5 @@ -655,7 +655,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber LOP_RETURN 0u @@ -682,7 +682,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipOncePerBlockChecks") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: CHECK_SAFE_ENV CHECK_GC @@ -721,7 +721,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_POINTER R0 CHECK_NO_METATABLE %0, bb_fallback_1 @@ -753,7 +753,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber LOP_RETURN 0u @@ -782,7 +782,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber STORE_INT R1, 10i @@ -829,7 +829,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_DOUBLE R0, 0.5 %1 = LOAD_POINTER R0 @@ -862,32 +862,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( -bb_0: - STORE_INT R0, 10i - STORE_DOUBLE R0, 0.5 - STORE_INT R0, 10i - LOP_RETURN 0u - -)"); -} - -TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") -{ - IrOp block = build.block(IrBlockKind::Internal); - - build.beginBlock(block); - - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10)); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10)); - - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); - - updateUseCounts(build.function); - constPropInBlockChains(build); - - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R0, 10i STORE_DOUBLE R0, 0.5 @@ -917,7 +892,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R0 CHECK_TAG %0, tnumber, bb_fallback_1 @@ -949,7 +924,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagationConflicting") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R0 CHECK_TAG %0, tnumber, bb_fallback_1 @@ -985,7 +960,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R1 CHECK_TAG %0, tnumber, bb_fallback_3 @@ -1024,7 +999,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R1 CHECK_TAG %0, tnumber, bb_fallback_3 @@ -1059,7 +1034,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R1 CHECK_TAG %0, tboolean @@ -1091,7 +1066,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R1, 5i JUMP bb_1 @@ -1122,7 +1097,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_DOUBLE R1, 4 JUMP bb_2 @@ -1150,7 +1125,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber JUMP bb_1 @@ -1183,7 +1158,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber JUMP bb_1 @@ -1199,6 +1174,120 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + IrOp repeat = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(repeat); + build.inst(IrCmd::INTERRUPT, build.constUint(0)); + build.inst(IrCmd::JUMP, entry); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + STORE_TAG R0, tnumber + JUMP bb_1 + +bb_1: + LOP_RETURN 0u, R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp block = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + IrOp repeat = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(block); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(repeat); + build.inst(IrCmd::INTERRUPT, build.constUint(0)); + build.inst(IrCmd::JUMP, block); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + LOP_RETURN 0u, R0, 0i + +bb_1: + STORE_TAG R0, tnumber + JUMP bb_2 + +bb_2: + LOP_RETURN 0u, R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit1 = build.block(IrBlockKind::Internal); + IrOp block = build.block(IrBlockKind::Internal); + IrOp exit2 = build.block(IrBlockKind::Internal); + IrOp repeat = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), block, exit1); + + build.beginBlock(exit1); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(block); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit2, repeat); + + build.beginBlock(exit2); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(repeat); + build.inst(IrCmd::INTERRUPT, build.constUint(0)); + build.inst(IrCmd::JUMP, block); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + JUMP bb_1 + +bb_1: + LOP_RETURN 0u, R0, 0i + +bb_2: + STORE_TAG R0, tnumber + JUMP bb_3 + +bb_3: + LOP_RETURN 0u, R0, 0i + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); @@ -1240,7 +1329,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R2 CHECK_TAG %0, tnumber, bb_fallback_1 @@ -1315,7 +1404,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R2 CHECK_TAG %0, tnumber, bb_fallback_1 @@ -1366,7 +1455,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InfiniteLoopInPathAnalysis") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber JUMP bb_1 @@ -1379,3 +1468,212 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InfiniteLoopInPathAnalysis") } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("Analysis"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDiamond") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp a = build.block(IrBlockKind::Internal); + IrOp b = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::JUMP_EQ_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), a, b); + + build.beginBlock(a); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(b); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(2), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1, bb_2 +; in regs: R0, R1, R2, R3 +; out regs: R1, R2, R3 + %0 = LOAD_TAG R0 + JUMP_EQ_TAG %0, tnumber, bb_1, bb_2 + +bb_1: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R3 +; out regs: R2, R3 + %2 = LOAD_TVALUE R1 + STORE_TVALUE R2, %2 + JUMP bb_3 + +bb_2: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R2 +; out regs: R2, R3 + %5 = LOAD_TVALUE R1 + STORE_TVALUE R3, %5 + JUMP bb_3 + +bb_3: +; predecessors: bb_1, bb_2 +; in regs: R2, R3 + LOP_RETURN 0u, R2, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ImplicitFixedRegistersInVarargCall") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(3), build.constInt(-1)); + build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(5)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(5)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1 +; in regs: R0, R1, R2 +; out regs: R0, R1, R2, R3, R4 + FALLBACK_GETVARARGS 0u, R3, -1i + LOP_CALL 0u, R0, -1i, 5i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0, R1, R2, R3, R4 + LOP_RETURN 0u, R0, 5i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ExplicitUseOfRegisterInVarargSequence") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); + build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(0), build.vmReg(0), build.vmReg(1), build.vmReg(2), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1 +; out regs: R0... + FALLBACK_GETVARARGS 0u, R1, -1i + %1 = INVOKE_FASTCALL 0u, R0, R1, R2, -1i, -1i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0... + LOP_RETURN 0u, R0, -1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequenceRestart") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(1), build.constInt(0), build.constInt(-1)); + build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1 +; in regs: R0, R1 +; out regs: R0... + LOP_CALL 0u, R1, 0i, -1i + LOP_CALL 0u, R0, -1i, -1i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0... + LOP_RETURN 0u, R0, -1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback); + build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(fallback); + build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0 +; out regs: R0... + FALLBACK_GETVARARGS 0u, R1, -1i + %1 = LOAD_TAG R0 + CHECK_TAG %1, tnumber, bb_fallback_1 + LOP_CALL 0u, R0, -1i, -1i + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0, R1... +; out regs: R0... + LOP_CALL 0u, R0, -1i, -1i + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0... + LOP_RETURN 0u, R0, -1i + +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index c716982ee..ebd004d38 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -35,7 +35,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobal") TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") { // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(frontend, "Wait", Binding{typeChecker.anyType, {}, true, "wait", "@test/global/Wait"}); + addGlobalBinding(frontend.globals, "Wait", Binding{builtinTypes->anyType, {}, true, "wait", "@test/global/Wait"}); LintResult result = lint("Wait(5)"); @@ -47,7 +47,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobalNoReplacement") { // Normally this would be defined externally, so hack it in for testing const char* deprecationReplacementString = ""; - addGlobalBinding(frontend, "Version", Binding{typeChecker.anyType, {}, true, deprecationReplacementString}); + addGlobalBinding(frontend.globals, "Version", Binding{builtinTypes->anyType, {}, true, deprecationReplacementString}); LintResult result = lint("Version()"); @@ -373,7 +373,7 @@ return bar() TEST_CASE_FIXTURE(Fixture, "ImportUnused") { // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(frontend, "game", typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "game", builtinTypes->anyType, "@test"); LintResult result = lint(R"( local Roact = require(game.Packages.Roact) @@ -604,16 +604,16 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { - unfreeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); TableType::Props instanceProps{ - {"ClassName", {typeChecker.anyType}}, + {"ClassName", {builtinTypes->anyType}}, }; - TableType instanceTable{instanceProps, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}; - TypeId instanceType = typeChecker.globalTypes.addType(instanceTable); + TableType instanceTable{instanceProps, std::nullopt, frontend.globals.globalScope->level, Luau::TableState::Sealed}; + TypeId instanceType = frontend.globals.globalTypes.addType(instanceTable); TypeFun instanceTypeFun{{}, instanceType}; - typeChecker.globalScope->exportedTypeBindings["Part"] = instanceTypeFun; + frontend.globals.globalScope->exportedTypeBindings["Part"] = instanceTypeFun; LintResult result = lint(R"( local game = ... @@ -1270,12 +1270,12 @@ TEST_CASE_FIXTURE(Fixture, "no_spurious_warning_after_a_function_type_alias") TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") { ScopePtr testScope = frontend.addEnvironment("Test"); - unfreeze(typeChecker.globalTypes); - loadDefinitionFile(frontend.typeChecker, testScope, R"( + unfreeze(frontend.globals.globalTypes); + loadDefinitionFile(frontend.typeChecker, frontend.globals, testScope, R"( declare Foo: number )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); fileResolver.environments["A"] = "Test"; @@ -1444,31 +1444,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiTyped") { ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); - unfreeze(typeChecker.globalTypes); - TypeId instanceType = typeChecker.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); + unfreeze(frontend.globals.globalTypes); + TypeId instanceType = frontend.globals.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); persist(instanceType); - typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; + frontend.globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; getMutable(instanceType)->props = { - {"Name", {typeChecker.stringType}}, - {"DataCost", {typeChecker.numberType, /* deprecated= */ true}}, - {"Wait", {typeChecker.anyType, /* deprecated= */ true}}, + {"Name", {builtinTypes->stringType}}, + {"DataCost", {builtinTypes->numberType, /* deprecated= */ true}}, + {"Wait", {builtinTypes->anyType, /* deprecated= */ true}}, }; - TypeId colorType = typeChecker.globalTypes.addType(TableType{{}, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}); + TypeId colorType = + frontend.globals.globalTypes.addType(TableType{{}, std::nullopt, frontend.globals.globalScope->level, Luau::TableState::Sealed}); - getMutable(colorType)->props = {{"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; + getMutable(colorType)->props = {{"toHSV", {builtinTypes->anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; - addGlobalBinding(frontend, "Color3", Binding{colorType, {}}); + addGlobalBinding(frontend.globals, "Color3", Binding{colorType, {}}); - if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(frontend.globals, "table"))) { ttv->props["foreach"].deprecated = true; ttv->props["getn"].deprecated = true; ttv->props["getn"].deprecatedSuggestion = "#"; } - freeze(typeChecker.globalTypes); + freeze(frontend.globals.globalTypes); LintResult result = lintTyped(R"( return function (i: Instance) @@ -1495,7 +1496,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiUntyped") { ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); - if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(frontend.globals, "table"))) { ttv->props["foreach"].deprecated = true; ttv->props["getn"].deprecated = true; diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 2c45cc385..d2796b6d0 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -48,8 +48,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") CloneState cloneState; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, cloneState); - CHECK_EQ(newNumber, typeChecker.numberType); + TypeId newNumber = clone(builtinTypes->numberType, dest, cloneState); + CHECK_EQ(newNumber, builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") @@ -58,9 +58,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") CloneState cloneState; // Create a new number type that isn't persistent - unfreeze(typeChecker.globalTypes); - TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveType{PrimitiveType::Number}); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + TypeId oldNumber = frontend.globals.globalTypes.addType(PrimitiveType{PrimitiveType::Number}); + freeze(frontend.globals.globalTypes); TypeId newNumber = clone(oldNumber, dest, cloneState); CHECK_NE(newNumber, oldNumber); @@ -170,10 +170,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") REQUIRE(signType != nullptr); CHECK(!isInArena(signType, module->interfaceTypes)); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK(isInArena(signType, frontend.globalTypes)); - else - CHECK(isInArena(signType, typeChecker.globalTypes)); + CHECK(isInArena(signType, frontend.globals.globalTypes)); } TEST_CASE_FIXTURE(Fixture, "deepClone_union") @@ -181,9 +178,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") TypeArena dest; CloneState cloneState; - unfreeze(typeChecker.globalTypes); - TypeId oldUnion = typeChecker.globalTypes.addType(UnionType{{typeChecker.numberType, typeChecker.stringType}}); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + TypeId oldUnion = frontend.globals.globalTypes.addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}); + freeze(frontend.globals.globalTypes); TypeId newUnion = clone(oldUnion, dest, cloneState); CHECK_NE(newUnion, oldUnion); @@ -196,9 +193,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") TypeArena dest; CloneState cloneState; - unfreeze(typeChecker.globalTypes); - TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionType{{typeChecker.numberType, typeChecker.stringType}}); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + TypeId oldIntersection = frontend.globals.globalTypes.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}); + freeze(frontend.globals.globalTypes); TypeId newIntersection = clone(oldIntersection, dest, cloneState); CHECK_NE(newIntersection, oldIntersection); @@ -210,13 +207,13 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") { Type exampleMetaClass{ClassType{"ExampleClassMeta", { - {"__add", {typeChecker.anyType}}, + {"__add", {builtinTypes->anyType}}, }, std::nullopt, std::nullopt, {}, {}, "Test"}}; Type exampleClass{ClassType{"ExampleClass", { - {"PropOne", {typeChecker.numberType}}, - {"PropTwo", {typeChecker.stringType}}, + {"PropOne", {builtinTypes->numberType}}, + {"PropTwo", {builtinTypes->stringType}}, }, std::nullopt, &exampleMetaClass, {}, {}, "Test"}}; diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index a84e26381..fddab8002 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -64,7 +64,7 @@ TEST_CASE_FIXTURE(Fixture, "return_annotation_is_still_checked") LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE_NE(*typeChecker.anyType, *requireType("foo")); + REQUIRE_NE(*builtinTypes->anyType, *requireType("foo")); } #endif @@ -107,7 +107,7 @@ TEST_CASE_FIXTURE(Fixture, "locals_are_any_by_default") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.anyType, *requireType("m")); + CHECK_EQ(*builtinTypes->anyType, *requireType("m")); } TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") @@ -173,7 +173,7 @@ TEST_CASE_FIXTURE(Fixture, "table_props_are_any") TypeId fooProp = ttv->props["foo"].type; REQUIRE(fooProp != nullptr); - CHECK_EQ(*fooProp, *typeChecker.anyType); + CHECK_EQ(*fooProp, *builtinTypes->anyType); } TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") @@ -192,8 +192,8 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") TableType* ttv = getMutable(requireType("T")); REQUIRE_MESSAGE(ttv, "Should be a table: " << toString(requireType("T"))); - CHECK_EQ(*typeChecker.anyType, *ttv->props["one"].type); - CHECK_EQ(*typeChecker.anyType, *ttv->props["two"].type); + CHECK_EQ(*builtinTypes->anyType, *ttv->props["one"].type); + CHECK_EQ(*builtinTypes->anyType, *ttv->props["two"].type); CHECK_MESSAGE(get(follow(ttv->props["three"].type)), "Should be a function: " << *ttv->props["three"].type); } diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index c45932c6f..b86af0ebc 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -325,9 +325,9 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "classes") check(""); // Ensure that we have a main Module. - TypeId p = typeChecker.globalScope->lookupType("Parent")->type; - TypeId c = typeChecker.globalScope->lookupType("Child")->type; - TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type; + TypeId p = frontend.globals.globalScope->lookupType("Parent")->type; + TypeId c = frontend.globals.globalScope->lookupType("Child")->type; + TypeId u = frontend.globals.globalScope->lookupType("Unrelated")->type; CHECK(isSubtype(c, p)); CHECK(!isSubtype(p, c)); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 11dca1106..73ae47739 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -15,7 +15,7 @@ struct ToDotClassFixture : Fixture { ToDotClassFixture() { - TypeArena& arena = typeChecker.globalTypes; + TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); @@ -23,17 +23,17 @@ struct ToDotClassFixture : Fixture TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test"}); getMutable(baseClassInstanceType)->props = { - {"BaseField", {typeChecker.numberType}}, + {"BaseField", {builtinTypes->numberType}}, }; - typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + frontend.globals.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { - {"ChildField", {typeChecker.stringType}}, + {"ChildField", {builtinTypes->stringType}}, }; - typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + frontend.globals.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; - for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + for (const auto& [name, ty] : frontend.globals.globalScope->exportedTypeBindings) persist(ty.type); freeze(arena); @@ -373,7 +373,7 @@ n1 [label="GenericTypePack T"]; TEST_CASE_FIXTURE(Fixture, "bound_pack") { - TypePackVar pack{TypePackVariant{TypePack{{typeChecker.numberType}, {}}}}; + TypePackVar pack{TypePackVariant{TypePack{{builtinTypes->numberType}, {}}}}; TypePackVar bound{TypePackVariant{BoundTypePack{&pack}}}; ToDotOptions opts; diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 2fc5187b8..fd245395b 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -184,27 +184,27 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") TEST_CASE_FIXTURE(Fixture, "intersection_parenthesized_only_if_needed") { - auto utv = Type{UnionType{{typeChecker.numberType, typeChecker.stringType}}}; - auto itv = Type{IntersectionType{{&utv, typeChecker.booleanType}}}; + auto utv = Type{UnionType{{builtinTypes->numberType, builtinTypes->stringType}}}; + auto itv = Type{IntersectionType{{&utv, builtinTypes->booleanType}}}; CHECK_EQ(toString(&itv), "(number | string) & boolean"); } TEST_CASE_FIXTURE(Fixture, "union_parenthesized_only_if_needed") { - auto itv = Type{IntersectionType{{typeChecker.numberType, typeChecker.stringType}}}; - auto utv = Type{UnionType{{&itv, typeChecker.booleanType}}}; + auto itv = Type{IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}}; + auto utv = Type{UnionType{{&itv, builtinTypes->booleanType}}}; CHECK_EQ(toString(&utv), "(number & string) | boolean"); } TEST_CASE_FIXTURE(Fixture, "functions_are_always_parenthesized_in_unions_or_intersections") { - auto stringAndNumberPack = TypePackVar{TypePack{{typeChecker.stringType, typeChecker.numberType}}}; - auto numberAndStringPack = TypePackVar{TypePack{{typeChecker.numberType, typeChecker.stringType}}}; + auto stringAndNumberPack = TypePackVar{TypePack{{builtinTypes->stringType, builtinTypes->numberType}}}; + auto numberAndStringPack = TypePackVar{TypePack{{builtinTypes->numberType, builtinTypes->stringType}}}; auto sn2ns = Type{FunctionType{&stringAndNumberPack, &numberAndStringPack}}; - auto ns2sn = Type{FunctionType(typeChecker.globalScope->level, &numberAndStringPack, &stringAndNumberPack)}; + auto ns2sn = Type{FunctionType(frontend.globals.globalScope->level, &numberAndStringPack, &stringAndNumberPack)}; auto utv = Type{UnionType{{&ns2sn, &sn2ns}}}; auto itv = Type{IntersectionType{{&ns2sn, &sn2ns}}}; @@ -250,7 +250,7 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_table_type_when_length_is_exceeded { TableType ttv{}; for (char c : std::string("abcdefghijklmno")) - ttv.props[std::string(1, c)] = {typeChecker.numberType}; + ttv.props[std::string(1, c)] = {builtinTypes->numberType}; Type tv{ttv}; @@ -264,7 +264,7 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_is_still_capped_when_exhaust { TableType ttv{}; for (char c : std::string("abcdefg")) - ttv.props[std::string(1, c)] = {typeChecker.numberType}; + ttv.props[std::string(1, c)] = {builtinTypes->numberType}; Type tv{ttv}; @@ -339,7 +339,7 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table { TableType ttv{TableState::Sealed, TypeLevel{}}; for (char c : std::string("abcdefghij")) - ttv.props[std::string(1, c)] = {typeChecker.numberType}; + ttv.props[std::string(1, c)] = {builtinTypes->numberType}; Type tv{ttv}; @@ -350,7 +350,7 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_union_type_bails_early") { - Type tv{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; + Type tv{UnionType{{builtinTypes->stringType, builtinTypes->numberType}}}; UnionType* utv = getMutable(&tv); utv->options.push_back(&tv); utv->options.push_back(&tv); @@ -371,11 +371,11 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_intersection_type_bails_early") TEST_CASE_FIXTURE(Fixture, "stringifying_array_uses_array_syntax") { TableType ttv{TableState::Sealed, TypeLevel{}}; - ttv.indexer = TableIndexer{typeChecker.numberType, typeChecker.stringType}; + ttv.indexer = TableIndexer{builtinTypes->numberType, builtinTypes->stringType}; CHECK_EQ("{string}", toString(Type{ttv})); - ttv.props["A"] = {typeChecker.numberType}; + ttv.props["A"] = {builtinTypes->numberType}; CHECK_EQ("{| [number]: string, A: number |}", toString(Type{ttv})); ttv.props.clear(); @@ -562,15 +562,15 @@ TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_T Type tv1{TableType{}}; TableType* ttv = getMutable(&tv1); ttv->state = TableState::Sealed; - ttv->props["hello"] = {typeChecker.numberType}; - ttv->props["world"] = {typeChecker.numberType}; + ttv->props["hello"] = {builtinTypes->numberType}; + ttv->props["world"] = {builtinTypes->numberType}; TypePackVar tpv1{TypePack{{&tv1}}}; Type tv2{TableType{}}; TableType* bttv = getMutable(&tv2); bttv->state = TableState::Free; - bttv->props["hello"] = {typeChecker.numberType}; + bttv->props["hello"] = {builtinTypes->numberType}; bttv->boundTo = &tv1; TypePackVar tpv2{TypePack{{&tv2}}}; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index a2fc0c75e..b55c77460 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -168,7 +168,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_whe TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("Node?", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") @@ -329,7 +329,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_typ TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("Wrapped", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") @@ -345,7 +345,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_typ TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->givenType); } // Check that recursive intersection type doesn't generate an OOM @@ -520,7 +520,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") CheckResult result = check("type t10 = typeof(table)"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId ty = getGlobalBinding(frontend, "table"); + TypeId ty = getGlobalBinding(frontend.globals, "table"); CHECK(toString(ty) == "typeof(table)"); @@ -922,4 +922,29 @@ TEST_CASE_FIXTURE(Fixture, "cannot_create_cyclic_type_with_unknown_module") CHECK(toString(result.errors[0]) == "Unknown type 'B.AAA'"); } +TEST_CASE_FIXTURE(Fixture, "type_alias_locations") +{ + check(R"( + type T = number + + do + type T = string + type X = boolean + end + )"); + + ModulePtr mod = getMainModule(); + REQUIRE(mod); + REQUIRE(mod->scopes.size() == 8); + + REQUIRE(mod->scopes[0].second->typeAliasNameLocations.count("T") > 0); + CHECK(mod->scopes[0].second->typeAliasNameLocations["T"] == Location(Position(1, 13), 1)); + + REQUIRE(mod->scopes[3].second->typeAliasNameLocations.count("T") > 0); + CHECK(mod->scopes[3].second->typeAliasNameLocations["T"] == Location(Position(4, 17), 1)); + + REQUIRE(mod->scopes[3].second->typeAliasNameLocations.count("X") > 0); + CHECK(mod->scopes[3].second->typeAliasNameLocations["X"] == Location(Position(5, 17), 1)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index d5f953746..2c87cb419 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -86,7 +86,7 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") REQUIRE_EQ(1, tp->head.size()); - REQUIRE_EQ(typeChecker.anyType, follow(tp->head[0])); + REQUIRE_EQ(builtinTypes->anyType, follow(tp->head[0])); } TEST_CASE_FIXTURE(Fixture, "function_return_multret_annotations_are_checked") @@ -166,11 +166,12 @@ TEST_CASE_FIXTURE(Fixture, "infer_type_of_value_a_via_typeof_with_assignment") a = "foo" )"); - CHECK_EQ(*typeChecker.numberType, *requireType("a")); - CHECK_EQ(*typeChecker.numberType, *requireType("b")); + CHECK_EQ(*builtinTypes->numberType, *requireType("a")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b")); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{4, 12}, Position{4, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK_EQ( + result.errors[0], (TypeError{Location{Position{4, 12}, Position{4, 17}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); } TEST_CASE_FIXTURE(Fixture, "table_annotation") @@ -459,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_always_resolve_to_a_real_type") )"); TypeId fType = requireType("aa"); - REQUIRE(follow(fType) == typeChecker.numberType); + REQUIRE(follow(fType) == builtinTypes->numberType); LUAU_REQUIRE_NO_ERRORS(result); } @@ -480,7 +481,7 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") const TypeFun& a = mod.exportedTypeBindings["A"]; CHECK(isInArena(a.type, mod.interfaceTypes)); - CHECK(!isInArena(a.type, typeChecker.globalTypes)); + CHECK(!isInArena(a.type, frontend.globals.globalTypes)); std::optional exportsType = first(mod.returnType); REQUIRE(exportsType); @@ -559,7 +560,7 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") { - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -585,7 +586,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") { - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -611,7 +612,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_are_not_exported") { - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 9988a1fc5..0488196bb 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -30,7 +30,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ(builtinTypes->anyType, requireType("a")); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") @@ -209,7 +209,7 @@ TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("A"); - CHECK_EQ(aType, typeChecker.anyType); + CHECK_EQ(aType, builtinTypes->anyType); } TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 860dcfd03..5318b402e 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -106,7 +106,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_concat_returns_string") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("r")); + CHECK_EQ(*builtinTypes->stringType, *requireType("r")); } TEST_CASE_FIXTURE(BuiltinsFixture, "sort") @@ -156,7 +156,7 @@ TEST_CASE_FIXTURE(Fixture, "strings_have_methods") )LUA"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_variatic") @@ -166,7 +166,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_variatic") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); } TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_checks_for_numbers") @@ -365,7 +365,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_ )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.stringType, requireType("s")); + CHECK_EQ(builtinTypes->stringType, requireType("s")); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_3_args_overload") @@ -429,7 +429,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gcinfo") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); } TEST_CASE_FIXTURE(BuiltinsFixture, "getfenv") @@ -446,9 +446,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "os_time_takes_optional_date_table") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n1")); - CHECK_EQ(*typeChecker.numberType, *requireType("n2")); - CHECK_EQ(*typeChecker.numberType, *requireType("n3")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n1")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n2")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n3")); } TEST_CASE_FIXTURE(BuiltinsFixture, "thread_is_a_type") @@ -552,8 +552,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_correctly_ordered_types") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(tm->wantedType, typeChecker.stringType); - CHECK_EQ(tm->givenType, typeChecker.numberType); + CHECK_EQ(tm->wantedType, builtinTypes->stringType); + CHECK_EQ(tm->givenType, builtinTypes->numberType); } TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_tostring_specifier") @@ -722,8 +722,8 @@ TEST_CASE_FIXTURE(Fixture, "string_format_as_method") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(tm->wantedType, typeChecker.stringType); - CHECK_EQ(tm->givenType, typeChecker.numberType); + CHECK_EQ(tm->wantedType, builtinTypes->stringType); + CHECK_EQ(tm->givenType, builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument") @@ -860,9 +860,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_report_all_type_errors_at_corr string.format("%s%d%s", 1, "hello", true) )"); - TypeId stringType = typeChecker.stringType; - TypeId numberType = typeChecker.numberType; - TypeId booleanType = typeChecker.booleanType; + TypeId stringType = builtinTypes->stringType; + TypeId numberType = builtinTypes->numberType; + TypeId booleanType = builtinTypes->booleanType; LUAU_REQUIRE_ERROR_COUNT(6, result); @@ -1027,7 +1027,7 @@ local function f(a: typeof(f)) end TEST_CASE_FIXTURE(BuiltinsFixture, "no_persistent_typelevel_change") { - TypeId mathTy = requireType(typeChecker.globalScope, "math"); + TypeId mathTy = requireType(frontend.globals.globalScope, "math"); REQUIRE(mathTy); TableType* ttv = getMutable(mathTy); REQUIRE(ttv); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 2a681d1a6..f3f464130 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -19,13 +19,13 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_simple") declare foo2: typeof(foo) )"); - TypeId globalFooTy = getGlobalBinding(frontend, "foo"); + TypeId globalFooTy = getGlobalBinding(frontend.globals, "foo"); CHECK_EQ(toString(globalFooTy), "number"); - TypeId globalBarTy = getGlobalBinding(frontend, "bar"); + TypeId globalBarTy = getGlobalBinding(frontend.globals, "bar"); CHECK_EQ(toString(globalBarTy), "(number) -> string"); - TypeId globalFoo2Ty = getGlobalBinding(frontend, "foo2"); + TypeId globalFoo2Ty = getGlobalBinding(frontend.globals, "foo2"); CHECK_EQ(toString(globalFoo2Ty), "number"); CheckResult result = check(R"( @@ -48,20 +48,20 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") declare function var(...: any): string )"); - TypeId globalFooTy = getGlobalBinding(frontend, "foo"); + TypeId globalFooTy = getGlobalBinding(frontend.globals, "foo"); CHECK_EQ(toString(globalFooTy), "number"); - std::optional globalAsdfTy = frontend.getGlobalScope()->lookupType("Asdf"); + std::optional globalAsdfTy = frontend.globals.globalScope->lookupType("Asdf"); REQUIRE(bool(globalAsdfTy)); CHECK_EQ(toString(globalAsdfTy->type), "number | string"); - TypeId globalBarTy = getGlobalBinding(frontend, "bar"); + TypeId globalBarTy = getGlobalBinding(frontend.globals, "bar"); CHECK_EQ(toString(globalBarTy), "(number) -> string"); - TypeId globalFoo2Ty = getGlobalBinding(frontend, "foo2"); + TypeId globalFoo2Ty = getGlobalBinding(frontend.globals, "foo2"); CHECK_EQ(toString(globalFoo2Ty), "number"); - TypeId globalVarTy = getGlobalBinding(frontend, "var"); + TypeId globalVarTy = getGlobalBinding(frontend.globals, "var"); CHECK_EQ(toString(globalVarTy), "(...any) -> string"); @@ -77,25 +77,25 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_scope") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult parseFailResult = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult parseFailResult = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( declare foo )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); REQUIRE(!parseFailResult.success); - std::optional fooTy = tryGetGlobalBinding(frontend, "foo"); + std::optional fooTy = tryGetGlobalBinding(frontend.globals, "foo"); CHECK(!fooTy.has_value()); - LoadDefinitionFileResult checkFailResult = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + LoadDefinitionFileResult checkFailResult = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( local foo: string = 123 declare bar: typeof(foo) )", - "@test"); + "@test", /* captureComments */ false); REQUIRE(!checkFailResult.success); - std::optional barTy = tryGetGlobalBinding(frontend, "bar"); + std::optional barTy = tryGetGlobalBinding(frontend.globals, "bar"); CHECK(!barTy.has_value()); } @@ -139,15 +139,15 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_classes") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( declare class A X: number X: string end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); REQUIRE(!result.success); CHECK_EQ(result.parseResult.errors.size(), 0); @@ -160,15 +160,15 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( type NotAClass = {} declare class Foo extends NotAClass end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); REQUIRE(!result.success); CHECK_EQ(result.parseResult.errors.size(), 0); @@ -181,16 +181,16 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( declare class Foo extends Bar end declare class Bar extends Foo end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); REQUIRE(!result.success); } @@ -281,16 +281,16 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") } )"); - std::optional xBinding = typeChecker.globalScope->linearSearchForBinding("x"); + std::optional xBinding = frontend.globals.globalScope->linearSearchForBinding("x"); REQUIRE(bool(xBinding)); // note: loadDefinition uses the @test package name. CHECK_EQ(xBinding->documentationSymbol, "@test/global/x"); - std::optional fooTy = typeChecker.globalScope->lookupType("Foo"); + std::optional fooTy = frontend.globals.globalScope->lookupType("Foo"); REQUIRE(bool(fooTy)); CHECK_EQ(fooTy->type->documentationSymbol, "@test/globaltype/Foo"); - std::optional barTy = typeChecker.globalScope->lookupType("Bar"); + std::optional barTy = frontend.globals.globalScope->lookupType("Bar"); REQUIRE(bool(barTy)); CHECK_EQ(barTy->type->documentationSymbol, "@test/globaltype/Bar"); @@ -299,7 +299,7 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") REQUIRE_EQ(barClass->props.count("prop"), 1); CHECK_EQ(barClass->props["prop"].documentationSymbol, "@test/globaltype/Bar.prop"); - std::optional yBinding = typeChecker.globalScope->linearSearchForBinding("y"); + std::optional yBinding = frontend.globals.globalScope->linearSearchForBinding("y"); REQUIRE(bool(yBinding)); CHECK_EQ(yBinding->documentationSymbol, "@test/global/y"); @@ -319,7 +319,7 @@ TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_re declare function myFunc(): MyClass )"); - std::optional myClassTy = typeChecker.globalScope->lookupType("MyClass"); + std::optional myClassTy = frontend.globals.globalScope->lookupType("MyClass"); REQUIRE(bool(myClassTy)); CHECK_EQ(myClassTy->type->documentationSymbol, "@test/globaltype/MyClass"); } @@ -330,7 +330,7 @@ TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_type export type Evil = string )"); - std::optional ty = typeChecker.globalScope->lookupType("Evil"); + std::optional ty = frontend.globals.globalScope->lookupType("Evil"); REQUIRE(bool(ty)); CHECK_EQ(ty->type->documentationSymbol, std::nullopt); } @@ -396,8 +396,8 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( declare class Channel Messages: { Message } OnMessage: (message: Message) -> () @@ -408,8 +408,8 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") Channel: Channel end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); REQUIRE(result.success); } diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 7c2e451a6..482a6b7f5 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -33,8 +33,8 @@ TEST_CASE_FIXTURE(Fixture, "check_function_bodies") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 44}, Position{0, 48}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.booleanType, + builtinTypes->numberType, + builtinTypes->booleanType, }})); } @@ -70,7 +70,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_return_type") std::vector retVec = flatten(takeFiveType->retTypes).first; REQUIRE(!retVec.empty()); - REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); + REQUIRE_EQ(*follow(retVec[0]), *builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") @@ -78,7 +78,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") CheckResult result = check("function take_five() return 5 end local five = take_five()"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *follow(requireType("five"))); + CHECK_EQ(*builtinTypes->numberType, *follow(requireType("five"))); } TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") @@ -92,7 +92,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{builtinTypes->numberType}})); } TEST_CASE_FIXTURE(Fixture, "generalize_table_property") @@ -171,8 +171,8 @@ TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_ TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); ExtraInformation* ei = get(result.errors[1]); REQUIRE(ei); @@ -208,8 +208,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_give_other_overloads_message_if_only_one_argume TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") @@ -847,13 +847,13 @@ TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK_EQ(result.errors[0], (TypeError{Location{Position{3, 12}, Position{3, 18}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.stringType, + builtinTypes->numberType, + builtinTypes->stringType, }})); CHECK_EQ(result.errors[1], (TypeError{Location{Position{3, 20}, Position{3, 23}}, TypeMismatch{ - typeChecker.stringType, - typeChecker.numberType, + builtinTypes->stringType, + builtinTypes->numberType, }})); } @@ -1669,6 +1669,10 @@ TEST_CASE_FIXTURE(Fixture, "dont_infer_parameter_types_for_functions_from_their_ LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("(a) -> a", toString(requireType("f"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("({+ p: {+ q: a +} +}) -> a & ~false", toString(requireType("g"))); + else + CHECK_EQ("({+ p: {+ q: nil +} +}) -> nil", toString(requireType("g"))); } TEST_CASE_FIXTURE(Fixture, "dont_mutate_the_underlying_head_of_typepack_when_calling_with_self") @@ -1851,4 +1855,29 @@ end LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceeded_during_generalization") +{ + ScopedFastInt sfi{"LuauTarjanChildLimit", 2}; + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + {"LuauClonePublicInterfaceLess", true}, + {"LuauSubstitutionReentrant", true}, + {"LuauSubstitutionFixMissingFields", true}, + }; + + CheckResult result = check(R"( + function f(t) + t.x.y.z = 441 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_MESSAGE(get(result.errors[0]), "Expected CodeTooComplex but got: " << toString(result.errors[0])); + CHECK(Location({1, 17}, {1, 18}) == result.errors[0].location); + + CHECK_MESSAGE(get(result.errors[1]), "Expected UnificationTooComplex but got: " << toString(result.errors[1])); + CHECK(Location({0, 0}, {4, 4}) == result.errors[1].location); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 0ba889c89..b3b2e4c94 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -844,8 +844,8 @@ TEST_CASE_FIXTURE(Fixture, "generic_function") LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("(a) -> a", toString(requireType("id"))); - CHECK_EQ(*typeChecker.numberType, *requireType("a")); - CHECK_EQ(*typeChecker.nilType, *requireType("b")); + CHECK_EQ(*builtinTypes->numberType, *requireType("a")); + CHECK_EQ(*builtinTypes->nilType, *requireType("b")); } TEST_CASE_FIXTURE(Fixture, "generic_table_method") diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index ea6fff773..f6d04a952 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -22,8 +22,8 @@ TEST_CASE_FIXTURE(Fixture, "select_correct_union_fn") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(requireType("b"), typeChecker.stringType); - CHECK_EQ(requireType("c"), typeChecker.numberType); + CHECK_EQ(requireType("b"), builtinTypes->stringType); + CHECK_EQ(requireType("c"), builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "table_combines") @@ -123,11 +123,11 @@ TEST_CASE_FIXTURE(Fixture, "should_still_pick_an_overload_whose_arguments_are_un LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("a1"), *typeChecker.numberType); - CHECK_EQ(*requireType("a2"), *typeChecker.numberType); + CHECK_EQ(*requireType("a1"), *builtinTypes->numberType); + CHECK_EQ(*requireType("a2"), *builtinTypes->numberType); - CHECK_EQ(*requireType("b1"), *typeChecker.stringType); - CHECK_EQ(*requireType("b2"), *typeChecker.stringType); + CHECK_EQ(*requireType("b1"), *builtinTypes->stringType); + CHECK_EQ(*requireType("b2"), *builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "propagates_name") @@ -249,7 +249,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_property_of_t )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.anyType, *requireType("r")); + CHECK_EQ(*builtinTypes->anyType, *requireType("r")); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_all_parts_missing_the_property") diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 30cbe1d5b..50e9f802f 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -28,7 +28,7 @@ TEST_CASE_FIXTURE(Fixture, "for_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("q")); + CHECK_EQ(*builtinTypes->numberType, *requireType("q")); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") @@ -44,8 +44,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") @@ -61,8 +61,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") @@ -218,8 +218,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_factory_not_returning_t TypeMismatch* tm = get(result.errors[1]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") @@ -281,8 +281,8 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "while_loop") @@ -296,7 +296,7 @@ TEST_CASE_FIXTURE(Fixture, "while_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("i")); + CHECK_EQ(*builtinTypes->numberType, *requireType("i")); } TEST_CASE_FIXTURE(Fixture, "repeat_loop") @@ -310,7 +310,7 @@ TEST_CASE_FIXTURE(Fixture, "repeat_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("i")); + CHECK_EQ(*builtinTypes->stringType, *requireType("i")); } TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") @@ -547,7 +547,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("key")); + CHECK_EQ(*builtinTypes->numberType, *requireType("key")); } TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") @@ -561,7 +561,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") )"); LUAU_REQUIRE_ERROR_COUNT(0, result); - CHECK_EQ(*typeChecker.nilType, *requireType("extra")); + CHECK_EQ(*builtinTypes->nilType, *requireType("extra")); } TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_strict") diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index ed3af11b4..ab07ee2da 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -12,6 +12,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) using namespace Luau; @@ -482,4 +483,42 @@ return unpack(l0[_]) LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "check_imported_module_names") +{ + fileResolver.source["game/A"] = R"( +return function(...) end + )"; + + fileResolver.source["game/B"] = R"( +local l0 = require(game.A) +return l0 + )"; + + CheckResult result = check(R"( +local l0 = require(game.B) +if true then + local l1 = require(game.A) +end +return l0 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr mod = getMainModule(); + REQUIRE(mod); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + REQUIRE(mod->scopes.size() >= 4); + CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); + CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); + } + else + { + + REQUIRE(mod->scopes.size() >= 3); + CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); + CHECK(mod->scopes[2].second->importedModules["l1"] == "game/A"); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index cf27518a6..ab41ce37e 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -290,4 +290,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "set_prop_of_intersection_containing_metatabl )"); } +// DCR once had a bug in the following code where it would erroneously bind the 'self' table to itself. +TEST_CASE_FIXTURE(Fixture, "dont_bind_free_tables_to_themselves") +{ + CheckResult result = check(R"( + local T = {} + local b: any + + function T:m() + local a = b[i] + if a then + self:n() + if self:p(a) then + self:n() + end + end + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index d75f00a2d..dcdc2e313 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -48,7 +48,7 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") local x:string = s )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("s"), *typeChecker.stringType); + CHECK_EQ(*requireType("s"), *builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "and_does_not_always_add_boolean") @@ -72,7 +72,7 @@ TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") local x:boolean = s )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("x"), *typeChecker.booleanType); + CHECK_EQ(*requireType("x"), *builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "and_or_ternary") @@ -99,9 +99,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") std::optional retType = first(functionType->retTypes); REQUIRE(retType.has_value()); - CHECK_EQ(typeChecker.numberType, follow(*retType)); - CHECK_EQ(requireType("n"), typeChecker.numberType); - CHECK_EQ(requireType("s"), typeChecker.stringType); + CHECK_EQ(builtinTypes->numberType, follow(*retType)); + CHECK_EQ(requireType("n"), builtinTypes->numberType); + CHECK_EQ(requireType("s"), builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") @@ -112,7 +112,7 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType); + CHECK_EQ(requireType("SOLAR_MASS"), builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable") @@ -248,8 +248,12 @@ TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_m LUAU_REQUIRE_ERROR_COUNT(1, result); GenericError* gen = get(result.errors[0]); + REQUIRE(gen != nullptr); - REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK(gen->message == "Types 'a' and 'b' cannot be compared with < because neither type has a metatable"); + else + REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); } TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") @@ -270,7 +274,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_indirectly_compare_types_that_do_not_ GenericError* gen = get(result.errors[0]); REQUIRE(gen != nullptr); - REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK(gen->message == "Types 'M' and 'M' cannot be compared with < because neither type's metatable has a '__lt' metamethod"); + else + REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); } TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") @@ -353,7 +361,7 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op") s += true )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}})); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{builtinTypes->numberType, builtinTypes->booleanType}})); } TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") @@ -364,8 +372,8 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); } TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") @@ -521,7 +529,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") CHECK_EQ("string", toString(requireType("a"))); TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.booleanType); + REQUIRE_EQ(*tm->wantedType, *builtinTypes->booleanType); // given type is the typeof(foo) which is complex to compare against } @@ -547,8 +555,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") CHECK_EQ("number", toString(requireType("a"))); TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); - REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); + REQUIRE_EQ(*tm->wantedType, *builtinTypes->numberType); + REQUIRE_EQ(*tm->givenType, *builtinTypes->stringType); } TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean") @@ -596,8 +604,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(*tm->wantedType, *typeChecker.numberType); - CHECK_EQ(*tm->givenType, *typeChecker.stringType); + CHECK_EQ(*tm->wantedType, *builtinTypes->numberType); + CHECK_EQ(*tm->givenType, *builtinTypes->stringType); GenericError* gen1 = get(result.errors[1]); REQUIRE(gen1); @@ -608,7 +616,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables TypeMismatch* tm2 = get(result.errors[2]); REQUIRE(tm2); - CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); + CHECK_EQ(*tm2->wantedType, *builtinTypes->numberType); CHECK_EQ(*tm2->givenType, *requireType("foo")); } @@ -802,7 +810,7 @@ local b: number = 1 or a TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); CHECK_EQ("number?", toString(tm->givenType)); } diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 02fdfa36e..949e64a57 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -57,7 +57,7 @@ TEST_CASE_FIXTURE(Fixture, "string_method") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("p"), *typeChecker.numberType); + CHECK_EQ(*requireType("p"), *builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "string_function_indirect") @@ -69,7 +69,7 @@ TEST_CASE_FIXTURE(Fixture, "string_function_indirect") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("p"), *typeChecker.stringType); + CHECK_EQ(*requireType("p"), *builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 570cf278e..064ec164a 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -30,9 +30,8 @@ std::optional> magicFunctionInstanceIsA( if (!lvalue || !tfun) return std::nullopt; - unfreeze(typeChecker.globalTypes); - TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); - freeze(typeChecker.globalTypes); + ModulePtr module = typeChecker.currentModule; + TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType}); return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } @@ -62,47 +61,47 @@ struct RefinementClassFixture : BuiltinsFixture { RefinementClassFixture() { - TypeArena& arena = typeChecker.globalTypes; - NotNull scope{typeChecker.globalScope.get()}; + TypeArena& arena = frontend.globals.globalTypes; + NotNull scope{frontend.globals.globalScope.get()}; - std::optional rootSuper = FFlag::LuauNegatedClassTypes ? std::make_optional(typeChecker.builtinTypes->classType) : std::nullopt; + std::optional rootSuper = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; unfreeze(arena); TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); getMutable(vec3)->props = { - {"X", Property{typeChecker.numberType}}, - {"Y", Property{typeChecker.numberType}}, - {"Z", Property{typeChecker.numberType}}, + {"X", Property{builtinTypes->numberType}}, + {"Y", Property{builtinTypes->numberType}}, + {"Z", Property{builtinTypes->numberType}}, }; TypeId inst = arena.addType(ClassType{"Instance", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); - TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); - TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); + TypePackId isAParams = arena.addTypePack({inst, builtinTypes->stringType}); + TypePackId isARets = arena.addTypePack({builtinTypes->booleanType}); TypeId isA = arena.addType(FunctionType{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; getMutable(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA; getMutable(inst)->props = { - {"Name", Property{typeChecker.stringType}}, + {"Name", Property{builtinTypes->stringType}}, {"IsA", Property{isA}}, }; - TypeId folder = typeChecker.globalTypes.addType(ClassType{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); - TypeId part = typeChecker.globalTypes.addType(ClassType{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); + TypeId folder = frontend.globals.globalTypes.addType(ClassType{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); + TypeId part = frontend.globals.globalTypes.addType(ClassType{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); getMutable(part)->props = { {"Position", Property{vec3}}, }; - typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; - typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; - typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; - typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; + frontend.globals.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; + frontend.globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; + frontend.globals.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; + frontend.globals.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; - for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + for (const auto& [name, ty] : frontend.globals.globalScope->exportedTypeBindings) persist(ty.type); - freeze(typeChecker.globalTypes); + freeze(frontend.globals.globalTypes); } }; } // namespace diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2a87f0e3b..0f5e3d310 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -198,7 +198,7 @@ TEST_CASE_FIXTURE(Fixture, "call_method") CheckResult result = check("local T = {} T.x = 0 function T:method() return self.x end local a = T:method()"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("a")); + CHECK_EQ(*builtinTypes->numberType, *requireType("a")); } TEST_CASE_FIXTURE(Fixture, "call_method_with_explicit_self_argument") @@ -576,8 +576,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_array") REQUIRE(bool(ttv->indexer)); - CHECK_EQ(*ttv->indexer->indexType, *typeChecker.numberType); - CHECK_EQ(*ttv->indexer->indexResultType, *typeChecker.stringType); + CHECK_EQ(*ttv->indexer->indexType, *builtinTypes->numberType); + CHECK_EQ(*ttv->indexer->indexResultType, *builtinTypes->stringType); } /* This is a bit weird. @@ -685,8 +685,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_array_like_table") REQUIRE(bool(ttv->indexer)); const TableIndexer& indexer = *ttv->indexer; - CHECK_EQ(*typeChecker.numberType, *indexer.indexType); - CHECK_EQ(*typeChecker.stringType, *indexer.indexResultType); + CHECK_EQ(*builtinTypes->numberType, *indexer.indexType); + CHECK_EQ(*builtinTypes->stringType, *indexer.indexResultType); } TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") @@ -740,8 +740,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_its_variable_type_and_unifiable") REQUIRE(tTy != nullptr); REQUIRE(tTy->indexer); - CHECK_EQ(*typeChecker.numberType, *tTy->indexer->indexType); - CHECK_EQ(*typeChecker.stringType, *tTy->indexer->indexResultType); + CHECK_EQ(*builtinTypes->numberType, *tTy->indexer->indexType); + CHECK_EQ(*builtinTypes->stringType, *tTy->indexer->indexResultType); } TEST_CASE_FIXTURE(Fixture, "indexer_mismatch") @@ -844,7 +844,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_type_when_indexing_from_a_table_indexer") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_possible") @@ -865,13 +865,13 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_ LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.stringType, *requireType("a1")); - CHECK_EQ(*typeChecker.stringType, *requireType("a2")); + CHECK_EQ(*builtinTypes->stringType, *requireType("a1")); + CHECK_EQ(*builtinTypes->stringType, *requireType("a2")); - CHECK_EQ(*typeChecker.numberType, *requireType("b1")); - CHECK_EQ(*typeChecker.numberType, *requireType("b2")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b1")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b2")); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ(*builtinTypes->numberType, *requireType("c")); CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } @@ -932,7 +932,7 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s LUAU_REQUIRE_NO_ERRORS(result); - CHECK("string" == toString(*typeChecker.stringType)); + CHECK("string" == toString(*builtinTypes->stringType)); TableType* tableType = getMutable(requireType("t")); REQUIRE(tableType != nullptr); @@ -941,7 +941,7 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s TypeId propertyA = tableType->props["a"].type; REQUIRE(propertyA != nullptr); - CHECK_EQ(*typeChecker.stringType, *propertyA); + CHECK_EQ(*builtinTypes->stringType, *propertyA); } TEST_CASE_FIXTURE(BuiltinsFixture, "oop_indexer_works") @@ -964,7 +964,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "oop_indexer_works") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("words")); + CHECK_EQ(*builtinTypes->stringType, *requireType("words")); } TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_table") @@ -977,7 +977,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_table") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("b")); + CHECK_EQ(*builtinTypes->stringType, *requireType("b")); } TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_fn") @@ -988,7 +988,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_fn") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("b")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b")); } TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add") @@ -1102,10 +1102,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "oop_polymorphic") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.booleanType, *requireType("alive")); - CHECK_EQ(*typeChecker.stringType, *requireType("movement")); - CHECK_EQ(*typeChecker.stringType, *requireType("name")); - CHECK_EQ(*typeChecker.numberType, *requireType("speed")); + CHECK_EQ(*builtinTypes->booleanType, *requireType("alive")); + CHECK_EQ(*builtinTypes->stringType, *requireType("movement")); + CHECK_EQ(*builtinTypes->stringType, *requireType("name")); + CHECK_EQ(*builtinTypes->numberType, *requireType("speed")); } TEST_CASE_FIXTURE(Fixture, "user_defined_table_types_are_named") @@ -2477,7 +2477,7 @@ TEST_CASE_FIXTURE(Fixture, "table_length") LUAU_REQUIRE_NO_ERRORS(result); CHECK(nullptr != get(requireType("t"))); - CHECK_EQ(*typeChecker.numberType, *requireType("s")); + CHECK_EQ(*builtinTypes->numberType, *requireType("s")); } TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") @@ -2498,8 +2498,8 @@ TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") CHECK((Location{Position{3, 15}, Position{3, 18}}) == result.errors[0].location); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK(tm->wantedType == typeChecker.numberType); - CHECK(tm->givenType == typeChecker.stringType); + CHECK(tm->wantedType == builtinTypes->numberType); + CHECK(tm->givenType == builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") @@ -2510,8 +2510,8 @@ TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(result.errors[0], (TypeError{Location{Position{2, 17}, Position{2, 20}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.nilType, + builtinTypes->numberType, + builtinTypes->nilType, }})); } @@ -2709,7 +2709,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_cant_be_used_to_mutate_global_types") Fixture fix; // inherit env from parent fixture checker - fix.typeChecker.globalScope = typeChecker.globalScope; + fix.frontend.globals.globalScope = frontend.globals.globalScope; fix.check(R"( --!nonstrict @@ -2723,7 +2723,7 @@ end // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down // note: it's important for typeck to be destroyed at this point! { - for (auto& p : typeChecker.globalScope->bindings) + for (auto& p : frontend.globals.globalScope->bindings) { toString(p.second.typeId); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas } @@ -3318,8 +3318,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; // Changes argument from table type to primitive - CheckResult result = check(R"( local function f(s): string local foo = s:lower() @@ -3350,8 +3348,6 @@ caused by: TEST_CASE_FIXTURE(BuiltinsFixture, "a_free_shape_can_turn_into_a_scalar_directly") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( local function stringByteList(str) local out = {} @@ -3457,8 +3453,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated") TEST_CASE_FIXTURE(Fixture, "fuzz_table_indexer_unification_can_bound_owner_to_string") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( sin,_ = nil _ = _[_.sin][_._][_][_]._ @@ -3470,8 +3464,6 @@ _[_] = _ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_extra_prop_unification_can_bound_owner_to_string") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( l0,_ = nil _ = _,_[_.n5]._[_][_][_]._ @@ -3483,8 +3475,6 @@ _._.foreach[_],_ = _[_],_._ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_typelevel_promote_on_changed_table_type") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( _._,_ = nil _ = _.foreach[_]._,_[_.n5]._[_.foreach][_][_]._ @@ -3498,8 +3488,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_unify_instantiated_table") { ScopedFastFlag sff[]{ {"LuauInstantiateInSubtyping", true}, - {"LuauScalarShapeUnifyToMtOwner2", true}, - {"LuauTableUnifyInstantiationFix", true}, }; CheckResult result = check(R"( @@ -3517,8 +3505,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_table_unify_instantiated_table_with_prop_reallo { ScopedFastFlag sff[]{ {"LuauInstantiateInSubtyping", true}, - {"LuauScalarShapeUnifyToMtOwner2", true}, - {"LuauTableUnifyInstantiationFix", true}, }; CheckResult result = check(R"( @@ -3537,12 +3523,6 @@ end) TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_unify_prop_realloc") { - // For this test, we don't need LuauInstantiateInSubtyping - ScopedFastFlag sff[]{ - {"LuauScalarShapeUnifyToMtOwner2", true}, - {"LuauTableUnifyInstantiationFix", true}, - }; - CheckResult result = check(R"( n3,_ = nil _ = _[""]._,_[l0][_._][{[_]=_,_=_,}][_G].number diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 3865e83a8..417f80a84 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -45,7 +45,8 @@ TEST_CASE_FIXTURE(Fixture, "tc_error") CheckResult result = check("local a = 7 local b = 'hi' a = b"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK_EQ( + result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); } TEST_CASE_FIXTURE(Fixture, "tc_error_2") @@ -55,7 +56,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_error_2") CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 18}, Position{0, 22}}, TypeMismatch{ requireType("a"), - typeChecker.stringType, + builtinTypes->stringType, }})); } @@ -123,8 +124,8 @@ TEST_CASE_FIXTURE(Fixture, "if_statement") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("a")); - CHECK_EQ(*typeChecker.numberType, *requireType("b")); + CHECK_EQ(*builtinTypes->stringType, *requireType("a")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b")); } TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") @@ -256,7 +257,13 @@ TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowi end )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + } + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") @@ -580,7 +587,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); } @@ -1150,8 +1157,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_retu TEST_CASE_FIXTURE(Fixture, "fuzz_free_table_type_change_during_index_check") { - ScopedFastFlag sff{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( local _ = nil while _["" >= _] do @@ -1175,4 +1180,24 @@ local b = typeof(foo) ~= 'nil' CHECK(toString(result.errors[1]) == "Unknown global 'foo'"); } +TEST_CASE_FIXTURE(Fixture, "dcr_delays_expansion_of_function_containing_blocked_parameter_type") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + {"LuauTinyUnifyNormalsFix", true}, + }; + + CheckResult result = check(R"( + local b: any + + function f(x) + local a = b[1] or 'Cn' + local c = x[1] + + if a:sub(1, #c) == c then + end + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 47b140a14..66e070139 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -38,7 +38,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") { Type functionOne{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; + TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))}}; Type functionTwo{TypeVariant{ FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; @@ -55,13 +55,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") { TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; Type functionOne{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; + TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))}}; Type functionOneSaved = functionOne; TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; Type functionTwo{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.stringType}))}}; + TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->stringType}))}}; Type functionTwoSaved = functionTwo; @@ -96,12 +96,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") { Type tableOne{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.numberType}}}, std::nullopt, globalScope->level, + TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; Type tableTwo{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.stringType}}}, std::nullopt, globalScope->level, + TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->stringType}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; @@ -214,8 +214,8 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails" TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") { - TypePackVar testPack{TypePack{{typeChecker.numberType, typeChecker.stringType}, std::nullopt}}; - TypePackVar variadicPack{VariadicTypePack{typeChecker.numberType}}; + TypePackVar testPack{TypePack{{builtinTypes->numberType, builtinTypes->stringType}, std::nullopt}}; + TypePackVar variadicPack{VariadicTypePack{builtinTypes->numberType}}; state.tryUnify(&testPack, &variadicPack); CHECK(!state.errors.empty()); @@ -223,9 +223,9 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") { - TypePackVar variadicPack{VariadicTypePack{typeChecker.booleanType}}; - TypePackVar a{TypePack{{typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType, typeChecker.booleanType}}}; - TypePackVar b{TypePack{{typeChecker.numberType, typeChecker.stringType}, &variadicPack}}; + TypePackVar variadicPack{VariadicTypePack{builtinTypes->booleanType}}; + TypePackVar a{TypePack{{builtinTypes->numberType, builtinTypes->stringType, builtinTypes->booleanType, builtinTypes->booleanType}}}; + TypePackVar b{TypePack{{builtinTypes->numberType, builtinTypes->stringType}, &variadicPack}}; state.tryUnify(&b, &a); CHECK(state.errors.empty()); @@ -266,8 +266,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unifica TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") { - TypePackId threeNumbers = arena.addTypePack(TypePack{{typeChecker.numberType, typeChecker.numberType, typeChecker.numberType}, std::nullopt}); - TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); + TypePackId threeNumbers = + arena.addTypePack(TypePack{{builtinTypes->numberType, builtinTypes->numberType, builtinTypes->numberType}, std::nullopt}); + TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{builtinTypes->numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); ErrorVec unifyErrors = state.canUnify(numberAndFreeTail, threeNumbers); CHECK(unifyErrors.size() == 0); @@ -279,7 +280,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") Type table{TableType{}}; Type metatable{MetatableType{&redirect, &table}}; redirect = BoundType{&metatable}; // Now we have a metatable that is recursive on the table type - Type variant{UnionType{{&metatable, typeChecker.numberType}}}; + Type variant{UnionType{{&metatable, builtinTypes->numberType}}}; state.tryUnify(&metatable, &variant); } @@ -293,13 +294,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") state.tryUnify(&free, &target); // Shouldn't assert or error. - state.tryUnify(&func, typeChecker.anyType); + state.tryUnify(&func, builtinTypes->anyType); } TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") { TypeId a = arena.addType(Type{FreeType{TypeLevel{}}}); - TypeId b = typeChecker.numberType; + TypeId b = builtinTypes->numberType; state.tryUnify(a, b); state.log.commit(); @@ -310,7 +311,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") { TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); - TypePackId b = typeChecker.anyTypePack; + TypePackId b = builtinTypes->anyTypePack; state.tryUnify(a, b); state.log.commit(); @@ -323,13 +324,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table ScopedFastFlag sff("DebugLuauDeferredConstraintResolution", true); TableType::Props freeProps{ - {"foo", {typeChecker.numberType}}, + {"foo", {builtinTypes->numberType}}, }; TypeId free = arena.addType(TableType{freeProps, std::nullopt, TypeLevel{}, TableState::Free}); TableType::Props indexProps{ - {"foo", {typeChecker.stringType}}, + {"foo", {builtinTypes->stringType}}, }; TypeId index = arena.addType(TableType{indexProps, std::nullopt, TypeLevel{}, TableState::Sealed}); @@ -356,9 +357,9 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") { - TypePackVar variadicAny{VariadicTypePack{typeChecker.anyType}}; - TypePackVar packTmp{TypePack{{typeChecker.anyType}, &variadicAny}}; - TypePackVar packSub{TypePack{{typeChecker.anyType, typeChecker.anyType}, &packTmp}}; + TypePackVar variadicAny{VariadicTypePack{builtinTypes->anyType}}; + TypePackVar packTmp{TypePack{{builtinTypes->anyType}, &variadicAny}}; + TypePackVar packSub{TypePack{{builtinTypes->anyType, builtinTypes->anyType}, &packTmp}}; Type freeTy{FreeType{TypeLevel{}}}; TypePackVar freeTp{FreeTypePack{TypeLevel{}}}; diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 78eb6d477..441191664 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -27,8 +27,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const auto& [returns, tail] = flatten(takeTwoType->retTypes); CHECK_EQ(2, returns.size()); - CHECK_EQ(typeChecker.numberType, follow(returns[0])); - CHECK_EQ(typeChecker.numberType, follow(returns[1])); + CHECK_EQ(builtinTypes->numberType, follow(returns[0])); + CHECK_EQ(builtinTypes->numberType, follow(returns[1])); CHECK(!tail); } @@ -74,9 +74,9 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const auto& [rets, tail] = flatten(takeOneMoreType->retTypes); REQUIRE_EQ(3, rets.size()); - CHECK_EQ(typeChecker.numberType, follow(rets[0])); - CHECK_EQ(typeChecker.numberType, follow(rets[1])); - CHECK_EQ(typeChecker.numberType, follow(rets[2])); + CHECK_EQ(builtinTypes->numberType, follow(rets[0])); + CHECK_EQ(builtinTypes->numberType, follow(rets[1])); + CHECK_EQ(builtinTypes->numberType, follow(rets[2])); CHECK(!tail); } @@ -184,28 +184,28 @@ TEST_CASE_FIXTURE(Fixture, "parenthesized_varargs_returns_any") TEST_CASE_FIXTURE(Fixture, "variadic_packs") { - TypeArena& arena = typeChecker.globalTypes; + TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); - TypePackId listOfNumbers = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); - TypePackId listOfStrings = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.stringType}}); + TypePackId listOfNumbers = arena.addTypePack(TypePackVar{VariadicTypePack{builtinTypes->numberType}}); + TypePackId listOfStrings = arena.addTypePack(TypePackVar{VariadicTypePack{builtinTypes->stringType}}); // clang-format off - addGlobalBinding(frontend, "foo", + addGlobalBinding(frontend.globals, "foo", arena.addType( FunctionType{ listOfNumbers, - arena.addTypePack({typeChecker.numberType}) + arena.addTypePack({builtinTypes->numberType}) } ), "@test" ); - addGlobalBinding(frontend, "bar", + addGlobalBinding(frontend.globals, "bar", arena.addType( FunctionType{ - arena.addTypePack({{typeChecker.numberType}, listOfStrings}), - arena.addTypePack({typeChecker.numberType}) + arena.addTypePack({{builtinTypes->numberType}, listOfStrings}), + arena.addTypePack({builtinTypes->numberType}) } ), "@test" @@ -223,9 +223,11 @@ TEST_CASE_FIXTURE(Fixture, "variadic_packs") LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location(Position{3, 21}, Position{3, 26}), TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK_EQ( + result.errors[0], (TypeError{Location(Position{3, 21}, Position{3, 26}), TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location(Position{4, 29}, Position{4, 30}), TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); + CHECK_EQ( + result.errors[1], (TypeError{Location(Position{4, 29}, Position{4, 30}), TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); } TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 6f69d6827..704e2a3b4 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -131,7 +131,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_property_guaranteed_to_ex )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("r")); + CHECK_EQ(*builtinTypes->numberType, *requireType("r")); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_mixed_types") @@ -211,7 +211,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any" )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.anyType, *requireType("r")); + CHECK_EQ(*builtinTypes->anyType, *requireType("r")); } TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") @@ -245,7 +245,7 @@ local c = bf.a.y )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ(*builtinTypes->numberType, *requireType("c")); CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } @@ -260,7 +260,7 @@ TEST_CASE_FIXTURE(Fixture, "optional_union_functions") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ(*builtinTypes->numberType, *requireType("c")); CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } @@ -275,7 +275,7 @@ local c = b:foo(1, 2) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ(*builtinTypes->numberType, *requireType("c")); CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } diff --git a/tests/TypeReduction.test.cpp b/tests/TypeReduction.test.cpp index 582725b74..5f11a71b7 100644 --- a/tests/TypeReduction.test.cpp +++ b/tests/TypeReduction.test.cpp @@ -31,6 +31,12 @@ struct ReductionFixture : Fixture return *reducedTy; } + std::optional tryReduce(const std::string& annotation) + { + check("type _Res = " + annotation); + return reduction.reduce(requireTypeAlias("_Res")); + } + TypeId reductionof(const std::string& annotation) { check("type _Res = " + annotation); @@ -1488,4 +1494,16 @@ TEST_CASE_FIXTURE(ReductionFixture, "cycles") } } +TEST_CASE_FIXTURE(ReductionFixture, "string_singletons") +{ + TypeId ty = reductionof("(string & Not<\"A\">)?"); + CHECK("(string & ~\"A\")?" == toStringFull(ty)); +} + +TEST_CASE_FIXTURE(ReductionFixture, "string_singletons_2") +{ + TypeId ty = reductionof("Not<\"A\"> & Not<\"B\"> & (string?)"); + CHECK("(string & ~\"A\" & ~\"B\")?" == toStringFull(ty)); +} + TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 36e437e24..64ba63c8d 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -16,13 +16,13 @@ TEST_SUITE_BEGIN("TypeTests"); TEST_CASE_FIXTURE(Fixture, "primitives_are_equal") { - REQUIRE_EQ(typeChecker.booleanType, typeChecker.booleanType); + REQUIRE_EQ(builtinTypes->booleanType, builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "bound_type_is_equal_to_that_which_it_is_bound") { - Type bound(BoundType(typeChecker.booleanType)); - REQUIRE_EQ(bound, *typeChecker.booleanType); + Type bound(BoundType(builtinTypes->booleanType)); + REQUIRE_EQ(bound, *builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "equivalent_cyclic_tables_are_equal") @@ -54,8 +54,8 @@ TEST_CASE_FIXTURE(Fixture, "different_cyclic_tables_are_not_equal") TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_not_parenthesized_if_just_one_value") { auto emptyArgumentPack = TypePackVar{TypePack{}}; - auto returnPack = TypePackVar{TypePack{{typeChecker.numberType}}}; - auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType}}}; + auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ("() -> number", res); @@ -64,8 +64,8 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_not_parenthesized_if_just TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_not_just_one_value") { auto emptyArgumentPack = TypePackVar{TypePack{}}; - auto returnPack = TypePackVar{TypePack{{typeChecker.numberType, typeChecker.numberType}}}; - auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType, builtinTypes->numberType}}}; + auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ("() -> (number, number)", res); @@ -76,8 +76,8 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_ auto emptyArgumentPack = TypePackVar{TypePack{}}; auto free = Unifiable::Free(TypeLevel()); auto freePack = TypePackVar{TypePackVariant{free}}; - auto returnPack = TypePackVar{TypePack{{typeChecker.numberType}, &freePack}}; - auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType}, &freePack}}; + auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ(res, "() -> (number, a...)"); @@ -86,9 +86,9 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_ TEST_CASE_FIXTURE(Fixture, "subset_check") { UnionType super, sub, notSub; - super.options = {typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType}; - sub.options = {typeChecker.numberType, typeChecker.stringType}; - notSub.options = {typeChecker.numberType, typeChecker.nilType}; + super.options = {builtinTypes->numberType, builtinTypes->stringType, builtinTypes->booleanType}; + sub.options = {builtinTypes->numberType, builtinTypes->stringType}; + notSub.options = {builtinTypes->numberType, builtinTypes->nilType}; CHECK(isSubset(super, sub)); CHECK(!isSubset(super, notSub)); @@ -97,7 +97,7 @@ TEST_CASE_FIXTURE(Fixture, "subset_check") TEST_CASE_FIXTURE(Fixture, "iterate_over_UnionType") { UnionType utv; - utv.options = {typeChecker.numberType, typeChecker.stringType, typeChecker.anyType}; + utv.options = {builtinTypes->numberType, builtinTypes->stringType, builtinTypes->anyType}; std::vector result; for (TypeId ty : &utv) @@ -110,19 +110,19 @@ TEST_CASE_FIXTURE(Fixture, "iterating_over_nested_UnionTypes") { Type subunion{UnionType{}}; UnionType* innerUtv = getMutable(&subunion); - innerUtv->options = {typeChecker.numberType, typeChecker.stringType}; + innerUtv->options = {builtinTypes->numberType, builtinTypes->stringType}; UnionType utv; - utv.options = {typeChecker.anyType, &subunion}; + utv.options = {builtinTypes->anyType, &subunion}; std::vector result; for (TypeId ty : &utv) result.push_back(ty); REQUIRE_EQ(result.size(), 3); - CHECK_EQ(result[0], typeChecker.anyType); - CHECK_EQ(result[2], typeChecker.stringType); - CHECK_EQ(result[1], typeChecker.numberType); + CHECK_EQ(result[0], builtinTypes->anyType); + CHECK_EQ(result[2], builtinTypes->stringType); + CHECK_EQ(result[1], builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypes_and_skips_over_them") @@ -132,8 +132,8 @@ TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypes_and_skips_over_th Type btv{UnionType{}}; UnionType* utv2 = getMutable(&btv); - utv2->options.push_back(typeChecker.numberType); - utv2->options.push_back(typeChecker.stringType); + utv2->options.push_back(builtinTypes->numberType); + utv2->options.push_back(builtinTypes->stringType); utv2->options.push_back(&atv); utv1->options.push_back(&btv); @@ -143,14 +143,14 @@ TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypes_and_skips_over_th result.push_back(ty); REQUIRE_EQ(result.size(), 2); - CHECK_EQ(result[0], typeChecker.numberType); - CHECK_EQ(result[1], typeChecker.stringType); + CHECK_EQ(result[0], builtinTypes->numberType); + CHECK_EQ(result[1], builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "iterator_descends_on_nested_in_first_operator*") { - Type tv1{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; - Type tv2{UnionType{{&tv1, typeChecker.booleanType}}}; + Type tv1{UnionType{{builtinTypes->stringType, builtinTypes->numberType}}}; + Type tv2{UnionType{{&tv1, builtinTypes->booleanType}}}; auto utv = get(&tv2); std::vector result; @@ -158,19 +158,19 @@ TEST_CASE_FIXTURE(Fixture, "iterator_descends_on_nested_in_first_operator*") result.push_back(ty); REQUIRE_EQ(result.size(), 3); - CHECK_EQ(result[0], typeChecker.stringType); - CHECK_EQ(result[1], typeChecker.numberType); - CHECK_EQ(result[2], typeChecker.booleanType); + CHECK_EQ(result[0], builtinTypes->stringType); + CHECK_EQ(result[1], builtinTypes->numberType); + CHECK_EQ(result[2], builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "UnionTypeIterator_with_vector_iter_ctor") { - Type tv1{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; - Type tv2{UnionType{{&tv1, typeChecker.booleanType}}}; + Type tv1{UnionType{{builtinTypes->stringType, builtinTypes->numberType}}}; + Type tv2{UnionType{{&tv1, builtinTypes->booleanType}}}; auto utv = get(&tv2); std::vector actual(begin(utv), end(utv)); - std::vector expected{typeChecker.stringType, typeChecker.numberType, typeChecker.booleanType}; + std::vector expected{builtinTypes->stringType, builtinTypes->numberType, builtinTypes->booleanType}; CHECK_EQ(actual, expected); } @@ -273,10 +273,10 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId root = &ttvTweenResult; - typeChecker.currentModule = std::make_shared(); - typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(builtinTypes->anyTypePack)); + frontend.typeChecker.currentModule = std::make_shared(); + frontend.typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(builtinTypes->anyTypePack)); - TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); + TypeId result = frontend.typeChecker.anyify(frontend.globals.globalScope, root, Location{}); CHECK_EQ("{| f: t1 |} where t1 = () -> {| f: () -> {| f: ({| f: t1 |}) -> (), signal: {| f: (any) -> () |} |} |}", toString(result)); } diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index e23c1a53f..f4a91fc38 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -120,6 +120,8 @@ assert((function() local a a = nil local b = 2 b = a and b return b end)() == ni assert((function() local a a = 1 local b = 2 b = a or b return b end)() == 1) assert((function() local a a = nil local b = 2 b = a or b return b end)() == 2) +assert((function(a) return 12 % a end)(5) == 2) + -- binary arithmetics coerces strings to numbers (sadly) assert(1 + "2" == 3) assert(2 * "0xa" == 20) diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index e2f68e654..18ed13706 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -281,6 +281,8 @@ assert(math.round(-0.4) == 0) assert(math.round(-0.5) == -1) assert(math.round(-3.5) == -4) assert(math.round(math.huge) == math.huge) +assert(math.round(0.49999999999999994) == 0) +assert(math.round(-0.49999999999999994) == 0) -- fmod assert(math.fmod(3, 2) == 1) diff --git a/tests/conformance/sort.lua b/tests/conformance/sort.lua index 95940e111..693a10dc5 100644 --- a/tests/conformance/sort.lua +++ b/tests/conformance/sort.lua @@ -2,6 +2,47 @@ -- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes print"testing sort" +function checksort(t, f, ...) + assert(#t == select('#', ...)) + local copy = table.clone(t) + table.sort(copy, f) + for i=1,#t do assert(copy[i] == select(i, ...)) end +end + +-- basic edge cases +checksort({}, nil) +checksort({1}, nil, 1) + +-- small inputs +checksort({1, 2}, nil, 1, 2) +checksort({2, 1}, nil, 1, 2) + +checksort({1, 2, 3}, nil, 1, 2, 3) +checksort({2, 1, 3}, nil, 1, 2, 3) +checksort({1, 3, 2}, nil, 1, 2, 3) +checksort({3, 2, 1}, nil, 1, 2, 3) +checksort({3, 1, 2}, nil, 1, 2, 3) + +-- "large" input +checksort({3, 8, 1, 7, 10, 2, 5, 4, 9, 6}, nil, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) +checksort({"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}, nil, "Apr", "Aug", "Dec", "Feb", "Jan", "Jul", "Jun", "Mar", "May", "Nov", "Oct", "Sep") + +-- duplicates +checksort({3, 1, 1, 7, 1, 3, 5, 1, 9, 3}, nil, 1, 1, 1, 1, 3, 3, 3, 5, 7, 9) + +-- predicates +checksort({3, 8, 1, 7, 10, 2, 5, 4, 9, 6}, function (a, b) return a > b end, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) + +-- can't sort readonly tables +assert(pcall(table.sort, table.freeze({2, 1})) == false) + +-- first argument must be a table, second argument must be nil or function +assert(pcall(table.sort) == false) +assert(pcall(table.sort, "abc") == false) +assert(pcall(table.sort, {}, 42) == false) +assert(pcall(table.sort, {}, {}) == false) + +-- legacy Lua tests function check (a, f) f = f or function (x,y) return x Date: Fri, 17 Mar 2023 16:59:30 +0200 Subject: [PATCH 41/66] Sync to upstream/release/568 --- .../include/Luau/ConstraintGraphBuilder.h | 43 +- Analysis/include/Luau/ConstraintSolver.h | 8 + Analysis/include/Luau/ControlFlow.h | 36 ++ Analysis/include/Luau/Frontend.h | 12 +- Analysis/include/Luau/Module.h | 2 + Analysis/include/Luau/Normalize.h | 6 +- Analysis/include/Luau/Scope.h | 1 + Analysis/include/Luau/Type.h | 2 +- Analysis/include/Luau/TypeInfer.h | 43 +- Analysis/include/Luau/Unifiable.h | 2 + Analysis/include/Luau/Unifier.h | 3 +- Analysis/src/ConstraintGraphBuilder.cpp | 164 +++++--- Analysis/src/ConstraintSolver.cpp | 67 ++- Analysis/src/Frontend.cpp | 91 ++++- Analysis/src/Normalize.cpp | 55 ++- Analysis/src/Quantify.cpp | 1 - Analysis/src/Scope.cpp | 22 + Analysis/src/Type.cpp | 5 +- Analysis/src/TypeInfer.cpp | 213 +++++++--- Analysis/src/Unifiable.cpp | 5 + Analysis/src/Unifier.cpp | 168 +++++++- Ast/include/Luau/Parser.h | 62 +-- Ast/src/Parser.cpp | 223 +++++----- CLI/Analyze.cpp | 4 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 7 + CodeGen/include/Luau/AssemblyBuilderX64.h | 7 + CodeGen/include/Luau/CodeAllocator.h | 3 +- CodeGen/include/Luau/IrAnalysis.h | 1 + CodeGen/include/Luau/IrBuilder.h | 46 ++- CodeGen/include/Luau/IrData.h | 45 ++- CodeGen/include/Luau/IrDump.h | 4 + CodeGen/include/Luau/IrUtils.h | 6 +- CodeGen/src/CodeAllocator.cpp | 2 +- CodeGen/src/CodeGen.cpp | 131 +++--- CodeGen/src/CodeGenA64.cpp | 69 ++++ CodeGen/src/CodeGenA64.h | 18 + CodeGen/src/CodeGenX64.cpp | 18 + CodeGen/src/CodeGenX64.h | 4 + CodeGen/src/EmitBuiltinsX64.cpp | 29 ++ CodeGen/src/EmitInstructionX64.cpp | 45 --- CodeGen/src/EmitInstructionX64.h | 1 - CodeGen/src/IrAnalysis.cpp | 33 +- CodeGen/src/IrBuilder.cpp | 47 ++- CodeGen/src/IrDump.cpp | 129 +++++- CodeGen/src/IrLoweringX64.cpp | 58 ++- CodeGen/src/IrTranslateBuiltins.cpp | 32 ++ CodeGen/src/IrTranslation.cpp | 62 ++- CodeGen/src/IrTranslation.h | 1 + CodeGen/src/IrUtils.cpp | 2 +- CodeGen/src/NativeState.cpp | 1 + CodeGen/src/NativeState.h | 2 +- CodeGen/src/OptimizeConstProp.cpp | 19 +- CodeGen/src/UnwindBuilderDwarf2.cpp | 3 + Common/include/Luau/Bytecode.h | 6 +- Common/include/Luau/ExperimentalFlags.h | 5 +- Compiler/src/Compiler.cpp | 3 +- Makefile | 3 + Sources.cmake | 2 + VM/src/ltable.cpp | 34 +- tests/Compiler.test.cpp | 46 ++- tests/ConstraintGraphBuilderFixture.cpp | 2 +- tests/Fixture.cpp | 35 +- tests/Fixture.h | 3 +- tests/Frontend.test.cpp | 42 +- tests/IrBuilder.test.cpp | 84 +++- tests/Linter.test.cpp | 30 +- tests/Normalize.test.cpp | 129 +++++- tests/RuntimeLimits.test.cpp | 3 +- tests/TypeInfer.anyerror.test.cpp | 25 +- tests/TypeInfer.builtins.test.cpp | 13 +- tests/TypeInfer.cfa.test.cpp | 380 ++++++++++++++++++ tests/TypeInfer.loops.test.cpp | 18 + tests/TypeInfer.modules.test.cpp | 17 +- tests/TypeInfer.oop.test.cpp | 17 + tests/TypeInfer.operators.test.cpp | 24 ++ tests/TypeInfer.provisional.test.cpp | 11 +- tests/TypeInfer.refinements.test.cpp | 3 +- tests/TypeInfer.tables.test.cpp | 12 +- tests/TypeInfer.test.cpp | 13 + tests/TypeInfer.tryUnify.test.cpp | 35 +- tests/TypeInfer.unionTypes.test.cpp | 58 +++ tests/TypeInfer.unknownnever.test.cpp | 9 +- tests/conformance/tables.lua | 15 + tools/faillist.txt | 35 +- 84 files changed, 2493 insertions(+), 682 deletions(-) create mode 100644 Analysis/include/Luau/ControlFlow.h create mode 100644 CodeGen/src/CodeGenA64.cpp create mode 100644 CodeGen/src/CodeGenA64.h create mode 100644 tests/TypeInfer.cfa.test.cpp diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index e79c4c91e..204470489 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -2,12 +2,13 @@ #pragma once #include "Luau/Ast.h" -#include "Luau/Refinement.h" #include "Luau/Constraint.h" +#include "Luau/ControlFlow.h" #include "Luau/DataFlowGraph.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" +#include "Luau/Refinement.h" #include "Luau/Symbol.h" #include "Luau/Type.h" #include "Luau/Variant.h" @@ -141,26 +142,26 @@ struct ConstraintGraphBuilder */ void visit(AstStatBlock* block); - void visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); - - void visit(const ScopePtr& scope, AstStat* stat); - void visit(const ScopePtr& scope, AstStatBlock* block); - void visit(const ScopePtr& scope, AstStatLocal* local); - void visit(const ScopePtr& scope, AstStatFor* for_); - void visit(const ScopePtr& scope, AstStatForIn* forIn); - void visit(const ScopePtr& scope, AstStatWhile* while_); - void visit(const ScopePtr& scope, AstStatRepeat* repeat); - void visit(const ScopePtr& scope, AstStatLocalFunction* function); - void visit(const ScopePtr& scope, AstStatFunction* function); - void visit(const ScopePtr& scope, AstStatReturn* ret); - void visit(const ScopePtr& scope, AstStatAssign* assign); - void visit(const ScopePtr& scope, AstStatCompoundAssign* assign); - void visit(const ScopePtr& scope, AstStatIf* ifStatement); - void visit(const ScopePtr& scope, AstStatTypeAlias* alias); - void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); - void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); - void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); - void visit(const ScopePtr& scope, AstStatError* error); + ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); + + ControlFlow visit(const ScopePtr& scope, AstStat* stat); + ControlFlow visit(const ScopePtr& scope, AstStatBlock* block); + ControlFlow visit(const ScopePtr& scope, AstStatLocal* local); + ControlFlow visit(const ScopePtr& scope, AstStatFor* for_); + ControlFlow visit(const ScopePtr& scope, AstStatForIn* forIn); + ControlFlow visit(const ScopePtr& scope, AstStatWhile* while_); + ControlFlow visit(const ScopePtr& scope, AstStatRepeat* repeat); + ControlFlow visit(const ScopePtr& scope, AstStatLocalFunction* function); + ControlFlow visit(const ScopePtr& scope, AstStatFunction* function); + ControlFlow visit(const ScopePtr& scope, AstStatReturn* ret); + ControlFlow visit(const ScopePtr& scope, AstStatAssign* assign); + ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign); + ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement); + ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias); + ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); + ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); + ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); + ControlFlow visit(const ScopePtr& scope, AstStatError* error); InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes = {}); InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector>& expectedTypes = {}); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 4fd7d0d10..e9e1e884a 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -143,6 +143,14 @@ struct ConstraintSolver bool block(TypeId target, NotNull constraint); bool block(TypePackId target, NotNull constraint); + /** + * For all constraints that are blocked on one constraint, make them block + * on a new constraint. + * @param source the constraint to copy blocks from. + * @param addition the constraint that other constraints should now block on. + */ + void inheritBlocks(NotNull source, NotNull addition); + // Traverse the type. If any blocked or pending types are found, block // the constraint on them. // diff --git a/Analysis/include/Luau/ControlFlow.h b/Analysis/include/Luau/ControlFlow.h new file mode 100644 index 000000000..8272bd53e --- /dev/null +++ b/Analysis/include/Luau/ControlFlow.h @@ -0,0 +1,36 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +struct Scope; +using ScopePtr = std::shared_ptr; + +enum class ControlFlow +{ + None = 0b00001, + Returns = 0b00010, + Throws = 0b00100, + Break = 0b01000, // Currently unused. + Continue = 0b10000, // Currently unused. +}; + +inline ControlFlow operator&(ControlFlow a, ControlFlow b) +{ + return ControlFlow(int(a) & int(b)); +} + +inline ControlFlow operator|(ControlFlow a, ControlFlow b) +{ + return ControlFlow(int(a) | int(b)); +} + +inline bool matches(ControlFlow a, ControlFlow b) +{ + return (a & b) != ControlFlow(0); +} + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 9c0366a6c..68ba8ff5d 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -89,14 +89,21 @@ struct FrontendOptions // order to get more precise type information) bool forAutocomplete = false; + bool runLintChecks = false; + // If not empty, randomly shuffle the constraint set before attempting to // solve. Use this value to seed the random number generator. std::optional randomizeConstraintResolutionSeed; + + std::optional enabledLintWarnings; }; struct CheckResult { std::vector errors; + + LintResult lintResult; + std::vector timeoutHits; }; @@ -133,8 +140,9 @@ struct Frontend CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess - LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); - LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); + // Use 'check' with 'runLintChecks' set to true in FrontendOptions (enabledLintWarnings be set there as well) + LintResult lint_DEPRECATED(const ModuleName& name, std::optional enabledLintWarnings = {}); + LintResult lint_DEPRECATED(const SourceModule& module, std::optional enabledLintWarnings = {}); bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 2faa0297f..72f87601d 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Error.h" +#include "Luau/Linter.h" #include "Luau/FileResolver.h" #include "Luau/ParseOptions.h" #include "Luau/ParseResult.h" @@ -88,6 +89,7 @@ struct Module std::unordered_map declaredGlobals; ErrorVec errors; + LintResult lintResult; Mode mode; SourceCode::Type type; bool timeout = false; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 15dc7d4a1..15404707d 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -19,6 +19,8 @@ using ModulePtr = std::shared_ptr; bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); +bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); +bool isConsistentSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); class TypeIds { @@ -203,7 +205,7 @@ struct NormalizedFunctionType }; // A normalized generic/free type is a union, where each option is of the form (X & T) where -// * X is either a free type or a generic +// * X is either a free type, a generic or a blocked type. // * T is a normalized type. struct NormalizedType; using NormalizedTyvars = std::unordered_map>; @@ -214,7 +216,7 @@ bool isInhabited_DEPRECATED(const NormalizedType& norm); // * P is a union of primitive types (including singletons, classes and the error type) // * T is a union of table types // * F is a union of an intersection of function types -// * G is a union of generic/free normalized types, intersected with a normalized type +// * G is a union of generic/free/blocked types, intersected with a normalized type struct NormalizedType { // The top part of the type. diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 0d3972672..745ea47ab 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -66,6 +66,7 @@ struct Scope RefinementMap refinements; DenseHashMap dcrRefinements{nullptr}; + void inheritRefinements(const ScopePtr& childScope); // For mutually recursive type aliases, it's important that // they use the same types for the same names. diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index ef2d4c6a4..dba2a8de2 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -102,7 +102,7 @@ struct BlockedType BlockedType(); int index; - static int nextIndex; + static int DEPRECATED_nextIndex; }; struct PrimitiveType diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 21cb26371..68161794a 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -2,14 +2,15 @@ #pragma once #include "Luau/Anyification.h" -#include "Luau/Predicate.h" +#include "Luau/ControlFlow.h" #include "Luau/Error.h" #include "Luau/Module.h" -#include "Luau/Symbol.h" +#include "Luau/Predicate.h" #include "Luau/Substitution.h" +#include "Luau/Symbol.h" #include "Luau/TxnLog.h" -#include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypePack.h" #include "Luau/Unifier.h" #include "Luau/UnifierSharedState.h" @@ -87,28 +88,28 @@ struct TypeChecker std::vector> getScopes() const; - void check(const ScopePtr& scope, const AstStat& statement); - void check(const ScopePtr& scope, const AstStatBlock& statement); - void check(const ScopePtr& scope, const AstStatIf& statement); - void check(const ScopePtr& scope, const AstStatWhile& statement); - void check(const ScopePtr& scope, const AstStatRepeat& statement); - void check(const ScopePtr& scope, const AstStatReturn& return_); - void check(const ScopePtr& scope, const AstStatAssign& assign); - void check(const ScopePtr& scope, const AstStatCompoundAssign& assign); - void check(const ScopePtr& scope, const AstStatLocal& local); - void check(const ScopePtr& scope, const AstStatFor& local); - void check(const ScopePtr& scope, const AstStatForIn& forin); - void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); - void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); - void check(const ScopePtr& scope, const AstStatTypeAlias& typealias); - void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); - void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); + ControlFlow check(const ScopePtr& scope, const AstStat& statement); + ControlFlow check(const ScopePtr& scope, const AstStatBlock& statement); + ControlFlow check(const ScopePtr& scope, const AstStatIf& statement); + ControlFlow check(const ScopePtr& scope, const AstStatWhile& statement); + ControlFlow check(const ScopePtr& scope, const AstStatRepeat& statement); + ControlFlow check(const ScopePtr& scope, const AstStatReturn& return_); + ControlFlow check(const ScopePtr& scope, const AstStatAssign& assign); + ControlFlow check(const ScopePtr& scope, const AstStatCompoundAssign& assign); + ControlFlow check(const ScopePtr& scope, const AstStatLocal& local); + ControlFlow check(const ScopePtr& scope, const AstStatFor& local); + ControlFlow check(const ScopePtr& scope, const AstStatForIn& forin); + ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); + ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); + ControlFlow check(const ScopePtr& scope, const AstStatTypeAlias& typealias); + ControlFlow check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); + ControlFlow check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0); void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); - void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); - void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); + ControlFlow checkBlock(const ScopePtr& scope, const AstStatBlock& statement); + ControlFlow checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); WithPredicate checkExpr( diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 15e501f02..9c4f01326 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -81,6 +81,8 @@ namespace Luau::Unifiable using Name = std::string; +int freshIndex(); + struct Free { explicit Free(TypeLevel level); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index fc886ac0c..e7817e57c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -58,6 +58,7 @@ struct Unifier NotNull scope; // const Scope maybe TxnLog log; + bool failure = false; ErrorVec errors; Location location; Variance variance = Covariant; @@ -93,7 +94,7 @@ struct Unifier // Traverse the two types provided and block on any BlockedTypes we find. // Returns true if any types were blocked on. - bool blockOnBlockedTypes(TypeId subTy, TypeId superTy); + bool DEPRECATED_blockOnBlockedTypes(TypeId subTy, TypeId superTy); void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 711d357f0..e90cb7d3a 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -5,6 +5,7 @@ #include "Luau/Breadcrumb.h" #include "Luau/Common.h" #include "Luau/Constraint.h" +#include "Luau/ControlFlow.h" #include "Luau/DcrLogger.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" @@ -22,6 +23,7 @@ LUAU_FASTFLAG(LuauNegatedClassTypes); namespace Luau { +bool doesCallError(const AstExprCall* call); // TypeInfer.cpp const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp static std::optional matchRequire(const AstExprCall& call) @@ -344,14 +346,14 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) logger->captureGenerationModule(module); } -void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) +ControlFlow ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) { RecursionCounter counter{&recursionCount}; if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(block->location); - return; + return ControlFlow::None; } std::unordered_map aliasDefinitionLocations; @@ -396,59 +398,77 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, } } + std::optional firstControlFlow; for (AstStat* stat : block->body) - visit(scope, stat); + { + ControlFlow cf = visit(scope, stat); + if (cf != ControlFlow::None && !firstControlFlow) + firstControlFlow = cf; + } + + return firstControlFlow.value_or(ControlFlow::None); } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) { RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto i = stat->as()) - visit(scope, i); + return visit(scope, i); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (stat->is() || stat->is()) { // Nothing + return ControlFlow::None; // TODO: ControlFlow::Break/Continue } else if (auto r = stat->as()) - visit(scope, r); + return visit(scope, r); else if (auto e = stat->as()) + { checkPack(scope, e->expr); + + if (auto call = e->expr->as(); call && doesCallError(call)) + return ControlFlow::Throws; + + return ControlFlow::None; + } else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto a = stat->as()) - visit(scope, a); + return visit(scope, a); else if (auto a = stat->as()) - visit(scope, a); + return visit(scope, a); else if (auto f = stat->as()) - visit(scope, f); + return visit(scope, f); else if (auto f = stat->as()) - visit(scope, f); + return visit(scope, f); else if (auto a = stat->as()) - visit(scope, a); + return visit(scope, a); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else + { LUAU_ASSERT(0 && "Internal error: Unknown AstStat type"); + return ControlFlow::None; + } } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) { std::vector varTypes; varTypes.reserve(local->vars.size); @@ -534,7 +554,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) } } - if (local->vars.size == 1 && local->values.size == 1 && firstValueType) + if (local->vars.size == 1 && local->values.size == 1 && firstValueType && scope.get() == rootScope) { AstLocal* var = local->vars.data[0]; AstExpr* value = local->values.data[0]; @@ -592,9 +612,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) } } } + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) { TypeId annotationTy = builtinTypes->numberType; if (for_->var->annotation) @@ -619,9 +641,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) forScope->dcrRefinements[bc->def] = annotationTy; visit(forScope, for_->body); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) { ScopePtr loopScope = childScope(forIn, scope); @@ -645,27 +669,33 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) addConstraint(loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack}); visit(loopScope, forIn->body); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_) { check(scope, while_->condition); ScopePtr whileScope = childScope(while_, scope); visit(whileScope, while_->body); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) { ScopePtr repeatScope = childScope(repeat, scope); visitBlockWithoutChildScope(repeatScope, repeat->body); check(repeatScope, repeat->condition); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) { // Local // Global @@ -699,9 +729,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* }); addConstraint(scope, std::move(c)); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* function) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* function) { // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. // With or without self @@ -779,9 +811,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct }); addConstraint(scope, std::move(c)); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) { // At this point, the only way scope->returnType should have anything // interesting in it is if the function has an explicit return annotation. @@ -793,13 +827,18 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); + + return ControlFlow::Returns; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) { ScopePtr innerScope = childScope(block, scope); - visitBlockWithoutChildScope(innerScope, block); + ControlFlow flow = visitBlockWithoutChildScope(innerScope, block); + scope->inheritRefinements(innerScope); + + return flow; } static void bindFreeType(TypeId a, TypeId b) @@ -819,7 +858,7 @@ static void bindFreeType(TypeId a, TypeId b) asMutable(b)->ty.emplace(a); } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { std::vector varTypes = checkLValues(scope, assign->vars); @@ -839,9 +878,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) TypePackId varPack = arena->addTypePack({varTypes}); addConstraint(scope, assign->location, PackSubtypeConstraint{exprPack, varPack}); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) { // We need to tweak the BinaryConstraint that we emit, so we cannot use the // strategy of falsifying an AST fragment. @@ -852,23 +893,34 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* addConstraint(scope, assign->location, BinaryConstraint{assign->op, varTy, valueTy, resultType, assign, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); addConstraint(scope, assign->location, SubtypeConstraint{resultType, varTy}); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { - ScopePtr condScope = childScope(ifStatement->condition, scope); - RefinementId refinement = check(condScope, ifStatement->condition, std::nullopt).refinement; + RefinementId refinement = check(scope, ifStatement->condition, std::nullopt).refinement; ScopePtr thenScope = childScope(ifStatement->thenbody, scope); applyRefinements(thenScope, ifStatement->condition->location, refinement); - visit(thenScope, ifStatement->thenbody); + ScopePtr elseScope = childScope(ifStatement->elsebody ? ifStatement->elsebody : ifStatement, scope); + applyRefinements(elseScope, ifStatement->elseLocation.value_or(ifStatement->condition->location), refinementArena.negation(refinement)); + + ControlFlow thencf = visit(thenScope, ifStatement->thenbody); + ControlFlow elsecf = ControlFlow::None; if (ifStatement->elsebody) - { - ScopePtr elseScope = childScope(ifStatement->elsebody, scope); - applyRefinements(elseScope, ifStatement->elseLocation.value_or(ifStatement->condition->location), refinementArena.negation(refinement)); - visit(elseScope, ifStatement->elsebody); - } + elsecf = visit(elseScope, ifStatement->elsebody); + + if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && elsecf == ControlFlow::None) + scope->inheritRefinements(elseScope); + else if (thencf == ControlFlow::None && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + scope->inheritRefinements(thenScope); + + if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + return ControlFlow::Returns; + else + return ControlFlow::None; } static bool occursCheck(TypeId needle, TypeId haystack) @@ -890,7 +942,7 @@ static bool occursCheck(TypeId needle, TypeId haystack) return false; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias) { ScopePtr* defnScope = astTypeAliasDefiningScopes.find(alias); @@ -904,7 +956,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia // case we just skip over it. auto bindingIt = typeBindings->find(alias->name.value); if (bindingIt == typeBindings->end() || defnScope == nullptr) - return; + return ControlFlow::None; TypeId ty = resolveType(*defnScope, alias->type, /* inTypeArguments */ false); @@ -935,9 +987,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia std::move(typeParams), std::move(typePackParams), }); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) { LUAU_ASSERT(global->type); @@ -949,6 +1003,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* BreadcrumbId bc = dfg->getBreadcrumb(global); rootScope->dcrRefinements[bc->def] = globalTy; + + return ControlFlow::None; } static bool isMetamethod(const Name& name) @@ -958,7 +1014,7 @@ static bool isMetamethod(const Name& name) name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) { std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; if (declaredClass->superName) @@ -969,7 +1025,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d if (!lookupType) { reportError(declaredClass->location, UnknownSymbol{superName, UnknownSymbol::Type}); - return; + return ControlFlow::None; } // We don't have generic classes, so this assertion _should_ never be hit. @@ -981,7 +1037,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d reportError(declaredClass->location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)}); - return; + return ControlFlow::None; } } @@ -1056,9 +1112,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d } } } + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) { std::vector> generics = createGenerics(scope, global->generics); std::vector> genericPacks = createGenericPacks(scope, global->genericPacks); @@ -1097,14 +1155,18 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction BreadcrumbId bc = dfg->getBreadcrumb(global); rootScope->dcrRefinements[bc->def] = fnType; + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) { for (AstStat* stat : error->statements) visit(scope, stat); for (AstExpr* expr : error->expressions) check(scope, expr); + + return ControlFlow::None; } InferencePack ConstraintGraphBuilder::checkPack( diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 3c306b40e..5662cf04b 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1273,19 +1273,11 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) - { - if (ic) - block(NotNull{ic}, blockedConstraint); - if (sc) - block(NotNull{sc}, blockedConstraint); - } - } + if (ic) + inheritBlocks(constraint, NotNull{ic}); + + if (sc) + inheritBlocks(constraint, NotNull{sc}); unblock(c.result); return true; @@ -1330,7 +1322,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullty.emplace(result.value_or(builtinTypes->errorRecoveryType())); + asMutable(c.resultType)->ty.emplace(result.value_or(builtinTypes->anyType)); unblock(c.resultType); return true; } @@ -1796,13 +1788,23 @@ bool ConstraintSolver::tryDispatchIterableFunction( return false; } - const TypeId firstIndex = isNil(firstIndexTy) ? arena->freshType(constraint->scope) // FIXME: Surely this should be a union (free | nil) - : firstIndexTy; + TypeId firstIndex; + TypeId retIndex; + if (isNil(firstIndexTy) || isOptional(firstIndexTy)) + { + firstIndex = arena->addType(UnionType{{arena->freshType(constraint->scope), builtinTypes->nilType}}); + retIndex = firstIndex; + } + else + { + firstIndex = firstIndexTy; + retIndex = arena->addType(UnionType{{firstIndexTy, builtinTypes->nilType}}); + } // nextTy : (tableTy, indexTy?) -> (indexTy?, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({tableTy, arena->addType(UnionType{{firstIndex, builtinTypes->nilType}})}); + const TypePackId nextArgPack = arena->addTypePack({tableTy, firstIndex}); const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); - const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); + const TypePackId nextRetPack = arena->addTypePack(TypePack{{retIndex}, valueTailTy}); const TypeId expectedNextTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); unify(nextTy, expectedNextTy, constraint->scope); @@ -1825,7 +1827,8 @@ bool ConstraintSolver::tryDispatchIterableFunction( modifiedNextRetHead.push_back(*it); TypePackId modifiedNextRetPack = arena->addTypePack(std::move(modifiedNextRetHead), it.tail()); - pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, modifiedNextRetPack}); + auto psc = pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, modifiedNextRetPack}); + inheritBlocks(constraint, psc); return true; } @@ -1883,7 +1886,17 @@ std::pair, std::optional> ConstraintSolver::lookupTa TypeId indexType = follow(indexProp->second.type); if (auto ft = get(indexType)) - return {{}, first(ft->retTypes)}; + { + TypePack rets = extendTypePack(*arena, builtinTypes, ft->retTypes, 1); + if (1 == rets.head.size()) + return {{}, rets.head[0]}; + else + { + // This should probably be an error: We need the first result of the MT.__index method, + // but it returns 0 values. See CLI-68672 + return {{}, builtinTypes->nilType}; + } + } else return lookupTableProp(indexType, propName, seen); } @@ -2009,6 +2022,20 @@ bool ConstraintSolver::block(TypePackId target, NotNull constr return false; } +void ConstraintSolver::inheritBlocks(NotNull source, NotNull addition) +{ + // Anything that is blocked on this constraint must also be blocked on our + // synthesized constraints. + auto blockedIt = blocked.find(source.get()); + if (blockedIt != blocked.end()) + { + for (const auto& blockedConstraint : blockedIt->second) + { + block(addition, blockedConstraint); + } + } +} + struct Blocker : TypeOnceVisitor { NotNull solver; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 722f1a2c8..de79e0be8 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -29,6 +29,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) +LUAU_FASTFLAGVARIABLE(LuauLintInTypecheck, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(LuauDefinitionFileSourceModule, false) @@ -330,7 +331,7 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleN namespace { -ErrorVec accumulateErrors( +static ErrorVec accumulateErrors( const std::unordered_map& sourceNodes, const std::unordered_map& modules, const ModuleName& name) { std::unordered_set seen; @@ -375,6 +376,25 @@ ErrorVec accumulateErrors( return result; } +static void filterLintOptions(LintOptions& lintOptions, const std::vector& hotcomments, Mode mode) +{ + LUAU_ASSERT(FFlag::LuauLintInTypecheck); + + uint64_t ignoreLints = LintWarning::parseMask(hotcomments); + + lintOptions.warningMask &= ~ignoreLints; + + if (mode != Mode::NoCheck) + { + lintOptions.disableWarning(Luau::LintWarning::Code_UnknownGlobal); + } + + if (mode == Mode::Strict) + { + lintOptions.disableWarning(Luau::LintWarning::Code_ImplicitReturn); + } +} + // Given a source node (start), find all requires that start a transitive dependency path that ends back at start // For each such path, record the full path and the location of the require in the starting module. // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) @@ -514,8 +534,24 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& modules = + frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + + checkResult.errors = accumulateErrors(sourceNodes, modules, name); + + // Get lint result only for top checked module + if (auto it = modules.find(name); it != modules.end()) + checkResult.lintResult = it->second->lintResult; + + return checkResult; + } + else + { + return CheckResult{accumulateErrors( + sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; + } } std::vector buildQueue; @@ -579,7 +615,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalerrors.clear(); + + if (frontendOptions.runLintChecks) + { + LUAU_TIMETRACE_SCOPE("lint", "Frontend"); + + LUAU_ASSERT(FFlag::LuauLintInTypecheck); + + LintOptions lintOptions = frontendOptions.enabledLintWarnings.value_or(config.enabledLint); + filterLintOptions(lintOptions, sourceModule.hotcomments, mode); + + double timestamp = getTimestamp(); + + std::vector warnings = + Luau::lint(sourceModule.root, *sourceModule.names, environmentScope, module.get(), sourceModule.hotcomments, lintOptions); + + stats.timeLint += getTimestamp() - timestamp; + + module->lintResult = classifyLints(warnings, config); + } + if (!frontendOptions.retainFullTypeGraphs) { // copyErrors needs to allocate into interfaceTypes as it copies @@ -665,6 +724,16 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& modules = + frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + + if (auto it = modules.find(name); it != modules.end()) + checkResult.lintResult = it->second->lintResult; + } + return checkResult; } @@ -793,8 +862,10 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config return result; } -LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) +LintResult Frontend::lint_DEPRECATED(const ModuleName& name, std::optional enabledLintWarnings) { + LUAU_ASSERT(!FFlag::LuauLintInTypecheck); + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); @@ -803,11 +874,13 @@ LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) +LintResult Frontend::lint_DEPRECATED(const SourceModule& module, std::optional enabledLintWarnings) { + LUAU_ASSERT(!FFlag::LuauLintInTypecheck); + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 0552bec03..f8f8b97f8 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -20,8 +20,10 @@ LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) +LUAU_FASTFLAG(LuauTransitiveSubtyping) namespace Luau { @@ -325,6 +327,8 @@ static int tyvarIndex(TypeId ty) return gtv->index; else if (const FreeType* ftv = get(ty)) return ftv->index; + else if (const BlockedType* btv = get(ty)) + return btv->index; else return 0; } @@ -529,7 +533,7 @@ static bool areNormalizedClasses(const NormalizedClassType& tys) static bool isPlainTyvar(TypeId ty) { - return (get(ty) || get(ty)); + return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty))); } static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) @@ -1271,6 +1275,8 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { TypeId tops = unionOfTops(here.tops, there.tops); + if (FFlag::LuauTransitiveSubtyping && get(tops) && (get(here.errors) || get(there.errors))) + tops = builtinTypes->anyType; if (!get(tops)) { clearNormal(here); @@ -1341,12 +1347,21 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor if (get(there) || get(there)) { TypeId tops = unionOfTops(here.tops, there); + if (FFlag::LuauTransitiveSubtyping && get(tops) && get(here.errors)) + tops = builtinTypes->anyType; clearNormal(here); here.tops = tops; return true; } - else if (get(there) || !get(here.tops)) + else if (!FFlag::LuauTransitiveSubtyping && (get(there) || !get(here.tops))) return true; + else if (FFlag::LuauTransitiveSubtyping && (get(there) || get(here.tops))) + return true; + else if (FFlag::LuauTransitiveSubtyping && get(there) && get(here.tops)) + { + here.tops = builtinTypes->anyType; + return true; + } else if (const UnionType* utv = get(there)) { for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) @@ -1363,7 +1378,9 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor return false; return unionNormals(here, norm); } - else if (get(there) || get(there)) + else if (FFlag::LuauTransitiveSubtyping && get(here.tops)) + return true; + else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there))) { if (tyvarIndex(there) <= ignoreSmallerTyvars) return true; @@ -1441,7 +1458,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor if (!unionNormals(here, *tn)) return false; } - else if (get(there)) + else if (!FFlag::LuauNormalizeBlockedTypes && get(there)) LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); else LUAU_ASSERT(!"Unreachable"); @@ -2527,7 +2544,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) return false; return true; } - else if (get(there) || get(there)) + else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there))) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; @@ -2802,6 +2819,32 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) } bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +{ + if (!FFlag::LuauTransitiveSubtyping) + return isConsistentSubtype(subTy, superTy, scope, builtinTypes, ice); + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + + u.tryUnify(subTy, superTy); + return !u.failure; +} + +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +{ + if (!FFlag::LuauTransitiveSubtyping) + return isConsistentSubtype(subPack, superPack, scope, builtinTypes, ice); + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + + u.tryUnify(subPack, superPack); + return !u.failure; +} + +bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; @@ -2813,7 +2856,7 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +bool isConsistentSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 9da43ed2d..0b8f46248 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -27,7 +27,6 @@ struct Quantifier final : TypeOnceVisitor explicit Quantifier(TypeLevel level) : level(level) { - LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); } /// @return true if outer encloses inner diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index cac72124e..f54ebe2a9 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -149,6 +149,28 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } +// Updates the `this` scope with the refinements from the `childScope` excluding ones that doesn't exist in `this`. +void Scope::inheritRefinements(const ScopePtr& childScope) +{ + if (FFlag::DebugLuauDeferredConstraintResolution) + { + for (const auto& [k, a] : childScope->dcrRefinements) + { + if (lookup(NotNull{k})) + dcrRefinements[k] = a; + } + } + else + { + for (const auto& [k, a] : childScope->refinements) + { + Symbol symbol = getBaseSymbol(k); + if (lookup(symbol)) + refinements[k] = a; + } + } +} + bool subsumesStrict(Scope* left, Scope* right) { while (right) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 4bc1223de..42fa40a54 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -25,6 +25,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAGVARIABLE(LuauMatchReturnsOptionalString, false); namespace Luau @@ -431,11 +432,11 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) } BlockedType::BlockedType() - : index(++nextIndex) + : index(FFlag::LuauNormalizeBlockedTypes ? Unifiable::freshIndex() : ++DEPRECATED_nextIndex) { } -int BlockedType::nextIndex = 0; +int BlockedType::DEPRECATED_nextIndex = 0; PendingExpansionType::PendingExpansionType( std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 87d5686fa..abc652861 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -43,6 +43,8 @@ LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) +LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) +LUAU_FASTFLAGVARIABLE(LuauReducingAndOr, false) namespace Luau { @@ -344,42 +346,54 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo return std::move(currentModule); } -void TypeChecker::check(const ScopePtr& scope, const AstStat& program) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program) { + if (finishTime && TimeTrace::getClock() > *finishTime) + throw TimeLimitError(iceHandler->moduleName); + if (auto block = program.as()) - check(scope, *block); + return check(scope, *block); else if (auto if_ = program.as()) - check(scope, *if_); + return check(scope, *if_); else if (auto while_ = program.as()) - check(scope, *while_); + return check(scope, *while_); else if (auto repeat = program.as()) - check(scope, *repeat); - else if (program.is()) - { - } // Nothing to do - else if (program.is()) + return check(scope, *repeat); + else if (program.is() || program.is()) { - } // Nothing to do + // Nothing to do + return ControlFlow::None; + } else if (auto return_ = program.as()) - check(scope, *return_); + return check(scope, *return_); else if (auto expr = program.as()) + { checkExprPack(scope, *expr->expr); + + if (FFlag::LuauTinyControlFlowAnalysis) + { + if (auto call = expr->expr->as(); call && doesCallError(call)) + return ControlFlow::Throws; + } + + return ControlFlow::None; + } else if (auto local = program.as()) - check(scope, *local); + return check(scope, *local); else if (auto for_ = program.as()) - check(scope, *for_); + return check(scope, *for_); else if (auto forIn = program.as()) - check(scope, *forIn); + return check(scope, *forIn); else if (auto assign = program.as()) - check(scope, *assign); + return check(scope, *assign); else if (auto assign = program.as()) - check(scope, *assign); + return check(scope, *assign); else if (program.is()) ice("Should not be calling two-argument check() on a function statement", program.location); else if (program.is()) ice("Should not be calling two-argument check() on a function statement", program.location); else if (auto typealias = program.as()) - check(scope, *typealias); + return check(scope, *typealias); else if (auto global = program.as()) { TypeId globalType = resolveType(scope, *global->type); @@ -387,11 +401,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) currentModule->declaredGlobals[globalName] = globalType; currentModule->getModuleScope()->bindings[global->name] = Binding{globalType, global->location}; + + return ControlFlow::None; } else if (auto global = program.as()) - check(scope, *global); + return check(scope, *global); else if (auto global = program.as()) - check(scope, *global); + return check(scope, *global); else if (auto errorStatement = program.as()) { const size_t oldSize = currentModule->errors.size(); @@ -405,37 +421,40 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) // HACK: We want to run typechecking on the contents of the AstStatError, but // we don't think the type errors will be useful most of the time. currentModule->errors.resize(oldSize); + + return ControlFlow::None; } else ice("Unknown AstStat"); - - if (finishTime && TimeTrace::getClock() > *finishTime) - throw TimeLimitError(iceHandler->moduleName); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. -void TypeChecker::check(const ScopePtr& scope, const AstStatBlock& block) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatBlock& block) { ScopePtr child = childScope(scope, block.location); - checkBlock(child, block); + + ControlFlow flow = checkBlock(child, block); + scope->inheritRefinements(child); + + return flow; } -void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) +ControlFlow TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(block.location); - return; + return ControlFlow::None; } try { - checkBlockWithoutRecursionCheck(scope, block); + return checkBlockWithoutRecursionCheck(scope, block); } catch (const RecursionLimitException&) { reportErrorCodeTooComplex(block.location); - return; + return ControlFlow::None; } } @@ -488,7 +507,7 @@ struct InplaceDemoter : TypeOnceVisitor } }; -void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) +ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) { int subLevel = 0; @@ -528,6 +547,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } }; + std::optional firstFlow; while (protoIter != sorted.end()) { // protoIter walks forward @@ -570,7 +590,9 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A // We do check the current element, so advance checkIter beyond it. ++checkIter; - check(scope, **protoIter); + ControlFlow flow = check(scope, **protoIter); + if (flow != ControlFlow::None && !firstFlow) + firstFlow = flow; } else if (auto fun = (*protoIter)->as()) { @@ -631,7 +653,11 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A scope->bindings[fun->name] = {funTy, fun->name->location}; } else - check(scope, **protoIter); + { + ControlFlow flow = check(scope, **protoIter); + if (flow != ControlFlow::None && !firstFlow) + firstFlow = flow; + } ++protoIter; } @@ -643,6 +669,8 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } checkBlockTypeAliases(scope, sorted); + + return firstFlow.value_or(ControlFlow::None); } LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted) @@ -717,19 +745,45 @@ static std::optional tryGetTypeGuardPredicate(const AstExprBinary& ex return predicate; } -void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) { WithPredicate result = checkExpr(scope, *statement.condition); - ScopePtr ifScope = childScope(scope, statement.thenbody->location); - resolve(result.predicates, ifScope, true); - check(ifScope, *statement.thenbody); + ScopePtr thenScope = childScope(scope, statement.thenbody->location); + resolve(result.predicates, thenScope, true); - if (statement.elsebody) + if (FFlag::LuauTinyControlFlowAnalysis) { - ScopePtr elseScope = childScope(scope, statement.elsebody->location); + ScopePtr elseScope = childScope(scope, statement.elsebody ? statement.elsebody->location : statement.location); resolve(result.predicates, elseScope, false); - check(elseScope, *statement.elsebody); + + ControlFlow thencf = check(thenScope, *statement.thenbody); + ControlFlow elsecf = ControlFlow::None; + if (statement.elsebody) + elsecf = check(elseScope, *statement.elsebody); + + if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && elsecf == ControlFlow::None) + scope->inheritRefinements(elseScope); + else if (thencf == ControlFlow::None && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + scope->inheritRefinements(thenScope); + + if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + return ControlFlow::Returns; + else + return ControlFlow::None; + } + else + { + check(thenScope, *statement.thenbody); + + if (statement.elsebody) + { + ScopePtr elseScope = childScope(scope, statement.elsebody->location); + resolve(result.predicates, elseScope, false); + check(elseScope, *statement.elsebody); + } + + return ControlFlow::None; } } @@ -750,22 +804,26 @@ ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Scope return canUnify_(subTy, superTy, scope, location); } -void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) { WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr whileScope = childScope(scope, statement.body->location); resolve(result.predicates, whileScope, true); check(whileScope, *statement.body); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) { ScopePtr repScope = childScope(scope, statement.location); checkBlock(repScope, *statement.body); checkExpr(repScope, *statement.condition); + + return ControlFlow::None; } struct Demoter : Substitution @@ -822,7 +880,7 @@ struct Demoter : Substitution } }; -void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; expectedTypes.reserve(return_.list.size); @@ -858,10 +916,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) if (!errors.empty()) currentModule->getModuleScope()->returnType = addTypePack({anyType}); - return; + return FFlag::LuauTinyControlFlowAnalysis ? ControlFlow::Returns : ControlFlow::None; } unify(retPack, scope->returnType, scope, return_.location, CountMismatch::Context::Return); + + return FFlag::LuauTinyControlFlowAnalysis ? ControlFlow::Returns : ControlFlow::None; } template @@ -893,7 +953,7 @@ ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const Scope return tryUnify_(subTy, superTy, scope, location); } -void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { std::vector> expectedTypes; expectedTypes.reserve(assign.vars.size); @@ -993,9 +1053,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } } } + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assign) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assign) { AstExprBinary expr(assign.location, assign.op, assign.var, assign.value); @@ -1005,9 +1067,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assi TypeId result = checkBinaryOperation(scope, expr, left, right); unify(result, left, scope, assign.location); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) { // Important subtlety: A local variable is not in scope while its initializer is being evaluated. // For instance, you cannot do this: @@ -1144,9 +1208,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) for (const auto& [local, binding] : varBindings) scope->bindings[local] = binding; + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) { ScopePtr loopScope = childScope(scope, expr.location); @@ -1169,9 +1235,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) unify(checkExpr(loopScope, *expr.step).type, loopVarType, scope, expr.step->location); check(loopScope, *expr.body); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) { ScopePtr loopScope = childScope(scope, forin.location); @@ -1360,9 +1428,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(retPack, varPack, scope, forin.location); check(loopScope, *forin.body); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function) +ControlFlow TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function) { if (auto exprName = function.name->as()) { @@ -1387,8 +1457,6 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco globalBindings[name] = oldBinding; else globalBindings[name] = {quantify(funScope, ty, exprName->location), exprName->location}; - - return; } else if (auto name = function.name->as()) { @@ -1397,7 +1465,6 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; - return; } else if (auto name = function.name->as()) { @@ -1444,9 +1511,11 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); } + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function) +ControlFlow TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function) { Name name = function.name->name.value; @@ -1455,15 +1524,17 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias) { Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. if (name == kParseNameError) - return; + return ControlFlow::None; std::optional binding; if (auto it = scope->exportedTypeBindings.find(name); it != scope->exportedTypeBindings.end()) @@ -1476,7 +1547,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be // interesting. if (duplicateTypeAliases.find({typealias.exported, name})) - return; + return ControlFlow::None; // By now this alias must have been `prototype()`d first. if (!binding) @@ -1557,6 +1628,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (unify(ty, bindingType, aliasScope, typealias.location)) bindingType = ty; + + return ControlFlow::None; } void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel) @@ -1648,13 +1721,13 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; } -void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { Name className(declaredClass.name.value); // Don't bother checking if the class definition was incorrect if (incorrectClassDefinitions.find(&declaredClass)) - return; + return ControlFlow::None; std::optional binding; if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end()) @@ -1721,9 +1794,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } } } + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& global) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& global) { ScopePtr funScope = childFunctionScope(scope, global.location); @@ -1754,6 +1829,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->declaredGlobals[fnName] = fnType; currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; + + return ControlFlow::None; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) @@ -2785,6 +2862,16 @@ TypeId TypeChecker::checkRelationalOperation( if (notNever) { LUAU_ASSERT(oty); + + if (FFlag::LuauReducingAndOr) + { + // Perform a limited form of type reduction for booleans + if (isPrim(*oty, PrimitiveType::Boolean) && get(get(follow(rhsType)))) + return booleanType; + if (isPrim(rhsType, PrimitiveType::Boolean) && get(get(follow(*oty)))) + return booleanType; + } + return unionOfTypes(*oty, rhsType, scope, expr.location, false); } else @@ -2808,6 +2895,16 @@ TypeId TypeChecker::checkRelationalOperation( if (notNever) { LUAU_ASSERT(oty); + + if (FFlag::LuauReducingAndOr) + { + // Perform a limited form of type reduction for booleans + if (isPrim(*oty, PrimitiveType::Boolean) && get(get(follow(rhsType)))) + return booleanType; + if (isPrim(rhsType, PrimitiveType::Boolean) && get(get(follow(*oty)))) + return booleanType; + } + return unionOfTypes(*oty, rhsType, scope, expr.location); } else diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index 9db8f7f00..dcb2d3673 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -8,6 +8,11 @@ namespace Unifiable static int nextIndex = 0; +int freshIndex() +{ + return ++nextIndex; +} + Free::Free(TypeLevel level) : index(++nextIndex) , level(level) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index b53401dce..9f30d11ba 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -20,9 +20,11 @@ LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) +LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauTinyUnifyNormalsFix, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauNegatedFunctionTypes) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAG(LuauNegatedTableTypes) @@ -475,16 +477,27 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (log.get(superTy)) return tryUnifyWithAny(subTy, builtinTypes->anyType); - if (log.get(superTy)) + if (!FFlag::LuauTransitiveSubtyping && log.get(superTy)) return tryUnifyWithAny(subTy, builtinTypes->errorType); - if (log.get(superTy)) + if (!FFlag::LuauTransitiveSubtyping && log.get(superTy)) return tryUnifyWithAny(subTy, builtinTypes->unknownType); if (log.get(subTy)) + { + if (FFlag::LuauTransitiveSubtyping && normalize) + { + // TODO: there are probably cheaper ways to check if any <: T. + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (!log.get(superNorm->tops)) + failure = true; + } + else + failure = true; return tryUnifyWithAny(superTy, builtinTypes->anyType); + } - if (log.get(subTy)) + if (!FFlag::LuauTransitiveSubtyping && log.get(subTy)) return tryUnifyWithAny(superTy, builtinTypes->errorType); if (log.get(subTy)) @@ -539,6 +552,35 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); } + else if (FFlag::LuauTransitiveSubtyping && log.get(subTy)) + { + tryUnifyWithAny(superTy, builtinTypes->unknownType); + failure = true; + } + else if (FFlag::LuauTransitiveSubtyping && log.get(subTy) && log.get(superTy)) + { + // error <: error + } + else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + { + tryUnifyWithAny(subTy, builtinTypes->errorType); + failure = true; + } + else if (FFlag::LuauTransitiveSubtyping && log.get(subTy)) + { + tryUnifyWithAny(superTy, builtinTypes->errorType); + failure = true; + } + else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + { + // At this point, all the supertypes of `error` have been handled, + // and if `error unknownType); + } + else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + { + tryUnifyWithAny(subTy, builtinTypes->unknownType); + } else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyPrimitives(subTy, superTy); @@ -611,6 +653,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ { // A | B <: T if and only if A <: T and B <: T bool failed = false; + bool errorsSuppressed = true; std::optional unificationTooComplex; std::optional firstFailedOption; @@ -626,13 +669,17 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; - else if (!innerState.errors.empty()) + else if (FFlag::LuauTransitiveSubtyping ? innerState.failure : !innerState.errors.empty()) { + // If errors were suppressed, we store the log up, so we can commit it if no other option succeeds. + if (FFlag::LuauTransitiveSubtyping && innerState.errors.empty()) + logs.push_back(std::move(innerState.log)); // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' - if (!firstFailedOption && !isNil(type)) + else if (!firstFailedOption && !isNil(type)) firstFailedOption = {innerState.errors.front()}; failed = true; + errorsSuppressed &= innerState.errors.empty(); } } @@ -684,12 +731,13 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ { if (firstFailedOption) reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption, mismatchContext()}); - else + else if (!FFlag::LuauTransitiveSubtyping || !errorsSuppressed) reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); + failure = true; } } -struct BlockedTypeFinder : TypeOnceVisitor +struct DEPRECATED_BlockedTypeFinder : TypeOnceVisitor { std::unordered_set blockedTypes; @@ -700,9 +748,10 @@ struct BlockedTypeFinder : TypeOnceVisitor } }; -bool Unifier::blockOnBlockedTypes(TypeId subTy, TypeId superTy) +bool Unifier::DEPRECATED_blockOnBlockedTypes(TypeId subTy, TypeId superTy) { - BlockedTypeFinder blockedTypeFinder; + LUAU_ASSERT(!FFlag::LuauNormalizeBlockedTypes); + DEPRECATED_BlockedTypeFinder blockedTypeFinder; blockedTypeFinder.traverse(subTy); blockedTypeFinder.traverse(superTy); if (!blockedTypeFinder.blockedTypes.empty()) @@ -718,6 +767,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { // T <: A | B if T <: A or T <: B bool found = false; + bool errorsSuppressed = false; std::optional unificationTooComplex; size_t failedOptionCount = 0; @@ -754,6 +804,21 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } } + if (FFlag::LuauTransitiveSubtyping && !foundHeuristic) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[i]; + + if (subTy == type) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + if (!foundHeuristic && cacheEnabled) { auto& cache = sharedState.cachedUnify; @@ -779,7 +844,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp innerState.normalize = false; innerState.tryUnify_(subTy, type, isFunctionCall); - if (innerState.errors.empty()) + if (FFlag::LuauTransitiveSubtyping ? !innerState.failure : innerState.errors.empty()) { found = true; if (FFlag::DebugLuauDeferredConstraintResolution) @@ -790,6 +855,10 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp break; } } + else if (FFlag::LuauTransitiveSubtyping && innerState.errors.empty()) + { + errorsSuppressed = true; + } else if (auto e = hasUnificationTooComplex(innerState.errors)) { unificationTooComplex = e; @@ -810,11 +879,32 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { reportError(*unificationTooComplex); } + else if (FFlag::LuauTransitiveSubtyping && !found && normalize) + { + // It is possible that T <: A | B even though T normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + Unifier innerState = makeChildUnifier(); + if (!subNorm || !superNorm) + return reportError(location, UnificationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + else + innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + if (!innerState.failure) + log.concat(std::move(innerState.log)); + else if (errorsSuppressed || innerState.errors.empty()) + failure = true; + else + reportError(std::move(innerState.errors.front())); + } else if (!found && normalize) { // We cannot normalize a type that contains blocked types. We have to // stop for now if we find any. - if (blockOnBlockedTypes(subTy, superTy)) + if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) return; // It is possible that T <: A | B even though T unificationTooComplex; size_t startIndex = 0; @@ -919,7 +1013,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* { // We cannot normalize a type that contains blocked types. We have to // stop for now if we find any. - if (blockOnBlockedTypes(subTy, superTy)) + if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) return; // Sometimes a negation type is inside one of the types, e.g. { p: number } & { p: ~number }. @@ -951,13 +1045,18 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* innerState.normalize = false; innerState.tryUnify_(type, superTy, isFunctionCall); + // TODO: This sets errorSuppressed to true if any of the parts is error-suppressing, + // in paricular any & T is error-suppressing. Really, errorSuppressed should be true if + // all of the parts are error-suppressing, but that fails to typecheck lua-apps. if (innerState.errors.empty()) { found = true; - if (FFlag::DebugLuauDeferredConstraintResolution) + errorsSuppressed = innerState.failure; + if (FFlag::DebugLuauDeferredConstraintResolution || (FFlag::LuauTransitiveSubtyping && innerState.failure)) logs.push_back(std::move(innerState.log)); else { + errorsSuppressed = false; log.concat(std::move(innerState.log)); break; } @@ -970,6 +1069,8 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* if (FFlag::DebugLuauDeferredConstraintResolution) log.concat(combineLogsIntoIntersection(std::move(logs))); + else if (FFlag::LuauTransitiveSubtyping && errorsSuppressed) + log.concat(std::move(logs.front())); if (unificationTooComplex) reportError(*unificationTooComplex); @@ -977,7 +1078,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* { // We cannot normalize a type that contains blocked types. We have to // stop for now if we find any. - if (blockOnBlockedTypes(subTy, superTy)) + if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) return; // It is possible that A & B <: T even though A error) { - if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) + if (!FFlag::LuauTransitiveSubtyping && get(superNorm.tops)) return; - else if (get(subNorm.tops)) + else if (get(superNorm.tops)) + return; + else if (get(subNorm.tops)) + { + failure = true; + return; + } + else if (!FFlag::LuauTransitiveSubtyping && get(subNorm.tops)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.errors)) if (!get(superNorm.errors)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + { + failure = true; + if (!FFlag::LuauTransitiveSubtyping) + reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + return; + } + + if (FFlag::LuauTransitiveSubtyping && get(superNorm.tops)) + return; + + if (FFlag::LuauTransitiveSubtyping && get(subNorm.tops)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.booleans)) { @@ -1911,6 +2032,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else if (subTable->indexer && maybeString(subTable->indexer->indexType)) { @@ -1926,6 +2048,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else if (subTable->state == TableState::Unsealed && isOptional(prop.type)) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` @@ -1988,6 +2111,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else if (superTable->state == TableState::Unsealed) { @@ -2059,6 +2183,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else if (superTable->indexer) { @@ -2234,6 +2359,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}); log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else if (TableType* subTable = log.getMutable(subTy)) { @@ -2274,6 +2400,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { log.concat(std::move(innerState.log)); log.bindTable(subTy, superTy); + failure |= innerState.failure; } } else @@ -2367,6 +2494,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (innerState.errors.empty()) { log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else { @@ -2398,7 +2526,7 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) // We cannot normalize a type that contains blocked types. We have to // stop for now if we find any. - if (blockOnBlockedTypes(subTy, superTy)) + if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) return; const NormalizedType* subNorm = normalizer->normalize(subTy); @@ -2726,6 +2854,7 @@ Unifier Unifier::makeChildUnifier() void Unifier::reportError(Location location, TypeErrorData data) { errors.emplace_back(std::move(location), std::move(data)); + failure = true; } // A utility function that appends the given error to the unifier's error log. @@ -2736,6 +2865,7 @@ void Unifier::reportError(Location location, TypeErrorData data) void Unifier::reportError(TypeError err) { errors.push_back(std::move(err)); + failure = true; } diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 8b7eb73cf..0b9d8c467 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -123,13 +123,13 @@ class Parser // return [explist] AstStat* parseReturn(); - // type Name `=' typeannotation + // type Name `=' Type AstStat* parseTypeAlias(const Location& start, bool exported); AstDeclaredClassProp parseDeclaredClassMethod(); - // `declare global' Name: typeannotation | - // `declare function' Name`(' [parlist] `)' [`:` TypeAnnotation] + // `declare global' Name: Type | + // `declare function' Name`(' [parlist] `)' [`:` Type] AstStat* parseDeclaration(const Location& start); // varlist `=' explist @@ -140,7 +140,7 @@ class Parser std::pair> prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args); - // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` TypeAnnotation] + // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); @@ -148,21 +148,21 @@ class Parser // explist ::= {exp `,'} exp void parseExprList(TempVector& result); - // binding ::= Name [`:` TypeAnnotation] + // binding ::= Name [`:` Type] Binding parseBinding(); // bindinglist ::= (binding | `...') {`,' bindinglist} // Returns the location of the vararg ..., or std::nullopt if the function is not vararg. std::tuple parseBindingList(TempVector& result, bool allowDot3 = false); - AstType* parseOptionalTypeAnnotation(); + AstType* parseOptionalType(); - // TypeList ::= TypeAnnotation [`,' TypeList] - // ReturnType ::= TypeAnnotation | `(' TypeList `)' - // TableProp ::= Name `:' TypeAnnotation - // TableIndexer ::= `[' TypeAnnotation `]' `:' TypeAnnotation + // TypeList ::= Type [`,' TypeList] + // ReturnType ::= Type | `(' TypeList `)' + // TableProp ::= Name `:' Type + // TableIndexer ::= `[' Type `]' `:' Type // PropList ::= (TableProp | TableIndexer) [`,' PropList] - // TypeAnnotation + // Type // ::= Name // | `nil` // | `{' [PropList] `}' @@ -171,24 +171,25 @@ class Parser // Returns the variadic annotation, if it exists. AstTypePack* parseTypeList(TempVector& result, TempVector>& resultNames); - std::optional parseOptionalReturnTypeAnnotation(); - std::pair parseReturnTypeAnnotation(); + std::optional parseOptionalReturnType(); + std::pair parseReturnType(); - AstTableIndexer* parseTableIndexerAnnotation(); + AstTableIndexer* parseTableIndexer(); - AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); - AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); + AstTypeOrPack parseFunctionType(bool allowPack); + AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation); - AstType* parseTableTypeAnnotation(); - AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack); + AstType* parseTableType(); + AstTypeOrPack parseSimpleType(bool allowPack); - AstTypeOrPack parseTypeOrPackAnnotation(); - AstType* parseTypeAnnotation(TempVector& parts, const Location& begin); - AstType* parseTypeAnnotation(); + AstTypeOrPack parseTypeOrPack(); + AstType* parseType(); - AstTypePack* parseTypePackAnnotation(); - AstTypePack* parseVariadicArgumentAnnotation(); + AstTypePack* parseTypePack(); + AstTypePack* parseVariadicArgumentTypePack(); + + AstType* parseTypeSuffix(AstType* type, const Location& begin); static std::optional parseUnaryOp(const Lexeme& l); static std::optional parseBinaryOp(const Lexeme& l); @@ -215,7 +216,7 @@ class Parser // primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } AstExpr* parsePrimaryExpr(bool asStatement); - // asexp -> simpleexp [`::' typeAnnotation] + // asexp -> simpleexp [`::' Type] AstExpr* parseAssertionExpr(); // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp @@ -244,7 +245,7 @@ class Parser // `<' namelist `>' std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); - // `<' typeAnnotation[, ...] `>' + // `<' Type[, ...] `>' AstArray parseTypeParams(); std::optional> parseCharArray(); @@ -302,13 +303,12 @@ class Parser AstStatError* reportStatError(const Location& location, const AstArray& expressions, const AstArray& statements, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); - AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) - LUAU_PRINTF_ATTR(4, 5); + AstTypeError* reportTypeError(const Location& location, const AstArray& types, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); // `parseErrorLocation` is associated with the parser error // `astErrorLocation` is associated with the AstTypeError created // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely // define the location (possibly of zero size) where a type annotation is expected. - AstTypeError* reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) + AstTypeError* reportMissingTypeError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); AstExpr* reportFunctionArgsError(AstExpr* func, bool self); @@ -401,8 +401,8 @@ class Parser std::vector scratchBinding; std::vector scratchLocal; std::vector scratchTableTypeProps; - std::vector scratchAnnotation; - std::vector scratchTypeOrPackAnnotation; + std::vector scratchType; + std::vector scratchTypeOrPack; std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 4c347712f..40fa754e6 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -130,7 +130,7 @@ void TempVector::push_back(const T& item) size_++; } -static bool shouldParseTypePackAnnotation(Lexer& lexer) +static bool shouldParseTypePack(Lexer& lexer) { if (lexer.current().type == Lexeme::Dot3) return true; @@ -330,11 +330,12 @@ AstStat* Parser::parseStat() if (options.allowTypeAnnotations) { if (ident == "type") - return parseTypeAlias(expr->location, /* exported =*/false); + return parseTypeAlias(expr->location, /* exported= */ false); + if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") { nextLexeme(); - return parseTypeAlias(expr->location, /* exported =*/true); + return parseTypeAlias(expr->location, /* exported= */ true); } } @@ -742,7 +743,7 @@ AstStat* Parser::parseReturn() return allocator.alloc(Location(start, end), copy(list)); } -// type Name [`<' varlist `>'] `=' typeannotation +// type Name [`<' varlist `>'] `=' Type AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { // note: `type` token is already parsed for us, so we just need to parse the rest @@ -757,7 +758,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) expectAndConsume('=', "type alias"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); return allocator.alloc(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported); } @@ -789,16 +790,16 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() expectMatchAndConsume(')', matchParen); - AstTypeList retTypes = parseOptionalReturnTypeAnnotation().value_or(AstTypeList{copy(nullptr, 0), nullptr}); + AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0), nullptr}); Location end = lexer.current().location; - TempVector vars(scratchAnnotation); + TempVector vars(scratchType); TempVector> varNames(scratchOptArgName); if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { return AstDeclaredClassProp{ - fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; + fnName.name, reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; } // Skip the first index. @@ -809,7 +810,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args[i].annotation) vars.push_back(args[i].annotation); else - vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); + vars.push_back(reportTypeError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); } if (vararg && !varargAnnotation) @@ -846,10 +847,10 @@ AstStat* Parser::parseDeclaration(const Location& start) expectMatchAndConsume(')', matchParen); - AstTypeList retTypes = parseOptionalReturnTypeAnnotation().value_or(AstTypeList{copy(nullptr, 0)}); + AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0)}); Location end = lexer.current().location; - TempVector vars(scratchAnnotation); + TempVector vars(scratchType); TempVector varNames(scratchArgName); for (size_t i = 0; i < args.size(); ++i) @@ -898,7 +899,7 @@ AstStat* Parser::parseDeclaration(const Location& start) expectMatchAndConsume(']', begin); expectAndConsume(':', "property type annotation"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); // TODO: since AstName conains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); @@ -912,7 +913,7 @@ AstStat* Parser::parseDeclaration(const Location& start) { Name propName = parseName("property name"); expectAndConsume(':', "property type annotation"); - AstType* propType = parseTypeAnnotation(); + AstType* propType = parseType(); props.push_back(AstDeclaredClassProp{propName.name, propType, false}); } } @@ -926,7 +927,7 @@ AstStat* Parser::parseDeclaration(const Location& start) { expectAndConsume(':', "global variable declaration"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); return allocator.alloc(Location(start, type->location), globalName->name, type); } else @@ -1027,7 +1028,7 @@ std::pair Parser::parseFunctionBody( expectMatchAndConsume(')', matchParen, true); - std::optional typelist = parseOptionalReturnTypeAnnotation(); + std::optional typelist = parseOptionalReturnType(); AstLocal* funLocal = nullptr; @@ -1085,7 +1086,7 @@ Parser::Binding Parser::parseBinding() if (!name) name = Name(nameError, lexer.current().location); - AstType* annotation = parseOptionalTypeAnnotation(); + AstType* annotation = parseOptionalType(); return Binding(*name, annotation); } @@ -1104,7 +1105,7 @@ std::tuple Parser::parseBindingList(TempVector Parser::parseBindingList(TempVector& result, TempVector>& resultNames) { while (true) { - if (shouldParseTypePackAnnotation(lexer)) - return parseTypePackAnnotation(); + if (shouldParseTypePack(lexer)) + return parseTypePack(); if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':') { @@ -1156,7 +1157,7 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() +std::optional Parser::parseOptionalReturnType() { if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) { @@ -1183,7 +1184,7 @@ std::optional Parser::parseOptionalReturnTypeAnnotation() unsigned int oldRecursionCount = recursionCounter; - auto [_location, result] = parseReturnTypeAnnotation(); + auto [_location, result] = parseReturnType(); // At this point, if we find a , character, it indicates that there are multiple return types // in this type annotation, but the list wasn't wrapped in parentheses. @@ -1202,27 +1203,27 @@ std::optional Parser::parseOptionalReturnTypeAnnotation() return std::nullopt; } -// ReturnType ::= TypeAnnotation | `(' TypeList `)' -std::pair Parser::parseReturnTypeAnnotation() +// ReturnType ::= Type | `(' TypeList `)' +std::pair Parser::parseReturnType() { incrementRecursionCounter("type annotation"); - TempVector result(scratchAnnotation); - TempVector> resultNames(scratchOptArgName); - AstTypePack* varargAnnotation = nullptr; - Lexeme begin = lexer.current(); if (lexer.current().type != '(') { - if (shouldParseTypePackAnnotation(lexer)) - varargAnnotation = parseTypePackAnnotation(); - else - result.push_back(parseTypeAnnotation()); + if (shouldParseTypePack(lexer)) + { + AstTypePack* typePack = parseTypePack(); - Location resultLocation = result.size() == 0 ? varargAnnotation->location : result[0]->location; + return {typePack->location, AstTypeList{{}, typePack}}; + } + else + { + AstType* type = parseType(); - return {resultLocation, AstTypeList{copy(result), varargAnnotation}}; + return {type->location, AstTypeList{copy(&type, 1), nullptr}}; + } } nextLexeme(); @@ -1231,6 +1232,10 @@ std::pair Parser::parseReturnTypeAnnotation() matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; + TempVector result(scratchType); + TempVector> resultNames(scratchOptArgName); + AstTypePack* varargAnnotation = nullptr; + // possibly () -> ReturnType if (lexer.current().type != ')') varargAnnotation = parseTypeList(result, resultNames); @@ -1246,9 +1251,9 @@ std::pair Parser::parseReturnTypeAnnotation() // If it turns out that it's just '(A)', it's possible that there are unions/intersections to follow, so fold over it. if (result.size() == 1) { - AstType* returnType = parseTypeAnnotation(result, innerBegin); + AstType* returnType = parseTypeSuffix(result[0], innerBegin); - // If parseTypeAnnotation parses nothing, then returnType->location.end only points at the last non-type-pack + // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack // type to successfully parse. We need the span of the whole annotation. Position endPos = result.size() == 1 ? location.end : returnType->location.end; @@ -1258,39 +1263,33 @@ std::pair Parser::parseReturnTypeAnnotation() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstArray generics{nullptr, 0}; - AstArray genericPacks{nullptr, 0}; - AstArray types = copy(result); - AstArray> names = copy(resultNames); + AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation); - TempVector fallbackReturnTypes(scratchAnnotation); - fallbackReturnTypes.push_back(parseFunctionTypeAnnotationTail(begin, generics, genericPacks, types, names, varargAnnotation)); - - return {Location{location, fallbackReturnTypes[0]->location}, AstTypeList{copy(fallbackReturnTypes), varargAnnotation}}; + return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } -// TableIndexer ::= `[' TypeAnnotation `]' `:' TypeAnnotation -AstTableIndexer* Parser::parseTableIndexerAnnotation() +// TableIndexer ::= `[' Type `]' `:' Type +AstTableIndexer* Parser::parseTableIndexer() { const Lexeme begin = lexer.current(); nextLexeme(); // [ - AstType* index = parseTypeAnnotation(); + AstType* index = parseType(); expectMatchAndConsume(']', begin); expectAndConsume(':', "table field"); - AstType* result = parseTypeAnnotation(); + AstType* result = parseType(); return allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location)}); } -// TableProp ::= Name `:' TypeAnnotation +// TableProp ::= Name `:' Type // TablePropOrIndexer ::= TableProp | TableIndexer // PropList ::= TablePropOrIndexer {fieldsep TablePropOrIndexer} [fieldsep] -// TableTypeAnnotation ::= `{' PropList `}' -AstType* Parser::parseTableTypeAnnotation() +// TableType ::= `{' PropList `}' +AstType* Parser::parseTableType() { incrementRecursionCounter("type annotation"); @@ -1313,7 +1312,7 @@ AstType* Parser::parseTableTypeAnnotation() expectMatchAndConsume(']', begin); expectAndConsume(':', "table field"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); // TODO: since AstName conains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); @@ -1329,19 +1328,19 @@ AstType* Parser::parseTableTypeAnnotation() { // maybe we don't need to parse the entire badIndexer... // however, we either have { or [ to lint, not the entire table type or the bad indexer. - AstTableIndexer* badIndexer = parseTableIndexerAnnotation(); + AstTableIndexer* badIndexer = parseTableIndexer(); // we lose all additional indexer expressions from the AST after error recovery here report(badIndexer->location, "Cannot have more than one table indexer"); } else { - indexer = parseTableIndexerAnnotation(); + indexer = parseTableIndexer(); } } else if (props.empty() && !indexer && !(lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':')) { - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); // array-like table type: {T} desugars into {[number]: T} AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber); @@ -1358,7 +1357,7 @@ AstType* Parser::parseTableTypeAnnotation() expectAndConsume(':', "table field"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); props.push_back({name->name, name->location, type}); } @@ -1382,9 +1381,9 @@ AstType* Parser::parseTableTypeAnnotation() return allocator.alloc(Location(start, end), copy(props), indexer); } -// ReturnType ::= TypeAnnotation | `(' TypeList `)' -// FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) +// ReturnType ::= Type | `(' TypeList `)' +// FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType +AstTypeOrPack Parser::parseFunctionType(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1400,7 +1399,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; - TempVector params(scratchAnnotation); + TempVector params(scratchType); TempVector> names(scratchOptArgName); AstTypePack* varargAnnotation = nullptr; @@ -1432,12 +1431,11 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray> paramNames = copy(names); - return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; + return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation) - +AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation) { incrementRecursionCounter("type annotation"); @@ -1458,21 +1456,22 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray(Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList); } -// typeannotation ::= +// Type ::= // nil | // Name[`.' Name] [`<' namelist `>'] | // `{' [PropList] `}' | // `(' [TypeList] `)' `->` ReturnType -// `typeof` typeannotation -AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location& begin) +// `typeof` Type +AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) { - LUAU_ASSERT(!parts.empty()); + TempVector parts(scratchType); + parts.push_back(type); incrementRecursionCounter("type annotation"); @@ -1487,7 +1486,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + parts.push_back(parseSimpleType(/* allowPack= */ false).type); isUnion = true; } else if (c == '?') @@ -1500,7 +1499,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + parts.push_back(parseSimpleType(/* allowPack= */ false).type); isIntersection = true; } else if (c == Lexeme::Dot3) @@ -1513,11 +1512,11 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location } if (parts.size() == 1) - return parts[0]; + return type; if (isUnion && isIntersection) { - return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), + return reportTypeError(Location(begin, parts.back()->location), copy(parts), "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } @@ -1533,16 +1532,14 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location ParseError::raise(begin, "Composite type was not an intersection or union."); } -AstTypeOrPack Parser::parseTypeOrPackAnnotation() +AstTypeOrPack Parser::parseTypeOrPack() { unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("type annotation"); Location begin = lexer.current().location; - TempVector parts(scratchAnnotation); - - auto [type, typePack] = parseSimpleTypeAnnotation(/* allowPack= */ true); + auto [type, typePack] = parseSimpleType(/* allowPack= */ true); if (typePack) { @@ -1550,31 +1547,28 @@ AstTypeOrPack Parser::parseTypeOrPackAnnotation() return {{}, typePack}; } - parts.push_back(type); - recursionCounter = oldRecursionCount; - return {parseTypeAnnotation(parts, begin), {}}; + return {parseTypeSuffix(type, begin), {}}; } -AstType* Parser::parseTypeAnnotation() +AstType* Parser::parseType() { unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("type annotation"); Location begin = lexer.current().location; - TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + AstType* type = parseSimpleType(/* allowPack= */ false).type; recursionCounter = oldRecursionCount; - return parseTypeAnnotation(parts, begin); + return parseTypeSuffix(type, begin); } -// typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' +// Type ::= nil | Name[`.' Name] [ `<' Type [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) +AstTypeOrPack Parser::parseSimpleType(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1603,18 +1597,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) return {allocator.alloc(start, svalue)}; } else - return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")}; + return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { parseInterpString(); - return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")}; + return {reportTypeError(start, {}, "Interpolated string literals cannot be used as types")}; } else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, "Malformed string")}; + return {reportTypeError(start, {}, "Malformed string")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1663,17 +1657,17 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) } else if (lexer.current().type == '{') { - return {parseTableTypeAnnotation(), {}}; + return {parseTableType(), {}}; } else if (lexer.current().type == '(' || lexer.current().type == '<') { - return parseFunctionTypeAnnotation(allowPack); + return parseFunctionType(allowPack); } else if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, + return {reportTypeError(start, {}, "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " "...any'"), {}}; @@ -1685,12 +1679,11 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. Location parseErrorLocation(lexer.previousLocation().end, start.end); - return { - reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; + return {reportMissingTypeError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } -AstTypePack* Parser::parseVariadicArgumentAnnotation() +AstTypePack* Parser::parseVariadicArgumentTypePack() { // Generic: a... if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == Lexeme::Dot3) @@ -1705,19 +1698,19 @@ AstTypePack* Parser::parseVariadicArgumentAnnotation() // Variadic: T else { - AstType* variadicAnnotation = parseTypeAnnotation(); + AstType* variadicAnnotation = parseType(); return allocator.alloc(variadicAnnotation->location, variadicAnnotation); } } -AstTypePack* Parser::parseTypePackAnnotation() +AstTypePack* Parser::parseTypePack() { // Variadic: ...T if (lexer.current().type == Lexeme::Dot3) { Location start = lexer.current().location; nextLexeme(); - AstType* varargTy = parseTypeAnnotation(); + AstType* varargTy = parseType(); return allocator.alloc(Location(start, varargTy->location), varargTy); } // Generic: a... @@ -2054,7 +2047,7 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) return expr; } -// asexp -> simpleexp [`::' typeannotation] +// asexp -> simpleexp [`::' Type] AstExpr* Parser::parseAssertionExpr() { Location start = lexer.current().location; @@ -2063,7 +2056,7 @@ AstExpr* Parser::parseAssertionExpr() if (options.allowTypeAnnotations && lexer.current().type == Lexeme::DoubleColon) { nextLexeme(); - AstType* annotation = parseTypeAnnotation(); + AstType* annotation = parseType(); return allocator.alloc(Location(start, annotation->location), expr, annotation); } else @@ -2455,15 +2448,15 @@ std::pair, AstArray> Parser::parseG Lexeme packBegin = lexer.current(); - if (shouldParseTypePackAnnotation(lexer)) + if (shouldParseTypePack(lexer)) { - AstTypePack* typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePack(); namePacks.push_back({name, nameLocation, typePack}); } else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (type) report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); @@ -2472,7 +2465,7 @@ std::pair, AstArray> Parser::parseG } else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (type) report(type->location, "Expected type pack after '=', got type"); @@ -2495,7 +2488,7 @@ std::pair, AstArray> Parser::parseG seenDefault = true; nextLexeme(); - AstType* defaultType = parseTypeAnnotation(); + AstType* defaultType = parseType(); names.push_back({name, nameLocation, defaultType}); } @@ -2532,7 +2525,7 @@ std::pair, AstArray> Parser::parseG AstArray Parser::parseTypeParams() { - TempVector parameters{scratchTypeOrPackAnnotation}; + TempVector parameters{scratchTypeOrPack}; if (lexer.current().type == '<') { @@ -2541,15 +2534,15 @@ AstArray Parser::parseTypeParams() while (true) { - if (shouldParseTypePackAnnotation(lexer)) + if (shouldParseTypePack(lexer)) { - AstTypePack* typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePack(); parameters.push_back({{}, typePack}); } else if (lexer.current().type == '(') { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (typePack) parameters.push_back({{}, typePack}); @@ -2562,7 +2555,7 @@ AstArray Parser::parseTypeParams() } else { - parameters.push_back({parseTypeAnnotation(), {}}); + parameters.push_back({parseType(), {}}); } if (lexer.current().type == ',') @@ -3018,7 +3011,7 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray(location, expressions, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) +AstTypeError* Parser::reportTypeError(const Location& location, const AstArray& types, const char* format, ...) { va_list args; va_start(args, format); @@ -3028,7 +3021,7 @@ AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const return allocator.alloc(location, types, false, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) +AstTypeError* Parser::reportMissingTypeError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) { va_list args; va_start(args, format); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index d6f1822d4..4fdb04439 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -14,6 +14,7 @@ #endif LUAU_FASTFLAG(DebugLuauTimeTracing) +LUAU_FASTFLAG(LuauLintInTypecheck) enum class ReportFormat { @@ -80,7 +81,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat for (auto& error : cr.errors) reportError(frontend, format, error); - Luau::LintResult lr = frontend.lint(name); + Luau::LintResult lr = FFlag::LuauLintInTypecheck ? cr.lintResult : frontend.lint_DEPRECATED(name); std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); for (auto& error : lr.errors) @@ -263,6 +264,7 @@ int main(int argc, char** argv) Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; + frontendOptions.runLintChecks = FFlag::LuauLintInTypecheck; CliFileResolver fileResolver; CliConfigResolver configResolver(mode); diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 94d8f8114..0179967af 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -93,6 +93,13 @@ class AssemblyBuilderA64 // Assigns label position to the current location void setLabel(Label& label); + // Extracts code offset (in bytes) from label + uint32_t getLabelOffset(const Label& label) + { + LUAU_ASSERT(label.location != ~0u); + return label.location * 4; + } + void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); uint32_t getCodeSize() const; diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 597f2b2c3..17076ed69 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -155,6 +155,13 @@ class AssemblyBuilderX64 // Assigns label position to the current location void setLabel(Label& label); + // Extracts code offset (in bytes) from label + uint32_t getLabelOffset(const Label& label) + { + LUAU_ASSERT(label.location != ~0u); + return label.location; + } + // Constant allocation (uses rip-relative addressing) OperandX64 i64(int64_t value); OperandX64 f32(float value); diff --git a/CodeGen/include/Luau/CodeAllocator.h b/CodeGen/include/Luau/CodeAllocator.h index a6cab4ad3..e0537b644 100644 --- a/CodeGen/include/Luau/CodeAllocator.h +++ b/CodeGen/include/Luau/CodeAllocator.h @@ -21,7 +21,8 @@ struct CodeAllocator // Places data and code into the executable page area // To allow allocation while previously allocated code is already running, allocation has page granularity // It's important to group functions together so that page alignment won't result in a lot of wasted space - bool allocate(uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart); + bool allocate( + const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart); // Provided to callbacks void* context = nullptr; diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 21fa755ca..5c2bc4dfc 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -42,6 +42,7 @@ struct CfgInfo std::vector successorsOffsets; std::vector in; + std::vector def; std::vector out; RegisterSet captured; diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 916c6eeda..e6202c777 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -1,8 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Common.h" #include "Luau/Bytecode.h" +#include "Luau/Common.h" +#include "Luau/DenseHash.h" #include "Luau/IrData.h" #include @@ -19,6 +20,8 @@ struct AssemblyOptions; struct IrBuilder { + IrBuilder(); + void buildFunctionIr(Proto* proto); void rebuildBytecodeBasicBlocks(Proto* proto); @@ -38,7 +41,7 @@ struct IrBuilder IrOp constUint(unsigned value); IrOp constDouble(double value); IrOp constTag(uint8_t value); - IrOp constAny(IrConst constant); + IrOp constAny(IrConst constant, uint64_t asCommonKey); IrOp cond(IrCondition cond); @@ -67,6 +70,45 @@ struct IrBuilder uint32_t activeBlockIdx = ~0u; std::vector instIndexToBlock; // Block index at the bytecode instruction + + // Similar to BytecodeBuilder, duplicate constants are removed used the same method + struct ConstantKey + { + IrConstKind kind; + // Note: this stores value* from IrConst; when kind is Double, this stores the same bits as double does but in uint64_t. + uint64_t value; + + bool operator==(const ConstantKey& key) const + { + return kind == key.kind && value == key.value; + } + }; + + struct ConstantKeyHash + { + size_t operator()(const ConstantKey& key) const + { + // finalizer from MurmurHash64B + const uint32_t m = 0x5bd1e995; + + uint32_t h1 = uint32_t(key.value); + uint32_t h2 = uint32_t(key.value >> 32) ^ (int(key.kind) * m); + + h1 ^= h2 >> 18; + h1 *= m; + h2 ^= h1 >> 22; + h2 *= m; + h1 ^= h2 >> 17; + h1 *= m; + h2 ^= h1 >> 19; + h2 *= m; + + // ... truncated to 32-bit output (normally hash is equal to (uint64_t(h1) << 32) | h2, but we only really need the lower 32-bit half) + return size_t(h2); + } + }; + + DenseHashMap constantMap; }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 439abb9bd..67e706324 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -22,7 +22,7 @@ namespace CodeGen // In the command description, following abbreviations are used: // * Rn - VM stack register slot, n in 0..254 // * Kn - VM proto constant slot, n in 0..2^23-1 -// * UPn - VM function upvalue slot, n in 0..254 +// * UPn - VM function upvalue slot, n in 0..199 // * A, B, C, D, E are instruction arguments enum class IrCmd : uint8_t { @@ -64,6 +64,11 @@ enum class IrCmd : uint8_t // A: pointer (Table) GET_SLOT_NODE_ADDR, + // Get pointer (LuaNode) to table node element at the main position of the specified key hash + // A: pointer (Table) + // B: unsigned int + GET_HASH_NODE_ADDR, + // Store a tag into TValue // A: Rn // B: tag @@ -173,6 +178,13 @@ enum class IrCmd : uint8_t // E: block (if false) JUMP_CMP_ANY, + // Perform a conditional jump based on cached table node slot matching the actual table node slot for a key + // A: pointer (LuaNode) + // B: Kn + // C: block (if matches) + // D: block (if it doesn't) + JUMP_SLOT_MATCH, + // Get table length // A: pointer (Table) TABLE_LEN, @@ -189,7 +201,13 @@ enum class IrCmd : uint8_t // Try to convert a double number into a table index (int) or jump if it's not an integer // A: double // B: block - NUM_TO_INDEX, + TRY_NUM_TO_INDEX, + + // Try to get pointer to tag method TValue inside the table's metatable or jump if there is no such value or metatable + // A: table + // B: int + // C: block + TRY_CALL_FASTGETTM, // Convert integer into a double number // A: int @@ -315,6 +333,11 @@ enum class IrCmd : uint8_t // C: block CHECK_SLOT_MATCH, + // Guard against table node with a linked next node to ensure that our lookup hits the main position of the key + // A: pointer (LuaNode) + // B: block + CHECK_NODE_NO_NEXT, + // Special operations // Check interrupt handler @@ -361,14 +384,6 @@ enum class IrCmd : uint8_t // E: unsigned int (table index to start from) LOP_SETLIST, - // Load function from source register using name into target register and copying source register into target register + 1 - // A: unsigned int (bytecode instruction index) - // B: Rn (target) - // C: Rn (source) - // D: block (next) - // E: block (fallback) - LOP_NAMECALL, - // Call specified function // A: unsigned int (bytecode instruction index) // B: Rn (function, followed by arguments) @@ -576,6 +591,16 @@ struct IrOp , index(index) { } + + bool operator==(const IrOp& rhs) const + { + return kind == rhs.kind && index == rhs.index; + } + + bool operator!=(const IrOp& rhs) const + { + return !(*this == rhs); + } }; static_assert(sizeof(IrOp) == 4); diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index a6329ecf5..ae517e894 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -37,5 +37,9 @@ std::string toString(IrFunction& function, bool includeUseInfo); std::string dump(IrFunction& function); +std::string toDot(IrFunction& function, bool includeInst); + +std::string dumpDot(IrFunction& function, bool includeInst); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 3b14a8c80..0fc140250 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -98,7 +98,7 @@ inline bool isBlockTerminator(IrCmd cmd) case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_ANY: - case IrCmd::LOP_NAMECALL: + case IrCmd::JUMP_SLOT_MATCH: case IrCmd::LOP_RETURN: case IrCmd::LOP_FORGLOOP: case IrCmd::LOP_FORGLOOP_FALLBACK: @@ -125,6 +125,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::LOAD_ENV: case IrCmd::GET_ARR_ADDR: case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::GET_HASH_NODE_ADDR: case IrCmd::ADD_INT: case IrCmd::SUB_INT: case IrCmd::ADD_NUM: @@ -140,7 +141,8 @@ inline bool hasResult(IrCmd cmd) case IrCmd::TABLE_LEN: case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: - case IrCmd::NUM_TO_INDEX: + case IrCmd::TRY_NUM_TO_INDEX: + case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::INT_TO_NUM: case IrCmd::SUBSTITUTE: case IrCmd::INVOKE_FASTCALL: diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index e1950dbc7..4d04a249f 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -112,7 +112,7 @@ CodeAllocator::~CodeAllocator() } bool CodeAllocator::allocate( - uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart) + const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart) { // 'Round up' to preserve code alignment size_t alignedDataSize = (dataSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index c794972d0..ce490f916 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -1,7 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/CodeGen.h" -#include "Luau/AssemblyBuilderX64.h" #include "Luau/Common.h" #include "Luau/CodeAllocator.h" #include "Luau/CodeBlockUnwind.h" @@ -9,12 +8,17 @@ #include "Luau/IrBuilder.h" #include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeFinalX64.h" + #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" #include "Luau/UnwindBuilderWin.h" +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/AssemblyBuilderA64.h" + #include "CustomExecUtils.h" #include "CodeGenX64.h" +#include "CodeGenA64.h" #include "EmitCommonX64.h" #include "EmitInstructionX64.h" #include "IrLoweringX64.h" @@ -39,32 +43,55 @@ namespace Luau namespace CodeGen { -constexpr uint32_t kFunctionAlignment = 32; +static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) +{ + NativeProto* result = new NativeProto(); + + result->proto = proto; + result->instTargets = new uintptr_t[proto->sizecode]; + + for (int i = 0; i < proto->sizecode; i++) + { + auto [irLocation, asmLocation] = ir.function.bcMapping[i]; + + result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation; + } -static void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) + return result; +} + +[[maybe_unused]] static void lowerIr( + X64::AssemblyBuilderX64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { - if (build.logText) - build.logAppend("; exitContinueVm\n"); - helpers.exitContinueVm = build.setLabel(); - emitExit(build, /* continueInVm */ true); + constexpr uint32_t kFunctionAlignment = 32; - if (build.logText) - build.logAppend("; exitNoContinueVm\n"); - helpers.exitNoContinueVm = build.setLabel(); - emitExit(build, /* continueInVm */ false); + optimizeMemoryOperandsX64(ir.function); - if (build.logText) - build.logAppend("; continueCallInVm\n"); - helpers.continueCallInVm = build.setLabel(); - emitContinueCallInVm(build); + build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); + + X64::IrLoweringX64 lowering(build, helpers, data, proto, ir.function); + + lowering.lower(options); } -static NativeProto* assembleFunction(X64::AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +[[maybe_unused]] static void lowerIr( + A64::AssemblyBuilderA64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { - NativeProto* result = new NativeProto(); + Label start = build.setLabel(); - result->proto = proto; + build.mov(A64::x0, 1); // finish function in VM + build.ret(); + // TODO: This is only needed while we don't support all IR opcodes + // When we can't translate some parts of the function, we instead encode a dummy assembly sequence that hands off control to VM + // In the future we could return nullptr from assembleFunction and handle it because there may be other reasons for why we refuse to assemble. + for (int i = 0; i < proto->sizecode; i++) + ir.function.bcMapping[i].asmLocation = build.getLabelOffset(start); +} + +template +static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +{ if (options.includeAssembly || options.includeIr) { if (proto->debugname) @@ -93,43 +120,24 @@ static NativeProto* assembleFunction(X64::AssemblyBuilderX64& build, NativeState build.logAppend("\n"); } - build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); - - Label start = build.setLabel(); - - IrBuilder builder; - builder.buildFunctionIr(proto); + IrBuilder ir; + ir.buildFunctionIr(proto); if (!FFlag::DebugCodegenNoOpt) { - constPropInBlockChains(builder); + constPropInBlockChains(ir); } // TODO: cfg info has to be computed earlier to use in optimizations // It's done here to appear in text output and to measure performance impact on code generation - computeCfgInfo(builder.function); - - optimizeMemoryOperandsX64(builder.function); - - X64::IrLoweringX64 lowering(build, helpers, data, proto, builder.function); - - lowering.lower(options); - - result->instTargets = new uintptr_t[proto->sizecode]; - - for (int i = 0; i < proto->sizecode; i++) - { - auto [irLocation, asmLocation] = builder.function.bcMapping[i]; - - result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation - start.location; - } + computeCfgInfo(ir.function); - result->location = start.location; + lowerIr(build, ir, data, helpers, proto, options); if (build.logText) build.logAppend("\n"); - return result; + return createNativeProto(proto, ir); } static void destroyNativeProto(NativeProto* nativeProto) @@ -207,6 +215,8 @@ bool isSupported() if ((cpuinfo[2] & (1 << 28)) == 0) return false; + return true; +#elif defined(__aarch64__) return true; #else return false; @@ -232,11 +242,19 @@ void create(lua_State* L) initFallbackTable(data); initHelperFunctions(data); +#if defined(__x86_64__) || defined(_M_X64) if (!X64::initEntryFunction(data)) { destroyNativeState(L); return; } +#elif defined(__aarch64__) + if (!A64::initEntryFunction(data)) + { + destroyNativeState(L); + return; + } +#endif lua_ExecutionCallbacks* ecb = getExecutionCallbacks(L); @@ -270,14 +288,21 @@ void compile(lua_State* L, int idx) if (!getNativeState(L)) return; +#if defined(__aarch64__) + A64::AssemblyBuilderA64 build(/* logText= */ false); +#else X64::AssemblyBuilderX64 build(/* logText= */ false); +#endif + NativeState* data = getNativeState(L); std::vector protos; gatherFunctions(protos, clvalue(func)->l.p); ModuleHelpers helpers; - assembleHelpers(build, helpers); +#if !defined(__aarch64__) + X64::assembleHelpers(build, helpers); +#endif std::vector results; results.reserve(protos.size()); @@ -292,8 +317,8 @@ void compile(lua_State* L, int idx) uint8_t* nativeData = nullptr; size_t sizeNativeData = 0; uint8_t* codeStart = nullptr; - if (!data->codeAllocator.allocate( - build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), nativeData, sizeNativeData, codeStart)) + if (!data->codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), + int(build.code.size() * sizeof(build.code[0])), nativeData, sizeNativeData, codeStart)) { for (NativeProto* result : results) destroyNativeProto(result); @@ -305,7 +330,7 @@ void compile(lua_State* L, int idx) for (NativeProto* result : results) { for (int i = 0; i < result->proto->sizecode; i++) - result->instTargets[i] += uintptr_t(codeStart + result->location); + result->instTargets[i] += uintptr_t(codeStart); LUAU_ASSERT(result->proto->sizecode); result->entryTarget = result->instTargets[0]; @@ -321,7 +346,11 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) LUAU_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); +#if defined(__aarch64__) + A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly); +#else X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); +#endif NativeState data; initFallbackTable(data); @@ -330,7 +359,9 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) gatherFunctions(protos, clvalue(func)->l.p); ModuleHelpers helpers; - assembleHelpers(build, helpers); +#if !defined(__aarch64__) + X64::assembleHelpers(build, helpers); +#endif for (Proto* p : protos) if (p) @@ -342,7 +373,9 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) build.finalize(); if (options.outputBinary) - return std::string(build.code.begin(), build.code.end()) + std::string(build.data.begin(), build.data.end()); + return std::string( + reinterpret_cast(build.code.data()), reinterpret_cast(build.code.data() + build.code.size())) + + std::string(build.data.begin(), build.data.end()); else return build.text; } diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp new file mode 100644 index 000000000..94d6f2e3f --- /dev/null +++ b/CodeGen/src/CodeGenA64.cpp @@ -0,0 +1,69 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "CodeGenA64.h" + +#include "Luau/AssemblyBuilderA64.h" +#include "Luau/UnwindBuilder.h" + +#include "CustomExecUtils.h" +#include "NativeState.h" + +#include "lstate.h" + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +bool initEntryFunction(NativeState& data) +{ + AssemblyBuilderA64 build(/* logText= */ false); + UnwindBuilder& unwind = *data.unwindBuilder.get(); + + unwind.start(); + unwind.allocStack(8); // TODO: this is only necessary to align stack by 16 bytes, as start() allocates 8b return pointer + + // TODO: prologue goes here + + unwind.finish(); + + size_t prologueSize = build.setLabel().location; + + // Setup native execution environment + // TODO: figure out state layout + + // Jump to the specified instruction; further control flow will be handled with custom ABI with register setup from EmitCommonX64.h + build.br(x2); + + // Even though we jumped away, we will return here in the end + Label returnOff = build.setLabel(); + + // Cleanup and exit + // TODO: epilogue + + build.ret(); + + build.finalize(); + + LUAU_ASSERT(build.data.empty()); + + if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), + int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, data.context.gateEntry)) + { + LUAU_ASSERT(!"failed to create entry function"); + return false; + } + + // Set the offset at the begining so that functions in new blocks will not overlay the locations + // specified by the unwind information of the entry function + unwind.setBeginOffset(prologueSize); + + data.context.gateExit = data.context.gateEntry + returnOff.location; + + return true; +} + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGenA64.h b/CodeGen/src/CodeGenA64.h new file mode 100644 index 000000000..5043e5c67 --- /dev/null +++ b/CodeGen/src/CodeGenA64.h @@ -0,0 +1,18 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +namespace Luau +{ +namespace CodeGen +{ + +struct NativeState; + +namespace A64 +{ + +bool initEntryFunction(NativeState& data); + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index ac6c9416c..7df1a909d 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -143,6 +143,24 @@ bool initEntryFunction(NativeState& data) return true; } +void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) +{ + if (build.logText) + build.logAppend("; exitContinueVm\n"); + helpers.exitContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ true); + + if (build.logText) + build.logAppend("; exitNoContinueVm\n"); + helpers.exitNoContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ false); + + if (build.logText) + build.logAppend("; continueCallInVm\n"); + helpers.continueCallInVm = build.setLabel(); + emitContinueCallInVm(build); +} + } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenX64.h b/CodeGen/src/CodeGenX64.h index b82266af7..1f4831138 100644 --- a/CodeGen/src/CodeGenX64.h +++ b/CodeGen/src/CodeGenX64.h @@ -7,11 +7,15 @@ namespace CodeGen { struct NativeState; +struct ModuleHelpers; namespace X64 { +class AssemblyBuilderX64; + bool initEntryFunction(NativeState& data); +void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers); } // namespace X64 } // namespace CodeGen diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 05b63551b..d70b6ed8b 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -286,6 +286,31 @@ void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npa build.vmovsd(luauRegValue(ra), tmp0.reg); } +void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +{ + ScopedRegX64 tmp0{regs, SizeX64::qword}; + ScopedRegX64 tag{regs, SizeX64::dword}; + + build.mov(tag.reg, luauRegTag(arg)); + + build.mov(tmp0.reg, qword[rState + offsetof(lua_State, global)]); + build.mov(tmp0.reg, qword[tmp0.reg + qwordReg(tag.reg) * sizeof(TString*) + offsetof(global_State, ttname)]); + + build.mov(luauRegValue(ra), tmp0.reg); +} + +void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +{ + regs.assertAllFree(); + + build.mov(rArg1, rState); + build.lea(rArg2, luauRegAddress(arg)); + + build.call(qword[rNativeContext + offsetof(NativeContext, luaT_objtypenamestr)]); + + build.mov(luauRegValue(ra), rax); +} + void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults) { OperandX64 argsOp = 0; @@ -353,6 +378,10 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r return emitBuiltinMathModf(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SIGN: return emitBuiltinMathSign(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_TYPE: + return emitBuiltinType(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_TYPEOF: + return emitBuiltinTypeof(regs, build, nparams, ra, arg, argsOp, nresults); default: LUAU_ASSERT(!"missing x64 lowering"); break; diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 3b0aa258b..e8f61ebb0 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -18,51 +18,6 @@ namespace CodeGen namespace X64 { -void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback) -{ - int ra = LUAU_INSN_A(*pc); - int rb = LUAU_INSN_B(*pc); - uint32_t aux = pc[1]; - - Label secondfpath; - - jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback); - - RegisterX64 table = r8; - build.mov(table, luauRegValue(rb)); - - // &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; - RegisterX64 node = rdx; - build.mov(node, qword[table + offsetof(Table, node)]); - build.mov(eax, 1); - build.mov(cl, byte[table + offsetof(Table, lsizenode)]); - build.shl(eax, cl); - build.dec(eax); - build.and_(eax, tsvalue(&k[aux])->hash); - build.shl(rax, kLuaNodeSizeLog2); - build.add(node, rax); - - jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), secondfpath); - - setLuauReg(build, xmm0, ra + 1, luauReg(rb)); - setLuauReg(build, xmm0, ra, luauNodeValue(node)); - build.jmp(next); - - build.setLabel(secondfpath); - - jumpIfNodeHasNext(build, node, fallback); - callGetFastTmOrFallback(build, table, TM_INDEX, fallback); - jumpIfTagIsNot(build, rax, LUA_TTABLE, fallback); - - build.mov(table, qword[rax + offsetof(TValue, value)]); - - getTableNodeAtCachedSlot(build, rax, node, table, pcpos); - jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback); - - setLuauReg(build, xmm0, ra + 1, luauReg(rb)); - setLuauReg(build, xmm0, ra, luauNodeValue(node)); -} - void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos) { int ra = LUAU_INSN_A(*pc); diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index dcca52ab6..6a8a3c0ee 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -21,7 +21,6 @@ namespace X64 class AssemblyBuilderX64; -void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback); void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index dc7d771ec..b998487f9 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -124,6 +124,10 @@ static void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& de { if (!defRs.varargSeq) { + // Peel away registers from variadic sequence that we define + while (defRs.regs.test(varargStart)) + varargStart++; + LUAU_ASSERT(!sourceRs.varargSeq || sourceRs.varargStart == varargStart); sourceRs.varargSeq = true; @@ -296,11 +300,6 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& use(inst.b); useRange(inst.c.index, function.intOp(inst.d)); break; - case IrCmd::LOP_NAMECALL: - use(inst.c); - - defRange(inst.b.index, 2); - break; case IrCmd::LOP_CALL: use(inst.b); useRange(inst.b.index + 1, function.intOp(inst.c)); @@ -411,6 +410,13 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& break; default: + // All instructions which reference registers have to be handled explicitly + LUAU_ASSERT(inst.a.kind != IrOpKind::VmReg); + LUAU_ASSERT(inst.b.kind != IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind != IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind != IrOpKind::VmReg); + LUAU_ASSERT(inst.e.kind != IrOpKind::VmReg); + LUAU_ASSERT(inst.f.kind != IrOpKind::VmReg); break; } } @@ -430,17 +436,20 @@ static void computeCfgLiveInOutRegSets(IrFunction& function) { CfgInfo& info = function.cfg; + // Clear existing data + // 'in' and 'captured' data is not cleared because it will be overwritten below + info.def.clear(); + info.out.clear(); + // Try to compute Luau VM register use-def info info.in.resize(function.blocks.size()); + info.def.resize(function.blocks.size()); info.out.resize(function.blocks.size()); // Captured registers are tracked for the whole function // It should be possible to have a more precise analysis for them in the future std::bitset<256> capturedRegs; - std::vector defRss; - defRss.resize(function.blocks.size()); - // First we compute live-in set of each block for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) { @@ -449,7 +458,7 @@ static void computeCfgLiveInOutRegSets(IrFunction& function) if (block.kind == IrBlockKind::Dead) continue; - info.in[blockIdx] = computeBlockLiveInRegSet(function, block, defRss[blockIdx], capturedRegs); + info.in[blockIdx] = computeBlockLiveInRegSet(function, block, info.def[blockIdx], capturedRegs); } info.captured.regs = capturedRegs; @@ -480,8 +489,8 @@ static void computeCfgLiveInOutRegSets(IrFunction& function) IrBlock& curr = function.blocks[blockIdx]; RegisterSet& inRs = info.in[blockIdx]; + RegisterSet& defRs = info.def[blockIdx]; RegisterSet& outRs = info.out[blockIdx]; - RegisterSet& defRs = defRss[blockIdx]; // Current block has to provide all registers in successor blocks for (uint32_t succIdx : successors(info, blockIdx)) @@ -547,6 +556,10 @@ static void computeCfgBlockEdges(IrFunction& function) { CfgInfo& info = function.cfg; + // Clear existing data + info.predecessorsOffsets.clear(); + info.successorsOffsets.clear(); + // Compute predecessors block edges info.predecessorsOffsets.reserve(function.blocks.size()); info.successorsOffsets.reserve(function.blocks.size()); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 0a700dba6..f1099cfac 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrBuilder.h" -#include "Luau/Common.h" -#include "Luau/DenseHash.h" #include "Luau/IrAnalysis.h" #include "Luau/IrUtils.h" @@ -11,6 +9,8 @@ #include "lapi.h" +#include + namespace Luau { namespace CodeGen @@ -18,6 +18,11 @@ namespace CodeGen constexpr unsigned kNoAssociatedBlockIndex = ~0u; +IrBuilder::IrBuilder() + : constantMap({IrConstKind::Bool, ~0ull}) +{ +} + void IrBuilder::buildFunctionIr(Proto* proto) { function.proto = proto; @@ -377,19 +382,8 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstCapture(*this, pc, i); break; case LOP_NAMECALL: - { - IrOp next = blockAtInst(i + getOpLength(LOP_NAMECALL)); - IrOp fallback = block(IrBlockKind::Fallback); - - inst(IrCmd::LOP_NAMECALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), next, fallback); - - beginBlock(fallback); - inst(IrCmd::FALLBACK_NAMECALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(pc[1])); - inst(IrCmd::JUMP, next); - - beginBlock(next); + translateInstNamecall(*this, pc, i); break; - } case LOP_PREPVARARGS: inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc))); break; @@ -501,7 +495,7 @@ IrOp IrBuilder::constBool(bool value) IrConst constant; constant.kind = IrConstKind::Bool; constant.valueBool = value; - return constAny(constant); + return constAny(constant, uint64_t(value)); } IrOp IrBuilder::constInt(int value) @@ -509,7 +503,7 @@ IrOp IrBuilder::constInt(int value) IrConst constant; constant.kind = IrConstKind::Int; constant.valueInt = value; - return constAny(constant); + return constAny(constant, uint64_t(value)); } IrOp IrBuilder::constUint(unsigned value) @@ -517,7 +511,7 @@ IrOp IrBuilder::constUint(unsigned value) IrConst constant; constant.kind = IrConstKind::Uint; constant.valueUint = value; - return constAny(constant); + return constAny(constant, uint64_t(value)); } IrOp IrBuilder::constDouble(double value) @@ -525,7 +519,12 @@ IrOp IrBuilder::constDouble(double value) IrConst constant; constant.kind = IrConstKind::Double; constant.valueDouble = value; - return constAny(constant); + + uint64_t asCommonKey; + static_assert(sizeof(asCommonKey) == sizeof(value), "Expecting double to be 64-bit"); + memcpy(&asCommonKey, &value, sizeof(value)); + + return constAny(constant, asCommonKey); } IrOp IrBuilder::constTag(uint8_t value) @@ -533,13 +532,21 @@ IrOp IrBuilder::constTag(uint8_t value) IrConst constant; constant.kind = IrConstKind::Tag; constant.valueTag = value; - return constAny(constant); + return constAny(constant, uint64_t(value)); } -IrOp IrBuilder::constAny(IrConst constant) +IrOp IrBuilder::constAny(IrConst constant, uint64_t asCommonKey) { + ConstantKey key{constant.kind, asCommonKey}; + + if (uint32_t* cache = constantMap.find(key)) + return {IrOpKind::Constant, *cache}; + uint32_t index = uint32_t(function.constants.size()); function.constants.push_back(constant); + + constantMap[key] = index; + return {IrOpKind::Constant, index}; } diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 2787fb11f..3c4e420d8 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -90,6 +90,8 @@ const char* getCmdName(IrCmd cmd) return "GET_ARR_ADDR"; case IrCmd::GET_SLOT_NODE_ADDR: return "GET_SLOT_NODE_ADDR"; + case IrCmd::GET_HASH_NODE_ADDR: + return "GET_HASH_NODE_ADDR"; case IrCmd::STORE_TAG: return "STORE_TAG"; case IrCmd::STORE_POINTER: @@ -142,14 +144,18 @@ const char* getCmdName(IrCmd cmd) return "JUMP_CMP_NUM"; case IrCmd::JUMP_CMP_ANY: return "JUMP_CMP_ANY"; + case IrCmd::JUMP_SLOT_MATCH: + return "JUMP_SLOT_MATCH"; case IrCmd::TABLE_LEN: return "TABLE_LEN"; case IrCmd::NEW_TABLE: return "NEW_TABLE"; case IrCmd::DUP_TABLE: return "DUP_TABLE"; - case IrCmd::NUM_TO_INDEX: - return "NUM_TO_INDEX"; + case IrCmd::TRY_NUM_TO_INDEX: + return "TRY_NUM_TO_INDEX"; + case IrCmd::TRY_CALL_FASTGETTM: + return "TRY_CALL_FASTGETTM"; case IrCmd::INT_TO_NUM: return "INT_TO_NUM"; case IrCmd::ADJUST_STACK_TO_REG: @@ -192,6 +198,8 @@ const char* getCmdName(IrCmd cmd) return "CHECK_ARRAY_SIZE"; case IrCmd::CHECK_SLOT_MATCH: return "CHECK_SLOT_MATCH"; + case IrCmd::CHECK_NODE_NO_NEXT: + return "CHECK_NODE_NO_NEXT"; case IrCmd::INTERRUPT: return "INTERRUPT"; case IrCmd::CHECK_GC: @@ -210,8 +218,6 @@ const char* getCmdName(IrCmd cmd) return "CAPTURE"; case IrCmd::LOP_SETLIST: return "LOP_SETLIST"; - case IrCmd::LOP_NAMECALL: - return "LOP_NAMECALL"; case IrCmd::LOP_CALL: return "LOP_CALL"; case IrCmd::LOP_RETURN: @@ -397,7 +403,7 @@ static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks) } } -static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs) +static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs, const char* separator) { bool comma = false; @@ -406,7 +412,7 @@ static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs) if (rs.regs.test(i)) { if (comma) - append(ctx.result, ", "); + ctx.result.append(separator); comma = true; append(ctx.result, "R%d", int(i)); @@ -416,7 +422,7 @@ static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs) if (rs.varargSeq) { if (comma) - append(ctx.result, ", "); + ctx.result.append(separator); append(ctx.result, "R%d...", rs.varargStart); } @@ -428,7 +434,7 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind if (block.useCount == 0 && block.kind != IrBlockKind::Dead && ctx.cfg.captured.regs.any()) { append(ctx.result, "; captured regs: "); - appendRegisterSet(ctx, ctx.cfg.captured); + appendRegisterSet(ctx, ctx.cfg.captured, ", "); append(ctx.result, "\n\n"); } @@ -484,7 +490,7 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind if (in.regs.any() || in.varargSeq) { append(ctx.result, "; in regs: "); - appendRegisterSet(ctx, in); + appendRegisterSet(ctx, in, ", "); append(ctx.result, "\n"); } } @@ -497,7 +503,7 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind if (out.regs.any() || out.varargSeq) { append(ctx.result, "; out regs: "); - appendRegisterSet(ctx, out); + appendRegisterSet(ctx, out, ", "); append(ctx.result, "\n"); } } @@ -551,5 +557,108 @@ std::string dump(IrFunction& function) return result; } +std::string toDot(IrFunction& function, bool includeInst) +{ + std::string result; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + + auto appendLabelRegset = [&ctx](std::vector& regSets, size_t blockIdx, const char* name) { + if (blockIdx < regSets.size()) + { + const RegisterSet& rs = regSets[blockIdx]; + + if (rs.regs.any() || rs.varargSeq) + { + append(ctx.result, "|{%s|", name); + appendRegisterSet(ctx, rs, "|"); + append(ctx.result, "}"); + } + } + }; + + append(ctx.result, "digraph CFG {\n"); + append(ctx.result, "node[shape=record]\n"); + + for (size_t i = 0; i < function.blocks.size(); i++) + { + IrBlock& block = function.blocks[i]; + + append(ctx.result, "b%u [", unsigned(i)); + + if (block.kind == IrBlockKind::Fallback) + append(ctx.result, "style=filled;fillcolor=salmon;"); + else if (block.kind == IrBlockKind::Bytecode) + append(ctx.result, "style=filled;fillcolor=palegreen;"); + + append(ctx.result, "label=\"{"); + toString(ctx, block, uint32_t(i)); + + appendLabelRegset(ctx.cfg.in, i, "in"); + + if (includeInst && block.start != ~0u) + { + for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) + { + IrInst& inst = function.instructions[instIdx]; + + // Skip pseudo instructions unless they are still referenced + if (isPseudo(inst.cmd) && inst.useCount == 0) + continue; + + append(ctx.result, "|"); + toString(ctx, inst, instIdx); + } + } + + appendLabelRegset(ctx.cfg.def, i, "def"); + appendLabelRegset(ctx.cfg.out, i, "out"); + + append(ctx.result, "}\"];\n"); + } + + for (size_t i = 0; i < function.blocks.size(); i++) + { + IrBlock& block = function.blocks[i]; + + if (block.start == ~0u) + continue; + + for (uint32_t instIdx = block.start; instIdx != ~0u && instIdx <= block.finish; instIdx++) + { + IrInst& inst = function.instructions[instIdx]; + + auto checkOp = [&](IrOp op) { + if (op.kind == IrOpKind::Block) + { + if (function.blocks[op.index].kind != IrBlockKind::Fallback) + append(ctx.result, "b%u -> b%u [weight=10];\n", unsigned(i), op.index); + else + append(ctx.result, "b%u -> b%u;\n", unsigned(i), op.index); + } + }; + + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + checkOp(inst.f); + } + } + + append(ctx.result, "}\n"); + + return result; +} + +std::string dumpDot(IrFunction& function, bool includeInst) +{ + std::string result = toDot(function, includeInst); + + printf("%s\n", result.c_str()); + + return result; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 3b27d09fc..b45ce2261 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -200,6 +200,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(inst.regX64, luauRegValue(inst.a.index)); else if (inst.a.kind == IrOpKind::VmConst) build.mov(inst.regX64, luauConstantValue(inst.a.index)); + // If we have a register, we assume it's a pointer to TValue + // We might introduce explicit operand types in the future to make this more robust + else if (inst.a.kind == IrOpKind::Inst) + build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(TValue, value)]); else LUAU_ASSERT(!"Unsupported instruction form"); break; @@ -277,6 +281,25 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) getTableNodeAtCachedSlot(build, tmp.reg, inst.regX64, regOp(inst.a), uintOp(inst.b)); break; } + case IrCmd::GET_HASH_NODE_ADDR: + { + inst.regX64 = regs.allocGprReg(SizeX64::qword); + + // Custom bit shift value can only be placed in cl + ScopedRegX64 shiftTmp{regs, regs.takeGprReg(rcx)}; + + ScopedRegX64 tmp{regs, SizeX64::qword}; + + build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, node)]); + build.mov(dwordReg(tmp.reg), 1); + build.mov(byteReg(shiftTmp.reg), byte[regOp(inst.a) + offsetof(Table, lsizenode)]); + build.shl(dwordReg(tmp.reg), byteReg(shiftTmp.reg)); + build.dec(dwordReg(tmp.reg)); + build.and_(dwordReg(tmp.reg), uintOp(inst.b)); + build.shl(tmp.reg, kLuaNodeSizeLog2); + build.add(inst.regX64, tmp.reg); + break; + }; case IrCmd::STORE_TAG: LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); @@ -686,6 +709,16 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.e), next); break; } + case IrCmd::JUMP_SLOT_MATCH: + { + LUAU_ASSERT(inst.b.kind == IrOpKind::VmConst); + + ScopedRegX64 tmp{regs, SizeX64::qword}; + + jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(inst.b.index), labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.c), next); + break; + } case IrCmd::TABLE_LEN: inst.regX64 = regs.allocXmmReg(); @@ -715,7 +748,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (inst.regX64 != rax) build.mov(inst.regX64, rax); break; - case IrCmd::NUM_TO_INDEX: + case IrCmd::TRY_NUM_TO_INDEX: { inst.regX64 = regs.allocGprReg(SizeX64::dword); @@ -724,6 +757,16 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) convertNumberToIndexOrJump(build, tmp.reg, regOp(inst.a), inst.regX64, labelOp(inst.b)); break; } + case IrCmd::TRY_CALL_FASTGETTM: + { + inst.regX64 = regs.allocGprReg(SizeX64::qword); + + callGetFastTmOrFallback(build, regOp(inst.a), TMS(intOp(inst.b)), labelOp(inst.c)); + + if (inst.regX64 != rax) + build.mov(inst.regX64, rax); + break; + } case IrCmd::INT_TO_NUM: inst.regX64 = regs.allocXmmReg(); @@ -1017,6 +1060,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(inst.b.index), labelOp(inst.c)); break; } + case IrCmd::CHECK_NODE_NO_NEXT: + jumpIfNodeHasNext(build, regOp(inst.a), labelOp(inst.b)); + break; case IrCmd::INTERRUPT: emitInterrupt(build, uintOp(inst.a)); break; @@ -1114,16 +1160,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.setLabel(next); break; } - case IrCmd::LOP_NAMECALL: - { - const Instruction* pc = proto->code + uintOp(inst.a); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - - emitInstNameCall(build, pc, uintOp(inst.a), proto->k, blockOp(inst.d).label, blockOp(inst.e).label); - jumpOrFallthrough(blockOp(inst.d), next); - break; - } case IrCmd::LOP_CALL: { const Instruction* pc = proto->code + uintOp(inst.a); diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index bc909105d..d9f935c49 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -210,6 +210,34 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r return {BuiltinImplType::UsesFallback, 1}; } +BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.inst( + IrCmd::FASTCALL, build.constUint(LBF_TYPE), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.inst( + IrCmd::FASTCALL, build.constUint(LBF_TYPEOF), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback) { switch (bfid) @@ -254,6 +282,10 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_FREXP: case LBF_MATH_MODF: return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + case LBF_TYPE: + return translateBuiltinType(build, nparams, ra, arg, args, nresults, fallback); + case LBF_TYPEOF: + return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults, fallback); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 48ca3975b..28c6aca19 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -806,7 +806,7 @@ void translateInstGetTable(IrBuilder& build, const Instruction* pc, int pcpos) IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rc)); - IrOp index = build.inst(IrCmd::NUM_TO_INDEX, vc, fallback); + IrOp index = build.inst(IrCmd::TRY_NUM_TO_INDEX, vc, fallback); index = build.inst(IrCmd::SUB_INT, index, build.constInt(1)); @@ -843,7 +843,7 @@ void translateInstSetTable(IrBuilder& build, const Instruction* pc, int pcpos) IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rc)); - IrOp index = build.inst(IrCmd::NUM_TO_INDEX, vc, fallback); + IrOp index = build.inst(IrCmd::TRY_NUM_TO_INDEX, vc, fallback); index = build.inst(IrCmd::SUB_INT, index, build.constInt(1)); @@ -1035,5 +1035,63 @@ void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos) } } +void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + uint32_t aux = pc[1]; + + IrOp next = build.blockAtInst(pcpos + getOpLength(LOP_NAMECALL)); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp firstFastPathSuccess = build.block(IrBlockKind::Internal); + IrOp secondFastPath = build.block(IrBlockKind::Internal); + + build.loadAndCheckTag(build.vmReg(rb), LUA_TTABLE, fallback); + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + + LUAU_ASSERT(build.function.proto); + IrOp addrNodeEl = build.inst(IrCmd::GET_HASH_NODE_ADDR, table, build.constUint(tsvalue(&build.function.proto->k[aux])->hash)); + + // We use 'jump' version instead of 'check' guard because we are jumping away into a non-fallback block + // This is required by CFG live range analysis because both non-fallback blocks define the same registers + build.inst(IrCmd::JUMP_SLOT_MATCH, addrNodeEl, build.vmConst(aux), firstFastPathSuccess, secondFastPath); + + build.beginBlock(firstFastPathSuccess); + build.inst(IrCmd::STORE_POINTER, build.vmReg(ra + 1), table); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TTABLE)); + + IrOp nodeEl = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrNodeEl); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), nodeEl); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(secondFastPath); + + build.inst(IrCmd::CHECK_NODE_NO_NEXT, addrNodeEl, fallback); + + IrOp indexPtr = build.inst(IrCmd::TRY_CALL_FASTGETTM, table, build.constInt(TM_INDEX), fallback); + + build.loadAndCheckTag(indexPtr, LUA_TTABLE, fallback); + IrOp index = build.inst(IrCmd::LOAD_POINTER, indexPtr); + + IrOp addrIndexNodeEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, index, build.constUint(pcpos)); + build.inst(IrCmd::CHECK_SLOT_MATCH, addrIndexNodeEl, build.vmConst(aux), fallback); + + // TODO: original 'table' was clobbered by a call inside 'FASTGETTM' + // Ideally, such calls should have to effect on SSA IR values, but simple register allocator doesn't support it + IrOp table2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(ra + 1), table2); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TTABLE)); + + IrOp indexNodeEl = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrIndexNodeEl); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), indexNodeEl); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(fallback); + build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 0d4a5096c..0be111dca 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -60,6 +60,7 @@ void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index d8115be9f..e29a5b029 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -360,7 +360,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 replace(function, block, index, {IrCmd::JUMP, inst.e}); } break; - case IrCmd::NUM_TO_INDEX: + case IrCmd::TRY_NUM_TO_INDEX: if (inst.a.kind == IrOpKind::Constant) { double value = function.doubleOp(inst.a); diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 33db54e55..f79bcab85 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -80,6 +80,7 @@ void initHelperFunctions(NativeState& data) data.context.luaF_close = luaF_close; data.context.luaT_gettm = luaT_gettm; + data.context.luaT_objtypenamestr = luaT_objtypenamestr; data.context.libm_exp = exp; data.context.libm_pow = pow; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index ad5aca66d..bebf421b9 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -39,7 +39,6 @@ struct NativeProto uintptr_t* instTargets = nullptr; // TODO: NativeProto should be variable-size with all target embedded Proto* proto = nullptr; - uint32_t location = 0; }; struct NativeContext @@ -79,6 +78,7 @@ struct NativeContext void (*luaF_close)(lua_State* L, StkId level) = nullptr; const TValue* (*luaT_gettm)(Table* events, TMS event, TString* ename) = nullptr; + const TString* (*luaT_objtypenamestr)(lua_State* L, const TValue* o) = nullptr; double (*libm_exp)(double) = nullptr; double (*libm_pow)(double, double) = nullptr; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 956c96d63..b12a9b946 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -319,10 +319,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& { if (inst.b.kind == IrOpKind::Constant) { - std::optional oldValue = function.asDoubleOp(state.tryGetValue(inst.a)); - double newValue = function.doubleOp(inst.b); - - if (oldValue && *oldValue == newValue) + if (state.tryGetValue(inst.a) == inst.b) kill(function, inst); else state.saveValue(inst.a, inst.b); @@ -338,10 +335,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& { if (inst.b.kind == IrOpKind::Constant) { - std::optional oldValue = function.asIntOp(state.tryGetValue(inst.a)); - int newValue = function.intOp(inst.b); - - if (oldValue && *oldValue == newValue) + if (state.tryGetValue(inst.a) == inst.b) kill(function, inst); else state.saveValue(inst.a, inst.b); @@ -504,6 +498,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::LOAD_ENV: case IrCmd::GET_ARR_ADDR: case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::GET_HASH_NODE_ADDR: case IrCmd::STORE_NODE_VALUE_TV: case IrCmd::ADD_INT: case IrCmd::SUB_INT: @@ -519,13 +514,16 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::NOT_ANY: case IrCmd::JUMP: case IrCmd::JUMP_EQ_POINTER: + case IrCmd::JUMP_SLOT_MATCH: case IrCmd::TABLE_LEN: case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: - case IrCmd::NUM_TO_INDEX: + case IrCmd::TRY_NUM_TO_INDEX: + case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::INT_TO_NUM: case IrCmd::CHECK_ARRAY_SIZE: case IrCmd::CHECK_SLOT_MATCH: + case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::BARRIER_TABLE_BACK: case IrCmd::LOP_RETURN: case IrCmd::LOP_COVERAGE: @@ -552,7 +550,6 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::CONCAT: case IrCmd::PREPARE_FORN: case IrCmd::INTERRUPT: // TODO: it will be important to keep tag/value state, but we have to track register capture - case IrCmd::LOP_NAMECALL: case IrCmd::LOP_CALL: case IrCmd::LOP_FORGLOOP: case IrCmd::LOP_FORGLOOP_FALLBACK: @@ -633,7 +630,7 @@ static std::vector collectDirectBlockJumpPath(IrFunction& function, st // * if the successor has multiple uses, it can't have such 'live in' values without phi nodes that we don't have yet // Another possibility is to have two paths from 'block' into the target through two intermediate blocks // Usually that would mean that we would have a conditional jump at the end of 'block' - // But using check guards and fallback clocks it becomes a possible setup + // But using check guards and fallback blocks it becomes a possible setup // We avoid this by making sure fallbacks rejoin the other immediate successor of 'block' LUAU_ASSERT(getLiveOutValueCount(function, *block) == 0); diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index a95ed0941..0b3134ba3 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -201,6 +201,7 @@ void UnwindBuilderDwarf2::setupFrameReg(X64::RegisterX64 reg, int espOffset) void UnwindBuilderDwarf2::finish() { LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); + LUAU_ASSERT(fdeEntryStart != nullptr); pos = alignPosition(fdeEntryStart, pos); writeu32(fdeEntryStart, unsigned(pos - fdeEntryStart - 4)); // Length field itself is excluded from length @@ -220,7 +221,9 @@ void UnwindBuilderDwarf2::finalize(char* target, void* funcAddress, size_t funcS { memcpy(target, rawData, getSize()); + LUAU_ASSERT(fdeEntryStart != nullptr); unsigned fdeEntryStartPos = unsigned(fdeEntryStart - rawData); + writeu64((uint8_t*)target + fdeEntryStartPos + kFdeInitialLocationOffset, uintptr_t(funcAddress)); writeu64((uint8_t*)target + fdeEntryStartPos + kFdeAddressRangeOffset, funcSize); } diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 67f9fbeab..82bf6e5a3 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -25,7 +25,7 @@ // Additionally, in some specific instructions such as ANDK, the limit on the encoded value is smaller; this means that if a value is larger, a different instruction must be selected. // // Registers: 0-254. Registers refer to the values on the function's stack frame, including arguments. -// Upvalues: 0-254. Upvalues refer to the values stored in the closure object. +// Upvalues: 0-199. Upvalues refer to the values stored in the closure object. // Constants: 0-2^23-1. Constants are stored in a table allocated with each proto; to allow for future bytecode tweaks the encodable value is limited to 23 bits. // Closures: 0-2^15-1. Closures are created from child protos via a child index; the limit is for the number of closures immediately referenced in each function. // Jumps: -2^23..2^23. Jump offsets are specified in word increments, so jumping over an instruction may sometimes require an offset of 2 or more. Note that for jump instructions with AUX, the AUX word is included as part of the jump offset. @@ -93,12 +93,12 @@ enum LuauOpcode // GETUPVAL: load upvalue from the upvalue table for the current function // A: target register - // B: upvalue index (0..255) + // B: upvalue index LOP_GETUPVAL, // SETUPVAL: store value into the upvalue table for the current function // A: target register - // B: upvalue index (0..255) + // B: upvalue index LOP_SETUPVAL, // CLOSEUPVALS: close (migrate to heap) all upvalues that were captured for registers >= target diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 35e11ca50..8eca1050a 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -11,8 +11,9 @@ inline bool isFlagExperimental(const char* flag) // Flags in this list are disabled by default in various command-line tools. They may have behavior that is not fully final, // or critical bugs that are found after the code has been submitted. static const char* const kList[] = { - "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code - "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) + "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code + "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) + "LuauTinyControlFlowAnalysis", // waiting for updates to packages depended by internal builtin plugins // makes sure we always have at least one entry nullptr, }; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 78896d311..03f4b3e69 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -25,7 +25,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileTerminateBC, false) LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinArity, false) namespace Luau @@ -143,7 +142,7 @@ struct Compiler return stat->body.size > 0 && alwaysTerminates(stat->body.data[stat->body.size - 1]); else if (node->is()) return true; - else if (FFlag::LuauCompileTerminateBC && (node->is() || node->is())) + else if (node->is() || node->is()) return true; else if (AstStatIf* stat = node->as()) return stat->elsebody && alwaysTerminates(stat->thenbody) && alwaysTerminates(stat->elsebody); diff --git a/Makefile b/Makefile index 66d6016d5..585122938 100644 --- a/Makefile +++ b/Makefile @@ -143,6 +143,9 @@ aliases: $(EXECUTABLE_ALIASES) test: $(TESTS_TARGET) $(TESTS_TARGET) $(TESTS_ARGS) +conformance: $(TESTS_TARGET) + $(TESTS_TARGET) $(TESTS_ARGS) -ts=Conformance + clean: rm -rf $(BUILD) rm -rf $(EXECUTABLE_ALIASES) diff --git a/Sources.cmake b/Sources.cmake index 88c6e9b63..6e0a32ed7 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -135,6 +135,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Constraint.h Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintSolver.h + Analysis/include/Luau/ControlFlow.h Analysis/include/Luau/DataFlowGraph.h Analysis/include/Luau/DcrLogger.h Analysis/include/Luau/Def.h @@ -370,6 +371,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.annotations.test.cpp tests/TypeInfer.anyerror.test.cpp tests/TypeInfer.builtins.test.cpp + tests/TypeInfer.cfa.test.cpp tests/TypeInfer.classes.test.cpp tests/TypeInfer.definitions.test.cpp tests/TypeInfer.functions.test.cpp diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 8d59ecbc8..5eceea746 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -33,6 +33,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauArrBoundResizeFix, false) + // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -454,15 +456,43 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) int nasize = numusearray(t, nums); // count keys in array part int totaluse = nasize; // all those keys are integer keys totaluse += numusehash(t, nums, &nasize); // count keys in hash part + // count extra key if (ttisnumber(ek)) nasize += countint(nvalue(ek), nums); totaluse++; + // compute new size for array part int na = computesizes(nums, &nasize); int nh = totaluse - na; - // enforce the boundary invariant; for performance, only do hash lookups if we must - nasize = adjustasize(t, nasize, ek); + + if (FFlag::LuauArrBoundResizeFix) + { + // enforce the boundary invariant; for performance, only do hash lookups if we must + int nadjusted = adjustasize(t, nasize, ek); + + // count how many extra elements belong to array part instead of hash part + int aextra = nadjusted - nasize; + + if (aextra != 0) + { + // we no longer need to store those extra array elements in hash part + nh -= aextra; + + // because hash nodes are twice as large as array nodes, the memory we saved for hash parts can be used by array part + // this follows the general sparse array part optimization where array is allocated when 50% occupation is reached + nasize = nadjusted + aextra; + + // since the size was changed, it's again important to enforce the boundary invariant at the new size + nasize = adjustasize(t, nasize, ek); + } + } + else + { + // enforce the boundary invariant; for performance, only do hash lookups if we must + nasize = adjustasize(t, nasize, ek); + } + // resize the table to new computed sizes resize(L, t, nasize, nh); } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 135a555ab..c9d0c01d1 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -1677,8 +1677,6 @@ RETURN R0 0 TEST_CASE("LoopBreak") { - ScopedFastFlag sff("LuauCompileTerminateBC", true); - // default codegen: compile breaks as unconditional jumps CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break else end end"), R"( L0: GETIMPORT R0 2 [math.random] @@ -1703,8 +1701,6 @@ L1: RETURN R0 0 TEST_CASE("LoopContinue") { - ScopedFastFlag sff("LuauCompileTerminateBC", true); - // default codegen: compile continue as unconditional jumps CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue else end break until false error()"), R"( L0: GETIMPORT R0 2 [math.random] @@ -2214,6 +2210,46 @@ TEST_CASE("RecursionParse") { CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); } + + try + { + Luau::compileOrThrow(bcb, "local f: " + rep("(", 1500) + "nil" + rep(")", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "local f: () " + rep("-> ()", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "local f: " + rep("{x:", 1500) + "nil" + rep("}", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "local f: " + rep("(nil & ", 1500) + "nil" + rep(")", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } } TEST_CASE("ArrayIndexLiteral") @@ -6816,8 +6852,6 @@ RETURN R0 0 TEST_CASE("ElideJumpAfterIf") { - ScopedFastFlag sff("LuauCompileTerminateBC", true); - // break refers to outer loop => we can elide unconditional branches CHECK_EQ("\n" + compileFunction0(R"( local foo, bar = ... diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index 81e5c41b7..cc239b7ec 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -13,7 +13,7 @@ ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() { mainModule->reduction = std::make_unique(NotNull{&mainModule->internalTypes}, builtinTypes, NotNull{&ice}); - BlockedType::nextIndex = 0; + BlockedType::DEPRECATED_nextIndex = 0; BlockedTypePack::nextIndex = 0; } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index a9c94eefc..4d2e83fc2 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -137,7 +137,7 @@ const Config& TestConfigResolver::getConfig(const ModuleName& name) const Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) , frontend(&fileResolver, &configResolver, - {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* randomConstraintResolutionSeed */ randomSeed}) + {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* runLintChecks */ false, /* randomConstraintResolutionSeed */ randomSeed}) , builtinTypes(frontend.builtinTypes) { configResolver.defaultConfig.mode = Mode::Strict; @@ -173,15 +173,19 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars // if AST is available, check how lint and typecheck handle error nodes if (result.root) { - frontend.lint(*sourceModule); - if (FFlag::DebugLuauDeferredConstraintResolution) { - Luau::check(*sourceModule, {}, builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, + ModulePtr module = Luau::check(*sourceModule, {}, builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, frontend.globals.globalScope, frontend.options); + + Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); } else - frontend.typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + { + ModulePtr module = frontend.typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + + Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); + } } throw ParseErrors(result.errors); @@ -209,20 +213,23 @@ CheckResult Fixture::check(const std::string& source) LintResult Fixture::lint(const std::string& source, const std::optional& lintOptions) { - ParseOptions parseOptions; - parseOptions.captureComments = true; - configResolver.defaultConfig.mode = Mode::Nonstrict; - parse(source, parseOptions); + ModuleName mm = fromString(mainModuleName); + configResolver.defaultConfig.mode = Mode::Strict; + fileResolver.source[mm] = std::move(source); + frontend.markDirty(mm); - return frontend.lint(*sourceModule, lintOptions); + return lintModule(mm); } -LintResult Fixture::lintTyped(const std::string& source, const std::optional& lintOptions) +LintResult Fixture::lintModule(const ModuleName& moduleName, const std::optional& lintOptions) { - check(source); - ModuleName mm = fromString(mainModuleName); + FrontendOptions options = frontend.options; + options.runLintChecks = true; + options.enabledLintWarnings = lintOptions; + + CheckResult result = frontend.check(moduleName, options); - return frontend.lint(mm, lintOptions); + return result.lintResult; } ParseResult Fixture::parseEx(const std::string& source, const ParseOptions& options) diff --git a/tests/Fixture.h b/tests/Fixture.h index 5db6ed165..4c49593cc 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -66,7 +66,7 @@ struct Fixture CheckResult check(const std::string& source); LintResult lint(const std::string& source, const std::optional& lintOptions = {}); - LintResult lintTyped(const std::string& source, const std::optional& lintOptions = {}); + LintResult lintModule(const ModuleName& moduleName, const std::optional& lintOptions = {}); /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); @@ -94,6 +94,7 @@ struct Fixture TypeId requireTypeAlias(const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; + ScopedFastFlag luauLintInTypecheck{"LuauLintInTypecheck", true}; TestFileResolver fileResolver; TestConfigResolver configResolver; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index e09990fb8..3b1ec4ad1 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -456,16 +456,16 @@ TEST_CASE_FIXTURE(FrontendFixture, "dont_reparse_clean_file_when_linting") end )"; - frontend.check("Modules/A"); + configResolver.defaultConfig.enabledLint.enableWarning(LintWarning::Code_ForRange); + + lintModule("Modules/A"); fileResolver.source["Modules/A"] = R"( -- We have fixed the lint error, but we did not tell the Frontend that the file is changed! - -- Therefore, we expect Frontend to reuse the parse tree. + -- Therefore, we expect Frontend to reuse the results from previous lint. )"; - configResolver.defaultConfig.enabledLint.enableWarning(LintWarning::Code_ForRange); - - LintResult lintResult = frontend.lint("Modules/A"); + LintResult lintResult = lintModule("Modules/A"); CHECK_EQ(1, lintResult.warnings.size()); } @@ -760,25 +760,49 @@ TEST_CASE_FIXTURE(FrontendFixture, "test_lint_uses_correct_config") configResolver.configFiles["Module/A"].enabledLint.enableWarning(LintWarning::Code_ForRange); - auto result = frontend.lint("Module/A"); + auto result = lintModule("Module/A"); CHECK_EQ(1, result.warnings.size()); configResolver.configFiles["Module/A"].enabledLint.disableWarning(LintWarning::Code_ForRange); + frontend.markDirty("Module/A"); - auto result2 = frontend.lint("Module/A"); + auto result2 = lintModule("Module/A"); CHECK_EQ(0, result2.warnings.size()); LintOptions overrideOptions; overrideOptions.enableWarning(LintWarning::Code_ForRange); - auto result3 = frontend.lint("Module/A", overrideOptions); + frontend.markDirty("Module/A"); + + auto result3 = lintModule("Module/A", overrideOptions); CHECK_EQ(1, result3.warnings.size()); overrideOptions.disableWarning(LintWarning::Code_ForRange); - auto result4 = frontend.lint("Module/A", overrideOptions); + frontend.markDirty("Module/A"); + + auto result4 = lintModule("Module/A", overrideOptions); CHECK_EQ(0, result4.warnings.size()); } +TEST_CASE_FIXTURE(FrontendFixture, "lint_results_are_only_for_checked_module") +{ + fileResolver.source["Module/A"] = R"( +local _ = 0b10000000000000000000000000000000000000000000000000000000000000000 + )"; + + fileResolver.source["Module/B"] = R"( +require(script.Parent.A) +local _ = 0x10000000000000000 + )"; + + LintResult lintResult = lintModule("Module/B"); + CHECK_EQ(1, lintResult.warnings.size()); + + // Check cached result + lintResult = lintModule("Module/B"); + CHECK_EQ(1, lintResult.warnings.size()); +} + TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") { Frontend fe{&fileResolver, &configResolver, {false}}; diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 41146d77a..37c12dc97 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -62,22 +62,16 @@ class IrBuilderFixture build.inst(IrCmd::LOP_RETURN, build.constUint(2)); }; - void checkEq(IrOp lhs, IrOp rhs) - { - CHECK_EQ(lhs.kind, rhs.kind); - LUAU_ASSERT(lhs.kind != IrOpKind::Constant && "can't compare constants, each ref is unique"); - CHECK_EQ(lhs.index, rhs.index); - } - void checkEq(IrOp instOp, const IrInst& inst) { const IrInst& target = build.function.instOp(instOp); CHECK(target.cmd == inst.cmd); - checkEq(target.a, inst.a); - checkEq(target.b, inst.b); - checkEq(target.c, inst.c); - checkEq(target.d, inst.d); - checkEq(target.e, inst.e); + CHECK(target.a == inst.a); + CHECK(target.b == inst.b); + CHECK(target.c == inst.c); + CHECK(target.d == inst.d); + CHECK(target.e == inst.e); + CHECK(target.f == inst.f); } IrBuilder build; @@ -405,18 +399,18 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowEq") TEST_CASE_FIXTURE(IrBuilderFixture, "NumToIndex") { withOneBlock([this](IrOp a) { - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, build.constDouble(4), a)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, build.constDouble(4), a)); build.inst(IrCmd::LOP_RETURN, build.constUint(0)); }); withOneBlock([this](IrOp a) { - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, build.constDouble(1.2), a)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, build.constDouble(1.2), a)); build.inst(IrCmd::LOP_RETURN, build.constUint(0)); }); withOneBlock([this](IrOp a) { IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, nan, a)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, nan, a)); build.inst(IrCmd::LOP_RETURN, build.constUint(0)); }); @@ -1676,4 +1670,64 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequencePeeling") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp a = build.block(IrBlockKind::Internal); + IrOp b = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(3), build.constInt(-1)); + build.inst(IrCmd::JUMP_EQ_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), a, b); + + build.beginBlock(a); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(b); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(2), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1, bb_2 +; in regs: R0, R1 +; out regs: R0, R1, R3... + FALLBACK_GETVARARGS 0u, R3, -1i + %1 = LOAD_TAG R0 + JUMP_EQ_TAG %1, tnumber, bb_1, bb_2 + +bb_1: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R0, R3... +; out regs: R2... + %3 = LOAD_TVALUE R0 + STORE_TVALUE R2, %3 + JUMP bb_3 + +bb_2: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R3... +; out regs: R2... + %6 = LOAD_TVALUE R1 + STORE_TVALUE R2, %6 + JUMP bb_3 + +bb_3: +; predecessors: bb_1, bb_2 +; in regs: R2... + LOP_RETURN 0u, R2, -1i + +)"); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index ebd004d38..0f1346161 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -733,6 +733,7 @@ end TEST_CASE_FIXTURE(Fixture, "ImplicitReturn") { LintResult result = lint(R"( +--!nonstrict function f1(a) if not a then return 5 @@ -789,20 +790,21 @@ return f1,f2,f3,f4,f5,f6,f7 )"); REQUIRE(3 == result.warnings.size()); - CHECK_EQ(result.warnings[0].location.begin.line, 4); + CHECK_EQ(result.warnings[0].location.begin.line, 5); CHECK_EQ(result.warnings[0].text, - "Function 'f1' can implicitly return no values even though there's an explicit return at line 4; add explicit return to silence"); - CHECK_EQ(result.warnings[1].location.begin.line, 28); + "Function 'f1' can implicitly return no values even though there's an explicit return at line 5; add explicit return to silence"); + CHECK_EQ(result.warnings[1].location.begin.line, 29); CHECK_EQ(result.warnings[1].text, - "Function 'f4' can implicitly return no values even though there's an explicit return at line 25; add explicit return to silence"); - CHECK_EQ(result.warnings[2].location.begin.line, 44); + "Function 'f4' can implicitly return no values even though there's an explicit return at line 26; add explicit return to silence"); + CHECK_EQ(result.warnings[2].location.begin.line, 45); CHECK_EQ(result.warnings[2].text, - "Function can implicitly return no values even though there's an explicit return at line 44; add explicit return to silence"); + "Function can implicitly return no values even though there's an explicit return at line 45; add explicit return to silence"); } TEST_CASE_FIXTURE(Fixture, "ImplicitReturnInfiniteLoop") { LintResult result = lint(R"( +--!nonstrict function f1(a) while true do if math.random() > 0.5 then @@ -845,12 +847,12 @@ return f1,f2,f3,f4 )"); REQUIRE(2 == result.warnings.size()); - CHECK_EQ(result.warnings[0].location.begin.line, 25); + CHECK_EQ(result.warnings[0].location.begin.line, 26); CHECK_EQ(result.warnings[0].text, - "Function 'f3' can implicitly return no values even though there's an explicit return at line 21; add explicit return to silence"); - CHECK_EQ(result.warnings[1].location.begin.line, 36); + "Function 'f3' can implicitly return no values even though there's an explicit return at line 22; add explicit return to silence"); + CHECK_EQ(result.warnings[1].location.begin.line, 37); CHECK_EQ(result.warnings[1].text, - "Function 'f4' can implicitly return no values even though there's an explicit return at line 32; add explicit return to silence"); + "Function 'f4' can implicitly return no values even though there's an explicit return at line 33; add explicit return to silence"); } TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings") @@ -1164,7 +1166,7 @@ os.date("!*t") TEST_CASE_FIXTURE(Fixture, "FormatStringTyped") { - LintResult result = lintTyped(R"~( + LintResult result = lint(R"~( local s: string, nons = ... string.match(s, "[]") @@ -1285,7 +1287,7 @@ TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") local _bar: typeof(os.clock) = os.clock )"; - LintResult result = frontend.lint("A"); + LintResult result = lintModule("A"); REQUIRE(0 == result.warnings.size()); } @@ -1471,7 +1473,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiTyped") freeze(frontend.globals.globalTypes); - LintResult result = lintTyped(R"( + LintResult result = lint(R"( return function (i: Instance) i:Wait(1.0) print(i.Name) @@ -1518,7 +1520,7 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "TableOperations") { - LintResult result = lintTyped(R"( + LintResult result = lint(R"( local t = {} local tt = {} diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index b86af0ebc..384a39fea 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -28,6 +28,18 @@ struct IsSubtypeFixture : Fixture return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); } + + bool isConsistentSubtype(TypeId a, TypeId b) + { + Location location; + ModulePtr module = getMainModule(); + REQUIRE(module); + + if (!module->hasModuleScope()) + FAIL("isSubtype: module scope data is not available"); + + return ::Luau::isConsistentSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); + } }; } // namespace @@ -86,8 +98,8 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_and_any") // any makes things work even when it makes no sense. - CHECK(isSubtype(b, a)); - CHECK(isSubtype(a, b)); + CHECK(isConsistentSubtype(b, a)); + CHECK(isConsistentSubtype(a, b)); } TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_functions_with_no_head") @@ -163,6 +175,10 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_union_prop") TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop") { + ScopedFastFlag sffs[] = { + {"LuauTransitiveSubtyping", true}, + }; + check(R"( local a: {x: number} local b: {x: any} @@ -172,7 +188,8 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop") TypeId b = requireType("b"); CHECK(isSubtype(a, b)); - CHECK(isSubtype(b, a)); + CHECK(!isSubtype(b, a)); + CHECK(isConsistentSubtype(b, a)); } TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection") @@ -216,6 +233,10 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "union_and_intersection") TEST_CASE_FIXTURE(IsSubtypeFixture, "tables") { + ScopedFastFlag sffs[] = { + {"LuauTransitiveSubtyping", true}, + }; + check(R"( local a: {x: number} local b: {x: any} @@ -229,7 +250,8 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "tables") TypeId d = requireType("d"); CHECK(isSubtype(a, b)); - CHECK(isSubtype(b, a)); + CHECK(!isSubtype(b, a)); + CHECK(isConsistentSubtype(b, a)); CHECK(!isSubtype(c, a)); CHECK(!isSubtype(a, c)); @@ -358,6 +380,92 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "metatable" * doctest::expected_failures{1}) } #endif +TEST_CASE_FIXTURE(IsSubtypeFixture, "any_is_unknown_union_error") +{ + ScopedFastFlag sffs[] = { + {"LuauTransitiveSubtyping", true}, + }; + + check(R"( + local err = 5.nope.nope -- err is now an error type + local a : any + local b : (unknown | typeof(err)) + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(isSubtype(b, a)); + CHECK_EQ("*error-type*", toString(requireType("err"))); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "any_intersect_T_is_T") +{ + ScopedFastFlag sffs[] = { + {"LuauTransitiveSubtyping", true}, + }; + + check(R"( + local a : (any & string) + local b : string + local c : number + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(isSubtype(a, b)); + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, c)); + CHECK(!isSubtype(c, a)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "error_suppression") +{ + ScopedFastFlag sffs[] = { + {"LuauTransitiveSubtyping", true}, + }; + + check(""); + + TypeId any = builtinTypes->anyType; + TypeId err = builtinTypes->errorType; + TypeId str = builtinTypes->stringType; + TypeId unk = builtinTypes->unknownType; + + CHECK(!isSubtype(any, err)); + CHECK(isSubtype(err, any)); + CHECK(isConsistentSubtype(any, err)); + CHECK(isConsistentSubtype(err, any)); + + CHECK(!isSubtype(any, str)); + CHECK(isSubtype(str, any)); + CHECK(isConsistentSubtype(any, str)); + CHECK(isConsistentSubtype(str, any)); + + CHECK(!isSubtype(any, unk)); + CHECK(isSubtype(unk, any)); + CHECK(isConsistentSubtype(any, unk)); + CHECK(isConsistentSubtype(unk, any)); + + CHECK(!isSubtype(err, str)); + CHECK(!isSubtype(str, err)); + CHECK(isConsistentSubtype(err, str)); + CHECK(isConsistentSubtype(str, err)); + + CHECK(!isSubtype(err, unk)); + CHECK(!isSubtype(unk, err)); + CHECK(isConsistentSubtype(err, unk)); + CHECK(isConsistentSubtype(unk, err)); + + CHECK(isSubtype(str, unk)); + CHECK(!isSubtype(unk, str)); + CHECK(isConsistentSubtype(str, unk)); + CHECK(!isConsistentSubtype(unk, str)); +} + TEST_SUITE_END(); struct NormalizeFixture : Fixture @@ -692,4 +800,17 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_tables") CHECK("table" == toString(normal("Not>"))); } +TEST_CASE_FIXTURE(NormalizeFixture, "normalize_blocked_types") +{ + ScopedFastFlag sff[] { + {"LuauNormalizeBlockedTypes", true}, + }; + + Type blocked{BlockedType{}}; + + const NormalizedType* norm = normalizer.normalize(&blocked); + + CHECK_EQ(normalizer.typeFromNormal(*norm), &blocked); +} + TEST_SUITE_END(); diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index 7e50d5b64..093570d30 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -263,9 +263,8 @@ TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") )LUA"; CheckResult result = check(src); - CodeTooComplex ctc; - CHECK(hasError(result, &ctc)); + CHECK(hasError(result)); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 0488196bb..c6766cada 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -225,7 +225,10 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK_EQ("*error-type*", toString(requireType("a"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("any", toString(requireType("a"))); + else + CHECK_EQ("*error-type*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -234,7 +237,10 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - CHECK_EQ("*error-type*", toString(requireType("a"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("any", toString(requireType("a"))); + else + CHECK_EQ("*error-type*", toString(requireType("a"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") @@ -343,4 +349,19 @@ stat = stat and tonumber(stat) or stat LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "intersection_of_any_can_have_props") +{ + // *blocked-130* ~ hasProp any & ~(false?), "_status" + CheckResult result = check(R"( +function foo(x: any, y) + if x then + return x._status + end + return y +end +)"); + + CHECK("(any, any) -> any" == toString(requireType("foo"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 5318b402e..49209a4d4 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -704,11 +704,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_strin LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("string", toString(requireType("foo"))); - else - CHECK_EQ("any", toString(requireType("foo"))); - + CHECK_EQ("any", toString(requireType("foo"))); CHECK_EQ("any", toString(requireType("bar"))); CHECK_EQ("any", toString(requireType("baz"))); CHECK_EQ("any", toString(requireType("quux"))); @@ -996,11 +992,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); + CHECK(Location({13, 18}, {13, 23}) == result.errors[0].location); CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("boolean", toString(requireType("c"))); - CHECK_EQ("*error-type*", toString(requireType("d"))); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("any", toString(requireType("d"))); + else + CHECK_EQ("*error-type*", toString(requireType("d"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") diff --git a/tests/TypeInfer.cfa.test.cpp b/tests/TypeInfer.cfa.test.cpp new file mode 100644 index 000000000..737429583 --- /dev/null +++ b/tests/TypeInfer.cfa.test.cpp @@ -0,0 +1,380 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Fixture.h" + +#include "Luau/Symbol.h" +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("ControlFlowAnalysis"); + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + if not x then + return + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + elseif not y then + return + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 24}))); + CHECK_EQ("string", toString(requireTypeAtPosition({9, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_rand_return_elif_not_y_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + elseif math.random() > 0.5 then + return + elseif not y then + return + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_rand_return_elif_not_y_fallthrough") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + elseif math.random() > 0.5 then + return + elseif not y then + + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({11, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_fallthrough_elif_not_z_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?, y: string?, z: string?) + if not x then + return + elseif not y then + + elseif not z then + return + end + + local foo = x + local bar = y + local baz = z + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({11, 24}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({12, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_if_not_x_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + do + if not x then + return + end + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "early_return_in_a_loop_which_isnt_guaranteed_to_run_first") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + while math.random() > 0.5 do + if not x then + return + end + + local foo = x + end + + local bar = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({10, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "early_return_in_a_loop_which_is_guaranteed_to_run_first") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + repeat + if not x then + return + end + + local foo = x + until math.random() > 0.5 + + local bar = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({10, 24}))); // TODO: This is wrong, should be `string`. +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "early_return_in_a_loop_which_is_guaranteed_to_run_first_2") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + for i = 1, 10 do + if not x then + return + end + + local foo = x + end + + local bar = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({10, 24}))); // TODO: This is wrong, should be `string`. +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_then_error") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + if not x then + error("oops") + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_then_assert_false") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + if not x then + assert(false) + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_if_not_y_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + end + + if not y then + return + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + if typeof(x) == "string" then + return + else + type Foo = number + end + + local foo: Foo = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Unknown type 'Foo'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({8, 29}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_scope") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + // In CGB, we walk the block to prototype aliases. We then visit the block in-order, which will resolve the prototype to a real type. + // That second walk assumes that the name occurs in the same `Scope` that the prototype walk had. If we arbitrarily change scope midway + // through, we'd invoke UB. + CheckResult result = check(R"( + local function f(x: string?) + type Foo = number + + if typeof(x) == "string" then + return + end + + local foo: Foo = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'nil' could not be converted into 'number'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({8, 29}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + type Ok = { tag: "ok", value: T } + type Err = { tag: "err", error: E } + type Result = Ok | Err + + local function map(result: Result, f: (T) -> U): Result + if result.tag == "ok" then + local tag = result.tag + local val = result.value + + return { tag = "ok", value = f(result.value) } + end + + local tag = result.tag + local err = result.error + + return result + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("\"ok\"", toString(requireTypeAtPosition({7, 35}))); + CHECK_EQ("T", toString(requireTypeAtPosition({8, 35}))); + + CHECK_EQ("\"err\"", toString(requireTypeAtPosition({13, 31}))); + CHECK_EQ("E", toString(requireTypeAtPosition({14, 31}))); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("{| error: E, tag: \"err\" |}", toString(requireTypeAtPosition({16, 19}))); + else + CHECK_EQ("Err", toString(requireTypeAtPosition({16, 19}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_assert_x") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + do + assert(x) + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 50e9f802f..511cbc763 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -689,4 +689,22 @@ TEST_CASE_FIXTURE(Fixture, "for_loop_lower_bound_is_string_3") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "cli_68448_iterators_need_not_accept_nil") +{ + CheckResult result = check(R"( + local function makeEnum(members) + local enum = {} + for _, memberName in ipairs(members) do + enum[memberName] = memberName + end + return enum + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // HACK (CLI-68453): We name this inner table `enum`. For now, use the + // exhaustive switch to see past it. + CHECK(toString(requireType("makeEnum"), {true}) == "({a}) -> {| [a]: a |}"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index ab07ee2da..8670729a7 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -485,6 +485,8 @@ return unpack(l0[_]) TEST_CASE_FIXTURE(BuiltinsFixture, "check_imported_module_names") { + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + fileResolver.source["game/A"] = R"( return function(...) end )"; @@ -506,19 +508,10 @@ return l0 ModulePtr mod = getMainModule(); REQUIRE(mod); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - REQUIRE(mod->scopes.size() >= 4); - CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); - CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); - } - else - { - REQUIRE(mod->scopes.size() >= 3); - CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); - CHECK(mod->scopes[2].second->importedModules["l1"] == "game/A"); - } + REQUIRE(mod->scopes.size() == 4); + CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); + CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index ab41ce37e..0f540f683 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -309,4 +309,21 @@ TEST_CASE_FIXTURE(Fixture, "dont_bind_free_tables_to_themselves") )"); } +// We should probably flag an error on this. See CLI-68672 +TEST_CASE_FIXTURE(BuiltinsFixture, "flag_when_index_metamethod_returns_0_values") +{ + CheckResult result = check(R"( + local T = {} + function T.__index() + end + + local a = setmetatable({}, T) + local p = a.prop + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("nil" == toString(requireType("p"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index dcdc2e313..720784c35 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -1109,4 +1109,28 @@ local f1 = f or 'f' CHECK("string" == toString(requireType("f1"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "reducing_and") +{ + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + {"LuauReducingAndOr", true}, + }; + + CheckResult result = check(R"( +type Foo = { name: string?, flag: boolean? } +local arr: {Foo} = {} + +local function foo(arg: {name: string}?) + local name = if arg and arg.name then arg.name else nil + + table.insert(arr, { + name = name or "", + flag = name ~= nil and name ~= "", + }) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 0aacb8aec..30f77d681 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -470,6 +470,10 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + TypeArena arena; TypeId nilType = builtinTypes->nilType; @@ -488,7 +492,7 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") u.tryUnify(option1, option2); - CHECK(u.errors.empty()); + CHECK(!u.failure); u.log.commit(); @@ -548,7 +552,10 @@ return wrapStrictTable(Constants, "Constants") std::optional result = first(m->returnType); REQUIRE(result); - CHECK(get(*result)); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("(any?) & ~table", toString(*result)); + else + CHECK_MESSAGE(get(*result), *result); } TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 064ec164a..890e9b693 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1615,7 +1615,8 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") )"); LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ("*error-type*", toString(requireTypeAtPosition({4, 30}))); + + CHECK_EQ("~false & ~nil", toString(requireTypeAtPosition({4, 30}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 0f5e3d310..21ac6421b 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1590,8 +1590,16 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") local hi: number = foo({ a = "hi" }, "a") -- shouldn't typecheck since at runtime hi is "hi" )"); - // This typechecks but shouldn't - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Type 'string' could not be converted into 'number'"); + } + else + { + // This typechecks but shouldn't + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 417f80a84..7c4bfb2e9 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -103,6 +103,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "obvious_type_error_in_nocheck_mode") +{ + CheckResult result = check(R"( + --!nocheck + local x: string = 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "expr_statement") { CheckResult result = check("local foo = 5 foo()"); @@ -1185,6 +1195,9 @@ TEST_CASE_FIXTURE(Fixture, "dcr_delays_expansion_of_function_containing_blocked_ ScopedFastFlag sff[] = { {"DebugLuauDeferredConstraintResolution", true}, {"LuauTinyUnifyNormalsFix", true}, + // If we run this with error-suppression, it triggers an assertion. + // FATAL ERROR: Assertion failed: !"Internal error: Trying to normalize a BlockedType" + {"LuauTransitiveSubtyping", false}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 66e070139..5a9c77d40 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -27,16 +27,25 @@ TEST_SUITE_BEGIN("TryUnifyTests"); TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + Type numberOne{TypeVariant{PrimitiveType{PrimitiveType::Number}}}; Type numberTwo = numberOne; state.tryUnify(&numberTwo, &numberOne); + CHECK(!state.failure); CHECK(state.errors.empty()); } TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + Type functionOne{ TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))}}; @@ -44,6 +53,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; state.tryUnify(&functionTwo, &functionOne); + CHECK(!state.failure); CHECK(state.errors.empty()); state.log.commit(); @@ -53,6 +63,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; Type functionOne{ TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))}}; @@ -66,6 +80,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") Type functionTwoSaved = functionTwo; state.tryUnify(&functionTwo, &functionOne); + CHECK(state.failure); CHECK(!state.errors.empty()); CHECK_EQ(functionOne, functionOneSaved); @@ -74,6 +89,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + Type tableOne{TypeVariant{ TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; @@ -86,6 +105,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") state.tryUnify(&tableTwo, &tableOne); + CHECK(!state.failure); CHECK(state.errors.empty()); state.log.commit(); @@ -95,6 +115,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + Type tableOne{TypeVariant{ TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, std::nullopt, globalScope->level, TableState::Unsealed}, @@ -109,6 +133,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") state.tryUnify(&tableTwo, &tableOne); + CHECK(state.failure); CHECK_EQ(1, state.errors.size()); CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); @@ -218,6 +243,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") TypePackVar variadicPack{VariadicTypePack{builtinTypes->numberType}}; state.tryUnify(&testPack, &variadicPack); + CHECK(state.failure); CHECK(!state.errors.empty()); } @@ -228,6 +254,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") TypePackVar b{TypePack{{builtinTypes->numberType, builtinTypes->stringType}, &variadicPack}}; state.tryUnify(&b, &a); + CHECK(!state.failure); CHECK(state.errors.empty()); } @@ -270,8 +297,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") arena.addTypePack(TypePack{{builtinTypes->numberType, builtinTypes->numberType, builtinTypes->numberType}, std::nullopt}); TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{builtinTypes->numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); - ErrorVec unifyErrors = state.canUnify(numberAndFreeTail, threeNumbers); - CHECK(unifyErrors.size() == 0); + CHECK(state.canUnify(numberAndFreeTail, threeNumbers).empty()); } TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") @@ -321,7 +347,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table") { - ScopedFastFlag sff("DebugLuauDeferredConstraintResolution", true); + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + {"DebugLuauDeferredConstraintResolution", true}, + }; TableType::Props freeProps{ {"foo", {builtinTypes->numberType}}, diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 704e2a3b4..d49f00443 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -715,4 +715,62 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types_2") CHECK_EQ("({| x: number |} | {| x: string |}) -> number | string", toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "union_table_any_property") +{ + CheckResult result = check(R"( + function f(x) + -- x : X + -- sup : { p : { q : X } }? + local sup = if true then { p = { q = x } } else nil + local sub : { p : any } + sup = nil + sup = sub + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_function_any_args") +{ + CheckResult result = check(R"( + local sup : ((...any) -> (...any))? + local sub : ((number) -> (...any)) + sup = sub + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "optional_any") +{ + CheckResult result = check(R"( + local sup : any? + local sub : number + sup = sub + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_function_with_optional_arg") +{ + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + + CheckResult result = check(R"( + function f(x : T?) : {T} + local result = {} + if x then + result[1] = x + end + return result + end + local t : {string} = f(nil) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index f17ada20e..410fd52de 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -303,6 +303,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_i { ScopedFastFlag sff[]{ {"LuauTryhardAnd", true}, + {"LuauReducingAndOr", true}, }; CheckResult result = check(R"( @@ -313,13 +314,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_i LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("(nil, a) -> boolean", toString(requireType("ord"))); - else - { - // Widening doesn't normalize yet, so the result is a bit strange - CHECK_EQ("(nil, a) -> boolean | boolean", toString(requireType("ord"))); - } + CHECK_EQ("(nil, a) -> boolean", toString(requireType("ord"))); } TEST_CASE_FIXTURE(Fixture, "math_operators_and_never") diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 4b47ed26a..596eed3db 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -578,6 +578,21 @@ do assert(#t2 == 6) end +-- test boundary invariant in sparse arrays or various kinds +do + local function obscuredalloc() return {} end + + local bits = 16 + + for i = 1, 2^bits - 1 do + local t1 = obscuredalloc() -- to avoid NEWTABLE guessing correct size + + for k = 1, bits do + t1[k] = if bit32.extract(i, k - 1) == 1 then true else nil + end + end +end + -- test table.unpack fastcall for rejecting large unpacks do local ok, res = pcall(function() diff --git a/tools/faillist.txt b/tools/faillist.txt index bcc177739..d513af142 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -31,8 +31,6 @@ BuiltinTests.table_pack_reduce BuiltinTests.table_pack_variadic DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props -FrontendTest.environments -FrontendTest.nocheck_cycle_used_by_checked GenericsTests.apply_type_function_nested_generics2 GenericsTests.better_mismatch_error_messages GenericsTests.bound_tables_do_not_clone_original_fields @@ -54,19 +52,6 @@ GenericsTests.self_recursive_instantiated_param IntersectionTypes.table_intersection_write_sealed IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect -ModuleTests.clone_self_property -NonstrictModeTests.for_in_iterator_variables_are_any -NonstrictModeTests.function_parameters_are_any -NonstrictModeTests.inconsistent_module_return_types_are_ok -NonstrictModeTests.inconsistent_return_types_are_ok -NonstrictModeTests.infer_nullary_function -NonstrictModeTests.infer_the_maximum_number_of_values_the_function_could_return -NonstrictModeTests.inline_table_props_are_also_any -NonstrictModeTests.local_tables_are_not_any -NonstrictModeTests.locals_are_any_by_default -NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon -NonstrictModeTests.parameters_having_type_any_are_optional -NonstrictModeTests.table_props_are_any ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.bail_early_if_unification_is_too_complicated ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack @@ -85,9 +70,7 @@ RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible -TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode TableTests.casting_tables_with_props_into_table_with_indexer3 -TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.checked_prop_too_early TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar @@ -117,7 +100,6 @@ TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys TableTests.nil_assign_doesnt_hit_indexer TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr -TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.oop_polymorphic TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table @@ -138,7 +120,6 @@ ToString.named_metatable_toStringNamedFunction ToString.toStringDetailed2 ToString.toStringErrorPack ToString.toStringNamedFunction_generic_pack -ToString.toStringNamedFunction_map TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails @@ -154,15 +135,11 @@ TypeAliases.type_alias_local_rename TypeAliases.type_alias_locations TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeInfer.check_type_infer_recursion_count -TypeInfer.checking_should_not_ice TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.fuzz_free_table_type_change_during_index_check -TypeInfer.globals -TypeInfer.globals2 TypeInfer.infer_assignment_value_types_mutable_lval -TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict TypeInfer.no_stack_overflow_from_isoptional TypeInfer.no_stack_overflow_from_isoptional2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error @@ -173,17 +150,13 @@ TypeInferClasses.classes_without_overloaded_operators_cannot_be_added TypeInferClasses.index_instance_property TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferClasses.warn_when_prop_almost_matches -TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.cannot_hoist_interior_defns_into_signature -TypeInferFunctions.check_function_before_lambda_that_uses_it TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists -TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite TypeInferFunctions.function_does_not_return_enough_values TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer -TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_that_function_does_not_return_a_table @@ -191,7 +164,6 @@ TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.no_lossy_function_type TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.record_matching_overload -TypeInferFunctions.report_exiting_without_return_nonstrict TypeInferFunctions.report_exiting_without_return_strict TypeInferFunctions.return_type_by_overload TypeInferFunctions.too_few_arguments_variadic @@ -204,11 +176,9 @@ TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_va TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_generic_next TypeInferLoops.loop_iter_metamethod_ok_with_inference -TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop -TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated @@ -220,18 +190,14 @@ TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_metatable TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops -TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators -TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown TypeInferOperators.operator_eq_completely_incompatible TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.typecheck_unary_len_error -TypeInferOperators.UnknownGlobalCompoundAssign TypeInferOperators.unrelated_classes_cannot_be_compared TypeInferOperators.unrelated_primitives_cannot_be_compared TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.string_index -TypeInferUnknownNever.assign_to_global_which_is_never TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators TypeInferUnknownNever.math_operators_and_never TypePackTests.detect_cyclic_typepacks2 @@ -250,6 +216,7 @@ TypeSingletons.table_properties_type_error_escapes TypeSingletons.taking_the_length_of_union_of_string_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeSingletons.widening_happens_almost_everywhere +UnionTypes.generic_function_with_optional_arg UnionTypes.index_on_a_union_type_with_missing_property UnionTypes.optional_assignment_errors UnionTypes.optional_call_error From 81200e13f6146f7f03d9dfa88974ef3f8f966656 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 24 Mar 2023 10:34:14 -0700 Subject: [PATCH 42/66] Sync to upstream/release/569 --- Analysis/include/Luau/Clone.h | 3 - Analysis/include/Luau/ControlFlow.h | 8 +- Analysis/include/Luau/Unifiable.h | 2 +- Analysis/src/Clone.cpp | 84 +---- Analysis/src/ConstraintGraphBuilder.cpp | 45 ++- Analysis/src/ConstraintSolver.cpp | 161 +++++---- Analysis/src/Frontend.cpp | 62 ---- Analysis/src/Instantiation.cpp | 2 +- Analysis/src/Module.cpp | 10 +- Analysis/src/Normalize.cpp | 15 +- Analysis/src/Substitution.cpp | 181 +++++++++- Analysis/src/ToString.cpp | 7 +- Analysis/src/Type.cpp | 9 +- Analysis/src/TypeChecker2.cpp | 18 +- Analysis/src/TypeInfer.cpp | 7 +- Analysis/src/TypeReduction.cpp | 2 +- Analysis/src/Unifiable.cpp | 2 +- Analysis/src/Unifier.cpp | 33 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 23 +- CodeGen/include/Luau/IrAnalysis.h | 12 +- CodeGen/include/Luau/IrData.h | 51 ++- CodeGen/include/Luau/IrDump.h | 14 +- CodeGen/src/AssemblyBuilderA64.cpp | 89 ++++- CodeGen/src/CodeGen.cpp | 205 ++++++++++-- CodeGen/src/CodeGenA64.cpp | 50 ++- CodeGen/src/CodeGenA64.h | 4 + CodeGen/src/CodeGenUtils.cpp | 37 +++ CodeGen/src/CodeGenUtils.h | 2 + CodeGen/src/EmitCommonA64.cpp | 75 +++++ CodeGen/src/EmitCommonA64.h | 77 +++++ CodeGen/src/EmitInstructionA64.cpp | 59 ++++ CodeGen/src/EmitInstructionA64.h | 20 ++ CodeGen/src/EmitInstructionX64.cpp | 75 ++--- CodeGen/src/EmitInstructionX64.h | 19 +- CodeGen/src/IrAnalysis.cpp | 38 ++- CodeGen/src/IrBuilder.cpp | 17 +- CodeGen/src/IrDump.cpp | 26 +- CodeGen/src/IrLoweringA64.cpp | 137 ++++++++ CodeGen/src/IrLoweringA64.h | 60 ++++ CodeGen/src/IrLoweringX64.cpp | 363 ++++----------------- CodeGen/src/IrLoweringX64.h | 6 +- CodeGen/src/IrTranslateBuiltins.cpp | 4 + CodeGen/src/IrTranslation.cpp | 4 +- CodeGen/src/IrUtils.cpp | 2 +- CodeGen/src/NativeState.cpp | 1 + CodeGen/src/NativeState.h | 13 +- CodeGen/src/OptimizeConstProp.cpp | 112 ++++++- Makefile | 5 + Sources.cmake | 8 + VM/src/lstate.h | 4 +- tests/AssemblyBuilderA64.test.cpp | 72 +++- tests/CodeAllocator.test.cpp | 1 + tests/Conformance.test.cpp | 4 +- tests/IrBuilder.test.cpp | 167 +++++++--- tests/Module.test.cpp | 4 +- tests/Normalize.test.cpp | 12 +- tests/TypeInfer.aliases.test.cpp | 67 ++++ tests/TypeInfer.builtins.test.cpp | 155 ++------- tests/TypeInfer.functions.test.cpp | 31 +- tests/TypeInfer.intersectionTypes.test.cpp | 7 +- tests/TypeInfer.oop.test.cpp | 55 ++++ tests/TypeInfer.operators.test.cpp | 12 +- tests/TypeInfer.unionTypes.test.cpp | 7 +- tests/conformance/interrupt.lua | 5 + tests/conformance/math.lua | 2 + tests/main.cpp | 21 +- tools/faillist.txt | 10 +- 67 files changed, 1957 insertions(+), 938 deletions(-) create mode 100644 CodeGen/src/EmitCommonA64.cpp create mode 100644 CodeGen/src/EmitCommonA64.h create mode 100644 CodeGen/src/EmitInstructionA64.cpp create mode 100644 CodeGen/src/EmitInstructionA64.h create mode 100644 CodeGen/src/IrLoweringA64.cpp create mode 100644 CodeGen/src/IrLoweringA64.h diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 51f1e7a67..b3cbe467c 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -26,7 +26,4 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); -TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false); -TypeId shallowClone(TypeId ty, NotNull dest); - } // namespace Luau diff --git a/Analysis/include/Luau/ControlFlow.h b/Analysis/include/Luau/ControlFlow.h index 8272bd53e..566d77bd4 100644 --- a/Analysis/include/Luau/ControlFlow.h +++ b/Analysis/include/Luau/ControlFlow.h @@ -11,10 +11,10 @@ using ScopePtr = std::shared_ptr; enum class ControlFlow { - None = 0b00001, - Returns = 0b00010, - Throws = 0b00100, - Break = 0b01000, // Currently unused. + None = 0b00001, + Returns = 0b00010, + Throws = 0b00100, + Break = 0b01000, // Currently unused. Continue = 0b10000, // Currently unused. }; diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 9c4f01326..ae55f3734 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -82,7 +82,7 @@ namespace Luau::Unifiable using Name = std::string; int freshIndex(); - + struct Free { explicit Free(TypeLevel level); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index ff8e0c3c2..2645209d5 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -7,7 +7,7 @@ #include "Luau/Unifiable.h" LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) -LUAU_FASTFLAG(LuauClonePublicInterfaceLess) +LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) @@ -422,86 +422,4 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) return result; } -TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) -{ - ty = log->follow(ty); - - TypeId result = ty; - - if (auto pty = log->pending(ty)) - ty = &pty->pending; - - if (const FunctionType* ftv = get(ty)) - { - FunctionType clone = FunctionType{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; - clone.generics = ftv->generics; - clone.genericPacks = ftv->genericPacks; - clone.magicFunction = ftv->magicFunction; - clone.dcrMagicFunction = ftv->dcrMagicFunction; - clone.dcrMagicRefinement = ftv->dcrMagicRefinement; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - result = dest.addType(std::move(clone)); - } - else if (const TableType* ttv = get(ty)) - { - LUAU_ASSERT(!ttv->boundTo); - TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; - clone.definitionModuleName = ttv->definitionModuleName; - clone.definitionLocation = ttv->definitionLocation; - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - clone.tags = ttv->tags; - result = dest.addType(std::move(clone)); - } - else if (const MetatableType* mtv = get(ty)) - { - MetatableType clone = MetatableType{mtv->table, mtv->metatable}; - clone.syntheticName = mtv->syntheticName; - result = dest.addType(std::move(clone)); - } - else if (const UnionType* utv = get(ty)) - { - UnionType clone; - clone.options = utv->options; - result = dest.addType(std::move(clone)); - } - else if (const IntersectionType* itv = get(ty)) - { - IntersectionType clone; - clone.parts = itv->parts; - result = dest.addType(std::move(clone)); - } - else if (const PendingExpansionType* petv = get(ty)) - { - PendingExpansionType clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; - result = dest.addType(std::move(clone)); - } - else if (const ClassType* ctv = get(ty); FFlag::LuauClonePublicInterfaceLess && ctv && alwaysClone) - { - ClassType clone{ctv->name, ctv->props, ctv->parent, ctv->metatable, ctv->tags, ctv->userData, ctv->definitionModuleName}; - result = dest.addType(std::move(clone)); - } - else if (FFlag::LuauClonePublicInterfaceLess && alwaysClone) - { - result = dest.addType(*ty); - } - else if (const NegationType* ntv = get(ty)) - { - result = dest.addType(NegationType{ntv->ty}); - } - else - return result; - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; -} - -TypeId shallowClone(TypeId ty, NotNull dest) -{ - return shallowClone(ty, *dest, TxnLog::empty()); -} - } // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index e90cb7d3a..474d39235 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -23,7 +23,7 @@ LUAU_FASTFLAG(LuauNegatedClassTypes); namespace Luau { -bool doesCallError(const AstExprCall* call); // TypeInfer.cpp +bool doesCallError(const AstExprCall* call); // TypeInfer.cpp const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp static std::optional matchRequire(const AstExprCall& call) @@ -1359,10 +1359,34 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa if (argTail && args.size() < 2) argTailPack = extendTypePack(*arena, builtinTypes, *argTail, 2 - args.size()); - LUAU_ASSERT(args.size() + argTailPack.head.size() == 2); + TypeId target = nullptr; + TypeId mt = nullptr; - TypeId target = args.size() > 0 ? args[0] : argTailPack.head[0]; - TypeId mt = args.size() > 1 ? args[1] : argTailPack.head[args.size() == 0 ? 1 : 0]; + if (args.size() + argTailPack.head.size() == 2) + { + target = args.size() > 0 ? args[0] : argTailPack.head[0]; + mt = args.size() > 1 ? args[1] : argTailPack.head[args.size() == 0 ? 1 : 0]; + } + else + { + std::vector unpackedTypes; + if (args.size() > 0) + target = args[0]; + else + { + target = arena->addType(BlockedType{}); + unpackedTypes.emplace_back(target); + } + + mt = arena->addType(BlockedType{}); + unpackedTypes.emplace_back(mt); + TypePackId mtPack = arena->addTypePack(std::move(unpackedTypes)); + + addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail}); + } + + LUAU_ASSERT(target); + LUAU_ASSERT(mt); AstExpr* targetExpr = call->args.data[0]; @@ -2090,6 +2114,19 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS TypePack expectedArgPack; const FunctionType* expectedFunction = expectedType ? get(*expectedType) : nullptr; + // This check ensures that expectedType is precisely optional and not any (since any is also an optional type) + if (expectedType && isOptional(*expectedType) && !get(*expectedType)) + { + auto ut = get(*expectedType); + for (auto u : ut) + { + if (get(u) && !isNil(u)) + { + expectedFunction = get(u); + break; + } + } + } if (expectedFunction) { diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 5662cf04b..d5853932e 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -3,6 +3,7 @@ #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" #include "Luau/Instantiation.h" @@ -221,17 +222,6 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) auto it = cs->blockedConstraints.find(c); int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); printf("\t%d\t%s\n", blockCount, toString(*c, opts).c_str()); - - for (NotNull dep : c->dependencies) - { - auto unsolvedIter = std::find(begin(cs->unsolvedConstraints), end(cs->unsolvedConstraints), dep); - if (unsolvedIter == cs->unsolvedConstraints.end()) - continue; - - auto it = cs->blockedConstraints.find(dep); - int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); - printf("\t%d\t\t%s\n", blockCount, toString(*dep, opts).c_str()); - } } } @@ -578,12 +568,16 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(builtinTypes->booleanType); + + unblock(c.resultType); return true; } case AstExprUnary::Len: { // __len must return a number. asMutable(c.resultType)->ty.emplace(builtinTypes->numberType); + + unblock(c.resultType); return true; } case AstExprUnary::Minus: @@ -613,6 +607,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(builtinTypes->errorRecoveryType()); } + unblock(c.resultType); return true; } } @@ -868,7 +863,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullanyType}; } - TypeId instantiatedTy = arena->addType(BlockedType{}); TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); - auto pushConstraintGreedy = [this, constraint](ConstraintV cv) -> Constraint* { - std::unique_ptr c = std::make_unique(constraint->scope, constraint->location, std::move(cv)); - NotNull borrow{c.get()}; + std::vector overloads = flattenIntersection(fn); - bool ok = tryDispatch(borrow, false); - if (ok) - return nullptr; + Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); - solverConstraints.push_back(std::move(c)); - unsolvedConstraints.push_back(borrow); + for (TypeId overload : overloads) + { + overload = follow(overload); - return borrow; - }; + std::optional instantiated = inst.substitute(overload); + LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - // HACK: We don't want other constraints to act on the free type pack - // created above until after these two constraints are solved, so we try to - // dispatch them directly. + Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; + u.useScopes = true; - auto ic = pushConstraintGreedy(InstantiationConstraint{instantiatedTy, fn}); - auto sc = pushConstraintGreedy(SubtypeConstraint{instantiatedTy, inferredTy}); + u.tryUnify(*instantiated, inferredTy, /* isFunctionCall */ true); - if (ic) - inheritBlocks(constraint, NotNull{ic}); + if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) + { + for (TypeId bt : u.blockedTypes) + block(bt, constraint); + for (TypePackId btp : u.blockedTypePacks) + block(btp, constraint); + return false; + } + + if (const auto& e = hasUnificationTooComplex(u.errors)) + reportError(*e); - if (sc) - inheritBlocks(constraint, NotNull{sc}); + if (u.errors.empty()) + { + // We found a matching overload. + const auto [changedTypes, changedPacks] = u.log.getChanges(); + u.log.commit(); + unblock(changedTypes); + unblock(changedPacks); + + unblock(c.result); + return true; + } + } + + // We found no matching overloads. + Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(inferredTy, builtinTypes->anyType); + u.tryUnify(fn, builtinTypes->anyType); + + LUAU_ASSERT(u.errors.empty()); // unifying with any should never fail + + const auto [changedTypes, changedPacks] = u.log.getChanges(); + u.log.commit(); + + unblock(changedTypes); + unblock(changedPacks); unblock(c.result); return true; @@ -1291,6 +1314,7 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNullty.emplace(bindTo); + unblock(c.resultType); return true; } @@ -1335,20 +1359,17 @@ static bool isUnsealedTable(TypeId ty) } /** - * Create a shallow copy of `ty` and its properties along `path`. Insert a new - * property (the last segment of `path`) into the tail table with the value `t`. + * Given a path into a set of nested unsealed tables `ty`, insert a new property `replaceTy` as the leaf-most property. * - * On success, returns the new outermost table type. If the root table or any - * of its subkeys are not unsealed tables, the function fails and returns - * std::nullopt. + * Fails and does nothing if every table along the way is not unsealed. * - * TODO: Prove that we completely give up in the face of indexers and - * metatables. + * Mutates the innermost table type in-place. */ -static std::optional updateTheTableType(NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) +static void updateTheTableType( + NotNull builtinTypes, NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) { if (path.empty()) - return std::nullopt; + return; // First walk the path and ensure that it's unsealed tables all the way // to the end. @@ -1357,12 +1378,12 @@ static std::optional updateTheTableType(NotNull arena, TypeId for (size_t i = 0; i < path.size() - 1; ++i) { if (!isUnsealedTable(t)) - return std::nullopt; + return; const TableType* tbl = get(t); auto it = tbl->props.find(path[i]); if (it == tbl->props.end()) - return std::nullopt; + return; t = follow(it->second.type); } @@ -1371,40 +1392,37 @@ static std::optional updateTheTableType(NotNull arena, TypeId // We are not changing property types. We are only admitting this one // new property to be appended. if (!isUnsealedTable(t)) - return std::nullopt; + return; const TableType* tbl = get(t); if (0 != tbl->props.count(path.back())) - return std::nullopt; + return; } - const TypeId res = shallowClone(ty, arena); - TypeId t = res; + TypeId t = ty; + ErrorVec dummy; for (size_t i = 0; i < path.size() - 1; ++i) { - const std::string segment = path[i]; + auto propTy = findTablePropertyRespectingMeta(builtinTypes, dummy, t, path[i], Location{}); + dummy.clear(); - TableType* ttv = getMutable(t); - LUAU_ASSERT(ttv); + if (!propTy) + return; - auto propIt = ttv->props.find(segment); - if (propIt != ttv->props.end()) - { - LUAU_ASSERT(isUnsealedTable(propIt->second.type)); - t = shallowClone(follow(propIt->second.type), arena); - ttv->props[segment].type = t; - } - else - return std::nullopt; + t = *propTy; } - TableType* ttv = getMutable(t); - LUAU_ASSERT(ttv); + const std::string& lastSegment = path.back(); + + t = follow(t); + TableType* tt = getMutable(t); + if (auto mt = get(t)) + tt = getMutable(mt->table); - const std::string lastSegment = path.back(); - LUAU_ASSERT(0 == ttv->props.count(lastSegment)); - ttv->props[lastSegment] = Property{replaceTy}; - return res; + if (!tt) + return; + + tt->props[lastSegment].type = replaceTy; } bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force) @@ -1443,6 +1461,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullscope); bind(c.resultType, c.subjectType); + unblock(c.resultType); return true; } @@ -1467,6 +1486,8 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) @@ -1477,20 +1498,23 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullprops[c.path[0]] = Property{c.propType}; bind(c.resultType, c.subjectType); + unblock(c.resultType); return true; } else if (ttv->state == TableState::Unsealed) { LUAU_ASSERT(!subjectType->persistent); - std::optional augmented = updateTheTableType(NotNull{arena}, subjectType, c.path, c.propType); - bind(c.resultType, augmented.value_or(subjectType)); - bind(subjectType, c.resultType); + updateTheTableType(builtinTypes, NotNull{arena}, subjectType, c.path, c.propType); + bind(c.resultType, c.subjectType); + unblock(subjectType); + unblock(c.resultType); return true; } else { bind(c.resultType, subjectType); + unblock(c.resultType); return true; } } @@ -1499,6 +1523,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull 0) - return LoadDefinitionFileResult{false, parseResult, {}, nullptr}; - - Luau::SourceModule module; - module.root = parseResult.root; - module.mode = Mode::Definition; - - ModulePtr checkedModule = typeChecker.check(module, Mode::Definition); - - if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, {}, checkedModule}; - - CloneState cloneState; - - std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); - - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - - typesToPersist.push_back(globalTy); - } - - for (const auto& [name, ty] : checkedModule->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - targetScope->exportedTypeBindings[name] = globalTy; - - typesToPersist.push_back(globalTy.type); - } - - for (TypeId ty : typesToPersist) - { - persist(ty); - } - - return LoadDefinitionFileResult{true, parseResult, {}, checkedModule}; -} - LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, bool captureComments) { - if (!FFlag::LuauDefinitionFileSourceModule) - return loadDefinitionFile_DEPRECATED(typeChecker, globals, targetScope, source, packageName); - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); Luau::SourceModule sourceModule; diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 9c3ae0771..7d0f0f72f 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -127,7 +127,7 @@ TypeId ReplaceGenerics::clean(TypeId ty) TypePackId ReplaceGenerics::clean(TypePackId tp) { LUAU_ASSERT(isDirty(tp)); - return addTypePack(TypePackVar(FreeTypePack{level})); + return addTypePack(TypePackVar(FreeTypePack{scope, level})); } } // namespace Luau diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index b51b7c9a6..fd9484038 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -16,7 +16,7 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); +LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess2, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); @@ -194,7 +194,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr TxnLog log; ClonePublicInterface clonePublicInterface{&log, builtinTypes, this}; - if (FFlag::LuauClonePublicInterfaceLess) + if (FFlag::LuauClonePublicInterfaceLess2) returnType = clonePublicInterface.cloneTypePack(returnType); else returnType = clone(returnType, interfaceTypes, cloneState); @@ -202,7 +202,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr moduleScope->returnType = returnType; if (varargPack) { - if (FFlag::LuauClonePublicInterfaceLess) + if (FFlag::LuauClonePublicInterfaceLess2) varargPack = clonePublicInterface.cloneTypePack(*varargPack); else varargPack = clone(*varargPack, interfaceTypes, cloneState); @@ -211,7 +211,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr for (auto& [name, tf] : moduleScope->exportedTypeBindings) { - if (FFlag::LuauClonePublicInterfaceLess) + if (FFlag::LuauClonePublicInterfaceLess2) tf = clonePublicInterface.cloneTypeFun(tf); else tf = clone(tf, interfaceTypes, cloneState); @@ -219,7 +219,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr for (auto& [name, ty] : declaredGlobals) { - if (FFlag::LuauClonePublicInterfaceLess) + if (FFlag::LuauClonePublicInterfaceLess2) ty = clonePublicInterface.cloneType(ty); else ty = clone(ty, interfaceTypes, cloneState); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index f8f8b97f8..f383f5eae 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -533,7 +533,7 @@ static bool areNormalizedClasses(const NormalizedClassType& tys) static bool isPlainTyvar(TypeId ty) { - return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty))); + return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty)) || get(ty)); } static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) @@ -1380,7 +1380,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor } else if (FFlag::LuauTransitiveSubtyping && get(here.tops)) return true; - else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there))) + else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there)) || + get(there)) { if (tyvarIndex(there) <= ignoreSmallerTyvars) return true; @@ -1460,6 +1461,10 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor } else if (!FFlag::LuauNormalizeBlockedTypes && get(there)) LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); + else if (get(there)) + { + // nothing + } else LUAU_ASSERT(!"Unreachable"); @@ -2544,7 +2549,8 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) return false; return true; } - else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there))) + else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there)) || + get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; @@ -2856,7 +2862,8 @@ bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, Not return ok; } -bool isConsistentSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +bool isConsistentSubtype( + TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 160647a05..935d85d71 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -9,7 +9,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false) -LUAU_FASTFLAG(LuauClonePublicInterfaceLess) +LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false) LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) @@ -17,6 +17,181 @@ LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) namespace Luau { +static TypeId DEPRECATED_shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) +{ + ty = log->follow(ty); + + TypeId result = ty; + + if (auto pty = log->pending(ty)) + ty = &pty->pending; + + if (const FunctionType* ftv = get(ty)) + { + FunctionType clone = FunctionType{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; + clone.generics = ftv->generics; + clone.genericPacks = ftv->genericPacks; + clone.magicFunction = ftv->magicFunction; + clone.dcrMagicFunction = ftv->dcrMagicFunction; + clone.dcrMagicRefinement = ftv->dcrMagicRefinement; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + result = dest.addType(std::move(clone)); + } + else if (const TableType* ttv = get(ty)) + { + LUAU_ASSERT(!ttv->boundTo); + TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; + clone.definitionModuleName = ttv->definitionModuleName; + clone.definitionLocation = ttv->definitionLocation; + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.instantiatedTypeParams = ttv->instantiatedTypeParams; + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; + clone.tags = ttv->tags; + result = dest.addType(std::move(clone)); + } + else if (const MetatableType* mtv = get(ty)) + { + MetatableType clone = MetatableType{mtv->table, mtv->metatable}; + clone.syntheticName = mtv->syntheticName; + result = dest.addType(std::move(clone)); + } + else if (const UnionType* utv = get(ty)) + { + UnionType clone; + clone.options = utv->options; + result = dest.addType(std::move(clone)); + } + else if (const IntersectionType* itv = get(ty)) + { + IntersectionType clone; + clone.parts = itv->parts; + result = dest.addType(std::move(clone)); + } + else if (const PendingExpansionType* petv = get(ty)) + { + PendingExpansionType clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; + result = dest.addType(std::move(clone)); + } + else if (const NegationType* ntv = get(ty)) + { + result = dest.addType(NegationType{ntv->ty}); + } + else + return result; + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; +} + +static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) +{ + if (!FFlag::LuauClonePublicInterfaceLess2) + return DEPRECATED_shallowClone(ty, dest, log, alwaysClone); + + auto go = [ty, &dest, alwaysClone](auto&& a) { + using T = std::decay_t; + + if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + { + // This should never happen, but visit() cannot see it. + LUAU_ASSERT(!"shallowClone didn't follow its argument!"); + return dest.addType(BoundType{a.boundTo}); + } + else if constexpr (std::is_same_v) + return dest.addType(a); + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return dest.addType(a); + else if constexpr (std::is_same_v) + { + FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf}; + clone.generics = a.generics; + clone.genericPacks = a.genericPacks; + clone.magicFunction = a.magicFunction; + clone.dcrMagicFunction = a.dcrMagicFunction; + clone.dcrMagicRefinement = a.dcrMagicRefinement; + clone.tags = a.tags; + clone.argNames = a.argNames; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + LUAU_ASSERT(!a.boundTo); + TableType clone = TableType{a.props, a.indexer, a.level, a.scope, a.state}; + clone.definitionModuleName = a.definitionModuleName; + clone.definitionLocation = a.definitionLocation; + clone.name = a.name; + clone.syntheticName = a.syntheticName; + clone.instantiatedTypeParams = a.instantiatedTypeParams; + clone.instantiatedTypePackParams = a.instantiatedTypePackParams; + clone.tags = a.tags; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + MetatableType clone = MetatableType{a.table, a.metatable}; + clone.syntheticName = a.syntheticName; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + UnionType clone; + clone.options = a.options; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + IntersectionType clone; + clone.parts = a.parts; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + if (alwaysClone) + { + ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName}; + return dest.addType(std::move(clone)); + } + else + return ty; + } + else if constexpr (std::is_same_v) + return dest.addType(NegationType{a.ty}); + else + static_assert(always_false_v, "Non-exhaustive shallowClone switch"); + }; + + ty = log->follow(ty); + + if (auto pty = log->pending(ty)) + ty = &pty->pending; + + TypeId resTy = visit(go, ty->ty); + if (resTy != ty) + asMutable(resTy)->documentationSymbol = ty->documentationSymbol; + + return resTy; +} + void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(ty == log->follow(ty)); @@ -469,7 +644,7 @@ std::optional Substitution::substitute(TypePackId tp) TypeId Substitution::clone(TypeId ty) { - return shallowClone(ty, *arena, log, /* alwaysClone */ FFlag::LuauClonePublicInterfaceLess); + return shallowClone(ty, *arena, log, /* alwaysClone */ FFlag::LuauClonePublicInterfaceLess2); } TypePackId Substitution::clone(TypePackId tp) @@ -494,7 +669,7 @@ TypePackId Substitution::clone(TypePackId tp) clone.hidden = vtp->hidden; return addTypePack(std::move(clone)); } - else if (FFlag::LuauClonePublicInterfaceLess) + else if (FFlag::LuauClonePublicInterfaceLess2) { return addTypePack(*tp); } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index d0c539845..5c0f48fae 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -85,6 +85,11 @@ struct FindCyclicTypes final : TypeVisitor { return false; } + + bool visit(TypeId, const PendingExpansionType&) override + { + return false; + } }; template @@ -1518,7 +1523,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) } else if constexpr (std::is_same_v) { - return "call " + tos(c.fn) + " with { result = " + tos(c.result) + " }"; + return "call " + tos(c.fn) + "( " + tos(c.argsPack) + " )" + " with { result = " + tos(c.result) + " }"; } else if constexpr (std::is_same_v) { diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 42fa40a54..021d95285 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -26,7 +26,6 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) -LUAU_FASTFLAGVARIABLE(LuauMatchReturnsOptionalString, false); namespace Luau { @@ -432,7 +431,7 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) } BlockedType::BlockedType() - : index(FFlag::LuauNormalizeBlockedTypes ? Unifiable::freshIndex() : ++DEPRECATED_nextIndex) + : index(FFlag::LuauNormalizeBlockedTypes ? Unifiable::freshIndex() : ++DEPRECATED_nextIndex) { } @@ -1219,12 +1218,12 @@ static std::vector parsePatternString(NotNull builtinTypes if (i + 1 < size && data[i + 1] == ')') { i++; - result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalNumberType : builtinTypes->numberType); + result.push_back(builtinTypes->optionalNumberType); continue; } ++depth; - result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalStringType : builtinTypes->stringType); + result.push_back(builtinTypes->optionalStringType); } else if (data[i] == ')') { @@ -1242,7 +1241,7 @@ static std::vector parsePatternString(NotNull builtinTypes return std::vector(); if (result.empty()) - result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalStringType : builtinTypes->stringType); + result.push_back(builtinTypes->optionalStringType); return result; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a160a1d26..ec71a583a 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -568,6 +568,10 @@ struct TypeChecker2 { // nothing } + else if (isOptional(iteratorTy)) + { + reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location); + } else if (std::optional iterMmTy = findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) { @@ -973,6 +977,12 @@ struct TypeChecker2 else if (auto utv = get(functionType)) { // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. + // Another scenario we might run into it is if the union has a nil member. In this case, we want to throw an error + if (isOptional(functionType)) + { + reportError(OptionalValueAccess{functionType}, call->location); + return; + } std::optional fst; for (TypeId ty : utv) { @@ -1187,6 +1197,8 @@ struct TypeChecker2 else reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); } + else if (get(exprType) && isOptional(exprType)) + reportError(OptionalValueAccess{exprType}, indexExpr->location); } void visit(AstExprFunction* fn) @@ -1297,9 +1309,13 @@ struct TypeChecker2 DenseHashSet seen{nullptr}; int recursionCount = 0; + if (!hasLength(operandType, seen, &recursionCount)) { - reportError(NotATable{operandType}, expr->location); + if (isOptional(operandType)) + reportError(OptionalValueAccess{operandType}, expr->location); + else + reportError(NotATable{operandType}, expr->location); } } else if (expr->op == AstExprUnary::Op::Minus) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index abc652861..48ff6a209 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -689,11 +689,10 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std if (duplicateTypeAliases.contains({typealias->exported, name})) continue; - TypeId type = bindings[name].type; - if (get(follow(type))) + TypeId type = follow(bindings[name].type); + if (get(type)) { - Type* mty = asMutable(follow(type)); - mty->reassign(*errorRecoveryType(anyType)); + asMutable(type)->ty.emplace(errorRecoveryType(anyType)); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 6a9fadfad..310df766f 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -331,7 +331,7 @@ TypeId TypeReducer::reduce(TypeId ty) if (edge->irreducible) return edge->type; else - ty = edge->type; + ty = follow(edge->type); } else if (cyclics->contains(ty)) return ty; diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index dcb2d3673..abdc6c329 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -12,7 +12,7 @@ int freshIndex() { return ++nextIndex; } - + Free::Free(TypeLevel level) : index(++nextIndex) , level(level) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 9f30d11ba..5f01a6062 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -192,6 +192,18 @@ struct SkipCacheForType final : TypeOnceVisitor return false; } + bool visit(TypeId, const BlockedType&) override + { + result = true; + return false; + } + + bool visit(TypeId, const PendingExpansionType&) override + { + result = true; + return false; + } + bool visit(TypeId ty, const TableType&) override { // Types from other modules don't contain mutable elements and are ok to cache @@ -259,6 +271,12 @@ struct SkipCacheForType final : TypeOnceVisitor return false; } + bool visit(TypePackId tp, const BlockedTypePack&) override + { + result = true; + return false; + } + const DenseHashMap& skipCacheForType; const TypeArena* typeArena = nullptr; bool result = false; @@ -386,6 +404,12 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i tryUnify_(subTy, superTy, isFunctionCall, isIntersection); } +static bool isBlocked(const TxnLog& log, TypeId ty) +{ + ty = log.follow(ty); + return get(ty) || get(ty); +} + void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); @@ -531,11 +555,15 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (log.getMutable(subTy) && log.getMutable(superTy)) + if (isBlocked(log, subTy) && isBlocked(log, superTy)) { blockedTypes.push_back(subTy); blockedTypes.push_back(superTy); } + else if (isBlocked(log, subTy)) + blockedTypes.push_back(subTy); + else if (isBlocked(log, superTy)) + blockedTypes.push_back(superTy); else if (const UnionType* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); @@ -890,7 +918,8 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp if (!subNorm || !superNorm) return reportError(location, UnificationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + innerState.tryUnifyNormalizedTypes( + subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); if (!innerState.failure) diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 0179967af..1190e9754 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -24,20 +24,24 @@ class AssemblyBuilderA64 // Moves void mov(RegisterA64 dst, RegisterA64 src); - void mov(RegisterA64 dst, uint16_t src, int shift = 0); + void mov(RegisterA64 dst, int src); // macro + + // Moves of 32-bit immediates get decomposed into one or more of these + void movz(RegisterA64 dst, uint16_t src, int shift = 0); + void movn(RegisterA64 dst, uint16_t src, int shift = 0); void movk(RegisterA64 dst, uint16_t src, int shift = 0); // Arithmetics void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); - void add(RegisterA64 dst, RegisterA64 src1, int src2); + void add(RegisterA64 dst, RegisterA64 src1, uint16_t src2); void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); - void sub(RegisterA64 dst, RegisterA64 src1, int src2); + void sub(RegisterA64 dst, RegisterA64 src1, uint16_t src2); void neg(RegisterA64 dst, RegisterA64 src); // Comparisons // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm void cmp(RegisterA64 src1, RegisterA64 src2); - void cmp(RegisterA64 src1, int src2); + void cmp(RegisterA64 src1, uint16_t src2); // Bitwise // Note: shifted-register support and bitfield operations are omitted for simplicity @@ -63,11 +67,13 @@ class AssemblyBuilderA64 void ldrsb(RegisterA64 dst, AddressA64 src); void ldrsh(RegisterA64 dst, AddressA64 src); void ldrsw(RegisterA64 dst, AddressA64 src); + void ldp(RegisterA64 dst1, RegisterA64 dst2, AddressA64 src); // Store void str(RegisterA64 src, AddressA64 dst); void strb(RegisterA64 src, AddressA64 dst); void strh(RegisterA64 src, AddressA64 dst); + void stp(RegisterA64 src1, RegisterA64 src2, AddressA64 dst); // Control flow // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks @@ -84,6 +90,9 @@ class AssemblyBuilderA64 void adr(RegisterA64 dst, uint64_t value); void adr(RegisterA64 dst, double value); + // Address of code (label) + void adr(RegisterA64 dst, Label& label); + // Run final checks bool finalize(); @@ -113,6 +122,9 @@ class AssemblyBuilderA64 const bool logText = false; + // Maximum immediate argument to functions like add/sub/cmp + static constexpr size_t kMaxImmediate = (1 << 12) - 1; + private: // Instruction archetypes void place0(const char* name, uint32_t word); @@ -127,6 +139,8 @@ class AssemblyBuilderA64 void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); void placeBR(const char* name, RegisterA64 src, uint32_t op); void placeADR(const char* name, RegisterA64 src, uint8_t op); + void placeADR(const char* name, RegisterA64 src, uint8_t op, Label& label); + void placeP(const char* name, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src, uint8_t op, uint8_t size); void place(uint32_t word); @@ -146,6 +160,7 @@ class AssemblyBuilderA64 LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); LUAU_NOINLINE void log(const char* opcode, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, Label label); diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 5c2bc4dfc..470690b95 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -52,8 +52,8 @@ void computeCfgInfo(IrFunction& function); struct BlockIteratorWrapper { - uint32_t* itBegin = nullptr; - uint32_t* itEnd = nullptr; + const uint32_t* itBegin = nullptr; + const uint32_t* itEnd = nullptr; bool empty() const { @@ -65,19 +65,19 @@ struct BlockIteratorWrapper return size_t(itEnd - itBegin); } - uint32_t* begin() const + const uint32_t* begin() const { return itBegin; } - uint32_t* end() const + const uint32_t* end() const { return itEnd; } }; -BlockIteratorWrapper predecessors(CfgInfo& cfg, uint32_t blockIdx); -BlockIteratorWrapper successors(CfgInfo& cfg, uint32_t blockIdx); +BlockIteratorWrapper predecessors(const CfgInfo& cfg, uint32_t blockIdx); +BlockIteratorWrapper successors(const CfgInfo& cfg, uint32_t blockIdx); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 67e706324..e8b2bc621 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -385,17 +385,15 @@ enum class IrCmd : uint8_t LOP_SETLIST, // Call specified function - // A: unsigned int (bytecode instruction index) - // B: Rn (function, followed by arguments) - // C: int (argument count or -1 to use all arguments up to stack top) - // D: int (result count or -1 to preserve all results and adjust stack top) - // Note: return values are placed starting from Rn specified in 'B' + // A: Rn (function, followed by arguments) + // B: int (argument count or -1 to use all arguments up to stack top) + // C: int (result count or -1 to preserve all results and adjust stack top) + // Note: return values are placed starting from Rn specified in 'A' LOP_CALL, // Return specified values from the function - // A: unsigned int (bytecode instruction index) - // B: Rn (value start) - // C: int (result count or -1 to return all values up to stack top) + // A: Rn (value start) + // B: int (result count or -1 to return all values up to stack top) LOP_RETURN, // Adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue @@ -421,10 +419,9 @@ enum class IrCmd : uint8_t LOP_FORGPREP_XNEXT_FALLBACK, // Perform `and` or `or` operation (selecting lhs or rhs based on whether the lhs is truthy) and put the result into target register - // A: unsigned int (bytecode instruction index) - // B: Rn (target) - // C: Rn (lhs) - // D: Rn or Kn (rhs) + // A: Rn (target) + // B: Rn (lhs) + // C: Rn or Kn (rhs) LOP_AND, LOP_ANDK, LOP_OR, @@ -790,12 +787,6 @@ struct IrFunction return value.valueDouble; } - IrCondition conditionOp(IrOp op) - { - LUAU_ASSERT(op.kind == IrOpKind::Condition); - return IrCondition(op.index); - } - uint32_t getBlockIndex(const IrBlock& block) { // Can only be called with blocks from our vector @@ -804,5 +795,29 @@ struct IrFunction } }; +inline IrCondition conditionOp(IrOp op) +{ + LUAU_ASSERT(op.kind == IrOpKind::Condition); + return IrCondition(op.index); +} + +inline int vmRegOp(IrOp op) +{ + LUAU_ASSERT(op.kind == IrOpKind::VmReg); + return op.index; +} + +inline int vmConstOp(IrOp op) +{ + LUAU_ASSERT(op.kind == IrOpKind::VmConst); + return op.index; +} + +inline int vmUpvalueOp(IrOp op) +{ + LUAU_ASSERT(op.kind == IrOpKind::VmUpvalue); + return op.index; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index ae517e894..1bc31d9d7 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -19,9 +19,9 @@ const char* getBlockKindName(IrBlockKind kind); struct IrToStringContext { std::string& result; - std::vector& blocks; - std::vector& constants; - CfgInfo& cfg; + const std::vector& blocks; + const std::vector& constants; + const CfgInfo& cfg; }; void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index); @@ -33,13 +33,13 @@ void toString(std::string& result, IrConst constant); void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo); void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo); // Block title -std::string toString(IrFunction& function, bool includeUseInfo); +std::string toString(const IrFunction& function, bool includeUseInfo); -std::string dump(IrFunction& function); +std::string dump(const IrFunction& function); -std::string toDot(IrFunction& function, bool includeInst); +std::string toDot(const IrFunction& function, bool includeInst); -std::string dumpDot(IrFunction& function, bool includeInst); +std::string dumpDot(const IrFunction& function, bool includeInst); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index 308747d26..bedd27409 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -45,9 +45,30 @@ void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src) placeSR2("mov", dst, src, 0b01'01010); } -void AssemblyBuilderA64::mov(RegisterA64 dst, uint16_t src, int shift) +void AssemblyBuilderA64::mov(RegisterA64 dst, int src) { - placeI16("mov", dst, src, 0b10'100101, shift); + if (src >= 0) + { + movz(dst, src & 0xffff); + if (src > 0xffff) + movk(dst, src >> 16, 16); + } + else + { + movn(dst, ~src & 0xffff); + if (src < -0x10000) + movk(dst, (src >> 16) & 0xffff, 16); + } +} + +void AssemblyBuilderA64::movz(RegisterA64 dst, uint16_t src, int shift) +{ + placeI16("movz", dst, src, 0b10'100101, shift); +} + +void AssemblyBuilderA64::movn(RegisterA64 dst, uint16_t src, int shift) +{ + placeI16("movn", dst, src, 0b00'100101, shift); } void AssemblyBuilderA64::movk(RegisterA64 dst, uint16_t src, int shift) @@ -60,7 +81,7 @@ void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2 placeSR3("add", dst, src1, src2, 0b00'01011, shift); } -void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, int src2) +void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, uint16_t src2) { placeI12("add", dst, src1, src2, 0b00'10001); } @@ -70,7 +91,7 @@ void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2 placeSR3("sub", dst, src1, src2, 0b10'01011, shift); } -void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, int src2) +void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, uint16_t src2) { placeI12("sub", dst, src1, src2, 0b10'10001); } @@ -87,7 +108,7 @@ void AssemblyBuilderA64::cmp(RegisterA64 src1, RegisterA64 src2) placeSR3("cmp", dst, src1, src2, 0b11'01011); } -void AssemblyBuilderA64::cmp(RegisterA64 src1, int src2) +void AssemblyBuilderA64::cmp(RegisterA64 src1, uint16_t src2) { RegisterA64 dst = src1.kind == KindA64::x ? xzr : wzr; @@ -186,6 +207,14 @@ void AssemblyBuilderA64::ldrsw(RegisterA64 dst, AddressA64 src) placeA("ldrsw", dst, src, 0b11100010, 0b10); } +void AssemblyBuilderA64::ldp(RegisterA64 dst1, RegisterA64 dst2, AddressA64 src) +{ + LUAU_ASSERT(dst1.kind == KindA64::x || dst1.kind == KindA64::w); + LUAU_ASSERT(dst1.kind == dst2.kind); + + placeP("ldp", dst1, dst2, src, 0b101'0'010'1, 0b10 | uint8_t(dst1.kind == KindA64::x)); +} + void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst) { LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w); @@ -207,6 +236,14 @@ void AssemblyBuilderA64::strh(RegisterA64 src, AddressA64 dst) placeA("strh", src, dst, 0b11100000, 0b01); } +void AssemblyBuilderA64::stp(RegisterA64 src1, RegisterA64 src2, AddressA64 dst) +{ + LUAU_ASSERT(src1.kind == KindA64::x || src1.kind == KindA64::w); + LUAU_ASSERT(src1.kind == src2.kind); + + placeP("stp", src1, src2, dst, 0b101'0'010'0, 0b10 | uint8_t(src1.kind == KindA64::x)); +} + void AssemblyBuilderA64::b(Label& label) { // Note: we aren't using 'b' form since it has a 26-bit immediate which requires custom fixup logic @@ -276,6 +313,11 @@ void AssemblyBuilderA64::adr(RegisterA64 dst, double value) patchImm19(location, -int(location) - int((data.size() - pos) / 4)); } +void AssemblyBuilderA64::adr(RegisterA64 dst, Label& label) +{ + placeADR("adr", dst, 0b10000, label); +} + bool AssemblyBuilderA64::finalize() { code.resize(codePos - code.data()); @@ -511,6 +553,32 @@ void AssemblyBuilderA64::placeADR(const char* name, RegisterA64 dst, uint8_t op) commit(); } +void AssemblyBuilderA64::placeADR(const char* name, RegisterA64 dst, uint8_t op, Label& label) +{ + LUAU_ASSERT(dst.kind == KindA64::x); + + place(dst.index | (op << 24)); + commit(); + + patchLabel(label); + + if (logText) + log(name, dst, label); +} + +void AssemblyBuilderA64::placeP(const char* name, RegisterA64 src1, RegisterA64 src2, AddressA64 dst, uint8_t op, uint8_t size) +{ + if (logText) + log(name, src1, src2, dst); + + LUAU_ASSERT(dst.kind == AddressKindA64::imm); + LUAU_ASSERT(dst.data >= -128 * (1 << size) && dst.data <= 127 * (1 << size)); + LUAU_ASSERT(dst.data % (1 << size) == 0); + + place(src1.index | (dst.base.index << 5) | (src2.index << 10) | (((dst.data >> size) & 127) << 15) | (op << 22) | (size << 31)); + commit(); +} + void AssemblyBuilderA64::place(uint32_t word) { LUAU_ASSERT(codePos < codeEnd); @@ -628,6 +696,17 @@ void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, AddressA64 src text.append("\n"); } +void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src) +{ + logAppend(" %-12s", opcode); + log(dst1); + text.append(","); + log(dst2); + text.append(","); + log(src); + text.append("\n"); +} + void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src) { logAppend(" %-12s", opcode); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index ce490f916..5ef5ba64f 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -6,6 +6,8 @@ #include "Luau/CodeBlockUnwind.h" #include "Luau/IrAnalysis.h" #include "Luau/IrBuilder.h" +#include "Luau/IrDump.h" +#include "Luau/IrUtils.h" #include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeFinalX64.h" @@ -13,19 +15,24 @@ #include "Luau/UnwindBuilderDwarf2.h" #include "Luau/UnwindBuilderWin.h" -#include "Luau/AssemblyBuilderX64.h" #include "Luau/AssemblyBuilderA64.h" +#include "Luau/AssemblyBuilderX64.h" #include "CustomExecUtils.h" -#include "CodeGenX64.h" +#include "NativeState.h" + #include "CodeGenA64.h" +#include "EmitCommonA64.h" +#include "IrLoweringA64.h" + +#include "CodeGenX64.h" #include "EmitCommonX64.h" #include "EmitInstructionX64.h" #include "IrLoweringX64.h" -#include "NativeState.h" #include "lapi.h" +#include #include #if defined(__x86_64__) || defined(_M_X64) @@ -60,6 +67,148 @@ static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) return result; } +template +static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) +{ + // While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined + std::vector sortedBlocks; + sortedBlocks.reserve(function.blocks.size()); + for (uint32_t i = 0; i < function.blocks.size(); i++) + sortedBlocks.push_back(i); + + std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) { + const IrBlock& a = function.blocks[idxA]; + const IrBlock& b = function.blocks[idxB]; + + // Place fallback blocks at the end + if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback)) + return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); + + // Try to order by instruction order + return a.start < b.start; + }); + + DenseHashMap bcLocations{~0u}; + + // Create keys for IR assembly locations that original bytecode instruction are interested in + for (const auto& [irLocation, asmLocation] : function.bcMapping) + { + if (irLocation != ~0u) + bcLocations[irLocation] = 0; + } + + DenseHashMap indexIrToBc{~0u}; + bool outputEnabled = options.includeAssembly || options.includeIr; + + if (outputEnabled && options.annotator) + { + // Create reverse mapping from IR location to bytecode location + for (size_t i = 0; i < function.bcMapping.size(); ++i) + { + uint32_t irLocation = function.bcMapping[i].irLocation; + + if (irLocation != ~0u) + indexIrToBc[irLocation] = uint32_t(i); + } + } + + IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; + + // We use this to skip outlined fallback blocks from IR/asm text output + size_t textSize = build.text.length(); + uint32_t codeSize = build.getCodeSize(); + bool seenFallback = false; + + IrBlock dummy; + dummy.start = ~0u; + + for (size_t i = 0; i < sortedBlocks.size(); ++i) + { + uint32_t blockIndex = sortedBlocks[i]; + + IrBlock& block = function.blocks[blockIndex]; + + if (block.kind == IrBlockKind::Dead) + continue; + + LUAU_ASSERT(block.start != ~0u); + LUAU_ASSERT(block.finish != ~0u); + + // If we want to skip fallback code IR/asm, we'll record when those blocks start once we see them + if (block.kind == IrBlockKind::Fallback && !seenFallback) + { + textSize = build.text.length(); + codeSize = build.getCodeSize(); + seenFallback = true; + } + + if (options.includeIr) + { + build.logAppend("# "); + toStringDetailed(ctx, block, blockIndex, /* includeUseInfo */ true); + } + + build.setLabel(block.label); + + for (uint32_t index = block.start; index <= block.finish; index++) + { + LUAU_ASSERT(index < function.instructions.size()); + + // If IR instruction is the first one for the original bytecode, we can annotate it with source code text + if (outputEnabled && options.annotator) + { + if (uint32_t* bcIndex = indexIrToBc.find(index)) + options.annotator(options.annotatorContext, build.text, bytecodeid, *bcIndex); + } + + // If bytecode needs the location of this instruction for jumps, record it + if (uint32_t* bcLocation = bcLocations.find(index)) + { + Label label = (index == block.start) ? block.label : build.setLabel(); + *bcLocation = build.getLabelOffset(label); + } + + IrInst& inst = function.instructions[index]; + + // Skip pseudo instructions, but make sure they are not used at this stage + // This also prevents them from getting into text output when that's enabled + if (isPseudo(inst.cmd)) + { + LUAU_ASSERT(inst.useCount == 0); + continue; + } + + if (options.includeIr) + { + build.logAppend("# "); + toStringDetailed(ctx, inst, index, /* includeUseInfo */ true); + } + + IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; + + lowering.lowerInst(inst, index, next); + } + + if (options.includeIr) + build.logAppend("#\n"); + } + + if (outputEnabled && !options.includeOutlinedCode && seenFallback) + { + build.text.resize(textSize); + + if (options.includeAssembly) + build.logAppend("; skipping %u bytes of outlined code\n", unsigned((build.getCodeSize() - codeSize) * sizeof(build.code[0]))); + } + + // Copy assembly locations of IR instructions that are mapped to bytecode instructions + for (auto& [irLocation, asmLocation] : function.bcMapping) + { + if (irLocation != ~0u) + asmLocation = bcLocations[irLocation]; + } +} + [[maybe_unused]] static void lowerIr( X64::AssemblyBuilderX64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { @@ -69,24 +218,34 @@ static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); - X64::IrLoweringX64 lowering(build, helpers, data, proto, ir.function); + X64::IrLoweringX64 lowering(build, helpers, data, ir.function); - lowering.lower(options); + lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } [[maybe_unused]] static void lowerIr( A64::AssemblyBuilderA64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { - Label start = build.setLabel(); + if (A64::IrLoweringA64::canLower(ir.function)) + { + A64::IrLoweringA64 lowering(build, helpers, data, proto, ir.function); - build.mov(A64::x0, 1); // finish function in VM - build.ret(); + lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); + } + else + { + // TODO: This is only needed while we don't support all IR opcodes + // When we can't translate some parts of the function, we instead encode a dummy assembly sequence that hands off control to VM + // In the future we could return nullptr from assembleFunction and handle it because there may be other reasons for why we refuse to assemble. + Label start = build.setLabel(); - // TODO: This is only needed while we don't support all IR opcodes - // When we can't translate some parts of the function, we instead encode a dummy assembly sequence that hands off control to VM - // In the future we could return nullptr from assembleFunction and handle it because there may be other reasons for why we refuse to assemble. - for (int i = 0; i < proto->sizecode; i++) - ir.function.bcMapping[i].asmLocation = build.getLabelOffset(start); + build.mov(A64::x0, 1); // finish function in VM + build.ldr(A64::x1, A64::mem(A64::rNativeContext, offsetof(NativeContext, gateExit))); + build.br(A64::x1); + + for (int i = 0; i < proto->sizecode; i++) + ir.function.bcMapping[i].asmLocation = build.getLabelOffset(start); + } } template @@ -123,15 +282,13 @@ static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, IrBuilder ir; ir.buildFunctionIr(proto); + computeCfgInfo(ir.function); + if (!FFlag::DebugCodegenNoOpt) { constPropInBlockChains(ir); } - // TODO: cfg info has to be computed earlier to use in optimizations - // It's done here to appear in text output and to measure performance impact on code generation - computeCfgInfo(ir.function); - lowerIr(build, ir, data, helpers, proto, options); if (build.logText) @@ -217,7 +374,8 @@ bool isSupported() return true; #elif defined(__aarch64__) - return true; + // TODO: A64 codegen does not generate correct unwind info at the moment so it requires longjmp instead of C++ exceptions + return bool(LUA_USE_LONGJMP); #else return false; #endif @@ -300,7 +458,9 @@ void compile(lua_State* L, int idx) gatherFunctions(protos, clvalue(func)->l.p); ModuleHelpers helpers; -#if !defined(__aarch64__) +#if defined(__aarch64__) + A64::assembleHelpers(build, helpers); +#else X64::assembleHelpers(build, helpers); #endif @@ -359,7 +519,9 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) gatherFunctions(protos, clvalue(func)->l.p); ModuleHelpers helpers; -#if !defined(__aarch64__) +#if defined(__aarch64__) + A64::assembleHelpers(build, helpers); +#else X64::assembleHelpers(build, helpers); #endif @@ -373,8 +535,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) build.finalize(); if (options.outputBinary) - return std::string( - reinterpret_cast(build.code.data()), reinterpret_cast(build.code.data() + build.code.size())) + + return std::string(reinterpret_cast(build.code.data()), reinterpret_cast(build.code.data() + build.code.size())) + std::string(build.data.begin(), build.data.end()); else return build.text; diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 94d6f2e3f..028b3327c 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -6,6 +6,7 @@ #include "CustomExecUtils.h" #include "NativeState.h" +#include "EmitCommonA64.h" #include "lstate.h" @@ -21,26 +22,50 @@ bool initEntryFunction(NativeState& data) AssemblyBuilderA64 build(/* logText= */ false); UnwindBuilder& unwind = *data.unwindBuilder.get(); + // Arguments: x0 = lua_State*, x1 = Proto*, x2 = native code pointer to jump to, x3 = NativeContext* + unwind.start(); - unwind.allocStack(8); // TODO: this is only necessary to align stack by 16 bytes, as start() allocates 8b return pointer + unwind.allocStack(8); // TODO: this is just a hack to make UnwindBuilder assertions cooperate + + // prologue + build.sub(sp, sp, kStackSize); + build.stp(x29, x30, mem(sp)); // fp, lr - // TODO: prologue goes here + // stash non-volatile registers used for execution environment + build.stp(x19, x20, mem(sp, 16)); + build.stp(x21, x22, mem(sp, 32)); + build.stp(x23, x24, mem(sp, 48)); + + build.mov(x29, sp); // this is only necessary if we maintain frame pointers, which we do in the JIT for now unwind.finish(); size_t prologueSize = build.setLabel().location; // Setup native execution environment - // TODO: figure out state layout + build.mov(rState, x0); + build.mov(rNativeContext, x3); + + build.ldr(rBase, mem(x0, offsetof(lua_State, base))); // L->base + build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k + build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code - // Jump to the specified instruction; further control flow will be handled with custom ABI with register setup from EmitCommonX64.h + build.ldr(x9, mem(x0, offsetof(lua_State, ci))); // L->ci + build.ldr(x9, mem(x9, offsetof(CallInfo, func))); // L->ci->func + build.ldr(rClosure, mem(x9, offsetof(TValue, value.gc))); // L->ci->func->value.gc aka cl + + // Jump to the specified instruction; further control flow will be handled with custom ABI with register setup from EmitCommonA64.h build.br(x2); // Even though we jumped away, we will return here in the end Label returnOff = build.setLabel(); // Cleanup and exit - // TODO: epilogue + build.ldp(x23, x24, mem(sp, 48)); + build.ldp(x21, x22, mem(sp, 32)); + build.ldp(x19, x20, mem(sp, 16)); + build.ldp(x29, x30, mem(sp)); // fp, lr + build.add(sp, sp, kStackSize); build.ret(); @@ -59,11 +84,24 @@ bool initEntryFunction(NativeState& data) // specified by the unwind information of the entry function unwind.setBeginOffset(prologueSize); - data.context.gateExit = data.context.gateEntry + returnOff.location; + data.context.gateExit = data.context.gateEntry + build.getLabelOffset(returnOff); return true; } +void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers) +{ + if (build.logText) + build.logAppend("; exitContinueVm\n"); + helpers.exitContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ true); + + if (build.logText) + build.logAppend("; exitNoContinueVm\n"); + helpers.exitNoContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ false); +} + } // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenA64.h b/CodeGen/src/CodeGenA64.h index 5043e5c67..7b792cc1b 100644 --- a/CodeGen/src/CodeGenA64.h +++ b/CodeGen/src/CodeGenA64.h @@ -7,11 +7,15 @@ namespace CodeGen { struct NativeState; +struct ModuleHelpers; namespace A64 { +class AssemblyBuilderA64; + bool initEntryFunction(NativeState& data); +void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers); } // namespace A64 } // namespace CodeGen diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 77047a7ab..26568c300 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -126,5 +126,42 @@ void callEpilogC(lua_State* L, int nresults, int n) L->top = (nresults == LUA_MULTRET) ? res : cip->top; } +const Instruction* returnFallback(lua_State* L, StkId ra, int n) +{ + // ci is our callinfo, cip is our parent + CallInfo* ci = L->ci; + CallInfo* cip = ci - 1; + + StkId res = ci->func; // note: we assume CALL always puts func+args and expects results to start at func + + StkId vali = ra; + StkId valend = (n == LUA_MULTRET) ? L->top : ra + n; // copy as much as possible for MULTRET calls, and only as much as needed otherwise + + int nresults = ci->nresults; + + // copy return values into parent stack (but only up to nresults!), fill the rest with nil + // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally + int i; + for (i = nresults; i != 0 && vali < valend; i--) + setobj2s(L, res++, vali++); + while (i-- > 0) + setnilvalue(res++); + + // pop the stack frame + L->ci = cip; + L->base = cip->base; + L->top = (nresults == LUA_MULTRET) ? res : cip->top; + + // we're done! + if (LUAU_UNLIKELY(ci->flags & LUA_CALLINFO_RETURN)) + { + L->top = res; + return NULL; + } + + LUAU_ASSERT(isLua(cip)); + return cip->savedpc; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index ca190213e..5d37bfd16 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -16,5 +16,7 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc); Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults); void callEpilogC(lua_State* L, int nresults, int n); +const Instruction* returnFallback(lua_State* L, StkId ra, int n); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitCommonA64.cpp b/CodeGen/src/EmitCommonA64.cpp new file mode 100644 index 000000000..66810d379 --- /dev/null +++ b/CodeGen/src/EmitCommonA64.cpp @@ -0,0 +1,75 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "EmitCommonA64.h" + +#include "NativeState.h" +#include "CustomExecUtils.h" + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +void emitExit(AssemblyBuilderA64& build, bool continueInVm) +{ + build.mov(x0, continueInVm); + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, gateExit))); + build.br(x1); +} + +void emitUpdateBase(AssemblyBuilderA64& build) +{ + build.ldr(rBase, mem(rState, offsetof(lua_State, base))); +} + +void emitSetSavedPc(AssemblyBuilderA64& build, int pcpos) +{ + if (pcpos * sizeof(Instruction) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(x0, rCode, uint16_t(pcpos * sizeof(Instruction))); + } + else + { + build.mov(x0, pcpos * sizeof(Instruction)); + build.add(x0, rCode, x0); + } + + build.ldr(x1, mem(rState, offsetof(lua_State, ci))); + build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); +} + +void emitInterrupt(AssemblyBuilderA64& build, int pcpos) +{ + Label skip; + + build.ldr(x2, mem(rState, offsetof(lua_State, global))); + build.ldr(x2, mem(x2, offsetof(global_State, cb.interrupt))); + build.cbz(x2, skip); + + emitSetSavedPc(build, pcpos + 1); // uses x0/x1 + + // Call interrupt + // TODO: This code should be outlined so that it can be shared by multiple interruptible instructions + build.mov(x0, rState); + build.mov(w1, -1); + build.blr(x2); + + // Check if we need to exit + build.ldrb(w0, mem(rState, offsetof(lua_State, status))); + build.cbz(w0, skip); + + // L->ci->savedpc-- + build.ldr(x0, mem(rState, offsetof(lua_State, ci))); + build.ldr(x1, mem(x0, offsetof(CallInfo, savedpc))); + build.sub(x1, x1, sizeof(Instruction)); + build.str(x1, mem(x0, offsetof(CallInfo, savedpc))); + + emitExit(build, /* continueInVm */ false); + + build.setLabel(skip); +} + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h new file mode 100644 index 000000000..251f6a351 --- /dev/null +++ b/CodeGen/src/EmitCommonA64.h @@ -0,0 +1,77 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AssemblyBuilderA64.h" + +#include "EmitCommon.h" + +#include "lobject.h" +#include "ltm.h" + +// AArch64 ABI reminder: +// Arguments: x0-x7, v0-v7 +// Return: x0, v0 (or x8 that points to the address of the resulting structure) +// Volatile: x9-x14, v16-v31 ("caller-saved", any call may change them) +// Non-volatile: x19-x28, v8-v15 ("callee-saved", preserved after calls, only bottom half of SIMD registers is preserved!) +// Reserved: x16-x18: reserved for linker/platform use; x29: frame pointer (unless omitted); x30: link register; x31: stack pointer + +namespace Luau +{ +namespace CodeGen +{ + +struct NativeState; + +namespace A64 +{ + +// Data that is very common to access is placed in non-volatile registers +constexpr RegisterA64 rState = x19; // lua_State* L +constexpr RegisterA64 rBase = x20; // StkId base +constexpr RegisterA64 rNativeContext = x21; // NativeContext* context +constexpr RegisterA64 rConstants = x22; // TValue* k +constexpr RegisterA64 rClosure = x23; // Closure* cl +constexpr RegisterA64 rCode = x24; // Instruction* code + +// Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point +// See CodeGenA64.cpp for layout +constexpr unsigned kStackSize = 64; // 8 stashed registers + +inline AddressA64 luauReg(int ri) +{ + return mem(rBase, ri * sizeof(TValue)); +} + +inline AddressA64 luauRegValue(int ri) +{ + return mem(rBase, ri * sizeof(TValue) + offsetof(TValue, value)); +} + +inline AddressA64 luauRegTag(int ri) +{ + return mem(rBase, ri * sizeof(TValue) + offsetof(TValue, tt)); +} + +inline AddressA64 luauConstant(int ki) +{ + return mem(rConstants, ki * sizeof(TValue)); +} + +inline AddressA64 luauConstantTag(int ki) +{ + return mem(rConstants, ki * sizeof(TValue) + offsetof(TValue, tt)); +} + +inline AddressA64 luauConstantValue(int ki) +{ + return mem(rConstants, ki * sizeof(TValue) + offsetof(TValue, value)); +} + +void emitExit(AssemblyBuilderA64& build, bool continueInVm); +void emitUpdateBase(AssemblyBuilderA64& build); +void emitSetSavedPc(AssemblyBuilderA64& build, int pcpos); // invalidates x0/x1 +void emitInterrupt(AssemblyBuilderA64& build, int pcpos); + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitInstructionA64.cpp b/CodeGen/src/EmitInstructionA64.cpp new file mode 100644 index 000000000..8289ee2ee --- /dev/null +++ b/CodeGen/src/EmitInstructionA64.cpp @@ -0,0 +1,59 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "EmitInstructionA64.h" + +#include "Luau/AssemblyBuilderA64.h" + +#include "EmitCommonA64.h" +#include "NativeState.h" +#include "CustomExecUtils.h" + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +void emitInstReturn(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int n) +{ + // callFallback(L, ra, n) + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(ra * sizeof(TValue))); + build.mov(w2, n); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, returnFallback))); + build.blr(x3); + + emitUpdateBase(build); + + // If the fallback requested an exit, we need to do this right away + build.cbz(x0, helpers.exitNoContinueVm); + + // Need to update state of the current function before we jump away + build.ldr(x1, mem(rState, offsetof(lua_State, ci))); // L->ci + build.ldr(x1, mem(x1, offsetof(CallInfo, func))); // L->ci->func + build.ldr(rClosure, mem(x1, offsetof(TValue, value.gc))); // L->ci->func->value.gc aka cl + + build.ldr(x1, mem(rClosure, offsetof(Closure, l.p))); // cl->l.p aka proto + + build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k + build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code + + // Get instruction index from instruction pointer + // To get instruction index from instruction pointer, we need to divide byte offset by 4 + // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out + build.sub(x2, x0, rCode); + build.add(x2, x2, x2); // TODO: this would not be necessary if we supported shifted register offsets in loads + + // We need to check if the new function can be executed natively + build.ldr(x1, mem(x1, offsetofProtoExecData)); + build.cbz(x1, helpers.exitContinueVm); + + // Get new instruction location and jump to it + build.ldr(x1, mem(x1, offsetof(NativeProto, instTargets))); + build.ldr(x1, mem(x1, x2)); + build.br(x1); +} + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitInstructionA64.h b/CodeGen/src/EmitInstructionA64.h new file mode 100644 index 000000000..7f15d819b --- /dev/null +++ b/CodeGen/src/EmitInstructionA64.h @@ -0,0 +1,20 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +namespace Luau +{ +namespace CodeGen +{ + +struct ModuleHelpers; + +namespace A64 +{ + +class AssemblyBuilderA64; + +void emitInstReturn(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int n); + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index e8f61ebb0..649498f55 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -4,12 +4,7 @@ #include "Luau/AssemblyBuilderX64.h" #include "CustomExecUtils.h" -#include "EmitBuiltinsX64.h" #include "EmitCommonX64.h" -#include "NativeState.h" - -#include "lobject.h" -#include "ltm.h" namespace Luau { @@ -18,16 +13,8 @@ namespace CodeGen namespace X64 { -void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos) +void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults) { - int ra = LUAU_INSN_A(*pc); - int nparams = LUAU_INSN_B(*pc) - 1; - int nresults = LUAU_INSN_C(*pc) - 1; - - emitInterrupt(build, pcpos); - - emitSetSavedPc(build, pcpos + 1); - build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(ra)); @@ -171,13 +158,8 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instr } } -void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos) +void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults) { - emitInterrupt(build, pcpos); - - int ra = LUAU_INSN_A(*pc); - int b = LUAU_INSN_B(*pc) - 1; - RegisterX64 ci = r8; RegisterX64 cip = r9; RegisterX64 res = rdi; @@ -196,7 +178,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins RegisterX64 counter = ecx; - if (b == 0) + if (actualResults == 0) { // Our instruction doesn't have any results, so just fill results expected in parent with 'nil' build.test(nresults, nresults); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0 @@ -210,7 +192,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins build.dec(counter); build.jcc(ConditionX64::NotZero, repeatNilLoop); } - else if (b == 1) + else if (actualResults == 1) { // Try setting our 1 result build.test(nresults, nresults); @@ -245,10 +227,10 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins build.lea(vali, luauRegAddress(ra)); // Copy as much as possible for MULTRET calls, and only as much as needed otherwise - if (b == LUA_MULTRET) + if (actualResults == LUA_MULTRET) build.mov(valend, qword[rState + offsetof(lua_State, top)]); // valend = L->top else - build.lea(valend, luauRegAddress(ra + b)); // valend = ra + b + build.lea(valend, luauRegAddress(ra + actualResults)); // valend = ra + actualResults build.mov(counter, nresults); @@ -333,24 +315,19 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins build.jmp(qword[rdx + rax * 2]); } -void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next) +void emitInstSetList(AssemblyBuilderX64& build, Label& next, int ra, int rb, int count, uint32_t index) { - int ra = LUAU_INSN_A(*pc); - int rb = LUAU_INSN_B(*pc); - int c = LUAU_INSN_C(*pc) - 1; - uint32_t index = pc[1]; - - OperandX64 last = index + c - 1; + OperandX64 last = index + count - 1; - // Using non-volatile 'rbx' for dynamic 'c' value (for LUA_MULTRET) to skip later recomputation - // We also keep 'c' scaled by sizeof(TValue) here as it helps in the loop below + // Using non-volatile 'rbx' for dynamic 'count' value (for LUA_MULTRET) to skip later recomputation + // We also keep 'count' scaled by sizeof(TValue) here as it helps in the loop below RegisterX64 cscaled = rbx; - if (c == LUA_MULTRET) + if (count == LUA_MULTRET) { RegisterX64 tmp = rax; - // c = L->top - rb + // count = L->top - rb build.mov(cscaled, qword[rState + offsetof(lua_State, top)]); build.lea(tmp, luauRegAddress(rb)); build.sub(cscaled, tmp); // Using byte difference @@ -360,7 +337,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne build.mov(tmp, qword[tmp + offsetof(CallInfo, top)]); build.mov(qword[rState + offsetof(lua_State, top)], tmp); - // last = index + c - 1; + // last = index + count - 1; last = edx; build.mov(last, dwordReg(cscaled)); build.shr(last, kTValueSizeLog2); @@ -394,9 +371,9 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne const int kUnrollSetListLimit = 4; - if (c != LUA_MULTRET && c <= kUnrollSetListLimit) + if (count != LUA_MULTRET && count <= kUnrollSetListLimit) { - for (int i = 0; i < c; ++i) + for (int i = 0; i < count; ++i) { // setobj2t(L, &array[index + i - 1], rb + i); build.vmovups(xmm0, luauRegValue(rb + i)); @@ -405,17 +382,17 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne } else { - LUAU_ASSERT(c != 0); + LUAU_ASSERT(count != 0); build.xor_(offset, offset); if (index != 1) build.add(arrayDst, (index - 1) * sizeof(TValue)); Label repeatLoop, endLoop; - OperandX64 limit = c == LUA_MULTRET ? cscaled : OperandX64(c * sizeof(TValue)); + OperandX64 limit = count == LUA_MULTRET ? cscaled : OperandX64(count * sizeof(TValue)); // If c is static, we will always do at least one iteration - if (c == LUA_MULTRET) + if (count == LUA_MULTRET) { build.cmp(offset, limit); build.jcc(ConditionX64::NotBelow, endLoop); @@ -556,14 +533,14 @@ static void emitInstAndX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c } } -void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc) +void emitInstAnd(AssemblyBuilderX64& build, int ra, int rb, int rc) { - emitInstAndX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauReg(LUAU_INSN_C(*pc))); + emitInstAndX(build, ra, rb, luauReg(rc)); } -void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc) +void emitInstAndK(AssemblyBuilderX64& build, int ra, int rb, int kc) { - emitInstAndX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstant(LUAU_INSN_C(*pc))); + emitInstAndX(build, ra, rb, luauConstant(kc)); } static void emitInstOrX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c) @@ -594,14 +571,14 @@ static void emitInstOrX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c) } } -void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc) +void emitInstOr(AssemblyBuilderX64& build, int ra, int rb, int rc) { - emitInstOrX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauReg(LUAU_INSN_C(*pc))); + emitInstOrX(build, ra, rb, luauReg(rc)); } -void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc) +void emitInstOrK(AssemblyBuilderX64& build, int ra, int rb, int kc) { - emitInstOrX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstant(LUAU_INSN_C(*pc))); + emitInstOrX(build, ra, rb, luauConstant(kc)); } void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux) diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index 6a8a3c0ee..880c9fa4f 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -3,11 +3,6 @@ #include -#include "ltm.h" - -typedef uint32_t Instruction; -typedef struct lua_TValue TValue; - namespace Luau { namespace CodeGen @@ -21,16 +16,16 @@ namespace X64 class AssemblyBuilderX64; -void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); -void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); -void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next); +void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); +void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults); +void emitInstSetList(AssemblyBuilderX64& build, Label& next, int ra, int rb, int count, uint32_t index); void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit); void emitinstForGLoopFallback(AssemblyBuilderX64& build, int pcpos, int ra, int aux, Label& loopRepeat); void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target); -void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc); -void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc); -void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc); -void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc); +void emitInstAnd(AssemblyBuilderX64& build, int ra, int rb, int rc); +void emitInstAndK(AssemblyBuilderX64& build, int ra, int rb, int kc); +void emitInstOr(AssemblyBuilderX64& build, int ra, int rb, int rc); +void emitInstOrK(AssemblyBuilderX64& build, int ra, int rb, int kc); void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux); void emitInstCoverage(AssemblyBuilderX64& build, int pcpos); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index b998487f9..6e77dfe44 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -244,12 +244,16 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& // A <- B, C case IrCmd::DO_ARITH: case IrCmd::GET_TABLE: - case IrCmd::SET_TABLE: use(inst.b); maybeUse(inst.c); // Argument can also be a VmConst def(inst.a); break; + case IrCmd::SET_TABLE: + use(inst.a); + use(inst.b); + maybeUse(inst.c); // Argument can also be a VmConst + break; // A <- B case IrCmd::DO_LEN: use(inst.b); @@ -301,13 +305,13 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& useRange(inst.c.index, function.intOp(inst.d)); break; case IrCmd::LOP_CALL: - use(inst.b); - useRange(inst.b.index + 1, function.intOp(inst.c)); + use(inst.a); + useRange(inst.a.index + 1, function.intOp(inst.b)); - defRange(inst.b.index, function.intOp(inst.d)); + defRange(inst.a.index, function.intOp(inst.c)); break; case IrCmd::LOP_RETURN: - useRange(inst.b.index, function.intOp(inst.c)); + useRange(inst.a.index, function.intOp(inst.b)); break; case IrCmd::FASTCALL: case IrCmd::INVOKE_FASTCALL: @@ -333,7 +337,9 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& useVarargs(inst.c.index); } - defRange(inst.b.index, function.intOp(inst.f)); + // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG + if (int count = function.intOp(inst.f); count != -1) + defRange(inst.b.index, count); break; case IrCmd::LOP_FORGLOOP: // First register is not used by instruction, we check that it's still 'nil' with CHECK_TAG @@ -352,20 +358,20 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: use(inst.b); break; - // B <- C, D + // A <- B, C case IrCmd::LOP_AND: case IrCmd::LOP_OR: + use(inst.b); use(inst.c); - use(inst.d); - def(inst.b); + def(inst.a); break; - // B <- C + // A <- B case IrCmd::LOP_ANDK: case IrCmd::LOP_ORK: - use(inst.c); + use(inst.b); - def(inst.b); + def(inst.a); break; case IrCmd::FALLBACK_GETGLOBAL: def(inst.b); @@ -405,8 +411,10 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& defRange(inst.b.index, 3); break; case IrCmd::ADJUST_STACK_TO_REG: + defRange(inst.a.index, -1); + break; case IrCmd::ADJUST_STACK_TO_TOP: - // While these can be considered as vararg producers and consumers, it is already handled in fastcall instruction + // While this can be considered to be a vararg consumer, it is already handled in fastcall instructions break; default: @@ -626,7 +634,7 @@ void computeCfgInfo(IrFunction& function) computeCfgLiveInOutRegSets(function); } -BlockIteratorWrapper predecessors(CfgInfo& cfg, uint32_t blockIdx) +BlockIteratorWrapper predecessors(const CfgInfo& cfg, uint32_t blockIdx) { LUAU_ASSERT(blockIdx < cfg.predecessorsOffsets.size()); @@ -636,7 +644,7 @@ BlockIteratorWrapper predecessors(CfgInfo& cfg, uint32_t blockIdx) return BlockIteratorWrapper{cfg.predecessors.data() + start, cfg.predecessors.data() + end}; } -BlockIteratorWrapper successors(CfgInfo& cfg, uint32_t blockIdx) +BlockIteratorWrapper successors(const CfgInfo& cfg, uint32_t blockIdx) { LUAU_ASSERT(blockIdx < cfg.successorsOffsets.size()); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index f1099cfac..239f7a8e6 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -132,7 +132,10 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstSetGlobal(*this, pc, i); break; case LOP_CALL: - inst(IrCmd::LOP_CALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1), constInt(LUAU_INSN_C(*pc) - 1)); + inst(IrCmd::INTERRUPT, constUint(i)); + inst(IrCmd::SET_SAVEDPC, constUint(i + 1)); + + inst(IrCmd::LOP_CALL, vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1), constInt(LUAU_INSN_C(*pc) - 1)); if (activeFastcallFallback) { @@ -144,7 +147,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } break; case LOP_RETURN: - inst(IrCmd::LOP_RETURN, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); + inst(IrCmd::INTERRUPT, constUint(i)); + + inst(IrCmd::LOP_RETURN, vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); break; case LOP_GETTABLE: translateInstGetTable(*this, pc, i); @@ -358,16 +363,16 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstForGPrepInext(*this, pc, i); break; case LOP_AND: - inst(IrCmd::LOP_AND, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); + inst(IrCmd::LOP_AND, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); break; case LOP_ANDK: - inst(IrCmd::LOP_ANDK, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); + inst(IrCmd::LOP_ANDK, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); break; case LOP_OR: - inst(IrCmd::LOP_OR, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); + inst(IrCmd::LOP_OR, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); break; case LOP_ORK: - inst(IrCmd::LOP_ORK, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); + inst(IrCmd::LOP_ORK, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); break; case LOP_COVERAGE: inst(IrCmd::LOP_COVERAGE, constUint(i)); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 3c4e420d8..53654d6a2 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -455,7 +455,7 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind } // Predecessor list - if (!ctx.cfg.predecessors.empty()) + if (index < ctx.cfg.predecessorsOffsets.size()) { BlockIteratorWrapper pred = predecessors(ctx.cfg, index); @@ -469,7 +469,7 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind } // Successor list - if (!ctx.cfg.successors.empty()) + if (index < ctx.cfg.successorsOffsets.size()) { BlockIteratorWrapper succ = successors(ctx.cfg, index); @@ -509,14 +509,14 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind } } -std::string toString(IrFunction& function, bool includeUseInfo) +std::string toString(const IrFunction& function, bool includeUseInfo) { std::string result; IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; for (size_t i = 0; i < function.blocks.size(); i++) { - IrBlock& block = function.blocks[i]; + const IrBlock& block = function.blocks[i]; if (block.kind == IrBlockKind::Dead) continue; @@ -532,7 +532,7 @@ std::string toString(IrFunction& function, bool includeUseInfo) // To allow dumping blocks that are still being constructed, we can't rely on terminator and need a bounds check for (uint32_t index = block.start; index <= block.finish && index < uint32_t(function.instructions.size()); index++) { - IrInst& inst = function.instructions[index]; + const IrInst& inst = function.instructions[index]; // Skip pseudo instructions unless they are still referenced if (isPseudo(inst.cmd) && inst.useCount == 0) @@ -548,7 +548,7 @@ std::string toString(IrFunction& function, bool includeUseInfo) return result; } -std::string dump(IrFunction& function) +std::string dump(const IrFunction& function) { std::string result = toString(function, /* includeUseInfo */ true); @@ -557,12 +557,12 @@ std::string dump(IrFunction& function) return result; } -std::string toDot(IrFunction& function, bool includeInst) +std::string toDot(const IrFunction& function, bool includeInst) { std::string result; IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; - auto appendLabelRegset = [&ctx](std::vector& regSets, size_t blockIdx, const char* name) { + auto appendLabelRegset = [&ctx](const std::vector& regSets, size_t blockIdx, const char* name) { if (blockIdx < regSets.size()) { const RegisterSet& rs = regSets[blockIdx]; @@ -581,7 +581,7 @@ std::string toDot(IrFunction& function, bool includeInst) for (size_t i = 0; i < function.blocks.size(); i++) { - IrBlock& block = function.blocks[i]; + const IrBlock& block = function.blocks[i]; append(ctx.result, "b%u [", unsigned(i)); @@ -599,7 +599,7 @@ std::string toDot(IrFunction& function, bool includeInst) { for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) { - IrInst& inst = function.instructions[instIdx]; + const IrInst& inst = function.instructions[instIdx]; // Skip pseudo instructions unless they are still referenced if (isPseudo(inst.cmd) && inst.useCount == 0) @@ -618,14 +618,14 @@ std::string toDot(IrFunction& function, bool includeInst) for (size_t i = 0; i < function.blocks.size(); i++) { - IrBlock& block = function.blocks[i]; + const IrBlock& block = function.blocks[i]; if (block.start == ~0u) continue; for (uint32_t instIdx = block.start; instIdx != ~0u && instIdx <= block.finish; instIdx++) { - IrInst& inst = function.instructions[instIdx]; + const IrInst& inst = function.instructions[instIdx]; auto checkOp = [&](IrOp op) { if (op.kind == IrOpKind::Block) @@ -651,7 +651,7 @@ std::string toDot(IrFunction& function, bool includeInst) return result; } -std::string dumpDot(IrFunction& function, bool includeInst) +std::string dumpDot(const IrFunction& function, bool includeInst) { std::string result = toDot(function, includeInst); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp new file mode 100644 index 000000000..ae4bc017d --- /dev/null +++ b/CodeGen/src/IrLoweringA64.cpp @@ -0,0 +1,137 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "IrLoweringA64.h" + +#include "Luau/CodeGen.h" +#include "Luau/DenseHash.h" +#include "Luau/IrAnalysis.h" +#include "Luau/IrDump.h" +#include "Luau/IrUtils.h" + +#include "EmitCommonA64.h" +#include "EmitInstructionA64.h" +#include "NativeState.h" + +#include "lstate.h" + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) + : build(build) + , helpers(helpers) + , data(data) + , proto(proto) + , function(function) +{ + // In order to allocate registers during lowering, we need to know where instruction results are last used + updateLastUseLocations(function); +} + +// TODO: Eventually this can go away +bool IrLoweringA64::canLower(const IrFunction& function) +{ + for (const IrInst& inst : function.instructions) + { + switch (inst.cmd) + { + case IrCmd::NOP: + case IrCmd::SUBSTITUTE: + case IrCmd::INTERRUPT: + case IrCmd::LOP_RETURN: + continue; + default: + return false; + } + } + + return true; +} + +void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) +{ + switch (inst.cmd) + { + case IrCmd::INTERRUPT: + { + emitInterrupt(build, uintOp(inst.a)); + break; + } + case IrCmd::LOP_RETURN: + { + emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); + break; + } + default: + LUAU_ASSERT(!"Not supported yet"); + break; + } + + // TODO + // regs.freeLastUseRegs(inst, index); +} + +bool IrLoweringA64::isFallthroughBlock(IrBlock target, IrBlock next) +{ + return target.start == next.start; +} + +void IrLoweringA64::jumpOrFallthrough(IrBlock& target, IrBlock& next) +{ + if (!isFallthroughBlock(target, next)) + build.b(target.label); +} + +RegisterA64 IrLoweringA64::regOp(IrOp op) const +{ + IrInst& inst = function.instOp(op); + LUAU_ASSERT(inst.regA64 != noreg); + return inst.regA64; +} + +IrConst IrLoweringA64::constOp(IrOp op) const +{ + return function.constOp(op); +} + +uint8_t IrLoweringA64::tagOp(IrOp op) const +{ + return function.tagOp(op); +} + +bool IrLoweringA64::boolOp(IrOp op) const +{ + return function.boolOp(op); +} + +int IrLoweringA64::intOp(IrOp op) const +{ + return function.intOp(op); +} + +unsigned IrLoweringA64::uintOp(IrOp op) const +{ + return function.uintOp(op); +} + +double IrLoweringA64::doubleOp(IrOp op) const +{ + return function.doubleOp(op); +} + +IrBlock& IrLoweringA64::blockOp(IrOp op) const +{ + return function.blockOp(op); +} + +Label& IrLoweringA64::labelOp(IrOp op) const +{ + return blockOp(op).label; +} + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h new file mode 100644 index 000000000..aa9eba422 --- /dev/null +++ b/CodeGen/src/IrLoweringA64.h @@ -0,0 +1,60 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AssemblyBuilderA64.h" +#include "Luau/IrData.h" + +#include + +struct Proto; + +namespace Luau +{ +namespace CodeGen +{ + +struct ModuleHelpers; +struct NativeState; +struct AssemblyOptions; + +namespace A64 +{ + +struct IrLoweringA64 +{ + IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function); + + static bool canLower(const IrFunction& function); + + void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); + + bool isFallthroughBlock(IrBlock target, IrBlock next); + void jumpOrFallthrough(IrBlock& target, IrBlock& next); + + // Operand data lookup helpers + RegisterA64 regOp(IrOp op) const; + + IrConst constOp(IrOp op) const; + uint8_t tagOp(IrOp op) const; + bool boolOp(IrOp op) const; + int intOp(IrOp op) const; + unsigned uintOp(IrOp op) const; + double doubleOp(IrOp op) const; + + IrBlock& blockOp(IrOp op) const; + Label& labelOp(IrOp op) const; + + AssemblyBuilderA64& build; + ModuleHelpers& helpers; + NativeState& data; + Proto* proto = nullptr; // Temporarily required to provide 'Instruction* pc' to old emitInst* methods + + IrFunction& function; + + // TODO: + // IrRegAllocA64 regs; +}; + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index b45ce2261..1cc56fe31 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -14,8 +14,6 @@ #include "lstate.h" -#include - namespace Luau { namespace CodeGen @@ -23,11 +21,10 @@ namespace CodeGen namespace X64 { -IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) +IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, IrFunction& function) : build(build) , helpers(helpers) , data(data) - , proto(proto) , function(function) , regs(function) { @@ -35,146 +32,6 @@ IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, updateLastUseLocations(function); } -void IrLoweringX64::lower(AssemblyOptions options) -{ - // While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined - std::vector sortedBlocks; - sortedBlocks.reserve(function.blocks.size()); - for (uint32_t i = 0; i < function.blocks.size(); i++) - sortedBlocks.push_back(i); - - std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) { - const IrBlock& a = function.blocks[idxA]; - const IrBlock& b = function.blocks[idxB]; - - // Place fallback blocks at the end - if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback)) - return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); - - // Try to order by instruction order - return a.start < b.start; - }); - - DenseHashMap bcLocations{~0u}; - - // Create keys for IR assembly locations that original bytecode instruction are interested in - for (const auto& [irLocation, asmLocation] : function.bcMapping) - { - if (irLocation != ~0u) - bcLocations[irLocation] = 0; - } - - DenseHashMap indexIrToBc{~0u}; - bool outputEnabled = options.includeAssembly || options.includeIr; - - if (outputEnabled && options.annotator) - { - // Create reverse mapping from IR location to bytecode location - for (size_t i = 0; i < function.bcMapping.size(); ++i) - { - uint32_t irLocation = function.bcMapping[i].irLocation; - - if (irLocation != ~0u) - indexIrToBc[irLocation] = uint32_t(i); - } - } - - IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; - - // We use this to skip outlined fallback blocks from IR/asm text output - size_t textSize = build.text.length(); - uint32_t codeSize = build.getCodeSize(); - bool seenFallback = false; - - IrBlock dummy; - dummy.start = ~0u; - - for (size_t i = 0; i < sortedBlocks.size(); ++i) - { - uint32_t blockIndex = sortedBlocks[i]; - - IrBlock& block = function.blocks[blockIndex]; - - if (block.kind == IrBlockKind::Dead) - continue; - - LUAU_ASSERT(block.start != ~0u); - LUAU_ASSERT(block.finish != ~0u); - - // If we want to skip fallback code IR/asm, we'll record when those blocks start once we see them - if (block.kind == IrBlockKind::Fallback && !seenFallback) - { - textSize = build.text.length(); - codeSize = build.getCodeSize(); - seenFallback = true; - } - - if (options.includeIr) - { - build.logAppend("# "); - toStringDetailed(ctx, block, blockIndex, /* includeUseInfo */ true); - } - - build.setLabel(block.label); - - for (uint32_t index = block.start; index <= block.finish; index++) - { - LUAU_ASSERT(index < function.instructions.size()); - - // If IR instruction is the first one for the original bytecode, we can annotate it with source code text - if (outputEnabled && options.annotator) - { - if (uint32_t* bcIndex = indexIrToBc.find(index)) - options.annotator(options.annotatorContext, build.text, proto->bytecodeid, *bcIndex); - } - - // If bytecode needs the location of this instruction for jumps, record it - if (uint32_t* bcLocation = bcLocations.find(index)) - *bcLocation = build.getCodeSize(); - - IrInst& inst = function.instructions[index]; - - // Skip pseudo instructions, but make sure they are not used at this stage - // This also prevents them from getting into text output when that's enabled - if (isPseudo(inst.cmd)) - { - LUAU_ASSERT(inst.useCount == 0); - continue; - } - - if (options.includeIr) - { - build.logAppend("# "); - toStringDetailed(ctx, inst, index, /* includeUseInfo */ true); - } - - IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; - - lowerInst(inst, index, next); - - regs.freeLastUseRegs(inst, index); - } - - if (options.includeIr) - build.logAppend("#\n"); - } - - if (outputEnabled && !options.includeOutlinedCode && seenFallback) - { - build.text.resize(textSize); - - if (options.includeAssembly) - build.logAppend("; skipping %u bytes of outlined code\n", build.getCodeSize() - codeSize); - } - - // Copy assembly locations of IR instructions that are mapped to bytecode instructions - for (auto& [irLocation, asmLocation] : function.bcMapping) - { - if (irLocation != ~0u) - asmLocation = bcLocations[irLocation]; - } -} - void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { switch (inst.cmd) @@ -183,9 +40,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regX64 = regs.allocGprReg(SizeX64::dword); if (inst.a.kind == IrOpKind::VmReg) - build.mov(inst.regX64, luauRegTag(inst.a.index)); + build.mov(inst.regX64, luauRegTag(vmRegOp(inst.a))); else if (inst.a.kind == IrOpKind::VmConst) - build.mov(inst.regX64, luauConstantTag(inst.a.index)); + build.mov(inst.regX64, luauConstantTag(vmConstOp(inst.a))); // If we have a register, we assume it's a pointer to TValue // We might introduce explicit operand types in the future to make this more robust else if (inst.a.kind == IrOpKind::Inst) @@ -197,9 +54,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regX64 = regs.allocGprReg(SizeX64::qword); if (inst.a.kind == IrOpKind::VmReg) - build.mov(inst.regX64, luauRegValue(inst.a.index)); + build.mov(inst.regX64, luauRegValue(vmRegOp(inst.a))); else if (inst.a.kind == IrOpKind::VmConst) - build.mov(inst.regX64, luauConstantValue(inst.a.index)); + build.mov(inst.regX64, luauConstantValue(vmConstOp(inst.a))); // If we have a register, we assume it's a pointer to TValue // We might introduce explicit operand types in the future to make this more robust else if (inst.a.kind == IrOpKind::Inst) @@ -211,26 +68,24 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regX64 = regs.allocXmmReg(); if (inst.a.kind == IrOpKind::VmReg) - build.vmovsd(inst.regX64, luauRegValue(inst.a.index)); + build.vmovsd(inst.regX64, luauRegValue(vmRegOp(inst.a))); else if (inst.a.kind == IrOpKind::VmConst) - build.vmovsd(inst.regX64, luauConstantValue(inst.a.index)); + build.vmovsd(inst.regX64, luauConstantValue(vmConstOp(inst.a))); else LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_INT: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - inst.regX64 = regs.allocGprReg(SizeX64::dword); - build.mov(inst.regX64, luauRegValueInt(inst.a.index)); + build.mov(inst.regX64, luauRegValueInt(vmRegOp(inst.a))); break; case IrCmd::LOAD_TVALUE: inst.regX64 = regs.allocXmmReg(); if (inst.a.kind == IrOpKind::VmReg) - build.vmovups(inst.regX64, luauReg(inst.a.index)); + build.vmovups(inst.regX64, luauReg(vmRegOp(inst.a))); else if (inst.a.kind == IrOpKind::VmConst) - build.vmovups(inst.regX64, luauConstant(inst.a.index)); + build.vmovups(inst.regX64, luauConstant(vmConstOp(inst.a))); else if (inst.a.kind == IrOpKind::Inst) build.vmovups(inst.regX64, xmmword[regOp(inst.a)]); else @@ -301,31 +156,25 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; }; case IrCmd::STORE_TAG: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - if (inst.b.kind == IrOpKind::Constant) - build.mov(luauRegTag(inst.a.index), tagOp(inst.b)); + build.mov(luauRegTag(vmRegOp(inst.a)), tagOp(inst.b)); else LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::STORE_POINTER: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - - build.mov(luauRegValue(inst.a.index), regOp(inst.b)); + build.mov(luauRegValue(vmRegOp(inst.a)), regOp(inst.b)); break; case IrCmd::STORE_DOUBLE: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - if (inst.b.kind == IrOpKind::Constant) { ScopedRegX64 tmp{regs, SizeX64::xmmword}; build.vmovsd(tmp.reg, build.f64(doubleOp(inst.b))); - build.vmovsd(luauRegValue(inst.a.index), tmp.reg); + build.vmovsd(luauRegValue(vmRegOp(inst.a)), tmp.reg); } else if (inst.b.kind == IrOpKind::Inst) { - build.vmovsd(luauRegValue(inst.a.index), regOp(inst.b)); + build.vmovsd(luauRegValue(vmRegOp(inst.a)), regOp(inst.b)); } else { @@ -334,19 +183,17 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::STORE_INT: { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - if (inst.b.kind == IrOpKind::Constant) - build.mov(luauRegValueInt(inst.a.index), intOp(inst.b)); + build.mov(luauRegValueInt(vmRegOp(inst.a)), intOp(inst.b)); else if (inst.b.kind == IrOpKind::Inst) - build.mov(luauRegValueInt(inst.a.index), regOp(inst.b)); + build.mov(luauRegValueInt(vmRegOp(inst.a)), regOp(inst.b)); else LUAU_ASSERT(!"Unsupported instruction form"); break; } case IrCmd::STORE_TVALUE: if (inst.a.kind == IrOpKind::VmReg) - build.vmovups(luauReg(inst.a.index), regOp(inst.b)); + build.vmovups(luauReg(vmRegOp(inst.a)), regOp(inst.b)); else if (inst.a.kind == IrOpKind::Inst) build.vmovups(xmmword[regOp(inst.a)], regOp(inst.b)); else @@ -642,15 +489,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.a), next); break; case IrCmd::JUMP_IF_TRUTHY: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - - jumpIfTruthy(build, inst.a.index, labelOp(inst.b), labelOp(inst.c)); + jumpIfTruthy(build, vmRegOp(inst.a), labelOp(inst.b), labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.c), next); break; case IrCmd::JUMP_IF_FALSY: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - - jumpIfFalsy(build, inst.a.index, labelOp(inst.b), labelOp(inst.c)); + jumpIfFalsy(build, vmRegOp(inst.a), labelOp(inst.b), labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.c), next); break; case IrCmd::JUMP_EQ_TAG: @@ -686,9 +529,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::JUMP_CMP_NUM: { - LUAU_ASSERT(inst.c.kind == IrOpKind::Condition); - - IrCondition cond = IrCondition(inst.c.index); + IrCondition cond = conditionOp(inst.c); ScopedRegX64 tmp{regs, SizeX64::xmmword}; @@ -698,24 +539,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::JUMP_CMP_ANY: - { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::Condition); - - IrCondition cond = IrCondition(inst.c.index); - - jumpOnAnyCmpFallback(build, inst.a.index, inst.b.index, cond, labelOp(inst.d)); + jumpOnAnyCmpFallback(build, vmRegOp(inst.a), vmRegOp(inst.b), conditionOp(inst.c), labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; - } case IrCmd::JUMP_SLOT_MATCH: { - LUAU_ASSERT(inst.b.kind == IrOpKind::VmConst); - ScopedRegX64 tmp{regs, SizeX64::qword}; - jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(inst.b.index), labelOp(inst.d)); + jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(vmConstOp(inst.b)), labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.c), next); break; } @@ -774,13 +605,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::ADJUST_STACK_TO_REG: { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - if (inst.b.kind == IrOpKind::Constant) { ScopedRegX64 tmp{regs, SizeX64::qword}; - build.lea(tmp.reg, addr[rBase + (inst.a.index + intOp(inst.b)) * sizeof(TValue)]); + build.lea(tmp.reg, addr[rBase + (vmRegOp(inst.a) + intOp(inst.b)) * sizeof(TValue)]); build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg); } else if (inst.b.kind == IrOpKind::Inst) @@ -788,7 +617,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) ScopedRegX64 tmp(regs, regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.b})); build.shl(qwordReg(tmp.reg), kTValueSizeLog2); - build.lea(qwordReg(tmp.reg), addr[rBase + qwordReg(tmp.reg) + inst.a.index * sizeof(TValue)]); + build.lea(qwordReg(tmp.reg), addr[rBase + qwordReg(tmp.reg) + vmRegOp(inst.a) * sizeof(TValue)]); build.mov(qword[rState + offsetof(lua_State, top)], qwordReg(tmp.reg)); } else @@ -807,28 +636,23 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::FASTCALL: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - emitBuiltin(regs, build, uintOp(inst.a), inst.b.index, inst.c.index, inst.d, intOp(inst.e), intOp(inst.f)); + emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f)); break; case IrCmd::INVOKE_FASTCALL: { - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - unsigned bfid = uintOp(inst.a); OperandX64 args = 0; if (inst.d.kind == IrOpKind::VmReg) - args = luauRegAddress(inst.d.index); + args = luauRegAddress(vmRegOp(inst.d)); else if (inst.d.kind == IrOpKind::VmConst) - args = luauConstantAddress(inst.d.index); + args = luauConstantAddress(vmConstOp(inst.d)); else LUAU_ASSERT(boolOp(inst.d) == false); - int ra = inst.b.index; - int arg = inst.c.index; + int ra = vmRegOp(inst.b); + int arg = vmRegOp(inst.c); int nparams = intOp(inst.e); int nresults = intOp(inst.f); @@ -889,34 +713,24 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::DO_ARITH: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg || inst.c.kind == IrOpKind::VmConst); - if (inst.c.kind == IrOpKind::VmReg) - callArithHelper(build, inst.a.index, inst.b.index, luauRegAddress(inst.c.index), TMS(intOp(inst.d))); + callArithHelper(build, vmRegOp(inst.a), vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), TMS(intOp(inst.d))); else - callArithHelper(build, inst.a.index, inst.b.index, luauConstantAddress(inst.c.index), TMS(intOp(inst.d))); + callArithHelper(build, vmRegOp(inst.a), vmRegOp(inst.b), luauConstantAddress(vmConstOp(inst.c)), TMS(intOp(inst.d))); break; case IrCmd::DO_LEN: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - - callLengthHelper(build, inst.a.index, inst.b.index); + callLengthHelper(build, vmRegOp(inst.a), vmRegOp(inst.b)); break; case IrCmd::GET_TABLE: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - if (inst.c.kind == IrOpKind::VmReg) { - callGetTable(build, inst.b.index, luauRegAddress(inst.c.index), inst.a.index); + callGetTable(build, vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), vmRegOp(inst.a)); } else if (inst.c.kind == IrOpKind::Constant) { TValue n; setnvalue(&n, uintOp(inst.c)); - callGetTable(build, inst.b.index, build.bytes(&n, sizeof(n)), inst.a.index); + callGetTable(build, vmRegOp(inst.b), build.bytes(&n, sizeof(n)), vmRegOp(inst.a)); } else { @@ -924,18 +738,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::SET_TABLE: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - if (inst.c.kind == IrOpKind::VmReg) { - callSetTable(build, inst.b.index, luauRegAddress(inst.c.index), inst.a.index); + callSetTable(build, vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), vmRegOp(inst.a)); } else if (inst.c.kind == IrOpKind::Constant) { TValue n; setnvalue(&n, uintOp(inst.c)); - callSetTable(build, inst.b.index, build.bytes(&n, sizeof(n)), inst.a.index); + callSetTable(build, vmRegOp(inst.b), build.bytes(&n, sizeof(n)), vmRegOp(inst.a)); } else { @@ -943,30 +754,23 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::GET_IMPORT: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - - emitInstGetImportFallback(build, inst.a.index, uintOp(inst.b)); + emitInstGetImportFallback(build, vmRegOp(inst.a), uintOp(inst.b)); break; case IrCmd::CONCAT: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - build.mov(rArg1, rState); build.mov(dwordReg(rArg2), uintOp(inst.b)); - build.mov(dwordReg(rArg3), inst.a.index + uintOp(inst.b) - 1); + build.mov(dwordReg(rArg3), vmRegOp(inst.a) + uintOp(inst.b) - 1); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]); emitUpdateBase(build); break; case IrCmd::GET_UPVALUE: { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmUpvalue); - ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::xmmword}; build.mov(tmp1.reg, sClosure); - build.add(tmp1.reg, offsetof(Closure, l.uprefs) + sizeof(TValue) * inst.b.index); + build.add(tmp1.reg, offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.b)); // uprefs[] is either an actual value, or it points to UpVal object which has a pointer to value Label skip; @@ -981,32 +785,29 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.setLabel(skip); build.vmovups(tmp2.reg, xmmword[tmp1.reg]); - build.vmovups(luauReg(inst.a.index), tmp2.reg); + build.vmovups(luauReg(vmRegOp(inst.a)), tmp2.reg); break; } case IrCmd::SET_UPVALUE: { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmUpvalue); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - Label next; ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::qword}; ScopedRegX64 tmp3{regs, SizeX64::xmmword}; build.mov(tmp1.reg, sClosure); - build.mov(tmp2.reg, qword[tmp1.reg + offsetof(Closure, l.uprefs) + sizeof(TValue) * inst.a.index + offsetof(TValue, value.gc)]); + build.mov(tmp2.reg, qword[tmp1.reg + offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.a) + offsetof(TValue, value.gc)]); build.mov(tmp1.reg, qword[tmp2.reg + offsetof(UpVal, v)]); - build.vmovups(tmp3.reg, luauReg(inst.b.index)); + build.vmovups(tmp3.reg, luauReg(vmRegOp(inst.b))); build.vmovups(xmmword[tmp1.reg], tmp3.reg); - callBarrierObject(build, tmp1.reg, tmp2.reg, inst.b.index, next); + callBarrierObject(build, tmp1.reg, tmp2.reg, vmRegOp(inst.b), next); build.setLabel(next); break; } case IrCmd::PREPARE_FORN: - callPrepareForN(build, inst.a.index, inst.b.index, inst.c.index); + callPrepareForN(build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); break; case IrCmd::CHECK_TAG: if (inst.a.kind == IrOpKind::Inst) @@ -1016,11 +817,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } else if (inst.a.kind == IrOpKind::VmReg) { - jumpIfTagIsNot(build, inst.a.index, lua_Type(tagOp(inst.b)), labelOp(inst.c)); + jumpIfTagIsNot(build, vmRegOp(inst.a), lua_Type(tagOp(inst.b)), labelOp(inst.c)); } else if (inst.a.kind == IrOpKind::VmConst) { - build.cmp(luauConstantTag(inst.a.index), tagOp(inst.b)); + build.cmp(luauConstantTag(vmConstOp(inst.a)), tagOp(inst.b)); build.jcc(ConditionX64::NotEqual, labelOp(inst.c)); } else @@ -1053,11 +854,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::CHECK_SLOT_MATCH: { - LUAU_ASSERT(inst.b.kind == IrOpKind::VmConst); - ScopedRegX64 tmp{regs, SizeX64::qword}; - jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(inst.b.index), labelOp(inst.c)); + jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(vmConstOp(inst.b)), labelOp(inst.c)); break; } case IrCmd::CHECK_NODE_NO_NEXT: @@ -1075,12 +874,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::BARRIER_OBJ: { - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - Label skip; ScopedRegX64 tmp{regs, SizeX64::qword}; - callBarrierObject(build, tmp.reg, regOp(inst.a), inst.b.index, skip); + callBarrierObject(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), skip); build.setLabel(skip); break; } @@ -1094,12 +891,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::BARRIER_TABLE_FORWARD: { - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - Label skip; ScopedRegX64 tmp{regs, SizeX64::qword}; - callBarrierTable(build, tmp.reg, regOp(inst.a), inst.b.index, skip); + callBarrierTable(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), skip); build.setLabel(skip); break; } @@ -1117,8 +912,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::CLOSE_UPVALS: { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - Label next; ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::qword}; @@ -1129,7 +922,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.jcc(ConditionX64::Zero, next); // ra <= L->openuval->v - build.lea(tmp2.reg, addr[rBase + inst.a.index * sizeof(TValue)]); + build.lea(tmp2.reg, addr[rBase + vmRegOp(inst.a) * sizeof(TValue)]); build.cmp(tmp2.reg, qword[tmp1.reg + offsetof(UpVal, v)]); build.jcc(ConditionX64::Above, next); @@ -1149,60 +942,38 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // Fallbacks to non-IR instruction implementations case IrCmd::LOP_SETLIST: { - const Instruction* pc = proto->code + uintOp(inst.a); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.d.kind == IrOpKind::Constant); - LUAU_ASSERT(inst.e.kind == IrOpKind::Constant); - Label next; - emitInstSetList(build, pc, next); + emitInstSetList(build, next, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); build.setLabel(next); break; } case IrCmd::LOP_CALL: - { - const Instruction* pc = proto->code + uintOp(inst.a); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - LUAU_ASSERT(inst.d.kind == IrOpKind::Constant); - - emitInstCall(build, helpers, pc, uintOp(inst.a)); + emitInstCall(build, helpers, vmRegOp(inst.a), intOp(inst.b), intOp(inst.c)); break; - } case IrCmd::LOP_RETURN: - { - const Instruction* pc = proto->code + uintOp(inst.a); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - - emitInstReturn(build, helpers, pc, uintOp(inst.a)); + emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); break; - } case IrCmd::LOP_FORGLOOP: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - emitinstForGLoop(build, inst.a.index, intOp(inst.b), labelOp(inst.c), labelOp(inst.d)); + emitinstForGLoop(build, vmRegOp(inst.a), intOp(inst.b), labelOp(inst.c), labelOp(inst.d)); break; case IrCmd::LOP_FORGLOOP_FALLBACK: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - emitinstForGLoopFallback(build, uintOp(inst.a), inst.b.index, intOp(inst.c), labelOp(inst.d)); + emitinstForGLoopFallback(build, uintOp(inst.a), vmRegOp(inst.b), intOp(inst.c), labelOp(inst.d)); build.jmp(labelOp(inst.e)); break; case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - emitInstForGPrepXnextFallback(build, uintOp(inst.a), inst.b.index, labelOp(inst.c)); + emitInstForGPrepXnextFallback(build, uintOp(inst.a), vmRegOp(inst.b), labelOp(inst.c)); break; case IrCmd::LOP_AND: - emitInstAnd(build, proto->code + uintOp(inst.a)); + emitInstAnd(build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); break; case IrCmd::LOP_ANDK: - emitInstAndK(build, proto->code + uintOp(inst.a)); + emitInstAndK(build, vmRegOp(inst.a), vmRegOp(inst.b), vmConstOp(inst.c)); break; case IrCmd::LOP_OR: - emitInstOr(build, proto->code + uintOp(inst.a)); + emitInstOr(build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); break; case IrCmd::LOP_ORK: - emitInstOrK(build, proto->code + uintOp(inst.a)); + emitInstOrK(build, vmRegOp(inst.a), vmRegOp(inst.b), vmConstOp(inst.c)); break; case IrCmd::LOP_COVERAGE: emitInstCoverage(build, uintOp(inst.a)); @@ -1272,6 +1043,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Not supported yet"); break; } + + regs.freeLastUseRegs(inst, index); } bool IrLoweringX64::isFallthroughBlock(IrBlock target, IrBlock next) @@ -1294,9 +1067,9 @@ OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) const case IrOpKind::Constant: return build.f64(doubleOp(op)); case IrOpKind::VmReg: - return luauRegValue(op.index); + return luauRegValue(vmRegOp(op)); case IrOpKind::VmConst: - return luauConstantValue(op.index); + return luauConstantValue(vmConstOp(op)); default: LUAU_ASSERT(!"Unsupported operand kind"); } @@ -1311,9 +1084,9 @@ OperandX64 IrLoweringX64::memRegTagOp(IrOp op) const case IrOpKind::Inst: return regOp(op); case IrOpKind::VmReg: - return luauRegTag(op.index); + return luauRegTag(vmRegOp(op)); case IrOpKind::VmConst: - return luauConstantTag(op.index); + return luauConstantTag(vmConstOp(op)); default: LUAU_ASSERT(!"Unsupported operand kind"); } @@ -1323,7 +1096,9 @@ OperandX64 IrLoweringX64::memRegTagOp(IrOp op) const RegisterX64 IrLoweringX64::regOp(IrOp op) const { - return function.instOp(op).regX64; + IrInst& inst = function.instOp(op); + LUAU_ASSERT(inst.regX64 != noreg); + return inst.regX64; } IrConst IrLoweringX64::constOp(IrOp op) const diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index a0ad3eabd..c8ebd1f18 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -24,10 +24,7 @@ namespace X64 struct IrLoweringX64 { - // Some of these arguments are only required while we re-use old direct bytecode to x64 lowering - IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function); - - void lower(AssemblyOptions options); + IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, IrFunction& function); void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); @@ -52,7 +49,6 @@ struct IrLoweringX64 AssemblyBuilderX64& build; ModuleHelpers& helpers; NativeState& data; - Proto* proto = nullptr; // Temporarily required to provide 'Instruction* pc' to old emitInst* methods IrFunction& function; diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index d9f935c49..cb8e41482 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -240,6 +240,10 @@ BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback) { + // Builtins are not allowed to handle variadic arguments + if (nparams == LUA_MULTRET) + return {BuiltinImplType::None, -1}; + switch (bfid) { case LBF_ASSERT: diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 28c6aca19..d90841ce3 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -501,10 +501,10 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool if (br.type == BuiltinImplType::UsesFallback) { + LUAU_ASSERT(nparams != LUA_MULTRET && "builtins are not allowed to handle variadic arguments"); + if (nresults == LUA_MULTRET) build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(ra), build.constInt(br.actualResultCount)); - else if (nparams == LUA_MULTRET) - build.inst(IrCmd::ADJUST_STACK_TO_TOP); } else { diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index e29a5b029..b28ce596e 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -354,7 +354,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 case IrCmd::JUMP_CMP_NUM: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { - if (compare(function.doubleOp(inst.a), function.doubleOp(inst.b), function.conditionOp(inst.c))) + if (compare(function.doubleOp(inst.a), function.doubleOp(inst.b), conditionOp(inst.c))) replace(function, block, index, {IrCmd::JUMP, inst.d}); else replace(function, block, index, {IrCmd::JUMP, inst.e}); diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index f79bcab85..f1497890c 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -109,6 +109,7 @@ void initHelperFunctions(NativeState& data) data.context.forgPrepXnextFallback = forgPrepXnextFallback; data.context.callProlog = callProlog; data.context.callEpilogC = callEpilogC; + data.context.returnFallback = returnFallback; } } // namespace CodeGen diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index bebf421b9..6d8331896 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -47,12 +47,6 @@ struct NativeContext uint8_t* gateEntry = nullptr; uint8_t* gateExit = nullptr; - // Opcode fallbacks, implemented in C - NativeFallback fallback[LOP__COUNT] = {}; - - // Fast call methods, implemented in C - luau_FastFunction luauF_table[256] = {}; - // Helper functions, implemented in C int (*luaV_lessthan)(lua_State* L, const TValue* l, const TValue* r) = nullptr; int (*luaV_lessequal)(lua_State* L, const TValue* l, const TValue* r) = nullptr; @@ -107,6 +101,13 @@ struct NativeContext void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr; void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr; + const Instruction* (*returnFallback)(lua_State* L, StkId ra, int n) = nullptr; + + // Opcode fallbacks, implemented in C + NativeFallback fallback[LOP__COUNT] = {}; + + // Fast call methods, implemented in C + luau_FastFunction luauF_table[256] = {}; }; struct NativeState diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index b12a9b946..672364764 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -42,6 +42,11 @@ struct RegisterLink // Data we know about the current VM state struct ConstPropState { + ConstPropState(const IrFunction& function) + : function(function) + { + } + uint8_t tryGetTag(IrOp op) { if (RegisterInfo* info = tryGetRegisterInfo(op)) @@ -107,14 +112,29 @@ struct ConstPropState invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ true); } - void invalidateRegistersFrom(uint32_t firstReg) + void invalidateRegistersFrom(int firstReg) { - for (int i = int(firstReg); i <= maxReg; ++i) + for (int i = firstReg; i <= maxReg; ++i) invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true); maxReg = int(firstReg) - 1; } + void invalidateRegisterRange(int firstReg, int count) + { + for (int i = firstReg; i < firstReg + count && i <= maxReg; ++i) + invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true); + } + + void invalidateCapturedRegisters() + { + for (int i = 0; i <= maxReg; ++i) + { + if (function.cfg.captured.regs.test(i)) + invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true); + } + } + void invalidateHeap() { for (int i = 0; i <= maxReg; ++i) @@ -127,10 +147,10 @@ struct ConstPropState reg.knownNoMetatable = false; } - void invalidateAll() + void invalidateUserCall() { - // Invalidating registers also invalidates what we know about the heap (stored in RegisterInfo) - invalidateRegistersFrom(0u); + invalidateHeap(); + invalidateCapturedRegisters(); inSafeEnv = false; } @@ -175,6 +195,8 @@ struct ConstPropState return nullptr; } + const IrFunction& function; + RegisterInfo regs[256]; // For range/full invalidations, we only want to visit a limited number of data that we have recorded @@ -411,7 +433,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (valueA && valueB) { - if (compare(*valueA, *valueB, function.conditionOp(inst.c))) + if (compare(*valueA, *valueB, conditionOp(inst.c))) replace(function, block, index, {IrCmd::JUMP, inst.d}); else replace(function, block, index, {IrCmd::JUMP, inst.e}); @@ -485,7 +507,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::LOP_ANDK: case IrCmd::LOP_OR: case IrCmd::LOP_ORK: - state.invalidate(inst.b); + state.invalidate(inst.a); break; case IrCmd::FASTCALL: case IrCmd::INVOKE_FASTCALL: @@ -538,35 +560,93 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::CHECK_FASTCALL_RES: // Changes stack top, but not the values break; - // We don't model the following instructions, so we just clear all the knowledge we have built up - // Many of these call user functions that can change memory and captured registers - // Some of these might yield with similar effects case IrCmd::JUMP_CMP_ANY: + state.invalidateUserCall(); // TODO: if arguments are strings, there will be no user calls + break; case IrCmd::DO_ARITH: + state.invalidate(inst.a); + state.invalidateUserCall(); + break; case IrCmd::DO_LEN: + state.invalidate(inst.a); + state.invalidateUserCall(); // TODO: if argument is a string, there will be no user call + + state.saveTag(inst.a, LUA_TNUMBER); + break; case IrCmd::GET_TABLE: + state.invalidate(inst.a); + state.invalidateUserCall(); + break; case IrCmd::SET_TABLE: + state.invalidateUserCall(); + break; case IrCmd::GET_IMPORT: + state.invalidate(inst.a); + state.invalidateUserCall(); + break; case IrCmd::CONCAT: + state.invalidateRegisterRange(inst.a.index, function.uintOp(inst.b)); + state.invalidateUserCall(); // TODO: if only strings and numbers are concatenated, there will be no user calls + break; case IrCmd::PREPARE_FORN: - case IrCmd::INTERRUPT: // TODO: it will be important to keep tag/value state, but we have to track register capture + state.invalidateValue(inst.a); + state.saveTag(inst.a, LUA_TNUMBER); + state.invalidateValue(inst.b); + state.saveTag(inst.b, LUA_TNUMBER); + state.invalidateValue(inst.c); + state.saveTag(inst.c, LUA_TNUMBER); + break; + case IrCmd::INTERRUPT: + state.invalidateUserCall(); + break; case IrCmd::LOP_CALL: + state.invalidateRegistersFrom(inst.a.index); + state.invalidateUserCall(); + break; case IrCmd::LOP_FORGLOOP: + state.invalidateRegistersFrom(inst.a.index + 2); // Rn and Rn+1 are not modified + break; case IrCmd::LOP_FORGLOOP_FALLBACK: + state.invalidateRegistersFrom(inst.b.index + 2); // Rn and Rn+1 are not modified + state.invalidateUserCall(); + break; case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + // This fallback only conditionally throws an exception + break; case IrCmd::FALLBACK_GETGLOBAL: + state.invalidate(inst.b); + state.invalidateUserCall(); + break; case IrCmd::FALLBACK_SETGLOBAL: + state.invalidateUserCall(); + break; case IrCmd::FALLBACK_GETTABLEKS: + state.invalidate(inst.b); + state.invalidateUserCall(); + break; case IrCmd::FALLBACK_SETTABLEKS: + state.invalidateUserCall(); + break; case IrCmd::FALLBACK_NAMECALL: + state.invalidate(IrOp{inst.b.kind, inst.b.index + 0u}); + state.invalidate(IrOp{inst.b.kind, inst.b.index + 1u}); + state.invalidateUserCall(); + break; case IrCmd::FALLBACK_PREPVARARGS: + break; case IrCmd::FALLBACK_GETVARARGS: + state.invalidateRegistersFrom(inst.b.index); + break; case IrCmd::FALLBACK_NEWCLOSURE: + state.invalidate(inst.b); + break; case IrCmd::FALLBACK_DUPCLOSURE: + state.invalidate(inst.b); + break; case IrCmd::FALLBACK_FORGPREP: - // TODO: this is very conservative, some of there instructions can be tracked better - // TODO: non-captured register tags and values should not be cleared here - state.invalidateAll(); + state.invalidate(IrOp{inst.b.kind, inst.b.index + 0u}); + state.invalidate(IrOp{inst.b.kind, inst.b.index + 1u}); + state.invalidate(IrOp{inst.b.kind, inst.b.index + 2u}); break; } } @@ -592,7 +672,7 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite { IrFunction& function = build.function; - ConstPropState state; + ConstPropState state{function}; while (block) { @@ -698,7 +778,7 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited return; // Initialize state with the knowledge of our current block - ConstPropState state; + ConstPropState state{function}; constPropInBlock(build, startingBlock, state); // Veryfy that target hasn't changed diff --git a/Makefile b/Makefile index 585122938..bbc66c2e7 100644 --- a/Makefile +++ b/Makefile @@ -117,6 +117,11 @@ ifneq ($(native),) TESTS_ARGS+=--codegen endif +ifneq ($(nativelj),) + CXXFLAGS+=-DLUA_CUSTOM_EXECUTION=1 -DLUA_USE_LONGJMP=1 + TESTS_ARGS+=--codegen +endif + # target-specific flags $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include diff --git a/Sources.cmake b/Sources.cmake index 6e0a32ed7..3f32aab83 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -84,14 +84,18 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/CodeBlockUnwind.cpp CodeGen/src/CodeGen.cpp CodeGen/src/CodeGenUtils.cpp + CodeGen/src/CodeGenA64.cpp CodeGen/src/CodeGenX64.cpp CodeGen/src/EmitBuiltinsX64.cpp + CodeGen/src/EmitCommonA64.cpp CodeGen/src/EmitCommonX64.cpp + CodeGen/src/EmitInstructionA64.cpp CodeGen/src/EmitInstructionX64.cpp CodeGen/src/Fallbacks.cpp CodeGen/src/IrAnalysis.cpp CodeGen/src/IrBuilder.cpp CodeGen/src/IrDump.cpp + CodeGen/src/IrLoweringA64.cpp CodeGen/src/IrLoweringX64.cpp CodeGen/src/IrRegAllocX64.cpp CodeGen/src/IrTranslateBuiltins.cpp @@ -106,13 +110,17 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/ByteUtils.h CodeGen/src/CustomExecUtils.h CodeGen/src/CodeGenUtils.h + CodeGen/src/CodeGenA64.h CodeGen/src/CodeGenX64.h CodeGen/src/EmitBuiltinsX64.h CodeGen/src/EmitCommon.h + CodeGen/src/EmitCommonA64.h CodeGen/src/EmitCommonX64.h + CodeGen/src/EmitInstructionA64.h CodeGen/src/EmitInstructionX64.h CodeGen/src/Fallbacks.h CodeGen/src/FallbacksProlog.h + CodeGen/src/IrLoweringA64.h CodeGen/src/IrLoweringX64.h CodeGen/src/IrRegAllocX64.h CodeGen/src/IrTranslateBuiltins.h diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 1d324896c..32a240bfb 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -208,14 +208,14 @@ typedef struct global_State uint64_t rngstate; // PCG random number generator state uint64_t ptrenckey[4]; // pointer encoding key for display - void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory - lua_Callbacks cb; #if LUA_CUSTOM_EXECUTION lua_ExecutionCallbacks ecb; #endif + void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory + GCStats gcstats; #ifdef LUAI_GCMETRICS diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index e23b965bc..a68932bac 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -120,6 +120,10 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads") SINGLE_COMPARE(ldrsh(x0, x1), 0x79800020); SINGLE_COMPARE(ldrsh(w0, x1), 0x79C00020); SINGLE_COMPARE(ldrsw(x0, x1), 0xB9800020); + + // paired loads + SINGLE_COMPARE(ldp(x0, x1, mem(x2, 8)), 0xA9408440); + SINGLE_COMPARE(ldp(w0, w1, mem(x2, -8)), 0x297F0440); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores") @@ -135,15 +139,58 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores") SINGLE_COMPARE(str(w0, x1), 0xB9000020); SINGLE_COMPARE(strb(w0, x1), 0x39000020); SINGLE_COMPARE(strh(w0, x1), 0x79000020); + + // paired stores + SINGLE_COMPARE(stp(x0, x1, mem(x2, 8)), 0xA9008440); + SINGLE_COMPARE(stp(w0, w1, mem(x2, -8)), 0x293F0440); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Moves") { SINGLE_COMPARE(mov(x0, x1), 0xAA0103E0); SINGLE_COMPARE(mov(w0, w1), 0x2A0103E0); - SINGLE_COMPARE(mov(x0, 42), 0xD2800540); - SINGLE_COMPARE(mov(w0, 42), 0x52800540); + + SINGLE_COMPARE(movz(x0, 42), 0xD2800540); + SINGLE_COMPARE(movz(w0, 42), 0x52800540); + SINGLE_COMPARE(movn(x0, 42), 0x92800540); + SINGLE_COMPARE(movn(w0, 42), 0x12800540); SINGLE_COMPARE(movk(x0, 42, 16), 0xF2A00540); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, 42); + }, + {0xD2800540})); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, 424242); + }, + {0xD28F2640, 0xF2A000C0})); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, -42); + }, + {0x92800520})); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, -424242); + }, + {0x928F2620, 0xF2BFFF20})); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, -65536); + }, + {0x929FFFE0})); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, -65537); + }, + {0x92800000, 0xF2BFFFC0})); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ControlFlow") @@ -222,6 +269,22 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Constants") // clang-format on } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "AddressOfLabel") +{ + // clang-format off + CHECK(check( + [](AssemblyBuilderA64& build) { + Label label; + build.adr(x0, label); + build.add(x0, x0, x0); + build.setLabel(label); + }, + { + 0x10000040, 0x8b000000, + })); + // clang-format on +} + TEST_CASE("LogTest") { AssemblyBuilderA64 build(/* logText= */ true); @@ -243,6 +306,9 @@ TEST_CASE("LogTest") build.b(ConditionA64::Plus, l); build.cbz(x7, l); + build.ldp(x0, x1, mem(x8, 8)); + build.adr(x0, l); + build.setLabel(l); build.ret(); @@ -263,6 +329,8 @@ TEST_CASE("LogTest") blr x0 b.pl .L1 cbz x7,.L1 + ldp x0,x1,[x8,#8] + adr x0,.L1 .L1: ret )"; diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index a6ed96f02..359f2ba1c 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -465,6 +465,7 @@ TEST_CASE("GeneratedCodeExecutionA64") build.add(x1, x1, 2); build.add(x0, x0, x1, /* LSL */ 1); + build.ret(); build.finalize(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 7d5f41a15..1072b95df 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1250,7 +1250,9 @@ TEST_CASE("Interrupt") 13, 13, 16, - 20, + 23, + 21, + 25, }; static int index; diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 37c12dc97..f4c9cdca9 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -764,12 +764,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation") build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(10)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.constDouble(2.0)); - build.inst(IrCmd::CONCAT, build.vmReg(0), build.vmReg(3)); // Concat invalidates more than the target register + build.inst(IrCmd::CONCAT, build.vmReg(0), build.constUint(3)); - build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); - build.inst(IrCmd::STORE_INT, build.vmReg(4), build.inst(IrCmd::LOAD_INT, build.vmReg(1))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); + build.inst(IrCmd::STORE_TAG, build.vmReg(4), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(5), build.inst(IrCmd::LOAD_INT, build.vmReg(1))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(6), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3))); build.inst(IrCmd::LOP_RETURN, build.constUint(0)); @@ -781,13 +783,15 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation") STORE_TAG R0, tnumber STORE_INT R1, 10i STORE_DOUBLE R2, 0.5 - CONCAT R0, R3 - %4 = LOAD_TAG R0 - STORE_TAG R3, %4 - %6 = LOAD_INT R1 - STORE_INT R4, %6 - %8 = LOAD_DOUBLE R2 - STORE_DOUBLE R5, %8 + STORE_DOUBLE R3, 2 + CONCAT R0, 3u + %5 = LOAD_TAG R0 + STORE_TAG R4, %5 + %7 = LOAD_INT R1 + STORE_INT R5, %7 + %9 = LOAD_DOUBLE R2 + STORE_DOUBLE R6, %9 + STORE_DOUBLE R7, 2 LOP_RETURN 0u )"); @@ -1179,7 +1183,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(repeat); build.inst(IrCmd::INTERRUPT, build.constUint(0)); @@ -1194,7 +1198,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") JUMP bb_1 bb_1: - LOP_RETURN 0u, R0, 0i + LOP_RETURN R0, 0i )"); } @@ -1207,14 +1211,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") IrOp repeat = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(block); build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(repeat); build.inst(IrCmd::INTERRUPT, build.constUint(0)); @@ -1225,14 +1229,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: - LOP_RETURN 0u, R0, 0i + LOP_RETURN R0, 0i bb_1: STORE_TAG R0, tnumber JUMP bb_2 bb_2: - LOP_RETURN 0u, R0, 0i + LOP_RETURN R0, 0i )"); } @@ -1249,14 +1253,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), block, exit1); build.beginBlock(exit1); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(block); build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit2, repeat); build.beginBlock(exit2); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(repeat); build.inst(IrCmd::INTERRUPT, build.constUint(0)); @@ -1270,14 +1274,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") JUMP bb_1 bb_1: - LOP_RETURN 0u, R0, 0i + LOP_RETURN R0, 0i bb_2: STORE_TAG R0, tnumber JUMP bb_3 bb_3: - LOP_RETURN 0u, R0, 0i + LOP_RETURN R0, 0i )"); } @@ -1318,7 +1322,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") build.inst(IrCmd::JUMP, block4); build.beginBlock(block4); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1346,10 +1350,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") JUMP bb_5 bb_5: - LOP_RETURN 0u, R0, 0i + LOP_RETURN R0, 0i bb_linear_6: - LOP_RETURN 0u, R0, 0i + LOP_RETURN R0, 0i )"); } @@ -1389,11 +1393,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" build.beginBlock(block4a); build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(block4b); build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1423,11 +1427,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" bb_5: STORE_TAG R0, %10 - LOP_RETURN 0u, R0, 0i + LOP_RETURN R0, 0i bb_6: STORE_TAG R0, %10 - LOP_RETURN 0u, R0, 0i + LOP_RETURN R0, 0i )"); } @@ -1484,7 +1488,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDiamond") build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(2), build.constInt(2)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(2), build.constInt(2)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1518,7 +1522,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDiamond") bb_3: ; predecessors: bb_1, bb_2 ; in regs: R2, R3 - LOP_RETURN 0u, R2, 2i + LOP_RETURN R2, 2i )"); } @@ -1530,11 +1534,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ImplicitFixedRegistersInVarargCall") build.beginBlock(entry); build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(3), build.constInt(-1)); - build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(5)); + build.inst(IrCmd::LOP_CALL, build.vmReg(0), build.constInt(-1), build.constInt(5)); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(5)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(5)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1545,13 +1549,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ImplicitFixedRegistersInVarargCall") ; in regs: R0, R1, R2 ; out regs: R0, R1, R2, R3, R4 FALLBACK_GETVARARGS 0u, R3, -1i - LOP_CALL 0u, R0, -1i, 5i + LOP_CALL R0, -1i, 5i JUMP bb_1 bb_1: ; predecessors: bb_0 ; in regs: R0, R1, R2, R3, R4 - LOP_RETURN 0u, R0, 5i + LOP_RETURN R0, 5i )"); } @@ -1563,11 +1567,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ExplicitUseOfRegisterInVarargSequence") build.beginBlock(entry); build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); - build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(0), build.vmReg(0), build.vmReg(1), build.vmReg(2), build.constInt(-1), build.constInt(-1)); + IrOp results = build.inst( + IrCmd::INVOKE_FASTCALL, build.constUint(0), build.vmReg(0), build.vmReg(1), build.vmReg(2), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(0), results); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1578,12 +1584,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ExplicitUseOfRegisterInVarargSequence") ; out regs: R0... FALLBACK_GETVARARGS 0u, R1, -1i %1 = INVOKE_FASTCALL 0u, R0, R1, R2, -1i, -1i + ADJUST_STACK_TO_REG R0, %1 JUMP bb_1 bb_1: ; predecessors: bb_0 ; in regs: R0... - LOP_RETURN 0u, R0, -1i + LOP_RETURN R0, -1i )"); } @@ -1594,12 +1601,12 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequenceRestart") IrOp exit = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(1), build.constInt(0), build.constInt(-1)); - build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::LOP_CALL, build.vmReg(1), build.constInt(0), build.constInt(-1)); + build.inst(IrCmd::LOP_CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1609,14 +1616,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequenceRestart") ; successors: bb_1 ; in regs: R0, R1 ; out regs: R0... - LOP_CALL 0u, R1, 0i, -1i - LOP_CALL 0u, R0, -1i, -1i + LOP_CALL R1, 0i, -1i + LOP_CALL R0, -1i, -1i JUMP bb_1 bb_1: ; predecessors: bb_0 ; in regs: R0... - LOP_RETURN 0u, R0, -1i + LOP_RETURN R0, -1i )"); } @@ -1630,15 +1637,15 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") build.beginBlock(entry); build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback); - build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::LOP_CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::JUMP, exit); build.beginBlock(fallback); - build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::LOP_CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1651,7 +1658,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") FALLBACK_GETVARARGS 0u, R1, -1i %1 = LOAD_TAG R0 CHECK_TAG %1, tnumber, bb_fallback_1 - LOP_CALL 0u, R0, -1i, -1i + LOP_CALL R0, -1i, -1i JUMP bb_2 bb_fallback_1: @@ -1659,13 +1666,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") ; successors: bb_2 ; in regs: R0, R1... ; out regs: R0... - LOP_CALL 0u, R0, -1i, -1i + LOP_CALL R0, -1i, -1i JUMP bb_2 bb_2: ; predecessors: bb_0, bb_fallback_1 ; in regs: R0... - LOP_RETURN 0u, R0, -1i + LOP_RETURN R0, -1i )"); } @@ -1690,7 +1697,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequencePeeling") build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(2), build.constInt(-1)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(2), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1725,7 +1732,65 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequencePeeling") bb_3: ; predecessors: bb_1, bb_2 ; in regs: R2... - LOP_RETURN 0u, R2, -1i + LOP_RETURN R2, -1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinVariadicStart") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(2.0)); + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(2), build.constInt(1)); + build.inst(IrCmd::LOP_CALL, build.vmReg(1), build.constInt(-1), build.constInt(1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1 +; in regs: R0 +; out regs: R0, R1 + STORE_DOUBLE R1, 1 + STORE_DOUBLE R2, 2 + ADJUST_STACK_TO_REG R2, 1i + LOP_CALL R1, -1i, 1i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0, R1 + LOP_RETURN R0, 2i + +)"); +} + + +TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::SET_TABLE, build.vmReg(0), build.vmReg(1), build.constUint(1)); + build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; in regs: R0, R1 + SET_TABLE R0, R1, 1u + LOP_RETURN R0, 1i )"); } diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index d2796b6d0..7e61235a8 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -338,7 +338,7 @@ type B = A TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_reexports") { ScopedFastFlag flags[] = { - {"LuauClonePublicInterfaceLess", true}, + {"LuauClonePublicInterfaceLess2", true}, {"LuauSubstitutionReentrant", true}, {"LuauClassTypeVarsInSubstitution", true}, {"LuauSubstitutionFixMissingFields", true}, @@ -376,7 +376,7 @@ return {} TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") { ScopedFastFlag flags[] = { - {"LuauClonePublicInterfaceLess", true}, + {"LuauClonePublicInterfaceLess2", true}, {"LuauSubstitutionReentrant", true}, {"LuauClassTypeVarsInSubstitution", true}, {"LuauSubstitutionFixMissingFields", true}, diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 384a39fea..a495ee231 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -802,7 +802,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_tables") TEST_CASE_FIXTURE(NormalizeFixture, "normalize_blocked_types") { - ScopedFastFlag sff[] { + ScopedFastFlag sff[]{ {"LuauNormalizeBlockedTypes", true}, }; @@ -813,4 +813,14 @@ TEST_CASE_FIXTURE(NormalizeFixture, "normalize_blocked_types") CHECK_EQ(normalizer.typeFromNormal(*norm), &blocked); } +TEST_CASE_FIXTURE(NormalizeFixture, "normalize_pending_expansion_types") +{ + AstName name; + Type pending{PendingExpansionType{std::nullopt, name, {}, {}}}; + + const NormalizedType* norm = normalizer.normalize(&pending); + + CHECK_EQ(normalizer.typeFromNormal(*norm), &pending); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index b55c77460..022abea0b 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -947,4 +947,71 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_locations") CHECK(mod->scopes[3].second->typeAliasNameLocations["X"] == Location(Position(5, 17), 1)); } +/* + * We had a bug in DCR where substitution would improperly clone a + * PendingExpansionType. + * + * This cloned type did not have a matching constraint to expand it, so it was + * left dangling and unexpanded forever. + * + * We must also delay the dispatch a constraint if doing so would require + * unifying a PendingExpansionType. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_lose_track_of_PendingExpansionTypes_after_substitution") +{ + fileResolver.source["game/ReactCurrentDispatcher"] = R"( + export type BasicStateAction = ((S) -> S) | S + export type Dispatch = (A) -> () + + export type Dispatcher = { + useState: (initialState: (() -> S) | S) -> (S, Dispatch>), + } + + return {} + )"; + + // Note: This script path is actually as short as it can be. Any shorter + // and we somehow fail to surface the bug. + fileResolver.source["game/React/React/ReactHooks"] = R"( + local RCD = require(script.Parent.Parent.Parent.ReactCurrentDispatcher) + + local function resolveDispatcher(): RCD.Dispatcher + return (nil :: any) :: RCD.Dispatcher + end + + function useState( + initialState: (() -> S) | S + ): (S, RCD.Dispatch>) + local dispatcher = resolveDispatcher() + return dispatcher.useState(initialState) + end + )"; + + CheckResult result = frontend.check("game/React/React/ReactHooks"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_thing_from_roact") +{ + CheckResult result = check(R"( + type Map = { [K]: V } + type Set = { [T]: boolean } + + type FiberRoot = { + pingCache: Map | Map>)> | nil, + } + + type Wakeable = { + andThen: (self: Wakeable) -> nil | Wakeable, + } + + local function attachPingListener(root: FiberRoot, wakeable: Wakeable, lanes: number) + local pingCache: Map | Map>)> | nil = root.pingCache + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 49209a4d4..79d9108d3 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -9,7 +9,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauMatchReturnsOptionalString); TEST_SUITE_BEGIN("BuiltinTests"); @@ -1064,10 +1063,7 @@ TEST_CASE_FIXTURE(Fixture, "string_match") )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - CHECK_EQ(toString(requireType("p")), "string?"); - else - CHECK_EQ(toString(requireType("p")), "string"); + CHECK_EQ(toString(requireType("p")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") @@ -1078,18 +1074,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(Fixture, "gmatch_capture_types2") @@ -1100,18 +1087,9 @@ TEST_CASE_FIXTURE(Fixture, "gmatch_capture_types2") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") @@ -1128,10 +1106,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") CHECK_EQ(acm->expected, 1); CHECK_EQ(acm->actual, 4); - if (FFlag::LuauMatchReturnsOptionalString) - CHECK_EQ(toString(requireType("a")), "string?"); - else - CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("a")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens") @@ -1148,18 +1123,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens CHECK_EQ(acm->expected, 3); CHECK_EQ(acm->actual, 4); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "string?"); - CHECK_EQ(toString(requireType("c")), "number?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "string"); - CHECK_EQ(toString(requireType("c")), "number"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "string?"); + CHECK_EQ(toString(requireType("c")), "number?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_ignored") @@ -1176,16 +1142,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_igno CHECK_EQ(acm->expected, 2); CHECK_EQ(acm->actual, 3); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket") @@ -1196,16 +1154,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "number?"); - CHECK_EQ(toString(requireType("b")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "number"); - CHECK_EQ(toString(requireType("b")), "string"); - } + CHECK_EQ(toString(requireType("a")), "number?"); + CHECK_EQ(toString(requireType("b")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_leading_end_bracket_is_part_of_set") @@ -1253,18 +1203,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") @@ -1280,18 +1221,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") CHECK_EQ(toString(tm->wantedType), "number?"); CHECK_EQ(toString(tm->givenType), "string"); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") @@ -1302,18 +1234,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } @@ -1331,18 +1254,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types2") CHECK_EQ(toString(tm->wantedType), "number?"); CHECK_EQ(toString(tm->givenType), "string"); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } @@ -1360,18 +1274,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") CHECK_EQ(toString(tm->wantedType), "boolean?"); CHECK_EQ(toString(tm->givenType), "string"); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 482a6b7f5..c7f9684b3 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1860,7 +1860,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceede ScopedFastInt sfi{"LuauTarjanChildLimit", 2}; ScopedFastFlag sff[] = { {"DebugLuauDeferredConstraintResolution", true}, - {"LuauClonePublicInterfaceLess", true}, + {"LuauClonePublicInterfaceLess2", true}, {"LuauSubstitutionReentrant", true}, {"LuauSubstitutionFixMissingFields", true}, }; @@ -1880,4 +1880,33 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceede CHECK(Location({0, 0}, {4, 4}) == result.errors[1].location); } +/* We had a bug under DCR where instantiated type packs had a nullptr scope. + * + * This caused an issue with promotion. + */ +TEST_CASE_FIXTURE(Fixture, "instantiated_type_packs_must_have_a_non_null_scope") +{ + CheckResult result = check(R"( + function pcall(...: A...): R... + end + + type Dispatch = (A) -> () + + function mountReducer() + dispatchAction() + return nil :: any + end + + function dispatchAction() + end + + function useReducer(): Dispatch + local result, setResult = pcall(mountReducer) + return setResult + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index f6d04a952..b682e5f6c 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -327,7 +327,12 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); + auto e = toString(result.errors[0]); + // In DCR, because of type normalization, we print a different error message + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Cannot add property 'z' to table '{| x: number, y: number |}'", e); + else + CHECK_EQ("Cannot add property 'z' to table 'X & Y'", e); } TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 0f540f683..eb4937fde 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -326,4 +326,59 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "flag_when_index_metamethod_returns_0_values" CHECK("nil" == toString(requireType("p"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "augmenting_an_unsealed_table_with_a_metatable") +{ + CheckResult result = check(R"( + local A = {number = 8} + + local B = setmetatable({}, A) + + function B:method() + return "hello!!" + end + )"); + + CHECK("{ @metatable { number: number }, { method: (a) -> string } }" == toString(requireType("B"), {true})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "react_style_oo") +{ + CheckResult result = check(R"( + local Prototype = {} + + local ClassMetatable = { + __index = Prototype + } + + local BaseClass = (setmetatable({}, ClassMetatable)) + + function BaseClass:extend(name) + local class = { + name=name + } + + class.__index = class + + function class.ctor(props) + return setmetatable({props=props}, class) + end + + return setmetatable(class, getmetatable(self)) + end + + local C = BaseClass:extend('C') + local i = C.ctor({hello='world'}) + + local iName = i.name + local cName = C.name + local hello = i.props.hello + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("string" == toString(requireType("iName"))); + CHECK("string" == toString(requireType("cName"))); + CHECK("string" == toString(requireType("hello"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 720784c35..8c289c7be 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -526,7 +526,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("string", toString(requireType("a"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // Under DCR, this currently functions as a failed overload resolution, and so we can't say + // anything about the result type of the unary minus. + CHECK_EQ("any", toString(requireType("a"))); + } + else + { + + CHECK_EQ("string", toString(requireType("a"))); + } TypeMismatch* tm = get(result.errors[0]); REQUIRE_EQ(*tm->wantedType, *builtinTypes->booleanType); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index d49f00443..19a19e450 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -196,7 +196,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") REQUIRE(bTy); CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - CHECK_EQ("*error-type*", toString(requireType("r"))); } @@ -354,7 +353,11 @@ a.x = 2 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); + auto s = toString(result.errors[0]); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Value of type '{| x: number, y: number |}?' could be nil", s); + else + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", s); } TEST_CASE_FIXTURE(Fixture, "optional_length_error") diff --git a/tests/conformance/interrupt.lua b/tests/conformance/interrupt.lua index d4b7c80a4..c07f57e7d 100644 --- a/tests/conformance/interrupt.lua +++ b/tests/conformance/interrupt.lua @@ -17,4 +17,9 @@ end bar() +function baz() +end + +baz() + return "OK" diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 18ed13706..ea3b5c87a 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -345,5 +345,7 @@ assert(math.round("1.8") == 2) assert(select('#', math.floor(1.4)) == 1) assert(select('#', math.ceil(1.6)) == 1) assert(select('#', math.sqrt(9)) == 1) +assert(select('#', math.deg(9)) == 1) +assert(select('#', math.rad(9)) == 1) return('OK') diff --git a/tests/main.cpp b/tests/main.cpp index 82ce4e16a..5395e7c60 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -59,6 +59,25 @@ static bool debuggerPresent() int ret = sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &size, nullptr, 0); // debugger is attached if the P_TRACED flag is set return ret == 0 && (info.kp_proc.p_flag & P_TRACED) != 0; +#elif defined(__linux__) + FILE* st = fopen("/proc/self/status", "r"); + if (!st) + return false; // assume no debugger is attached. + + int tpid = 0; + char buf[256]; + + while (fgets(buf, sizeof(buf), st)) + { + if (strncmp(buf, "TracerPid:\t", 11) == 0) + { + tpid = atoi(buf + 11); + break; + } + } + + fclose(st); + return tpid != 0; #else return false; // assume no debugger is attached. #endif @@ -67,7 +86,7 @@ static bool debuggerPresent() static int testAssertionHandler(const char* expr, const char* file, int line, const char* function) { if (debuggerPresent()) - LUAU_DEBUGBREAK(); + return 1; // LUAU_ASSERT will trigger LUAU_DEBUGBREAK for a more convenient debugging experience ADD_FAIL_AT(file, line, "Assertion failed: ", std::string(expr)); return 1; diff --git a/tools/faillist.txt b/tools/faillist.txt index d513af142..76e5972dc 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -25,7 +25,6 @@ BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_tostring_specifier_type_constraint BuiltinTests.string_format_use_correct_argument2 -BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload BuiltinTests.table_pack BuiltinTests.table_pack_reduce BuiltinTests.table_pack_variadic @@ -49,9 +48,9 @@ GenericsTests.infer_generic_lib_function_function_argument GenericsTests.instantiated_function_argument_names GenericsTests.no_stack_overflow_from_quantifying GenericsTests.self_recursive_instantiated_param -IntersectionTypes.table_intersection_write_sealed IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect +isSubtype.any_is_unknown_union_error ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.bail_early_if_unification_is_too_complicated ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack @@ -70,7 +69,6 @@ RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible -TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.checked_prop_too_early TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar @@ -123,6 +121,7 @@ ToString.toStringNamedFunction_generic_pack TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails +TypeAliases.dont_lose_track_of_PendingExpansionTypes_after_substitution TypeAliases.generic_param_remap TypeAliases.mismatched_generic_type_param TypeAliases.mutually_recursive_types_restriction_not_ok_1 @@ -218,10 +217,5 @@ TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeSingletons.widening_happens_almost_everywhere UnionTypes.generic_function_with_optional_arg UnionTypes.index_on_a_union_type_with_missing_property -UnionTypes.optional_assignment_errors -UnionTypes.optional_call_error -UnionTypes.optional_index_error -UnionTypes.optional_iteration -UnionTypes.optional_length_error UnionTypes.optional_union_follow UnionTypes.table_union_write_indirect From d1acde36bb84d045e705174c52f83ce3ff735ba4 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 31 Mar 2023 15:21:14 +0300 Subject: [PATCH 43/66] Sync to upstream/release/570 --- Analysis/include/Luau/ConstraintSolver.h | 18 +- Analysis/include/Luau/Normalize.h | 6 +- Analysis/include/Luau/Scope.h | 6 +- Analysis/include/Luau/TypeInfer.h | 6 +- Analysis/src/AstQuery.cpp | 22 +- Analysis/src/Autocomplete.cpp | 51 +- Analysis/src/BuiltinDefinitions.cpp | 25 +- Analysis/src/ConstraintSolver.cpp | 121 +++-- Analysis/src/Frontend.cpp | 8 +- Analysis/src/Linter.cpp | 10 +- Analysis/src/Normalize.cpp | 81 ++- Analysis/src/Scope.cpp | 6 +- Analysis/src/TypeChecker2.cpp | 6 +- Analysis/src/TypeInfer.cpp | 53 +- Analysis/src/Unifier.cpp | 21 +- Ast/src/Lexer.cpp | 6 +- CodeGen/include/Luau/AddressA64.h | 7 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 46 +- CodeGen/include/Luau/AssemblyBuilderX64.h | 1 + CodeGen/include/Luau/ConditionA64.h | 17 + CodeGen/include/Luau/IrCallWrapperX64.h | 82 +++ CodeGen/include/Luau/IrData.h | 60 ++- CodeGen/{src => include/Luau}/IrRegAllocX64.h | 9 +- CodeGen/include/Luau/IrUtils.h | 13 +- CodeGen/include/Luau/RegisterA64.h | 68 +++ CodeGen/src/AssemblyBuilderA64.cpp | 287 +++++++++- CodeGen/src/AssemblyBuilderX64.cpp | 9 +- CodeGen/src/CodeGen.cpp | 89 ++- CodeGen/src/CodeGenA64.cpp | 10 + CodeGen/src/CodeGenUtils.cpp | 87 ++- CodeGen/src/CodeGenUtils.h | 3 +- CodeGen/src/EmitBuiltinsX64.cpp | 144 ++--- CodeGen/src/EmitCommon.h | 7 + CodeGen/src/EmitCommonA64.cpp | 91 ++-- CodeGen/src/EmitCommonA64.h | 53 +- CodeGen/src/EmitCommonX64.cpp | 185 +++---- CodeGen/src/EmitCommonX64.h | 24 +- CodeGen/src/EmitInstructionA64.cpp | 59 +- CodeGen/src/EmitInstructionA64.h | 4 + CodeGen/src/EmitInstructionX64.cpp | 11 +- CodeGen/src/EmitInstructionX64.h | 5 +- CodeGen/src/IrAnalysis.cpp | 26 +- CodeGen/src/IrBuilder.cpp | 21 +- CodeGen/src/IrCallWrapperX64.cpp | 400 ++++++++++++++ CodeGen/src/IrDump.cpp | 54 +- CodeGen/src/IrLoweringA64.cpp | 505 +++++++++++++++++- CodeGen/src/IrLoweringA64.h | 10 +- CodeGen/src/IrLoweringX64.cpp | 309 ++++++----- CodeGen/src/IrLoweringX64.h | 3 +- CodeGen/src/IrRegAllocA64.cpp | 174 ++++++ CodeGen/src/IrRegAllocA64.h | 55 ++ CodeGen/src/IrRegAllocX64.cpp | 68 ++- CodeGen/src/IrTranslateBuiltins.cpp | 68 ++- CodeGen/src/IrTranslation.cpp | 71 ++- CodeGen/src/IrUtils.cpp | 20 + CodeGen/src/NativeState.cpp | 2 + CodeGen/src/NativeState.h | 4 +- CodeGen/src/OptimizeConstProp.cpp | 29 +- Compiler/src/Compiler.cpp | 6 +- Sources.cmake | 8 +- VM/src/lbuiltins.cpp | 17 - fuzz/format.cpp | 1 + fuzz/linter.cpp | 15 +- fuzz/proto.cpp | 16 +- fuzz/typeck.cpp | 17 +- tests/AssemblyBuilderA64.test.cpp | 98 +++- tests/Autocomplete.test.cpp | 2 - tests/Compiler.test.cpp | 8 - tests/Conformance.test.cpp | 2 +- tests/ConstraintGraphBuilderFixture.cpp | 3 +- tests/IrBuilder.test.cpp | 278 +++++----- tests/IrCallWrapperX64.test.cpp | 484 +++++++++++++++++ tests/Lexer.test.cpp | 2 - tests/Linter.test.cpp | 4 - tests/Normalize.test.cpp | 1 - tests/Parser.test.cpp | 2 - tests/TypeInfer.aliases.test.cpp | 30 ++ tests/TypeInfer.functions.test.cpp | 3 - tests/TypeInfer.loops.test.cpp | 22 + tests/TypeInfer.oop.test.cpp | 25 + tests/TypeInfer.operators.test.cpp | 6 - tests/TypeInfer.provisional.test.cpp | 2 - tests/TypeInfer.tables.test.cpp | 10 +- tests/TypeInfer.test.cpp | 1 - tools/faillist.txt | 5 +- tools/natvis/CodeGen.natvis | 57 +- 86 files changed, 3543 insertions(+), 1218 deletions(-) create mode 100644 CodeGen/include/Luau/IrCallWrapperX64.h rename CodeGen/{src => include/Luau}/IrRegAllocX64.h (85%) create mode 100644 CodeGen/src/IrCallWrapperX64.cpp create mode 100644 CodeGen/src/IrRegAllocA64.cpp create mode 100644 CodeGen/src/IrRegAllocA64.h create mode 100644 tests/IrCallWrapperX64.test.cpp diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index e9e1e884a..2feee2368 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -53,7 +53,6 @@ struct ConstraintSolver NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; - NotNull reducer; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; @@ -85,8 +84,7 @@ struct ConstraintSolver DcrLogger* logger; explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, - DcrLogger* logger); + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); @@ -219,6 +217,20 @@ struct ConstraintSolver void reportError(TypeError e); private: + + /** Helper used by tryDispatch(SubtypeConstraint) and + * tryDispatch(PackSubtypeConstraint) + * + * Attempts to unify subTy with superTy. If doing so would require unifying + * BlockedTypes, fail and block the constraint on those BlockedTypes. + * + * If unification fails, replace all free types with errorType. + * + * If unification succeeds, unblock every type changed by the unification. + */ + template + bool tryUnify(NotNull constraint, TID subTy, TID superTy); + /** * Marks a constraint as being blocked on a type or type pack. The constraint * solver will not attempt to dispatch blocked constraints until their diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 15404707d..efcb51085 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -191,12 +191,8 @@ struct NormalizedClassType // this type may contain `error`. struct NormalizedFunctionType { - NormalizedFunctionType(); - bool isTop = false; - // TODO: Remove this wrapping optional when clipping - // FFlagLuauNegatedFunctionTypes. - std::optional parts; + TypeIds parts; void resetToNever(); void resetToTop(); diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 745ea47ab..c3038face 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -55,11 +55,11 @@ struct Scope std::optional lookup(DefId def) const; std::optional> lookupEx(Symbol sym); - std::optional lookupType(const Name& name); - std::optional lookupImportedType(const Name& moduleAlias, const Name& name); + std::optional lookupType(const Name& name) const; + std::optional lookupImportedType(const Name& moduleAlias, const Name& name) const; std::unordered_map privateTypePackBindings; - std::optional lookupPack(const Name& name); + std::optional lookupPack(const Name& name) const; // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 68161794a..7dae79c31 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -79,7 +79,8 @@ struct GlobalTypes // within a program are borrowed pointers into this set. struct TypeChecker { - explicit TypeChecker(const GlobalTypes& globals, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); + explicit TypeChecker( + const ScopePtr& globalScope, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); TypeChecker(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete; @@ -367,8 +368,7 @@ struct TypeChecker */ std::vector unTypePack(const ScopePtr& scope, TypePackId pack, size_t expectedLength, const Location& location); - // TODO: only const version of global scope should be available to make sure nothing else is modified inside of from users of TypeChecker - const GlobalTypes& globals; + const ScopePtr& globalScope; ModuleResolver* resolver; ModulePtr currentModule; diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index b0c3750b1..dc07a35ca 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -11,8 +11,6 @@ #include -LUAU_FASTFLAG(LuauCompleteTableKeysBetter); - namespace Luau { @@ -31,24 +29,12 @@ struct AutocompleteNodeFinder : public AstVisitor bool visit(AstExpr* expr) override { - if (FFlag::LuauCompleteTableKeysBetter) - { - if (expr->location.begin <= pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; - } - else + if (expr->location.begin <= pos && pos <= expr->location.end) { - if (expr->location.begin < pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; + ancestry.push_back(expr); + return true; } + return false; } bool visit(AstStat* stat) override diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1df4d3d75..3fdd93190 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,7 +13,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteSkipNormalization, false); static const std::unordered_set kStatementStartingKeywords = { @@ -981,25 +980,14 @@ T* extractStat(const std::vector& ancestry) AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; - if (FFlag::LuauCompleteTableKeysBetter) - { - if (!grandParent) - return nullptr; - - if (T* t = parent->as(); t && grandParent->is()) - return t; + if (!grandParent) + return nullptr; - if (!greatGrandParent) - return nullptr; - } - else - { - if (T* t = parent->as(); t && parent->is()) - return t; + if (T* t = parent->as(); t && grandParent->is()) + return t; - if (!grandParent || !greatGrandParent) - return nullptr; - } + if (!greatGrandParent) + return nullptr; if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) return t; @@ -1533,23 +1521,20 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); - if (FFlag::LuauCompleteTableKeysBetter) - { - if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*nodeIt, !node->is(), result); + if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*nodeIt, !node->is(), result); - if (!key) + if (!key) + { + // If there is "no key," it may be that the user + // intends for the current token to be the key, but + // has yet to type the `=` sign. + // + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) { - // If there is "no key," it may be that the user - // intends for the current token to be the key, but - // has yet to type the `=` sign. - // - // If the key type is a union of singleton strings, - // suggest those too. - if (auto ttv = get(follow(*it)); ttv && ttv->indexer) - { - autocompleteStringSingleton(ttv->indexer->indexType, false, result); - } + autocompleteStringSingleton(ttv->indexer->indexType, false, result); } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index d2ace49b9..2108b160f 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -15,8 +15,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauDeprecateTableGetnForeach, false) - /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk @@ -298,13 +296,10 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); - if (FFlag::LuauDeprecateTableGetnForeach) - { - ttv->props["getn"].deprecated = true; - ttv->props["getn"].deprecatedSuggestion = "#"; - ttv->props["foreach"].deprecated = true; - ttv->props["foreachi"].deprecated = true; - } + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + ttv->props["foreach"].deprecated = true; + ttv->props["foreachi"].deprecated = true; attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); @@ -401,15 +396,13 @@ void registerBuiltinGlobals(Frontend& frontend) ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); - if (FFlag::LuauDeprecateTableGetnForeach) - { - ttv->props["getn"].deprecated = true; - ttv->props["getn"].deprecatedSuggestion = "#"; - ttv->props["foreach"].deprecated = true; - ttv->props["foreachi"].deprecated = true; - } + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + ttv->props["foreach"].deprecated = true; + ttv->props["foreachi"].deprecated = true; attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); + attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); } attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index d5853932e..d2bed2da3 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -226,12 +226,10 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, - DcrLogger* logger) + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) - , reducer(reducer) , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) @@ -458,40 +456,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNullscope, Location{}, Covariant}; - u.useScopes = true; - - u.tryUnify(c.subType, c.superType); - - if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) - { - for (TypeId bt : u.blockedTypes) - block(bt, constraint); - for (TypePackId btp : u.blockedTypePacks) - block(btp, constraint); - return false; - } - - if (const auto& e = hasUnificationTooComplex(u.errors)) - reportError(*e); - - if (!u.errors.empty()) - { - TypeId errorType = errorRecoveryType(); - u.tryUnify(c.subType, errorType); - u.tryUnify(c.superType, errorType); - } - - const auto [changedTypes, changedPacks] = u.log.getChanges(); - - u.log.commit(); - - unblock(changedTypes); - unblock(changedPacks); - - // unify(c.subType, c.superType, constraint->scope); - - return true; + return tryUnify(constraint, c.subType, c.superType); } bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) @@ -501,9 +466,7 @@ bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNullscope); - - return true; + return tryUnify(constraint, c.subPack, c.superPack); } bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force) @@ -1117,7 +1080,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul InstantiationQueuer queuer{constraint->scope, constraint->location, this}; queuer.traverse(target); - if (target->persistent) + if (target->persistent || target->owningArena != arena) { bindResult(target); return true; @@ -1335,8 +1298,6 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullreduce(subjectType).value_or(subjectType); - auto [blocked, result] = lookupTableProp(subjectType, c.prop); if (!blocked.empty()) { @@ -1716,8 +1677,15 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (auto iteratorTable = get(iteratorTy)) { - if (iteratorTable->state == TableState::Free) - return block_(iteratorTy); + /* + * We try not to dispatch IterableConstraints over free tables because + * it's possible that there are other constraints on the table that will + * clarify what we should do. + * + * We should eventually introduce a type family to talk about iteration. + */ + if (iteratorTable->state == TableState::Free && !force) + return block(iteratorTy, constraint); if (iteratorTable->indexer) { @@ -1957,14 +1925,14 @@ std::pair, std::optional> ConstraintSolver::lookupTa else if (auto utv = get(subjectType)) { std::vector blocked; - std::vector options; + std::set options; for (TypeId ty : utv) { auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) - options.push_back(*innerResult); + options.insert(*innerResult); } if (!blocked.empty()) @@ -1973,21 +1941,21 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (options.empty()) return {{}, std::nullopt}; else if (options.size() == 1) - return {{}, options[0]}; + return {{}, *begin(options)}; else - return {{}, arena->addType(UnionType{std::move(options)})}; + return {{}, arena->addType(UnionType{std::vector(begin(options), end(options))})}; } else if (auto itv = get(subjectType)) { std::vector blocked; - std::vector options; + std::set options; for (TypeId ty : itv) { auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) - options.push_back(*innerResult); + options.insert(*innerResult); } if (!blocked.empty()) @@ -1996,14 +1964,61 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (options.empty()) return {{}, std::nullopt}; else if (options.size() == 1) - return {{}, options[0]}; + return {{}, *begin(options)}; else - return {{}, arena->addType(IntersectionType{std::move(options)})}; + return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; } return {{}, std::nullopt}; } +static TypeId getErrorType(NotNull builtinTypes, TypeId) +{ + return builtinTypes->errorRecoveryType(); +} + +static TypePackId getErrorType(NotNull builtinTypes, TypePackId) +{ + return builtinTypes->errorRecoveryTypePack(); +} + +template +bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, TID superTy) +{ + Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(subTy, superTy); + + if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) + { + for (TypeId bt : u.blockedTypes) + block(bt, constraint); + for (TypePackId btp : u.blockedTypePacks) + block(btp, constraint); + return false; + } + + if (const auto& e = hasUnificationTooComplex(u.errors)) + reportError(*e); + + if (!u.errors.empty()) + { + TID errorType = getErrorType(builtinTypes, TID{}); + u.tryUnify(subTy, errorType); + u.tryUnify(superTy, errorType); + } + + const auto [changedTypes, changedPacks] = u.log.getChanges(); + + u.log.commit(); + + unblock(changedTypes); + unblock(changedPacks); + + return true; +} + void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index a50933b77..191e94f4d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -435,8 +435,8 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c , moduleResolverForAutocomplete(this) , globals(builtinTypes) , globalsForAutocomplete(builtinTypes) - , typeChecker(globals, &moduleResolver, builtinTypes, &iceHandler) - , typeCheckerForAutocomplete(globalsForAutocomplete, &moduleResolverForAutocomplete, builtinTypes, &iceHandler) + , typeChecker(globals.globalScope, &moduleResolver, builtinTypes, &iceHandler) + , typeCheckerForAutocomplete(globalsForAutocomplete.globalScope, &moduleResolverForAutocomplete, builtinTypes, &iceHandler) , configResolver(configResolver) , options(options) { @@ -970,8 +970,8 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorerrors = std::move(cgb.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, - NotNull{result->reduction.get()}, moduleResolver, requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, moduleResolver, + requireCycles, logger.get()}; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index f850bd3d1..d6aafda62 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,8 +14,6 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) -LUAU_FASTFLAGVARIABLE(LuauImproveDeprecatedApiLint, false) - namespace Luau { @@ -2102,9 +2100,6 @@ class LintDeprecatedApi : AstVisitor public: LUAU_NOINLINE static void process(LintContext& context) { - if (!FFlag::LuauImproveDeprecatedApiLint && !context.module) - return; - LintDeprecatedApi pass{&context}; context.root->visit(&pass); } @@ -2122,8 +2117,7 @@ class LintDeprecatedApi : AstVisitor if (std::optional ty = context->getType(node->expr)) check(node, follow(*ty)); else if (AstExprGlobal* global = node->expr->as()) - if (FFlag::LuauImproveDeprecatedApiLint) - check(node->location, global->name, node->index); + check(node->location, global->name, node->index); return true; } @@ -2144,7 +2138,7 @@ class LintDeprecatedApi : AstVisitor if (prop != tty->props.end() && prop->second.deprecated) { // strip synthetic typeof() for builtin tables - if (FFlag::LuauImproveDeprecatedApiLint && tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')') + if (tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')') report(node->location, prop->second, tty->name->substr(7, tty->name->length() - 8).c_str(), node->index.value); else report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index f383f5eae..7c56a4b8f 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -202,26 +201,21 @@ bool NormalizedClassType::isNever() const return classes.empty(); } -NormalizedFunctionType::NormalizedFunctionType() - : parts(FFlag::LuauNegatedFunctionTypes ? std::optional{TypeIds{}} : std::nullopt) -{ -} - void NormalizedFunctionType::resetToTop() { isTop = true; - parts.emplace(); + parts.clear(); } void NormalizedFunctionType::resetToNever() { isTop = false; - parts.emplace(); + parts.clear(); } bool NormalizedFunctionType::isNever() const { - return !isTop && (!parts || parts->empty()); + return !isTop && parts.empty(); } NormalizedType::NormalizedType(NotNull builtinTypes) @@ -438,13 +432,10 @@ static bool isNormalizedThread(TypeId ty) static bool areNormalizedFunctions(const NormalizedFunctionType& tys) { - if (tys.parts) + for (TypeId ty : tys.parts) { - for (TypeId ty : *tys.parts) - { - if (!get(ty) && !get(ty)) - return false; - } + if (!get(ty) && !get(ty)) + return false; } return true; } @@ -1170,13 +1161,10 @@ std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (FFlag::LuauNegatedFunctionTypes) - { - if (heres.isTop) - return; - if (theres.isTop) - heres.resetToTop(); - } + if (heres.isTop) + return; + if (theres.isTop) + heres.resetToTop(); if (theres.isNever()) return; @@ -1185,13 +1173,13 @@ void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedF if (heres.isNever()) { - tmps.insert(theres.parts->begin(), theres.parts->end()); + tmps.insert(theres.parts.begin(), theres.parts.end()); heres.parts = std::move(tmps); return; } - for (TypeId here : *heres.parts) - for (TypeId there : *theres.parts) + for (TypeId here : heres.parts) + for (TypeId there : theres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); @@ -1213,7 +1201,7 @@ void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeI } TypeIds tmps; - for (TypeId here : *heres.parts) + for (TypeId here : heres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); @@ -1420,7 +1408,6 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.threads = there; else if (ptv->type == PrimitiveType::Function) { - LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); here.functions.resetToTop(); } else if (ptv->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) @@ -1553,15 +1540,12 @@ std::optional Normalizer::negateNormal(const NormalizedType& her * arbitrary function types. Ordinary code can never form these kinds of * types, so we decline to negate them. */ - if (FFlag::LuauNegatedFunctionTypes) - { - if (here.functions.isNever()) - result.functions.resetToTop(); - else if (here.functions.isTop) - result.functions.resetToNever(); - else - return std::nullopt; - } + if (here.functions.isNever()) + result.functions.resetToTop(); + else if (here.functions.isTop) + result.functions.resetToNever(); + else + return std::nullopt; /* * It is not possible to negate an arbitrary table type, because function @@ -2390,15 +2374,15 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T heres.isTop = false; - for (auto it = heres.parts->begin(); it != heres.parts->end();) + for (auto it = heres.parts.begin(); it != heres.parts.end();) { TypeId here = *it; if (get(here)) it++; else if (std::optional tmp = intersectionOfFunctions(here, there)) { - heres.parts->erase(it); - heres.parts->insert(*tmp); + heres.parts.erase(it); + heres.parts.insert(*tmp); return; } else @@ -2406,13 +2390,13 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T } TypeIds tmps; - for (TypeId here : *heres.parts) + for (TypeId here : heres.parts) { if (std::optional tmp = unionSaturatedFunctions(here, there)) tmps.insert(*tmp); } - heres.parts->insert(there); - heres.parts->insert(tmps.begin(), tmps.end()); + heres.parts.insert(there); + heres.parts.insert(tmps.begin(), tmps.end()); } void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) @@ -2426,7 +2410,7 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali } else { - for (TypeId there : *theres.parts) + for (TypeId there : theres.parts) intersectFunctionsWithFunction(heres, there); } } @@ -2621,10 +2605,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) else if (ptv->type == PrimitiveType::Thread) here.threads = threads; else if (ptv->type == PrimitiveType::Function) - { - LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); here.functions = std::move(functions); - } else if (ptv->type == PrimitiveType::Table) { LUAU_ASSERT(FFlag::LuauNegatedTableTypes); @@ -2768,16 +2749,16 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) if (!get(norm.errors)) result.push_back(norm.errors); - if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop) + if (norm.functions.isTop) result.push_back(builtinTypes->functionType); else if (!norm.functions.isNever()) { - if (norm.functions.parts->size() == 1) - result.push_back(*norm.functions.parts->begin()); + if (norm.functions.parts.size() == 1) + result.push_back(*norm.functions.parts.begin()); else { std::vector parts; - parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); + parts.insert(parts.end(), norm.functions.parts.begin(), norm.functions.parts.end()); result.push_back(arena->addType(IntersectionType{std::move(parts)})); } } diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index f54ebe2a9..2de381be2 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -65,7 +65,7 @@ std::optional Scope::lookup(DefId def) const return std::nullopt; } -std::optional Scope::lookupType(const Name& name) +std::optional Scope::lookupType(const Name& name) const { const Scope* scope = this; while (true) @@ -85,7 +85,7 @@ std::optional Scope::lookupType(const Name& name) } } -std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) +std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) const { const Scope* scope = this; while (scope) @@ -110,7 +110,7 @@ std::optional Scope::lookupImportedType(const Name& moduleAlias, const return std::nullopt; } -std::optional Scope::lookupPack(const Name& name) +std::optional Scope::lookupPack(const Name& name) const { const Scope* scope = this; while (true) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index ec71a583a..c7d30f437 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -2075,12 +2075,12 @@ struct TypeChecker2 fetch(builtinTypes->functionType); else if (!norm.functions.isNever()) { - if (norm.functions.parts->size() == 1) - fetch(norm.functions.parts->front()); + if (norm.functions.parts.size() == 1) + fetch(norm.functions.parts.front()); else { std::vector parts; - parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); + parts.insert(parts.end(), norm.functions.parts.begin(), norm.functions.parts.end()); fetch(testArena.addType(IntersectionType{std::move(parts)})); } } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 48ff6a209..f47815588 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -26,7 +26,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTFLAGVARIABLE(LuauDontExtendUnsealedRValueTables, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) @@ -38,7 +37,6 @@ LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) @@ -228,8 +226,8 @@ GlobalTypes::GlobalTypes(NotNull builtinTypes) globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); } -TypeChecker::TypeChecker(const GlobalTypes& globals, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) - : globals(globals) +TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) + : globalScope(globalScope) , resolver(resolver) , builtinTypes(builtinTypes) , iceHandler(iceHandler) @@ -280,7 +278,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; - ScopePtr parentScope = environmentScope.value_or(globals.globalScope); + ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); if (module.cyclic) @@ -1656,7 +1654,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea } else { - if (globals.globalScope->builtinTypeNames.contains(name)) + if (globalScope->builtinTypeNames.contains(name)) { reportError(typealias.location, DuplicateTypeDefinition{name}); duplicateTypeAliases.insert({typealias.exported, name}); @@ -2690,7 +2688,7 @@ TypeId TypeChecker::checkRelationalOperation( if (get(lhsType) || get(rhsType)) return booleanType; - if (FFlag::LuauIntersectionTestForEquality && isEquality) + if (isEquality) { // Unless either type is free or any, an equality comparison is only // valid when the intersection of the two operands is non-empty. @@ -3261,16 +3259,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { return it->second.type; } - else if (!FFlag::LuauDontExtendUnsealedRValueTables && (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free)) - { - TypeId theType = freshType(scope); - Property& property = lhsTable->props[name]; - property.type = theType; - property.location = expr.indexLocation; - return theType; - } - else if (FFlag::LuauDontExtendUnsealedRValueTables && - ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free)) + else if ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free) { TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; @@ -3391,16 +3380,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { return it->second.type; } - else if (!FFlag::LuauDontExtendUnsealedRValueTables && (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)) - { - TypeId resultType = freshType(scope); - Property& property = exprTable->props[value->value.data]; - property.type = resultType; - property.location = expr.index->location; - return resultType; - } - else if (FFlag::LuauDontExtendUnsealedRValueTables && - ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free) { TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; @@ -3416,14 +3396,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex unify(indexType, indexer.indexType, scope, expr.index->location); return indexer.indexResultType; } - else if (!FFlag::LuauDontExtendUnsealedRValueTables && (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)) - { - TypeId resultType = freshType(exprTable->level); - exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; - return resultType; - } - else if (FFlag::LuauDontExtendUnsealedRValueTables && - ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free) { TypeId indexerType = freshType(exprTable->level); unify(indexType, indexerType, scope, expr.location); @@ -3439,13 +3412,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex * has no indexer, we have no idea if it will work so we just return any * and hope for the best. */ - if (FFlag::LuauDontExtendUnsealedRValueTables) - return anyType; - else - { - TypeId resultType = freshType(scope); - return resultType; - } + return anyType; } } @@ -5997,7 +5964,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r if (!typeguardP.isTypeof) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - auto typeFun = globals.globalScope->lookupType(typeguardP.kind); + auto typeFun = globalScope->lookupType(typeguardP.kind); if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 5f01a6062..b748d115f 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -21,11 +21,9 @@ LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauTinyUnifyNormalsFix, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) -LUAU_FASTFLAG(LuauNegatedFunctionTypes) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAG(LuauNegatedTableTypes) @@ -615,8 +613,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); - else if (auto ptv = get(superTy); - FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveType::Function && get(subTy)) + else if (auto ptv = get(superTy); ptv && ptv->type == PrimitiveType::Function && get(subTy)) { // Ok. Do nothing. forall functions F, F <: function } @@ -1275,17 +1272,7 @@ void Unifier::tryUnifyNormalizedTypes( Unifier innerState = makeChildUnifier(); - if (FFlag::LuauTinyUnifyNormalsFix) - innerState.tryUnify(subTable, superTable); - else - { - if (get(superTable)) - innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); - else if (get(subTable)) - innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); - else - innerState.tryUnifyTables(subTable, superTable); - } + innerState.tryUnify(subTable, superTable); if (innerState.errors.empty()) { @@ -1304,7 +1291,7 @@ void Unifier::tryUnifyNormalizedTypes( { if (superNorm.functions.isNever()) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - for (TypeId superFun : *superNorm.functions.parts) + for (TypeId superFun : superNorm.functions.parts) { Unifier innerState = makeChildUnifier(); const FunctionType* superFtv = get(superFun); @@ -1343,7 +1330,7 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized std::optional result; const FunctionType* firstFun = nullptr; - for (TypeId overload : *overloads.parts) + for (TypeId overload : overloads.parts) { if (const FunctionType* ftv = get(overload)) { diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index dac3b95b6..75b4fe301 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauFixInterpStringMid, false) - namespace Luau { @@ -642,9 +640,7 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT } consume(); - Lexeme lexemeOutput(Location(start, position()), FFlag::LuauFixInterpStringMid ? formatType : Lexeme::InterpStringBegin, - &buffer[startOffset], offset - startOffset - 1); - return lexemeOutput; + return Lexeme(Location(start, position()), formatType, &buffer[startOffset], offset - startOffset - 1); } default: diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h index 2c852046c..2796ef708 100644 --- a/CodeGen/include/Luau/AddressA64.h +++ b/CodeGen/include/Luau/AddressA64.h @@ -3,6 +3,8 @@ #include "Luau/RegisterA64.h" +#include + namespace Luau { namespace CodeGen @@ -23,6 +25,10 @@ enum class AddressKindA64 : uint8_t struct AddressA64 { + // This is a little misleading since AddressA64 can encode offsets up to 1023*size where size depends on the load/store size + // For example, ldr x0, [reg+imm] is limited to 8 KB offsets assuming imm is divisible by 8, but loading into w0 reduces the range to 4 KB + static constexpr size_t kMaxOffset = 1023; + AddressA64(RegisterA64 base, int off = 0) : kind(AddressKindA64::imm) , base(base) @@ -30,7 +36,6 @@ struct AddressA64 , data(off) { LUAU_ASSERT(base.kind == KindA64::x || base == sp); - LUAU_ASSERT(off >= -256 && off < 4096); } AddressA64(RegisterA64 base, RegisterA64 offset) diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 1190e9754..0c7387128 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -16,10 +16,15 @@ namespace CodeGen namespace A64 { +enum FeaturesA64 +{ + Feature_JSCVT = 1 << 0, +}; + class AssemblyBuilderA64 { public: - explicit AssemblyBuilderA64(bool logText); + explicit AssemblyBuilderA64(bool logText, unsigned int features = 0); ~AssemblyBuilderA64(); // Moves @@ -42,6 +47,7 @@ class AssemblyBuilderA64 // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm void cmp(RegisterA64 src1, RegisterA64 src2); void cmp(RegisterA64 src1, uint16_t src2); + void csel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); // Bitwise // Note: shifted-register support and bitfield operations are omitted for simplicity @@ -93,6 +99,36 @@ class AssemblyBuilderA64 // Address of code (label) void adr(RegisterA64 dst, Label& label); + // Floating-point scalar moves + void fmov(RegisterA64 dst, RegisterA64 src); + + // Floating-point scalar math + void fabs(RegisterA64 dst, RegisterA64 src); + void fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void fmul(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void fneg(RegisterA64 dst, RegisterA64 src); + void fsqrt(RegisterA64 dst, RegisterA64 src); + void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + + // Floating-point rounding and conversions + void frinta(RegisterA64 dst, RegisterA64 src); + void frintm(RegisterA64 dst, RegisterA64 src); + void frintp(RegisterA64 dst, RegisterA64 src); + void fcvtzs(RegisterA64 dst, RegisterA64 src); + void fcvtzu(RegisterA64 dst, RegisterA64 src); + void scvtf(RegisterA64 dst, RegisterA64 src); + void ucvtf(RegisterA64 dst, RegisterA64 src); + + // Floating-point conversion to integer using JS rules (wrap around 2^32) and set Z flag + // note: this is part of ARM8.3 (JSCVT feature); support of this instruction needs to be checked at runtime + void fjcvtzs(RegisterA64 dst, RegisterA64 src); + + // Floating-point comparisons + void fcmp(RegisterA64 src1, RegisterA64 src2); + void fcmpz(RegisterA64 src); + void fcsel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); + // Run final checks bool finalize(); @@ -121,6 +157,7 @@ class AssemblyBuilderA64 std::string text; const bool logText = false; + const unsigned int features = 0; // Maximum immediate argument to functions like add/sub/cmp static constexpr size_t kMaxImmediate = (1 << 12) - 1; @@ -134,13 +171,15 @@ class AssemblyBuilderA64 void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); - void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size); + void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size, int sizelog); void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); void placeBR(const char* name, RegisterA64 src, uint32_t op); void placeADR(const char* name, RegisterA64 src, uint8_t op); void placeADR(const char* name, RegisterA64 src, uint8_t op, Label& label); - void placeP(const char* name, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src, uint8_t op, uint8_t size); + void placeP(const char* name, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src, uint8_t op, uint8_t opc, int sizelog); + void placeCS(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc); + void placeFCMP(const char* name, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t opc); void place(uint32_t word); @@ -164,6 +203,7 @@ class AssemblyBuilderA64 LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); LUAU_NOINLINE void log(const char* opcode, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); LUAU_NOINLINE void log(Label label); LUAU_NOINLINE void log(RegisterA64 reg); LUAU_NOINLINE void log(AddressA64 addr); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 17076ed69..2b2a849c6 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -41,6 +41,7 @@ enum class ABIX64 class AssemblyBuilderX64 { public: + explicit AssemblyBuilderX64(bool logText, ABIX64 abi); explicit AssemblyBuilderX64(bool logText); ~AssemblyBuilderX64(); diff --git a/CodeGen/include/Luau/ConditionA64.h b/CodeGen/include/Luau/ConditionA64.h index 0beadad52..e94adbcf7 100644 --- a/CodeGen/include/Luau/ConditionA64.h +++ b/CodeGen/include/Luau/ConditionA64.h @@ -8,28 +8,45 @@ namespace CodeGen namespace A64 { +// See Table C1-1 on page C1-229 of Arm ARM for A-profile architecture enum class ConditionA64 { + // EQ: integer (equal), floating-point (equal) Equal, + // NE: integer (not equal), floating-point (not equal or unordered) NotEqual, + // CS: integer (carry set), floating-point (greater than, equal or unordered) CarrySet, + // CC: integer (carry clear), floating-point (less than) CarryClear, + // MI: integer (negative), floating-point (less than) Minus, + // PL: integer (positive or zero), floating-point (greater than, equal or unordered) Plus, + // VS: integer (overflow), floating-point (unordered) Overflow, + // VC: integer (no overflow), floating-point (ordered) NoOverflow, + // HI: integer (unsigned higher), floating-point (greater than, or unordered) UnsignedGreater, + // LS: integer (unsigned lower or same), floating-point (less than or equal) UnsignedLessEqual, + // GE: integer (signed greater than or equal), floating-point (greater than or equal) GreaterEqual, + // LT: integer (signed less than), floating-point (less than, or unordered) Less, + + // GT: integer (signed greater than), floating-point (greater than) Greater, + // LE: integer (signed less than or equal), floating-point (less than, equal or unordered) LessEqual, + // AL: always Always, Count diff --git a/CodeGen/include/Luau/IrCallWrapperX64.h b/CodeGen/include/Luau/IrCallWrapperX64.h new file mode 100644 index 000000000..b70c8da62 --- /dev/null +++ b/CodeGen/include/Luau/IrCallWrapperX64.h @@ -0,0 +1,82 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrData.h" +#include "Luau/OperandX64.h" +#include "Luau/RegisterX64.h" + +#include + +// TODO: call wrapper can be used to suggest target registers for ScopedRegX64 to compute data into argument registers directly + +namespace Luau +{ +namespace CodeGen +{ +namespace X64 +{ + +// When IrInst operands are used, current instruction index is required to track lifetime +// In all other calls it is ok to omit the argument +constexpr uint32_t kInvalidInstIdx = ~0u; + +struct IrRegAllocX64; +struct ScopedRegX64; + +struct CallArgument +{ + SizeX64 targetSize = SizeX64::none; + + OperandX64 source = noreg; + IrOp sourceOp; + + OperandX64 target = noreg; + bool candidate = true; +}; + +class IrCallWrapperX64 +{ +public: + IrCallWrapperX64(IrRegAllocX64& regs, AssemblyBuilderX64& build, uint32_t instIdx = kInvalidInstIdx); + + void addArgument(SizeX64 targetSize, OperandX64 source, IrOp sourceOp = {}); + void addArgument(SizeX64 targetSize, ScopedRegX64& scopedReg); + + void call(const OperandX64& func); + + IrRegAllocX64& regs; + AssemblyBuilderX64& build; + uint32_t instIdx = ~0u; + +private: + void assignTargetRegisters(); + void countRegisterUses(); + CallArgument* findNonInterferingArgument(); + bool interferesWithOperand(const OperandX64& op, RegisterX64 reg) const; + bool interferesWithActiveSources(const CallArgument& targetArg, int targetArgIndex) const; + bool interferesWithActiveTarget(RegisterX64 sourceReg) const; + void moveToTarget(CallArgument& arg); + void freeSourceRegisters(CallArgument& arg); + void renameRegister(RegisterX64& target, RegisterX64 reg, RegisterX64 replacement); + void renameSourceRegisters(RegisterX64 reg, RegisterX64 replacement); + RegisterX64 findConflictingTarget() const; + + int getRegisterUses(RegisterX64 reg) const; + void addRegisterUse(RegisterX64 reg); + void removeRegisterUse(RegisterX64 reg); + + static const int kMaxCallArguments = 6; + std::array args; + int argCount = 0; + + OperandX64 funcOp; + + // Internal counters for remaining register use counts + std::array gprUses; + std::array xmmUses; +}; + +} // namespace X64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index e8b2bc621..752160817 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -125,6 +125,26 @@ enum class IrCmd : uint8_t // A: double UNM_NUM, + // Round number to negative infinity (math.floor) + // A: double + FLOOR_NUM, + + // Round number to positive infinity (math.ceil) + // A: double + CEIL_NUM, + + // Round number to nearest integer number, rounding half-way cases away from zero (math.round) + // A: double + ROUND_NUM, + + // Get square root of the argument (math.sqrt) + // A: double + SQRT_NUM, + + // Get absolute value of the argument (math.abs) + // A: double + ABS_NUM, + // Compute Luau 'not' operation on destructured TValue // A: tag // B: double @@ -252,6 +272,7 @@ enum class IrCmd : uint8_t // A: Rn (where to store the result) // B: Rn (lhs) // C: Rn or Kn (rhs) + // D: int (TMS enum with arithmetic type) DO_ARITH, // Get length of a TValue of any type @@ -382,54 +403,53 @@ enum class IrCmd : uint8_t // C: Rn (source start) // D: int (count or -1 to assign values up to stack top) // E: unsigned int (table index to start from) - LOP_SETLIST, + SETLIST, // Call specified function // A: Rn (function, followed by arguments) // B: int (argument count or -1 to use all arguments up to stack top) // C: int (result count or -1 to preserve all results and adjust stack top) // Note: return values are placed starting from Rn specified in 'A' - LOP_CALL, + CALL, // Return specified values from the function // A: Rn (value start) // B: int (result count or -1 to return all values up to stack top) - LOP_RETURN, + RETURN, // Adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue // A: Rn (loop variable start, updates Rn+2 and 'B' number of registers starting from Rn+3) // B: int (loop variable count, if more than 2, registers starting from Rn+5 are set to nil) // C: block (repeat) // D: block (exit) - LOP_FORGLOOP, + FORGLOOP, // Handle LOP_FORGLOOP fallback when variable being iterated is not a table - // A: unsigned int (bytecode instruction index) - // B: Rn (loop state start, updates Rn+2 and 'C' number of registers starting from Rn+3) - // C: int (loop variable count and a MSB set when it's an ipairs-like iteration loop) - // D: block (repeat) - // E: block (exit) - LOP_FORGLOOP_FALLBACK, + // A: Rn (loop state start, updates Rn+2 and 'B' number of registers starting from Rn+3) + // B: int (loop variable count and a MSB set when it's an ipairs-like iteration loop) + // C: block (repeat) + // D: block (exit) + FORGLOOP_FALLBACK, // Fallback for generic for loop preparation when iterating over builtin pairs/ipairs // It raises an error if 'B' register is not a function // A: unsigned int (bytecode instruction index) // B: Rn // C: block (forgloop location) - LOP_FORGPREP_XNEXT_FALLBACK, + FORGPREP_XNEXT_FALLBACK, // Perform `and` or `or` operation (selecting lhs or rhs based on whether the lhs is truthy) and put the result into target register // A: Rn (target) // B: Rn (lhs) // C: Rn or Kn (rhs) - LOP_AND, - LOP_ANDK, - LOP_OR, - LOP_ORK, + AND, + ANDK, + OR, + ORK, // Increment coverage data (saturating 24 bit add) // A: unsigned int (bytecode instruction index) - LOP_COVERAGE, + COVERAGE, // Operations that have a translation, but use a full instruction fallback @@ -676,6 +696,14 @@ struct IrFunction return instructions[op.index]; } + IrInst* asInstOp(IrOp op) + { + if (op.kind == IrOpKind::Inst) + return &instructions[op.index]; + + return nullptr; + } + IrConst& constOp(IrOp op) { LUAU_ASSERT(op.kind == IrOpKind::Constant); diff --git a/CodeGen/src/IrRegAllocX64.h b/CodeGen/include/Luau/IrRegAllocX64.h similarity index 85% rename from CodeGen/src/IrRegAllocX64.h rename to CodeGen/include/Luau/IrRegAllocX64.h index 497bb035c..c2486faf8 100644 --- a/CodeGen/src/IrRegAllocX64.h +++ b/CodeGen/include/Luau/IrRegAllocX64.h @@ -24,12 +24,17 @@ struct IrRegAllocX64 RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list oprefs); RegisterX64 allocXmmRegOrReuse(uint32_t index, std::initializer_list oprefs); - RegisterX64 takeGprReg(RegisterX64 reg); + RegisterX64 takeReg(RegisterX64 reg); void freeReg(RegisterX64 reg); void freeLastUseReg(IrInst& target, uint32_t index); void freeLastUseRegs(const IrInst& inst, uint32_t index); + bool isLastUseReg(const IrInst& target, uint32_t index) const; + + bool shouldFreeGpr(RegisterX64 reg) const; + + void assertFree(RegisterX64 reg) const; void assertAllFree() const; IrFunction& function; @@ -51,6 +56,8 @@ struct ScopedRegX64 void alloc(SizeX64 size); void free(); + RegisterX64 release(); + IrRegAllocX64& owner; RegisterX64 reg; }; diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 0fc140250..6e73e47a6 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -99,10 +99,10 @@ inline bool isBlockTerminator(IrCmd cmd) case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_ANY: case IrCmd::JUMP_SLOT_MATCH: - case IrCmd::LOP_RETURN: - case IrCmd::LOP_FORGLOOP: - case IrCmd::LOP_FORGLOOP_FALLBACK: - case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + case IrCmd::RETURN: + case IrCmd::FORGLOOP: + case IrCmd::FORGLOOP_FALLBACK: + case IrCmd::FORGPREP_XNEXT_FALLBACK: case IrCmd::FALLBACK_FORGPREP: return true; default: @@ -137,6 +137,11 @@ inline bool hasResult(IrCmd cmd) case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: + case IrCmd::FLOOR_NUM: + case IrCmd::CEIL_NUM: + case IrCmd::ROUND_NUM: + case IrCmd::SQRT_NUM: + case IrCmd::ABS_NUM: case IrCmd::NOT_ANY: case IrCmd::TABLE_LEN: case IrCmd::NEW_TABLE: diff --git a/CodeGen/include/Luau/RegisterA64.h b/CodeGen/include/Luau/RegisterA64.h index 519e83fcf..242e8b793 100644 --- a/CodeGen/include/Luau/RegisterA64.h +++ b/CodeGen/include/Luau/RegisterA64.h @@ -17,6 +17,8 @@ enum class KindA64 : uint8_t none, w, // 32-bit GPR x, // 64-bit GPR + d, // 64-bit SIMD&FP scalar + q, // 128-bit SIMD&FP vector }; struct RegisterA64 @@ -105,6 +107,72 @@ constexpr RegisterA64 xzr{KindA64::x, 31}; constexpr RegisterA64 sp{KindA64::none, 31}; +constexpr RegisterA64 d0{KindA64::d, 0}; +constexpr RegisterA64 d1{KindA64::d, 1}; +constexpr RegisterA64 d2{KindA64::d, 2}; +constexpr RegisterA64 d3{KindA64::d, 3}; +constexpr RegisterA64 d4{KindA64::d, 4}; +constexpr RegisterA64 d5{KindA64::d, 5}; +constexpr RegisterA64 d6{KindA64::d, 6}; +constexpr RegisterA64 d7{KindA64::d, 7}; +constexpr RegisterA64 d8{KindA64::d, 8}; +constexpr RegisterA64 d9{KindA64::d, 9}; +constexpr RegisterA64 d10{KindA64::d, 10}; +constexpr RegisterA64 d11{KindA64::d, 11}; +constexpr RegisterA64 d12{KindA64::d, 12}; +constexpr RegisterA64 d13{KindA64::d, 13}; +constexpr RegisterA64 d14{KindA64::d, 14}; +constexpr RegisterA64 d15{KindA64::d, 15}; +constexpr RegisterA64 d16{KindA64::d, 16}; +constexpr RegisterA64 d17{KindA64::d, 17}; +constexpr RegisterA64 d18{KindA64::d, 18}; +constexpr RegisterA64 d19{KindA64::d, 19}; +constexpr RegisterA64 d20{KindA64::d, 20}; +constexpr RegisterA64 d21{KindA64::d, 21}; +constexpr RegisterA64 d22{KindA64::d, 22}; +constexpr RegisterA64 d23{KindA64::d, 23}; +constexpr RegisterA64 d24{KindA64::d, 24}; +constexpr RegisterA64 d25{KindA64::d, 25}; +constexpr RegisterA64 d26{KindA64::d, 26}; +constexpr RegisterA64 d27{KindA64::d, 27}; +constexpr RegisterA64 d28{KindA64::d, 28}; +constexpr RegisterA64 d29{KindA64::d, 29}; +constexpr RegisterA64 d30{KindA64::d, 30}; +constexpr RegisterA64 d31{KindA64::d, 31}; + +constexpr RegisterA64 q0{KindA64::q, 0}; +constexpr RegisterA64 q1{KindA64::q, 1}; +constexpr RegisterA64 q2{KindA64::q, 2}; +constexpr RegisterA64 q3{KindA64::q, 3}; +constexpr RegisterA64 q4{KindA64::q, 4}; +constexpr RegisterA64 q5{KindA64::q, 5}; +constexpr RegisterA64 q6{KindA64::q, 6}; +constexpr RegisterA64 q7{KindA64::q, 7}; +constexpr RegisterA64 q8{KindA64::q, 8}; +constexpr RegisterA64 q9{KindA64::q, 9}; +constexpr RegisterA64 q10{KindA64::q, 10}; +constexpr RegisterA64 q11{KindA64::q, 11}; +constexpr RegisterA64 q12{KindA64::q, 12}; +constexpr RegisterA64 q13{KindA64::q, 13}; +constexpr RegisterA64 q14{KindA64::q, 14}; +constexpr RegisterA64 q15{KindA64::q, 15}; +constexpr RegisterA64 q16{KindA64::q, 16}; +constexpr RegisterA64 q17{KindA64::q, 17}; +constexpr RegisterA64 q18{KindA64::q, 18}; +constexpr RegisterA64 q19{KindA64::q, 19}; +constexpr RegisterA64 q20{KindA64::q, 20}; +constexpr RegisterA64 q21{KindA64::q, 21}; +constexpr RegisterA64 q22{KindA64::q, 22}; +constexpr RegisterA64 q23{KindA64::q, 23}; +constexpr RegisterA64 q24{KindA64::q, 24}; +constexpr RegisterA64 q25{KindA64::q, 25}; +constexpr RegisterA64 q26{KindA64::q, 26}; +constexpr RegisterA64 q27{KindA64::q, 27}; +constexpr RegisterA64 q28{KindA64::q, 28}; +constexpr RegisterA64 q29{KindA64::q, 29}; +constexpr RegisterA64 q30{KindA64::q, 30}; +constexpr RegisterA64 q31{KindA64::q, 31}; + } // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index bedd27409..e7f50b142 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -21,8 +21,9 @@ static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(C const unsigned kMaxAlign = 32; -AssemblyBuilderA64::AssemblyBuilderA64(bool logText) +AssemblyBuilderA64::AssemblyBuilderA64(bool logText, unsigned int features) : logText(logText) + , features(features) { data.resize(4096); dataPos = data.size(); // data is filled backwards @@ -39,6 +40,9 @@ AssemblyBuilderA64::~AssemblyBuilderA64() void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src) { + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp); + LUAU_ASSERT(dst.kind == src.kind || (dst.kind == KindA64::x && src == sp) || (dst == sp && src.kind == KindA64::x)); + if (dst == sp || src == sp) placeR1("mov", dst, src, 0b00'100010'0'000000000000); else @@ -115,6 +119,13 @@ void AssemblyBuilderA64::cmp(RegisterA64 src1, uint16_t src2) placeI12("cmp", dst, src1, src2, 0b11'10001); } +void AssemblyBuilderA64::csel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond) +{ + LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w); + + placeCS("csel", dst, src1, src2, cond, 0b11010'10'0, 0b00); +} + void AssemblyBuilderA64::and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { placeSR3("and", dst, src1, src2, 0b00'01010); @@ -157,54 +168,76 @@ void AssemblyBuilderA64::ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2 void AssemblyBuilderA64::clz(RegisterA64 dst, RegisterA64 src) { + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(dst.kind == src.kind); + placeR1("clz", dst, src, 0b10'11010110'00000'00010'0); } void AssemblyBuilderA64::rbit(RegisterA64 dst, RegisterA64 src) { + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(dst.kind == src.kind); + placeR1("rbit", dst, src, 0b10'11010110'00000'0000'00); } void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src) { - LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w); + LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w || dst.kind == KindA64::d || dst.kind == KindA64::q); - placeA("ldr", dst, src, 0b11100001, 0b10 | uint8_t(dst.kind == KindA64::x)); + switch (dst.kind) + { + case KindA64::w: + placeA("ldr", dst, src, 0b11100001, 0b10, 2); + break; + case KindA64::x: + placeA("ldr", dst, src, 0b11100001, 0b11, 3); + break; + case KindA64::d: + placeA("ldr", dst, src, 0b11110001, 0b11, 3); + break; + case KindA64::q: + placeA("ldr", dst, src, 0b11110011, 0b00, 4); + break; + case KindA64::none: + LUAU_ASSERT(!"Unexpected register kind"); + } } void AssemblyBuilderA64::ldrb(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::w); - placeA("ldrb", dst, src, 0b11100001, 0b00); + placeA("ldrb", dst, src, 0b11100001, 0b00, 2); } void AssemblyBuilderA64::ldrh(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::w); - placeA("ldrh", dst, src, 0b11100001, 0b01); + placeA("ldrh", dst, src, 0b11100001, 0b01, 2); } void AssemblyBuilderA64::ldrsb(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w); - placeA("ldrsb", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b00); + placeA("ldrsb", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b00, 0); } void AssemblyBuilderA64::ldrsh(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w); - placeA("ldrsh", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b01); + placeA("ldrsh", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b01, 1); } void AssemblyBuilderA64::ldrsw(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::x); - placeA("ldrsw", dst, src, 0b11100010, 0b10); + placeA("ldrsw", dst, src, 0b11100010, 0b10, 2); } void AssemblyBuilderA64::ldp(RegisterA64 dst1, RegisterA64 dst2, AddressA64 src) @@ -212,28 +245,44 @@ void AssemblyBuilderA64::ldp(RegisterA64 dst1, RegisterA64 dst2, AddressA64 src) LUAU_ASSERT(dst1.kind == KindA64::x || dst1.kind == KindA64::w); LUAU_ASSERT(dst1.kind == dst2.kind); - placeP("ldp", dst1, dst2, src, 0b101'0'010'1, 0b10 | uint8_t(dst1.kind == KindA64::x)); + placeP("ldp", dst1, dst2, src, 0b101'0'010'1, uint8_t(dst1.kind == KindA64::x) << 1, dst1.kind == KindA64::x ? 3 : 2); } void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst) { - LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w); + LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w || src.kind == KindA64::d || src.kind == KindA64::q); - placeA("str", src, dst, 0b11100000, 0b10 | uint8_t(src.kind == KindA64::x)); + switch (src.kind) + { + case KindA64::w: + placeA("str", src, dst, 0b11100000, 0b10, 2); + break; + case KindA64::x: + placeA("str", src, dst, 0b11100000, 0b11, 3); + break; + case KindA64::d: + placeA("str", src, dst, 0b11110000, 0b11, 3); + break; + case KindA64::q: + placeA("str", src, dst, 0b11110010, 0b00, 4); + break; + case KindA64::none: + LUAU_ASSERT(!"Unexpected register kind"); + } } void AssemblyBuilderA64::strb(RegisterA64 src, AddressA64 dst) { LUAU_ASSERT(src.kind == KindA64::w); - placeA("strb", src, dst, 0b11100000, 0b00); + placeA("strb", src, dst, 0b11100000, 0b00, 2); } void AssemblyBuilderA64::strh(RegisterA64 src, AddressA64 dst) { LUAU_ASSERT(src.kind == KindA64::w); - placeA("strh", src, dst, 0b11100000, 0b01); + placeA("strh", src, dst, 0b11100000, 0b01, 2); } void AssemblyBuilderA64::stp(RegisterA64 src1, RegisterA64 src2, AddressA64 dst) @@ -241,7 +290,7 @@ void AssemblyBuilderA64::stp(RegisterA64 src1, RegisterA64 src2, AddressA64 dst) LUAU_ASSERT(src1.kind == KindA64::x || src1.kind == KindA64::w); LUAU_ASSERT(src1.kind == src2.kind); - placeP("stp", src1, src2, dst, 0b101'0'010'0, 0b10 | uint8_t(src1.kind == KindA64::x)); + placeP("stp", src1, src2, dst, 0b101'0'010'0, uint8_t(src1.kind == KindA64::x) << 1, src1.kind == KindA64::x ? 3 : 2); } void AssemblyBuilderA64::b(Label& label) @@ -318,6 +367,145 @@ void AssemblyBuilderA64::adr(RegisterA64 dst, Label& label) placeADR("adr", dst, 0b10000, label); } +void AssemblyBuilderA64::fmov(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("fmov", dst, src, 0b000'11110'01'1'0000'00'10000); +} + +void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000); +} + +void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src1.kind == KindA64::d && src2.kind == KindA64::d); + + placeR3("fadd", dst, src1, src2, 0b11110'01'1, 0b0010'10); +} + +void AssemblyBuilderA64::fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src1.kind == KindA64::d && src2.kind == KindA64::d); + + placeR3("fdiv", dst, src1, src2, 0b11110'01'1, 0b0001'10); +} + +void AssemblyBuilderA64::fmul(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src1.kind == KindA64::d && src2.kind == KindA64::d); + + placeR3("fmul", dst, src1, src2, 0b11110'01'1, 0b0000'10); +} + +void AssemblyBuilderA64::fneg(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("fneg", dst, src, 0b000'11110'01'1'0000'10'10000); +} + +void AssemblyBuilderA64::fsqrt(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("fsqrt", dst, src, 0b000'11110'01'1'0000'11'10000); +} + +void AssemblyBuilderA64::fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src1.kind == KindA64::d && src2.kind == KindA64::d); + + placeR3("fsub", dst, src1, src2, 0b11110'01'1, 0b0011'10); +} + +void AssemblyBuilderA64::frinta(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("frinta", dst, src, 0b000'11110'01'1'001'100'10000); +} + +void AssemblyBuilderA64::frintm(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("frintm", dst, src, 0b000'11110'01'1'001'010'10000); +} + +void AssemblyBuilderA64::frintp(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("frintp", dst, src, 0b000'11110'01'1'001'001'10000); +} + +void AssemblyBuilderA64::fcvtzs(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(src.kind == KindA64::d); + + placeR1("fcvtzs", dst, src, 0b000'11110'01'1'11'000'000000); +} + +void AssemblyBuilderA64::fcvtzu(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(src.kind == KindA64::d); + + placeR1("fcvtzu", dst, src, 0b000'11110'01'1'11'001'000000); +} + +void AssemblyBuilderA64::scvtf(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d); + LUAU_ASSERT(src.kind == KindA64::w || src.kind == KindA64::x); + + placeR1("scvtf", dst, src, 0b000'11110'01'1'00'010'000000); +} + +void AssemblyBuilderA64::ucvtf(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d); + LUAU_ASSERT(src.kind == KindA64::w || src.kind == KindA64::x); + + placeR1("ucvtf", dst, src, 0b000'11110'01'1'00'011'000000); +} + +void AssemblyBuilderA64::fjcvtzs(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::w); + LUAU_ASSERT(src.kind == KindA64::d); + LUAU_ASSERT(features & Feature_JSCVT); + + placeR1("fjcvtzs", dst, src, 0b000'11110'01'1'11'110'000000); +} + +void AssemblyBuilderA64::fcmp(RegisterA64 src1, RegisterA64 src2) +{ + LUAU_ASSERT(src1.kind == KindA64::d && src2.kind == KindA64::d); + + placeFCMP("fcmp", src1, src2, 0b11110'01'1, 0b00); +} + +void AssemblyBuilderA64::fcmpz(RegisterA64 src) +{ + LUAU_ASSERT(src.kind == KindA64::d); + + placeFCMP("fcmp", src, {src.kind, 0}, 0b11110'01'1, 0b01); +} + +void AssemblyBuilderA64::fcsel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond) +{ + LUAU_ASSERT(dst.kind == KindA64::d); + + placeCS("fcsel", dst, src1, src2, cond, 0b11110'01'1, 0b11); +} + bool AssemblyBuilderA64::finalize() { code.resize(codePos - code.data()); @@ -429,7 +617,7 @@ void AssemblyBuilderA64::placeR3(const char* name, RegisterA64 dst, RegisterA64 if (logText) log(name, dst, src1, src2); - LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst.kind == KindA64::d); LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind); uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; @@ -443,10 +631,7 @@ void AssemblyBuilderA64::placeR1(const char* name, RegisterA64 dst, RegisterA64 if (logText) log(name, dst, src); - LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp); - LUAU_ASSERT(dst.kind == src.kind || (dst.kind == KindA64::x && src == sp) || (dst == sp && src.kind == KindA64::x)); - - uint32_t sf = (dst.kind != KindA64::w) ? 0x80000000 : 0; + uint32_t sf = (dst.kind == KindA64::x || src.kind == KindA64::x) ? 0x80000000 : 0; place(dst.index | (src.index << 5) | (op << 10) | sf); commit(); @@ -482,7 +667,7 @@ void AssemblyBuilderA64::placeI16(const char* name, RegisterA64 dst, int src, ui commit(); } -void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size) +void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size, int sizelog) { if (logText) log(name, dst, src); @@ -490,8 +675,8 @@ void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 sr switch (src.kind) { case AddressKindA64::imm: - if (src.data >= 0 && src.data % (1 << size) == 0) - place(dst.index | (src.base.index << 5) | ((src.data >> size) << 10) | (op << 22) | (1 << 24) | (size << 30)); + if (src.data >= 0 && (src.data >> sizelog) < 1024 && (src.data & ((1 << sizelog) - 1)) == 0) + place(dst.index | (src.base.index << 5) | ((src.data >> sizelog) << 10) | (op << 22) | (1 << 24) | (size << 30)); else if (src.data >= -256 && src.data <= 255) place(dst.index | (src.base.index << 5) | ((src.data & ((1 << 9) - 1)) << 12) | (op << 22) | (size << 30)); else @@ -566,16 +751,45 @@ void AssemblyBuilderA64::placeADR(const char* name, RegisterA64 dst, uint8_t op, log(name, dst, label); } -void AssemblyBuilderA64::placeP(const char* name, RegisterA64 src1, RegisterA64 src2, AddressA64 dst, uint8_t op, uint8_t size) +void AssemblyBuilderA64::placeP(const char* name, RegisterA64 src1, RegisterA64 src2, AddressA64 dst, uint8_t op, uint8_t opc, int sizelog) { if (logText) log(name, src1, src2, dst); LUAU_ASSERT(dst.kind == AddressKindA64::imm); - LUAU_ASSERT(dst.data >= -128 * (1 << size) && dst.data <= 127 * (1 << size)); - LUAU_ASSERT(dst.data % (1 << size) == 0); + LUAU_ASSERT(dst.data >= -128 * (1 << sizelog) && dst.data <= 127 * (1 << sizelog)); + LUAU_ASSERT(dst.data % (1 << sizelog) == 0); + + place(src1.index | (dst.base.index << 5) | (src2.index << 10) | (((dst.data >> sizelog) & 127) << 15) | (op << 22) | (opc << 30)); + commit(); +} + +void AssemblyBuilderA64::placeCS(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc) +{ + if (logText) + log(name, dst, src1, src2, cond); + + LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind); + + uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; + + place(dst.index | (src1.index << 5) | (opc << 10) | (codeForCondition[int(cond)] << 12) | (src2.index << 16) | (op << 21) | sf); + commit(); +} + +void AssemblyBuilderA64::placeFCMP(const char* name, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t opc) +{ + if (logText) + { + if (opc) + log(name, src1, 0); + else + log(name, src1, src2); + } - place(src1.index | (dst.base.index << 5) | (src2.index << 10) | (((dst.data >> size) & 127) << 15) | (op << 22) | (size << 31)); + LUAU_ASSERT(src1.kind == src2.kind); + + place((opc << 3) | (src1.index << 5) | (0b1000 << 10) | (src2.index << 16) | (op << 21)); commit(); } @@ -747,6 +961,19 @@ void AssemblyBuilderA64::log(const char* opcode, Label label) logAppend(" %-12s.L%d\n", opcode, label.id); } +void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond) +{ + logAppend(" %-12s", opcode); + log(dst); + text.append(","); + log(src1); + text.append(","); + log(src2); + text.append(","); + text.append(textForCondition[int(cond)] + 2); // skip b. + text.append("\n"); +} + void AssemblyBuilderA64::log(Label label) { logAppend(".L%d:\n", label.id); @@ -770,6 +997,14 @@ void AssemblyBuilderA64::log(RegisterA64 reg) logAppend("x%d", reg.index); break; + case KindA64::d: + logAppend("d%d", reg.index); + break; + + case KindA64::q: + logAppend("q%d", reg.index); + break; + case KindA64::none: if (reg.index == 31) text.append("sp"); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index bf7889b89..0285c2a16 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -71,9 +71,9 @@ static ABIX64 getCurrentX64ABI() #endif } -AssemblyBuilderX64::AssemblyBuilderX64(bool logText) +AssemblyBuilderX64::AssemblyBuilderX64(bool logText, ABIX64 abi) : logText(logText) - , abi(getCurrentX64ABI()) + , abi(abi) { data.resize(4096); dataPos = data.size(); // data is filled backwards @@ -83,6 +83,11 @@ AssemblyBuilderX64::AssemblyBuilderX64(bool logText) codeEnd = code.data() + code.size(); } +AssemblyBuilderX64::AssemblyBuilderX64(bool logText) + : AssemblyBuilderX64(logText, getCurrentX64ABI()) +{ +} + AssemblyBuilderX64::~AssemblyBuilderX64() { LUAU_ASSERT(finalized); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 5ef5ba64f..b0cc8d9cd 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -43,6 +43,12 @@ #endif #endif +#if defined(__aarch64__) +#ifdef __APPLE__ +#include +#endif +#endif + LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false) namespace Luau @@ -209,7 +215,7 @@ static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& } } -[[maybe_unused]] static void lowerIr( +[[maybe_unused]] static bool lowerIr( X64::AssemblyBuilderX64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { constexpr uint32_t kFunctionAlignment = 32; @@ -221,31 +227,21 @@ static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& X64::IrLoweringX64 lowering(build, helpers, data, ir.function); lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); + + return true; } -[[maybe_unused]] static void lowerIr( +[[maybe_unused]] static bool lowerIr( A64::AssemblyBuilderA64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { - if (A64::IrLoweringA64::canLower(ir.function)) - { - A64::IrLoweringA64 lowering(build, helpers, data, proto, ir.function); + if (!A64::IrLoweringA64::canLower(ir.function)) + return false; - lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); - } - else - { - // TODO: This is only needed while we don't support all IR opcodes - // When we can't translate some parts of the function, we instead encode a dummy assembly sequence that hands off control to VM - // In the future we could return nullptr from assembleFunction and handle it because there may be other reasons for why we refuse to assemble. - Label start = build.setLabel(); + A64::IrLoweringA64 lowering(build, helpers, data, proto, ir.function); - build.mov(A64::x0, 1); // finish function in VM - build.ldr(A64::x1, A64::mem(A64::rNativeContext, offsetof(NativeContext, gateExit))); - build.br(A64::x1); + lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); - for (int i = 0; i < proto->sizecode; i++) - ir.function.bcMapping[i].asmLocation = build.getLabelOffset(start); - } + return true; } template @@ -289,7 +285,13 @@ static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, constPropInBlockChains(ir); } - lowerIr(build, ir, data, helpers, proto, options); + if (!lowerIr(build, ir, data, helpers, proto, options)) + { + if (build.logText) + build.logAppend("; skipping (can't lower)\n\n"); + + return nullptr; + } if (build.logText) build.logAppend("\n"); @@ -345,6 +347,22 @@ static void onSetBreakpoint(lua_State* L, Proto* proto, int instruction) LUAU_ASSERT(!"native breakpoints are not implemented"); } +#if defined(__aarch64__) +static unsigned int getCpuFeaturesA64() +{ + unsigned int result = 0; + +#ifdef __APPLE__ + int jscvt = 0; + size_t jscvtLen = sizeof(jscvt); + if (sysctlbyname("hw.optional.arm.FEAT_JSCVT", &jscvt, &jscvtLen, nullptr, 0) == 0 && jscvt == 1) + result |= A64::Feature_JSCVT; +#endif + + return result; +} +#endif + bool isSupported() { #if !LUA_CUSTOM_EXECUTION @@ -374,8 +392,20 @@ bool isSupported() return true; #elif defined(__aarch64__) + if (LUA_EXTRA_SIZE != 1) + return false; + + if (sizeof(TValue) != 16) + return false; + + if (sizeof(LuaNode) != 32) + return false; + // TODO: A64 codegen does not generate correct unwind info at the moment so it requires longjmp instead of C++ exceptions - return bool(LUA_USE_LONGJMP); + if (!LUA_USE_LONGJMP) + return false; + + return true; #else return false; #endif @@ -447,7 +477,7 @@ void compile(lua_State* L, int idx) return; #if defined(__aarch64__) - A64::AssemblyBuilderA64 build(/* logText= */ false); + A64::AssemblyBuilderA64 build(/* logText= */ false, getCpuFeaturesA64()); #else X64::AssemblyBuilderX64 build(/* logText= */ false); #endif @@ -470,10 +500,15 @@ void compile(lua_State* L, int idx) // Skip protos that have been compiled during previous invocations of CodeGen::compile for (Proto* p : protos) if (p && getProtoExecData(p) == nullptr) - results.push_back(assembleFunction(build, *data, helpers, p, {})); + if (NativeProto* np = assembleFunction(build, *data, helpers, p, {})) + results.push_back(np); build.finalize(); + // If no functions were assembled, we don't need to allocate/copy executable pages for helpers + if (results.empty()) + return; + uint8_t* nativeData = nullptr; size_t sizeNativeData = 0; uint8_t* codeStart = nullptr; @@ -507,7 +542,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) const TValue* func = luaA_toobject(L, idx); #if defined(__aarch64__) - A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly); + A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, getCpuFeaturesA64()); #else X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); #endif @@ -527,10 +562,8 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) for (Proto* p : protos) if (p) - { - NativeProto* nativeProto = assembleFunction(build, data, helpers, p, options); - destroyNativeProto(nativeProto); - } + if (NativeProto* np = assembleFunction(build, data, helpers, p, options)) + destroyNativeProto(np); build.finalize(); diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 028b3327c..e7a1e2e21 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -100,6 +100,16 @@ void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers) build.logAppend("; exitNoContinueVm\n"); helpers.exitNoContinueVm = build.setLabel(); emitExit(build, /* continueInVm */ false); + + if (build.logText) + build.logAppend("; reentry\n"); + helpers.reentry = build.setLabel(); + emitReentry(build, helpers); + + if (build.logText) + build.logAppend("; interrupt\n"); + helpers.interrupt = build.setLabel(); + emitInterrupt(build); } } // namespace A64 diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 26568c300..ae3dbd452 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -126,7 +126,89 @@ void callEpilogC(lua_State* L, int nresults, int n) L->top = (nresults == LUA_MULTRET) ? res : cip->top; } -const Instruction* returnFallback(lua_State* L, StkId ra, int n) +// Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc +Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) +{ + // slow-path: not a function call + if (LUAU_UNLIKELY(!ttisfunction(ra))) + { + luaV_tryfuncTM(L, ra); + argtop++; // __call adds an extra self + } + + Closure* ccl = clvalue(ra); + + CallInfo* ci = incr_ci(L); + ci->func = ra; + ci->base = ra + 1; + ci->top = argtop + ccl->stacksize; // note: technically UB since we haven't reallocated the stack yet + ci->savedpc = NULL; + ci->flags = 0; + ci->nresults = nresults; + + L->base = ci->base; + L->top = argtop; + + // note: this reallocs stack, but we don't need to VM_PROTECT this + // this is because we're going to modify base/savedpc manually anyhow + // crucially, we can't use ra/argtop after this line + luaD_checkstack(L, ccl->stacksize); + + LUAU_ASSERT(ci->top <= L->stack_last); + + if (!ccl->isC) + { + Proto* p = ccl->l.p; + + // fill unused parameters with nil + StkId argi = L->top; + StkId argend = L->base + p->numparams; + while (argi < argend) + setnilvalue(argi++); // complete missing arguments + L->top = p->is_vararg ? argi : ci->top; + + // keep executing new function + ci->savedpc = p->code; + return ccl; + } + else + { + lua_CFunction func = ccl->c.f; + int n = func(L); + + // yield + if (n < 0) + return NULL; + + // ci is our callinfo, cip is our parent + CallInfo* ci = L->ci; + CallInfo* cip = ci - 1; + + // copy return values into parent stack (but only up to nresults!), fill the rest with nil + // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally + StkId res = ci->func; + StkId vali = L->top - n; + StkId valend = L->top; + + int i; + for (i = nresults; i != 0 && vali < valend; i--) + setobj2s(L, res++, vali++); + while (i-- > 0) + setnilvalue(res++); + + // pop the stack frame + L->ci = cip; + L->base = cip->base; + L->top = (nresults == LUA_MULTRET) ? res : cip->top; + + // keep executing current function + LUAU_ASSERT(isLua(cip)); + return clvalue(cip->func); + } +} + +// Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts +Closure* returnFallback(lua_State* L, StkId ra, int n) { // ci is our callinfo, cip is our parent CallInfo* ci = L->ci; @@ -159,8 +241,9 @@ const Instruction* returnFallback(lua_State* L, StkId ra, int n) return NULL; } + // keep executing new function LUAU_ASSERT(isLua(cip)); - return cip->savedpc; + return clvalue(cip->func); } } // namespace CodeGen diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 5d37bfd16..6066a691c 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -16,7 +16,8 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc); Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults); void callEpilogC(lua_State* L, int nresults, int n); -const Instruction* returnFallback(lua_State* L, StkId ra, int n); +Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); +Closure* returnFallback(lua_State* L, StkId ra, int n); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index d70b6ed8b..2e745cbf2 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -3,9 +3,10 @@ #include "Luau/AssemblyBuilderX64.h" #include "Luau/Bytecode.h" +#include "Luau/IrCallWrapperX64.h" +#include "Luau/IrRegAllocX64.h" #include "EmitCommonX64.h" -#include "IrRegAllocX64.h" #include "NativeState.h" #include "lstate.h" @@ -19,40 +20,11 @@ namespace CodeGen namespace X64 { -void emitBuiltinMathFloor(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - build.vroundsd(tmp.reg, tmp.reg, luauRegValue(arg), RoundingModeX64::RoundToNegativeInfinity); - build.vmovsd(luauRegValue(ra), tmp.reg); -} - -void emitBuiltinMathCeil(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - build.vroundsd(tmp.reg, tmp.reg, luauRegValue(arg), RoundingModeX64::RoundToPositiveInfinity); - build.vmovsd(luauRegValue(ra), tmp.reg); -} - -void emitBuiltinMathSqrt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - build.vsqrtsd(tmp.reg, tmp.reg, luauRegValue(arg)); - build.vmovsd(luauRegValue(ra), tmp.reg); -} - -void emitBuiltinMathAbs(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - build.vmovsd(tmp.reg, luauRegValue(arg)); - build.vandpd(tmp.reg, tmp.reg, build.i64(~(1LL << 63))); - build.vmovsd(luauRegValue(ra), tmp.reg); -} - static void emitBuiltinMathSingleArgFunc(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int32_t offset) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - build.call(qword[rNativeContext + offset]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.call(qword[rNativeContext + offset]); build.vmovsd(luauRegValue(ra), xmm0); } @@ -64,20 +36,10 @@ void emitBuiltinMathExp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npar void emitBuiltinMathFmod(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); - build.call(qword[rNativeContext + offsetof(NativeContext, libm_fmod)]); - - build.vmovsd(luauRegValue(ra), xmm0); -} - -void emitBuiltinMathPow(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); - build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.addArgument(SizeX64::xmmword, qword[args + offsetof(TValue, value)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_fmod)]); build.vmovsd(luauRegValue(ra), xmm0); } @@ -129,10 +91,10 @@ void emitBuiltinMathTanh(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npa void emitBuiltinMathAtan2(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); - build.call(qword[rNativeContext + offsetof(NativeContext, libm_atan2)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.addArgument(SizeX64::xmmword, qword[args + offsetof(TValue, value)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_atan2)]); build.vmovsd(luauRegValue(ra), xmm0); } @@ -194,46 +156,23 @@ void emitBuiltinMathLog(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npar void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); + ScopedRegX64 tmp{regs, SizeX64::qword}; + build.vcvttsd2si(tmp.reg, qword[args + offsetof(TValue, value)]); - if (build.abi == ABIX64::Windows) - build.vcvttsd2si(rArg2, qword[args + offsetof(TValue, value)]); - else - build.vcvttsd2si(rArg1, qword[args + offsetof(TValue, value)]); - - build.call(qword[rNativeContext + offsetof(NativeContext, libm_ldexp)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_ldexp)]); build.vmovsd(luauRegValue(ra), xmm0); } -void emitBuiltinMathRound(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - ScopedRegX64 tmp0{regs, SizeX64::xmmword}; - ScopedRegX64 tmp1{regs, SizeX64::xmmword}; - ScopedRegX64 tmp2{regs, SizeX64::xmmword}; - - build.vmovsd(tmp0.reg, luauRegValue(arg)); - build.vandpd(tmp1.reg, tmp0.reg, build.f64x2(-0.0, -0.0)); - build.vmovsd(tmp2.reg, build.i64(0x3fdfffffffffffff)); // 0.49999999999999994 - build.vorpd(tmp1.reg, tmp1.reg, tmp2.reg); - build.vaddsd(tmp0.reg, tmp0.reg, tmp1.reg); - build.vroundsd(tmp0.reg, tmp0.reg, tmp0.reg, RoundingModeX64::RoundToZero); - - build.vmovsd(luauRegValue(ra), tmp0.reg); -} - void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - - if (build.abi == ABIX64::Windows) - build.lea(rArg2, sTemporarySlot); - else - build.lea(rArg1, sTemporarySlot); - - build.call(qword[rNativeContext + offsetof(NativeContext, libm_frexp)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.addArgument(SizeX64::qword, sTemporarySlot); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_frexp)]); build.vmovsd(luauRegValue(ra), xmm0); @@ -243,15 +182,10 @@ void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int np void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - - if (build.abi == ABIX64::Windows) - build.lea(rArg2, sTemporarySlot); - else - build.lea(rArg1, sTemporarySlot); - - build.call(qword[rNativeContext + offsetof(NativeContext, libm_modf)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.addArgument(SizeX64::qword, sTemporarySlot); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_modf)]); build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra), xmm1); @@ -301,12 +235,10 @@ void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(arg)); - - build.call(qword[rNativeContext + offsetof(NativeContext, luaT_objtypenamestr)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(arg)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_objtypenamestr)]); build.mov(luauRegValue(ra), rax); } @@ -328,22 +260,18 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r case LBF_MATH_MIN: case LBF_MATH_MAX: case LBF_MATH_CLAMP: - // These instructions are fully translated to IR - break; case LBF_MATH_FLOOR: - return emitBuiltinMathFloor(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_CEIL: - return emitBuiltinMathCeil(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SQRT: - return emitBuiltinMathSqrt(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_POW: case LBF_MATH_ABS: - return emitBuiltinMathAbs(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_ROUND: + // These instructions are fully translated to IR + break; case LBF_MATH_EXP: return emitBuiltinMathExp(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_FMOD: return emitBuiltinMathFmod(regs, build, nparams, ra, arg, argsOp, nresults); - case LBF_MATH_POW: - return emitBuiltinMathPow(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ASIN: return emitBuiltinMathAsin(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SIN: @@ -370,8 +298,6 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r return emitBuiltinMathLog(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_LDEXP: return emitBuiltinMathLdexp(regs, build, nparams, ra, arg, argsOp, nresults); - case LBF_MATH_ROUND: - return emitBuiltinMathRound(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_FREXP: return emitBuiltinMathFrexp(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_MODF: diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index 3c41c271d..a71eafd4c 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -20,9 +20,16 @@ constexpr unsigned kOffsetOfInstructionC = 3; // Leaf functions that are placed in every module to perform common instruction sequences struct ModuleHelpers { + // A64/X64 Label exitContinueVm; Label exitNoContinueVm; + + // X64 Label continueCallInVm; + + // A64 + Label reentry; // x0: closure + Label interrupt; // x0: pc offset, x1: return address, x2: interrupt }; } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonA64.cpp b/CodeGen/src/EmitCommonA64.cpp index 66810d379..2b4bbaba1 100644 --- a/CodeGen/src/EmitCommonA64.cpp +++ b/CodeGen/src/EmitCommonA64.cpp @@ -11,6 +11,11 @@ namespace CodeGen namespace A64 { +void emitUpdateBase(AssemblyBuilderA64& build) +{ + build.ldr(rBase, mem(rState, offsetof(lua_State, base))); +} + void emitExit(AssemblyBuilderA64& build, bool continueInVm) { build.mov(x0, continueInVm); @@ -18,56 +23,82 @@ void emitExit(AssemblyBuilderA64& build, bool continueInVm) build.br(x1); } -void emitUpdateBase(AssemblyBuilderA64& build) +void emitInterrupt(AssemblyBuilderA64& build) { - build.ldr(rBase, mem(rState, offsetof(lua_State, base))); -} + // x0 = pc offset + // x1 = return address in native code + // x2 = interrupt -void emitSetSavedPc(AssemblyBuilderA64& build, int pcpos) -{ - if (pcpos * sizeof(Instruction) <= AssemblyBuilderA64::kMaxImmediate) - { - build.add(x0, rCode, uint16_t(pcpos * sizeof(Instruction))); - } - else - { - build.mov(x0, pcpos * sizeof(Instruction)); - build.add(x0, rCode, x0); - } + // Stash return address in rBase; we need to reload rBase anyway + build.mov(rBase, x1); + // Update savedpc; required in case interrupt errors + build.add(x0, rCode, x0); build.ldr(x1, mem(rState, offsetof(lua_State, ci))); build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); -} - -void emitInterrupt(AssemblyBuilderA64& build, int pcpos) -{ - Label skip; - - build.ldr(x2, mem(rState, offsetof(lua_State, global))); - build.ldr(x2, mem(x2, offsetof(global_State, cb.interrupt))); - build.cbz(x2, skip); - - emitSetSavedPc(build, pcpos + 1); // uses x0/x1 // Call interrupt - // TODO: This code should be outlined so that it can be shared by multiple interruptible instructions build.mov(x0, rState); build.mov(w1, -1); build.blr(x2); // Check if we need to exit + Label skip; build.ldrb(w0, mem(rState, offsetof(lua_State, status))); build.cbz(w0, skip); // L->ci->savedpc-- - build.ldr(x0, mem(rState, offsetof(lua_State, ci))); - build.ldr(x1, mem(x0, offsetof(CallInfo, savedpc))); - build.sub(x1, x1, sizeof(Instruction)); - build.str(x1, mem(x0, offsetof(CallInfo, savedpc))); + // note: recomputing this avoids having to stash x0 + build.ldr(x1, mem(rState, offsetof(lua_State, ci))); + build.ldr(x0, mem(x1, offsetof(CallInfo, savedpc))); + build.sub(x0, x0, sizeof(Instruction)); + build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); emitExit(build, /* continueInVm */ false); build.setLabel(skip); + + // Return back to caller; rBase has stashed return address + build.mov(x0, rBase); + + emitUpdateBase(build); // interrupt may have reallocated stack + + build.br(x0); +} + +void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) +{ + // x0 = closure object to reentry (equal to clvalue(L->ci->func)) + + // If the fallback requested an exit, we need to do this right away + build.cbz(x0, helpers.exitNoContinueVm); + + emitUpdateBase(build); + + // Need to update state of the current function before we jump away + build.ldr(x1, mem(x0, offsetof(Closure, l.p))); // cl->l.p aka proto + + build.mov(rClosure, x0); + build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k + build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code + + // Get instruction index from instruction pointer + // To get instruction index from instruction pointer, we need to divide byte offset by 4 + // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out + build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci + build.ldr(x2, mem(x2, offsetof(CallInfo, savedpc))); // L->ci->savedpc + build.sub(x2, x2, rCode); + build.add(x2, x2, x2); // TODO: this would not be necessary if we supported shifted register offsets in loads + + // We need to check if the new function can be executed natively + // TODO: This can be done earlier in the function flow, to reduce the JIT->VM transition penalty + build.ldr(x1, mem(x1, offsetofProtoExecData)); + build.cbz(x1, helpers.exitContinueVm); + + // Get new instruction location and jump to it + build.ldr(x1, mem(x1, offsetof(NativeProto, instTargets))); + build.ldr(x1, mem(x1, x2)); + build.br(x1); } } // namespace A64 diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index 251f6a351..5ca9c5586 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -11,7 +11,7 @@ // AArch64 ABI reminder: // Arguments: x0-x7, v0-v7 // Return: x0, v0 (or x8 that points to the address of the resulting structure) -// Volatile: x9-x14, v16-v31 ("caller-saved", any call may change them) +// Volatile: x9-x15, v16-v31 ("caller-saved", any call may change them) // Non-volatile: x19-x28, v8-v15 ("callee-saved", preserved after calls, only bottom half of SIMD registers is preserved!) // Reserved: x16-x18: reserved for linker/platform use; x29: frame pointer (unless omitted); x30: link register; x31: stack pointer @@ -25,52 +25,27 @@ struct NativeState; namespace A64 { -// Data that is very common to access is placed in non-volatile registers +// Data that is very common to access is placed in non-volatile registers: +// 1. Constant registers (only loaded during codegen entry) constexpr RegisterA64 rState = x19; // lua_State* L -constexpr RegisterA64 rBase = x20; // StkId base -constexpr RegisterA64 rNativeContext = x21; // NativeContext* context -constexpr RegisterA64 rConstants = x22; // TValue* k -constexpr RegisterA64 rClosure = x23; // Closure* cl -constexpr RegisterA64 rCode = x24; // Instruction* code +constexpr RegisterA64 rNativeContext = x20; // NativeContext* context + +// 2. Frame registers (reloaded when call frame changes; rBase is also reloaded after all calls that may reallocate stack) +constexpr RegisterA64 rConstants = x21; // TValue* k +constexpr RegisterA64 rClosure = x22; // Closure* cl +constexpr RegisterA64 rCode = x23; // Instruction* code +constexpr RegisterA64 rBase = x24; // StkId base // Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point // See CodeGenA64.cpp for layout constexpr unsigned kStackSize = 64; // 8 stashed registers -inline AddressA64 luauReg(int ri) -{ - return mem(rBase, ri * sizeof(TValue)); -} - -inline AddressA64 luauRegValue(int ri) -{ - return mem(rBase, ri * sizeof(TValue) + offsetof(TValue, value)); -} - -inline AddressA64 luauRegTag(int ri) -{ - return mem(rBase, ri * sizeof(TValue) + offsetof(TValue, tt)); -} - -inline AddressA64 luauConstant(int ki) -{ - return mem(rConstants, ki * sizeof(TValue)); -} - -inline AddressA64 luauConstantTag(int ki) -{ - return mem(rConstants, ki * sizeof(TValue) + offsetof(TValue, tt)); -} - -inline AddressA64 luauConstantValue(int ki) -{ - return mem(rConstants, ki * sizeof(TValue) + offsetof(TValue, value)); -} +void emitUpdateBase(AssemblyBuilderA64& build); +// TODO: Move these to CodeGenA64 so that they can't be accidentally called during lowering void emitExit(AssemblyBuilderA64& build, bool continueInVm); -void emitUpdateBase(AssemblyBuilderA64& build); -void emitSetSavedPc(AssemblyBuilderA64& build, int pcpos); // invalidates x0/x1 -void emitInterrupt(AssemblyBuilderA64& build, int pcpos); +void emitInterrupt(AssemblyBuilderA64& build); +void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers); } // namespace A64 } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index e9cfdc486..7db4068d0 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -2,7 +2,9 @@ #include "EmitCommonX64.h" #include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrCallWrapperX64.h" #include "Luau/IrData.h" +#include "Luau/IrRegAllocX64.h" #include "CustomExecUtils.h" #include "NativeState.h" @@ -64,18 +66,19 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, } } -void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label) +void jumpOnAnyCmpFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label) { - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(rb)); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); if (cond == IrCondition::NotLessEqual || cond == IrCondition::LessEqual) - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]); else if (cond == IrCondition::NotLess || cond == IrCondition::Less) - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]); else if (cond == IrCondition::NotEqual || cond == IrCondition::Equal) - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]); else LUAU_ASSERT(!"Unsupported condition"); @@ -119,68 +122,66 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi build.jcc(ConditionX64::NotZero, label); } -void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm) +void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm) { - if (build.abi == ABIX64::Windows) - build.mov(sArg5, tm); - else - build.mov(rArg5, tm); - - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(rb)); - build.lea(rArg4, c); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); + callWrap.addArgument(SizeX64::qword, c); + callWrap.addArgument(SizeX64::dword, tm); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); emitUpdateBase(build); } -void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb) +void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb) { - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(rb)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_dolen)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_dolen)]); emitUpdateBase(build); } -void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init) +void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init) { - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(limit)); - build.lea(rArg3, luauRegAddress(step)); - build.lea(rArg4, luauRegAddress(init)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(limit)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(step)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(init)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]); } -void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) +void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) { - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(rb)); - build.lea(rArg3, c); - build.lea(rArg4, luauRegAddress(ra)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_gettable)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); + callWrap.addArgument(SizeX64::qword, c); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_gettable)]); emitUpdateBase(build); } -void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) +void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) { - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(rb)); - build.lea(rArg3, c); - build.lea(rArg4, luauRegAddress(ra)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_settable)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); + callWrap.addArgument(SizeX64::qword, c); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_settable)]); emitUpdateBase(build); } -// works for luaC_barriertable, luaC_barrierf -static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip, int contextOffset) +void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip) { - LUAU_ASSERT(tmp != object); - // iscollectable(ra) build.cmp(luauRegTag(ra), LUA_TSTRING); build.jcc(ConditionX64::Less, skip); @@ -193,88 +194,52 @@ static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, Register build.mov(tmp, luauRegValue(ra)); build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT)); build.jcc(ConditionX64::Zero, skip); - - // TODO: even with re-ordering we have a chance of failure, we have a task to fix this in the future - if (object == rArg3) - { - LUAU_ASSERT(tmp != rArg2); - - if (rArg2 != object) - build.mov(rArg2, object); - - if (rArg3 != tmp) - build.mov(rArg3, tmp); - } - else - { - if (rArg3 != tmp) - build.mov(rArg3, tmp); - - if (rArg2 != object) - build.mov(rArg2, object); - } - - build.mov(rArg1, rState); - build.call(qword[rNativeContext + contextOffset]); } -void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip) +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, Label& skip) { - callBarrierImpl(build, tmp, table, ra, skip, offsetof(NativeContext, luaC_barriertable)); + ScopedRegX64 tmp{regs, SizeX64::qword}; + checkObjectBarrierConditions(build, tmp.reg, object, ra, skip); + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, object, objectOp); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierf)]); } -void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip) -{ - callBarrierImpl(build, tmp, object, ra, skip, offsetof(NativeContext, luaC_barrierf)); -} - -void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip) +void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp, Label& skip) { // isblack(obj2gco(t)) build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT)); build.jcc(ConditionX64::Zero, skip); - // Argument setup re-ordered to avoid conflicts with table register - if (table != rArg2) - build.mov(rArg2, table); - build.lea(rArg3, addr[rArg2 + offsetof(Table, gclist)]); - build.mov(rArg1, rState); - build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, table, tableOp); + callWrap.addArgument(SizeX64::qword, addr[table + offsetof(Table, gclist)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); } -void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip) +void callCheckGc(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& skip) { - build.mov(rax, qword[rState + offsetof(lua_State, global)]); - build.mov(rdx, qword[rax + offsetof(global_State, totalbytes)]); - build.cmp(rdx, qword[rax + offsetof(global_State, GCthreshold)]); - build.jcc(ConditionX64::Below, skip); - - if (savepc) - emitSetSavedPc(build, pcpos + 1); + { + ScopedRegX64 tmp1{regs, SizeX64::qword}; + ScopedRegX64 tmp2{regs, SizeX64::qword}; - build.mov(rArg1, rState); - build.mov(dwordReg(rArg2), 1); - build.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]); + build.mov(tmp1.reg, qword[rState + offsetof(lua_State, global)]); + build.mov(tmp2.reg, qword[tmp1.reg + offsetof(global_State, totalbytes)]); + build.cmp(tmp2.reg, qword[tmp1.reg + offsetof(global_State, GCthreshold)]); + build.jcc(ConditionX64::Below, skip); + } + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, 1); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]); emitUpdateBase(build); } -void callGetFastTmOrFallback(AssemblyBuilderX64& build, RegisterX64 table, TMS tm, Label& fallback) -{ - build.mov(rArg1, qword[table + offsetof(Table, metatable)]); - build.test(rArg1, rArg1); - build.jcc(ConditionX64::Zero, fallback); // no metatable - - build.test(byte[rArg1 + offsetof(Table, tmcache)], 1 << tm); - build.jcc(ConditionX64::NotZero, fallback); // no tag method - - // rArg1 is already prepared - build.mov(rArg2, tm); - build.mov(rax, qword[rState + offsetof(lua_State, global)]); - build.mov(rArg3, qword[rax + offsetof(global_State, tmname) + tm * sizeof(TString*)]); - build.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); -} - void emitExit(AssemblyBuilderX64& build, bool continueInVm) { if (continueInVm) @@ -317,6 +282,8 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos) build.mov(dwordReg(rArg2), -1); // function accepts 'int' here and using qword reg would've forced 8 byte constant here build.call(r8); + emitUpdateBase(build); // interrupt may have reallocated stack + // Check if we need to exit build.mov(al, byte[rState + offsetof(lua_State, status)]); build.test(al, al); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 6b6762550..85045ad5b 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -27,10 +27,13 @@ namespace CodeGen enum class IrCondition : uint8_t; struct NativeState; +struct IrOp; namespace X64 { +struct IrRegAllocX64; + // Data that is very common to access is placed in non-volatile registers constexpr RegisterX64 rState = r15; // lua_State* L constexpr RegisterX64 rBase = r14; // StkId base @@ -233,21 +236,20 @@ inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX6 } void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label); -void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label); +void jumpOnAnyCmpFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label); void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); -void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm); -void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb); -void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init); -void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); -void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); -void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip); -void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip); -void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip); -void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip); -void callGetFastTmOrFallback(AssemblyBuilderX64& build, RegisterX64 table, TMS tm, Label& fallback); +void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm); +void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb); +void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init); +void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); +void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); +void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip); +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, Label& skip); +void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp, Label& skip); +void callCheckGc(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& skip); void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitUpdateBase(AssemblyBuilderX64& build); diff --git a/CodeGen/src/EmitInstructionA64.cpp b/CodeGen/src/EmitInstructionA64.cpp index 8289ee2ee..400ba77e0 100644 --- a/CodeGen/src/EmitInstructionA64.cpp +++ b/CodeGen/src/EmitInstructionA64.cpp @@ -23,35 +23,50 @@ void emitInstReturn(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, i build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, returnFallback))); build.blr(x3); - emitUpdateBase(build); + // reentry with x0=closure (NULL will trigger exit) + build.b(helpers.reentry); +} - // If the fallback requested an exit, we need to do this right away - build.cbz(x0, helpers.exitNoContinueVm); +void emitInstCall(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults) +{ + // argtop = (nparams == LUA_MULTRET) ? L->top : ra + 1 + nparams; + if (nparams == LUA_MULTRET) + build.ldr(x2, mem(rState, offsetof(lua_State, top))); + else + build.add(x2, rBase, uint16_t((ra + 1 + nparams) * sizeof(TValue))); - // Need to update state of the current function before we jump away - build.ldr(x1, mem(rState, offsetof(lua_State, ci))); // L->ci - build.ldr(x1, mem(x1, offsetof(CallInfo, func))); // L->ci->func - build.ldr(rClosure, mem(x1, offsetof(TValue, value.gc))); // L->ci->func->value.gc aka cl + // callFallback(L, ra, argtop, nresults) + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(ra * sizeof(TValue))); + build.mov(w3, nresults); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, callFallback))); + build.blr(x4); - build.ldr(x1, mem(rClosure, offsetof(Closure, l.p))); // cl->l.p aka proto + // reentry with x0=closure (NULL will trigger exit) + build.b(helpers.reentry); +} - build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k - build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code +void emitInstGetImport(AssemblyBuilderA64& build, int ra, uint32_t aux) +{ + // luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false) + build.mov(x0, rState); + build.ldr(x1, mem(rClosure, offsetof(Closure, env))); + build.mov(x2, rConstants); + build.mov(w3, aux); + build.mov(w4, 0); + build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_getimport))); + build.blr(x5); - // Get instruction index from instruction pointer - // To get instruction index from instruction pointer, we need to divide byte offset by 4 - // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out - build.sub(x2, x0, rCode); - build.add(x2, x2, x2); // TODO: this would not be necessary if we supported shifted register offsets in loads + emitUpdateBase(build); - // We need to check if the new function can be executed natively - build.ldr(x1, mem(x1, offsetofProtoExecData)); - build.cbz(x1, helpers.exitContinueVm); + // setobj2s(L, ra, L->top - 1) + build.ldr(x0, mem(rState, offsetof(lua_State, top))); + build.sub(x0, x0, sizeof(TValue)); + build.ldr(q0, x0); + build.str(q0, mem(rBase, ra * sizeof(TValue))); - // Get new instruction location and jump to it - build.ldr(x1, mem(x1, offsetof(NativeProto, instTargets))); - build.ldr(x1, mem(x1, x2)); - build.br(x1); + // L->top-- + build.str(x0, mem(rState, offsetof(lua_State, top))); } } // namespace A64 diff --git a/CodeGen/src/EmitInstructionA64.h b/CodeGen/src/EmitInstructionA64.h index 7f15d819b..278d8e8e3 100644 --- a/CodeGen/src/EmitInstructionA64.h +++ b/CodeGen/src/EmitInstructionA64.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include + namespace Luau { namespace CodeGen @@ -14,6 +16,8 @@ namespace A64 class AssemblyBuilderA64; void emitInstReturn(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int n); +void emitInstCall(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); +void emitInstGetImport(AssemblyBuilderA64& build, int ra, uint32_t aux); } // namespace A64 } // namespace CodeGen diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 649498f55..b645f9f7a 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -2,6 +2,7 @@ #include "EmitInstructionX64.h" #include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrRegAllocX64.h" #include "CustomExecUtils.h" #include "EmitCommonX64.h" @@ -315,7 +316,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i build.jmp(qword[rdx + rax * 2]); } -void emitInstSetList(AssemblyBuilderX64& build, Label& next, int ra, int rb, int count, uint32_t index) +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& next, int ra, int rb, int count, uint32_t index) { OperandX64 last = index + count - 1; @@ -346,7 +347,7 @@ void emitInstSetList(AssemblyBuilderX64& build, Label& next, int ra, int rb, int Label skipResize; - RegisterX64 table = rax; + RegisterX64 table = regs.takeReg(rax); build.mov(table, luauRegValue(ra)); @@ -411,7 +412,7 @@ void emitInstSetList(AssemblyBuilderX64& build, Label& next, int ra, int rb, int build.setLabel(endLoop); } - callBarrierTableFast(build, table, next); + callBarrierTableFast(regs, build, table, {}, next); } void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit) @@ -483,10 +484,8 @@ void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRep build.jcc(ConditionX64::NotZero, loopRepeat); } -void emitinstForGLoopFallback(AssemblyBuilderX64& build, int pcpos, int ra, int aux, Label& loopRepeat) +void emitinstForGLoopFallback(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat) { - emitSetSavedPc(build, pcpos + 1); - build.mov(rArg1, rState); build.mov(dwordReg(rArg2), ra); build.mov(dwordReg(rArg3), aux); diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index 880c9fa4f..cc1b86456 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -15,12 +15,13 @@ namespace X64 { class AssemblyBuilderX64; +struct IrRegAllocX64; void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults); -void emitInstSetList(AssemblyBuilderX64& build, Label& next, int ra, int rb, int count, uint32_t index); +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& next, int ra, int rb, int count, uint32_t index); void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit); -void emitinstForGLoopFallback(AssemblyBuilderX64& build, int pcpos, int ra, int aux, Label& loopRepeat); +void emitinstForGLoopFallback(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat); void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target); void emitInstAnd(AssemblyBuilderX64& build, int ra, int rb, int rc); void emitInstAndK(AssemblyBuilderX64& build, int ra, int rb, int kc); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 6e77dfe44..b248b97d5 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -300,17 +300,17 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& if (function.boolOp(inst.b)) capturedRegs.set(inst.a.index, true); break; - case IrCmd::LOP_SETLIST: + case IrCmd::SETLIST: use(inst.b); useRange(inst.c.index, function.intOp(inst.d)); break; - case IrCmd::LOP_CALL: + case IrCmd::CALL: use(inst.a); useRange(inst.a.index + 1, function.intOp(inst.b)); defRange(inst.a.index, function.intOp(inst.c)); break; - case IrCmd::LOP_RETURN: + case IrCmd::RETURN: useRange(inst.a.index, function.intOp(inst.b)); break; case IrCmd::FASTCALL: @@ -341,7 +341,7 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& if (int count = function.intOp(inst.f); count != -1) defRange(inst.b.index, count); break; - case IrCmd::LOP_FORGLOOP: + case IrCmd::FORGLOOP: // First register is not used by instruction, we check that it's still 'nil' with CHECK_TAG use(inst.a, 1); use(inst.a, 2); @@ -349,26 +349,26 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& def(inst.a, 2); defRange(inst.a.index + 3, function.intOp(inst.b)); break; - case IrCmd::LOP_FORGLOOP_FALLBACK: - useRange(inst.b.index, 3); + case IrCmd::FORGLOOP_FALLBACK: + useRange(inst.a.index, 3); - def(inst.b, 2); - defRange(inst.b.index + 3, uint8_t(function.intOp(inst.c))); // ignore most significant bit + def(inst.a, 2); + defRange(inst.a.index + 3, uint8_t(function.intOp(inst.b))); // ignore most significant bit break; - case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + case IrCmd::FORGPREP_XNEXT_FALLBACK: use(inst.b); break; // A <- B, C - case IrCmd::LOP_AND: - case IrCmd::LOP_OR: + case IrCmd::AND: + case IrCmd::OR: use(inst.b); use(inst.c); def(inst.a); break; // A <- B - case IrCmd::LOP_ANDK: - case IrCmd::LOP_ORK: + case IrCmd::ANDK: + case IrCmd::ORK: use(inst.b); def(inst.a); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 239f7a8e6..4fee080ba 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -135,7 +135,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) inst(IrCmd::INTERRUPT, constUint(i)); inst(IrCmd::SET_SAVEDPC, constUint(i + 1)); - inst(IrCmd::LOP_CALL, vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1), constInt(LUAU_INSN_C(*pc) - 1)); + inst(IrCmd::CALL, vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1), constInt(LUAU_INSN_C(*pc) - 1)); if (activeFastcallFallback) { @@ -149,7 +149,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) case LOP_RETURN: inst(IrCmd::INTERRUPT, constUint(i)); - inst(IrCmd::LOP_RETURN, vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); + inst(IrCmd::RETURN, vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); break; case LOP_GETTABLE: translateInstGetTable(*this, pc, i); @@ -266,7 +266,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstDupTable(*this, pc, i); break; case LOP_SETLIST: - inst(IrCmd::LOP_SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); + inst(IrCmd::SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); break; case LOP_GETUPVAL: translateInstGetUpval(*this, pc, i); @@ -347,10 +347,11 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) inst(IrCmd::INTERRUPT, constUint(i)); loadAndCheckTag(vmReg(ra), LUA_TNIL, fallback); - inst(IrCmd::LOP_FORGLOOP, vmReg(ra), constInt(aux), loopRepeat, loopExit); + inst(IrCmd::FORGLOOP, vmReg(ra), constInt(aux), loopRepeat, loopExit); beginBlock(fallback); - inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), vmReg(ra), constInt(aux), loopRepeat, loopExit); + inst(IrCmd::SET_SAVEDPC, constUint(i + 1)); + inst(IrCmd::FORGLOOP_FALLBACK, vmReg(ra), constInt(aux), loopRepeat, loopExit); beginBlock(loopExit); } @@ -363,19 +364,19 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstForGPrepInext(*this, pc, i); break; case LOP_AND: - inst(IrCmd::LOP_AND, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); + inst(IrCmd::AND, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); break; case LOP_ANDK: - inst(IrCmd::LOP_ANDK, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); + inst(IrCmd::ANDK, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); break; case LOP_OR: - inst(IrCmd::LOP_OR, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); + inst(IrCmd::OR, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); break; case LOP_ORK: - inst(IrCmd::LOP_ORK, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); + inst(IrCmd::ORK, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); break; case LOP_COVERAGE: - inst(IrCmd::LOP_COVERAGE, constUint(i)); + inst(IrCmd::COVERAGE, constUint(i)); break; case LOP_GETIMPORT: translateInstGetImport(*this, pc, i); diff --git a/CodeGen/src/IrCallWrapperX64.cpp b/CodeGen/src/IrCallWrapperX64.cpp new file mode 100644 index 000000000..4f0c0cf66 --- /dev/null +++ b/CodeGen/src/IrCallWrapperX64.cpp @@ -0,0 +1,400 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IrCallWrapperX64.h" + +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrRegAllocX64.h" + +#include "EmitCommonX64.h" + +namespace Luau +{ +namespace CodeGen +{ +namespace X64 +{ + +static bool sameUnderlyingRegister(RegisterX64 a, RegisterX64 b) +{ + SizeX64 underlyingSizeA = a.size == SizeX64::xmmword ? SizeX64::xmmword : SizeX64::qword; + SizeX64 underlyingSizeB = b.size == SizeX64::xmmword ? SizeX64::xmmword : SizeX64::qword; + + return underlyingSizeA == underlyingSizeB && a.index == b.index; +} + +IrCallWrapperX64::IrCallWrapperX64(IrRegAllocX64& regs, AssemblyBuilderX64& build, uint32_t instIdx) + : regs(regs) + , build(build) + , instIdx(instIdx) + , funcOp(noreg) +{ + gprUses.fill(0); + xmmUses.fill(0); +} + +void IrCallWrapperX64::addArgument(SizeX64 targetSize, OperandX64 source, IrOp sourceOp) +{ + // Instruction operands rely on current instruction index for lifetime tracking + LUAU_ASSERT(instIdx != kInvalidInstIdx || sourceOp.kind == IrOpKind::None); + + LUAU_ASSERT(argCount < kMaxCallArguments); + args[argCount++] = {targetSize, source, sourceOp}; +} + +void IrCallWrapperX64::addArgument(SizeX64 targetSize, ScopedRegX64& scopedReg) +{ + LUAU_ASSERT(argCount < kMaxCallArguments); + args[argCount++] = {targetSize, scopedReg.release(), {}}; +} + +void IrCallWrapperX64::call(const OperandX64& func) +{ + funcOp = func; + + assignTargetRegisters(); + + countRegisterUses(); + + for (int i = 0; i < argCount; ++i) + { + CallArgument& arg = args[i]; + + // If source is the last use of IrInst, clear the register + // Source registers are recorded separately in CallArgument + if (arg.sourceOp.kind != IrOpKind::None) + { + if (IrInst* inst = regs.function.asInstOp(arg.sourceOp)) + { + if (regs.isLastUseReg(*inst, instIdx)) + inst->regX64 = noreg; + } + } + + // Immediate values are stored at the end since they are not interfering and target register can still be used temporarily + if (arg.source.cat == CategoryX64::imm) + { + arg.candidate = false; + } + // Arguments passed through stack can be handled immediately + else if (arg.target.cat == CategoryX64::mem) + { + if (arg.source.cat == CategoryX64::mem) + { + ScopedRegX64 tmp{regs, arg.target.memSize}; + + freeSourceRegisters(arg); + + build.mov(tmp.reg, arg.source); + build.mov(arg.target, tmp.reg); + } + else + { + freeSourceRegisters(arg); + + build.mov(arg.target, arg.source); + } + + arg.candidate = false; + } + // Skip arguments that are already in their place + else if (arg.source.cat == CategoryX64::reg && sameUnderlyingRegister(arg.target.base, arg.source.base)) + { + freeSourceRegisters(arg); + + // If target is not used as source in other arguments, prevent register allocator from giving it out + if (getRegisterUses(arg.target.base) == 0) + regs.takeReg(arg.target.base); + else // Otherwise, make sure we won't free it when last source use is completed + addRegisterUse(arg.target.base); + + arg.candidate = false; + } + } + + // Repeat until we run out of arguments to pass + while (true) + { + // Find target argument register that is not an active source + if (CallArgument* candidate = findNonInterferingArgument()) + { + // This section is only for handling register targets + LUAU_ASSERT(candidate->target.cat == CategoryX64::reg); + + freeSourceRegisters(*candidate); + + LUAU_ASSERT(getRegisterUses(candidate->target.base) == 0); + regs.takeReg(candidate->target.base); + + moveToTarget(*candidate); + + candidate->candidate = false; + } + // If all registers cross-interfere (rcx <- rdx, rdx <- rcx), one has to be renamed + else if (RegisterX64 conflict = findConflictingTarget(); conflict != noreg) + { + // Get a fresh register + RegisterX64 freshReg = conflict.size == SizeX64::xmmword ? regs.allocXmmReg() : regs.allocGprReg(conflict.size); + + if (conflict.size == SizeX64::xmmword) + build.vmovsd(freshReg, conflict, conflict); + else + build.mov(freshReg, conflict); + + renameSourceRegisters(conflict, freshReg); + } + else + { + for (int i = 0; i < argCount; ++i) + LUAU_ASSERT(!args[i].candidate); + break; + } + } + + // Handle immediate arguments last + for (int i = 0; i < argCount; ++i) + { + CallArgument& arg = args[i]; + + if (arg.source.cat == CategoryX64::imm) + { + if (arg.target.cat == CategoryX64::reg) + regs.takeReg(arg.target.base); + + moveToTarget(arg); + } + } + + // Free registers used in the function call + removeRegisterUse(funcOp.base); + removeRegisterUse(funcOp.index); + + // Just before the call is made, argument registers are all marked as free in register allocator + for (int i = 0; i < argCount; ++i) + { + CallArgument& arg = args[i]; + + if (arg.target.cat == CategoryX64::reg) + regs.freeReg(arg.target.base); + } + + build.call(funcOp); +} + +void IrCallWrapperX64::assignTargetRegisters() +{ + static const std::array kWindowsGprOrder = {rcx, rdx, r8, r9, addr[rsp + 32], addr[rsp + 40]}; + static const std::array kSystemvGprOrder = {rdi, rsi, rdx, rcx, r8, r9}; + + const std::array& gprOrder = build.abi == ABIX64::Windows ? kWindowsGprOrder : kSystemvGprOrder; + static const std::array kXmmOrder = {xmm0, xmm1, xmm2, xmm3}; // Common order for first 4 fp arguments on Windows/SystemV + + int gprPos = 0; + int xmmPos = 0; + + for (int i = 0; i < argCount; i++) + { + CallArgument& arg = args[i]; + + if (arg.targetSize == SizeX64::xmmword) + { + LUAU_ASSERT(size_t(xmmPos) < kXmmOrder.size()); + arg.target = kXmmOrder[xmmPos++]; + + if (build.abi == ABIX64::Windows) + gprPos++; // On Windows, gpr/xmm register positions move in sync + } + else + { + LUAU_ASSERT(size_t(gprPos) < gprOrder.size()); + arg.target = gprOrder[gprPos++]; + + if (build.abi == ABIX64::Windows) + xmmPos++; // On Windows, gpr/xmm register positions move in sync + + // Keep requested argument size + if (arg.target.cat == CategoryX64::reg) + arg.target.base.size = arg.targetSize; + else if (arg.target.cat == CategoryX64::mem) + arg.target.memSize = arg.targetSize; + } + } +} + +void IrCallWrapperX64::countRegisterUses() +{ + for (int i = 0; i < argCount; ++i) + { + addRegisterUse(args[i].source.base); + addRegisterUse(args[i].source.index); + } + + addRegisterUse(funcOp.base); + addRegisterUse(funcOp.index); +} + +CallArgument* IrCallWrapperX64::findNonInterferingArgument() +{ + for (int i = 0; i < argCount; ++i) + { + CallArgument& arg = args[i]; + + if (arg.candidate && !interferesWithActiveSources(arg, i) && !interferesWithOperand(funcOp, arg.target.base)) + return &arg; + } + + return nullptr; +} + +bool IrCallWrapperX64::interferesWithOperand(const OperandX64& op, RegisterX64 reg) const +{ + return sameUnderlyingRegister(op.base, reg) || sameUnderlyingRegister(op.index, reg); +} + +bool IrCallWrapperX64::interferesWithActiveSources(const CallArgument& targetArg, int targetArgIndex) const +{ + for (int i = 0; i < argCount; ++i) + { + const CallArgument& arg = args[i]; + + if (arg.candidate && i != targetArgIndex && interferesWithOperand(arg.source, targetArg.target.base)) + return true; + } + + return false; +} + +bool IrCallWrapperX64::interferesWithActiveTarget(RegisterX64 sourceReg) const +{ + for (int i = 0; i < argCount; ++i) + { + const CallArgument& arg = args[i]; + + if (arg.candidate && sameUnderlyingRegister(arg.target.base, sourceReg)) + return true; + } + + return false; +} + +void IrCallWrapperX64::moveToTarget(CallArgument& arg) +{ + if (arg.source.cat == CategoryX64::reg) + { + RegisterX64 source = arg.source.base; + + if (source.size == SizeX64::xmmword) + build.vmovsd(arg.target, source, source); + else + build.mov(arg.target, source); + } + else if (arg.source.cat == CategoryX64::imm) + { + build.mov(arg.target, arg.source); + } + else + { + if (arg.source.memSize == SizeX64::none) + build.lea(arg.target, arg.source); + else if (arg.target.base.size == SizeX64::xmmword && arg.source.memSize == SizeX64::xmmword) + build.vmovups(arg.target, arg.source); + else if (arg.target.base.size == SizeX64::xmmword) + build.vmovsd(arg.target, arg.source); + else + build.mov(arg.target, arg.source); + } +} + +void IrCallWrapperX64::freeSourceRegisters(CallArgument& arg) +{ + removeRegisterUse(arg.source.base); + removeRegisterUse(arg.source.index); +} + +void IrCallWrapperX64::renameRegister(RegisterX64& target, RegisterX64 reg, RegisterX64 replacement) +{ + if (sameUnderlyingRegister(target, reg)) + { + addRegisterUse(replacement); + removeRegisterUse(target); + + target.index = replacement.index; // Only change index, size is preserved + } +} + +void IrCallWrapperX64::renameSourceRegisters(RegisterX64 reg, RegisterX64 replacement) +{ + for (int i = 0; i < argCount; ++i) + { + CallArgument& arg = args[i]; + + if (arg.candidate) + { + renameRegister(arg.source.base, reg, replacement); + renameRegister(arg.source.index, reg, replacement); + } + } + + renameRegister(funcOp.base, reg, replacement); + renameRegister(funcOp.index, reg, replacement); +} + +RegisterX64 IrCallWrapperX64::findConflictingTarget() const +{ + for (int i = 0; i < argCount; ++i) + { + const CallArgument& arg = args[i]; + + if (arg.candidate) + { + if (interferesWithActiveTarget(arg.source.base)) + return arg.source.base; + + if (interferesWithActiveTarget(arg.source.index)) + return arg.source.index; + } + } + + if (interferesWithActiveTarget(funcOp.base)) + return funcOp.base; + + if (interferesWithActiveTarget(funcOp.index)) + return funcOp.index; + + return noreg; +} + +int IrCallWrapperX64::getRegisterUses(RegisterX64 reg) const +{ + return reg.size == SizeX64::xmmword ? xmmUses[reg.index] : (reg.size != SizeX64::none ? gprUses[reg.index] : 0); +} + +void IrCallWrapperX64::addRegisterUse(RegisterX64 reg) +{ + if (reg.size == SizeX64::xmmword) + xmmUses[reg.index]++; + else if (reg.size != SizeX64::none) + gprUses[reg.index]++; +} + +void IrCallWrapperX64::removeRegisterUse(RegisterX64 reg) +{ + if (reg.size == SizeX64::xmmword) + { + LUAU_ASSERT(xmmUses[reg.index] != 0); + xmmUses[reg.index]--; + + if (xmmUses[reg.index] == 0) // we don't use persistent xmm regs so no need to call shouldFreeRegister + regs.freeReg(reg); + } + else if (reg.size != SizeX64::none) + { + LUAU_ASSERT(gprUses[reg.index] != 0); + gprUses[reg.index]--; + + if (gprUses[reg.index] == 0 && regs.shouldFreeGpr(reg)) + regs.freeReg(reg); + } +} + +} // namespace X64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 53654d6a2..fb56df8c5 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -126,6 +126,16 @@ const char* getCmdName(IrCmd cmd) return "MAX_NUM"; case IrCmd::UNM_NUM: return "UNM_NUM"; + case IrCmd::FLOOR_NUM: + return "FLOOR_NUM"; + case IrCmd::CEIL_NUM: + return "CEIL_NUM"; + case IrCmd::ROUND_NUM: + return "ROUND_NUM"; + case IrCmd::SQRT_NUM: + return "SQRT_NUM"; + case IrCmd::ABS_NUM: + return "ABS_NUM"; case IrCmd::NOT_ANY: return "NOT_ANY"; case IrCmd::JUMP: @@ -216,28 +226,28 @@ const char* getCmdName(IrCmd cmd) return "CLOSE_UPVALS"; case IrCmd::CAPTURE: return "CAPTURE"; - case IrCmd::LOP_SETLIST: - return "LOP_SETLIST"; - case IrCmd::LOP_CALL: - return "LOP_CALL"; - case IrCmd::LOP_RETURN: - return "LOP_RETURN"; - case IrCmd::LOP_FORGLOOP: - return "LOP_FORGLOOP"; - case IrCmd::LOP_FORGLOOP_FALLBACK: - return "LOP_FORGLOOP_FALLBACK"; - case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: - return "LOP_FORGPREP_XNEXT_FALLBACK"; - case IrCmd::LOP_AND: - return "LOP_AND"; - case IrCmd::LOP_ANDK: - return "LOP_ANDK"; - case IrCmd::LOP_OR: - return "LOP_OR"; - case IrCmd::LOP_ORK: - return "LOP_ORK"; - case IrCmd::LOP_COVERAGE: - return "LOP_COVERAGE"; + case IrCmd::SETLIST: + return "SETLIST"; + case IrCmd::CALL: + return "CALL"; + case IrCmd::RETURN: + return "RETURN"; + case IrCmd::FORGLOOP: + return "FORGLOOP"; + case IrCmd::FORGLOOP_FALLBACK: + return "FORGLOOP_FALLBACK"; + case IrCmd::FORGPREP_XNEXT_FALLBACK: + return "FORGPREP_XNEXT_FALLBACK"; + case IrCmd::AND: + return "AND"; + case IrCmd::ANDK: + return "ANDK"; + case IrCmd::OR: + return "OR"; + case IrCmd::ORK: + return "ORK"; + case IrCmd::COVERAGE: + return "COVERAGE"; case IrCmd::FALLBACK_GETGLOBAL: return "FALLBACK_GETGLOBAL"; case IrCmd::FALLBACK_SETGLOBAL: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index ae4bc017d..37f381572 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -13,6 +13,9 @@ #include "lstate.h" +// TODO: Eventually this can go away +// #define TRACE + namespace Luau { namespace CodeGen @@ -20,12 +23,67 @@ namespace CodeGen namespace A64 { +#ifdef TRACE +struct LoweringStatsA64 +{ + size_t can; + size_t total; + + ~LoweringStatsA64() + { + if (total) + printf("A64 lowering succeded for %.1f%% functions (%d/%d)\n", double(can) / double(total) * 100, int(can), int(total)); + } +} gStatsA64; +#endif + +inline ConditionA64 getConditionFP(IrCondition cond) +{ + switch (cond) + { + case IrCondition::Equal: + return ConditionA64::Equal; + + case IrCondition::NotEqual: + return ConditionA64::NotEqual; + + case IrCondition::Less: + return ConditionA64::Minus; + + case IrCondition::NotLess: + return ConditionA64::Plus; + + case IrCondition::LessEqual: + return ConditionA64::UnsignedLessEqual; + + case IrCondition::NotLessEqual: + return ConditionA64::UnsignedGreater; + + case IrCondition::Greater: + return ConditionA64::Greater; + + case IrCondition::NotGreater: + return ConditionA64::LessEqual; + + case IrCondition::GreaterEqual: + return ConditionA64::GreaterEqual; + + case IrCondition::NotGreaterEqual: + return ConditionA64::Less; + + default: + LUAU_ASSERT(!"Unexpected condition code"); + return ConditionA64::Always; + } +} + IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) : build(build) , helpers(helpers) , data(data) , proto(proto) , function(function) + , regs(function, {{x0, x15}, {q0, q7}, {q16, q31}}) { // In order to allocate registers during lowering, we need to know where instruction results are last used updateLastUseLocations(function); @@ -34,20 +92,61 @@ IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, // TODO: Eventually this can go away bool IrLoweringA64::canLower(const IrFunction& function) { +#ifdef TRACE + gStatsA64.total++; +#endif + for (const IrInst& inst : function.instructions) { switch (inst.cmd) { case IrCmd::NOP: - case IrCmd::SUBSTITUTE: + case IrCmd::LOAD_TAG: + case IrCmd::LOAD_POINTER: + case IrCmd::LOAD_DOUBLE: + case IrCmd::LOAD_INT: + case IrCmd::LOAD_TVALUE: + case IrCmd::LOAD_NODE_VALUE_TV: + case IrCmd::LOAD_ENV: + case IrCmd::STORE_TAG: + case IrCmd::STORE_POINTER: + case IrCmd::STORE_DOUBLE: + case IrCmd::STORE_INT: + case IrCmd::STORE_TVALUE: + case IrCmd::STORE_NODE_VALUE_TV: + case IrCmd::ADD_NUM: + case IrCmd::SUB_NUM: + case IrCmd::MUL_NUM: + case IrCmd::DIV_NUM: + case IrCmd::MOD_NUM: + case IrCmd::UNM_NUM: + case IrCmd::JUMP: + case IrCmd::JUMP_EQ_TAG: + case IrCmd::JUMP_CMP_NUM: + case IrCmd::JUMP_CMP_ANY: + case IrCmd::DO_ARITH: + case IrCmd::GET_IMPORT: + case IrCmd::GET_UPVALUE: + case IrCmd::CHECK_TAG: + case IrCmd::CHECK_READONLY: + case IrCmd::CHECK_NO_METATABLE: + case IrCmd::CHECK_SAFE_ENV: case IrCmd::INTERRUPT: - case IrCmd::LOP_RETURN: + case IrCmd::SET_SAVEDPC: + case IrCmd::CALL: + case IrCmd::RETURN: + case IrCmd::SUBSTITUTE: continue; + default: return false; } } +#ifdef TRACE + gStatsA64.can++; +#endif + return true; } @@ -55,23 +154,338 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { switch (inst.cmd) { + case IrCmd::LOAD_TAG: + { + inst.regA64 = regs.allocReg(KindA64::w); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, tt)); + build.ldr(inst.regA64, addr); + break; + } + case IrCmd::LOAD_POINTER: + { + inst.regA64 = regs.allocReg(KindA64::x); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.ldr(inst.regA64, addr); + break; + } + case IrCmd::LOAD_DOUBLE: + { + inst.regA64 = regs.allocReg(KindA64::d); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.ldr(inst.regA64, addr); + break; + } + case IrCmd::LOAD_INT: + { + inst.regA64 = regs.allocReg(KindA64::w); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.ldr(inst.regA64, addr); + break; + } + case IrCmd::LOAD_TVALUE: + { + inst.regA64 = regs.allocReg(KindA64::q); + AddressA64 addr = tempAddr(inst.a, 0); + build.ldr(inst.regA64, addr); + break; + } + case IrCmd::LOAD_NODE_VALUE_TV: + { + inst.regA64 = regs.allocReg(KindA64::q); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(LuaNode, val))); + break; + } + case IrCmd::LOAD_ENV: + inst.regA64 = regs.allocReg(KindA64::x); + build.ldr(inst.regA64, mem(rClosure, offsetof(Closure, env))); + break; + case IrCmd::STORE_TAG: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, tt)); + build.mov(temp, tagOp(inst.b)); + build.str(temp, addr); + break; + } + case IrCmd::STORE_POINTER: + { + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.str(regOp(inst.b), addr); + break; + } + case IrCmd::STORE_DOUBLE: + { + RegisterA64 temp = tempDouble(inst.b); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.str(temp, addr); + break; + } + case IrCmd::STORE_INT: + { + RegisterA64 temp = tempInt(inst.b); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.str(temp, addr); + break; + } + case IrCmd::STORE_TVALUE: + { + AddressA64 addr = tempAddr(inst.a, 0); + build.str(regOp(inst.b), addr); + break; + } + case IrCmd::STORE_NODE_VALUE_TV: + build.str(regOp(inst.b), mem(regOp(inst.a), offsetof(LuaNode, val))); + break; + case IrCmd::ADD_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fadd(inst.regA64, temp1, temp2); + break; + } + case IrCmd::SUB_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fsub(inst.regA64, temp1, temp2); + break; + } + case IrCmd::MUL_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fmul(inst.regA64, temp1, temp2); + break; + } + case IrCmd::DIV_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fdiv(inst.regA64, temp1, temp2); + break; + } + case IrCmd::MOD_NUM: + { + inst.regA64 = regs.allocReg(KindA64::d); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fdiv(inst.regA64, temp1, temp2); + build.frintm(inst.regA64, inst.regA64); + build.fmul(inst.regA64, inst.regA64, temp2); + build.fsub(inst.regA64, temp1, inst.regA64); + break; + } + case IrCmd::UNM_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.fneg(inst.regA64, temp); + break; + } + case IrCmd::JUMP: + jumpOrFallthrough(blockOp(inst.a), next); + break; + case IrCmd::JUMP_EQ_TAG: + if (inst.b.kind == IrOpKind::Constant) + build.cmp(regOp(inst.a), tagOp(inst.b)); + else if (inst.b.kind == IrOpKind::Inst) + build.cmp(regOp(inst.a), regOp(inst.b)); + else + LUAU_ASSERT(!"Unsupported instruction form"); + + if (isFallthroughBlock(blockOp(inst.d), next)) + { + build.b(ConditionA64::Equal, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + } + else + { + build.b(ConditionA64::NotEqual, labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.c), next); + } + break; + case IrCmd::JUMP_CMP_NUM: + { + IrCondition cond = conditionOp(inst.c); + + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + + build.fcmp(temp1, temp2); + build.b(getConditionFP(cond), labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.e), next); + break; + } + case IrCmd::JUMP_CMP_ANY: + { + IrCondition cond = conditionOp(inst.c); + + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (cond == IrCondition::NotLessEqual || cond == IrCondition::LessEqual) + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_lessequal))); + else if (cond == IrCondition::NotLess || cond == IrCondition::Less) + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_lessthan))); + else if (cond == IrCondition::NotEqual || cond == IrCondition::Equal) + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_equalval))); + else + LUAU_ASSERT(!"Unsupported condition"); + + build.blr(x3); + + emitUpdateBase(build); + + if (cond == IrCondition::NotLessEqual || cond == IrCondition::NotLess || cond == IrCondition::NotEqual) + build.cbz(x0, labelOp(inst.d)); + else + build.cbnz(x0, labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.e), next); + break; + } + case IrCmd::DO_ARITH: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (inst.c.kind == IrOpKind::VmConst) + { + // TODO: refactor into a common helper + if (vmConstOp(inst.c) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(x3, rConstants, uint16_t(vmConstOp(inst.c) * sizeof(TValue))); + } + else + { + build.mov(x3, vmConstOp(inst.c) * sizeof(TValue)); + build.add(x3, rConstants, x3); + } + } + else + build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + + build.mov(w4, TMS(intOp(inst.d))); + build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_doarith))); + build.blr(x5); + + emitUpdateBase(build); + break; + case IrCmd::GET_IMPORT: + regs.assertAllFree(); + emitInstGetImport(build, vmRegOp(inst.a), uintOp(inst.b)); + break; + case IrCmd::GET_UPVALUE: + { + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::q); + RegisterA64 temp3 = regs.allocTemp(KindA64::w); + + build.add(temp1, rClosure, uint16_t(offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.b))); + + // uprefs[] is either an actual value, or it points to UpVal object which has a pointer to value + Label skip; + build.ldr(temp3, mem(temp1, offsetof(TValue, tt))); + build.cmp(temp3, LUA_TUPVAL); + build.b(ConditionA64::NotEqual, skip); + + // UpVal.v points to the value (either on stack, or on heap inside each UpVal, but we can deref it unconditionally) + build.ldr(temp1, mem(temp1, offsetof(TValue, value.gc))); + build.ldr(temp1, mem(temp1, offsetof(UpVal, v))); + + build.setLabel(skip); + + build.ldr(temp2, temp1); + build.str(temp2, mem(rBase, vmRegOp(inst.a) * sizeof(TValue))); + break; + } + case IrCmd::CHECK_TAG: + build.cmp(regOp(inst.a), tagOp(inst.b)); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + break; + case IrCmd::CHECK_READONLY: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldrb(temp, mem(regOp(inst.a), offsetof(Table, readonly))); + build.cbnz(temp, labelOp(inst.b)); + break; + } + case IrCmd::CHECK_NO_METATABLE: + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.ldr(temp, mem(regOp(inst.a), offsetof(Table, metatable))); + build.cbnz(temp, labelOp(inst.b)); + break; + } + case IrCmd::CHECK_SAFE_ENV: + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + RegisterA64 tempw{KindA64::w, temp.index}; + build.ldr(temp, mem(rClosure, offsetof(Closure, env))); + build.ldrb(tempw, mem(temp, offsetof(Table, safeenv))); + build.cbz(tempw, labelOp(inst.a)); + break; + } case IrCmd::INTERRUPT: { - emitInterrupt(build, uintOp(inst.a)); + unsigned int pcpos = uintOp(inst.a); + regs.assertAllFree(); + + Label skip; + build.ldr(x2, mem(rState, offsetof(lua_State, global))); + build.ldr(x2, mem(x2, offsetof(global_State, cb.interrupt))); + build.cbz(x2, skip); + + // Jump to outlined interrupt handler, it will give back control to x1 + build.mov(x0, (pcpos + 1) * sizeof(Instruction)); + build.adr(x1, skip); + build.b(helpers.interrupt); + + build.setLabel(skip); break; } - case IrCmd::LOP_RETURN: + case IrCmd::SET_SAVEDPC: { - emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); + unsigned int pcpos = uintOp(inst.a); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + // TODO: refactor into a common helper + if (pcpos * sizeof(Instruction) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(temp1, rCode, uint16_t(pcpos * sizeof(Instruction))); + } + else + { + build.mov(temp1, pcpos * sizeof(Instruction)); + build.add(temp1, rCode, temp1); + } + + build.ldr(temp2, mem(rState, offsetof(lua_State, ci))); + build.str(temp1, mem(temp2, offsetof(CallInfo, savedpc))); break; } + case IrCmd::CALL: + regs.assertAllFree(); + emitInstCall(build, helpers, vmRegOp(inst.a), intOp(inst.b), intOp(inst.c)); + break; + case IrCmd::RETURN: + regs.assertAllFree(); + emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); + break; default: LUAU_ASSERT(!"Not supported yet"); break; } - // TODO - // regs.freeLastUseRegs(inst, index); + regs.freeLastUseRegs(inst, index); + regs.freeTempRegs(); } bool IrLoweringA64::isFallthroughBlock(IrBlock target, IrBlock next) @@ -85,6 +499,83 @@ void IrLoweringA64::jumpOrFallthrough(IrBlock& target, IrBlock& next) build.b(target.label); } +RegisterA64 IrLoweringA64::tempDouble(IrOp op) +{ + if (op.kind == IrOpKind::Inst) + return regOp(op); + else if (op.kind == IrOpKind::Constant) + { + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::d); + build.adr(temp1, doubleOp(op)); + build.ldr(temp2, temp1); + return temp2; + } + else + { + LUAU_ASSERT(!"Unsupported instruction form"); + return noreg; + } +} + +RegisterA64 IrLoweringA64::tempInt(IrOp op) +{ + if (op.kind == IrOpKind::Inst) + return regOp(op); + else if (op.kind == IrOpKind::Constant) + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, intOp(op)); + return temp; + } + else + { + LUAU_ASSERT(!"Unsupported instruction form"); + return noreg; + } +} + +AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset) +{ + // This is needed to tighten the bounds checks in the VmConst case below + LUAU_ASSERT(offset % 4 == 0); + + if (op.kind == IrOpKind::VmReg) + return mem(rBase, vmRegOp(op) * sizeof(TValue) + offset); + else if (op.kind == IrOpKind::VmConst) + { + size_t constantOffset = vmConstOp(op) * sizeof(TValue) + offset; + + // Note: cumulative offset is guaranteed to be divisible by 4; we can use that to expand the useful range that doesn't require temporaries + if (constantOffset / 4 <= AddressA64::kMaxOffset) + return mem(rConstants, int(constantOffset)); + + RegisterA64 temp = regs.allocTemp(KindA64::x); + + // TODO: refactor into a common helper + if (constantOffset <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(temp, rConstants, uint16_t(constantOffset)); + } + else + { + build.mov(temp, int(constantOffset)); + build.add(temp, rConstants, temp); + } + + return temp; + } + // If we have a register, we assume it's a pointer to TValue + // We might introduce explicit operand types in the future to make this more robust + else if (op.kind == IrOpKind::Inst) + return mem(regOp(op), offset); + else + { + LUAU_ASSERT(!"Unsupported instruction form"); + return noreg; + } +} + RegisterA64 IrLoweringA64::regOp(IrOp op) const { IrInst& inst = function.instOp(op); diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index aa9eba422..f638432ff 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -4,6 +4,8 @@ #include "Luau/AssemblyBuilderA64.h" #include "Luau/IrData.h" +#include "IrRegAllocA64.h" + #include struct Proto; @@ -31,6 +33,11 @@ struct IrLoweringA64 bool isFallthroughBlock(IrBlock target, IrBlock next); void jumpOrFallthrough(IrBlock& target, IrBlock& next); + // Operand data build helpers + RegisterA64 tempDouble(IrOp op); + RegisterA64 tempInt(IrOp op); + AddressA64 tempAddr(IrOp op, int offset); + // Operand data lookup helpers RegisterA64 regOp(IrOp op) const; @@ -51,8 +58,7 @@ struct IrLoweringA64 IrFunction& function; - // TODO: - // IrRegAllocA64 regs; + IrRegAllocA64 regs; }; } // namespace A64 diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 1cc56fe31..8c45f36ad 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -4,6 +4,7 @@ #include "Luau/CodeGen.h" #include "Luau/DenseHash.h" #include "Luau/IrAnalysis.h" +#include "Luau/IrCallWrapperX64.h" #include "Luau/IrDump.h" #include "Luau/IrUtils.h" @@ -141,7 +142,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regX64 = regs.allocGprReg(SizeX64::qword); // Custom bit shift value can only be placed in cl - ScopedRegX64 shiftTmp{regs, regs.takeGprReg(rcx)}; + ScopedRegX64 shiftTmp{regs, regs.takeReg(rcx)}; ScopedRegX64 tmp{regs, SizeX64::qword}; @@ -325,82 +326,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::POW_NUM: { - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); - - ScopedRegX64 optLhsTmp{regs}; - RegisterX64 lhs; - - if (inst.a.kind == IrOpKind::Constant) - { - optLhsTmp.alloc(SizeX64::xmmword); - - build.vmovsd(optLhsTmp.reg, memRegDoubleOp(inst.a)); - lhs = optLhsTmp.reg; - } - else - { - lhs = regOp(inst.a); - } - - if (inst.b.kind == IrOpKind::Inst) - { - // TODO: this doesn't happen with current local-only register allocation, but has to be handled in the future - LUAU_ASSERT(regOp(inst.b) != xmm0); - - if (lhs != xmm0) - build.vmovsd(xmm0, lhs, lhs); - - if (regOp(inst.b) != xmm1) - build.vmovsd(xmm1, regOp(inst.b), regOp(inst.b)); - - build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); - - if (inst.regX64 != xmm0) - build.vmovsd(inst.regX64, xmm0, xmm0); - } - else if (inst.b.kind == IrOpKind::Constant) - { - double rhs = doubleOp(inst.b); - - if (rhs == 2.0) - { - build.vmulsd(inst.regX64, lhs, lhs); - } - else if (rhs == 0.5) - { - build.vsqrtsd(inst.regX64, lhs, lhs); - } - else if (rhs == 3.0) - { - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - - build.vmulsd(tmp.reg, lhs, lhs); - build.vmulsd(inst.regX64, lhs, tmp.reg); - } - else - { - if (lhs != xmm0) - build.vmovsd(xmm0, xmm0, lhs); - - build.vmovsd(xmm1, build.f64(rhs)); - build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); - - if (inst.regX64 != xmm0) - build.vmovsd(inst.regX64, xmm0, xmm0); - } - } - else - { - if (lhs != xmm0) - build.vmovsd(xmm0, lhs, lhs); - - build.vmovsd(xmm1, memRegDoubleOp(inst.b)); - build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); - - if (inst.regX64 != xmm0) - build.vmovsd(inst.regX64, xmm0, xmm0); - } - + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.a), inst.a); + callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); + inst.regX64 = regs.takeReg(xmm0); break; } case IrCmd::MIN_NUM: @@ -451,6 +381,46 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } + case IrCmd::FLOOR_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + + build.vroundsd(inst.regX64, inst.regX64, memRegDoubleOp(inst.a), RoundingModeX64::RoundToNegativeInfinity); + break; + case IrCmd::CEIL_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + + build.vroundsd(inst.regX64, inst.regX64, memRegDoubleOp(inst.a), RoundingModeX64::RoundToPositiveInfinity); + break; + case IrCmd::ROUND_NUM: + { + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; + + if (inst.a.kind != IrOpKind::Inst || regOp(inst.a) != inst.regX64) + build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + + build.vandpd(tmp1.reg, inst.regX64, build.f64x2(-0.0, -0.0)); + build.vmovsd(tmp2.reg, build.i64(0x3fdfffffffffffff)); // 0.49999999999999994 + build.vorpd(tmp1.reg, tmp1.reg, tmp2.reg); + build.vaddsd(inst.regX64, inst.regX64, tmp1.reg); + build.vroundsd(inst.regX64, inst.regX64, inst.regX64, RoundingModeX64::RoundToZero); + break; + } + case IrCmd::SQRT_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + + build.vsqrtsd(inst.regX64, inst.regX64, memRegDoubleOp(inst.a)); + break; + case IrCmd::ABS_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + + if (inst.a.kind != IrOpKind::Inst || regOp(inst.a) != inst.regX64) + build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + + build.vandpd(inst.regX64, inst.regX64, build.i64(~(1LL << 63))); + break; case IrCmd::NOT_ANY: { // TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target @@ -539,7 +509,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::JUMP_CMP_ANY: - jumpOnAnyCmpFallback(build, vmRegOp(inst.a), vmRegOp(inst.b), conditionOp(inst.c), labelOp(inst.d)); + jumpOnAnyCmpFallback(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), conditionOp(inst.c), labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; case IrCmd::JUMP_SLOT_MATCH: @@ -551,34 +521,34 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::TABLE_LEN: - inst.regX64 = regs.allocXmmReg(); + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]); - build.mov(rArg1, regOp(inst.a)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]); + inst.regX64 = regs.allocXmmReg(); build.vcvtsi2sd(inst.regX64, inst.regX64, eax); break; + } case IrCmd::NEW_TABLE: - inst.regX64 = regs.allocGprReg(SizeX64::qword); - - build.mov(rArg1, rState); - build.mov(dwordReg(rArg2), uintOp(inst.a)); - build.mov(dwordReg(rArg3), uintOp(inst.b)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); - - if (inst.regX64 != rax) - build.mov(inst.regX64, rax); + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.a)), inst.a); + callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.b)), inst.b); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); + inst.regX64 = regs.takeReg(rax); break; + } case IrCmd::DUP_TABLE: - inst.regX64 = regs.allocGprReg(SizeX64::qword); - - // Re-ordered to avoid register conflict - build.mov(rArg2, regOp(inst.a)); - build.mov(rArg1, rState); - build.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]); - - if (inst.regX64 != rax) - build.mov(inst.regX64, rax); + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]); + inst.regX64 = regs.takeReg(rax); break; + } case IrCmd::TRY_NUM_TO_INDEX: { inst.regX64 = regs.allocGprReg(SizeX64::dword); @@ -590,12 +560,26 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::TRY_CALL_FASTGETTM: { - inst.regX64 = regs.allocGprReg(SizeX64::qword); + ScopedRegX64 tmp{regs, SizeX64::qword}; + + build.mov(tmp.reg, qword[regOp(inst.a) + offsetof(Table, metatable)]); + regs.freeLastUseReg(function.instOp(inst.a), index); // Release before the call if it's the last use - callGetFastTmOrFallback(build, regOp(inst.a), TMS(intOp(inst.b)), labelOp(inst.c)); + build.test(tmp.reg, tmp.reg); + build.jcc(ConditionX64::Zero, labelOp(inst.c)); // No metatable - if (inst.regX64 != rax) - build.mov(inst.regX64, rax); + build.test(byte[tmp.reg + offsetof(Table, tmcache)], 1 << intOp(inst.b)); + build.jcc(ConditionX64::NotZero, labelOp(inst.c)); // No tag method + + ScopedRegX64 tmp2{regs, SizeX64::qword}; + build.mov(tmp2.reg, qword[rState + offsetof(lua_State, global)]); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.addArgument(SizeX64::qword, intOp(inst.b)); + callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); + inst.regX64 = regs.takeReg(rax); break; } case IrCmd::INT_TO_NUM: @@ -701,7 +685,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.call(rax); - inst.regX64 = regs.takeGprReg(eax); // Result of a builtin call is returned in eax + inst.regX64 = regs.takeReg(eax); // Result of a builtin call is returned in eax break; } case IrCmd::CHECK_FASTCALL_RES: @@ -714,23 +698,23 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::DO_ARITH: if (inst.c.kind == IrOpKind::VmReg) - callArithHelper(build, vmRegOp(inst.a), vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), TMS(intOp(inst.d))); + callArithHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), TMS(intOp(inst.d))); else - callArithHelper(build, vmRegOp(inst.a), vmRegOp(inst.b), luauConstantAddress(vmConstOp(inst.c)), TMS(intOp(inst.d))); + callArithHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), luauConstantAddress(vmConstOp(inst.c)), TMS(intOp(inst.d))); break; case IrCmd::DO_LEN: - callLengthHelper(build, vmRegOp(inst.a), vmRegOp(inst.b)); + callLengthHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b)); break; case IrCmd::GET_TABLE: if (inst.c.kind == IrOpKind::VmReg) { - callGetTable(build, vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), vmRegOp(inst.a)); + callGetTable(regs, build, vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), vmRegOp(inst.a)); } else if (inst.c.kind == IrOpKind::Constant) { TValue n; setnvalue(&n, uintOp(inst.c)); - callGetTable(build, vmRegOp(inst.b), build.bytes(&n, sizeof(n)), vmRegOp(inst.a)); + callGetTable(regs, build, vmRegOp(inst.b), build.bytes(&n, sizeof(n)), vmRegOp(inst.a)); } else { @@ -740,13 +724,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::SET_TABLE: if (inst.c.kind == IrOpKind::VmReg) { - callSetTable(build, vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), vmRegOp(inst.a)); + callSetTable(regs, build, vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), vmRegOp(inst.a)); } else if (inst.c.kind == IrOpKind::Constant) { TValue n; setnvalue(&n, uintOp(inst.c)); - callSetTable(build, vmRegOp(inst.b), build.bytes(&n, sizeof(n)), vmRegOp(inst.a)); + callSetTable(regs, build, vmRegOp(inst.b), build.bytes(&n, sizeof(n)), vmRegOp(inst.a)); } else { @@ -757,13 +741,16 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) emitInstGetImportFallback(build, vmRegOp(inst.a), uintOp(inst.b)); break; case IrCmd::CONCAT: - build.mov(rArg1, rState); - build.mov(dwordReg(rArg2), uintOp(inst.b)); - build.mov(dwordReg(rArg3), vmRegOp(inst.a) + uintOp(inst.b) - 1); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]); + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.b))); + callWrap.addArgument(SizeX64::dword, int32_t(vmRegOp(inst.a) + uintOp(inst.b) - 1)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]); emitUpdateBase(build); break; + } case IrCmd::GET_UPVALUE: { ScopedRegX64 tmp1{regs, SizeX64::qword}; @@ -793,21 +780,26 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label next; ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::qword}; - ScopedRegX64 tmp3{regs, SizeX64::xmmword}; build.mov(tmp1.reg, sClosure); build.mov(tmp2.reg, qword[tmp1.reg + offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.a) + offsetof(TValue, value.gc)]); build.mov(tmp1.reg, qword[tmp2.reg + offsetof(UpVal, v)]); - build.vmovups(tmp3.reg, luauReg(vmRegOp(inst.b))); - build.vmovups(xmmword[tmp1.reg], tmp3.reg); - callBarrierObject(build, tmp1.reg, tmp2.reg, vmRegOp(inst.b), next); + { + ScopedRegX64 tmp3{regs, SizeX64::xmmword}; + build.vmovups(tmp3.reg, luauReg(vmRegOp(inst.b))); + build.vmovups(xmmword[tmp1.reg], tmp3.reg); + } + + tmp1.free(); + + callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b), next); build.setLabel(next); break; } case IrCmd::PREPARE_FORN: - callPrepareForN(build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); + callPrepareForN(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); break; case IrCmd::CHECK_TAG: if (inst.a.kind == IrOpKind::Inst) @@ -863,38 +855,43 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpIfNodeHasNext(build, regOp(inst.a), labelOp(inst.b)); break; case IrCmd::INTERRUPT: + regs.assertAllFree(); emitInterrupt(build, uintOp(inst.a)); break; case IrCmd::CHECK_GC: { Label skip; - callCheckGc(build, -1, false, skip); + callCheckGc(regs, build, skip); build.setLabel(skip); break; } case IrCmd::BARRIER_OBJ: { Label skip; - ScopedRegX64 tmp{regs, SizeX64::qword}; - - callBarrierObject(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), skip); + callBarrierObject(regs, build, regOp(inst.a), inst.a, vmRegOp(inst.b), skip); build.setLabel(skip); break; } case IrCmd::BARRIER_TABLE_BACK: { Label skip; - - callBarrierTableFast(build, regOp(inst.a), skip); + callBarrierTableFast(regs, build, regOp(inst.a), inst.a, skip); build.setLabel(skip); break; } case IrCmd::BARRIER_TABLE_FORWARD: { Label skip; + ScopedRegX64 tmp{regs, SizeX64::qword}; + checkObjectBarrierConditions(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), skip); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barriertable)]); - callBarrierTable(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), skip); build.setLabel(skip); break; } @@ -926,11 +923,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(tmp2.reg, qword[tmp1.reg + offsetof(UpVal, v)]); build.jcc(ConditionX64::Above, next); - if (rArg2 != tmp2.reg) - build.mov(rArg2, tmp2.reg); + tmp1.free(); - build.mov(rArg1, rState); - build.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); build.setLabel(next); break; @@ -940,42 +938,53 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; // Fallbacks to non-IR instruction implementations - case IrCmd::LOP_SETLIST: + case IrCmd::SETLIST: { Label next; - emitInstSetList(build, next, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); + regs.assertAllFree(); + emitInstSetList(regs, build, next, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); build.setLabel(next); break; } - case IrCmd::LOP_CALL: + case IrCmd::CALL: + regs.assertAllFree(); emitInstCall(build, helpers, vmRegOp(inst.a), intOp(inst.b), intOp(inst.c)); break; - case IrCmd::LOP_RETURN: + case IrCmd::RETURN: + regs.assertAllFree(); emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); break; - case IrCmd::LOP_FORGLOOP: + case IrCmd::FORGLOOP: + regs.assertAllFree(); emitinstForGLoop(build, vmRegOp(inst.a), intOp(inst.b), labelOp(inst.c), labelOp(inst.d)); break; - case IrCmd::LOP_FORGLOOP_FALLBACK: - emitinstForGLoopFallback(build, uintOp(inst.a), vmRegOp(inst.b), intOp(inst.c), labelOp(inst.d)); - build.jmp(labelOp(inst.e)); + case IrCmd::FORGLOOP_FALLBACK: + regs.assertAllFree(); + emitinstForGLoopFallback(build, vmRegOp(inst.a), intOp(inst.b), labelOp(inst.c)); + build.jmp(labelOp(inst.d)); break; - case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + case IrCmd::FORGPREP_XNEXT_FALLBACK: + regs.assertAllFree(); emitInstForGPrepXnextFallback(build, uintOp(inst.a), vmRegOp(inst.b), labelOp(inst.c)); break; - case IrCmd::LOP_AND: + case IrCmd::AND: + regs.assertAllFree(); emitInstAnd(build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); break; - case IrCmd::LOP_ANDK: + case IrCmd::ANDK: + regs.assertAllFree(); emitInstAndK(build, vmRegOp(inst.a), vmRegOp(inst.b), vmConstOp(inst.c)); break; - case IrCmd::LOP_OR: + case IrCmd::OR: + regs.assertAllFree(); emitInstOr(build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); break; - case IrCmd::LOP_ORK: + case IrCmd::ORK: + regs.assertAllFree(); emitInstOrK(build, vmRegOp(inst.a), vmRegOp(inst.b), vmConstOp(inst.c)); break; - case IrCmd::LOP_COVERAGE: + case IrCmd::COVERAGE: + regs.assertAllFree(); emitInstCoverage(build, uintOp(inst.a)); break; @@ -984,12 +993,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_GETGLOBAL, uintOp(inst.a)); break; case IrCmd::FALLBACK_SETGLOBAL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_SETGLOBAL, uintOp(inst.a)); break; case IrCmd::FALLBACK_GETTABLEKS: @@ -997,6 +1008,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_GETTABLEKS, uintOp(inst.a)); break; case IrCmd::FALLBACK_SETTABLEKS: @@ -1004,6 +1016,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_SETTABLEKS, uintOp(inst.a)); break; case IrCmd::FALLBACK_NAMECALL: @@ -1011,32 +1024,38 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_NAMECALL, uintOp(inst.a)); break; case IrCmd::FALLBACK_PREPVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); + regs.assertAllFree(); emitFallback(build, data, LOP_PREPVARARGS, uintOp(inst.a)); break; case IrCmd::FALLBACK_GETVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + regs.assertAllFree(); emitFallback(build, data, LOP_GETVARARGS, uintOp(inst.a)); break; case IrCmd::FALLBACK_NEWCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + regs.assertAllFree(); emitFallback(build, data, LOP_NEWCLOSURE, uintOp(inst.a)); break; case IrCmd::FALLBACK_DUPCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_DUPCLOSURE, uintOp(inst.a)); break; case IrCmd::FALLBACK_FORGPREP: + regs.assertAllFree(); emitFallback(build, data, LOP_FORGPREP, uintOp(inst.a)); break; default: diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index c8ebd1f18..ecaa6a1d5 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -3,8 +3,7 @@ #include "Luau/AssemblyBuilderX64.h" #include "Luau/IrData.h" - -#include "IrRegAllocX64.h" +#include "Luau/IrRegAllocX64.h" #include diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp new file mode 100644 index 000000000..3609c8e25 --- /dev/null +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -0,0 +1,174 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "IrRegAllocA64.h" + +#ifdef _MSC_VER +#include +#endif + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +inline int setBit(uint32_t n) +{ + LUAU_ASSERT(n); + +#ifdef _MSC_VER + unsigned long rl; + _BitScanReverse(&rl, n); + return int(rl); +#else + return 31 - __builtin_clz(n); +#endif +} + +IrRegAllocA64::IrRegAllocA64(IrFunction& function, std::initializer_list> regs) + : function(function) +{ + for (auto& p : regs) + { + LUAU_ASSERT(p.first.kind == p.second.kind && p.first.index <= p.second.index); + + Set& set = getSet(p.first.kind); + + for (int i = p.first.index; i <= p.second.index; ++i) + set.base |= 1u << i; + } + + gpr.free = gpr.base; + simd.free = simd.base; +} + +RegisterA64 IrRegAllocA64::allocReg(KindA64 kind) +{ + Set& set = getSet(kind); + + if (set.free == 0) + { + LUAU_ASSERT(!"Out of registers to allocate"); + return noreg; + } + + int index = setBit(set.free); + set.free &= ~(1u << index); + + return {kind, uint8_t(index)}; +} + +RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind) +{ + Set& set = getSet(kind); + + if (set.free == 0) + { + LUAU_ASSERT(!"Out of registers to allocate"); + return noreg; + } + + int index = setBit(set.free); + + set.free &= ~(1u << index); + set.temp |= 1u << index; + + return {kind, uint8_t(index)}; +} + +RegisterA64 IrRegAllocA64::allocReuse(KindA64 kind, uint32_t index, std::initializer_list oprefs) +{ + for (IrOp op : oprefs) + { + if (op.kind != IrOpKind::Inst) + continue; + + IrInst& source = function.instructions[op.index]; + + if (source.lastUse == index && !source.reusedReg) + { + LUAU_ASSERT(source.regA64.kind == kind); + + source.reusedReg = true; + return source.regA64; + } + } + + return allocReg(kind); +} + +void IrRegAllocA64::freeReg(RegisterA64 reg) +{ + Set& set = getSet(reg.kind); + + LUAU_ASSERT((set.base & (1u << reg.index)) != 0); + LUAU_ASSERT((set.free & (1u << reg.index)) == 0); + set.free |= 1u << reg.index; +} + +void IrRegAllocA64::freeLastUseReg(IrInst& target, uint32_t index) +{ + if (target.lastUse == index && !target.reusedReg) + { + // Register might have already been freed if it had multiple uses inside a single instruction + if (target.regA64 == noreg) + return; + + freeReg(target.regA64); + target.regA64 = noreg; + } +} + +void IrRegAllocA64::freeLastUseRegs(const IrInst& inst, uint32_t index) +{ + auto checkOp = [this, index](IrOp op) { + if (op.kind == IrOpKind::Inst) + freeLastUseReg(function.instructions[op.index], index); + }; + + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + checkOp(inst.f); +} + +void IrRegAllocA64::freeTempRegs() +{ + LUAU_ASSERT((gpr.free & gpr.temp) == 0); + gpr.free |= gpr.temp; + gpr.temp = 0; + + LUAU_ASSERT((simd.free & simd.temp) == 0); + simd.free |= simd.temp; + simd.temp = 0; +} + +void IrRegAllocA64::assertAllFree() const +{ + LUAU_ASSERT(gpr.free == gpr.base); + LUAU_ASSERT(simd.free == simd.base); +} + +IrRegAllocA64::Set& IrRegAllocA64::getSet(KindA64 kind) +{ + switch (kind) + { + case KindA64::x: + case KindA64::w: + return gpr; + + case KindA64::d: + case KindA64::q: + return simd; + + default: + LUAU_ASSERT(!"Unexpected register kind"); + LUAU_UNREACHABLE(); + } +} + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrRegAllocA64.h b/CodeGen/src/IrRegAllocA64.h new file mode 100644 index 000000000..2ed0787aa --- /dev/null +++ b/CodeGen/src/IrRegAllocA64.h @@ -0,0 +1,55 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrData.h" +#include "Luau/RegisterA64.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +struct IrRegAllocA64 +{ + IrRegAllocA64(IrFunction& function, std::initializer_list> regs); + + RegisterA64 allocReg(KindA64 kind); + RegisterA64 allocTemp(KindA64 kind); + RegisterA64 allocReuse(KindA64 kind, uint32_t index, std::initializer_list oprefs); + + void freeReg(RegisterA64 reg); + + void freeLastUseReg(IrInst& target, uint32_t index); + void freeLastUseRegs(const IrInst& inst, uint32_t index); + + void freeTempRegs(); + + void assertAllFree() const; + + IrFunction& function; + + struct Set + { + // which registers are in the set that the allocator manages (initialized at construction) + uint32_t base = 0; + + // which subset of initial set is free + uint32_t free = 0; + + // which subset of initial set is allocated as temporary + uint32_t temp = 0; + }; + + Set gpr, simd; + + Set& getSet(KindA64 kind); +}; + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index c527d033f..eeb6cfe69 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -1,19 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "IrRegAllocX64.h" - -#include "Luau/CodeGen.h" -#include "Luau/DenseHash.h" -#include "Luau/IrAnalysis.h" -#include "Luau/IrDump.h" -#include "Luau/IrUtils.h" - -#include "EmitCommonX64.h" -#include "EmitInstructionX64.h" -#include "NativeState.h" - -#include "lstate.h" - -#include +#include "Luau/IrRegAllocX64.h" namespace Luau { @@ -108,13 +94,21 @@ RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_l return allocXmmReg(); } -RegisterX64 IrRegAllocX64::takeGprReg(RegisterX64 reg) +RegisterX64 IrRegAllocX64::takeReg(RegisterX64 reg) { // In a more advanced register allocator, this would require a spill for the current register user // But at the current stage we don't have register live ranges intersecting forced register uses - LUAU_ASSERT(freeGprMap[reg.index]); + if (reg.size == SizeX64::xmmword) + { + LUAU_ASSERT(freeXmmMap[reg.index]); + freeXmmMap[reg.index] = false; + } + else + { + LUAU_ASSERT(freeGprMap[reg.index]); + freeGprMap[reg.index] = false; + } - freeGprMap[reg.index] = false; return reg; } @@ -134,7 +128,7 @@ void IrRegAllocX64::freeReg(RegisterX64 reg) void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t index) { - if (target.lastUse == index && !target.reusedReg) + if (isLastUseReg(target, index)) { // Register might have already been freed if it had multiple uses inside a single instruction if (target.regX64 == noreg) @@ -160,6 +154,35 @@ void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index) checkOp(inst.f); } +bool IrRegAllocX64::isLastUseReg(const IrInst& target, uint32_t index) const +{ + return target.lastUse == index && !target.reusedReg; +} + +bool IrRegAllocX64::shouldFreeGpr(RegisterX64 reg) const +{ + if (reg == noreg) + return false; + + LUAU_ASSERT(reg.size != SizeX64::xmmword); + + for (RegisterX64 gpr : kGprAllocOrder) + { + if (reg.index == gpr.index) + return true; + } + + return false; +} + +void IrRegAllocX64::assertFree(RegisterX64 reg) const +{ + if (reg.size == SizeX64::xmmword) + LUAU_ASSERT(freeXmmMap[reg.index]); + else + LUAU_ASSERT(freeGprMap[reg.index]); +} + void IrRegAllocX64::assertAllFree() const { for (RegisterX64 reg : kGprAllocOrder) @@ -211,6 +234,13 @@ void ScopedRegX64::free() reg = noreg; } +RegisterX64 ScopedRegX64::release() +{ + RegisterX64 tmp = reg; + reg = noreg; + return tmp; +} + } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index cb8e41482..2955aaffb 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -6,7 +6,6 @@ #include "lstate.h" -// TODO: should be possible to handle fastcalls in contexts where nresults is -1 by adding the adjustment instruction // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results namespace Luau @@ -26,8 +25,8 @@ BuiltinImplResult translateBuiltinNumberToNumber( build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 1}; } @@ -43,8 +42,8 @@ BuiltinImplResult translateBuiltin2NumberToNumber( build.loadAndCheckTag(args, LUA_TNUMBER, fallback); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO:tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 1}; } @@ -59,8 +58,9 @@ BuiltinImplResult translateBuiltinNumberTo2Number( build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO: some tag updates might not be required, we place them here now because FASTCALL is not modeled in constant propagation yet - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 2}; @@ -131,8 +131,8 @@ BuiltinImplResult translateBuiltinMathLog( build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 1}; } @@ -210,6 +210,44 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r return {BuiltinImplType::UsesFallback, 1}; } +BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + + IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp result = build.inst(cmd, varg); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), result); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathBinary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + + IrOp lhs = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp rhs = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp result = build.inst(cmd, lhs, rhs); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), result); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) @@ -218,7 +256,6 @@ BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, in build.inst( IrCmd::FASTCALL, build.constUint(LBF_TYPE), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); return {BuiltinImplType::UsesFallback, 1}; @@ -232,7 +269,6 @@ BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, build.inst( IrCmd::FASTCALL, build.constUint(LBF_TYPEOF), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); return {BuiltinImplType::UsesFallback, 1}; @@ -261,9 +297,17 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_CLAMP: return translateBuiltinMathClamp(build, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_FLOOR: + return translateBuiltinMathUnary(build, IrCmd::FLOOR_NUM, nparams, ra, arg, nresults, fallback); case LBF_MATH_CEIL: + return translateBuiltinMathUnary(build, IrCmd::CEIL_NUM, nparams, ra, arg, nresults, fallback); case LBF_MATH_SQRT: + return translateBuiltinMathUnary(build, IrCmd::SQRT_NUM, nparams, ra, arg, nresults, fallback); case LBF_MATH_ABS: + return translateBuiltinMathUnary(build, IrCmd::ABS_NUM, nparams, ra, arg, nresults, fallback); + case LBF_MATH_ROUND: + return translateBuiltinMathUnary(build, IrCmd::ROUND_NUM, nparams, ra, arg, nresults, fallback); + case LBF_MATH_POW: + return translateBuiltinMathBinary(build, IrCmd::POW_NUM, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_EXP: case LBF_MATH_ASIN: case LBF_MATH_SIN: @@ -275,11 +319,9 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_TAN: case LBF_MATH_TANH: case LBF_MATH_LOG10: - case LBF_MATH_ROUND: case LBF_MATH_SIGN: return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); case LBF_MATH_FMOD: - case LBF_MATH_POW: case LBF_MATH_ATAN2: case LBF_MATH_LDEXP: return translateBuiltin2NumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index d90841ce3..e366888e6 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -296,46 +296,60 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb)); IrOp vc; + IrOp result; + if (opc.kind == IrOpKind::VmConst) { LUAU_ASSERT(build.function.proto); TValue protok = build.function.proto->k[opc.index]; LUAU_ASSERT(protok.tt == LUA_TNUMBER); - vc = build.constDouble(protok.value.n); + + // VM has special cases for exponentiation with constants + if (tm == TM_POW && protok.value.n == 0.5) + result = build.inst(IrCmd::SQRT_NUM, vb); + else if (tm == TM_POW && protok.value.n == 2.0) + result = build.inst(IrCmd::MUL_NUM, vb, vb); + else if (tm == TM_POW && protok.value.n == 3.0) + result = build.inst(IrCmd::MUL_NUM, vb, build.inst(IrCmd::MUL_NUM, vb, vb)); + else + vc = build.constDouble(protok.value.n); } else { vc = build.inst(IrCmd::LOAD_DOUBLE, opc); } - IrOp va; - - switch (tm) + if (result.kind == IrOpKind::None) { - case TM_ADD: - va = build.inst(IrCmd::ADD_NUM, vb, vc); - break; - case TM_SUB: - va = build.inst(IrCmd::SUB_NUM, vb, vc); - break; - case TM_MUL: - va = build.inst(IrCmd::MUL_NUM, vb, vc); - break; - case TM_DIV: - va = build.inst(IrCmd::DIV_NUM, vb, vc); - break; - case TM_MOD: - va = build.inst(IrCmd::MOD_NUM, vb, vc); - break; - case TM_POW: - va = build.inst(IrCmd::POW_NUM, vb, vc); - break; - default: - LUAU_ASSERT(!"unsupported binary op"); + LUAU_ASSERT(vc.kind != IrOpKind::None); + + switch (tm) + { + case TM_ADD: + result = build.inst(IrCmd::ADD_NUM, vb, vc); + break; + case TM_SUB: + result = build.inst(IrCmd::SUB_NUM, vb, vc); + break; + case TM_MUL: + result = build.inst(IrCmd::MUL_NUM, vb, vc); + break; + case TM_DIV: + result = build.inst(IrCmd::DIV_NUM, vb, vc); + break; + case TM_MOD: + result = build.inst(IrCmd::MOD_NUM, vb, vc); + break; + case TM_POW: + result = build.inst(IrCmd::POW_NUM, vb, vc); + break; + default: + LUAU_ASSERT(!"unsupported binary op"); + } } - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), va); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), result); if (ra != rb && ra != rc) // TODO: optimization should handle second check, but we'll test this later build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); @@ -638,7 +652,7 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo build.inst(IrCmd::JUMP, target); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); + build.inst(IrCmd::FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); } void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcpos) @@ -670,7 +684,7 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp build.inst(IrCmd::JUMP, target); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); + build.inst(IrCmd::FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); } void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pcpos) @@ -721,7 +735,8 @@ void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pc build.inst(IrCmd::JUMP, loopRepeat); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGLOOP_FALLBACK, build.constUint(pcpos), build.vmReg(ra), build.constInt(int(pc[1])), loopRepeat, loopExit); + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::FORGLOOP_FALLBACK, build.vmReg(ra), build.constInt(int(pc[1])), loopRepeat, loopExit); // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(loopExit)) diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index b28ce596e..45e2bae09 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -320,6 +320,26 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(-function.doubleOp(inst.a))); break; + case IrCmd::FLOOR_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(floor(function.doubleOp(inst.a)))); + break; + case IrCmd::CEIL_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(ceil(function.doubleOp(inst.a)))); + break; + case IrCmd::ROUND_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(round(function.doubleOp(inst.a)))); + break; + case IrCmd::SQRT_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(sqrt(function.doubleOp(inst.a)))); + break; + case IrCmd::ABS_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(fabs(function.doubleOp(inst.a)))); + break; case IrCmd::NOT_ANY: if (inst.a.kind == IrOpKind::Constant) { diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index f1497890c..ddc9c03d1 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -109,6 +109,8 @@ void initHelperFunctions(NativeState& data) data.context.forgPrepXnextFallback = forgPrepXnextFallback; data.context.callProlog = callProlog; data.context.callEpilogC = callEpilogC; + + data.context.callFallback = callFallback; data.context.returnFallback = returnFallback; } diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 6d8331896..2d97e63ca 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -101,7 +101,9 @@ struct NativeContext void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr; void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr; - const Instruction* (*returnFallback)(lua_State* L, StkId ra, int n) = nullptr; + + Closure* (*callFallback)(lua_State* L, StkId ra, StkId argtop, int nresults) = nullptr; + Closure* (*returnFallback)(lua_State* L, StkId ra, int n) = nullptr; // Opcode fallbacks, implemented in C NativeFallback fallback[LOP__COUNT] = {}; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 672364764..f767f5496 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -503,10 +503,10 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } break; - case IrCmd::LOP_AND: - case IrCmd::LOP_ANDK: - case IrCmd::LOP_OR: - case IrCmd::LOP_ORK: + case IrCmd::AND: + case IrCmd::ANDK: + case IrCmd::OR: + case IrCmd::ORK: state.invalidate(inst.a); break; case IrCmd::FASTCALL: @@ -533,6 +533,11 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: + case IrCmd::FLOOR_NUM: + case IrCmd::CEIL_NUM: + case IrCmd::ROUND_NUM: + case IrCmd::SQRT_NUM: + case IrCmd::ABS_NUM: case IrCmd::NOT_ANY: case IrCmd::JUMP: case IrCmd::JUMP_EQ_POINTER: @@ -547,10 +552,10 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::CHECK_SLOT_MATCH: case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::BARRIER_TABLE_BACK: - case IrCmd::LOP_RETURN: - case IrCmd::LOP_COVERAGE: + case IrCmd::RETURN: + case IrCmd::COVERAGE: case IrCmd::SET_UPVALUE: - case IrCmd::LOP_SETLIST: // We don't track table state that this can invalidate + case IrCmd::SETLIST: // We don't track table state that this can invalidate case IrCmd::SET_SAVEDPC: // TODO: we may be able to remove some updates to PC case IrCmd::CLOSE_UPVALS: // Doesn't change memory that we track case IrCmd::CAPTURE: @@ -599,18 +604,18 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::INTERRUPT: state.invalidateUserCall(); break; - case IrCmd::LOP_CALL: + case IrCmd::CALL: state.invalidateRegistersFrom(inst.a.index); state.invalidateUserCall(); break; - case IrCmd::LOP_FORGLOOP: + case IrCmd::FORGLOOP: state.invalidateRegistersFrom(inst.a.index + 2); // Rn and Rn+1 are not modified break; - case IrCmd::LOP_FORGLOOP_FALLBACK: - state.invalidateRegistersFrom(inst.b.index + 2); // Rn and Rn+1 are not modified + case IrCmd::FORGLOOP_FALLBACK: + state.invalidateRegistersFrom(inst.a.index + 2); // Rn and Rn+1 are not modified state.invalidateUserCall(); break; - case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + case IrCmd::FORGPREP_XNEXT_FALLBACK: // This fallback only conditionally throws an exception break; case IrCmd::FALLBACK_GETGLOBAL: diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 03f4b3e69..9478404a0 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -25,8 +25,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinArity, false) - namespace Luau { @@ -295,7 +293,7 @@ struct Compiler // handles builtin calls that can't be constant-folded but are known to return one value // note: optimizationLevel check is technically redundant but it's important that we never optimize based on builtins in O1 - if (FFlag::LuauCompileBuiltinArity && options.optimizationLevel >= 2) + if (options.optimizationLevel >= 2) if (int* bfid = builtins.find(expr)) return getBuiltinInfo(*bfid).results != 1; @@ -766,7 +764,7 @@ struct Compiler { if (!isExprMultRet(expr->args.data[expr->args.size - 1])) return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); - else if (FFlag::LuauCompileBuiltinArity && options.optimizationLevel >= 2 && int(expr->args.size) == getBuiltinInfo(bfid).params) + else if (options.optimizationLevel >= 2 && int(expr->args.size) == getBuiltinInfo(bfid).params) return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); } diff --git a/Sources.cmake b/Sources.cmake index 3f32aab83..3508ec39e 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -65,8 +65,10 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/ConditionX64.h CodeGen/include/Luau/IrAnalysis.h CodeGen/include/Luau/IrBuilder.h + CodeGen/include/Luau/IrCallWrapperX64.h CodeGen/include/Luau/IrDump.h CodeGen/include/Luau/IrData.h + CodeGen/include/Luau/IrRegAllocX64.h CodeGen/include/Luau/IrUtils.h CodeGen/include/Luau/Label.h CodeGen/include/Luau/OperandX64.h @@ -94,9 +96,11 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/Fallbacks.cpp CodeGen/src/IrAnalysis.cpp CodeGen/src/IrBuilder.cpp + CodeGen/src/IrCallWrapperX64.cpp CodeGen/src/IrDump.cpp CodeGen/src/IrLoweringA64.cpp CodeGen/src/IrLoweringX64.cpp + CodeGen/src/IrRegAllocA64.cpp CodeGen/src/IrRegAllocX64.cpp CodeGen/src/IrTranslateBuiltins.cpp CodeGen/src/IrTranslation.cpp @@ -122,7 +126,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/FallbacksProlog.h CodeGen/src/IrLoweringA64.h CodeGen/src/IrLoweringX64.h - CodeGen/src/IrRegAllocX64.h + CodeGen/src/IrRegAllocA64.h CodeGen/src/IrTranslateBuiltins.h CodeGen/src/IrTranslation.h CodeGen/src/NativeState.h @@ -342,6 +346,7 @@ if(TARGET Luau.UnitTest) tests/Fixture.h tests/IostreamOptional.h tests/ScopedFlags.h + tests/AssemblyBuilderA64.test.cpp tests/AssemblyBuilderX64.test.cpp tests/AstJsonEncoder.test.cpp tests/AstQuery.test.cpp @@ -358,6 +363,7 @@ if(TARGET Luau.UnitTest) tests/Error.test.cpp tests/Frontend.test.cpp tests/IrBuilder.test.cpp + tests/IrCallWrapperX64.test.cpp tests/JsonEmitter.test.cpp tests/Lexer.test.cpp tests/Linter.test.cpp diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 3c669bff9..e0dc8a38f 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -23,8 +23,6 @@ #endif #endif -LUAU_FASTFLAGVARIABLE(LuauBuiltinSSE41, false) - // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -105,9 +103,7 @@ static int luauF_atan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } -// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN -LUAU_NOINLINE static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -170,9 +166,7 @@ static int luauF_exp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } -// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN -LUAU_NOINLINE static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -949,9 +943,7 @@ static int luauF_sign(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } -// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN -LUAU_NOINLINE static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -1271,9 +1263,6 @@ LUAU_TARGET_SSE41 inline double roundsd_sse41(double v) LUAU_TARGET_SSE41 static int luauF_floor_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (!FFlag::LuauBuiltinSSE41) - return luauF_floor(L, res, arg0, nresults, args, nparams); - if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) { double a1 = nvalue(arg0); @@ -1286,9 +1275,6 @@ LUAU_TARGET_SSE41 static int luauF_floor_sse41(lua_State* L, StkId res, TValue* LUAU_TARGET_SSE41 static int luauF_ceil_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (!FFlag::LuauBuiltinSSE41) - return luauF_ceil(L, res, arg0, nresults, args, nparams); - if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) { double a1 = nvalue(arg0); @@ -1301,9 +1287,6 @@ LUAU_TARGET_SSE41 static int luauF_ceil_sse41(lua_State* L, StkId res, TValue* a LUAU_TARGET_SSE41 static int luauF_round_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (!FFlag::LuauBuiltinSSE41) - return luauF_round(L, res, arg0, nresults, args, nparams); - if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) { double a1 = nvalue(arg0); diff --git a/fuzz/format.cpp b/fuzz/format.cpp index 3ad3912f3..4b943bf1b 100644 --- a/fuzz/format.cpp +++ b/fuzz/format.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Common.h" +#include #include #include diff --git a/fuzz/linter.cpp b/fuzz/linter.cpp index 66ca5bb14..854c63277 100644 --- a/fuzz/linter.cpp +++ b/fuzz/linter.cpp @@ -3,10 +3,10 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/Frontend.h" #include "Luau/Linter.h" #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" -#include "Luau/TypeInfer.h" extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) { @@ -18,18 +18,17 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) Luau::ParseResult parseResult = Luau::Parser::parse(reinterpret_cast(Data), Size, names, allocator, options); // "static" here is to accelerate fuzzing process by only creating and populating the type environment once - static Luau::NullModuleResolver moduleResolver; - static Luau::InternalErrorReporter iceHandler; - static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); - static int once = (Luau::registerBuiltinGlobals(sharedEnv), 1); + static Luau::NullFileResolver fileResolver; + static Luau::NullConfigResolver configResolver; + static Luau::Frontend frontend{&fileResolver, &configResolver}; + static int once = (Luau::registerBuiltinGlobals(frontend), 1); (void)once; - static int once2 = (Luau::freeze(sharedEnv.globalTypes), 1); + static int once2 = (Luau::freeze(frontend.globals.globalTypes), 1); (void)once2; if (parseResult.errors.empty()) { - Luau::TypeChecker typeck(&moduleResolver, &iceHandler); - typeck.globalScope = sharedEnv.globalScope; + Luau::TypeChecker typeck(frontend.globals.globalScope, &frontend.moduleResolver, frontend.builtinTypes, &frontend.iceHandler); Luau::LintOptions lintOptions; lintOptions.warningMask = ~0ull; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index c94f0889b..ffeb49195 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -261,8 +261,8 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) { static FuzzFileResolver fileResolver; static FuzzConfigResolver configResolver; - static Luau::FrontendOptions options{true, true}; - static Luau::Frontend frontend(&fileResolver, &configResolver, options); + static Luau::FrontendOptions defaultOptions{/*retainFullTypeGraphs*/ true, /*forAutocomplete*/ false, /*runLintChecks*/ kFuzzLinter}; + static Luau::Frontend frontend(&fileResolver, &configResolver, defaultOptions); static int once = (setupFrontend(frontend), 0); (void)once; @@ -285,16 +285,12 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) try { - Luau::CheckResult result = frontend.check(name, std::nullopt); - - // lint (note that we need access to types so we need to do this with typeck in scope) - if (kFuzzLinter && result.errors.empty()) - frontend.lint(name, std::nullopt); + frontend.check(name); // Second pass in strict mode (forced by auto-complete) - Luau::FrontendOptions opts; - opts.forAutocomplete = true; - frontend.check(name, opts); + Luau::FrontendOptions options = defaultOptions; + options.forAutocomplete = true; + frontend.check(name, options); } catch (std::exception&) { diff --git a/fuzz/typeck.cpp b/fuzz/typeck.cpp index a6c9ae284..4f8f88575 100644 --- a/fuzz/typeck.cpp +++ b/fuzz/typeck.cpp @@ -3,9 +3,9 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/Frontend.h" #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" -#include "Luau/TypeInfer.h" LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) @@ -23,23 +23,22 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) Luau::ParseResult parseResult = Luau::Parser::parse(reinterpret_cast(Data), Size, names, allocator, options); // "static" here is to accelerate fuzzing process by only creating and populating the type environment once - static Luau::NullModuleResolver moduleResolver; - static Luau::InternalErrorReporter iceHandler; - static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); - static int once = (Luau::registerBuiltinGlobals(sharedEnv), 1); + static Luau::NullFileResolver fileResolver; + static Luau::NullConfigResolver configResolver; + static Luau::Frontend frontend{&fileResolver, &configResolver}; + static int once = (Luau::registerBuiltinGlobals(frontend), 1); (void)once; - static int once2 = (Luau::freeze(sharedEnv.globalTypes), 1); + static int once2 = (Luau::freeze(frontend.globals.globalTypes), 1); (void)once2; if (parseResult.errors.empty()) { + Luau::TypeChecker typeck(frontend.globals.globalScope, &frontend.moduleResolver, frontend.builtinTypes, &frontend.iceHandler); + Luau::SourceModule module; module.root = parseResult.root; module.mode = Luau::Mode::Nonstrict; - Luau::TypeChecker typeck(&moduleResolver, &iceHandler); - typeck.globalScope = sharedEnv.globalScope; - try { typeck.check(module, Luau::Mode::Nonstrict); diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index a68932bac..1690c748c 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -32,9 +32,9 @@ static std::string bytecodeAsArray(const std::vector& code) class AssemblyBuilderA64Fixture { public: - bool check(void (*f)(AssemblyBuilderA64& build), std::vector code, std::vector data = {}) + bool check(void (*f)(AssemblyBuilderA64& build), std::vector code, std::vector data = {}, unsigned int features = 0) { - AssemblyBuilderA64 build(/* logText= */ false); + AssemblyBuilderA64 build(/* logText= */ false, features); f(build); @@ -285,6 +285,87 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "AddressOfLabel") // clang-format on } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPBasic") +{ + SINGLE_COMPARE(fmov(d0, d1), 0x1E604020); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath") +{ + SINGLE_COMPARE(fabs(d1, d2), 0x1E60C041); + SINGLE_COMPARE(fadd(d1, d2, d3), 0x1E632841); + SINGLE_COMPARE(fdiv(d1, d2, d3), 0x1E631841); + SINGLE_COMPARE(fmul(d1, d2, d3), 0x1E630841); + SINGLE_COMPARE(fneg(d1, d2), 0x1E614041); + SINGLE_COMPARE(fsqrt(d1, d2), 0x1E61C041); + SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841); + + SINGLE_COMPARE(frinta(d1, d2), 0x1E664041); + SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); + SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); + + SINGLE_COMPARE(fcvtzs(w1, d2), 0x1E780041); + SINGLE_COMPARE(fcvtzs(x1, d2), 0x9E780041); + SINGLE_COMPARE(fcvtzu(w1, d2), 0x1E790041); + SINGLE_COMPARE(fcvtzu(x1, d2), 0x9E790041); + + SINGLE_COMPARE(scvtf(d1, w2), 0x1E620041); + SINGLE_COMPARE(scvtf(d1, x2), 0x9E620041); + SINGLE_COMPARE(ucvtf(d1, w2), 0x1E630041); + SINGLE_COMPARE(ucvtf(d1, x2), 0x9E630041); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.fjcvtzs(w1, d2); + }, + {0x1E7E0041}, {}, A64::Feature_JSCVT)); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPLoadStore") +{ + // address forms + SINGLE_COMPARE(ldr(d0, x1), 0xFD400020); + SINGLE_COMPARE(ldr(d0, mem(x1, 8)), 0xFD400420); + SINGLE_COMPARE(ldr(d0, mem(x1, x7)), 0xFC676820); + SINGLE_COMPARE(ldr(d0, mem(x1, -7)), 0xFC5F9020); + SINGLE_COMPARE(str(d0, x1), 0xFD000020); + SINGLE_COMPARE(str(d0, mem(x1, 8)), 0xFD000420); + SINGLE_COMPARE(str(d0, mem(x1, x7)), 0xFC276820); + SINGLE_COMPARE(str(d0, mem(x1, -7)), 0xFC1F9020); + + // load/store sizes + SINGLE_COMPARE(ldr(d0, x1), 0xFD400020); + SINGLE_COMPARE(ldr(q0, x1), 0x3DC00020); + SINGLE_COMPARE(str(d0, x1), 0xFD000020); + SINGLE_COMPARE(str(q0, x1), 0x3D800020); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPCompare") +{ + SINGLE_COMPARE(fcmp(d0, d1), 0x1E612000); + SINGLE_COMPARE(fcmpz(d1), 0x1E602028); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "AddressOffsetSize") +{ + SINGLE_COMPARE(ldr(w0, mem(x1, 16)), 0xB9401020); + SINGLE_COMPARE(ldr(x0, mem(x1, 16)), 0xF9400820); + SINGLE_COMPARE(ldr(d0, mem(x1, 16)), 0xFD400820); + SINGLE_COMPARE(ldr(q0, mem(x1, 16)), 0x3DC00420); + + SINGLE_COMPARE(str(w0, mem(x1, 16)), 0xB9001020); + SINGLE_COMPARE(str(x0, mem(x1, 16)), 0xF9000820); + SINGLE_COMPARE(str(d0, mem(x1, 16)), 0xFD000820); + SINGLE_COMPARE(str(q0, mem(x1, 16)), 0x3D800420); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ConditionalSelect") +{ + SINGLE_COMPARE(csel(x0, x1, x2, ConditionA64::Equal), 0x9A820020); + SINGLE_COMPARE(csel(w0, w1, w2, ConditionA64::Equal), 0x1A820020); + SINGLE_COMPARE(fcsel(d0, d1, d2, ConditionA64::Equal), 0x1E620C20); +} + TEST_CASE("LogTest") { AssemblyBuilderA64 build(/* logText= */ true); @@ -309,6 +390,14 @@ TEST_CASE("LogTest") build.ldp(x0, x1, mem(x8, 8)); build.adr(x0, l); + build.fabs(d1, d2); + build.ldr(q1, x2); + + build.csel(x0, x1, x2, ConditionA64::Equal); + + build.fcmp(d0, d1); + build.fcmpz(d0); + build.setLabel(l); build.ret(); @@ -331,6 +420,11 @@ TEST_CASE("LogTest") cbz x7,.L1 ldp x0,x1,[x8,#8] adr x0,.L1 + fabs d1,d2 + ldr q1,[x2] + csel x0,x1,x2,eq + fcmp d0,d1 + fcmp d0,#0 .L1: ret )"; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index aedb50ab6..53dc99e15 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2995,8 +2995,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") { - ScopedFastFlag sff{"LuauCompleteTableKeysBetter", true}; - check(R"( type Direction = "up" | "down" diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index c9d0c01d1..cabf1ccea 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4691,8 +4691,6 @@ RETURN R0 0 TEST_CASE("LoopUnrollCost") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - ScopedFastInt sfis[] = { {"LuauCompileLoopUnrollThreshold", 25}, {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, @@ -5962,8 +5960,6 @@ RETURN R2 1 TEST_CASE("InlineMultret") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - // inlining a function in multret context is prohibited since we can't adjust L->top outside of CALL/GETVARARGS CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -6301,8 +6297,6 @@ RETURN R0 52 TEST_CASE("BuiltinFoldingProhibited") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - CHECK_EQ("\n" + compileFunction(R"( return math.abs(), @@ -6905,8 +6899,6 @@ L3: RETURN R0 0 TEST_CASE("BuiltinArity") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - // by default we can't assume that we know parameter/result count for builtins as they can be overridden at runtime CHECK_EQ("\n" + compileFunction(R"( return math.abs(unknown()) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 1072b95df..957d32719 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -504,7 +504,7 @@ TEST_CASE("Types") Luau::InternalErrorReporter iceHandler; Luau::BuiltinTypes builtinTypes; Luau::GlobalTypes globals{Luau::NotNull{&builtinTypes}}; - Luau::TypeChecker env(globals, &moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); + Luau::TypeChecker env(globals.globalScope, &moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); Luau::registerBuiltinGlobals(env, globals); Luau::freeze(globals.globalTypes); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index cc239b7ec..d34b86bdd 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -31,8 +31,7 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) void ConstraintGraphBuilderFixture::solve(const std::string& code) { generateConstraints(code); - ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull{mainModule->reduction.get()}, - NotNull(&moduleResolver), {}, &logger}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger}; cs.run(); } diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index f4c9cdca9..c1392c9d9 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -42,7 +42,7 @@ class IrBuilderFixture f(a); build.beginBlock(a); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); }; template @@ -56,10 +56,10 @@ class IrBuilderFixture f(a, b); build.beginBlock(a); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(b); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); }; void checkEq(IrOp instOp, const IrInst& inst) @@ -94,10 +94,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag") build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(0), fallback); IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmConst(5)); build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(0), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -107,10 +107,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag") bb_0: CHECK_TAG R2, tnil, bb_fallback_1 CHECK_TAG K5, tnil, bb_fallback_1 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -123,7 +123,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptBinaryArith") IrOp opA = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); IrOp opB = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); build.inst(IrCmd::ADD_NUM, opA, opB); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -133,7 +133,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptBinaryArith") bb_0: %0 = LOAD_DOUBLE R1 %2 = ADD_NUM %0, R2 - LOP_RETURN 0u + RETURN 0u )"); } @@ -150,10 +150,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag1") build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -165,10 +165,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag1") JUMP_EQ_TAG R1, %1, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -186,10 +186,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag2") build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -203,10 +203,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag2") JUMP_EQ_TAG R2, %0, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -224,10 +224,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") build.inst(IrCmd::JUMP_EQ_TAG, opA, build.constTag(0), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -241,10 +241,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") JUMP_EQ_TAG %2, tnil, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -261,10 +261,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum") build.inst(IrCmd::JUMP_CMP_NUM, opA, opB, trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -276,10 +276,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum") JUMP_CMP_NUM R1, %1, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -317,7 +317,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::INT_TO_NUM, build.constInt(8))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constantFold(); @@ -342,7 +342,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") STORE_INT R0, 1i STORE_INT R0, 0i STORE_DOUBLE R0, 8 - LOP_RETURN 0u + RETURN 0u )"); } @@ -373,25 +373,25 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowEq") JUMP bb_1 bb_1: - LOP_RETURN 1u + RETURN 1u bb_3: JUMP bb_5 bb_5: - LOP_RETURN 2u + RETURN 2u bb_6: JUMP bb_7 bb_7: - LOP_RETURN 1u + RETURN 1u bb_9: JUMP bb_11 bb_11: - LOP_RETURN 2u + RETURN 2u )"); } @@ -400,18 +400,18 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumToIndex") { withOneBlock([this](IrOp a) { build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, build.constDouble(4), a)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); }); withOneBlock([this](IrOp a) { build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, build.constDouble(1.2), a)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); }); withOneBlock([this](IrOp a) { IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, nan, a)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); }); updateUseCounts(build.function); @@ -420,19 +420,19 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumToIndex") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R0, 4i - LOP_RETURN 0u + RETURN 0u bb_2: JUMP bb_3 bb_3: - LOP_RETURN 1u + RETURN 1u bb_4: JUMP bb_5 bb_5: - LOP_RETURN 1u + RETURN 1u )"); } @@ -441,12 +441,12 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Guards") { withOneBlock([this](IrOp a) { build.inst(IrCmd::CHECK_TAG, build.constTag(tnumber), build.constTag(tnumber), a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); }); withOneBlock([this](IrOp a) { build.inst(IrCmd::CHECK_TAG, build.constTag(tnil), build.constTag(tnumber), a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); }); updateUseCounts(build.function); @@ -454,13 +454,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Guards") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: - LOP_RETURN 0u + RETURN 0u bb_2: JUMP bb_3 bb_3: - LOP_RETURN 1u + RETURN 1u )"); } @@ -568,7 +568,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues") build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::LOAD_INT, build.vmReg(1))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(11), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -593,7 +593,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues") STORE_INT R10, %20 %22 = LOAD_DOUBLE R2 STORE_DOUBLE R11, %22 - LOP_RETURN 0u + RETURN 0u )"); } @@ -614,7 +614,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue") build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(1))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -627,7 +627,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue") STORE_TVALUE R1, %2 STORE_TAG R3, tnumber STORE_DOUBLE R3, 0.5 - LOP_RETURN 0u + RETURN 0u )"); } @@ -641,10 +641,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag") build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -652,7 +652,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber - LOP_RETURN 0u + RETURN 0u )"); } @@ -671,7 +671,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipOncePerBlockChecks") build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); // Can make env unsafe build.inst(IrCmd::CHECK_SAFE_ENV); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -682,7 +682,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipOncePerBlockChecks") CHECK_GC DO_LEN R1, R2 CHECK_SAFE_ENV - LOP_RETURN 0u + RETURN 0u )"); } @@ -707,10 +707,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState") build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -723,10 +723,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState") DO_LEN R1, R2 CHECK_NO_METATABLE %0, bb_fallback_1 CHECK_READONLY %0, bb_fallback_1 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -742,7 +742,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, build.vmReg(0)); IrOp something = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); build.inst(IrCmd::BARRIER_OBJ, something, build.vmReg(0)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -750,7 +750,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber - LOP_RETURN 0u + RETURN 0u )"); } @@ -773,7 +773,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(6), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -792,7 +792,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation") %9 = LOAD_DOUBLE R2 STORE_DOUBLE R6, %9 STORE_DOUBLE R7, 2 - LOP_RETURN 0u + RETURN 0u )"); } @@ -819,10 +819,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0))); // At least R0 wasn't touched - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -837,10 +837,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 STORE_DOUBLE R1, 0.5 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -855,7 +855,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -865,7 +865,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") STORE_INT R0, 10i STORE_DOUBLE R0, 0.5 STORE_INT R0, 10i - LOP_RETURN 0u + RETURN 0u )"); } @@ -882,10 +882,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation") build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback); build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -894,10 +894,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation") bb_0: %0 = LOAD_TAG R0 CHECK_TAG %0, tnumber, bb_fallback_1 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -914,10 +914,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagationConflicting") build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback); build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnil), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -929,7 +929,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagationConflicting") JUMP bb_fallback_1 bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -947,13 +947,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval") build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(1), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(3)); + build.inst(IrCmd::RETURN, build.constUint(3)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -965,10 +965,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval") JUMP bb_1 bb_1: - LOP_RETURN 1u + RETURN 1u bb_fallback_3: - LOP_RETURN 3u + RETURN 3u )"); } @@ -986,13 +986,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval") build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(1), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(3)); + build.inst(IrCmd::RETURN, build.constUint(3)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1004,10 +1004,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval") JUMP bb_2 bb_2: - LOP_RETURN 2u + RETURN 2u bb_fallback_3: - LOP_RETURN 3u + RETURN 3u )"); } @@ -1024,10 +1024,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval") build.inst(IrCmd::JUMP_EQ_TAG, tag, build.constTag(tnumber), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1039,7 +1039,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval") JUMP bb_2 bb_2: - LOP_RETURN 2u + RETURN 2u )"); } @@ -1056,10 +1056,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") build.inst(IrCmd::JUMP_EQ_INT, value, build.constInt(5), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1070,7 +1070,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") JUMP bb_1 bb_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -1087,10 +1087,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval") build.inst(IrCmd::JUMP_CMP_NUM, value, build.constDouble(8.0), build.cond(IrCondition::Greater), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1101,7 +1101,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval") JUMP bb_2 bb_2: - LOP_RETURN 2u + RETURN 2u )"); } @@ -1118,7 +1118,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor build.beginBlock(block2); build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1130,7 +1130,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor bb_1: STORE_TAG R1, tnumber - LOP_RETURN 1u + RETURN 1u )"); } @@ -1148,7 +1148,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique build.beginBlock(block2); build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(block3); build.inst(IrCmd::JUMP, block2); @@ -1164,7 +1164,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique bb_1: %2 = LOAD_TAG R0 STORE_TAG R1, %2 - LOP_RETURN 1u + RETURN 1u bb_2: JUMP bb_1 @@ -1183,7 +1183,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(repeat); build.inst(IrCmd::INTERRUPT, build.constUint(0)); @@ -1198,7 +1198,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") JUMP bb_1 bb_1: - LOP_RETURN R0, 0i + RETURN R0, 0i )"); } @@ -1211,14 +1211,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") IrOp repeat = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(block); build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(repeat); build.inst(IrCmd::INTERRUPT, build.constUint(0)); @@ -1229,14 +1229,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: - LOP_RETURN R0, 0i + RETURN R0, 0i bb_1: STORE_TAG R0, tnumber JUMP bb_2 bb_2: - LOP_RETURN R0, 0i + RETURN R0, 0i )"); } @@ -1253,14 +1253,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), block, exit1); build.beginBlock(exit1); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(block); build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit2, repeat); build.beginBlock(exit2); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(repeat); build.inst(IrCmd::INTERRUPT, build.constUint(0)); @@ -1274,14 +1274,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") JUMP bb_1 bb_1: - LOP_RETURN R0, 0i + RETURN R0, 0i bb_2: STORE_TAG R0, tnumber JUMP bb_3 bb_3: - LOP_RETURN R0, 0i + RETURN R0, 0i )"); } @@ -1322,7 +1322,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") build.inst(IrCmd::JUMP, block4); build.beginBlock(block4); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1350,10 +1350,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") JUMP bb_5 bb_5: - LOP_RETURN R0, 0i + RETURN R0, 0i bb_linear_6: - LOP_RETURN R0, 0i + RETURN R0, 0i )"); } @@ -1393,11 +1393,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" build.beginBlock(block4a); build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(block4b); build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1427,11 +1427,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" bb_5: STORE_TAG R0, %10 - LOP_RETURN R0, 0i + RETURN R0, 0i bb_6: STORE_TAG R0, %10 - LOP_RETURN R0, 0i + RETURN R0, 0i )"); } @@ -1488,7 +1488,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDiamond") build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.vmReg(2), build.constInt(2)); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(2)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1522,7 +1522,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDiamond") bb_3: ; predecessors: bb_1, bb_2 ; in regs: R2, R3 - LOP_RETURN R2, 2i + RETURN R2, 2i )"); } @@ -1534,11 +1534,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ImplicitFixedRegistersInVarargCall") build.beginBlock(entry); build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(3), build.constInt(-1)); - build.inst(IrCmd::LOP_CALL, build.vmReg(0), build.constInt(-1), build.constInt(5)); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(5)); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(5)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(5)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1549,13 +1549,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ImplicitFixedRegistersInVarargCall") ; in regs: R0, R1, R2 ; out regs: R0, R1, R2, R3, R4 FALLBACK_GETVARARGS 0u, R3, -1i - LOP_CALL R0, -1i, 5i + CALL R0, -1i, 5i JUMP bb_1 bb_1: ; predecessors: bb_0 ; in regs: R0, R1, R2, R3, R4 - LOP_RETURN R0, 5i + RETURN R0, 5i )"); } @@ -1573,7 +1573,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ExplicitUseOfRegisterInVarargSequence") build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(-1)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1590,7 +1590,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ExplicitUseOfRegisterInVarargSequence") bb_1: ; predecessors: bb_0 ; in regs: R0... - LOP_RETURN R0, -1i + RETURN R0, -1i )"); } @@ -1601,12 +1601,12 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequenceRestart") IrOp exit = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::LOP_CALL, build.vmReg(1), build.constInt(0), build.constInt(-1)); - build.inst(IrCmd::LOP_CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::CALL, build.vmReg(1), build.constInt(0), build.constInt(-1)); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(-1)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1616,14 +1616,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequenceRestart") ; successors: bb_1 ; in regs: R0, R1 ; out regs: R0... - LOP_CALL R1, 0i, -1i - LOP_CALL R0, -1i, -1i + CALL R1, 0i, -1i + CALL R0, -1i, -1i JUMP bb_1 bb_1: ; predecessors: bb_0 ; in regs: R0... - LOP_RETURN R0, -1i + RETURN R0, -1i )"); } @@ -1637,15 +1637,15 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") build.beginBlock(entry); build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback); - build.inst(IrCmd::LOP_CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::JUMP, exit); build.beginBlock(fallback); - build.inst(IrCmd::LOP_CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(-1)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1658,7 +1658,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") FALLBACK_GETVARARGS 0u, R1, -1i %1 = LOAD_TAG R0 CHECK_TAG %1, tnumber, bb_fallback_1 - LOP_CALL R0, -1i, -1i + CALL R0, -1i, -1i JUMP bb_2 bb_fallback_1: @@ -1666,13 +1666,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") ; successors: bb_2 ; in regs: R0, R1... ; out regs: R0... - LOP_CALL R0, -1i, -1i + CALL R0, -1i, -1i JUMP bb_2 bb_2: ; predecessors: bb_0, bb_fallback_1 ; in regs: R0... - LOP_RETURN R0, -1i + RETURN R0, -1i )"); } @@ -1697,7 +1697,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequencePeeling") build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.vmReg(2), build.constInt(-1)); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1732,7 +1732,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequencePeeling") bb_3: ; predecessors: bb_1, bb_2 ; in regs: R2... - LOP_RETURN R2, -1i + RETURN R2, -1i )"); } @@ -1746,11 +1746,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinVariadicStart") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(2.0)); build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(2), build.constInt(1)); - build.inst(IrCmd::LOP_CALL, build.vmReg(1), build.constInt(-1), build.constInt(1)); + build.inst(IrCmd::CALL, build.vmReg(1), build.constInt(-1), build.constInt(1)); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(2)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1763,13 +1763,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinVariadicStart") STORE_DOUBLE R1, 1 STORE_DOUBLE R2, 2 ADJUST_STACK_TO_REG R2, 1i - LOP_CALL R1, -1i, 1i + CALL R1, -1i, 1i JUMP bb_1 bb_1: ; predecessors: bb_0 ; in regs: R0, R1 - LOP_RETURN R0, 2i + RETURN R0, 2i )"); } @@ -1781,7 +1781,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") build.beginBlock(entry); build.inst(IrCmd::SET_TABLE, build.vmReg(0), build.vmReg(1), build.constUint(1)); - build.inst(IrCmd::LOP_RETURN, build.vmReg(0), build.constInt(1)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1790,7 +1790,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") bb_0: ; in regs: R0, R1 SET_TABLE R0, R1, 1u - LOP_RETURN R0, 1i + RETURN R0, 1i )"); } diff --git a/tests/IrCallWrapperX64.test.cpp b/tests/IrCallWrapperX64.test.cpp new file mode 100644 index 000000000..8c7b1393f --- /dev/null +++ b/tests/IrCallWrapperX64.test.cpp @@ -0,0 +1,484 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IrCallWrapperX64.h" +#include "Luau/IrRegAllocX64.h" + +#include "doctest.h" + +using namespace Luau::CodeGen; +using namespace Luau::CodeGen::X64; + +class IrCallWrapperX64Fixture +{ +public: + IrCallWrapperX64Fixture() + : build(/* logText */ true, ABIX64::Windows) + , regs(function) + , callWrap(regs, build, ~0u) + { + } + + void checkMatch(std::string expected) + { + regs.assertAllFree(); + + build.finalize(); + + CHECK("\n" + build.text == expected); + } + + AssemblyBuilderX64 build; + IrFunction function; + IrRegAllocX64 regs; + IrCallWrapperX64 callWrap; + + // Tests rely on these to force interference between registers + static constexpr RegisterX64 rArg1 = rcx; + static constexpr RegisterX64 rArg1d = ecx; + static constexpr RegisterX64 rArg2 = rdx; + static constexpr RegisterX64 rArg2d = edx; + static constexpr RegisterX64 rArg3 = r8; + static constexpr RegisterX64 rArg3d = r8d; + static constexpr RegisterX64 rArg4 = r9; + static constexpr RegisterX64 rArg4d = r9d; +}; + +TEST_SUITE_BEGIN("IrCallWrapperX64"); + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleRegs") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rax)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + callWrap.addArgument(SizeX64::qword, tmp1); + callWrap.addArgument(SizeX64::qword, tmp2); // Already in its place + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,rax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse1") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + callWrap.addArgument(SizeX64::qword, tmp1.reg); // Already in its place + callWrap.addArgument(SizeX64::qword, tmp1.release()); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rdx,rcx + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg]); + callWrap.addArgument(SizeX64::qword, tmp1.release()); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rdx,rcx + mov rcx,qword ptr [rcx] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleMemImm") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rax)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rsi)}; + callWrap.addArgument(SizeX64::dword, 32); + callWrap.addArgument(SizeX64::dword, -1); + callWrap.addArgument(SizeX64::qword, qword[r14 + 32]); + callWrap.addArgument(SizeX64::qword, qword[tmp1.release() + tmp2.release()]); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov r8,qword ptr [r14+020h] + mov r9,qword ptr [rax+rsi] + mov ecx,20h + mov edx,FFFFFFFFh + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleStackArgs") +{ + ScopedRegX64 tmp{regs, regs.takeReg(rax)}; + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.addArgument(SizeX64::qword, qword[r14 + 16]); + callWrap.addArgument(SizeX64::qword, qword[r14 + 32]); + callWrap.addArgument(SizeX64::qword, qword[r14 + 48]); + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::qword, qword[r13]); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rdx,qword ptr [r13] + mov qword ptr [rsp+028h],rdx + mov rcx,rax + mov rdx,qword ptr [r14+010h] + mov r8,qword ptr [r14+020h] + mov r9,qword ptr [r14+030h] + mov dword ptr [rsp+020h],1 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FixedRegisters") +{ + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::qword, 2); + callWrap.addArgument(SizeX64::qword, 3); + callWrap.addArgument(SizeX64::qword, 4); + callWrap.addArgument(SizeX64::qword, r14); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov qword ptr [rsp+020h],r14 + mov ecx,1 + mov rdx,2 + mov r8,3 + mov r9,4 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "EasyInterference") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rdi)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rsi)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1)}; + callWrap.addArgument(SizeX64::qword, tmp1); + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.addArgument(SizeX64::qword, tmp3); + callWrap.addArgument(SizeX64::qword, tmp4); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov r8,rdx + mov rdx,rsi + mov r9,rcx + mov rcx,rdi + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeInterference") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.release() + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + 8]); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,qword ptr [rcx+8] + mov rdx,qword ptr [rdx+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg4)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg3)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1)}; + callWrap.addArgument(SizeX64::qword, tmp1); + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.addArgument(SizeX64::qword, tmp3); + callWrap.addArgument(SizeX64::qword, tmp4); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,r9 + mov r9,rcx + mov rcx,rax + mov rax,r8 + mov r8,rdx + mov rdx,rax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg4d)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg3d)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2d)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1d)}; + callWrap.addArgument(SizeX64::dword, tmp1); + callWrap.addArgument(SizeX64::dword, tmp2); + callWrap.addArgument(SizeX64::dword, tmp3); + callWrap.addArgument(SizeX64::dword, tmp4); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov eax,r9d + mov r9d,ecx + mov ecx,eax + mov eax,r8d + mov r8d,edx + mov edx,eax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceFp") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(xmm1)}; + ScopedRegX64 tmp2{regs, regs.takeReg(xmm0)}; + callWrap.addArgument(SizeX64::xmmword, tmp1); + callWrap.addArgument(SizeX64::xmmword, tmp2); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm2,xmm1,xmm1 + vmovsd xmm1,xmm0,xmm0 + vmovsd xmm0,xmm2,xmm2 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceBoth") +{ + ScopedRegX64 int1{regs, regs.takeReg(rArg2)}; + ScopedRegX64 int2{regs, regs.takeReg(rArg1)}; + ScopedRegX64 fp1{regs, regs.takeReg(xmm3)}; + ScopedRegX64 fp2{regs, regs.takeReg(xmm2)}; + callWrap.addArgument(SizeX64::qword, int1); + callWrap.addArgument(SizeX64::qword, int2); + callWrap.addArgument(SizeX64::xmmword, fp1); + callWrap.addArgument(SizeX64::xmmword, fp2); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,rdx + mov rdx,rcx + mov rcx,rax + vmovsd xmm0,xmm3,xmm3 + vmovsd xmm3,xmm2,xmm2 + vmovsd xmm2,xmm0,xmm0 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeMultiuseInterferenceMem") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp2.reg + 16]); + tmp1.release(); + tmp2.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,qword ptr [rcx+rdx+8] + mov rdx,qword ptr [rdx+010h] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem1") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + 16]); + tmp1.release(); + tmp2.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,rcx + mov rcx,qword ptr [rax+rdx+8] + mov rdx,qword ptr [rax+010h] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 16]); + tmp1.release(); + tmp2.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,rcx + mov rcx,qword ptr [rax+rdx+8] + mov rdx,qword ptr [rax+rdx+010h] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem3") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg3)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg1)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp2.reg + tmp3.reg + 16]); + callWrap.addArgument(SizeX64::qword, qword[tmp3.reg + tmp1.reg + 16]); + tmp1.release(); + tmp2.release(); + tmp3.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,r8 + mov r8,qword ptr [rcx+rax+010h] + mov rbx,rdx + mov rdx,qword ptr [rbx+rcx+010h] + mov rcx,qword ptr [rax+rbx+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg1") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + 8]); + callWrap.call(qword[tmp1.release() + 16]); + + checkMatch(R"( + mov rax,rcx + mov rcx,qword ptr [rax+8] + call qword ptr [rax+010h] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.call(qword[tmp1.release() + 16]); + + checkMatch(R"( + mov rax,rcx + mov rcx,rdx + call qword ptr [rax+010h] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg3") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + callWrap.addArgument(SizeX64::qword, tmp1.reg); + callWrap.call(qword[tmp1.release() + 16]); + + checkMatch(R"( + call qword ptr [rcx+010h] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse1") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(xmm0); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); // Already in its place + callWrap.addArgument(SizeX64::xmmword, qword[r12 + 8]); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm1,qword ptr [r12+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse2") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(xmm0); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + callWrap.addArgument(SizeX64::xmmword, qword[r12 + 8]); + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm1,xmm0,xmm0 + vmovsd xmm0,qword ptr [r12+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse3") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(xmm0); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm1,xmm0,xmm0 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse4") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(rax); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + ScopedRegX64 tmp{regs, regs.takeReg(rdx)}; + callWrap.addArgument(SizeX64::qword, r15); + callWrap.addArgument(SizeX64::qword, irInst1.regX64, irOp1); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,r15 + mov r8,rdx + mov rdx,rax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ExtraCoverage") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + callWrap.addArgument(SizeX64::qword, addr[r12 + 8]); + callWrap.addArgument(SizeX64::qword, addr[r12 + 16]); + callWrap.addArgument(SizeX64::xmmword, xmmword[r13]); + callWrap.call(qword[tmp1.release() + tmp2.release()]); + + checkMatch(R"( + vmovups xmm2,xmmword ptr [r13] + mov rax,rcx + lea rcx,none ptr [r12+8] + mov rbx,rdx + lea rdx,none ptr [r12+010h] + call qword ptr [rax+rbx] +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 7fcc1e542..78d1389a6 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -157,8 +157,6 @@ TEST_CASE("string_interpolation_basic") TEST_CASE("string_interpolation_full") { - ScopedFastFlag sff("LuauFixInterpStringMid", true); - const std::string testInput = R"(`foo {"bar"} {"baz"} end`)"; Luau::Allocator alloc; AstNameTable table(alloc); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 0f1346161..8bef5922f 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1444,8 +1444,6 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiTyped") { - ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); - unfreeze(frontend.globals.globalTypes); TypeId instanceType = frontend.globals.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); persist(instanceType); @@ -1496,8 +1494,6 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiUntyped") { - ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); - if (TableType* ttv = getMutable(getGlobalBinding(frontend.globals, "table"))) { ttv->props["foreach"].deprecated = true; diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index a495ee231..4378bab8b 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -470,7 +470,6 @@ TEST_SUITE_END(); struct NormalizeFixture : Fixture { - ScopedFastFlag sff1{"LuauNegatedFunctionTypes", true}; ScopedFastFlag sff2{"LuauNegatedClassTypes", true}; TypeArena arena; diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 9ff16d16b..ef5aabbe3 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1040,8 +1040,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_call_without_parens") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_expression") { - ScopedFastFlag sff("LuauFixInterpStringMid", true); - try { parse(R"( diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 022abea0b..52de15c75 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -1014,4 +1014,34 @@ TEST_CASE_FIXTURE(Fixture, "another_thing_from_roact") LUAU_REQUIRE_NO_ERRORS(result); } +/* + * It is sometimes possible for type alias resolution to produce a TypeId that + * belongs to a different module. + * + * We must not mutate any fields of the resulting type when this happens. The + * memory has been frozen. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "alias_expands_to_bare_reference_to_imported_type") +{ + fileResolver.source["game/A"] = R"( + --!strict + export type Object = {[string]: any} + return {} + )"; + + fileResolver.source["game/B"] = R"( + local A = require(script.Parent.A) + + type Object = A.Object + type ReadOnly = T + + local function f(): ReadOnly + return nil :: any + end + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index c7f9684b3..f1d42c6a4 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1784,7 +1784,6 @@ z = y -- Not OK, so the line is colorable TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") { - ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; registerHiddenTypes(&frontend); CheckResult result = check(R"( @@ -1803,7 +1802,6 @@ TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function") { - ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; registerHiddenTypes(&frontend); CheckResult result = check(R"( @@ -1824,7 +1822,6 @@ TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function") TEST_CASE_FIXTURE(Fixture, "other_things_are_not_related_to_function") { - ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; registerHiddenTypes(&frontend); CheckResult result = check(R"( diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 511cbc763..7a1343584 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -707,4 +707,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_68448_iterators_need_not_accept_nil") CHECK(toString(requireType("makeEnum"), {true}) == "({a}) -> {| [a]: a |}"); } +TEST_CASE_FIXTURE(Fixture, "iterate_over_free_table") +{ + CheckResult result = check(R"( + function print(x) end + + function dump(tbl) + print(tbl.whatever) + for k, v in tbl do + print(k) + print(v) + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + + CHECK("Cannot iterate over a table without indexer" == ge->message); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index eb4937fde..f2b3d0559 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -381,4 +381,29 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "react_style_oo") CHECK("string" == toString(requireType("hello"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "cycle_between_object_constructor_and_alias") +{ + CheckResult result = check(R"( + local T = {} + T.__index = T + + function T.new(): T + return setmetatable({}, T) + end + + export type T = typeof(T.new()) + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto module = getMainModule(); + + REQUIRE(module->exportedTypeBindings.count("T")); + + TypeId aliasType = module->exportedTypeBindings["T"].type; + CHECK_MESSAGE(get(follow(aliasType)), "Expected metatable type but got: " << toString(aliasType)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 8c289c7be..174bc310e 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -860,8 +860,6 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local a: string | number = "hi" local b: {x: string}? = {x = "bye"} @@ -970,8 +968,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or") TEST_CASE_FIXTURE(ClassFixture, "unrelated_classes_cannot_be_compared") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local a = BaseClass.New() local b = UnrelatedClass.New() @@ -984,8 +980,6 @@ TEST_CASE_FIXTURE(ClassFixture, "unrelated_classes_cannot_be_compared") TEST_CASE_FIXTURE(Fixture, "unrelated_primitives_cannot_be_compared") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local c = 5 == true )"); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 30f77d681..38e7e2f31 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -176,8 +176,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "error_on_eq_metamethod_returning_a_type_othe // We need refine both operands as `never` in the `==` branch. TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local function f(a: string, b: boolean?) if a == b then diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 21ac6421b..468adc2c6 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) -LUAU_FASTFLAG(LuauDontExtendUnsealedRValueTables) TEST_SUITE_BEGIN("TableTests"); @@ -913,10 +912,7 @@ TEST_CASE_FIXTURE(Fixture, "disallow_indexing_into_an_unsealed_table_with_no_ind local k1 = getConstant("key1") )"); - if (FFlag::LuauDontExtendUnsealedRValueTables) - CHECK("any" == toString(requireType("k1"))); - else - CHECK("a" == toString(requireType("k1"))); + CHECK("any" == toString(requireType("k1"))); LUAU_REQUIRE_NO_ERRORS(result); } @@ -3542,8 +3538,6 @@ _ = {_,} TEST_CASE_FIXTURE(Fixture, "when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type") { - ScopedFastFlag sff{"LuauDontExtendUnsealedRValueTables", true}; - CheckResult result = check(R"( local events = {} local mockObserveEvent = function(_, key, callback) @@ -3572,8 +3566,6 @@ TEST_CASE_FIXTURE(Fixture, "when_augmenting_an_unsealed_table_with_an_indexer_ap TEST_CASE_FIXTURE(Fixture, "dont_extend_unsealed_tables_in_rvalue_position") { - ScopedFastFlag sff{"LuauDontExtendUnsealedRValueTables", true}; - CheckResult result = check(R"( local testDictionary = { FruitName = "Lemon", diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 7c4bfb2e9..7e317f2ef 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1194,7 +1194,6 @@ TEST_CASE_FIXTURE(Fixture, "dcr_delays_expansion_of_function_containing_blocked_ { ScopedFastFlag sff[] = { {"DebugLuauDeferredConstraintResolution", true}, - {"LuauTinyUnifyNormalsFix", true}, // If we run this with error-suppression, it triggers an assertion. // FATAL ERROR: Assertion failed: !"Internal error: Trying to normalize a BlockedType" {"LuauTransitiveSubtyping", false}, diff --git a/tools/faillist.txt b/tools/faillist.txt index 76e5972dc..31fc82dae 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -25,9 +25,6 @@ BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_tostring_specifier_type_constraint BuiltinTests.string_format_use_correct_argument2 -BuiltinTests.table_pack -BuiltinTests.table_pack_reduce -BuiltinTests.table_pack_variadic DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props GenericsTests.apply_type_function_nested_generics2 @@ -114,7 +111,6 @@ TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.table_unification_4 TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon -ToString.named_metatable_toStringNamedFunction ToString.toStringDetailed2 ToString.toStringErrorPack ToString.toStringNamedFunction_generic_pack @@ -137,6 +133,7 @@ TypeInfer.check_type_infer_recursion_count TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError +TypeInfer.follow_on_new_types_in_substitution TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.no_stack_overflow_from_isoptional diff --git a/tools/natvis/CodeGen.natvis b/tools/natvis/CodeGen.natvis index 5ff6e1432..84fb3329a 100644 --- a/tools/natvis/CodeGen.natvis +++ b/tools/natvis/CodeGen.natvis @@ -1,45 +1,46 @@ - - noreg - rip + + noreg + rip - al - cl - dl - bl + al + cl + dl + bl - eax - ecx - edx - ebx - esp - ebp - esi - edi - e{(int)index,d}d + eax + ecx + edx + ebx + esp + ebp + esi + edi + e{(int)index,d}d - rax - rcx - rdx - rbx - rsp - rbp - rsi - rdi - r{(int)index,d} + rax + rcx + rdx + rbx + rsp + rbp + rsi + rdi + r{(int)index,d} - xmm{(int)index,d} + xmm{(int)index,d} - ymm{(int)index,d} + ymm{(int)index,d} - + {base} {memSize,en} ptr[{base} + {index}*{(int)scale,d} + {imm}] {memSize,en} ptr[{index}*{(int)scale,d} + {imm}] {memSize,en} ptr[{base} + {imm}] + {memSize,en} ptr[{base} + {imm}] {memSize,en} ptr[{imm}] {imm} From d071e410ce59a2c3deb2deb38ee99836eda49874 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 31 Mar 2023 16:25:13 +0300 Subject: [PATCH 44/66] g++ build fix --- CodeGen/src/AssemblyBuilderA64.cpp | 2 +- CodeGen/src/IrRegAllocA64.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index e7f50b142..a80003e94 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -496,7 +496,7 @@ void AssemblyBuilderA64::fcmpz(RegisterA64 src) { LUAU_ASSERT(src.kind == KindA64::d); - placeFCMP("fcmp", src, {src.kind, 0}, 0b11110'01'1, 0b01); + placeFCMP("fcmp", src, RegisterA64{src.kind, 0}, 0b11110'01'1, 0b01); } void AssemblyBuilderA64::fcsel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond) diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index 3609c8e25..dc18ab56d 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -55,7 +55,7 @@ RegisterA64 IrRegAllocA64::allocReg(KindA64 kind) int index = setBit(set.free); set.free &= ~(1u << index); - return {kind, uint8_t(index)}; + return RegisterA64{kind, uint8_t(index)}; } RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind) @@ -73,7 +73,7 @@ RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind) set.free &= ~(1u << index); set.temp |= 1u << index; - return {kind, uint8_t(index)}; + return RegisterA64{kind, uint8_t(index)}; } RegisterA64 IrRegAllocA64::allocReuse(KindA64 kind, uint32_t index, std::initializer_list oprefs) From 5309401f49a78b01a6df158a0e538eab4cb34e6c Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 7 Apr 2023 12:56:27 -0700 Subject: [PATCH 45/66] Sync to upstream/release/571 --- Analysis/include/Luau/BuiltinDefinitions.h | 4 +- Analysis/include/Luau/Frontend.h | 19 +- Analysis/include/Luau/Type.h | 51 +- Analysis/include/Luau/TypePack.h | 38 +- Analysis/include/Luau/Unifiable.h | 40 +- Analysis/include/Luau/VisitType.h | 4 +- Analysis/src/Autocomplete.cpp | 11 +- Analysis/src/BuiltinDefinitions.cpp | 106 +-- Analysis/src/Clone.cpp | 18 +- Analysis/src/Frontend.cpp | 192 ++---- Analysis/src/Normalize.cpp | 31 +- Analysis/src/ToString.cpp | 3 +- Analysis/src/Type.cpp | 67 +- Analysis/src/TypeAttach.cpp | 18 +- Analysis/src/TypeInfer.cpp | 18 +- Analysis/src/TypePack.cpp | 71 +- Analysis/src/Unifiable.cpp | 65 -- Analysis/src/Unifier.cpp | 22 +- CLI/Analyze.cpp | 13 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 7 +- CodeGen/include/Luau/AssemblyBuilderX64.h | 1 + CodeGen/include/Luau/IrAnalysis.h | 2 + CodeGen/include/Luau/IrCallWrapperX64.h | 5 +- CodeGen/include/Luau/IrData.h | 41 +- CodeGen/include/Luau/IrRegAllocX64.h | 69 +- CodeGen/include/Luau/IrUtils.h | 2 + CodeGen/include/Luau/RegisterA64.h | 9 + CodeGen/src/AssemblyBuilderA64.cpp | 2 +- CodeGen/src/AssemblyBuilderX64.cpp | 10 + CodeGen/src/CodeGen.cpp | 15 +- CodeGen/src/EmitBuiltinsX64.cpp | 16 +- CodeGen/src/EmitCommonA64.cpp | 24 + CodeGen/src/EmitCommonA64.h | 1 + CodeGen/src/EmitCommonX64.cpp | 65 +- CodeGen/src/EmitCommonX64.h | 18 +- CodeGen/src/EmitInstructionX64.cpp | 82 +-- CodeGen/src/EmitInstructionX64.h | 6 +- CodeGen/src/IrAnalysis.cpp | 108 +-- CodeGen/src/IrBuilder.cpp | 8 +- CodeGen/src/IrCallWrapperX64.cpp | 54 +- CodeGen/src/IrDump.cpp | 16 +- CodeGen/src/IrLoweringA64.cpp | 741 ++++++++++++++++++++- CodeGen/src/IrLoweringA64.h | 2 + CodeGen/src/IrLoweringX64.cpp | 245 +++---- CodeGen/src/IrLoweringX64.h | 10 +- CodeGen/src/IrRegAllocA64.cpp | 13 +- CodeGen/src/IrRegAllocA64.h | 1 + CodeGen/src/IrRegAllocX64.cpp | 270 +++++++- CodeGen/src/IrTranslateBuiltins.cpp | 30 +- CodeGen/src/IrTranslation.cpp | 68 +- CodeGen/src/IrTranslation.h | 2 + CodeGen/src/IrUtils.cpp | 128 ++++ CodeGen/src/NativeState.cpp | 1 + CodeGen/src/OptimizeConstProp.cpp | 47 +- VM/include/lua.h | 2 +- VM/src/ldo.cpp | 48 +- VM/src/ltablib.cpp | 280 +++----- tests/AssemblyBuilderX64.test.cpp | 2 + tests/Autocomplete.test.cpp | 6 +- tests/Conformance.test.cpp | 109 ++- tests/Fixture.cpp | 11 +- tests/Fixture.h | 1 - tests/Frontend.test.cpp | 2 +- tests/IrCallWrapperX64.test.cpp | 134 ++-- tests/Linter.test.cpp | 2 +- tests/Normalize.test.cpp | 14 + tests/TypeInfer.definitions.test.cpp | 12 +- tests/TypeInfer.generics.test.cpp | 2 +- tests/TypeInfer.provisional.test.cpp | 4 +- tests/TypeInfer.test.cpp | 9 +- tests/TypePack.test.cpp | 2 +- tests/TypeVar.test.cpp | 2 +- tests/conformance/apicalls.lua | 8 + tests/conformance/pcall.lua | 7 + tests/conformance/sort.lua | 48 +- tests/conformance/strings.lua | 7 + tools/natvis/Common.natvis | 27 + tools/test_dcr.py | 22 +- 78 files changed, 2540 insertions(+), 1131 deletions(-) create mode 100644 tools/natvis/Common.natvis diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 162139581..d44576385 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -16,9 +16,7 @@ struct TypeArena; void registerBuiltinTypes(GlobalTypes& globals); -void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals); -void registerBuiltinGlobals(Frontend& frontend); - +void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false); TypeId makeUnion(TypeArena& arena, std::vector&& types); TypeId makeIntersection(TypeArena& arena, std::vector&& types); diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 68ba8ff5d..82251378e 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -8,7 +8,6 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" - #include #include #include @@ -36,9 +35,6 @@ struct LoadDefinitionFileResult ModulePtr module; }; -LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition, - const std::string& packageName, bool captureComments); - std::optional parseMode(const std::vector& hotcomments); std::vector parsePathExpr(const AstExpr& pathExpr); @@ -55,7 +51,9 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa * error when we try during typechecking. */ std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& expr); - +// TODO: Deprecate this code path when we move away from the old solver +LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition, + const std::string& packageName, bool captureComments); struct SourceNode { bool hasDirtySourceModule() const @@ -140,10 +138,6 @@ struct Frontend CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess - // Use 'check' with 'runLintChecks' set to true in FrontendOptions (enabledLintWarnings be set there as well) - LintResult lint_DEPRECATED(const ModuleName& name, std::optional enabledLintWarnings = {}); - LintResult lint_DEPRECATED(const SourceModule& module, std::optional enabledLintWarnings = {}); - bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); @@ -164,10 +158,11 @@ struct Frontend ScopePtr addEnvironment(const std::string& environmentName); ScopePtr getEnvironmentScope(const std::string& environmentName) const; - void registerBuiltinDefinition(const std::string& name, std::function); + void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); - LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments); + LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, + bool captureComments, bool typeCheckForAutocomplete = false); private: ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete = false, bool recordJsonLog = false); @@ -182,7 +177,7 @@ struct Frontend ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const; std::unordered_map environments; - std::unordered_map> builtinDefinitions; + std::unordered_map> builtinDefinitions; BuiltinTypes builtinTypes_; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index dba2a8de2..cff86df42 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -75,13 +75,44 @@ using TypeId = const Type*; using Name = std::string; // A free type var is one whose exact shape has yet to be fully determined. -using FreeType = Unifiable::Free; +struct FreeType +{ + explicit FreeType(TypeLevel level); + explicit FreeType(Scope* scope); + FreeType(Scope* scope, TypeLevel level); -// When a free type var is unified with any other, it is then "bound" -// to that type var, indicating that the two types are actually the same type. -using BoundType = Unifiable::Bound; + int index; + TypeLevel level; + Scope* scope = nullptr; + + // True if this free type variable is part of a mutually + // recursive type alias whose definitions haven't been + // resolved yet. + bool forwardedTypeAlias = false; +}; + +struct GenericType +{ + // By default, generics are global, with a synthetic name + GenericType(); -using GenericType = Unifiable::Generic; + explicit GenericType(TypeLevel level); + explicit GenericType(const Name& name); + explicit GenericType(Scope* scope); + + GenericType(TypeLevel level, const Name& name); + GenericType(Scope* scope, const Name& name); + + int index; + TypeLevel level; + Scope* scope = nullptr; + Name name; + bool explicitName = false; +}; + +// When an equality constraint is found, it is then "bound" to that type, +// indicating that the two types are actually the same type. +using BoundType = Unifiable::Bound; using Tags = std::vector; @@ -395,9 +426,11 @@ struct TableType // Represents a metatable attached to a table type. Somewhat analogous to a bound type. struct MetatableType { - // Always points to a TableType. + // Should always be a TableType. TypeId table; - // Always points to either a TableType or a MetatableType. + // Should almost always either be a TableType or another MetatableType, + // though it is possible for other types (like AnyType and ErrorType) to + // find their way here sometimes. TypeId metatable; std::optional syntheticName; @@ -536,8 +569,8 @@ struct NegationType using ErrorType = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct Type final { diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 4831f2338..2ae56e5f0 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -12,20 +12,48 @@ namespace Luau { struct TypeArena; +struct TxnLog; struct TypePack; struct VariadicTypePack; struct BlockedTypePack; struct TypePackVar; +using TypePackId = const TypePackVar*; -struct TxnLog; +struct FreeTypePack +{ + explicit FreeTypePack(TypeLevel level); + explicit FreeTypePack(Scope* scope); + FreeTypePack(Scope* scope, TypeLevel level); + + int index; + TypeLevel level; + Scope* scope = nullptr; +}; + +struct GenericTypePack +{ + // By default, generics are global, with a synthetic name + GenericTypePack(); + explicit GenericTypePack(TypeLevel level); + explicit GenericTypePack(const Name& name); + explicit GenericTypePack(Scope* scope); + GenericTypePack(TypeLevel level, const Name& name); + GenericTypePack(Scope* scope, const Name& name); + + int index; + TypeLevel level; + Scope* scope = nullptr; + Name name; + bool explicitName = false; +}; -using TypePackId = const TypePackVar*; -using FreeTypePack = Unifiable::Free; using BoundTypePack = Unifiable::Bound; -using GenericTypePack = Unifiable::Generic; -using TypePackVariant = Unifiable::Variant; + +using ErrorTypePack = Unifiable::Error; + +using TypePackVariant = Unifiable::Variant; /* A TypePack is a rope-like string of TypeIds. We use this structure to encode * notions like packs of unknown length and packs of any length, as well as more diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index ae55f3734..79b3b7dea 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -83,24 +83,6 @@ using Name = std::string; int freshIndex(); -struct Free -{ - explicit Free(TypeLevel level); - explicit Free(Scope* scope); - explicit Free(Scope* scope, TypeLevel level); - - int index; - TypeLevel level; - Scope* scope = nullptr; - // True if this free type variable is part of a mutually - // recursive type alias whose definitions haven't been - // resolved yet. - bool forwardedTypeAlias = false; - -private: - static int DEPRECATED_nextIndex; -}; - template struct Bound { @@ -112,26 +94,6 @@ struct Bound Id boundTo; }; -struct Generic -{ - // By default, generics are global, with a synthetic name - Generic(); - explicit Generic(TypeLevel level); - explicit Generic(const Name& name); - explicit Generic(Scope* scope); - Generic(TypeLevel level, const Name& name); - Generic(Scope* scope, const Name& name); - - int index; - TypeLevel level; - Scope* scope = nullptr; - Name name; - bool explicitName = false; - -private: - static int DEPRECATED_nextIndex; -}; - struct Error { // This constructor has to be public, since it's used in Type and TypePack, @@ -145,6 +107,6 @@ struct Error }; template -using Variant = Luau::Variant, Generic, Error, Value...>; +using Variant = Luau::Variant, Error, Value...>; } // namespace Luau::Unifiable diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index ff4dfc3c3..95b2b0507 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -341,10 +341,10 @@ struct GenericTypeVisitor traverse(btv->boundTo); } - else if (auto ftv = get(tp)) + else if (auto ftv = get(tp)) visit(tp, *ftv); - else if (auto gtv = get(tp)) + else if (auto gtv = get(tp)) visit(tp, *gtv); else if (auto etv = get(tp)) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 3fdd93190..42fc9a717 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,8 +13,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauAutocompleteSkipNormalization, false); - static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -143,12 +141,9 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); - if (FFlag::LuauAutocompleteSkipNormalization) - { - // Cost of normalization can be too high for autocomplete response time requirements - unifier.normalize = false; - unifier.checkInhabited = false; - } + // Cost of normalization can be too high for autocomplete response time requirements + unifier.normalize = false; + unifier.checkInhabited = false; return unifier.canUnify(subTy, superTy).empty(); } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 2108b160f..7ed92fb41 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -212,7 +212,7 @@ void registerBuiltinTypes(GlobalTypes& globals) globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType}); } -void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) +void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete) { LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); @@ -220,8 +220,8 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) TypeArena& arena = globals.globalTypes; NotNull builtinTypes = globals.builtinTypes; - LoadDefinitionFileResult loadResult = - Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false); + LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile( + globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete); LUAU_ASSERT(loadResult.success); TypeId genericK = arena.addType(GenericType{"K"}); @@ -309,106 +309,6 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); } -void registerBuiltinGlobals(Frontend& frontend) -{ - GlobalTypes& globals = frontend.globals; - - LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); - LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); - - registerBuiltinTypes(globals); - - TypeArena& arena = globals.globalTypes; - NotNull builtinTypes = globals.builtinTypes; - - LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau", /* captureComments */ false); - LUAU_ASSERT(loadResult.success); - - TypeId genericK = arena.addType(GenericType{"K"}); - TypeId genericV = arena.addType(GenericType{"V"}); - TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), globals.globalScope->level, TableState::Generic}); - - std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); - LUAU_ASSERT(stringMetatableTy); - const TableType* stringMetatableTable = get(follow(*stringMetatableTy)); - LUAU_ASSERT(stringMetatableTable); - - auto it = stringMetatableTable->props.find("__index"); - LUAU_ASSERT(it != stringMetatableTable->props.end()); - - addGlobalBinding(globals, "string", it->second.type, "@luau"); - - // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); - addGlobalBinding(globals, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - - TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, builtinTypes->nilType}}); - - // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) - addGlobalBinding(globals, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); - - TypeId genericMT = arena.addType(GenericType{"MT"}); - - TableType tab{TableState::Generic, globals.globalScope->level}; - TypeId tabTy = arena.addType(tab); - - TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); - - addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - - // clang-format off - // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(globals, "setmetatable", - arena.addType( - FunctionType{ - {genericMT}, - {}, - arena.addTypePack(TypePack{{tabTy, genericMT}}), - arena.addTypePack(TypePack{{tableMetaMT}}) - } - ), "@luau" - ); - // clang-format on - - for (const auto& pair : globals.globalScope->bindings) - { - persist(pair.second.typeId); - - if (TableType* ttv = getMutable(pair.second.typeId)) - { - if (!ttv->name) - ttv->name = "typeof(" + toString(pair.first) + ")"; - } - } - - attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); - attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); - - if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) - { - // tabTy is a generic table type which we can't express via declaration syntax yet - ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); - ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); - - ttv->props["getn"].deprecated = true; - ttv->props["getn"].deprecatedSuggestion = "#"; - ttv->props["foreach"].deprecated = true; - ttv->props["foreachi"].deprecated = true; - - attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); - attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); - } - - attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); -} - static std::optional> magicFunctionSelect( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 2645209d5..ac73622d3 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -44,10 +44,10 @@ struct TypeCloner template void defaultClone(const T& t); - void operator()(const Unifiable::Free& t); - void operator()(const Unifiable::Generic& t); - void operator()(const Unifiable::Bound& t); - void operator()(const Unifiable::Error& t); + void operator()(const FreeType& t); + void operator()(const GenericType& t); + void operator()(const BoundType& t); + void operator()(const ErrorType& t); void operator()(const BlockedType& t); void operator()(const PendingExpansionType& t); void operator()(const PrimitiveType& t); @@ -89,15 +89,15 @@ struct TypePackCloner seenTypePacks[typePackId] = cloned; } - void operator()(const Unifiable::Free& t) + void operator()(const FreeTypePack& t) { defaultClone(t); } - void operator()(const Unifiable::Generic& t) + void operator()(const GenericTypePack& t) { defaultClone(t); } - void operator()(const Unifiable::Error& t) + void operator()(const ErrorTypePack& t) { defaultClone(t); } @@ -145,12 +145,12 @@ void TypeCloner::defaultClone(const T& t) seenTypes[typeId] = cloned; } -void TypeCloner::operator()(const Unifiable::Free& t) +void TypeCloner::operator()(const FreeType& t) { defaultClone(t); } -void TypeCloner::operator()(const Unifiable::Generic& t) +void TypeCloner::operator()(const GenericType& t) { defaultClone(t); } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 191e94f4d..98022d862 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -29,7 +29,6 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauLintInTypecheck, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); @@ -84,32 +83,20 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) } } -LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments) +static ParseResult parseSourceForModule(std::string_view source, Luau::SourceModule& sourceModule, bool captureComments) { - if (!FFlag::DebugLuauDeferredConstraintResolution) - return Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, source, packageName, captureComments); - - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - - Luau::SourceModule sourceModule; - ParseOptions options; options.allowDeclarationSyntax = true; options.captureComments = captureComments; Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); - - if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - sourceModule.root = parseResult.root; sourceModule.mode = Mode::Definition; + return parseResult; +} - ModulePtr checkedModule = check(sourceModule, Mode::Definition, {}); - - if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; - +static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, ScopePtr targetScope, const std::string& packageName) +{ CloneState cloneState; std::vector typesToPersist; @@ -120,7 +107,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c TypeId globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); - globals.globalScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; typesToPersist.push_back(globalTy); } @@ -130,7 +117,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); - globals.globalScope->exportedTypeBindings[name] = globalTy; + targetScope->exportedTypeBindings[name] = globalTy; typesToPersist.push_back(globalTy.type); } @@ -139,63 +126,49 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c { persist(ty); } - - return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } -LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, - const std::string& packageName, bool captureComments) +LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, + const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete) { + if (!FFlag::DebugLuauDeferredConstraintResolution) + return Luau::loadDefinitionFileNoDCR(typeCheckForAutocomplete ? typeCheckerForAutocomplete : typeChecker, + typeCheckForAutocomplete ? globalsForAutocomplete : globals, targetScope, source, packageName, captureComments); + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); Luau::SourceModule sourceModule; - - ParseOptions options; - options.allowDeclarationSyntax = true; - options.captureComments = captureComments; - - Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); - + Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments); if (parseResult.errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - sourceModule.root = parseResult.root; - sourceModule.mode = Mode::Definition; - - ModulePtr checkedModule = typeChecker.check(sourceModule, Mode::Definition); + ModulePtr checkedModule = check(sourceModule, Mode::Definition, {}); if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; - CloneState cloneState; + persistCheckedTypes(checkedModule, globals, targetScope, packageName); - std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); + return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; +} - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; +LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, + const std::string& packageName, bool captureComments) +{ + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - typesToPersist.push_back(globalTy); - } + Luau::SourceModule sourceModule; + Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments); - for (const auto& [name, ty] : checkedModule->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - targetScope->exportedTypeBindings[name] = globalTy; + if (parseResult.errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - typesToPersist.push_back(globalTy.type); - } + ModulePtr checkedModule = typeChecker.check(sourceModule, Mode::Definition); - for (TypeId ty : typesToPersist) - { - persist(ty); - } + if (checkedModule->errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; + + persistCheckedTypes(checkedModule, globals, targetScope, packageName); return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } @@ -316,8 +289,6 @@ static ErrorVec accumulateErrors( static void filterLintOptions(LintOptions& lintOptions, const std::vector& hotcomments, Mode mode) { - LUAU_ASSERT(FFlag::LuauLintInTypecheck); - uint64_t ignoreLints = LintWarning::parseMask(hotcomments); lintOptions.warningMask &= ~ignoreLints; @@ -472,24 +443,16 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& modules = - frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + std::unordered_map& modules = + frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; - checkResult.errors = accumulateErrors(sourceNodes, modules, name); + checkResult.errors = accumulateErrors(sourceNodes, modules, name); - // Get lint result only for top checked module - if (auto it = modules.find(name); it != modules.end()) - checkResult.lintResult = it->second->lintResult; + // Get lint result only for top checked module + if (auto it = modules.find(name); it != modules.end()) + checkResult.lintResult = it->second->lintResult; - return checkResult; - } - else - { - return CheckResult{accumulateErrors( - sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; - } + return checkResult; } std::vector buildQueue; @@ -553,9 +516,10 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& modules = - frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + // Get lint result only for top checked module + std::unordered_map& modules = + frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; - if (auto it = modules.find(name); it != modules.end()) - checkResult.lintResult = it->second->lintResult; - } + if (auto it = modules.find(name); it != modules.end()) + checkResult.lintResult = it->second->lintResult; return checkResult; } @@ -800,59 +759,6 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config return result; } -LintResult Frontend::lint_DEPRECATED(const ModuleName& name, std::optional enabledLintWarnings) -{ - LUAU_ASSERT(!FFlag::LuauLintInTypecheck); - - LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); - - auto [_sourceNode, sourceModule] = getSourceNode(name); - - if (!sourceModule) - return LintResult{}; // FIXME: We really should do something a bit more obvious when a file is too broken to lint. - - return lint_DEPRECATED(*sourceModule, enabledLintWarnings); -} - -LintResult Frontend::lint_DEPRECATED(const SourceModule& module, std::optional enabledLintWarnings) -{ - LUAU_ASSERT(!FFlag::LuauLintInTypecheck); - - LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); - - const Config& config = configResolver->getConfig(module.name); - - uint64_t ignoreLints = LintWarning::parseMask(module.hotcomments); - - LintOptions options = enabledLintWarnings.value_or(config.enabledLint); - options.warningMask &= ~ignoreLints; - - Mode mode = module.mode.value_or(config.mode); - if (mode != Mode::NoCheck) - { - options.disableWarning(Luau::LintWarning::Code_UnknownGlobal); - } - - if (mode == Mode::Strict) - { - options.disableWarning(Luau::LintWarning::Code_ImplicitReturn); - } - - ScopePtr environmentScope = getModuleEnvironment(module, config, /*forAutocomplete*/ false); - - ModulePtr modulePtr = moduleResolver.getModule(module.name); - - double timestamp = getTimestamp(); - - std::vector warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), module.hotcomments, options); - - stats.timeLint += getTimestamp() - timestamp; - - return classifyLints(warnings, config); -} - bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); @@ -1195,7 +1101,7 @@ ScopePtr Frontend::getEnvironmentScope(const std::string& environmentName) const return {}; } -void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) +void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) { LUAU_ASSERT(builtinDefinitions.count(name) == 0); @@ -1208,7 +1114,7 @@ void Frontend::applyBuiltinDefinitionToEnvironment(const std::string& environmen LUAU_ASSERT(builtinDefinitions.count(definitionName) > 0); if (builtinDefinitions.count(definitionName) > 0) - builtinDefinitions[definitionName](typeChecker, globals, getEnvironmentScope(environmentName)); + builtinDefinitions[definitionName](*this, globals, getEnvironmentScope(environmentName)); } LintResult Frontend::classifyLints(const std::vector& warnings, const Config& config) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 7c56a4b8f..46595b702 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -20,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeMetatableFixes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauTransitiveSubtyping) @@ -2062,6 +2063,18 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (isPrim(there, PrimitiveType::Table)) return here; + if (FFlag::LuauNormalizeMetatableFixes) + { + if (get(here)) + return there; + else if (get(there)) + return here; + else if (get(here)) + return there; + else if (get(there)) + return here; + } + TypeId htable = here; TypeId hmtable = nullptr; if (const MetatableType* hmtv = get(here)) @@ -2078,9 +2091,23 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there } const TableType* httv = get(htable); - LUAU_ASSERT(httv); + if (FFlag::LuauNormalizeMetatableFixes) + { + if (!httv) + return std::nullopt; + } + else + LUAU_ASSERT(httv); + const TableType* tttv = get(ttable); - LUAU_ASSERT(tttv); + if (FFlag::LuauNormalizeMetatableFixes) + { + if (!tttv) + return std::nullopt; + } + else + LUAU_ASSERT(tttv); + if (httv->state == TableState::Free || tttv->state == TableState::Free) return std::nullopt; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5c0f48fae..fe09ef11a 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -14,7 +14,6 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) /* * Prefix generic typenames with gen- @@ -369,7 +368,7 @@ struct TypeStringifier state.emit(">"); } - void operator()(TypeId ty, const Unifiable::Free& ftv) + void operator()(TypeId ty, const FreeType& ftv) { state.result.invalid = true; if (FFlag::DebugLuauVerboseTypeNames) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 021d95285..d70f17f57 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -430,6 +430,69 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } +FreeType::FreeType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ +} + +FreeType::FreeType(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ +} + +FreeType::FreeType(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ +} + +GenericType::GenericType() + : index(Unifiable::freshIndex()) + , name("g" + std::to_string(index)) +{ +} + +GenericType::GenericType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , name("g" + std::to_string(index)) +{ +} + +GenericType::GenericType(const Name& name) + : index(Unifiable::freshIndex()) + , name(name) + , explicitName(true) +{ +} + +GenericType::GenericType(Scope* scope) + : index(Unifiable::freshIndex()) + , scope(scope) +{ +} + +GenericType::GenericType(TypeLevel level, const Name& name) + : index(Unifiable::freshIndex()) + , level(level) + , name(name) + , explicitName(true) +{ +} + +GenericType::GenericType(Scope* scope, const Name& name) + : index(Unifiable::freshIndex()) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + BlockedType::BlockedType() : index(FFlag::LuauNormalizeBlockedTypes ? Unifiable::freshIndex() : ++DEPRECATED_nextIndex) { @@ -971,7 +1034,7 @@ const TypeLevel* getLevel(TypeId ty) { ty = follow(ty); - if (auto ftv = get(ty)) + if (auto ftv = get(ty)) return &ftv->level; else if (auto ttv = get(ty)) return &ttv->level; @@ -990,7 +1053,7 @@ std::optional getLevel(TypePackId tp) { tp = follow(tp); - if (auto ftv = get(tp)) + if (auto ftv = get(tp)) return ftv->level; else return std::nullopt; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f9a162056..d6494edfd 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -35,7 +35,21 @@ using SyntheticNames = std::unordered_map; namespace Luau { -static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen) +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const GenericType& gen) +{ + size_t s = syntheticNames->size(); + char*& n = (*syntheticNames)[&gen]; + if (!n) + { + std::string str = gen.explicitName ? gen.name : generateName(s); + n = static_cast(allocator->allocate(str.size() + 1)); + strcpy(n, str.c_str()); + } + + return n; +} + +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const GenericTypePack& gen) { size_t s = syntheticNames->size(); char*& n = (*syntheticNames)[&gen]; @@ -237,7 +251,7 @@ class TypeRehydrationVisitor size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { - if (auto gtv = get(*it)) + if (auto gtv = get(*it)) genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index f47815588..acf70fec1 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1020,7 +1020,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assig right = errorRecoveryType(scope); else if (auto vtp = get(tailPack)) right = vtp->ty; - else if (get(tailPack)) + else if (get(tailPack)) { *asMutable(tailPack) = TypePack{{left}}; growingPack = getMutable(tailPack); @@ -1281,7 +1281,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) callRetPack = checkExprPack(scope, *exprCall).type; callRetPack = follow(callRetPack); - if (get(callRetPack)) + if (get(callRetPack)) { iterTy = freshType(scope); unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location); @@ -1951,7 +1951,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return WithPredicate{errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return WithPredicate{vtp->ty}; - else if (get(varargPack)) + else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); @@ -1970,7 +1970,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (const FreeTypePack* ftp = get(retPack)) + else if (const FreeTypePack* ftp = get(retPack)) { TypeId head = freshType(scope->level); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope->level)}}); @@ -1981,7 +1981,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; - else if (get(retPack)) + else if (get(retPack)) { if (FFlag::LuauReturnAnyInsteadOfICE) return {anyType, std::move(result.predicates)}; @@ -3838,7 +3838,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam if (argTail) { - if (state.log.getMutable(state.log.follow(*argTail))) + if (state.log.getMutable(state.log.follow(*argTail))) { if (paramTail) state.tryUnify(*paramTail, *argTail); @@ -3853,7 +3853,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam else if (paramTail) { // argTail is definitely empty - if (state.log.getMutable(state.log.follow(*paramTail))) + if (state.log.getMutable(state.log.follow(*paramTail))) state.log.replace(*paramTail, TypePackVar(TypePack{{}})); } @@ -5570,7 +5570,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st } else { - g = addType(Unifiable::Generic{level, n}); + g = addType(GenericType{level, n}); } generics.push_back({g, defaultValue}); @@ -5598,7 +5598,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); + cached = addTypePack(TypePackVar{GenericTypePack{level, n}}); genericPacks.push_back({cached, defaultValue}); scope->privateTypePackBindings[n] = cached; diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index ccea604ff..6873820a7 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -9,6 +9,69 @@ namespace Luau { +FreeTypePack::FreeTypePack(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ +} + +FreeTypePack::FreeTypePack(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ +} + +FreeTypePack::FreeTypePack(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ +} + +GenericTypePack::GenericTypePack() + : index(Unifiable::freshIndex()) + , name("g" + std::to_string(index)) +{ +} + +GenericTypePack::GenericTypePack(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , name("g" + std::to_string(index)) +{ +} + +GenericTypePack::GenericTypePack(const Name& name) + : index(Unifiable::freshIndex()) + , name(name) + , explicitName(true) +{ +} + +GenericTypePack::GenericTypePack(Scope* scope) + : index(Unifiable::freshIndex()) + , scope(scope) +{ +} + +GenericTypePack::GenericTypePack(TypeLevel level, const Name& name) + : index(Unifiable::freshIndex()) + , level(level) + , name(name) + , explicitName(true) +{ +} + +GenericTypePack::GenericTypePack(Scope* scope, const Name& name) + : index(Unifiable::freshIndex()) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + BlockedTypePack::BlockedTypePack() : index(++nextIndex) { @@ -160,8 +223,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId rhsTail = *rhsIter.tail(); { - const Unifiable::Free* lf = get_if(&lhsTail->ty); - const Unifiable::Free* rf = get_if(&rhsTail->ty); + const FreeTypePack* lf = get_if(&lhsTail->ty); + const FreeTypePack* rf = get_if(&rhsTail->ty); if (lf && rf) return lf->index == rf->index; } @@ -174,8 +237,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) } { - const Unifiable::Generic* lg = get_if(&lhsTail->ty); - const Unifiable::Generic* rg = get_if(&rhsTail->ty); + const GenericTypePack* lg = get_if(&lhsTail->ty); + const GenericTypePack* rg = get_if(&rhsTail->ty); if (lg && rg) return lg->index == rg->index; } diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index abdc6c329..2ceb97aae 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -13,71 +13,6 @@ int freshIndex() return ++nextIndex; } -Free::Free(TypeLevel level) - : index(++nextIndex) - , level(level) -{ -} - -Free::Free(Scope* scope) - : index(++nextIndex) - , scope(scope) -{ -} - -Free::Free(Scope* scope, TypeLevel level) - : index(++nextIndex) - , level(level) - , scope(scope) -{ -} - -int Free::DEPRECATED_nextIndex = 0; - -Generic::Generic() - : index(++nextIndex) - , name("g" + std::to_string(index)) -{ -} - -Generic::Generic(TypeLevel level) - : index(++nextIndex) - , level(level) - , name("g" + std::to_string(index)) -{ -} - -Generic::Generic(const Name& name) - : index(++nextIndex) - , name(name) - , explicitName(true) -{ -} - -Generic::Generic(Scope* scope) - : index(++nextIndex) - , scope(scope) -{ -} - -Generic::Generic(TypeLevel level, const Name& name) - : index(++nextIndex) - , level(level) - , name(name) - , explicitName(true) -{ -} - -Generic::Generic(Scope* scope, const Name& name) - : index(++nextIndex) - , scope(scope) - , name(name) - , explicitName(true) -{ -} - -int Generic::DEPRECATED_nextIndex = 0; - Error::Error() : index(++nextIndex) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index b748d115f..642aa399f 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -1489,7 +1489,7 @@ struct WeirdIter bool canGrow() const { - return nullptr != log.getMutable(packId); + return nullptr != log.getMutable(packId); } void grow(TypePackId newTail) @@ -1497,7 +1497,7 @@ struct WeirdIter LUAU_ASSERT(canGrow()); LUAU_ASSERT(log.getMutable(newTail)); - auto freePack = log.getMutable(packId); + auto freePack = log.getMutable(packId); level = freePack->level; if (FFlag::LuauMaintainScopesInUnifier && freePack->scope != nullptr) @@ -1591,7 +1591,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (log.haveSeen(superTp, subTp)) return; - if (log.getMutable(superTp)) + if (log.getMutable(superTp)) { if (!occursCheck(superTp, subTp)) { @@ -1599,7 +1599,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal log.replace(superTp, Unifiable::Bound(widen(subTp))); } } - else if (log.getMutable(subTp)) + else if (log.getMutable(subTp)) { if (!occursCheck(subTp, superTp)) { @@ -2567,9 +2567,9 @@ static void queueTypePack(std::vector& queue, DenseHashSet& break; seenTypePacks.insert(a); - if (state.log.getMutable(a)) + if (state.log.getMutable(a)) { - state.log.replace(a, Unifiable::Bound{anyTypePack}); + state.log.replace(a, BoundTypePack{anyTypePack}); } else if (auto tp = state.log.getMutable(a)) { @@ -2617,7 +2617,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever { tryUnify_(vtp->ty, superVariadic->ty); } - else if (get(tail)) + else if (get(tail)) { reportError(location, GenericError{"Cannot unify variadic and generic packs"}); } @@ -2777,10 +2777,10 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays seen.insert(haystack); - if (log.getMutable(needle)) + if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle)) ice("Expected needle to be free"); if (needle == haystack) @@ -2824,10 +2824,10 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ seen.insert(haystack); - if (log.getMutable(needle)) + if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle)) ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 4fdb04439..6d1f54514 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -14,7 +14,6 @@ #endif LUAU_FASTFLAG(DebugLuauTimeTracing) -LUAU_FASTFLAG(LuauLintInTypecheck) enum class ReportFormat { @@ -81,12 +80,10 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat for (auto& error : cr.errors) reportError(frontend, format, error); - Luau::LintResult lr = FFlag::LuauLintInTypecheck ? cr.lintResult : frontend.lint_DEPRECATED(name); - std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); - for (auto& error : lr.errors) + for (auto& error : cr.lintResult.errors) reportWarning(format, humanReadableName.c_str(), error); - for (auto& warning : lr.warnings) + for (auto& warning : cr.lintResult.warnings) reportWarning(format, humanReadableName.c_str(), warning); if (annotate) @@ -101,7 +98,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat printf("%s", annotated.c_str()); } - return cr.errors.empty() && lr.errors.empty(); + return cr.errors.empty() && cr.lintResult.errors.empty(); } static void displayHelp(const char* argv0) @@ -264,13 +261,13 @@ int main(int argc, char** argv) Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; - frontendOptions.runLintChecks = FFlag::LuauLintInTypecheck; + frontendOptions.runLintChecks = true; CliFileResolver fileResolver; CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); - Luau::registerBuiltinGlobals(frontend.typeChecker, frontend.globals); + Luau::registerBuiltinGlobals(frontend, frontend.globals); Luau::freeze(frontend.globals.globalTypes); #ifdef CALLGRIND diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 0c7387128..def4d0c0c 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -37,6 +37,7 @@ class AssemblyBuilderA64 void movk(RegisterA64 dst, uint16_t src, int shift = 0); // Arithmetics + // TODO: support various kinds of shifts void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); void add(RegisterA64 dst, RegisterA64 src1, uint16_t src2); void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); @@ -50,8 +51,10 @@ class AssemblyBuilderA64 void csel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); // Bitwise - // Note: shifted-register support and bitfield operations are omitted for simplicity // TODO: support immediate arguments (they have odd encoding and forbid many values) + // TODO: support bic (andnot) + // TODO: support shifts + // TODO: support bitfield ops void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); @@ -82,7 +85,7 @@ class AssemblyBuilderA64 void stp(RegisterA64 src1, RegisterA64 src2, AddressA64 dst); // Control flow - // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks + // TODO: support tbz/tbnz; they have 15-bit offsets but they can be useful in constrained cases void b(Label& label); void b(ConditionA64 cond, Label& label); void cbz(RegisterA64 src, Label& label); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 2b2a849c6..467be4664 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -121,6 +121,7 @@ class AssemblyBuilderX64 void vcvttsd2si(OperandX64 dst, OperandX64 src); void vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vcvtsd2ss(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode); // inexact diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 470690b95..75b4940a6 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -19,6 +19,8 @@ void updateUseCounts(IrFunction& function); void updateLastUseLocations(IrFunction& function); +uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t startInstIdx); + // Returns how many values are coming into the block (live in) and how many are coming out of the block (live out) std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block); uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block); diff --git a/CodeGen/include/Luau/IrCallWrapperX64.h b/CodeGen/include/Luau/IrCallWrapperX64.h index b70c8da62..724d46243 100644 --- a/CodeGen/include/Luau/IrCallWrapperX64.h +++ b/CodeGen/include/Luau/IrCallWrapperX64.h @@ -17,10 +17,6 @@ namespace CodeGen namespace X64 { -// When IrInst operands are used, current instruction index is required to track lifetime -// In all other calls it is ok to omit the argument -constexpr uint32_t kInvalidInstIdx = ~0u; - struct IrRegAllocX64; struct ScopedRegX64; @@ -61,6 +57,7 @@ class IrCallWrapperX64 void renameRegister(RegisterX64& target, RegisterX64 reg, RegisterX64 replacement); void renameSourceRegisters(RegisterX64 reg, RegisterX64 replacement); RegisterX64 findConflictingTarget() const; + void renameConflictingRegister(RegisterX64 conflict); int getRegisterUses(RegisterX64 reg) const; void addRegisterUse(RegisterX64 reg); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 752160817..fcf29adb1 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -62,11 +62,12 @@ enum class IrCmd : uint8_t // Get pointer (LuaNode) to table node element at the active cached slot index // A: pointer (Table) + // B: unsigned int (pcpos) GET_SLOT_NODE_ADDR, // Get pointer (LuaNode) to table node element at the main position of the specified key hash // A: pointer (Table) - // B: unsigned int + // B: unsigned int (hash) GET_HASH_NODE_ADDR, // Store a tag into TValue @@ -89,6 +90,13 @@ enum class IrCmd : uint8_t // B: int STORE_INT, + // Store a vector into TValue + // A: Rn + // B: double (x) + // C: double (y) + // D: double (z) + STORE_VECTOR, + // Store a TValue into memory // A: Rn or pointer (TValue) // B: TValue @@ -438,15 +446,6 @@ enum class IrCmd : uint8_t // C: block (forgloop location) FORGPREP_XNEXT_FALLBACK, - // Perform `and` or `or` operation (selecting lhs or rhs based on whether the lhs is truthy) and put the result into target register - // A: Rn (target) - // B: Rn (lhs) - // C: Rn or Kn (rhs) - AND, - ANDK, - OR, - ORK, - // Increment coverage data (saturating 24 bit add) // A: unsigned int (bytecode instruction index) COVERAGE, @@ -622,6 +621,17 @@ struct IrOp static_assert(sizeof(IrOp) == 4); +enum class IrValueKind : uint8_t +{ + Unknown, // Used by SUBSTITUTE, argument has to be checked to get type + None, + Tag, + Int, + Pointer, + Double, + Tvalue, +}; + struct IrInst { IrCmd cmd; @@ -641,8 +651,12 @@ struct IrInst X64::RegisterX64 regX64 = X64::noreg; A64::RegisterA64 regA64 = A64::noreg; bool reusedReg = false; + bool spilled = false; }; +// When IrInst operands are used, current instruction index is often required to track lifetime +constexpr uint32_t kInvalidInstIdx = ~0u; + enum class IrBlockKind : uint8_t { Bytecode, @@ -821,6 +835,13 @@ struct IrFunction LUAU_ASSERT(&block >= blocks.data() && &block <= blocks.data() + blocks.size()); return uint32_t(&block - blocks.data()); } + + uint32_t getInstIndex(const IrInst& inst) + { + // Can only be called with instructions from our vector + LUAU_ASSERT(&inst >= instructions.data() && &inst <= instructions.data() + instructions.size()); + return uint32_t(&inst - instructions.data()); + } }; inline IrCondition conditionOp(IrOp op) diff --git a/CodeGen/include/Luau/IrRegAllocX64.h b/CodeGen/include/Luau/IrRegAllocX64.h index c2486faf8..dc7b48c6b 100644 --- a/CodeGen/include/Luau/IrRegAllocX64.h +++ b/CodeGen/include/Luau/IrRegAllocX64.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/AssemblyBuilderX64.h" #include "Luau/IrData.h" #include "Luau/RegisterX64.h" @@ -14,33 +15,66 @@ namespace CodeGen namespace X64 { +constexpr uint8_t kNoStackSlot = 0xff; + +struct IrSpillX64 +{ + uint32_t instIdx = 0; + bool useDoubleSlot = 0; + + // Spill location can be a stack location or be empty + // When it's empty, it means that instruction value can be rematerialized + uint8_t stackSlot = kNoStackSlot; + + RegisterX64 originalLoc = noreg; +}; + struct IrRegAllocX64 { - IrRegAllocX64(IrFunction& function); + IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function); - RegisterX64 allocGprReg(SizeX64 preferredSize); - RegisterX64 allocXmmReg(); + RegisterX64 allocGprReg(SizeX64 preferredSize, uint32_t instIdx); + RegisterX64 allocXmmReg(uint32_t instIdx); - RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list oprefs); - RegisterX64 allocXmmRegOrReuse(uint32_t index, std::initializer_list oprefs); + RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t instIdx, std::initializer_list oprefs); + RegisterX64 allocXmmRegOrReuse(uint32_t instIdx, std::initializer_list oprefs); - RegisterX64 takeReg(RegisterX64 reg); + RegisterX64 takeReg(RegisterX64 reg, uint32_t instIdx); void freeReg(RegisterX64 reg); - void freeLastUseReg(IrInst& target, uint32_t index); - void freeLastUseRegs(const IrInst& inst, uint32_t index); + void freeLastUseReg(IrInst& target, uint32_t instIdx); + void freeLastUseRegs(const IrInst& inst, uint32_t instIdx); - bool isLastUseReg(const IrInst& target, uint32_t index) const; + bool isLastUseReg(const IrInst& target, uint32_t instIdx) const; bool shouldFreeGpr(RegisterX64 reg) const; + // Register used by instruction is about to be freed, have to find a way to restore value later + void preserve(IrInst& inst); + + void restore(IrInst& inst, bool intoOriginalLocation); + + void preserveAndFreeInstValues(); + + uint32_t findInstructionWithFurthestNextUse(const std::array& regInstUsers) const; + void assertFree(RegisterX64 reg) const; void assertAllFree() const; + void assertNoSpills() const; + AssemblyBuilderX64& build; IrFunction& function; + uint32_t currInstIdx = ~0u; + std::array freeGprMap; + std::array gprInstUsers; std::array freeXmmMap; + std::array xmmInstUsers; + + std::bitset<256> usedSpillSlots; + unsigned maxUsedSlot = 0; + std::vector spills; }; struct ScopedRegX64 @@ -62,6 +96,23 @@ struct ScopedRegX64 RegisterX64 reg; }; +// When IR instruction makes a call under a condition that's not reflected as a real branch in IR, +// spilled values have to be restored to their exact original locations, so that both after a call +// and after the skip, values are found in the same place +struct ScopedSpills +{ + explicit ScopedSpills(IrRegAllocX64& owner); + ~ScopedSpills(); + + ScopedSpills(const ScopedSpills&) = delete; + ScopedSpills& operator=(const ScopedSpills&) = delete; + + bool wasSpilledBefore(const IrSpillX64& spill) const; + + IrRegAllocX64& owner; + std::vector snapshot; +}; + } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 6e73e47a6..09c55c799 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -175,6 +175,8 @@ inline bool isPseudo(IrCmd cmd) return cmd == IrCmd::NOP || cmd == IrCmd::SUBSTITUTE; } +IrValueKind getCmdValueKind(IrCmd cmd); + bool isGCO(uint8_t tag); // Manually add or remove use of an operand diff --git a/CodeGen/include/Luau/RegisterA64.h b/CodeGen/include/Luau/RegisterA64.h index 242e8b793..99e62958d 100644 --- a/CodeGen/include/Luau/RegisterA64.h +++ b/CodeGen/include/Luau/RegisterA64.h @@ -37,6 +37,15 @@ struct RegisterA64 } }; +constexpr RegisterA64 castReg(KindA64 kind, RegisterA64 reg) +{ + LUAU_ASSERT(kind != reg.kind); + LUAU_ASSERT(kind != KindA64::none && reg.kind != KindA64::none); + LUAU_ASSERT((kind == KindA64::w || kind == KindA64::x) == (reg.kind == KindA64::w || reg.kind == KindA64::x)); + + return RegisterA64{kind, reg.index}; +} + constexpr RegisterA64 noreg{KindA64::none, 0}; constexpr RegisterA64 w0{KindA64::w, 0}; diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index e7f50b142..a80003e94 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -496,7 +496,7 @@ void AssemblyBuilderA64::fcmpz(RegisterA64 src) { LUAU_ASSERT(src.kind == KindA64::d); - placeFCMP("fcmp", src, {src.kind, 0}, 0b11110'01'1, 0b01); + placeFCMP("fcmp", src, RegisterA64{src.kind, 0}, 0b11110'01'1, 0b01); } void AssemblyBuilderA64::fcsel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond) diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 0285c2a16..d86a37c6e 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -676,6 +676,16 @@ void AssemblyBuilderX64::vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 s placeAvx("vcvtsi2sd", dst, src1, src2, 0x2a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::qword, AVX_0F, AVX_F2); } +void AssemblyBuilderX64::vcvtsd2ss(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + if (src2.cat == CategoryX64::reg) + LUAU_ASSERT(src2.base.size == SizeX64::xmmword); + else + LUAU_ASSERT(src2.memSize == SizeX64::qword); + + placeAvx("vcvtsd2ss", dst, src1, src2, 0x5a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::qword, AVX_0F, AVX_F2); +} + void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode) { placeAvx("vroundsd", dst, src1, src2, uint8_t(roundingMode) | kRoundingPrecisionInexact, 0x0b, false, AVX_0F3A, AVX_66); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index b0cc8d9cd..8e6e94933 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -74,7 +74,7 @@ static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) } template -static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) +static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) { // While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined std::vector sortedBlocks; @@ -193,6 +193,9 @@ static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; lowering.lowerInst(inst, index, next); + + if (lowering.hasError()) + return false; } if (options.includeIr) @@ -213,6 +216,8 @@ static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& if (irLocation != ~0u) asmLocation = bcLocations[irLocation]; } + + return true; } [[maybe_unused]] static bool lowerIr( @@ -226,9 +231,7 @@ static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& X64::IrLoweringX64 lowering(build, helpers, data, ir.function); - lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); - - return true; + return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } [[maybe_unused]] static bool lowerIr( @@ -239,9 +242,7 @@ static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& A64::IrLoweringA64 lowering(build, helpers, data, proto, ir.function); - lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); - - return true; + return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } template diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 2e745cbf2..b010ce627 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -11,8 +11,6 @@ #include "lstate.h" -// TODO: LBF_MATH_FREXP and LBF_MATH_MODF can work for 1 result case if second store is removed - namespace Luau { namespace CodeGen @@ -176,8 +174,11 @@ void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int np build.vmovsd(luauRegValue(ra), xmm0); - build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); - build.vmovsd(luauRegValue(ra + 1), xmm0); + if (nresults > 1) + { + build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); + build.vmovsd(luauRegValue(ra + 1), xmm0); + } } void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) @@ -190,7 +191,8 @@ void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npa build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra), xmm1); - build.vmovsd(luauRegValue(ra + 1), xmm0); + if (nresults > 1) + build.vmovsd(luauRegValue(ra + 1), xmm0); } void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) @@ -248,9 +250,9 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r OperandX64 argsOp = 0; if (args.kind == IrOpKind::VmReg) - argsOp = luauRegAddress(args.index); + argsOp = luauRegAddress(vmRegOp(args)); else if (args.kind == IrOpKind::VmConst) - argsOp = luauConstantAddress(args.index); + argsOp = luauConstantAddress(vmConstOp(args)); switch (bfid) { diff --git a/CodeGen/src/EmitCommonA64.cpp b/CodeGen/src/EmitCommonA64.cpp index 2b4bbaba1..1758e4fb1 100644 --- a/CodeGen/src/EmitCommonA64.cpp +++ b/CodeGen/src/EmitCommonA64.cpp @@ -101,6 +101,30 @@ void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) build.br(x1); } +void emitFallback(AssemblyBuilderA64& build, int op, int pcpos) +{ + // fallback(L, instruction, base, k) + build.mov(x0, rState); + + // TODO: refactor into a common helper + if (pcpos * sizeof(Instruction) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(x1, rCode, uint16_t(pcpos * sizeof(Instruction))); + } + else + { + build.mov(x1, pcpos * sizeof(Instruction)); + build.add(x1, rCode, x1); + } + + build.mov(x2, rBase); + build.mov(x3, rConstants); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback))); + build.blr(x4); + + emitUpdateBase(build); +} + } // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index 5ca9c5586..2a65afa8f 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -46,6 +46,7 @@ void emitUpdateBase(AssemblyBuilderA64& build); void emitExit(AssemblyBuilderA64& build, bool continueInVm); void emitInterrupt(AssemblyBuilderA64& build); void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers); +void emitFallback(AssemblyBuilderA64& build, int op, int pcpos); } // namespace A64 } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 7db4068d0..9136add85 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -196,33 +196,51 @@ void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, Re build.jcc(ConditionX64::Zero, skip); } -void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, Label& skip) +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra) { + Label skip; + ScopedRegX64 tmp{regs, SizeX64::qword}; checkObjectBarrierConditions(build, tmp.reg, object, ra, skip); - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, object, objectOp); - callWrap.addArgument(SizeX64::qword, tmp); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierf)]); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, object, objectOp); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierf)]); + } + + build.setLabel(skip); } -void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp, Label& skip) +void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp) { + Label skip; + // isblack(obj2gco(t)) build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT)); build.jcc(ConditionX64::Zero, skip); - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, table, tableOp); - callWrap.addArgument(SizeX64::qword, addr[table + offsetof(Table, gclist)]); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, table, tableOp); + callWrap.addArgument(SizeX64::qword, addr[table + offsetof(Table, gclist)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); + } + + build.setLabel(skip); } -void callCheckGc(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& skip) +void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build) { + Label skip; + { ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::qword}; @@ -233,11 +251,17 @@ void callCheckGc(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& skip) build.jcc(ConditionX64::Below, skip); } - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::dword, 1); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]); - emitUpdateBase(build); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, 1); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]); + emitUpdateBase(build); + } + + build.setLabel(skip); } void emitExit(AssemblyBuilderX64& build, bool continueInVm) @@ -256,7 +280,7 @@ void emitUpdateBase(AssemblyBuilderX64& build) } // Note: only uses rax/rdx, the caller may use other registers -void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos) +static void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos) { build.mov(rdx, sCode); build.add(rdx, pcpos * sizeof(Instruction)); @@ -298,9 +322,6 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos) void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos) { - if (op == LOP_CAPTURE) - return; - NativeFallback& opinfo = data.context.fallback[op]; LUAU_ASSERT(opinfo.fallback); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 85045ad5b..6aac5a1ec 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -42,12 +42,14 @@ constexpr RegisterX64 rConstants = r12; // TValue* k // Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point // See CodeGenX64.cpp for layout -constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments -constexpr unsigned kLocalsSize = 24; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary) +constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments +constexpr unsigned kSpillSlots = 4; // locations for register allocator to spill data into +constexpr unsigned kLocalsSize = 24 + 8 * kSpillSlots; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary) constexpr OperandX64 sClosure = qword[rsp + kStackSize + 0]; // Closure* cl constexpr OperandX64 sCode = qword[rsp + kStackSize + 8]; // Instruction* code constexpr OperandX64 sTemporarySlot = addr[rsp + kStackSize + 16]; +constexpr OperandX64 sSpillArea = addr[rsp + kStackSize + 24]; // TODO: These should be replaced with a portable call function that checks the ABI at runtime and reorders moves accordingly to avoid conflicts #if defined(_WIN32) @@ -99,6 +101,11 @@ inline OperandX64 luauRegValueInt(int ri) return dword[rBase + ri * sizeof(TValue) + offsetof(TValue, value)]; } +inline OperandX64 luauRegValueVector(int ri, int index) +{ + return dword[rBase + ri * sizeof(TValue) + offsetof(TValue, value) + (sizeof(float) * index)]; +} + inline OperandX64 luauConstant(int ki) { return xmmword[rConstants + ki * sizeof(TValue)]; @@ -247,13 +254,12 @@ void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip); -void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, Label& skip); -void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp, Label& skip); -void callCheckGc(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& skip); +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra); +void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp); +void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitUpdateBase(AssemblyBuilderX64& build); -void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos); // Note: only uses rax/rdx, the caller may use other registers void emitInterrupt(AssemblyBuilderX64& build, int pcpos); void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos); diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index b645f9f7a..c0a64274a 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -316,7 +316,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i build.jmp(qword[rdx + rax * 2]); } -void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& next, int ra, int rb, int count, uint32_t index) +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index) { OperandX64 last = index + count - 1; @@ -347,7 +347,7 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& next Label skipResize; - RegisterX64 table = regs.takeReg(rax); + RegisterX64 table = regs.takeReg(rax, kInvalidInstIdx); build.mov(table, luauRegValue(ra)); @@ -412,7 +412,7 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& next build.setLabel(endLoop); } - callBarrierTableFast(regs, build, table, {}, next); + callBarrierTableFast(regs, build, table, {}); } void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit) @@ -504,82 +504,6 @@ void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, build.jmp(target); } -static void emitInstAndX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c) -{ - Label target, fallthrough; - jumpIfFalsy(build, rb, target, fallthrough); - - build.setLabel(fallthrough); - - build.vmovups(xmm0, c); - build.vmovups(luauReg(ra), xmm0); - - if (ra == rb) - { - build.setLabel(target); - } - else - { - Label exit; - build.jmp(exit); - - build.setLabel(target); - - build.vmovups(xmm0, luauReg(rb)); - build.vmovups(luauReg(ra), xmm0); - - build.setLabel(exit); - } -} - -void emitInstAnd(AssemblyBuilderX64& build, int ra, int rb, int rc) -{ - emitInstAndX(build, ra, rb, luauReg(rc)); -} - -void emitInstAndK(AssemblyBuilderX64& build, int ra, int rb, int kc) -{ - emitInstAndX(build, ra, rb, luauConstant(kc)); -} - -static void emitInstOrX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c) -{ - Label target, fallthrough; - jumpIfTruthy(build, rb, target, fallthrough); - - build.setLabel(fallthrough); - - build.vmovups(xmm0, c); - build.vmovups(luauReg(ra), xmm0); - - if (ra == rb) - { - build.setLabel(target); - } - else - { - Label exit; - build.jmp(exit); - - build.setLabel(target); - - build.vmovups(xmm0, luauReg(rb)); - build.vmovups(luauReg(ra), xmm0); - - build.setLabel(exit); - } -} - -void emitInstOr(AssemblyBuilderX64& build, int ra, int rb, int rc) -{ - emitInstOrX(build, ra, rb, luauReg(rc)); -} - -void emitInstOrK(AssemblyBuilderX64& build, int ra, int rb, int kc) -{ - emitInstOrX(build, ra, rb, luauConstant(kc)); -} - void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux) { build.mov(rax, sClosure); diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index cc1b86456..d58e13310 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -19,14 +19,10 @@ struct IrRegAllocX64; void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults); -void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& next, int ra, int rb, int count, uint32_t index); +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index); void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit); void emitinstForGLoopFallback(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat); void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target); -void emitInstAnd(AssemblyBuilderX64& build, int ra, int rb, int rc); -void emitInstAndK(AssemblyBuilderX64& build, int ra, int rb, int kc); -void emitInstOr(AssemblyBuilderX64& build, int ra, int rb, int rc); -void emitInstOrK(AssemblyBuilderX64& build, int ra, int rb, int kc); void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux); void emitInstCoverage(AssemblyBuilderX64& build, int pcpos); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index b248b97d5..2246e5c5e 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -69,6 +69,9 @@ void updateLastUseLocations(IrFunction& function) instructions[op.index].lastUse = uint32_t(instIdx); }; + if (isPseudo(inst.cmd)) + continue; + checkOp(inst.a); checkOp(inst.b); checkOp(inst.c); @@ -78,6 +81,42 @@ void updateLastUseLocations(IrFunction& function) } } +uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t startInstIdx) +{ + LUAU_ASSERT(startInstIdx < function.instructions.size()); + IrInst& targetInst = function.instructions[targetInstIdx]; + + for (uint32_t i = startInstIdx; i <= targetInst.lastUse; i++) + { + IrInst& inst = function.instructions[i]; + + if (isPseudo(inst.cmd)) + continue; + + if (inst.a.kind == IrOpKind::Inst && inst.a.index == targetInstIdx) + return i; + + if (inst.b.kind == IrOpKind::Inst && inst.b.index == targetInstIdx) + return i; + + if (inst.c.kind == IrOpKind::Inst && inst.c.index == targetInstIdx) + return i; + + if (inst.d.kind == IrOpKind::Inst && inst.d.index == targetInstIdx) + return i; + + if (inst.e.kind == IrOpKind::Inst && inst.e.index == targetInstIdx) + return i; + + if (inst.f.kind == IrOpKind::Inst && inst.f.index == targetInstIdx) + return i; + } + + // There must be a next use since there is the last use location + LUAU_ASSERT(!"failed to find next use"); + return targetInst.lastUse; +} + std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block) { uint32_t liveIns = 0; @@ -97,6 +136,9 @@ std::pair getLiveInOutValueCount(IrFunction& function, IrBlo { IrInst& inst = function.instructions[instIdx]; + if (isPseudo(inst.cmd)) + continue; + liveOuts += inst.useCount; checkOp(inst.a); @@ -149,26 +191,24 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& RegisterSet inRs; auto def = [&](IrOp op, int offset = 0) { - LUAU_ASSERT(op.kind == IrOpKind::VmReg); - defRs.regs.set(op.index + offset, true); + defRs.regs.set(vmRegOp(op) + offset, true); }; auto use = [&](IrOp op, int offset = 0) { - LUAU_ASSERT(op.kind == IrOpKind::VmReg); - if (!defRs.regs.test(op.index + offset)) - inRs.regs.set(op.index + offset, true); + if (!defRs.regs.test(vmRegOp(op) + offset)) + inRs.regs.set(vmRegOp(op) + offset, true); }; auto maybeDef = [&](IrOp op) { if (op.kind == IrOpKind::VmReg) - defRs.regs.set(op.index, true); + defRs.regs.set(vmRegOp(op), true); }; auto maybeUse = [&](IrOp op) { if (op.kind == IrOpKind::VmReg) { - if (!defRs.regs.test(op.index)) - inRs.regs.set(op.index, true); + if (!defRs.regs.test(vmRegOp(op))) + inRs.regs.set(vmRegOp(op), true); } }; @@ -230,6 +270,7 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::STORE_POINTER: case IrCmd::STORE_DOUBLE: case IrCmd::STORE_INT: + case IrCmd::STORE_VECTOR: case IrCmd::STORE_TVALUE: maybeDef(inst.a); // Argument can also be a pointer value break; @@ -264,9 +305,9 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& def(inst.a); break; case IrCmd::CONCAT: - useRange(inst.a.index, function.uintOp(inst.b)); + useRange(vmRegOp(inst.a), function.uintOp(inst.b)); - defRange(inst.a.index, function.uintOp(inst.b)); + defRange(vmRegOp(inst.a), function.uintOp(inst.b)); break; case IrCmd::GET_UPVALUE: def(inst.a); @@ -298,20 +339,20 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& maybeUse(inst.a); if (function.boolOp(inst.b)) - capturedRegs.set(inst.a.index, true); + capturedRegs.set(vmRegOp(inst.a), true); break; case IrCmd::SETLIST: use(inst.b); - useRange(inst.c.index, function.intOp(inst.d)); + useRange(vmRegOp(inst.c), function.intOp(inst.d)); break; case IrCmd::CALL: use(inst.a); - useRange(inst.a.index + 1, function.intOp(inst.b)); + useRange(vmRegOp(inst.a) + 1, function.intOp(inst.b)); - defRange(inst.a.index, function.intOp(inst.c)); + defRange(vmRegOp(inst.a), function.intOp(inst.c)); break; case IrCmd::RETURN: - useRange(inst.a.index, function.intOp(inst.b)); + useRange(vmRegOp(inst.a), function.intOp(inst.b)); break; case IrCmd::FASTCALL: case IrCmd::INVOKE_FASTCALL: @@ -319,9 +360,9 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& { if (count >= 3) { - LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && inst.d.index == inst.c.index + 1); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1); - useRange(inst.c.index, count); + useRange(vmRegOp(inst.c), count); } else { @@ -334,12 +375,12 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& } else { - useVarargs(inst.c.index); + useVarargs(vmRegOp(inst.c)); } // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG if (int count = function.intOp(inst.f); count != -1) - defRange(inst.b.index, count); + defRange(vmRegOp(inst.b), count); break; case IrCmd::FORGLOOP: // First register is not used by instruction, we check that it's still 'nil' with CHECK_TAG @@ -347,32 +388,17 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& use(inst.a, 2); def(inst.a, 2); - defRange(inst.a.index + 3, function.intOp(inst.b)); + defRange(vmRegOp(inst.a) + 3, function.intOp(inst.b)); break; case IrCmd::FORGLOOP_FALLBACK: - useRange(inst.a.index, 3); + useRange(vmRegOp(inst.a), 3); def(inst.a, 2); - defRange(inst.a.index + 3, uint8_t(function.intOp(inst.b))); // ignore most significant bit + defRange(vmRegOp(inst.a) + 3, uint8_t(function.intOp(inst.b))); // ignore most significant bit break; case IrCmd::FORGPREP_XNEXT_FALLBACK: use(inst.b); break; - // A <- B, C - case IrCmd::AND: - case IrCmd::OR: - use(inst.b); - use(inst.c); - - def(inst.a); - break; - // A <- B - case IrCmd::ANDK: - case IrCmd::ORK: - use(inst.b); - - def(inst.a); - break; case IrCmd::FALLBACK_GETGLOBAL: def(inst.b); break; @@ -391,13 +417,13 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::FALLBACK_NAMECALL: use(inst.c); - defRange(inst.b.index, 2); + defRange(vmRegOp(inst.b), 2); break; case IrCmd::FALLBACK_PREPVARARGS: // No effect on explicitly referenced registers break; case IrCmd::FALLBACK_GETVARARGS: - defRange(inst.b.index, function.intOp(inst.c)); + defRange(vmRegOp(inst.b), function.intOp(inst.c)); break; case IrCmd::FALLBACK_NEWCLOSURE: def(inst.b); @@ -408,10 +434,10 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::FALLBACK_FORGPREP: use(inst.b); - defRange(inst.b.index, 3); + defRange(vmRegOp(inst.b), 3); break; case IrCmd::ADJUST_STACK_TO_REG: - defRange(inst.a.index, -1); + defRange(vmRegOp(inst.a), -1); break; case IrCmd::ADJUST_STACK_TO_TOP: // While this can be considered to be a vararg consumer, it is already handled in fastcall instructions diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 4fee080ba..48c0e25c0 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -364,16 +364,16 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstForGPrepInext(*this, pc, i); break; case LOP_AND: - inst(IrCmd::AND, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); + translateInstAndX(*this, pc, i, vmReg(LUAU_INSN_C(*pc))); break; case LOP_ANDK: - inst(IrCmd::ANDK, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); + translateInstAndX(*this, pc, i, vmConst(LUAU_INSN_C(*pc))); break; case LOP_OR: - inst(IrCmd::OR, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); + translateInstOrX(*this, pc, i, vmReg(LUAU_INSN_C(*pc))); break; case LOP_ORK: - inst(IrCmd::ORK, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); + translateInstOrX(*this, pc, i, vmConst(LUAU_INSN_C(*pc))); break; case LOP_COVERAGE: inst(IrCmd::COVERAGE, constUint(i)); diff --git a/CodeGen/src/IrCallWrapperX64.cpp b/CodeGen/src/IrCallWrapperX64.cpp index 4f0c0cf66..8ac5f8bcf 100644 --- a/CodeGen/src/IrCallWrapperX64.cpp +++ b/CodeGen/src/IrCallWrapperX64.cpp @@ -58,14 +58,17 @@ void IrCallWrapperX64::call(const OperandX64& func) { CallArgument& arg = args[i]; - // If source is the last use of IrInst, clear the register - // Source registers are recorded separately in CallArgument if (arg.sourceOp.kind != IrOpKind::None) { if (IrInst* inst = regs.function.asInstOp(arg.sourceOp)) { + // Source registers are recorded separately from source operands in CallArgument + // If source is the last use of IrInst, clear the register from the operand if (regs.isLastUseReg(*inst, instIdx)) inst->regX64 = noreg; + // If it's not the last use and register is volatile, register ownership is taken, which also spills the operand + else if (inst->regX64.size == SizeX64::xmmword || regs.shouldFreeGpr(inst->regX64)) + regs.takeReg(inst->regX64, kInvalidInstIdx); } } @@ -83,7 +86,11 @@ void IrCallWrapperX64::call(const OperandX64& func) freeSourceRegisters(arg); - build.mov(tmp.reg, arg.source); + if (arg.source.memSize == SizeX64::none) + build.lea(tmp.reg, arg.source); + else + build.mov(tmp.reg, arg.source); + build.mov(arg.target, tmp.reg); } else @@ -102,7 +109,7 @@ void IrCallWrapperX64::call(const OperandX64& func) // If target is not used as source in other arguments, prevent register allocator from giving it out if (getRegisterUses(arg.target.base) == 0) - regs.takeReg(arg.target.base); + regs.takeReg(arg.target.base, kInvalidInstIdx); else // Otherwise, make sure we won't free it when last source use is completed addRegisterUse(arg.target.base); @@ -122,7 +129,7 @@ void IrCallWrapperX64::call(const OperandX64& func) freeSourceRegisters(*candidate); LUAU_ASSERT(getRegisterUses(candidate->target.base) == 0); - regs.takeReg(candidate->target.base); + regs.takeReg(candidate->target.base, kInvalidInstIdx); moveToTarget(*candidate); @@ -131,15 +138,7 @@ void IrCallWrapperX64::call(const OperandX64& func) // If all registers cross-interfere (rcx <- rdx, rdx <- rcx), one has to be renamed else if (RegisterX64 conflict = findConflictingTarget(); conflict != noreg) { - // Get a fresh register - RegisterX64 freshReg = conflict.size == SizeX64::xmmword ? regs.allocXmmReg() : regs.allocGprReg(conflict.size); - - if (conflict.size == SizeX64::xmmword) - build.vmovsd(freshReg, conflict, conflict); - else - build.mov(freshReg, conflict); - - renameSourceRegisters(conflict, freshReg); + renameConflictingRegister(conflict); } else { @@ -156,10 +155,18 @@ void IrCallWrapperX64::call(const OperandX64& func) if (arg.source.cat == CategoryX64::imm) { + // There could be a conflict with the function source register, make this argument a candidate to find it + arg.candidate = true; + + if (RegisterX64 conflict = findConflictingTarget(); conflict != noreg) + renameConflictingRegister(conflict); + if (arg.target.cat == CategoryX64::reg) - regs.takeReg(arg.target.base); + regs.takeReg(arg.target.base, kInvalidInstIdx); moveToTarget(arg); + + arg.candidate = false; } } @@ -176,6 +183,10 @@ void IrCallWrapperX64::call(const OperandX64& func) regs.freeReg(arg.target.base); } + regs.preserveAndFreeInstValues(); + + regs.assertAllFree(); + build.call(funcOp); } @@ -362,6 +373,19 @@ RegisterX64 IrCallWrapperX64::findConflictingTarget() const return noreg; } +void IrCallWrapperX64::renameConflictingRegister(RegisterX64 conflict) +{ + // Get a fresh register + RegisterX64 freshReg = conflict.size == SizeX64::xmmword ? regs.allocXmmReg(kInvalidInstIdx) : regs.allocGprReg(conflict.size, kInvalidInstIdx); + + if (conflict.size == SizeX64::xmmword) + build.vmovsd(freshReg, conflict, conflict); + else + build.mov(freshReg, conflict); + + renameSourceRegisters(conflict, freshReg); +} + int IrCallWrapperX64::getRegisterUses(RegisterX64 reg) const { return reg.size == SizeX64::xmmword ? xmmUses[reg.index] : (reg.size != SizeX64::none ? gprUses[reg.index] : 0); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index fb56df8c5..8f299520b 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -100,6 +100,8 @@ const char* getCmdName(IrCmd cmd) return "STORE_DOUBLE"; case IrCmd::STORE_INT: return "STORE_INT"; + case IrCmd::STORE_VECTOR: + return "STORE_VECTOR"; case IrCmd::STORE_TVALUE: return "STORE_TVALUE"; case IrCmd::STORE_NODE_VALUE_TV: @@ -238,14 +240,6 @@ const char* getCmdName(IrCmd cmd) return "FORGLOOP_FALLBACK"; case IrCmd::FORGPREP_XNEXT_FALLBACK: return "FORGPREP_XNEXT_FALLBACK"; - case IrCmd::AND: - return "AND"; - case IrCmd::ANDK: - return "ANDK"; - case IrCmd::OR: - return "OR"; - case IrCmd::ORK: - return "ORK"; case IrCmd::COVERAGE: return "COVERAGE"; case IrCmd::FALLBACK_GETGLOBAL: @@ -345,13 +339,13 @@ void toString(IrToStringContext& ctx, IrOp op) append(ctx.result, "%s_%u", getBlockKindName(ctx.blocks[op.index].kind), op.index); break; case IrOpKind::VmReg: - append(ctx.result, "R%u", op.index); + append(ctx.result, "R%d", vmRegOp(op)); break; case IrOpKind::VmConst: - append(ctx.result, "K%u", op.index); + append(ctx.result, "K%d", vmConstOp(op)); break; case IrOpKind::VmUpvalue: - append(ctx.result, "U%u", op.index); + append(ctx.result, "U%d", vmUpvalueOp(op)); break; } } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 37f381572..7f0305cc2 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -12,6 +12,7 @@ #include "NativeState.h" #include "lstate.h" +#include "lgc.h" // TODO: Eventually this can go away // #define TRACE @@ -32,7 +33,7 @@ struct LoweringStatsA64 ~LoweringStatsA64() { if (total) - printf("A64 lowering succeded for %.1f%% functions (%d/%d)\n", double(can) / double(total) * 100, int(can), int(total)); + printf("A64 lowering succeeded for %.1f%% functions (%d/%d)\n", double(can) / double(total) * 100, int(can), int(total)); } } gStatsA64; #endif @@ -77,6 +78,34 @@ inline ConditionA64 getConditionFP(IrCondition cond) } } +// TODO: instead of temp1/temp2 we can take a register that we will use for ra->value; that way callers to this function will be able to use it when +// calling luaC_barrier* +static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp1, RegisterA64 temp2, int ra, Label& skip) +{ + RegisterA64 temp1w = castReg(KindA64::w, temp1); + RegisterA64 temp2w = castReg(KindA64::w, temp2); + + // iscollectable(ra) + build.ldr(temp1w, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, tt))); + build.cmp(temp1w, LUA_TSTRING); + build.b(ConditionA64::Less, skip); + + // isblack(obj2gco(o)) + // TODO: conditional bit test with BLACKBIT + build.ldrb(temp1w, mem(object, offsetof(GCheader, marked))); + build.mov(temp2w, bitmask(BLACKBIT)); + build.and_(temp1w, temp1w, temp2w); + build.cbz(temp1w, skip); + + // iswhite(gcvalue(ra)) + // TODO: tst with bitmask(WHITE0BIT, WHITE1BIT) + build.ldr(temp1, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, value))); + build.ldrb(temp1w, mem(temp1, offsetof(GCheader, marked))); + build.mov(temp2w, bit2mask(WHITE0BIT, WHITE1BIT)); + build.and_(temp1w, temp1w, temp2w); + build.cbz(temp1w, skip); +} + IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) : build(build) , helpers(helpers) @@ -108,37 +137,89 @@ bool IrLoweringA64::canLower(const IrFunction& function) case IrCmd::LOAD_TVALUE: case IrCmd::LOAD_NODE_VALUE_TV: case IrCmd::LOAD_ENV: + case IrCmd::GET_ARR_ADDR: + case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::GET_HASH_NODE_ADDR: case IrCmd::STORE_TAG: case IrCmd::STORE_POINTER: case IrCmd::STORE_DOUBLE: case IrCmd::STORE_INT: case IrCmd::STORE_TVALUE: case IrCmd::STORE_NODE_VALUE_TV: + case IrCmd::ADD_INT: + case IrCmd::SUB_INT: case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: + case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: + case IrCmd::FLOOR_NUM: + case IrCmd::CEIL_NUM: + case IrCmd::ROUND_NUM: + case IrCmd::SQRT_NUM: + case IrCmd::ABS_NUM: case IrCmd::JUMP: + case IrCmd::JUMP_IF_TRUTHY: + case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_EQ_TAG: + case IrCmd::JUMP_EQ_INT: + case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_ANY: + case IrCmd::TABLE_LEN: + case IrCmd::NEW_TABLE: + case IrCmd::DUP_TABLE: + case IrCmd::TRY_NUM_TO_INDEX: + case IrCmd::INT_TO_NUM: + case IrCmd::ADJUST_STACK_TO_REG: + case IrCmd::ADJUST_STACK_TO_TOP: + case IrCmd::INVOKE_FASTCALL: + case IrCmd::CHECK_FASTCALL_RES: case IrCmd::DO_ARITH: + case IrCmd::DO_LEN: + case IrCmd::GET_TABLE: + case IrCmd::SET_TABLE: case IrCmd::GET_IMPORT: + case IrCmd::CONCAT: case IrCmd::GET_UPVALUE: + case IrCmd::SET_UPVALUE: + case IrCmd::PREPARE_FORN: case IrCmd::CHECK_TAG: case IrCmd::CHECK_READONLY: case IrCmd::CHECK_NO_METATABLE: case IrCmd::CHECK_SAFE_ENV: + case IrCmd::CHECK_ARRAY_SIZE: + case IrCmd::CHECK_SLOT_MATCH: case IrCmd::INTERRUPT: + case IrCmd::CHECK_GC: + case IrCmd::BARRIER_OBJ: + case IrCmd::BARRIER_TABLE_BACK: + case IrCmd::BARRIER_TABLE_FORWARD: case IrCmd::SET_SAVEDPC: + case IrCmd::CLOSE_UPVALS: + case IrCmd::CAPTURE: case IrCmd::CALL: case IrCmd::RETURN: + case IrCmd::FALLBACK_GETGLOBAL: + case IrCmd::FALLBACK_SETGLOBAL: + case IrCmd::FALLBACK_GETTABLEKS: + case IrCmd::FALLBACK_SETTABLEKS: + case IrCmd::FALLBACK_NAMECALL: + case IrCmd::FALLBACK_PREPVARARGS: + case IrCmd::FALLBACK_GETVARARGS: + case IrCmd::FALLBACK_NEWCLOSURE: + case IrCmd::FALLBACK_DUPCLOSURE: case IrCmd::SUBSTITUTE: continue; default: +#ifdef TRACE + printf("A64 lowering missing %s\n", getCmdName(inst.cmd)); +#endif return false; } } @@ -199,6 +280,64 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regA64 = regs.allocReg(KindA64::x); build.ldr(inst.regA64, mem(rClosure, offsetof(Closure, env))); break; + case IrCmd::GET_ARR_ADDR: + { + inst.regA64 = regs.allocReg(KindA64::x); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, array))); + + if (inst.b.kind == IrOpKind::Inst) + { + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(inst.regA64, inst.regA64, castReg(KindA64::x, regOp(inst.b)), kTValueSizeLog2); + } + else if (inst.b.kind == IrOpKind::Constant) + { + LUAU_ASSERT(size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate >> kTValueSizeLog2); // TODO: handle out of range values + build.add(inst.regA64, inst.regA64, uint16_t(intOp(inst.b) << kTValueSizeLog2)); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + break; + } + case IrCmd::GET_SLOT_NODE_ADDR: + { + inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp1w = castReg(KindA64::w, temp1); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + + // TODO: this can use a slightly more efficient sequence with a 4b load + and-with-right-shift for pcpos<1024 but we don't support it yet. + build.mov(temp1, uintOp(inst.b) * sizeof(Instruction) + kOffsetOfInstructionC); + build.ldrb(temp1w, mem(rCode, temp1)); + build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, nodemask8))); + build.and_(temp2, temp2, temp1w); + + // note: this may clobber inst.a, so it's important that we don't use it after this + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(inst.regA64, inst.regA64, castReg(KindA64::x, temp2), kLuaNodeSizeLog2); + break; + } + case IrCmd::GET_HASH_NODE_ADDR: + { + inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); + RegisterA64 temp1 = regs.allocTemp(KindA64::w); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + + // TODO: this can use bic (andnot) to do hash & ~(-1 << lsizenode) instead but we don't support it yet + build.mov(temp1, 1); + build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, lsizenode))); + build.lsl(temp1, temp1, temp2); + build.sub(temp1, temp1, 1); + build.mov(temp2, uintOp(inst.b)); + build.and_(temp2, temp2, temp1); + + // note: this may clobber inst.a, so it's important that we don't use it after this + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(inst.regA64, inst.regA64, castReg(KindA64::x, temp2), kLuaNodeSizeLog2); + break; + } case IrCmd::STORE_TAG: { RegisterA64 temp = regs.allocTemp(KindA64::w); @@ -236,6 +375,16 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::STORE_NODE_VALUE_TV: build.str(regOp(inst.b), mem(regOp(inst.a), offsetof(LuaNode, val))); break; + case IrCmd::ADD_INT: + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); // TODO: handle out of range values + inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a}); + build.add(inst.regA64, regOp(inst.a), uint16_t(intOp(inst.b))); + break; + case IrCmd::SUB_INT: + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); // TODO: handle out of range values + inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a}); + build.sub(inst.regA64, regOp(inst.a), uint16_t(intOp(inst.b))); + break; case IrCmd::ADD_NUM: { inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); @@ -270,7 +419,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::MOD_NUM: { - inst.regA64 = regs.allocReg(KindA64::d); + inst.regA64 = regs.allocReg(KindA64::d); // can't allocReuse because both A and B are used twice RegisterA64 temp1 = tempDouble(inst.a); RegisterA64 temp2 = tempDouble(inst.b); build.fdiv(inst.regA64, temp1, temp2); @@ -279,6 +428,37 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.fsub(inst.regA64, temp1, inst.regA64); break; } + case IrCmd::POW_NUM: + { + // TODO: this instruction clobbers all registers because of a call but it's unclear how to assert that cleanly atm + inst.regA64 = regs.allocReg(KindA64::d); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fmov(d0, temp1); // TODO: aliasing hazard + build.fmov(d1, temp2); // TODO: aliasing hazard + build.ldr(x0, mem(rNativeContext, offsetof(NativeContext, libm_pow))); + build.blr(x0); + build.fmov(inst.regA64, d0); + break; + } + case IrCmd::MIN_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fcmp(temp1, temp2); + build.fcsel(inst.regA64, temp1, temp2, getConditionFP(IrCondition::Less)); + break; + } + case IrCmd::MAX_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fcmp(temp1, temp2); + build.fcsel(inst.regA64, temp1, temp2, getConditionFP(IrCondition::Greater)); + break; + } case IrCmd::UNM_NUM: { inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); @@ -286,9 +466,76 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.fneg(inst.regA64, temp); break; } + case IrCmd::FLOOR_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.frintm(inst.regA64, temp); + break; + } + case IrCmd::CEIL_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.frintp(inst.regA64, temp); + break; + } + case IrCmd::ROUND_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.frinta(inst.regA64, temp); + break; + } + case IrCmd::SQRT_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.fsqrt(inst.regA64, temp); + break; + } + case IrCmd::ABS_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.fabs(inst.regA64, temp); + break; + } case IrCmd::JUMP: jumpOrFallthrough(blockOp(inst.a), next); break; + case IrCmd::JUMP_IF_TRUTHY: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); + // nil => falsy + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp, labelOp(inst.c)); + // not boolean => truthy + build.cmp(temp, LUA_TBOOLEAN); + build.b(ConditionA64::NotEqual, labelOp(inst.b)); + // compare boolean value + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, value))); + build.cbnz(temp, labelOp(inst.b)); + jumpOrFallthrough(blockOp(inst.c), next); + break; + } + case IrCmd::JUMP_IF_FALSY: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); + // nil => falsy + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp, labelOp(inst.b)); + // not boolean => truthy + build.cmp(temp, LUA_TBOOLEAN); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + // compare boolean value + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, value))); + build.cbz(temp, labelOp(inst.b)); + jumpOrFallthrough(blockOp(inst.c), next); + break; + } case IrCmd::JUMP_EQ_TAG: if (inst.b.kind == IrOpKind::Constant) build.cmp(regOp(inst.a), tagOp(inst.b)); @@ -308,6 +555,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.c), next); } break; + case IrCmd::JUMP_EQ_INT: + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); + build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); + build.b(ConditionA64::Equal, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + break; + case IrCmd::JUMP_EQ_POINTER: + build.cmp(regOp(inst.a), regOp(inst.b)); + build.b(ConditionA64::Equal, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + break; case IrCmd::JUMP_CMP_NUM: { IrCondition cond = conditionOp(inst.c); @@ -349,6 +607,150 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.e), next); break; } + case IrCmd::TABLE_LEN: + { + regs.assertAllFreeExcept(regOp(inst.a)); + build.mov(x0, regOp(inst.a)); // TODO: minor aliasing hazard + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaH_getn))); + build.blr(x1); + + inst.regA64 = regs.allocReg(KindA64::d); + build.scvtf(inst.regA64, x0); + break; + } + case IrCmd::NEW_TABLE: + { + regs.assertAllFree(); + build.mov(x0, rState); + build.mov(x1, uintOp(inst.a)); + build.mov(x2, uintOp(inst.b)); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaH_new))); + build.blr(x3); + // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + inst.regA64 = regs.allocReg(KindA64::x); + build.mov(inst.regA64, x0); + break; + } + case IrCmd::DUP_TABLE: + { + regs.assertAllFreeExcept(regOp(inst.a)); + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard + build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaH_clone))); + build.blr(x2); + // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + inst.regA64 = regs.allocReg(KindA64::x); + build.mov(inst.regA64, x0); + break; + } + case IrCmd::TRY_NUM_TO_INDEX: + { + inst.regA64 = regs.allocReg(KindA64::w); + RegisterA64 temp1 = tempDouble(inst.a); + + if (build.features & Feature_JSCVT) + { + build.fjcvtzs(inst.regA64, temp1); // fjcvtzs sets PSTATE.Z (equal) iff conversion is exact + build.b(ConditionA64::NotEqual, labelOp(inst.b)); + } + else + { + RegisterA64 temp2 = regs.allocTemp(KindA64::d); + + build.fcvtzs(inst.regA64, temp1); + build.scvtf(temp2, inst.regA64); + build.fcmp(temp1, temp2); + build.b(ConditionA64::NotEqual, labelOp(inst.b)); + } + break; + } + case IrCmd::INT_TO_NUM: + { + inst.regA64 = regs.allocReg(KindA64::d); + RegisterA64 temp = tempInt(inst.a); + build.scvtf(inst.regA64, temp); + break; + } + case IrCmd::ADJUST_STACK_TO_REG: + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + + if (inst.b.kind == IrOpKind::Constant) + { + build.add(temp, rBase, uint16_t((vmRegOp(inst.a) + intOp(inst.b)) * sizeof(TValue))); + build.str(temp, mem(rState, offsetof(lua_State, top))); + } + else if (inst.b.kind == IrOpKind::Inst) + { + build.add(temp, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(temp, temp, castReg(KindA64::x, regOp(inst.b)), kTValueSizeLog2); + build.str(temp, mem(rState, offsetof(lua_State, top))); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + break; + } + case IrCmd::ADJUST_STACK_TO_TOP: + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.ldr(temp, mem(rState, offsetof(lua_State, ci))); + build.ldr(temp, mem(temp, offsetof(CallInfo, top))); + build.str(temp, mem(rState, offsetof(lua_State, top))); + break; + } + case IrCmd::INVOKE_FASTCALL: + { + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + build.mov(w3, intOp(inst.f)); // nresults + + if (inst.d.kind == IrOpKind::VmReg) + build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue))); + else if (inst.d.kind == IrOpKind::VmConst) + { + // TODO: refactor into a common helper + if (vmConstOp(inst.d) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(x4, rConstants, uint16_t(vmConstOp(inst.d) * sizeof(TValue))); + } + else + { + build.mov(x4, vmConstOp(inst.d) * sizeof(TValue)); + build.add(x4, rConstants, x4); + } + } + else + LUAU_ASSERT(boolOp(inst.d) == false); + + // nparams + if (intOp(inst.e) == LUA_MULTRET) + { + // L->top - (ra + 1) + build.ldr(x5, mem(rState, offsetof(lua_State, top))); + build.sub(x5, x5, rBase); + build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue))); + // TODO: this can use immediate shift right or maybe add/sub with shift right but we don't implement them yet + build.mov(x6, kTValueSizeLog2); + build.lsr(x5, x5, x6); + } + else + build.mov(w5, intOp(inst.e)); + + build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction))); + build.blr(x6); + + // TODO: we could takeReg w0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + inst.regA64 = regs.allocReg(KindA64::w); + build.mov(inst.regA64, w0); + break; + } + case IrCmd::CHECK_FASTCALL_RES: + build.cmp(regOp(inst.a), 0); + build.b(ConditionA64::Less, labelOp(inst.b)); + break; case IrCmd::DO_ARITH: regs.assertAllFree(); build.mov(x0, rState); @@ -375,12 +777,76 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_doarith))); build.blr(x5); + emitUpdateBase(build); + break; + case IrCmd::DO_LEN: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_dolen))); + build.blr(x3); + + emitUpdateBase(build); + break; + case IrCmd::GET_TABLE: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (inst.c.kind == IrOpKind::VmReg) + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + else if (inst.c.kind == IrOpKind::Constant) + { + TValue n; + setnvalue(&n, uintOp(inst.c)); + build.adr(x2, &n, sizeof(n)); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + + build.add(x3, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_gettable))); + build.blr(x4); + + emitUpdateBase(build); + break; + case IrCmd::SET_TABLE: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (inst.c.kind == IrOpKind::VmReg) + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + else if (inst.c.kind == IrOpKind::Constant) + { + TValue n; + setnvalue(&n, uintOp(inst.c)); + build.adr(x2, &n, sizeof(n)); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + + build.add(x3, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_settable))); + build.blr(x4); + emitUpdateBase(build); break; case IrCmd::GET_IMPORT: regs.assertAllFree(); emitInstGetImport(build, vmRegOp(inst.a), uintOp(inst.b)); break; + case IrCmd::CONCAT: + regs.assertAllFree(); + build.mov(x0, rState); + build.mov(x1, uintOp(inst.b)); + build.mov(x2, vmRegOp(inst.a) + uintOp(inst.b) - 1); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_concat))); + build.blr(x3); + + emitUpdateBase(build); + break; case IrCmd::GET_UPVALUE: { RegisterA64 temp1 = regs.allocTemp(KindA64::x); @@ -405,6 +871,44 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.str(temp2, mem(rBase, vmRegOp(inst.a) * sizeof(TValue))); break; } + case IrCmd::SET_UPVALUE: + { + regs.assertAllFree(); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + RegisterA64 temp3 = regs.allocTemp(KindA64::q); + RegisterA64 temp4 = regs.allocTemp(KindA64::x); + + // UpVal* + build.ldr(temp1, mem(rClosure, offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.a) + offsetof(TValue, value.gc))); + + build.ldr(temp2, mem(temp1, offsetof(UpVal, v))); + build.ldr(temp3, mem(rBase, vmRegOp(inst.b) * sizeof(TValue))); + build.str(temp3, temp2); + + Label skip; + checkObjectBarrierConditions(build, temp1, temp2, temp4, vmRegOp(inst.b), skip); + + build.mov(x0, rState); + build.mov(x1, temp1); // TODO: aliasing hazard + build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } + case IrCmd::PREPARE_FORN: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_prepareFORN))); + build.blr(x4); + // note: no emitUpdateBase necessary because prepareFORN does not reallocate stack + break; case IrCmd::CHECK_TAG: build.cmp(regOp(inst.a), tagOp(inst.b)); build.b(ConditionA64::NotEqual, labelOp(inst.c)); @@ -426,12 +930,55 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::CHECK_SAFE_ENV: { RegisterA64 temp = regs.allocTemp(KindA64::x); - RegisterA64 tempw{KindA64::w, temp.index}; + RegisterA64 tempw = castReg(KindA64::w, temp); build.ldr(temp, mem(rClosure, offsetof(Closure, env))); build.ldrb(tempw, mem(temp, offsetof(Table, safeenv))); build.cbz(tempw, labelOp(inst.a)); break; } + case IrCmd::CHECK_ARRAY_SIZE: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldr(temp, mem(regOp(inst.a), offsetof(Table, sizearray))); + + if (inst.b.kind == IrOpKind::Inst) + build.cmp(temp, regOp(inst.b)); + else if (inst.b.kind == IrOpKind::Constant) + { + LUAU_ASSERT(size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); // TODO: handle out of range values + build.cmp(temp, uint16_t(intOp(inst.b))); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + + build.b(ConditionA64::UnsignedLessEqual, labelOp(inst.c)); + break; + } + case IrCmd::CHECK_SLOT_MATCH: + { + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp1w = castReg(KindA64::w, temp1); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + RegisterA64 temp2w = castReg(KindA64::w, temp2); + + build.ldr(temp1w, mem(regOp(inst.a), kOffsetOfLuaNodeTag)); + // TODO: this needs bitfield extraction, or and-immediate + build.mov(temp2w, kLuaNodeTagMask); + build.and_(temp1w, temp1w, temp2w); + build.cmp(temp1w, LUA_TSTRING); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + + AddressA64 addr = tempAddr(inst.b, offsetof(TValue, value)); + build.ldr(temp1, mem(regOp(inst.a), offsetof(LuaNode, key.value))); + build.ldr(temp2, addr); + build.cmp(temp1, temp2); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + + build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, val.tt))); + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp1w, labelOp(inst.c)); + break; + } case IrCmd::INTERRUPT: { unsigned int pcpos = uintOp(inst.a); @@ -450,6 +997,93 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.setLabel(skip); break; } + case IrCmd::CHECK_GC: + { + regs.assertAllFree(); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + Label skip; + build.ldr(temp1, mem(rState, offsetof(lua_State, global))); + build.ldr(temp2, mem(temp1, offsetof(global_State, totalbytes))); + build.ldr(temp1, mem(temp1, offsetof(global_State, GCthreshold))); + build.cmp(temp1, temp2); + build.b(ConditionA64::UnsignedGreater, skip); + + build.mov(x0, rState); + build.mov(w1, 1); + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaC_step))); + build.blr(x1); + + emitUpdateBase(build); + build.setLabel(skip); + break; + } + case IrCmd::BARRIER_OBJ: + { + regs.assertAllFreeExcept(regOp(inst.a)); + + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + checkObjectBarrierConditions(build, regOp(inst.a), temp1, temp2, vmRegOp(inst.b), skip); + + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard + build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } + case IrCmd::BARRIER_TABLE_BACK: + { + regs.assertAllFreeExcept(regOp(inst.a)); + + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::w); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + + // isblack(obj2gco(t)) + build.ldrb(temp1, mem(regOp(inst.a), offsetof(GCheader, marked))); + // TODO: conditional bit test with BLACKBIT + build.mov(temp2, bitmask(BLACKBIT)); + build.and_(temp1, temp1, temp2); + build.cbz(temp1, skip); + + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard here and below + build.add(x2, regOp(inst.a), uint16_t(offsetof(Table, gclist))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierback))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } + case IrCmd::BARRIER_TABLE_FORWARD: + { + regs.assertAllFreeExcept(regOp(inst.a)); + + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + checkObjectBarrierConditions(build, regOp(inst.a), temp1, temp2, vmRegOp(inst.b), skip); + + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard + build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barriertable))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } case IrCmd::SET_SAVEDPC: { unsigned int pcpos = uintOp(inst.a); @@ -471,6 +1105,34 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.str(temp1, mem(temp2, offsetof(CallInfo, savedpc))); break; } + case IrCmd::CLOSE_UPVALS: + { + regs.assertAllFree(); + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + // L->openupval != 0 + build.ldr(temp1, mem(rState, offsetof(lua_State, openupval))); + build.cbz(temp1, skip); + + // ra <= L->openuval->v + build.ldr(temp1, mem(temp1, offsetof(UpVal, v))); + build.add(temp2, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.cmp(temp2, temp1); + build.b(ConditionA64::UnsignedGreater, skip); + + build.mov(x0, rState); + build.mov(x1, temp2); // TODO: aliasing hazard + build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaF_close))); + build.blr(x2); + + build.setLabel(skip); + break; + } + case IrCmd::CAPTURE: + // no-op + break; case IrCmd::CALL: regs.assertAllFree(); emitInstCall(build, helpers, vmRegOp(inst.a), intOp(inst.b), intOp(inst.c)); @@ -479,6 +1141,74 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.assertAllFree(); emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); break; + + // Full instruction fallbacks + case IrCmd::FALLBACK_GETGLOBAL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_GETGLOBAL, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_SETGLOBAL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_SETGLOBAL, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_GETTABLEKS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_GETTABLEKS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_SETTABLEKS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_SETTABLEKS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_NAMECALL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_NAMECALL, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_PREPVARARGS: + LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); + + regs.assertAllFree(); + emitFallback(build, LOP_PREPVARARGS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_GETVARARGS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + + regs.assertAllFree(); + emitFallback(build, LOP_GETVARARGS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_NEWCLOSURE: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + + regs.assertAllFree(); + emitFallback(build, LOP_NEWCLOSURE, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_DUPCLOSURE: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_DUPCLOSURE, uintOp(inst.a)); + break; + default: LUAU_ASSERT(!"Not supported yet"); break; @@ -488,6 +1218,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.freeTempRegs(); } +bool IrLoweringA64::hasError() const +{ + return false; +} + bool IrLoweringA64::isFallthroughBlock(IrBlock target, IrBlock next) { return target.start == next.start; diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index f638432ff..b374a26a0 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -30,6 +30,8 @@ struct IrLoweringA64 void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); + bool hasError() const; + bool isFallthroughBlock(IrBlock target, IrBlock next); void jumpOrFallthrough(IrBlock& target, IrBlock& next); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 8c45f36ad..f2dfdb3b1 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -27,18 +27,39 @@ IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, , helpers(helpers) , data(data) , function(function) - , regs(function) + , regs(build, function) { // In order to allocate registers during lowering, we need to know where instruction results are last used updateLastUseLocations(function); } +void IrLoweringX64::storeDoubleAsFloat(OperandX64 dst, IrOp src) +{ + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + if (src.kind == IrOpKind::Constant) + { + build.vmovss(tmp.reg, build.f32(float(doubleOp(src)))); + } + else if (src.kind == IrOpKind::Inst) + { + build.vcvtsd2ss(tmp.reg, regOp(src), regOp(src)); + } + else + { + LUAU_ASSERT(!"Unsupported instruction form"); + } + build.vmovss(dst, tmp.reg); +} + void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { + regs.currInstIdx = index; + switch (inst.cmd) { case IrCmd::LOAD_TAG: - inst.regX64 = regs.allocGprReg(SizeX64::dword); + inst.regX64 = regs.allocGprReg(SizeX64::dword, index); if (inst.a.kind == IrOpKind::VmReg) build.mov(inst.regX64, luauRegTag(vmRegOp(inst.a))); @@ -52,7 +73,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_POINTER: - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); if (inst.a.kind == IrOpKind::VmReg) build.mov(inst.regX64, luauRegValue(vmRegOp(inst.a))); @@ -66,7 +87,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_DOUBLE: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); if (inst.a.kind == IrOpKind::VmReg) build.vmovsd(inst.regX64, luauRegValue(vmRegOp(inst.a))); @@ -76,12 +97,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_INT: - inst.regX64 = regs.allocGprReg(SizeX64::dword); + inst.regX64 = regs.allocGprReg(SizeX64::dword, index); build.mov(inst.regX64, luauRegValueInt(vmRegOp(inst.a))); break; case IrCmd::LOAD_TVALUE: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); if (inst.a.kind == IrOpKind::VmReg) build.vmovups(inst.regX64, luauReg(vmRegOp(inst.a))); @@ -93,12 +114,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_NODE_VALUE_TV: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); build.vmovups(inst.regX64, luauNodeValue(regOp(inst.a))); break; case IrCmd::LOAD_ENV: - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); build.mov(inst.regX64, sClosure); build.mov(inst.regX64, qword[inst.regX64 + offsetof(Closure, env)]); @@ -130,7 +151,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::GET_SLOT_NODE_ADDR: { - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); ScopedRegX64 tmp{regs, SizeX64::qword}; @@ -139,10 +160,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::GET_HASH_NODE_ADDR: { - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); // Custom bit shift value can only be placed in cl - ScopedRegX64 shiftTmp{regs, regs.takeReg(rcx)}; + ScopedRegX64 shiftTmp{regs, regs.takeReg(rcx, kInvalidInstIdx)}; ScopedRegX64 tmp{regs, SizeX64::qword}; @@ -192,6 +213,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; } + case IrCmd::STORE_VECTOR: + { + storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 0), inst.b); + storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 1), inst.c); + storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 2), inst.d); + break; + } case IrCmd::STORE_TVALUE: if (inst.a.kind == IrOpKind::VmReg) build.vmovups(luauReg(vmRegOp(inst.a)), regOp(inst.b)); @@ -330,7 +358,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.a), inst.a); callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); - inst.regX64 = regs.takeReg(xmm0); + inst.regX64 = regs.takeReg(xmm0, index); break; } case IrCmd::MIN_NUM: @@ -398,8 +426,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) ScopedRegX64 tmp1{regs, SizeX64::xmmword}; ScopedRegX64 tmp2{regs, SizeX64::xmmword}; - if (inst.a.kind != IrOpKind::Inst || regOp(inst.a) != inst.regX64) + if (inst.a.kind != IrOpKind::Inst) build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + else if (regOp(inst.a) != inst.regX64) + build.vmovsd(inst.regX64, inst.regX64, regOp(inst.a)); build.vandpd(tmp1.reg, inst.regX64, build.f64x2(-0.0, -0.0)); build.vmovsd(tmp2.reg, build.i64(0x3fdfffffffffffff)); // 0.49999999999999994 @@ -416,8 +446,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::ABS_NUM: inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); - if (inst.a.kind != IrOpKind::Inst || regOp(inst.a) != inst.regX64) + if (inst.a.kind != IrOpKind::Inst) build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + else if (regOp(inst.a) != inst.regX64) + build.vmovsd(inst.regX64, inst.regX64, regOp(inst.a)); build.vandpd(inst.regX64, inst.regX64, build.i64(~(1LL << 63))); break; @@ -526,7 +558,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]); - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); build.vcvtsi2sd(inst.regX64, inst.regX64, eax); break; } @@ -537,7 +569,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.a)), inst.a); callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.b)), inst.b); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); - inst.regX64 = regs.takeReg(rax); + inst.regX64 = regs.takeReg(rax, index); break; } case IrCmd::DUP_TABLE: @@ -546,12 +578,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::qword, rState); callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]); - inst.regX64 = regs.takeReg(rax); + inst.regX64 = regs.takeReg(rax, index); break; } case IrCmd::TRY_NUM_TO_INDEX: { - inst.regX64 = regs.allocGprReg(SizeX64::dword); + inst.regX64 = regs.allocGprReg(SizeX64::dword, index); ScopedRegX64 tmp{regs, SizeX64::xmmword}; @@ -574,35 +606,39 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) ScopedRegX64 tmp2{regs, SizeX64::qword}; build.mov(tmp2.reg, qword[rState + offsetof(lua_State, global)]); - IrCallWrapperX64 callWrap(regs, build, index); - callWrap.addArgument(SizeX64::qword, tmp); - callWrap.addArgument(SizeX64::qword, intOp(inst.b)); - callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*)]); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); - inst.regX64 = regs.takeReg(rax); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.addArgument(SizeX64::qword, intOp(inst.b)); + callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); + } + + inst.regX64 = regs.takeReg(rax, index); break; } case IrCmd::INT_TO_NUM: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); build.vcvtsi2sd(inst.regX64, inst.regX64, regOp(inst.a)); break; case IrCmd::ADJUST_STACK_TO_REG: { + ScopedRegX64 tmp{regs, SizeX64::qword}; + if (inst.b.kind == IrOpKind::Constant) { - ScopedRegX64 tmp{regs, SizeX64::qword}; - build.lea(tmp.reg, addr[rBase + (vmRegOp(inst.a) + intOp(inst.b)) * sizeof(TValue)]); build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg); } else if (inst.b.kind == IrOpKind::Inst) { - ScopedRegX64 tmp(regs, regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.b})); - - build.shl(qwordReg(tmp.reg), kTValueSizeLog2); - build.lea(qwordReg(tmp.reg), addr[rBase + qwordReg(tmp.reg) + vmRegOp(inst.a) * sizeof(TValue)]); - build.mov(qword[rState + offsetof(lua_State, top)], qwordReg(tmp.reg)); + build.mov(dwordReg(tmp.reg), regOp(inst.b)); + build.shl(tmp.reg, kTValueSizeLog2); + build.lea(tmp.reg, addr[rBase + tmp.reg + vmRegOp(inst.a) * sizeof(TValue)]); + build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg); } else { @@ -640,52 +676,37 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) int nparams = intOp(inst.e); int nresults = intOp(inst.f); - regs.assertAllFree(); - - build.mov(rax, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); + ScopedRegX64 func{regs, SizeX64::qword}; + build.mov(func.reg, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); - // 5th parameter (args) is left unset for LOP_FASTCALL1 - if (args.cat == CategoryX64::mem) - { - if (build.abi == ABIX64::Windows) - { - build.lea(rcx, args); - build.mov(sArg5, rcx); - } - else - { - build.lea(rArg5, args); - } - } + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(arg)); + callWrap.addArgument(SizeX64::dword, nresults); + callWrap.addArgument(SizeX64::qword, args); if (nparams == LUA_MULTRET) { - // L->top - (ra + 1) - RegisterX64 reg = (build.abi == ABIX64::Windows) ? rcx : rArg6; + // Compute 'L->top - (ra + 1)', on SystemV, take r9 register to compute directly into the argument + // TODO: IrCallWrapperX64 should provide a way to 'guess' target argument register correctly + RegisterX64 reg = build.abi == ABIX64::Windows ? regs.allocGprReg(SizeX64::qword, kInvalidInstIdx) : regs.takeReg(rArg6, kInvalidInstIdx); + ScopedRegX64 tmp{regs, SizeX64::qword}; + build.mov(reg, qword[rState + offsetof(lua_State, top)]); - build.lea(rdx, addr[rBase + (ra + 1) * sizeof(TValue)]); - build.sub(reg, rdx); + build.lea(tmp.reg, addr[rBase + (ra + 1) * sizeof(TValue)]); + build.sub(reg, tmp.reg); build.shr(reg, kTValueSizeLog2); - if (build.abi == ABIX64::Windows) - build.mov(sArg6, reg); + callWrap.addArgument(SizeX64::dword, dwordReg(reg)); } else { - if (build.abi == ABIX64::Windows) - build.mov(sArg6, nparams); - else - build.mov(rArg6, nparams); + callWrap.addArgument(SizeX64::dword, nparams); } - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(arg)); - build.mov(dwordReg(rArg4), nresults); - - build.call(rax); - - inst.regX64 = regs.takeReg(eax); // Result of a builtin call is returned in eax + callWrap.call(func.release()); + inst.regX64 = regs.takeReg(eax, index); // Result of a builtin call is returned in eax break; } case IrCmd::CHECK_FASTCALL_RES: @@ -738,6 +759,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::GET_IMPORT: + regs.assertAllFree(); emitInstGetImportFallback(build, vmRegOp(inst.a), uintOp(inst.b)); break; case IrCmd::CONCAT: @@ -777,7 +799,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::SET_UPVALUE: { - Label next; ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::qword}; @@ -794,8 +815,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) tmp1.free(); - callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b), next); - build.setLabel(next); + callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b)); break; } case IrCmd::PREPARE_FORN: @@ -859,26 +879,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) emitInterrupt(build, uintOp(inst.a)); break; case IrCmd::CHECK_GC: - { - Label skip; - callCheckGc(regs, build, skip); - build.setLabel(skip); + callStepGc(regs, build); break; - } case IrCmd::BARRIER_OBJ: - { - Label skip; - callBarrierObject(regs, build, regOp(inst.a), inst.a, vmRegOp(inst.b), skip); - build.setLabel(skip); + callBarrierObject(regs, build, regOp(inst.a), inst.a, vmRegOp(inst.b)); break; - } case IrCmd::BARRIER_TABLE_BACK: - { - Label skip; - callBarrierTableFast(regs, build, regOp(inst.a), inst.a, skip); - build.setLabel(skip); + callBarrierTableFast(regs, build, regOp(inst.a), inst.a); break; - } case IrCmd::BARRIER_TABLE_FORWARD: { Label skip; @@ -886,11 +894,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) ScopedRegX64 tmp{regs, SizeX64::qword}; checkObjectBarrierConditions(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), skip); - IrCallWrapperX64 callWrap(regs, build, index); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); - callWrap.addArgument(SizeX64::qword, tmp); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barriertable)]); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barriertable)]); + } build.setLabel(skip); break; @@ -925,10 +937,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) tmp1.free(); - IrCallWrapperX64 callWrap(regs, build, index); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, tmp2); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); + } build.setLabel(next); break; @@ -939,19 +955,17 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // Fallbacks to non-IR instruction implementations case IrCmd::SETLIST: - { - Label next; regs.assertAllFree(); - emitInstSetList(regs, build, next, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); - build.setLabel(next); + emitInstSetList(regs, build, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); break; - } case IrCmd::CALL: regs.assertAllFree(); + regs.assertNoSpills(); emitInstCall(build, helpers, vmRegOp(inst.a), intOp(inst.b), intOp(inst.c)); break; case IrCmd::RETURN: regs.assertAllFree(); + regs.assertNoSpills(); emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); break; case IrCmd::FORGLOOP: @@ -967,22 +981,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.assertAllFree(); emitInstForGPrepXnextFallback(build, uintOp(inst.a), vmRegOp(inst.b), labelOp(inst.c)); break; - case IrCmd::AND: - regs.assertAllFree(); - emitInstAnd(build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); - break; - case IrCmd::ANDK: - regs.assertAllFree(); - emitInstAndK(build, vmRegOp(inst.a), vmRegOp(inst.b), vmConstOp(inst.c)); - break; - case IrCmd::OR: - regs.assertAllFree(); - emitInstOr(build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); - break; - case IrCmd::ORK: - regs.assertAllFree(); - emitInstOrK(build, vmRegOp(inst.a), vmRegOp(inst.b), vmConstOp(inst.c)); - break; case IrCmd::COVERAGE: regs.assertAllFree(); emitInstCoverage(build, uintOp(inst.a)); @@ -1066,6 +1064,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.freeLastUseRegs(inst, index); } +bool IrLoweringX64::hasError() const +{ + // If register allocator had to use more stack slots than we have available, this function can't run natively + if (regs.maxUsedSlot > kSpillSlots) + return true; + + return false; +} + bool IrLoweringX64::isFallthroughBlock(IrBlock target, IrBlock next) { return target.start == next.start; @@ -1077,7 +1084,7 @@ void IrLoweringX64::jumpOrFallthrough(IrBlock& target, IrBlock& next) build.jmp(target.label); } -OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) const +OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) { switch (op.kind) { @@ -1096,7 +1103,7 @@ OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) const return noreg; } -OperandX64 IrLoweringX64::memRegTagOp(IrOp op) const +OperandX64 IrLoweringX64::memRegTagOp(IrOp op) { switch (op.kind) { @@ -1113,9 +1120,13 @@ OperandX64 IrLoweringX64::memRegTagOp(IrOp op) const return noreg; } -RegisterX64 IrLoweringX64::regOp(IrOp op) const +RegisterX64 IrLoweringX64::regOp(IrOp op) { IrInst& inst = function.instOp(op); + + if (inst.spilled) + regs.restore(inst, false); + LUAU_ASSERT(inst.regX64 != noreg); return inst.regX64; } diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index ecaa6a1d5..42d262775 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -27,13 +27,17 @@ struct IrLoweringX64 void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); + bool hasError() const; + bool isFallthroughBlock(IrBlock target, IrBlock next); void jumpOrFallthrough(IrBlock& target, IrBlock& next); + void storeDoubleAsFloat(OperandX64 dst, IrOp src); + // Operand data lookup helpers - OperandX64 memRegDoubleOp(IrOp op) const; - OperandX64 memRegTagOp(IrOp op) const; - RegisterX64 regOp(IrOp op) const; + OperandX64 memRegDoubleOp(IrOp op); + OperandX64 memRegTagOp(IrOp op); + RegisterX64 regOp(IrOp op); IrConst constOp(IrOp op) const; uint8_t tagOp(IrOp op) const; diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index 3609c8e25..c6db9e9e0 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -55,7 +55,7 @@ RegisterA64 IrRegAllocA64::allocReg(KindA64 kind) int index = setBit(set.free); set.free &= ~(1u << index); - return {kind, uint8_t(index)}; + return RegisterA64{kind, uint8_t(index)}; } RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind) @@ -73,7 +73,7 @@ RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind) set.free &= ~(1u << index); set.temp |= 1u << index; - return {kind, uint8_t(index)}; + return RegisterA64{kind, uint8_t(index)}; } RegisterA64 IrRegAllocA64::allocReuse(KindA64 kind, uint32_t index, std::initializer_list oprefs) @@ -151,6 +151,15 @@ void IrRegAllocA64::assertAllFree() const LUAU_ASSERT(simd.free == simd.base); } +void IrRegAllocA64::assertAllFreeExcept(RegisterA64 reg) const +{ + const Set& set = const_cast(this)->getSet(reg.kind); + const Set& other = &set == &gpr ? simd : gpr; + + LUAU_ASSERT(set.free == (set.base & ~(1u << reg.index))); + LUAU_ASSERT(other.free == other.base); +} + IrRegAllocA64::Set& IrRegAllocA64::getSet(KindA64 kind) { switch (kind) diff --git a/CodeGen/src/IrRegAllocA64.h b/CodeGen/src/IrRegAllocA64.h index 2ed0787aa..9ff035528 100644 --- a/CodeGen/src/IrRegAllocA64.h +++ b/CodeGen/src/IrRegAllocA64.h @@ -30,6 +30,7 @@ struct IrRegAllocA64 void freeTempRegs(); void assertAllFree() const; + void assertAllFreeExcept(RegisterA64 reg) const; IrFunction& function; diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index eeb6cfe69..dc9e7f908 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrRegAllocX64.h" +#include "EmitCommonX64.h" + namespace Luau { namespace CodeGen @@ -10,14 +12,22 @@ namespace X64 static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; -IrRegAllocX64::IrRegAllocX64(IrFunction& function) - : function(function) +static bool isFullTvalueOperand(IrCmd cmd) +{ + return cmd == IrCmd::LOAD_TVALUE || cmd == IrCmd::LOAD_NODE_VALUE_TV; +} + +IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function) + : build(build) + , function(function) { freeGprMap.fill(true); + gprInstUsers.fill(kInvalidInstIdx); freeXmmMap.fill(true); + xmmInstUsers.fill(kInvalidInstIdx); } -RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize) +RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize, uint32_t instIdx) { LUAU_ASSERT( preferredSize == SizeX64::byte || preferredSize == SizeX64::word || preferredSize == SizeX64::dword || preferredSize == SizeX64::qword); @@ -27,30 +37,40 @@ RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize) if (freeGprMap[reg.index]) { freeGprMap[reg.index] = false; + gprInstUsers[reg.index] = instIdx; return RegisterX64{preferredSize, reg.index}; } } + // If possible, spill the value with the furthest next use + if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(gprInstUsers); furthestUseTarget != kInvalidInstIdx) + return takeReg(function.instructions[furthestUseTarget].regX64, instIdx); + LUAU_ASSERT(!"Out of GPR registers to allocate"); return noreg; } -RegisterX64 IrRegAllocX64::allocXmmReg() +RegisterX64 IrRegAllocX64::allocXmmReg(uint32_t instIdx) { for (size_t i = 0; i < freeXmmMap.size(); ++i) { if (freeXmmMap[i]) { freeXmmMap[i] = false; + xmmInstUsers[i] = instIdx; return RegisterX64{SizeX64::xmmword, uint8_t(i)}; } } + // Out of registers, spill the value with the furthest next use + if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(xmmInstUsers); furthestUseTarget != kInvalidInstIdx) + return takeReg(function.instructions[furthestUseTarget].regX64, instIdx); + LUAU_ASSERT(!"Out of XMM registers to allocate"); return noreg; } -RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list oprefs) +RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t instIdx, std::initializer_list oprefs) { for (IrOp op : oprefs) { @@ -59,20 +79,21 @@ RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t in IrInst& source = function.instructions[op.index]; - if (source.lastUse == index && !source.reusedReg) + if (source.lastUse == instIdx && !source.reusedReg && !source.spilled) { LUAU_ASSERT(source.regX64.size != SizeX64::xmmword); LUAU_ASSERT(source.regX64 != noreg); source.reusedReg = true; + gprInstUsers[source.regX64.index] = instIdx; return RegisterX64{preferredSize, source.regX64.index}; } } - return allocGprReg(preferredSize); + return allocGprReg(preferredSize, instIdx); } -RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_list oprefs) +RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t instIdx, std::initializer_list oprefs) { for (IrOp op : oprefs) { @@ -81,32 +102,45 @@ RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_l IrInst& source = function.instructions[op.index]; - if (source.lastUse == index && !source.reusedReg) + if (source.lastUse == instIdx && !source.reusedReg && !source.spilled) { LUAU_ASSERT(source.regX64.size == SizeX64::xmmword); LUAU_ASSERT(source.regX64 != noreg); source.reusedReg = true; + xmmInstUsers[source.regX64.index] = instIdx; return source.regX64; } } - return allocXmmReg(); + return allocXmmReg(instIdx); } -RegisterX64 IrRegAllocX64::takeReg(RegisterX64 reg) +RegisterX64 IrRegAllocX64::takeReg(RegisterX64 reg, uint32_t instIdx) { - // In a more advanced register allocator, this would require a spill for the current register user - // But at the current stage we don't have register live ranges intersecting forced register uses if (reg.size == SizeX64::xmmword) { + if (!freeXmmMap[reg.index]) + { + LUAU_ASSERT(xmmInstUsers[reg.index] != kInvalidInstIdx); + preserve(function.instructions[xmmInstUsers[reg.index]]); + } + LUAU_ASSERT(freeXmmMap[reg.index]); freeXmmMap[reg.index] = false; + xmmInstUsers[reg.index] = instIdx; } else { + if (!freeGprMap[reg.index]) + { + LUAU_ASSERT(gprInstUsers[reg.index] != kInvalidInstIdx); + preserve(function.instructions[gprInstUsers[reg.index]]); + } + LUAU_ASSERT(freeGprMap[reg.index]); freeGprMap[reg.index] = false; + gprInstUsers[reg.index] = instIdx; } return reg; @@ -118,17 +152,19 @@ void IrRegAllocX64::freeReg(RegisterX64 reg) { LUAU_ASSERT(!freeXmmMap[reg.index]); freeXmmMap[reg.index] = true; + xmmInstUsers[reg.index] = kInvalidInstIdx; } else { LUAU_ASSERT(!freeGprMap[reg.index]); freeGprMap[reg.index] = true; + gprInstUsers[reg.index] = kInvalidInstIdx; } } -void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t index) +void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t instIdx) { - if (isLastUseReg(target, index)) + if (isLastUseReg(target, instIdx)) { // Register might have already been freed if it had multiple uses inside a single instruction if (target.regX64 == noreg) @@ -139,11 +175,11 @@ void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t index) } } -void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index) +void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t instIdx) { - auto checkOp = [this, index](IrOp op) { + auto checkOp = [this, instIdx](IrOp op) { if (op.kind == IrOpKind::Inst) - freeLastUseReg(function.instructions[op.index], index); + freeLastUseReg(function.instructions[op.index], instIdx); }; checkOp(inst.a); @@ -154,9 +190,132 @@ void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index) checkOp(inst.f); } -bool IrRegAllocX64::isLastUseReg(const IrInst& target, uint32_t index) const +bool IrRegAllocX64::isLastUseReg(const IrInst& target, uint32_t instIdx) const +{ + return target.lastUse == instIdx && !target.reusedReg; +} + +void IrRegAllocX64::preserve(IrInst& inst) +{ + bool doubleSlot = isFullTvalueOperand(inst.cmd); + + // Find a free stack slot. Two consecutive slots might be required for 16 byte TValues, so '- 1' is used + for (unsigned i = 0; i < unsigned(usedSpillSlots.size() - 1); ++i) + { + if (usedSpillSlots.test(i)) + continue; + + if (doubleSlot && usedSpillSlots.test(i + 1)) + { + ++i; // No need to retest this double position + continue; + } + + if (inst.regX64.size == SizeX64::xmmword && doubleSlot) + { + build.vmovups(xmmword[sSpillArea + i * 8], inst.regX64); + } + else if (inst.regX64.size == SizeX64::xmmword) + { + build.vmovsd(qword[sSpillArea + i * 8], inst.regX64); + } + else + { + OperandX64 location = addr[sSpillArea + i * 8]; + location.memSize = inst.regX64.size; // Override memory access size + build.mov(location, inst.regX64); + } + + usedSpillSlots.set(i); + + if (i + 1 > maxUsedSlot) + maxUsedSlot = i + 1; + + if (doubleSlot) + { + usedSpillSlots.set(i + 1); + + if (i + 2 > maxUsedSlot) + maxUsedSlot = i + 2; + } + + IrSpillX64 spill; + spill.instIdx = function.getInstIndex(inst); + spill.useDoubleSlot = doubleSlot; + spill.stackSlot = uint8_t(i); + spill.originalLoc = inst.regX64; + + spills.push_back(spill); + + freeReg(inst.regX64); + + inst.regX64 = noreg; + inst.spilled = true; + return; + } + + LUAU_ASSERT(!"nowhere to spill"); +} + +void IrRegAllocX64::restore(IrInst& inst, bool intoOriginalLocation) +{ + uint32_t instIdx = function.getInstIndex(inst); + + for (size_t i = 0; i < spills.size(); i++) + { + const IrSpillX64& spill = spills[i]; + + if (spill.instIdx == instIdx) + { + LUAU_ASSERT(spill.stackSlot != kNoStackSlot); + RegisterX64 reg; + + if (spill.originalLoc.size == SizeX64::xmmword) + { + reg = intoOriginalLocation ? takeReg(spill.originalLoc, instIdx) : allocXmmReg(instIdx); + + if (spill.useDoubleSlot) + build.vmovups(reg, xmmword[sSpillArea + spill.stackSlot * 8]); + else + build.vmovsd(reg, qword[sSpillArea + spill.stackSlot * 8]); + } + else + { + reg = intoOriginalLocation ? takeReg(spill.originalLoc, instIdx) : allocGprReg(spill.originalLoc.size, instIdx); + + OperandX64 location = addr[sSpillArea + spill.stackSlot * 8]; + location.memSize = reg.size; // Override memory access size + build.mov(reg, location); + } + + inst.regX64 = reg; + inst.spilled = false; + + usedSpillSlots.set(spill.stackSlot, false); + + if (spill.useDoubleSlot) + usedSpillSlots.set(spill.stackSlot + 1, false); + + spills[i] = spills.back(); + spills.pop_back(); + return; + } + } +} + +void IrRegAllocX64::preserveAndFreeInstValues() { - return target.lastUse == index && !target.reusedReg; + for (uint32_t instIdx : gprInstUsers) + { + if (instIdx != kInvalidInstIdx) + preserve(function.instructions[instIdx]); + } + + for (uint32_t instIdx : xmmInstUsers) + { + if (instIdx != kInvalidInstIdx) + preserve(function.instructions[instIdx]); + } } bool IrRegAllocX64::shouldFreeGpr(RegisterX64 reg) const @@ -175,6 +334,33 @@ bool IrRegAllocX64::shouldFreeGpr(RegisterX64 reg) const return false; } +uint32_t IrRegAllocX64::findInstructionWithFurthestNextUse(const std::array& regInstUsers) const +{ + uint32_t furthestUseTarget = kInvalidInstIdx; + uint32_t furthestUseLocation = 0; + + for (uint32_t regInstUser : regInstUsers) + { + // Cannot spill temporary registers or the register of the value that's defined in the current instruction + if (regInstUser == kInvalidInstIdx || regInstUser == currInstIdx) + continue; + + uint32_t nextUse = getNextInstUse(function, regInstUser, currInstIdx); + + // Cannot spill value that is about to be used in the current instruction + if (nextUse == currInstIdx) + continue; + + if (furthestUseTarget == kInvalidInstIdx || nextUse > furthestUseLocation) + { + furthestUseLocation = nextUse; + furthestUseTarget = regInstUser; + } + } + + return furthestUseTarget; +} + void IrRegAllocX64::assertFree(RegisterX64 reg) const { if (reg.size == SizeX64::xmmword) @@ -192,6 +378,11 @@ void IrRegAllocX64::assertAllFree() const LUAU_ASSERT(free); } +void IrRegAllocX64::assertNoSpills() const +{ + LUAU_ASSERT(spills.empty()); +} + ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner) : owner(owner) , reg(noreg) @@ -222,9 +413,9 @@ void ScopedRegX64::alloc(SizeX64 size) LUAU_ASSERT(reg == noreg); if (size == SizeX64::xmmword) - reg = owner.allocXmmReg(); + reg = owner.allocXmmReg(kInvalidInstIdx); else - reg = owner.allocGprReg(size); + reg = owner.allocGprReg(size, kInvalidInstIdx); } void ScopedRegX64::free() @@ -241,6 +432,41 @@ RegisterX64 ScopedRegX64::release() return tmp; } +ScopedSpills::ScopedSpills(IrRegAllocX64& owner) + : owner(owner) +{ + snapshot = owner.spills; +} + +ScopedSpills::~ScopedSpills() +{ + // Taking a copy of current spills because we are going to potentially restore them + std::vector current = owner.spills; + + // Restore registers that were spilled inside scope protected by this object + for (IrSpillX64& curr : current) + { + // If spill existed before current scope, it can be restored outside of it + if (!wasSpilledBefore(curr)) + { + IrInst& inst = owner.function.instructions[curr.instIdx]; + + owner.restore(inst, /*intoOriginalLocation*/ true); + } + } +} + +bool ScopedSpills::wasSpilledBefore(const IrSpillX64& spill) const +{ + for (const IrSpillX64& preexisting : snapshot) + { + if (spill.instIdx == preexisting.instIdx) + return true; + } + + return false; +} + } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index 2955aaffb..ba4915645 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -61,7 +61,8 @@ BuiltinImplResult translateBuiltinNumberTo2Number( if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); + if (nresults > 1) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 2}; } @@ -190,10 +191,10 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - build.loadAndCheckTag(build.vmReg(args.index + 1), LUA_TNUMBER, fallback); + build.loadAndCheckTag(build.vmReg(vmRegOp(args) + 1), LUA_TNUMBER, fallback); IrOp min = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp max = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 1)); + IrOp max = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::JUMP_CMP_NUM, min, max, build.cond(IrCondition::NotLessEqual), fallback, block); build.beginBlock(block); @@ -274,6 +275,27 @@ BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } +BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 3 || nresults > 1) + return {BuiltinImplType::None, -1}; + + LUAU_ASSERT(LUA_VECTOR_SIZE == 3); + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + build.loadAndCheckTag(build.vmReg(vmRegOp(args) + 1), LUA_TNUMBER, fallback); + + IrOp x = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp y = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp z = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + 1)); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), x, y, z); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback) { // Builtins are not allowed to handle variadic arguments @@ -332,6 +354,8 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, return translateBuiltinType(build, nparams, ra, arg, args, nresults, fallback); case LBF_TYPEOF: return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults, fallback); + case LBF_VECTOR: + return translateBuiltinVector(build, nparams, ra, arg, args, nresults, fallback); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index e366888e6..a985318b9 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -301,7 +301,7 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, if (opc.kind == IrOpKind::VmConst) { LUAU_ASSERT(build.function.proto); - TValue protok = build.function.proto->k[opc.index]; + TValue protok = build.function.proto->k[vmConstOp(opc)]; LUAU_ASSERT(protok.tt == LUA_TNUMBER); @@ -1108,5 +1108,71 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(next); } +void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + + IrOp fallthrough = build.block(IrBlockKind::Internal); + IrOp next = build.blockAtInst(pcpos + 1); + + IrOp target = (ra == rb) ? next : build.block(IrBlockKind::Internal); + + build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(rb), target, fallthrough); + build.beginBlock(fallthrough); + + IrOp load = build.inst(IrCmd::LOAD_TVALUE, c); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load); + build.inst(IrCmd::JUMP, next); + + if (ra == rb) + { + build.beginBlock(next); + } + else + { + build.beginBlock(target); + + IrOp load1 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load1); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); + } +} + +void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + + IrOp fallthrough = build.block(IrBlockKind::Internal); + IrOp next = build.blockAtInst(pcpos + 1); + + IrOp target = (ra == rb) ? next : build.block(IrBlockKind::Internal); + + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(rb), target, fallthrough); + build.beginBlock(fallthrough); + + IrOp load = build.inst(IrCmd::LOAD_TVALUE, c); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load); + build.inst(IrCmd::JUMP, next); + + if (ra == rb) + { + build.beginBlock(next); + } + else + { + build.beginBlock(target); + + IrOp load1 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load1); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); + } +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 0be111dca..87a530b50 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -61,6 +61,8 @@ void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); +void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 45e2bae09..c5e7c887a 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -14,6 +14,134 @@ namespace Luau namespace CodeGen { +IrValueKind getCmdValueKind(IrCmd cmd) +{ + switch (cmd) + { + case IrCmd::NOP: + return IrValueKind::None; + case IrCmd::LOAD_TAG: + return IrValueKind::Tag; + case IrCmd::LOAD_POINTER: + return IrValueKind::Pointer; + case IrCmd::LOAD_DOUBLE: + return IrValueKind::Double; + case IrCmd::LOAD_INT: + return IrValueKind::Int; + case IrCmd::LOAD_TVALUE: + case IrCmd::LOAD_NODE_VALUE_TV: + return IrValueKind::Tvalue; + case IrCmd::LOAD_ENV: + case IrCmd::GET_ARR_ADDR: + case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::GET_HASH_NODE_ADDR: + return IrValueKind::Pointer; + case IrCmd::STORE_TAG: + case IrCmd::STORE_POINTER: + case IrCmd::STORE_DOUBLE: + case IrCmd::STORE_INT: + case IrCmd::STORE_VECTOR: + case IrCmd::STORE_TVALUE: + case IrCmd::STORE_NODE_VALUE_TV: + return IrValueKind::None; + case IrCmd::ADD_INT: + case IrCmd::SUB_INT: + return IrValueKind::Int; + case IrCmd::ADD_NUM: + case IrCmd::SUB_NUM: + case IrCmd::MUL_NUM: + case IrCmd::DIV_NUM: + case IrCmd::MOD_NUM: + case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: + case IrCmd::UNM_NUM: + case IrCmd::FLOOR_NUM: + case IrCmd::CEIL_NUM: + case IrCmd::ROUND_NUM: + case IrCmd::SQRT_NUM: + case IrCmd::ABS_NUM: + return IrValueKind::Double; + case IrCmd::NOT_ANY: + return IrValueKind::Int; + case IrCmd::JUMP: + case IrCmd::JUMP_IF_TRUTHY: + case IrCmd::JUMP_IF_FALSY: + case IrCmd::JUMP_EQ_TAG: + case IrCmd::JUMP_EQ_INT: + case IrCmd::JUMP_EQ_POINTER: + case IrCmd::JUMP_CMP_NUM: + case IrCmd::JUMP_CMP_ANY: + case IrCmd::JUMP_SLOT_MATCH: + return IrValueKind::None; + case IrCmd::TABLE_LEN: + return IrValueKind::Double; + case IrCmd::NEW_TABLE: + case IrCmd::DUP_TABLE: + return IrValueKind::Pointer; + case IrCmd::TRY_NUM_TO_INDEX: + return IrValueKind::Int; + case IrCmd::TRY_CALL_FASTGETTM: + return IrValueKind::Pointer; + case IrCmd::INT_TO_NUM: + return IrValueKind::Double; + case IrCmd::ADJUST_STACK_TO_REG: + case IrCmd::ADJUST_STACK_TO_TOP: + return IrValueKind::None; + case IrCmd::FASTCALL: + return IrValueKind::None; + case IrCmd::INVOKE_FASTCALL: + return IrValueKind::Int; + case IrCmd::CHECK_FASTCALL_RES: + case IrCmd::DO_ARITH: + case IrCmd::DO_LEN: + case IrCmd::GET_TABLE: + case IrCmd::SET_TABLE: + case IrCmd::GET_IMPORT: + case IrCmd::CONCAT: + case IrCmd::GET_UPVALUE: + case IrCmd::SET_UPVALUE: + case IrCmd::PREPARE_FORN: + case IrCmd::CHECK_TAG: + case IrCmd::CHECK_READONLY: + case IrCmd::CHECK_NO_METATABLE: + case IrCmd::CHECK_SAFE_ENV: + case IrCmd::CHECK_ARRAY_SIZE: + case IrCmd::CHECK_SLOT_MATCH: + case IrCmd::CHECK_NODE_NO_NEXT: + case IrCmd::INTERRUPT: + case IrCmd::CHECK_GC: + case IrCmd::BARRIER_OBJ: + case IrCmd::BARRIER_TABLE_BACK: + case IrCmd::BARRIER_TABLE_FORWARD: + case IrCmd::SET_SAVEDPC: + case IrCmd::CLOSE_UPVALS: + case IrCmd::CAPTURE: + case IrCmd::SETLIST: + case IrCmd::CALL: + case IrCmd::RETURN: + case IrCmd::FORGLOOP: + case IrCmd::FORGLOOP_FALLBACK: + case IrCmd::FORGPREP_XNEXT_FALLBACK: + case IrCmd::COVERAGE: + case IrCmd::FALLBACK_GETGLOBAL: + case IrCmd::FALLBACK_SETGLOBAL: + case IrCmd::FALLBACK_GETTABLEKS: + case IrCmd::FALLBACK_SETTABLEKS: + case IrCmd::FALLBACK_NAMECALL: + case IrCmd::FALLBACK_PREPVARARGS: + case IrCmd::FALLBACK_GETVARARGS: + case IrCmd::FALLBACK_NEWCLOSURE: + case IrCmd::FALLBACK_DUPCLOSURE: + case IrCmd::FALLBACK_FORGPREP: + return IrValueKind::None; + case IrCmd::SUBSTITUTE: + return IrValueKind::Unknown; + } + + LUAU_UNREACHABLE(); +} + static void removeInstUse(IrFunction& function, uint32_t instIdx) { IrInst& inst = function.instructions[instIdx]; diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index ddc9c03d1..524796929 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -45,6 +45,7 @@ void initFallbackTable(NativeState& data) CODEGEN_SET_FALLBACK(LOP_BREAK, 0); // Fallbacks that are called from partial implementation of an instruction + // TODO: these fallbacks should be replaced with special functions that exclude the (redundantly executed) fast path from the fallback CODEGEN_SET_FALLBACK(LOP_GETGLOBAL, 0); CODEGEN_SET_FALLBACK(LOP_SETGLOBAL, 0); CODEGEN_SET_FALLBACK(LOP_GETTABLEKS, 0); diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index f767f5496..7157a18c4 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -96,20 +96,17 @@ struct ConstPropState void invalidateTag(IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); - invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ false); + invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ true, /* invalidateValue */ false); } void invalidateValue(IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); - invalidate(regs[regOp.index], /* invalidateTag */ false, /* invalidateValue */ true); + invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ false, /* invalidateValue */ true); } void invalidate(IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); - invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ true); + invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ true, /* invalidateValue */ true); } void invalidateRegistersFrom(int firstReg) @@ -156,17 +153,16 @@ struct ConstPropState void createRegLink(uint32_t instIdx, IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); LUAU_ASSERT(!instLink.contains(instIdx)); - instLink[instIdx] = RegisterLink{uint8_t(regOp.index), regs[regOp.index].version}; + instLink[instIdx] = RegisterLink{uint8_t(vmRegOp(regOp)), regs[vmRegOp(regOp)].version}; } RegisterInfo* tryGetRegisterInfo(IrOp op) { if (op.kind == IrOpKind::VmReg) { - maxReg = int(op.index) > maxReg ? int(op.index) : maxReg; - return ®s[op.index]; + maxReg = vmRegOp(op) > maxReg ? vmRegOp(op) : maxReg; + return ®s[vmRegOp(op)]; } if (RegisterLink* link = tryGetRegLink(op)) @@ -368,6 +364,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } break; + case IrCmd::STORE_VECTOR: + state.invalidateValue(inst.a); + break; case IrCmd::STORE_TVALUE: if (inst.a.kind == IrOpKind::VmReg) { @@ -503,15 +502,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } break; - case IrCmd::AND: - case IrCmd::ANDK: - case IrCmd::OR: - case IrCmd::ORK: - state.invalidate(inst.a); - break; case IrCmd::FASTCALL: case IrCmd::INVOKE_FASTCALL: - handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), inst.b.index, function.intOp(inst.f)); + handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), vmRegOp(inst.b), function.intOp(inst.f)); break; // These instructions don't have an effect on register/memory state we are tracking @@ -590,7 +583,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateUserCall(); break; case IrCmd::CONCAT: - state.invalidateRegisterRange(inst.a.index, function.uintOp(inst.b)); + state.invalidateRegisterRange(vmRegOp(inst.a), function.uintOp(inst.b)); state.invalidateUserCall(); // TODO: if only strings and numbers are concatenated, there will be no user calls break; case IrCmd::PREPARE_FORN: @@ -605,14 +598,14 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateUserCall(); break; case IrCmd::CALL: - state.invalidateRegistersFrom(inst.a.index); + state.invalidateRegistersFrom(vmRegOp(inst.a)); state.invalidateUserCall(); break; case IrCmd::FORGLOOP: - state.invalidateRegistersFrom(inst.a.index + 2); // Rn and Rn+1 are not modified + state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified break; case IrCmd::FORGLOOP_FALLBACK: - state.invalidateRegistersFrom(inst.a.index + 2); // Rn and Rn+1 are not modified + state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified state.invalidateUserCall(); break; case IrCmd::FORGPREP_XNEXT_FALLBACK: @@ -633,14 +626,14 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateUserCall(); break; case IrCmd::FALLBACK_NAMECALL: - state.invalidate(IrOp{inst.b.kind, inst.b.index + 0u}); - state.invalidate(IrOp{inst.b.kind, inst.b.index + 1u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 0u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 1u}); state.invalidateUserCall(); break; case IrCmd::FALLBACK_PREPVARARGS: break; case IrCmd::FALLBACK_GETVARARGS: - state.invalidateRegistersFrom(inst.b.index); + state.invalidateRegistersFrom(vmRegOp(inst.b)); break; case IrCmd::FALLBACK_NEWCLOSURE: state.invalidate(inst.b); @@ -649,9 +642,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidate(inst.b); break; case IrCmd::FALLBACK_FORGPREP: - state.invalidate(IrOp{inst.b.kind, inst.b.index + 0u}); - state.invalidate(IrOp{inst.b.kind, inst.b.index + 1u}); - state.invalidate(IrOp{inst.b.kind, inst.b.index + 2u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 0u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 1u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 2u}); break; } } diff --git a/VM/include/lua.h b/VM/include/lua.h index 649c96c1a..f5f5059fe 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -29,7 +29,7 @@ enum lua_Status LUA_OK = 0, LUA_YIELD, LUA_ERRRUN, - LUA_ERRSYNTAX, + LUA_ERRSYNTAX, // legacy error code, preserved for compatibility LUA_ERRMEM, LUA_ERRERR, LUA_BREAK, // yielded for a debug breakpoint diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index ff8105b8c..264388bc9 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,6 +17,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauBetterOOMHandling, false) + /* ** {====================================================== ** Error-recovery functions @@ -79,22 +81,17 @@ class lua_exception : public std::exception const char* what() const throw() override { - // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. - if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) - { - // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. + // LUA_ERRRUN passes error object on the stack + if (status == LUA_ERRRUN || (status == LUA_ERRSYNTAX && !FFlag::LuauBetterOOMHandling)) if (const char* str = lua_tostring(L, -1)) - { return str; - } - } switch (status) { case LUA_ERRRUN: - return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; + return "lua_exception: runtime error"; case LUA_ERRSYNTAX: - return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; + return "lua_exception: syntax error"; case LUA_ERRMEM: return "lua_exception: " LUA_MEMERRMSG; case LUA_ERRERR: @@ -550,19 +547,42 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e int status = luaD_rawrunprotected(L, func, u); if (status != 0) { + int errstatus = status; + // call user-defined error function (used in xpcall) if (ef) { - // if errfunc fails, we fail with "error in error handling" - if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) - status = LUA_ERRERR; + if (FFlag::LuauBetterOOMHandling) + { + // push error object to stack top if it's not already there + if (status != LUA_ERRRUN) + seterrorobj(L, status, L->top); + + // if errfunc fails, we fail with "error in error handling" or "not enough memory" + int err = luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)); + + // in general we preserve the status, except for cases when the error handler fails + // out of memory is treated specially because it's common for it to be cascading, in which case we preserve the code + if (err == 0) + errstatus = LUA_ERRRUN; + else if (status == LUA_ERRMEM && err == LUA_ERRMEM) + errstatus = LUA_ERRMEM; + else + errstatus = status = LUA_ERRERR; + } + else + { + // if errfunc fails, we fail with "error in error handling" + if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) + status = LUA_ERRERR; + } } // since the call failed with an error, we might have to reset the 'active' thread state if (!oldactive) L->isactive = false; - // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + // restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. L->nCcalls = oldnCcalls; // an error occurred, check if we have a protected error callback @@ -577,7 +597,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); // close eventual pending closures - seterrorobj(L, status, oldtop); + seterrorobj(L, FFlag::LuauBetterOOMHandling ? errstatus : status, oldtop); L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index ddee3a71e..4443be34f 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,7 +10,7 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauOptimizedSort, false) +LUAU_FASTFLAGVARIABLE(LuauIntrosort, false) static int foreachi(lua_State* L) { @@ -298,120 +298,6 @@ static int tunpack(lua_State* L) return (int)n; } -/* -** {====================================================== -** Quicksort -** (based on `Algorithms in MODULA-3', Robert Sedgewick; -** Addison-Wesley, 1993.) -*/ - -static void set2(lua_State* L, int i, int j) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - lua_rawseti(L, 1, i); - lua_rawseti(L, 1, j); -} - -static int sort_comp(lua_State* L, int a, int b) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - if (!lua_isnil(L, 2)) - { // function? - int res; - lua_pushvalue(L, 2); - lua_pushvalue(L, a - 1); // -1 to compensate function - lua_pushvalue(L, b - 2); // -2 to compensate function and `a' - lua_call(L, 2, 1); - res = lua_toboolean(L, -1); - lua_pop(L, 1); - return res; - } - else // a < b? - return lua_lessthan(L, a, b); -} - -static void auxsort(lua_State* L, int l, int u) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - while (l < u) - { // for tail recursion - int i, j; - // sort elements a[l], a[(l+u)/2] and a[u] - lua_rawgeti(L, 1, l); - lua_rawgeti(L, 1, u); - if (sort_comp(L, -1, -2)) // a[u] < a[l]? - set2(L, l, u); // swap a[l] - a[u] - else - lua_pop(L, 2); - if (u - l == 1) - break; // only 2 elements - i = (l + u) / 2; - lua_rawgeti(L, 1, i); - lua_rawgeti(L, 1, l); - if (sort_comp(L, -2, -1)) // a[i]= P - while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) - { - if (i >= u) - luaL_error(L, "invalid order function for sorting"); - lua_pop(L, 1); // remove a[i] - } - // repeat --j until a[j] <= P - while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) - { - if (j <= l) - luaL_error(L, "invalid order function for sorting"); - lua_pop(L, 1); // remove a[j] - } - if (j < i) - { - lua_pop(L, 3); // pop pivot, a[i], a[j] - break; - } - set2(L, i, j); - } - lua_rawgeti(L, 1, u - 1); - lua_rawgeti(L, 1, i); - set2(L, u - 1, i); // swap pivot (a[u-1]) with a[i] - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // adjust so that smaller half is in [j..i] and larger one in [l..u] - if (i - l < u - i) - { - j = l; - i = i - 1; - l = i + 2; - } - else - { - j = i + 1; - i = u; - u = j - 2; - } - auxsort(L, j, i); // call recursively the smaller one - } // repeat the routine for the larger one -} - typedef int (*SortPredicate)(lua_State* L, const TValue* l, const TValue* r); static int sort_func(lua_State* L, const TValue* l, const TValue* r) @@ -456,30 +342,77 @@ inline int sort_less(lua_State* L, Table* t, int i, int j, SortPredicate pred) return res; } -static void sort_rec(lua_State* L, Table* t, int l, int u, SortPredicate pred) +static void sort_siftheap(lua_State* L, Table* t, int l, int u, SortPredicate pred, int root) +{ + LUAU_ASSERT(l <= u); + int count = u - l + 1; + + // process all elements with two children + while (root * 2 + 2 < count) + { + int left = root * 2 + 1, right = root * 2 + 2; + int next = root; + next = sort_less(L, t, l + next, l + left, pred) ? left : next; + next = sort_less(L, t, l + next, l + right, pred) ? right : next; + + if (next == root) + break; + + sort_swap(L, t, l + root, l + next); + root = next; + } + + // process last element if it has just one child + int lastleft = root * 2 + 1; + if (lastleft == count - 1 && sort_less(L, t, l + root, l + lastleft, pred)) + sort_swap(L, t, l + root, l + lastleft); +} + +static void sort_heap(lua_State* L, Table* t, int l, int u, SortPredicate pred) +{ + LUAU_ASSERT(l <= u); + int count = u - l + 1; + + for (int i = count / 2 - 1; i >= 0; --i) + sort_siftheap(L, t, l, u, pred, i); + + for (int i = count - 1; i > 0; --i) + { + sort_swap(L, t, l, l + i); + sort_siftheap(L, t, l, l + i - 1, pred, 0); + } +} + +static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredicate pred) { // sort range [l..u] (inclusive, 0-based) while (l < u) { - int i, j; + // if the limit has been reached, quick sort is going over the permitted nlogn complexity, so we fall back to heap sort + if (FFlag::LuauIntrosort && limit == 0) + return sort_heap(L, t, l, u, pred); + // sort elements a[l], a[(l+u)/2] and a[u] + // note: this simultaneously acts as a small sort and a median selector if (sort_less(L, t, u, l, pred)) // a[u] < a[l]? sort_swap(L, t, u, l); // swap a[l] - a[u] if (u - l == 1) break; // only 2 elements - i = l + ((u - l) >> 1); // midpoint - if (sort_less(L, t, i, l, pred)) // a[i]> 1); // midpoint + if (sort_less(L, t, m, l, pred)) // a[m]= P @@ -498,62 +431,71 @@ static void sort_rec(lua_State* L, Table* t, int l, int u, SortPredicate pred) break; sort_swap(L, t, i, j); } - // swap pivot (a[u-1]) with a[i], which is the new midpoint - sort_swap(L, t, u - 1, i); - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // adjust so that smaller half is in [j..i] and larger one in [l..u] - if (i - l < u - i) + + // swap pivot a[p] with a[i], which is the new midpoint + sort_swap(L, t, p, i); + + if (FFlag::LuauIntrosort) { - j = l; - i = i - 1; - l = i + 2; + // adjust limit to allow 1.5 log2N recursive steps + limit = (limit >> 1) + (limit >> 2); + + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // sort smaller half recursively; the larger half is sorted in the next loop iteration + if (i - l < u - i) + { + sort_rec(L, t, l, i - 1, limit, pred); + l = i + 1; + } + else + { + sort_rec(L, t, i + 1, u, limit, pred); + u = i - 1; + } } else { - j = i + 1; - i = u; - u = j - 2; + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // adjust so that smaller half is in [j..i] and larger one in [l..u] + if (i - l < u - i) + { + j = l; + i = i - 1; + l = i + 2; + } + else + { + j = i + 1; + i = u; + u = j - 2; + } + + // sort smaller half recursively; the larger half is sorted in the next loop iteration + sort_rec(L, t, j, i, limit, pred); } - sort_rec(L, t, j, i, pred); // call recursively the smaller one - } // repeat the routine for the larger one + } } static int tsort(lua_State* L) { - if (FFlag::LuauOptimizedSort) - { - luaL_checktype(L, 1, LUA_TTABLE); - Table* t = hvalue(L->base); - int n = luaH_getn(t); - if (t->readonly) - luaG_readonlyerror(L); - - SortPredicate pred = luaV_lessthan; - if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? - { - luaL_checktype(L, 2, LUA_TFUNCTION); - pred = sort_func; - } - lua_settop(L, 2); // make sure there are two arguments + luaL_checktype(L, 1, LUA_TTABLE); + Table* t = hvalue(L->base); + int n = luaH_getn(t); + if (t->readonly) + luaG_readonlyerror(L); - if (n > 0) - sort_rec(L, t, 0, n - 1, pred); - return 0; - } - else + SortPredicate pred = luaV_lessthan; + if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? { - luaL_checktype(L, 1, LUA_TTABLE); - int n = lua_objlen(L, 1); - luaL_checkstack(L, 40, ""); // assume array is smaller than 2^40 - if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? - luaL_checktype(L, 2, LUA_TFUNCTION); - lua_settop(L, 2); // make sure there is two arguments - auxsort(L, 1, n); - return 0; + luaL_checktype(L, 2, LUA_TFUNCTION); + pred = sort_func; } -} + lua_settop(L, 2); // make sure there are two arguments -// }====================================================== + if (n > 0) + sort_rec(L, t, 0, n - 1, n, pred); + return 0; +} static int tcreate(lua_State* L) { diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 6aa7aa561..054eca7bf 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -507,6 +507,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXConversionInstructionForms") SINGLE_COMPARE(vcvtsi2sd(xmm6, xmm11, dword[rcx + rdx]), 0xc4, 0xe1, 0x23, 0x2a, 0x34, 0x11); SINGLE_COMPARE(vcvtsi2sd(xmm5, xmm10, r13), 0xc4, 0xc1, 0xab, 0x2a, 0xed); SINGLE_COMPARE(vcvtsi2sd(xmm6, xmm11, qword[rcx + rdx]), 0xc4, 0xe1, 0xa3, 0x2a, 0x34, 0x11); + SINGLE_COMPARE(vcvtsd2ss(xmm5, xmm10, xmm11), 0xc4, 0xc1, 0x2b, 0x5a, 0xeb); + SINGLE_COMPARE(vcvtsd2ss(xmm6, xmm11, qword[rcx + rdx]), 0xc4, 0xe1, 0xa3, 0x5a, 0x34, 0x11); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 53dc99e15..c79bf35ea 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -85,8 +85,8 @@ struct ACFixtureImpl : BaseType { GlobalTypes& globals = this->frontend.globalsForAutocomplete; unfreeze(globals.globalTypes); - LoadDefinitionFileResult result = - loadDefinitionFile(this->frontend.typeChecker, globals, globals.globalScope, source, "@test", /* captureComments */ false); + LoadDefinitionFileResult result = this->frontend.loadDefinitionFile( + globals, globals.globalScope, source, "@test", /* captureComments */ false, /* typeCheckForAutocomplete */ true); freeze(globals.globalTypes); REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); @@ -3448,8 +3448,6 @@ TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") TEST_CASE_FIXTURE(ACFixture, "autocomplete_response_perf1" * doctest::timeout(0.5)) { - ScopedFastFlag luauAutocompleteSkipNormalization{"LuauAutocompleteSkipNormalization", true}; - // Build a function type with a large overload set const int parts = 100; std::string source; diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 957d32719..2a32bce2d 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -9,6 +9,7 @@ #include "Luau/StringUtils.h" #include "Luau/BytecodeBuilder.h" #include "Luau/CodeGen.h" +#include "Luau/Frontend.h" #include "doctest.h" #include "ScopedFlags.h" @@ -243,6 +244,24 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n return globalState; } +static void* limitedRealloc(void* ud, void* ptr, size_t osize, size_t nsize) +{ + if (nsize == 0) + { + free(ptr); + return nullptr; + } + else if (nsize > 8 * 1024 * 1024) + { + // For testing purposes return null for large allocations so we can generate errors related to memory allocation failures + return nullptr; + } + else + { + return realloc(ptr, nsize); + } +} + TEST_SUITE_BEGIN("Conformance"); TEST_CASE("Assert") @@ -381,6 +400,8 @@ static int cxxthrow(lua_State* L) TEST_CASE("PCall") { + ScopedFastFlag sff("LuauBetterOOMHandling", true); + runConformance("pcall.lua", [](lua_State* L) { lua_pushcfunction(L, cxxthrow, "cxxthrow"); lua_setglobal(L, "cxxthrow"); @@ -395,7 +416,7 @@ TEST_CASE("PCall") }, "resumeerror"); lua_setglobal(L, "resumeerror"); - }); + }, nullptr, lua_newstate(limitedRealloc, nullptr)); } TEST_CASE("Pack") @@ -501,17 +522,15 @@ TEST_CASE("Types") { runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; - Luau::InternalErrorReporter iceHandler; - Luau::BuiltinTypes builtinTypes; - Luau::GlobalTypes globals{Luau::NotNull{&builtinTypes}}; - Luau::TypeChecker env(globals.globalScope, &moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); - - Luau::registerBuiltinGlobals(env, globals); - Luau::freeze(globals.globalTypes); + Luau::NullFileResolver fileResolver; + Luau::NullConfigResolver configResolver; + Luau::Frontend frontend{&fileResolver, &configResolver}; + Luau::registerBuiltinGlobals(frontend, frontend.globals); + Luau::freeze(frontend.globals.globalTypes); lua_newtable(L); - for (const auto& [name, binding] : globals.globalScope->bindings) + for (const auto& [name, binding] : frontend.globals.globalScope->bindings) { populateRTTI(L, binding.typeId); lua_setfield(L, -2, toString(name).c_str()); @@ -882,7 +901,7 @@ TEST_CASE("ApiIter") TEST_CASE("ApiCalls") { - StateRef globalState = runConformance("apicalls.lua"); + StateRef globalState = runConformance("apicalls.lua", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); lua_State* L = globalState.get(); // lua_call @@ -981,6 +1000,55 @@ TEST_CASE("ApiCalls") CHECK(lua_tonumber(L, -1) == 4); lua_pop(L, 1); } + + ScopedFastFlag sff("LuauBetterOOMHandling", true); + + // lua_pcall on OOM + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 0, 0); + CHECK(res == LUA_ERRMEM); + } + + // lua_pcall on OOM with an error handler + { + lua_getfield(L, LUA_GLOBALSINDEX, "oops"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRMEM); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "oops") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on OOM with an error handler that errors + { + lua_getfield(L, LUA_GLOBALSINDEX, "error"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRERR); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "error in error handling") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on OOM with an error handler that OOMs + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRMEM); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "not enough memory") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on error with an error handler that OOMs + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + lua_getfield(L, LUA_GLOBALSINDEX, "error"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRERR); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "error in error handling") == 0)); + lua_pop(L, 1); + } } TEST_CASE("ApiAtoms") @@ -1051,26 +1119,7 @@ TEST_CASE("ExceptionObject") return ExceptionResult{false, ""}; }; - auto reallocFunc = [](void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { - if (nsize == 0) - { - free(ptr); - return nullptr; - } - else if (nsize > 512 * 1024) - { - // For testing purposes return null for large allocations - // so we can generate exceptions related to memory allocation - // failures. - return nullptr; - } - else - { - return realloc(ptr, nsize); - } - }; - - StateRef globalState = runConformance("exceptions.lua", nullptr, nullptr, lua_newstate(reallocFunc, nullptr)); + StateRef globalState = runConformance("exceptions.lua", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); lua_State* L = globalState.get(); { diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 4d2e83fc2..aebf177cd 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -506,7 +506,8 @@ void Fixture::validateErrors(const std::vector& errors) LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test", /* captureComments */ false); + LoadDefinitionFileResult result = + frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, source, "@test", /* captureComments */ false); freeze(frontend.globals.globalTypes); if (result.module) @@ -521,9 +522,9 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::unfreeze(frontend.globals.globalTypes); Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); - registerBuiltinGlobals(frontend); + registerBuiltinGlobals(frontend, frontend.globals); if (prepareAutocomplete) - registerBuiltinGlobals(frontend.typeCheckerForAutocomplete, frontend.globalsForAutocomplete); + registerBuiltinGlobals(frontend, frontend.globalsForAutocomplete, /*typeCheckForAutocomplete*/ true); registerTestTypes(); Luau::freeze(frontend.globals.globalTypes); @@ -594,8 +595,12 @@ void registerHiddenTypes(Frontend* frontend) TypeId t = globals.globalTypes.addType(GenericType{"T"}); GenericTypeDefinition genericT{t}; + TypeId u = globals.globalTypes.addType(GenericType{"U"}); + GenericTypeDefinition genericU{u}; + ScopePtr globalScope = globals.globalScope; globalScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, globals.globalTypes.addType(NegationType{t})}; + globalScope->exportedTypeBindings["Mt"] = TypeFun{{genericT, genericU}, globals.globalTypes.addType(MetatableType{t, u})}; globalScope->exportedTypeBindings["fun"] = TypeFun{{}, frontend->builtinTypes->functionType}; globalScope->exportedTypeBindings["cls"] = TypeFun{{}, frontend->builtinTypes->classType}; globalScope->exportedTypeBindings["err"] = TypeFun{{}, frontend->builtinTypes->errorType}; diff --git a/tests/Fixture.h b/tests/Fixture.h index 4c49593cc..8d48ab1dc 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -94,7 +94,6 @@ struct Fixture TypeId requireTypeAlias(const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; - ScopedFastFlag luauLintInTypecheck{"LuauLintInTypecheck", true}; TestFileResolver fileResolver; TestConfigResolver configResolver; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 3b1ec4ad1..13fd6e0f8 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -877,7 +877,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "environments") ScopePtr testScope = frontend.addEnvironment("test"); unfreeze(frontend.globals.globalTypes); - loadDefinitionFile(frontend.typeChecker, frontend.globals, testScope, R"( + frontend.loadDefinitionFile(frontend.globals, testScope, R"( export type Foo = number | string )", "@test", /* captureComments */ false); diff --git a/tests/IrCallWrapperX64.test.cpp b/tests/IrCallWrapperX64.test.cpp index 8c7b1393f..c8918dbde 100644 --- a/tests/IrCallWrapperX64.test.cpp +++ b/tests/IrCallWrapperX64.test.cpp @@ -12,7 +12,7 @@ class IrCallWrapperX64Fixture public: IrCallWrapperX64Fixture() : build(/* logText */ true, ABIX64::Windows) - , regs(function) + , regs(build, function) , callWrap(regs, build, ~0u) { } @@ -46,8 +46,8 @@ TEST_SUITE_BEGIN("IrCallWrapperX64"); TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleRegs") { - ScopedRegX64 tmp1{regs, regs.takeReg(rax)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rax, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp1); callWrap.addArgument(SizeX64::qword, tmp2); // Already in its place callWrap.call(qword[r12]); @@ -60,7 +60,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleRegs") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse1") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp1.reg); // Already in its place callWrap.addArgument(SizeX64::qword, tmp1.release()); callWrap.call(qword[r12]); @@ -73,7 +73,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse1") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse2") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg]); callWrap.addArgument(SizeX64::qword, tmp1.release()); callWrap.call(qword[r12]); @@ -87,8 +87,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse2") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleMemImm") { - ScopedRegX64 tmp1{regs, regs.takeReg(rax)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rsi)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rax, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rsi, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::dword, 32); callWrap.addArgument(SizeX64::dword, -1); callWrap.addArgument(SizeX64::qword, qword[r14 + 32]); @@ -106,7 +106,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleMemImm") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleStackArgs") { - ScopedRegX64 tmp{regs, regs.takeReg(rax)}; + ScopedRegX64 tmp{regs, regs.takeReg(rax, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp); callWrap.addArgument(SizeX64::qword, qword[r14 + 16]); callWrap.addArgument(SizeX64::qword, qword[r14 + 32]); @@ -148,10 +148,10 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FixedRegisters") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "EasyInterference") { - ScopedRegX64 tmp1{regs, regs.takeReg(rdi)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rsi)}; - ScopedRegX64 tmp3{regs, regs.takeReg(rArg2)}; - ScopedRegX64 tmp4{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rdi, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rsi, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp1); callWrap.addArgument(SizeX64::qword, tmp2); callWrap.addArgument(SizeX64::qword, tmp3); @@ -169,8 +169,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "EasyInterference") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeInterference") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.release() + 8]); callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + 8]); callWrap.call(qword[r12]); @@ -184,10 +184,10 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeInterference") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg4)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg3)}; - ScopedRegX64 tmp3{regs, regs.takeReg(rArg2)}; - ScopedRegX64 tmp4{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg4, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg3, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp1); callWrap.addArgument(SizeX64::qword, tmp2); callWrap.addArgument(SizeX64::qword, tmp3); @@ -207,10 +207,10 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt2") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg4d)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg3d)}; - ScopedRegX64 tmp3{regs, regs.takeReg(rArg2d)}; - ScopedRegX64 tmp4{regs, regs.takeReg(rArg1d)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg4d, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg3d, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2d, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1d, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::dword, tmp1); callWrap.addArgument(SizeX64::dword, tmp2); callWrap.addArgument(SizeX64::dword, tmp3); @@ -230,8 +230,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt2") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceFp") { - ScopedRegX64 tmp1{regs, regs.takeReg(xmm1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(xmm0)}; + ScopedRegX64 tmp1{regs, regs.takeReg(xmm1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(xmm0, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::xmmword, tmp1); callWrap.addArgument(SizeX64::xmmword, tmp2); callWrap.call(qword[r12]); @@ -246,10 +246,10 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceFp") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceBoth") { - ScopedRegX64 int1{regs, regs.takeReg(rArg2)}; - ScopedRegX64 int2{regs, regs.takeReg(rArg1)}; - ScopedRegX64 fp1{regs, regs.takeReg(xmm3)}; - ScopedRegX64 fp2{regs, regs.takeReg(xmm2)}; + ScopedRegX64 int1{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 int2{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 fp1{regs, regs.takeReg(xmm3, kInvalidInstIdx)}; + ScopedRegX64 fp2{regs, regs.takeReg(xmm2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, int1); callWrap.addArgument(SizeX64::qword, int2); callWrap.addArgument(SizeX64::xmmword, fp1); @@ -269,8 +269,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceBoth") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeMultiuseInterferenceMem") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); callWrap.addArgument(SizeX64::qword, qword[tmp2.reg + 16]); tmp1.release(); @@ -286,8 +286,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeMultiuseInterferenceMem") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem1") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + 16]); tmp1.release(); @@ -304,8 +304,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem1") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem2") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 16]); tmp1.release(); @@ -322,9 +322,9 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem2") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem3") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg3)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; - ScopedRegX64 tmp3{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg3, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); callWrap.addArgument(SizeX64::qword, qword[tmp2.reg + tmp3.reg + 16]); callWrap.addArgument(SizeX64::qword, qword[tmp3.reg + tmp1.reg + 16]); @@ -345,7 +345,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem3") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg1") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + 8]); callWrap.call(qword[tmp1.release() + 16]); @@ -358,8 +358,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg1") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg2") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp2); callWrap.call(qword[tmp1.release() + 16]); @@ -372,7 +372,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg2") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg3") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp1.reg); callWrap.call(qword[tmp1.release() + 16]); @@ -385,7 +385,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse1") { IrInst irInst1; IrOp irOp1 = {IrOpKind::Inst, 0}; - irInst1.regX64 = regs.takeReg(xmm0); + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); irInst1.lastUse = 1; function.instructions.push_back(irInst1); callWrap.instIdx = irInst1.lastUse; @@ -404,7 +404,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse2") { IrInst irInst1; IrOp irOp1 = {IrOpKind::Inst, 0}; - irInst1.regX64 = regs.takeReg(xmm0); + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); irInst1.lastUse = 1; function.instructions.push_back(irInst1); callWrap.instIdx = irInst1.lastUse; @@ -424,7 +424,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse3") { IrInst irInst1; IrOp irOp1 = {IrOpKind::Inst, 0}; - irInst1.regX64 = regs.takeReg(xmm0); + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); irInst1.lastUse = 1; function.instructions.push_back(irInst1); callWrap.instIdx = irInst1.lastUse; @@ -443,12 +443,12 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse4") { IrInst irInst1; IrOp irOp1 = {IrOpKind::Inst, 0}; - irInst1.regX64 = regs.takeReg(rax); + irInst1.regX64 = regs.takeReg(rax, irOp1.index); irInst1.lastUse = 1; function.instructions.push_back(irInst1); callWrap.instIdx = irInst1.lastUse; - ScopedRegX64 tmp{regs, regs.takeReg(rdx)}; + ScopedRegX64 tmp{regs, regs.takeReg(rdx, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, r15); callWrap.addArgument(SizeX64::qword, irInst1.regX64, irOp1); callWrap.addArgument(SizeX64::qword, tmp); @@ -464,8 +464,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse4") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ExtraCoverage") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, addr[r12 + 8]); callWrap.addArgument(SizeX64::qword, addr[r12 + 16]); callWrap.addArgument(SizeX64::xmmword, xmmword[r13]); @@ -481,4 +481,42 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ExtraCoverage") )"); } +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "AddressInStackArguments") +{ + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::dword, 2); + callWrap.addArgument(SizeX64::dword, 3); + callWrap.addArgument(SizeX64::dword, 4); + callWrap.addArgument(SizeX64::qword, addr[r12 + 16]); + callWrap.call(qword[r14]); + + checkMatch(R"( + lea rax,none ptr [r12+010h] + mov qword ptr [rsp+020h],rax + mov ecx,1 + mov edx,2 + mov r8d,3 + mov r9d,4 + call qword ptr [r14] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ImmediateConflictWithFunction") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::dword, 2); + callWrap.call(qword[tmp1.release() + tmp2.release()]); + + checkMatch(R"( + mov rax,rcx + mov ecx,1 + mov rbx,rdx + mov edx,2 + call qword ptr [rax+rbx] +)"); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 8bef5922f..54a1f44cb 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1273,7 +1273,7 @@ TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") { ScopePtr testScope = frontend.addEnvironment("Test"); unfreeze(frontend.globals.globalTypes); - loadDefinitionFile(frontend.typeChecker, frontend.globals, testScope, R"( + frontend.loadDefinitionFile(frontend.globals, testScope, R"( declare Foo: number )", "@test", /* captureComments */ false); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 4378bab8b..6552a24da 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -748,6 +748,20 @@ TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") CHECK("Child" == toString(normal("(Child | Unrelated) & Child"))); } +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_metatables_where_the_metatable_is_top_or_bottom") +{ + ScopedFastFlag sff{"LuauNormalizeMetatableFixes", true}; + + CHECK("{ @metatable *error-type*, {| |} }" == toString(normal("Mt<{}, any> & Mt<{}, err>"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") +{ + ScopedFastFlag sff{"LuauNormalizeMetatableFixes", true}; + + CHECK("never" == toString(normal("Mt<{}, number> & Mt<{}, string>"))); +} + TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") { ScopedFastFlag sffs[] = { diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index f3f464130..d67997574 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -78,7 +78,7 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_scope") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult parseFailResult = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult parseFailResult = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare foo )", "@test", /* captureComments */ false); @@ -88,7 +88,7 @@ TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_sc std::optional fooTy = tryGetGlobalBinding(frontend.globals, "foo"); CHECK(!fooTy.has_value()); - LoadDefinitionFileResult checkFailResult = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult checkFailResult = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( local foo: string = 123 declare bar: typeof(foo) )", @@ -140,7 +140,7 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_classes") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare class A X: number X: string @@ -161,7 +161,7 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( type NotAClass = {} declare class Foo extends NotAClass @@ -182,7 +182,7 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare class Foo extends Bar end @@ -397,7 +397,7 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare class Channel Messages: { Message } OnMessage: (message: Message) -> () diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index b3b2e4c94..b97848176 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -874,7 +874,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_table_method") std::vector args = flatten(ftv->argTypes).first; TypeId argType = args.at(1); - CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); + CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); } TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 38e7e2f31..87419debb 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -477,10 +477,10 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - TypeId free1 = arena.addType(FreeTypePack{scope.get()}); + TypeId free1 = arena.addType(FreeType{scope.get()}); TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - TypeId free2 = arena.addType(FreeTypePack{scope.get()}); + TypeId free2 = arena.addType(FreeType{scope.get()}); TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 7e317f2ef..3088235ae 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -490,8 +490,13 @@ struct FindFreeTypes return !foundOne; } - template - bool operator()(ID, Unifiable::Free) + bool operator()(TypeId, FreeType) + { + foundOne = true; + return false; + } + + bool operator()(TypePackId, FreeTypePack) { foundOne = true; return false; diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 20404434a..7d8ed38f7 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -25,7 +25,7 @@ struct TypePackFixture TypePackId freshTypePack() { - typePacks.emplace_back(new TypePackVar{Unifiable::Free{TypeLevel{}}}); + typePacks.emplace_back(new TypePackVar{FreeTypePack{TypeLevel{}}}); return typePacks.back().get(); } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 64ba63c8d..3f0becc54 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -74,7 +74,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_not_just TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_free") { auto emptyArgumentPack = TypePackVar{TypePack{}}; - auto free = Unifiable::Free(TypeLevel()); + auto free = FreeTypePack(TypeLevel()); auto freePack = TypePackVar{TypePackVariant{free}}; auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType}, &freePack}}; auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 274166237..8db62d96c 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -22,4 +22,12 @@ function getpi() return pi end +function largealloc() + table.create(1000000) +end + +function oops() + return "oops" +end + return('OK') diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index 969209fc4..b94f7972e 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -161,4 +161,11 @@ checkresults({ false, "ok" }, xpcall(recurse, function() return string.reverse(" -- however, if xpcall handler itself runs out of extra stack space, we get "error in error handling" checkresults({ false, "error in error handling" }, xpcall(recurse, function() return recurse(calllimit) end, calllimit - 2)) +-- simulate OOM and make sure we can catch it with pcall or xpcall +checkresults({ false, "not enough memory" }, pcall(function() table.create(1e6) end)) +checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) return e end)) +checkresults({ false, "oops" }, xpcall(function() table.create(1e6) end, function(e) return "oops" end)) +checkresults({ false, "error in error handling" }, xpcall(function() error("oops") end, function(e) table.create(1e6) end)) +checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) table.create(1e6) end)) + return 'OK' diff --git a/tests/conformance/sort.lua b/tests/conformance/sort.lua index 693a10dc5..3c2c20dd4 100644 --- a/tests/conformance/sort.lua +++ b/tests/conformance/sort.lua @@ -99,12 +99,12 @@ a = {" table.sort(a) check(a) --- TODO: assert that pcall returns false for new sort implementation (table is modified during sorting) -pcall(table.sort, a, function (x, y) +local ok = pcall(table.sort, a, function (x, y) loadstring(string.format("a[%q] = ''", x))() collectgarbage() return x + + + + + count + capacity + + capacity + data + + + + + + + impl + + + + + + impl + + + + diff --git a/tools/test_dcr.py b/tools/test_dcr.py index d30490b30..817d08313 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -107,6 +107,12 @@ def main(): action="store_true", help="Write a new faillist.txt after running tests.", ) + parser.add_argument( + "--lti", + dest="lti", + action="store_true", + help="Run the tests with local type inference enabled.", + ) parser.add_argument("--randomize", action="store_true", help="Pick a random seed") @@ -120,13 +126,19 @@ def main(): args = parser.parse_args() + if args.write and args.lti: + print_stderr( + "Cannot run test_dcr.py with --write *and* --lti. You don't want to commit local type inference faillist.txt yet." + ) + sys.exit(1) + failList = loadFailList() - commandLine = [ - args.path, - "--reporters=xml", - "--fflags=true,DebugLuauDeferredConstraintResolution=true", - ] + flags = ["true", "DebugLuauDeferredConstraintResolution"] + if args.lti: + flags.append("DebugLuauLocalTypeInference") + + commandLine = [args.path, "--reporters=xml", "--fflags=" + ",".join(flags)] if args.random_seed: commandLine.append("--random-seed=" + str(args.random_seed)) From 5e771b87ae4385c0160c296be56b542fd6fe3c24 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 14 Apr 2023 15:05:27 +0300 Subject: [PATCH 46/66] Sync to upstream/release/572 --- Analysis/include/Luau/AstQuery.h | 3 + Analysis/include/Luau/Frontend.h | 20 +- Analysis/include/Luau/Module.h | 1 + Analysis/include/Luau/Type.h | 14 + Analysis/include/Luau/Unifier.h | 4 +- Analysis/src/AstQuery.cpp | 31 +- Analysis/src/ConstraintSolver.cpp | 18 +- Analysis/src/Frontend.cpp | 158 +++-- Analysis/src/Module.cpp | 23 +- Analysis/src/Type.cpp | 9 + Analysis/src/TypeChecker2.cpp | 66 +- Analysis/src/TypeInfer.cpp | 64 +- Analysis/src/Unifier.cpp | 86 ++- Ast/src/StringUtils.cpp | 8 +- CodeGen/include/Luau/AddressA64.h | 4 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 17 +- CodeGen/include/Luau/IrCallWrapperX64.h | 7 +- CodeGen/include/Luau/IrData.h | 38 +- CodeGen/include/Luau/IrRegAllocX64.h | 23 +- CodeGen/include/Luau/IrUtils.h | 2 +- CodeGen/include/Luau/RegisterA64.h | 12 + CodeGen/include/Luau/RegisterX64.h | 12 + CodeGen/include/Luau/UnwindBuilder.h | 12 +- CodeGen/include/Luau/UnwindBuilderDwarf2.h | 22 +- CodeGen/include/Luau/UnwindBuilderWin.h | 38 +- CodeGen/src/AssemblyBuilderA64.cpp | 87 ++- CodeGen/src/AssemblyBuilderX64.cpp | 11 +- CodeGen/src/BitUtils.h | 36 + CodeGen/src/CodeBlockUnwind.cpp | 48 +- CodeGen/src/CodeGen.cpp | 26 +- CodeGen/src/CodeGenA64.cpp | 132 +++- CodeGen/src/CodeGenA64.h | 2 +- CodeGen/src/CodeGenUtils.cpp | 50 +- CodeGen/src/CodeGenUtils.h | 1 + CodeGen/src/CodeGenX64.cpp | 50 +- CodeGen/src/CodeGenX64.h | 2 +- CodeGen/src/EmitBuiltinsX64.cpp | 78 +-- CodeGen/src/EmitCommon.h | 4 +- CodeGen/src/EmitCommonA64.cpp | 130 ---- CodeGen/src/EmitCommonA64.h | 19 +- CodeGen/src/EmitCommonX64.cpp | 73 +- CodeGen/src/EmitCommonX64.h | 41 +- CodeGen/src/EmitInstructionA64.cpp | 74 --- CodeGen/src/EmitInstructionA64.h | 24 - CodeGen/src/EmitInstructionX64.cpp | 74 +-- CodeGen/src/EmitInstructionX64.h | 6 +- CodeGen/src/Fallbacks.cpp | 38 ++ CodeGen/src/Fallbacks.h | 1 + CodeGen/src/IrAnalysis.cpp | 2 + CodeGen/src/IrBuilder.cpp | 3 +- CodeGen/src/IrCallWrapperX64.cpp | 79 ++- CodeGen/src/IrLoweringA64.cpp | 733 ++++++++++++++------- CodeGen/src/IrLoweringA64.h | 4 +- CodeGen/src/IrLoweringX64.cpp | 239 ++++--- CodeGen/src/IrRegAllocA64.cpp | 21 +- CodeGen/src/IrRegAllocX64.cpp | 299 +++++---- CodeGen/src/IrTranslateBuiltins.cpp | 70 +- CodeGen/src/IrUtils.cpp | 4 +- CodeGen/src/NativeState.cpp | 26 +- CodeGen/src/NativeState.h | 13 +- CodeGen/src/OptimizeConstProp.cpp | 2 + CodeGen/src/UnwindBuilderDwarf2.cpp | 47 +- CodeGen/src/UnwindBuilderWin.cpp | 112 +++- Sources.cmake | 4 +- VM/src/lapi.cpp | 2 + VM/src/ltable.cpp | 34 +- fuzz/linter.cpp | 2 +- fuzz/proto.cpp | 8 +- fuzz/typeck.cpp | 2 +- tests/AssemblyBuilderA64.test.cpp | 26 +- tests/AssemblyBuilderX64.test.cpp | 10 + tests/Autocomplete.test.cpp | 30 + tests/CodeAllocator.test.cpp | 190 +++++- tests/Conformance.test.cpp | 43 +- tests/Fixture.cpp | 10 +- tests/Module.test.cpp | 39 +- tests/StringUtils.test.cpp | 18 + tests/TypeInfer.annotations.test.cpp | 13 +- tests/TypeInfer.functions.test.cpp | 33 + tests/TypeInfer.operators.test.cpp | 78 ++- tests/TypeInfer.provisional.test.cpp | 36 +- tests/TypeInfer.test.cpp | 15 + tests/TypeInfer.unionTypes.test.cpp | 16 + tests/TypeInfer.unknownnever.test.cpp | 5 - tests/TypeVar.test.cpp | 14 +- tests/conformance/math.lua | 10 + tests/conformance/tables.lua | 7 + tools/lvmexecute_split.py | 2 +- 88 files changed, 2573 insertions(+), 1427 deletions(-) create mode 100644 CodeGen/src/BitUtils.h delete mode 100644 CodeGen/src/EmitCommonA64.cpp delete mode 100644 CodeGen/src/EmitInstructionA64.cpp delete mode 100644 CodeGen/src/EmitInstructionA64.h diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index aa7ef8d3e..e7a018c0a 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -64,8 +64,11 @@ struct ExprOrLocal }; std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos); +std::vector findAncestryAtPositionForAutocomplete(AstStatBlock* root, Position pos); std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos, bool includeTypes = false); +std::vector findAstAncestryOfPosition(AstStatBlock* root, Position pos, bool includeTypes = false); AstNode* findNodeAtPosition(const SourceModule& source, Position pos); +AstNode* findNodeAtPosition(AstStatBlock* root, Position pos); AstExpr* findExprAtPosition(const SourceModule& source, Position pos); ScopePtr findScopeAtPosition(const Module& module, Position pos); std::optional findBindingAtPosition(const Module& module, const SourceModule& source, Position pos); diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 82251378e..3f41c1456 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -165,7 +165,15 @@ struct Frontend bool captureComments, bool typeCheckForAutocomplete = false); private: - ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete = false, bool recordJsonLog = false); + struct TypeCheckLimits + { + std::optional finishTime; + std::optional instantiationChildLimit; + std::optional unifierIterationLimit; + }; + + ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, std::optional environmentScope, + bool forAutocomplete, bool recordJsonLog, TypeCheckLimits typeCheckLimits); std::pair getSourceNode(const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); @@ -185,15 +193,21 @@ struct Frontend const NotNull builtinTypes; FileResolver* fileResolver; + FrontendModuleResolver moduleResolver; FrontendModuleResolver moduleResolverForAutocomplete; + GlobalTypes globals; GlobalTypes globalsForAutocomplete; - TypeChecker typeChecker; - TypeChecker typeCheckerForAutocomplete; + + // TODO: remove with FFlagLuauOnDemandTypecheckers + TypeChecker typeChecker_DEPRECATED; + TypeChecker typeCheckerForAutocomplete_DEPRECATED; + ConfigResolver* configResolver; FrontendOptions options; InternalErrorReporter iceHandler; + std::function prepareModuleScope; std::unordered_map sourceNodes; std::unordered_map sourceModules; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 72f87601d..1bca7636c 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -51,6 +51,7 @@ struct SourceModule }; bool isWithinComment(const SourceModule& sourceModule, Position pos); +bool isWithinComment(const ParseResult& result, Position pos); struct RequireCycle { diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index cff86df42..b9544a11d 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -738,6 +738,7 @@ const T* get(TypeId tv) return get_if(&tv->ty); } + template T* getMutable(TypeId tv) { @@ -897,6 +898,19 @@ bool hasTag(TypeId ty, const std::string& tagName); bool hasTag(const Property& prop, const std::string& tagName); bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work. +template +bool hasTypeInIntersection(TypeId ty) +{ + TypeId tf = follow(ty); + if (get(tf)) + return true; + for (auto t : flattenIntersection(tf)) + if (get(follow(t))) + return true; + return false; +} + +bool hasPrimitiveTypeInIntersection(TypeId ty, PrimitiveType::Type primTy); /* * Use this to change the kind of a particular type. * diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index e7817e57c..e3b0a8782 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -137,9 +137,9 @@ struct Unifier public: // Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error" - bool occursCheck(TypeId needle, TypeId haystack); + bool occursCheck(TypeId needle, TypeId haystack, bool reversed); bool occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); - bool occursCheck(TypePackId needle, TypePackId haystack); + bool occursCheck(TypePackId needle, TypePackId haystack, bool reversed); bool occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); Unifier makeChildUnifier(); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index dc07a35ca..cb3efe6a6 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -211,33 +211,48 @@ struct FindFullAncestry final : public AstVisitor std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos) { - AutocompleteNodeFinder finder{pos, source.root}; - source.root->visit(&finder); + return findAncestryAtPositionForAutocomplete(source.root, pos); +} + +std::vector findAncestryAtPositionForAutocomplete(AstStatBlock* root, Position pos) +{ + AutocompleteNodeFinder finder{pos, root}; + root->visit(&finder); return finder.ancestry; } std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos, bool includeTypes) { - const Position end = source.root->location.end; + return findAstAncestryOfPosition(source.root, pos, includeTypes); +} + +std::vector findAstAncestryOfPosition(AstStatBlock* root, Position pos, bool includeTypes) +{ + const Position end = root->location.end; if (pos > end) pos = end; FindFullAncestry finder(pos, end, includeTypes); - source.root->visit(&finder); + root->visit(&finder); return finder.nodes; } AstNode* findNodeAtPosition(const SourceModule& source, Position pos) { - const Position end = source.root->location.end; - if (pos < source.root->location.begin) - return source.root; + return findNodeAtPosition(source.root, pos); +} + +AstNode* findNodeAtPosition(AstStatBlock* root, Position pos) +{ + const Position end = root->location.end; + if (pos < root->location.begin) + return root; if (pos > end) pos = end; FindNode findNode{pos, end}; - findNode.visit(source.root); + findNode.visit(root); return findNode.best; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index d2bed2da3..0fc32c33d 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -595,6 +595,11 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) || get(leftType); + bool rightAny = get(rightType) || get(rightType); + bool anyPresent = leftAny || rightAny; + if (isBlocked(leftType) && leftType != resultType) return block(c.leftType, constraint); @@ -604,12 +609,12 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) && !isLogical) + if (hasTypeInIntersection(leftType) && !isLogical) return block(leftType, constraint); } // Logical expressions may proceed if the LHS is free. - if (isBlocked(leftType) || (get(leftType) && !isLogical)) + if (isBlocked(leftType) || (hasTypeInIntersection(leftType) && !isLogical)) { asMutable(resultType)->ty.emplace(errorRecoveryType()); unblock(resultType); @@ -696,11 +701,6 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) || get(leftType); - bool rightAny = get(rightType) || get(rightType); - bool anyPresent = leftAny || rightAny; - switch (c.op) { // For arithmetic operators, if the LHS is a number, the RHS must be a @@ -711,6 +711,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) && force) + asMutable(leftType)->ty.emplace(anyPresent ? builtinTypes->anyType : builtinTypes->numberType); if (isNumber(leftType)) { unify(leftType, rightType, constraint->scope); @@ -723,6 +725,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) && force) + asMutable(leftType)->ty.emplace(anyPresent ? builtinTypes->anyType : builtinTypes->stringType); if (isString(leftType)) { unify(leftType, rightType, constraint->scope); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 98022d862..5beb6c4e1 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -31,7 +31,8 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) -LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) +LUAU_FASTFLAGVARIABLE(LuauOnDemandTypecheckers, false) namespace Luau { @@ -131,8 +132,8 @@ static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, S LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete) { - if (!FFlag::DebugLuauDeferredConstraintResolution) - return Luau::loadDefinitionFileNoDCR(typeCheckForAutocomplete ? typeCheckerForAutocomplete : typeChecker, + if (!FFlag::DebugLuauDeferredConstraintResolution && !FFlag::LuauOnDemandTypecheckers) + return Luau::loadDefinitionFileNoDCR(typeCheckForAutocomplete ? typeCheckerForAutocomplete_DEPRECATED : typeChecker_DEPRECATED, typeCheckForAutocomplete ? globalsForAutocomplete : globals, targetScope, source, packageName, captureComments); LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); @@ -142,7 +143,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, Scop if (parseResult.errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - ModulePtr checkedModule = check(sourceModule, Mode::Definition, {}); + ModulePtr checkedModule = check(sourceModule, Mode::Definition, {}, std::nullopt, /*forAutocomplete*/ false, /*recordJsonLog*/ false, {}); if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; @@ -155,6 +156,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, Scop LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, bool captureComments) { + LUAU_ASSERT(!FFlag::LuauOnDemandTypecheckers); LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); Luau::SourceModule sourceModule; @@ -406,8 +408,8 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c , moduleResolverForAutocomplete(this) , globals(builtinTypes) , globalsForAutocomplete(builtinTypes) - , typeChecker(globals.globalScope, &moduleResolver, builtinTypes, &iceHandler) - , typeCheckerForAutocomplete(globalsForAutocomplete.globalScope, &moduleResolverForAutocomplete, builtinTypes, &iceHandler) + , typeChecker_DEPRECATED(globals.globalScope, &moduleResolver, builtinTypes, &iceHandler) + , typeCheckerForAutocomplete_DEPRECATED(globalsForAutocomplete.globalScope, &moduleResolverForAutocomplete, builtinTypes, &iceHandler) , configResolver(configResolver) , options(options) { @@ -491,35 +493,68 @@ CheckResult Frontend::check(const ModuleName& name, std::optional 0) - typeCheckerForAutocomplete.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; - - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckerForAutocomplete.unifierIterationLimit = - std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + if (!FFlag::LuauOnDemandTypecheckers) + { + // The autocomplete typecheck is always in strict mode with DM awareness + // to provide better type information for IDE features + typeCheckerForAutocomplete_DEPRECATED.requireCycles = requireCycles; + + if (autocompleteTimeLimit != 0.0) + typeCheckerForAutocomplete_DEPRECATED.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; + else + typeCheckerForAutocomplete_DEPRECATED.finishTime = std::nullopt; + + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckerForAutocomplete_DEPRECATED.instantiationChildLimit = + std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete_DEPRECATED.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckerForAutocomplete_DEPRECATED.unifierIterationLimit = + std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete_DEPRECATED.unifierIterationLimit = std::nullopt; + + moduleForAutocomplete = + FFlag::DebugLuauDeferredConstraintResolution + ? check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, /*recordJsonLog*/ false, {}) + : typeCheckerForAutocomplete_DEPRECATED.check(sourceModule, Mode::Strict, environmentScope); + } else - typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; - - ModulePtr moduleForAutocomplete = - FFlag::DebugLuauDeferredConstraintResolution - ? check(sourceModule, Mode::Strict, requireCycles, /*forAutocomplete*/ true, /*recordJsonLog*/ false) - : typeCheckerForAutocomplete.check(sourceModule, Mode::Strict, environmentScope); + { + // The autocomplete typecheck is always in strict mode with DM awareness + // to provide better type information for IDE features + TypeCheckLimits typeCheckLimits; + + if (autocompleteTimeLimit != 0.0) + typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; + else + typeCheckLimits.finishTime = std::nullopt; + + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.unifierIterationLimit = std::nullopt; + + moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, + /*recordJsonLog*/ false, typeCheckLimits); + } moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; @@ -543,13 +578,22 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalget(global.c_str()); if (name.value) - result->bindings[name].typeId = typeChecker.anyType; + result->bindings[name].typeId = FFlag::LuauOnDemandTypecheckers ? builtinTypes->anyType : typeChecker_DEPRECATED.anyType; } } @@ -829,15 +873,15 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, FrontendOptions options) + const ScopePtr& parentScope, FrontendOptions options) { const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson; - return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, globalScope, options, recordJsonLog); + return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, parentScope, options, recordJsonLog); } ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, FrontendOptions options, bool recordJsonLog) + const ScopePtr& parentScope, FrontendOptions options, bool recordJsonLog) { ModulePtr result = std::make_shared(); result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, iceHandler); @@ -868,7 +912,7 @@ ModulePtr check(const SourceModule& sourceModule, const std::vector requireCycles, bool forAutocomplete, bool recordJsonLog) +ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, + std::optional environmentScope, bool forAutocomplete, bool recordJsonLog, TypeCheckLimits typeCheckLimits) { - return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, - NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, - forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope, options, recordJsonLog); + if (FFlag::DebugLuauDeferredConstraintResolution && mode == Mode::Strict) + { + return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, + NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, + environmentScope ? *environmentScope : globals.globalScope, options, recordJsonLog); + } + else + { + LUAU_ASSERT(FFlag::LuauOnDemandTypecheckers); + + TypeChecker typeChecker(globals.globalScope, forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver, builtinTypes, &iceHandler); + + if (prepareModuleScope) + { + typeChecker.prepareModuleScope = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) { + prepareModuleScope(name, scope, forAutocomplete); + }; + } + + typeChecker.requireCycles = requireCycles; + typeChecker.finishTime = typeCheckLimits.finishTime; + typeChecker.instantiationChildLimit = typeCheckLimits.instantiationChildLimit; + typeChecker.unifierIterationLimit = typeCheckLimits.unifierIterationLimit; + + return typeChecker.check(sourceModule, mode, environmentScope); + } } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index fd9484038..830aaf754 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -20,6 +20,7 @@ LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess2, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); +LUAU_FASTFLAGVARIABLE(LuauCopyExportedTypes, false); namespace Luau { @@ -37,14 +38,14 @@ static bool contains(Position pos, Comment comment) return false; } -bool isWithinComment(const SourceModule& sourceModule, Position pos) +static bool isWithinComment(const std::vector& commentLocations, Position pos) { - auto iter = std::lower_bound(sourceModule.commentLocations.begin(), sourceModule.commentLocations.end(), - Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) { + auto iter = std::lower_bound( + commentLocations.begin(), commentLocations.end(), Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) { return a.location.end < b.location.end; }); - if (iter == sourceModule.commentLocations.end()) + if (iter == commentLocations.end()) return false; if (contains(pos, *iter)) @@ -53,12 +54,22 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) // Due to the nature of std::lower_bound, it is possible that iter points at a comment that ends // at pos. We'll try the next comment, if it exists. ++iter; - if (iter == sourceModule.commentLocations.end()) + if (iter == commentLocations.end()) return false; return contains(pos, *iter); } +bool isWithinComment(const SourceModule& sourceModule, Position pos) +{ + return isWithinComment(sourceModule.commentLocations, pos); +} + +bool isWithinComment(const ParseResult& result, Position pos) +{ + return isWithinComment(result.commentLocations, pos); +} + struct ClonePublicInterface : Substitution { NotNull builtinTypes; @@ -227,7 +238,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr // Copy external stuff over to Module itself this->returnType = moduleScope->returnType; - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::DebugLuauDeferredConstraintResolution || FFlag::LuauCopyExportedTypes) this->exportedTypeBindings = moduleScope->exportedTypeBindings; else this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index d70f17f57..528541083 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -337,7 +337,16 @@ bool isSubset(const UnionType& super, const UnionType& sub) return true; } +bool hasPrimitiveTypeInIntersection(TypeId ty, PrimitiveType::Type primTy) +{ + TypeId tf = follow(ty); + if (isPrim(tf, primTy)) + return true; + for (auto t : flattenIntersection(tf)) + return isPrim(follow(t), primTy); + return false; +} // When typechecking an assignment `x = e`, we typecheck `x:T` and `e:U`, // then instantiate U if `isGeneric(U)` is true, and `maybeGeneric(T)` is false. bool isGeneric(TypeId ty) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index c7d30f437..6e76af042 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1160,11 +1160,7 @@ struct TypeChecker2 visit(expr, RValue); TypeId leftType = stripFromNilAndReport(lookupType(expr), location); - const NormalizedType* norm = normalizer.normalize(leftType); - if (!norm) - reportError(NormalizationTooComplex{}, location); - - checkIndexTypeFromType(leftType, *norm, propName, location, context); + checkIndexTypeFromType(leftType, propName, location, context); } void visit(AstExprIndexName* indexName, ValueContext context) @@ -2033,8 +2029,16 @@ struct TypeChecker2 reportError(std::move(e)); } - void checkIndexTypeFromType(TypeId tableTy, const NormalizedType& norm, const std::string& prop, const Location& location, ValueContext context) + // If the provided type does not have the named property, report an error. + void checkIndexTypeFromType(TypeId tableTy, const std::string& prop, const Location& location, ValueContext context) { + const NormalizedType* norm = normalizer.normalize(tableTy); + if (!norm) + { + reportError(NormalizationTooComplex{}, location); + return; + } + bool foundOneProp = false; std::vector typesMissingTheProp; @@ -2042,49 +2046,50 @@ struct TypeChecker2 if (!normalizer.isInhabited(ty)) return; - bool found = hasIndexTypeFromType(ty, prop, location); + std::unordered_set seen; + bool found = hasIndexTypeFromType(ty, prop, location, seen); foundOneProp |= found; if (!found) typesMissingTheProp.push_back(ty); }; - fetch(norm.tops); - fetch(norm.booleans); + fetch(norm->tops); + fetch(norm->booleans); if (FFlag::LuauNegatedClassTypes) { - for (const auto& [ty, _negations] : norm.classes.classes) + for (const auto& [ty, _negations] : norm->classes.classes) { fetch(ty); } } else { - for (TypeId ty : norm.DEPRECATED_classes) + for (TypeId ty : norm->DEPRECATED_classes) fetch(ty); } - fetch(norm.errors); - fetch(norm.nils); - fetch(norm.numbers); - if (!norm.strings.isNever()) + fetch(norm->errors); + fetch(norm->nils); + fetch(norm->numbers); + if (!norm->strings.isNever()) fetch(builtinTypes->stringType); - fetch(norm.threads); - for (TypeId ty : norm.tables) + fetch(norm->threads); + for (TypeId ty : norm->tables) fetch(ty); - if (norm.functions.isTop) + if (norm->functions.isTop) fetch(builtinTypes->functionType); - else if (!norm.functions.isNever()) + else if (!norm->functions.isNever()) { - if (norm.functions.parts.size() == 1) - fetch(norm.functions.parts.front()); + if (norm->functions.parts.size() == 1) + fetch(norm->functions.parts.front()); else { std::vector parts; - parts.insert(parts.end(), norm.functions.parts.begin(), norm.functions.parts.end()); + parts.insert(parts.end(), norm->functions.parts.begin(), norm->functions.parts.end()); fetch(testArena.addType(IntersectionType{std::move(parts)})); } } - for (const auto& [tyvar, intersect] : norm.tyvars) + for (const auto& [tyvar, intersect] : norm->tyvars) { if (get(intersect->tops)) { @@ -2110,8 +2115,15 @@ struct TypeChecker2 } } - bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location) + bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location, std::unordered_set& seen) { + // If we have already encountered this type, we must assume that some + // other codepath will do the right thing and signal false if the + // property is not present. + const bool isUnseen = seen.insert(ty).second; + if (!isUnseen) + return true; + if (get(ty) || get(ty) || get(ty)) return true; @@ -2136,10 +2148,12 @@ struct TypeChecker2 else if (const ClassType* cls = get(ty)) return bool(lookupClassProp(cls, prop)); else if (const UnionType* utv = get(ty)) - ice.ice("getIndexTypeFromTypeHelper cannot take a UnionType"); + return std::all_of(begin(utv), end(utv), [&](TypeId part) { + return hasIndexTypeFromType(part, prop, location, seen); + }); else if (const IntersectionType* itv = get(ty)) return std::any_of(begin(itv), end(itv), [&](TypeId part) { - return hasIndexTypeFromType(part, prop, location); + return hasIndexTypeFromType(part, prop, location, seen); }); else return false; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index acf70fec1..7f366a204 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,14 +35,13 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) -LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) +LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) -LUAU_FASTFLAGVARIABLE(LuauReducingAndOr, false) namespace Luau { @@ -1623,9 +1622,28 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty TypeId& bindingType = bindingsMap[name].type; - if (unify(ty, bindingType, aliasScope, typealias.location)) - bindingType = ty; + if (!FFlag::LuauOccursIsntAlwaysFailure) + { + if (unify(ty, bindingType, aliasScope, typealias.location)) + bindingType = ty; + return ControlFlow::None; + } + + unify(ty, bindingType, aliasScope, typealias.location); + + // It is possible for this unification to succeed but for + // `bindingType` still to be free For example, in + // `type T = T|T`, we generate a fresh free type `X`, and then + // unify `X` with `X|X`, which succeeds without binding `X` to + // anything, since `X <: X|X` + if (bindingType->ty.get_if()) + { + ty = errorRecoveryType(aliasScope); + unify(ty, bindingType, aliasScope, typealias.location); + reportError(TypeError{typealias.location, OccursCheckFailed{}}); + } + bindingType = ty; return ControlFlow::None; } @@ -2848,7 +2866,7 @@ TypeId TypeChecker::checkRelationalOperation( { return lhsType; } - else if (FFlag::LuauTryhardAnd) + else { // If lhs is free, we can't tell which 'falsy' components it has, if any if (get(lhsType)) @@ -2860,14 +2878,11 @@ TypeId TypeChecker::checkRelationalOperation( { LUAU_ASSERT(oty); - if (FFlag::LuauReducingAndOr) - { - // Perform a limited form of type reduction for booleans - if (isPrim(*oty, PrimitiveType::Boolean) && get(get(follow(rhsType)))) - return booleanType; - if (isPrim(rhsType, PrimitiveType::Boolean) && get(get(follow(*oty)))) - return booleanType; - } + // Perform a limited form of type reduction for booleans + if (isPrim(*oty, PrimitiveType::Boolean) && get(get(follow(rhsType)))) + return booleanType; + if (isPrim(rhsType, PrimitiveType::Boolean) && get(get(follow(*oty)))) + return booleanType; return unionOfTypes(*oty, rhsType, scope, expr.location, false); } @@ -2876,16 +2891,12 @@ TypeId TypeChecker::checkRelationalOperation( return rhsType; } } - else - { - return unionOfTypes(rhsType, booleanType, scope, expr.location, false); - } case AstExprBinary::Or: if (lhsIsAny) { return lhsType; } - else if (FFlag::LuauTryhardAnd) + else { auto [oty, notNever] = pickTypesFromSense(lhsType, true, neverType); // Filter out truthy types @@ -2893,14 +2904,11 @@ TypeId TypeChecker::checkRelationalOperation( { LUAU_ASSERT(oty); - if (FFlag::LuauReducingAndOr) - { - // Perform a limited form of type reduction for booleans - if (isPrim(*oty, PrimitiveType::Boolean) && get(get(follow(rhsType)))) - return booleanType; - if (isPrim(rhsType, PrimitiveType::Boolean) && get(get(follow(*oty)))) - return booleanType; - } + // Perform a limited form of type reduction for booleans + if (isPrim(*oty, PrimitiveType::Boolean) && get(get(follow(rhsType)))) + return booleanType; + if (isPrim(rhsType, PrimitiveType::Boolean) && get(get(follow(*oty)))) + return booleanType; return unionOfTypes(*oty, rhsType, scope, expr.location); } @@ -2909,10 +2917,6 @@ TypeId TypeChecker::checkRelationalOperation( return rhsType; } } - else - { - return unionOfTypes(lhsType, rhsType, scope, expr.location); - } default: LUAU_ASSERT(0); ice(format("checkRelationalOperation called with incorrect binary expression '%s'", toString(expr.op).c_str()), expr.location); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 642aa399f..3f4e34f6d 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -19,8 +19,10 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) +LUAU_FASTFLAGVARIABLE(LuauVariadicAnyCanBeGeneric, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) +LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) @@ -431,14 +433,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superFree && subFree && subsumes(useScopes, superFree, subFree)) { - if (!occursCheck(subTy, superTy)) + if (!occursCheck(subTy, superTy, /* reversed = */ false)) log.replace(subTy, BoundType(superTy)); return; } else if (superFree && subFree) { - if (!occursCheck(superTy, subTy)) + if (!occursCheck(superTy, subTy, /* reversed = */ true)) { if (subsumes(useScopes, superFree, subFree)) { @@ -461,7 +463,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return; } - if (!occursCheck(superTy, subTy)) + if (!occursCheck(superTy, subTy, /* reversed = */ true)) { promoteTypeLevels(log, types, superFree->level, superFree->scope, useScopes, subTy); @@ -487,7 +489,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return; } - if (!occursCheck(subTy, superTy)) + if (!occursCheck(subTy, superTy, /* reversed = */ false)) { promoteTypeLevels(log, types, subFree->level, subFree->scope, useScopes, superTy); log.replace(subTy, BoundType(superTy)); @@ -1593,7 +1595,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (log.getMutable(superTp)) { - if (!occursCheck(superTp, subTp)) + if (!occursCheck(superTp, subTp, /* reversed = */ true)) { Widen widen{types, builtinTypes}; log.replace(superTp, Unifiable::Bound(widen(subTp))); @@ -1601,7 +1603,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else if (log.getMutable(subTp)) { - if (!occursCheck(subTp, superTp)) + if (!occursCheck(subTp, superTp, /* reversed = */ false)) { log.replace(subTp, Unifiable::Bound(superTp)); } @@ -2585,13 +2587,14 @@ static void queueTypePack(std::vector& queue, DenseHashSet& void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool reversed, int subOffset) { const VariadicTypePack* superVariadic = log.getMutable(superTp); + const TypeId variadicTy = follow(superVariadic->ty); if (!superVariadic) ice("passed non-variadic pack to tryUnifyVariadics"); if (const VariadicTypePack* subVariadic = log.get(subTp)) { - tryUnify_(reversed ? superVariadic->ty : subVariadic->ty, reversed ? subVariadic->ty : superVariadic->ty); + tryUnify_(reversed ? variadicTy : subVariadic->ty, reversed ? subVariadic->ty : variadicTy); } else if (log.get(subTp)) { @@ -2602,7 +2605,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever while (subIter != subEnd) { - tryUnify_(reversed ? superVariadic->ty : *subIter, reversed ? *subIter : superVariadic->ty); + tryUnify_(reversed ? variadicTy : *subIter, reversed ? *subIter : variadicTy); ++subIter; } @@ -2615,7 +2618,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else if (const VariadicTypePack* vtp = get(tail)) { - tryUnify_(vtp->ty, superVariadic->ty); + tryUnify_(vtp->ty, variadicTy); } else if (get(tail)) { @@ -2631,6 +2634,10 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } } } + else if (FFlag::LuauVariadicAnyCanBeGeneric && get(variadicTy) && log.get(subTp)) + { + // Nothing to do. This is ok. + } else { reportError(location, GenericError{"Failed to unify variadic packs"}); @@ -2751,11 +2758,42 @@ TxnLog Unifier::combineLogsIntoUnion(std::vector logs) return result; } -bool Unifier::occursCheck(TypeId needle, TypeId haystack) +bool Unifier::occursCheck(TypeId needle, TypeId haystack, bool reversed) { sharedState.tempSeenTy.clear(); - return occursCheck(sharedState.tempSeenTy, needle, haystack); + bool occurs = occursCheck(sharedState.tempSeenTy, needle, haystack); + + if (occurs && FFlag::LuauOccursIsntAlwaysFailure) + { + Unifier innerState = makeChildUnifier(); + if (const UnionType* ut = get(haystack)) + { + if (reversed) + innerState.tryUnifyUnionWithType(haystack, ut, needle); + else + innerState.tryUnifyTypeWithUnion(needle, haystack, ut, /* cacheEnabled = */ false, /* isFunction = */ false); + } + else if (const IntersectionType* it = get(haystack)) + { + if (reversed) + innerState.tryUnifyIntersectionWithType(haystack, it, needle, /* cacheEnabled = */ false, /* isFunction = */ false); + else + innerState.tryUnifyTypeWithIntersection(needle, haystack, it); + } + else + { + innerState.failure = true; + } + + if (innerState.failure) + { + reportError(location, OccursCheckFailed{}); + log.replace(needle, *builtinTypes->errorRecoveryType()); + } + } + + return occurs; } bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) @@ -2785,8 +2823,11 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - reportError(location, OccursCheckFailed{}); - log.replace(needle, *builtinTypes->errorRecoveryType()); + if (!FFlag::LuauOccursIsntAlwaysFailure) + { + reportError(location, OccursCheckFailed{}); + log.replace(needle, *builtinTypes->errorRecoveryType()); + } return true; } @@ -2807,11 +2848,19 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays return occurrence; } -bool Unifier::occursCheck(TypePackId needle, TypePackId haystack) +bool Unifier::occursCheck(TypePackId needle, TypePackId haystack, bool reversed) { sharedState.tempSeenTp.clear(); - return occursCheck(sharedState.tempSeenTp, needle, haystack); + bool occurs = occursCheck(sharedState.tempSeenTp, needle, haystack); + + if (occurs && FFlag::LuauOccursIsntAlwaysFailure) + { + reportError(location, OccursCheckFailed{}); + log.replace(needle, *builtinTypes->errorRecoveryTypePack()); + } + + return occurs; } bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) @@ -2836,8 +2885,11 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { if (needle == haystack) { - reportError(location, OccursCheckFailed{}); - log.replace(needle, *builtinTypes->errorRecoveryTypePack()); + if (!FFlag::LuauOccursIsntAlwaysFailure) + { + reportError(location, OccursCheckFailed{}); + log.replace(needle, *builtinTypes->errorRecoveryTypePack()); + } return true; } diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 11e0076a2..343c553c3 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -167,7 +167,9 @@ size_t editDistance(std::string_view a, std::string_view b) for (size_t y = 1; y <= b.size(); ++y) { - size_t x1 = seenCharToRow[b[y - 1]]; + // The value of b[N] can be negative with unicode characters + unsigned char bSeenCharIndex = static_cast(b[y - 1]); + size_t x1 = seenCharToRow[bSeenCharIndex]; size_t y1 = lastMatchedY; size_t cost = 1; @@ -187,7 +189,9 @@ size_t editDistance(std::string_view a, std::string_view b) distances[getPos(x + 1, y + 1)] = std::min(std::min(insertion, deletion), std::min(substitution, transposition)); } - seenCharToRow[a[x - 1]] = x; + // The value of a[N] can be negative with unicode characters + unsigned char aSeenCharIndex = static_cast(a[x - 1]); + seenCharToRow[aSeenCharIndex] = x; } return distances[getPos(a.size() + 1, b.size() + 1)]; diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h index 2796ef708..acb64e390 100644 --- a/CodeGen/include/Luau/AddressA64.h +++ b/CodeGen/include/Luau/AddressA64.h @@ -29,7 +29,7 @@ struct AddressA64 // For example, ldr x0, [reg+imm] is limited to 8 KB offsets assuming imm is divisible by 8, but loading into w0 reduces the range to 4 KB static constexpr size_t kMaxOffset = 1023; - AddressA64(RegisterA64 base, int off = 0) + constexpr AddressA64(RegisterA64 base, int off = 0) : kind(AddressKindA64::imm) , base(base) , offset(xzr) @@ -38,7 +38,7 @@ struct AddressA64 LUAU_ASSERT(base.kind == KindA64::x || base == sp); } - AddressA64(RegisterA64 base, RegisterA64 offset) + constexpr AddressA64(RegisterA64 base, RegisterA64 offset) : kind(AddressKindA64::reg) , base(base) , offset(offset) diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index def4d0c0c..42f5f8a68 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -49,17 +49,25 @@ class AssemblyBuilderA64 void cmp(RegisterA64 src1, RegisterA64 src2); void cmp(RegisterA64 src1, uint16_t src2); void csel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); + void cset(RegisterA64 dst, ConditionA64 cond); // Bitwise - // TODO: support immediate arguments (they have odd encoding and forbid many values) - // TODO: support bic (andnot) // TODO: support shifts // TODO: support bitfield ops void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void bic(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void tst(RegisterA64 src1, RegisterA64 src2); void mvn(RegisterA64 dst, RegisterA64 src); + // Bitwise with immediate + // Note: immediate must have a single contiguous sequence of 1 bits set of length 1..31 + void and_(RegisterA64 dst, RegisterA64 src1, uint32_t src2); + void orr(RegisterA64 dst, RegisterA64 src1, uint32_t src2); + void eor(RegisterA64 dst, RegisterA64 src1, uint32_t src2); + void tst(RegisterA64 src1, uint32_t src2); + // Shifts void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); @@ -168,7 +176,7 @@ class AssemblyBuilderA64 private: // Instruction archetypes void place0(const char* name, uint32_t word); - void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0); + void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0, int N = 0); void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2 = 0); void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); @@ -181,8 +189,9 @@ class AssemblyBuilderA64 void placeADR(const char* name, RegisterA64 src, uint8_t op); void placeADR(const char* name, RegisterA64 src, uint8_t op, Label& label); void placeP(const char* name, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src, uint8_t op, uint8_t opc, int sizelog); - void placeCS(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc); + void placeCS(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc, int invert = 0); void placeFCMP(const char* name, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t opc); + void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op); void place(uint32_t word); diff --git a/CodeGen/include/Luau/IrCallWrapperX64.h b/CodeGen/include/Luau/IrCallWrapperX64.h index 724d46243..c403d189b 100644 --- a/CodeGen/include/Luau/IrCallWrapperX64.h +++ b/CodeGen/include/Luau/IrCallWrapperX64.h @@ -41,12 +41,14 @@ class IrCallWrapperX64 void call(const OperandX64& func); + RegisterX64 suggestNextArgumentRegister(SizeX64 size) const; + IrRegAllocX64& regs; AssemblyBuilderX64& build; uint32_t instIdx = ~0u; private: - void assignTargetRegisters(); + OperandX64 getNextArgumentTarget(SizeX64 size) const; void countRegisterUses(); CallArgument* findNonInterferingArgument(); bool interferesWithOperand(const OperandX64& op, RegisterX64 reg) const; @@ -67,6 +69,9 @@ class IrCallWrapperX64 std::array args; int argCount = 0; + int gprPos = 0; + int xmmPos = 0; + OperandX64 funcOp; // Internal counters for remaining register use counts diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index fcf29adb1..486a0135a 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -155,7 +155,7 @@ enum class IrCmd : uint8_t // Compute Luau 'not' operation on destructured TValue // A: tag - // B: double + // B: int (value) NOT_ANY, // TODO: boolean specialization will be useful // Unconditional jump @@ -233,7 +233,7 @@ enum class IrCmd : uint8_t // Try to get pointer to tag method TValue inside the table's metatable or jump if there is no such value or metatable // A: table - // B: int + // B: int (TMS enum) // C: block TRY_CALL_FASTGETTM, @@ -256,8 +256,8 @@ enum class IrCmd : uint8_t // B: Rn (result start) // C: Rn (argument start) // D: Rn or Kn or a boolean that's false (optional second argument) - // E: int (argument count or -1 to use all arguments up to stack top) - // F: int (result count or -1 to preserve all results and adjust stack top) + // E: int (argument count) + // F: int (result count) FASTCALL, // Call the fastcall builtin function @@ -517,8 +517,10 @@ enum class IrCmd : uint8_t FALLBACK_FORGPREP, // Instruction that passes value through, it is produced by constant folding and users substitute it with the value + // When operand location is set, updates the tracked location of the value in memory SUBSTITUTE, // A: operand of any type + // B: Rn/Kn/none (location of operand in memory; optional) }; enum class IrConstKind : uint8_t @@ -694,6 +696,9 @@ struct IrFunction std::vector bcMapping; + // For each instruction, an operand that can be used to recompute the calue + std::vector valueRestoreOps; + Proto* proto = nullptr; CfgInfo cfg; @@ -829,19 +834,40 @@ struct IrFunction return value.valueDouble; } - uint32_t getBlockIndex(const IrBlock& block) + uint32_t getBlockIndex(const IrBlock& block) const { // Can only be called with blocks from our vector LUAU_ASSERT(&block >= blocks.data() && &block <= blocks.data() + blocks.size()); return uint32_t(&block - blocks.data()); } - uint32_t getInstIndex(const IrInst& inst) + uint32_t getInstIndex(const IrInst& inst) const { // Can only be called with instructions from our vector LUAU_ASSERT(&inst >= instructions.data() && &inst <= instructions.data() + instructions.size()); return uint32_t(&inst - instructions.data()); } + + void recordRestoreOp(uint32_t instIdx, IrOp location) + { + if (instIdx >= valueRestoreOps.size()) + valueRestoreOps.resize(instIdx + 1); + + valueRestoreOps[instIdx] = location; + } + + IrOp findRestoreOp(uint32_t instIdx) const + { + if (instIdx >= valueRestoreOps.size()) + return {}; + + return valueRestoreOps[instIdx]; + } + + IrOp findRestoreOp(const IrInst& inst) const + { + return findRestoreOp(getInstIndex(inst)); + } }; inline IrCondition conditionOp(IrOp op) diff --git a/CodeGen/include/Luau/IrRegAllocX64.h b/CodeGen/include/Luau/IrRegAllocX64.h index dc7b48c6b..f83cc2208 100644 --- a/CodeGen/include/Luau/IrRegAllocX64.h +++ b/CodeGen/include/Luau/IrRegAllocX64.h @@ -20,7 +20,9 @@ constexpr uint8_t kNoStackSlot = 0xff; struct IrSpillX64 { uint32_t instIdx = 0; - bool useDoubleSlot = 0; + IrValueKind valueKind = IrValueKind::Unknown; + + unsigned spillId = 0; // Spill location can be a stack location or be empty // When it's empty, it means that instruction value can be rematerialized @@ -33,12 +35,8 @@ struct IrRegAllocX64 { IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function); - RegisterX64 allocGprReg(SizeX64 preferredSize, uint32_t instIdx); - RegisterX64 allocXmmReg(uint32_t instIdx); - - RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t instIdx, std::initializer_list oprefs); - RegisterX64 allocXmmRegOrReuse(uint32_t instIdx, std::initializer_list oprefs); - + RegisterX64 allocReg(SizeX64 size, uint32_t instIdx); + RegisterX64 allocRegOrReuse(SizeX64 size, uint32_t instIdx, std::initializer_list oprefs); RegisterX64 takeReg(RegisterX64 reg, uint32_t instIdx); void freeReg(RegisterX64 reg); @@ -49,6 +47,12 @@ struct IrRegAllocX64 bool shouldFreeGpr(RegisterX64 reg) const; + unsigned findSpillStackSlot(IrValueKind valueKind); + + IrOp getRestoreOp(const IrInst& inst) const; + bool hasRestoreOp(const IrInst& inst) const; + OperandX64 getRestoreAddress(const IrInst& inst, IrOp restoreOp); + // Register used by instruction is about to be freed, have to find a way to restore value later void preserve(IrInst& inst); @@ -74,6 +78,7 @@ struct IrRegAllocX64 std::bitset<256> usedSpillSlots; unsigned maxUsedSlot = 0; + unsigned nextSpillId = 1; std::vector spills; }; @@ -107,10 +112,8 @@ struct ScopedSpills ScopedSpills(const ScopedSpills&) = delete; ScopedSpills& operator=(const ScopedSpills&) = delete; - bool wasSpilledBefore(const IrSpillX64& spill) const; - IrRegAllocX64& owner; - std::vector snapshot; + unsigned startSpillId = 0; }; } // namespace X64 diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 09c55c799..136ce3b8b 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -200,7 +200,7 @@ void replace(IrFunction& function, IrOp& original, IrOp replacement); void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst replacement); // Replace instruction with a different value (using IrCmd::SUBSTITUTE) -void substitute(IrFunction& function, IrInst& inst, IrOp replacement); +void substitute(IrFunction& function, IrInst& inst, IrOp replacement, IrOp location = {}); // Replace instruction arguments that point to substitutions with target values void applySubstitutions(IrFunction& function, IrOp& op); diff --git a/CodeGen/include/Luau/RegisterA64.h b/CodeGen/include/Luau/RegisterA64.h index 99e62958d..c3a9ae03f 100644 --- a/CodeGen/include/Luau/RegisterA64.h +++ b/CodeGen/include/Luau/RegisterA64.h @@ -46,6 +46,18 @@ constexpr RegisterA64 castReg(KindA64 kind, RegisterA64 reg) return RegisterA64{kind, reg.index}; } +// This is equivalent to castReg(KindA64::x), but is separate because it implies different semantics +// Specifically, there are cases when it's useful to treat a wN register as an xN register *after* it has been assigned a value +// Since all A64 instructions that write to wN implicitly zero the top half, this works when we need zero extension semantics +// Crucially, this is *not* safe on an ABI boundary - an int parameter in wN register may have anything in its top half in certain cases +// However, as long as our codegen doesn't use 32-bit truncation by using castReg x=>w, we can safely rely on this. +constexpr RegisterA64 zextReg(RegisterA64 reg) +{ + LUAU_ASSERT(reg.kind == KindA64::w); + + return RegisterA64{KindA64::x, reg.index}; +} + constexpr RegisterA64 noreg{KindA64::none, 0}; constexpr RegisterA64 w0{KindA64::w, 0}; diff --git a/CodeGen/include/Luau/RegisterX64.h b/CodeGen/include/Luau/RegisterX64.h index 9d76b1169..7fa976077 100644 --- a/CodeGen/include/Luau/RegisterX64.h +++ b/CodeGen/include/Luau/RegisterX64.h @@ -46,6 +46,18 @@ constexpr RegisterX64 al{SizeX64::byte, 0}; constexpr RegisterX64 cl{SizeX64::byte, 1}; constexpr RegisterX64 dl{SizeX64::byte, 2}; constexpr RegisterX64 bl{SizeX64::byte, 3}; +constexpr RegisterX64 spl{SizeX64::byte, 4}; +constexpr RegisterX64 bpl{SizeX64::byte, 5}; +constexpr RegisterX64 sil{SizeX64::byte, 6}; +constexpr RegisterX64 dil{SizeX64::byte, 7}; +constexpr RegisterX64 r8b{SizeX64::byte, 8}; +constexpr RegisterX64 r9b{SizeX64::byte, 9}; +constexpr RegisterX64 r10b{SizeX64::byte, 10}; +constexpr RegisterX64 r11b{SizeX64::byte, 11}; +constexpr RegisterX64 r12b{SizeX64::byte, 12}; +constexpr RegisterX64 r13b{SizeX64::byte, 13}; +constexpr RegisterX64 r14b{SizeX64::byte, 14}; +constexpr RegisterX64 r15b{SizeX64::byte, 15}; constexpr RegisterX64 eax{SizeX64::dword, 0}; constexpr RegisterX64 ecx{SizeX64::dword, 1}; diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h index 98e604982..8fe55ba61 100644 --- a/CodeGen/include/Luau/UnwindBuilder.h +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -11,6 +11,9 @@ namespace Luau namespace CodeGen { +// This value is used in 'finishFunction' to mark the function that spans to the end of the whole code block +static uint32_t kFullBlockFuncton = ~0u; + class UnwindBuilder { public: @@ -19,19 +22,22 @@ class UnwindBuilder virtual void setBeginOffset(size_t beginOffset) = 0; virtual size_t getBeginOffset() const = 0; - virtual void start() = 0; + virtual void startInfo() = 0; + virtual void startFunction() = 0; virtual void spill(int espOffset, X64::RegisterX64 reg) = 0; virtual void save(X64::RegisterX64 reg) = 0; virtual void allocStack(int size) = 0; virtual void setupFrameReg(X64::RegisterX64 reg, int espOffset) = 0; + virtual void finishFunction(uint32_t beginOffset, uint32_t endOffset) = 0; - virtual void finish() = 0; + virtual void finishInfo() = 0; virtual size_t getSize() const = 0; + virtual size_t getFunctionCount() const = 0; // This will place the unwinding data at the target address and might update values of some fields - virtual void finalize(char* target, void* funcAddress, size_t funcSize) const = 0; + virtual void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const = 0; }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index 972f7423b..9f862d23f 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -4,34 +4,48 @@ #include "Luau/RegisterX64.h" #include "UnwindBuilder.h" +#include + namespace Luau { namespace CodeGen { +struct UnwindFunctionDwarf2 +{ + uint32_t beginOffset; + uint32_t endOffset; + uint32_t fdeEntryStartPos; +}; + class UnwindBuilderDwarf2 : public UnwindBuilder { public: void setBeginOffset(size_t beginOffset) override; size_t getBeginOffset() const override; - void start() override; + void startInfo() override; + void startFunction() override; void spill(int espOffset, X64::RegisterX64 reg) override; void save(X64::RegisterX64 reg) override; void allocStack(int size) override; void setupFrameReg(X64::RegisterX64 reg, int espOffset) override; + void finishFunction(uint32_t beginOffset, uint32_t endOffset) override; - void finish() override; + void finishInfo() override; size_t getSize() const override; + size_t getFunctionCount() const override; - void finalize(char* target, void* funcAddress, size_t funcSize) const override; + void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const override; private: size_t beginOffset = 0; - static const unsigned kRawDataLimit = 128; + std::vector unwindFunctions; + + static const unsigned kRawDataLimit = 1024; uint8_t rawData[kRawDataLimit]; uint8_t* pos = rawData; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index 1cd750a1d..ccd7125d7 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -11,6 +11,25 @@ namespace Luau namespace CodeGen { +// This struct matches the layout of x64 RUNTIME_FUNCTION from winnt.h +struct UnwindFunctionWin +{ + uint32_t beginOffset; + uint32_t endOffset; + uint32_t unwindInfoOffset; +}; + +// This struct matches the layout of x64 UNWIND_INFO from ehdata.h +struct UnwindInfoWin +{ + uint8_t version : 3; + uint8_t flags : 5; + uint8_t prologsize; + uint8_t unwindcodecount; + uint8_t framereg : 4; + uint8_t frameregoff : 4; +}; + // This struct matches the layout of UNWIND_CODE from ehdata.h struct UnwindCodeWin { @@ -25,31 +44,38 @@ class UnwindBuilderWin : public UnwindBuilder void setBeginOffset(size_t beginOffset) override; size_t getBeginOffset() const override; - void start() override; + void startInfo() override; + void startFunction() override; void spill(int espOffset, X64::RegisterX64 reg) override; void save(X64::RegisterX64 reg) override; void allocStack(int size) override; void setupFrameReg(X64::RegisterX64 reg, int espOffset) override; + void finishFunction(uint32_t beginOffset, uint32_t endOffset) override; - void finish() override; + void finishInfo() override; size_t getSize() const override; + size_t getFunctionCount() const override; - void finalize(char* target, void* funcAddress, size_t funcSize) const override; + void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const override; private: size_t beginOffset = 0; + static const unsigned kRawDataLimit = 1024; + uint8_t rawData[kRawDataLimit]; + uint8_t* rawDataPos = rawData; + + std::vector unwindFunctions; + // Windows unwind codes are written in reverse, so we have to collect them all first std::vector unwindCodes; uint8_t prologSize = 0; - X64::RegisterX64 frameReg = X64::rax; // rax means that frame register is not used + X64::RegisterX64 frameReg = X64::noreg; uint8_t frameRegOffset = 0; uint32_t stackOffset = 0; - - size_t infoSize = 0; }; } // namespace CodeGen diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index a80003e94..d6274256e 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/AssemblyBuilderA64.h" +#include "BitUtils.h" #include "ByteUtils.h" #include @@ -126,6 +127,15 @@ void AssemblyBuilderA64::csel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src placeCS("csel", dst, src1, src2, cond, 0b11010'10'0, 0b00); } +void AssemblyBuilderA64::cset(RegisterA64 dst, ConditionA64 cond) +{ + LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w); + + RegisterA64 src = dst.kind == KindA64::x ? xzr : wzr; + + placeCS("cset", dst, src, src, cond, 0b11010'10'0, 0b01, /* invert= */ 1); +} + void AssemblyBuilderA64::and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { placeSR3("and", dst, src1, src2, 0b00'01010); @@ -141,11 +151,45 @@ void AssemblyBuilderA64::eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2 placeSR3("eor", dst, src1, src2, 0b10'01010); } +void AssemblyBuilderA64::bic(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) +{ + placeSR3("bic", dst, src1, src2, 0b00'01010, /* shift= */ 0, /* N= */ 1); +} + +void AssemblyBuilderA64::tst(RegisterA64 src1, RegisterA64 src2) +{ + RegisterA64 dst = src1.kind == KindA64::x ? xzr : wzr; + + placeSR3("tst", dst, src1, src2, 0b11'01010); +} + void AssemblyBuilderA64::mvn(RegisterA64 dst, RegisterA64 src) { placeSR2("mvn", dst, src, 0b01'01010, 0b1); } +void AssemblyBuilderA64::and_(RegisterA64 dst, RegisterA64 src1, uint32_t src2) +{ + placeBM("and", dst, src1, src2, 0b00'100100); +} + +void AssemblyBuilderA64::orr(RegisterA64 dst, RegisterA64 src1, uint32_t src2) +{ + placeBM("orr", dst, src1, src2, 0b01'100100); +} + +void AssemblyBuilderA64::eor(RegisterA64 dst, RegisterA64 src1, uint32_t src2) +{ + placeBM("eor", dst, src1, src2, 0b10'100100); +} + +void AssemblyBuilderA64::tst(RegisterA64 src1, uint32_t src2) +{ + RegisterA64 dst = src1.kind == KindA64::x ? xzr : wzr; + + placeBM("tst", dst, src1, src2, 0b11'100100); +} + void AssemblyBuilderA64::lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { placeR3("lsl", dst, src1, src2, 0b11010110, 0b0010'00); @@ -583,7 +627,7 @@ void AssemblyBuilderA64::place0(const char* name, uint32_t op) commit(); } -void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift) +void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift, int N) { if (logText) log(name, dst, src1, src2, shift); @@ -594,7 +638,7 @@ void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64 uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; - place(dst.index | (src1.index << 5) | (shift << 10) | (src2.index << 16) | (op << 24) | sf); + place(dst.index | (src1.index << 5) | (shift << 10) | (src2.index << 16) | (N << 21) | (op << 24) | sf); commit(); } @@ -764,7 +808,8 @@ void AssemblyBuilderA64::placeP(const char* name, RegisterA64 src1, RegisterA64 commit(); } -void AssemblyBuilderA64::placeCS(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc) +void AssemblyBuilderA64::placeCS( + const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc, int invert) { if (logText) log(name, dst, src1, src2, cond); @@ -773,7 +818,7 @@ void AssemblyBuilderA64::placeCS(const char* name, RegisterA64 dst, RegisterA64 uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; - place(dst.index | (src1.index << 5) | (opc << 10) | (codeForCondition[int(cond)] << 12) | (src2.index << 16) | (op << 21) | sf); + place(dst.index | (src1.index << 5) | (opc << 10) | ((codeForCondition[int(cond)] ^ invert) << 12) | (src2.index << 16) | (op << 21) | sf); commit(); } @@ -793,6 +838,29 @@ void AssemblyBuilderA64::placeFCMP(const char* name, RegisterA64 src1, RegisterA commit(); } +void AssemblyBuilderA64::placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op) +{ + if (logText) + log(name, dst, src1, src2); + + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(dst.kind == src1.kind); + + uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; + + int lz = countlz(src2); + int rz = countrz(src2); + + LUAU_ASSERT(lz + rz > 0 && lz + rz < 32); // must have at least one 0 and at least one 1 + LUAU_ASSERT((src2 >> rz) == (1 << (32 - lz - rz)) - 1); // sequence of 1s must be contiguous + + int imms = 31 - lz - rz; // count of 1s minus 1 + int immr = (32 - rz) & 31; // right rotate amount + + place(dst.index | (src1.index << 5) | (imms << 10) | (immr << 16) | (op << 23) | sf); + commit(); +} + void AssemblyBuilderA64::place(uint32_t word) { LUAU_ASSERT(codePos < codeEnd); @@ -965,10 +1033,13 @@ void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 sr { logAppend(" %-12s", opcode); log(dst); - text.append(","); - log(src1); - text.append(","); - log(src2); + if ((src1 != wzr && src1 != xzr) || (src2 != wzr && src2 != xzr)) + { + text.append(","); + log(src1); + text.append(","); + log(src2); + } text.append(","); text.append(textForCondition[int(cond)] + 2); // skip b. text.append("\n"); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index d86a37c6e..ed95004fd 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -31,7 +31,8 @@ static_assert(sizeof(setccTextForCondition) / sizeof(setccTextForCondition[0]) = #define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7)) #define OP_PLUS_CC(op, cc) ((op) + uint8_t(cc)) -#define REX_W(value) (value ? 0x8 : 0x0) +#define REX_W_BIT(value) (value ? 0x8 : 0x0) +#define REX_W(reg) REX_W_BIT((reg).size == SizeX64::qword || ((reg).size == SizeX64::byte && (reg).index >= 4)) #define REX_R(reg) (((reg).index & 0x8) >> 1) #define REX_X(reg) (((reg).index & 0x8) >> 2) #define REX_B(reg) (((reg).index & 0x8) >> 3) @@ -1116,7 +1117,7 @@ void AssemblyBuilderX64::placeAvx( void AssemblyBuilderX64::placeRex(RegisterX64 op) { - uint8_t code = REX_W(op.size == SizeX64::qword) | REX_B(op); + uint8_t code = REX_W(op) | REX_B(op); if (code != 0) place(code | 0x40); @@ -1127,9 +1128,9 @@ void AssemblyBuilderX64::placeRex(OperandX64 op) uint8_t code = 0; if (op.cat == CategoryX64::reg) - code = REX_W(op.base.size == SizeX64::qword) | REX_B(op.base); + code = REX_W(op.base) | REX_B(op.base); else if (op.cat == CategoryX64::mem) - code = REX_W(op.memSize == SizeX64::qword) | REX_X(op.index) | REX_B(op.base); + code = REX_W_BIT(op.memSize == SizeX64::qword) | REX_X(op.index) | REX_B(op.base); else LUAU_ASSERT(!"No encoding for left operand of this category"); @@ -1154,7 +1155,7 @@ void AssemblyBuilderX64::placeRexNoW(OperandX64 op) void AssemblyBuilderX64::placeRex(RegisterX64 lhs, OperandX64 rhs) { - uint8_t code = REX_W(lhs.size == SizeX64::qword); + uint8_t code = REX_W(lhs); if (rhs.cat == CategoryX64::imm) code |= REX_B(lhs); diff --git a/CodeGen/src/BitUtils.h b/CodeGen/src/BitUtils.h new file mode 100644 index 000000000..93f7cc8db --- /dev/null +++ b/CodeGen/src/BitUtils.h @@ -0,0 +1,36 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +#ifdef _MSC_VER +#include +#endif + +namespace Luau +{ +namespace CodeGen +{ + +inline int countlz(uint32_t n) +{ +#ifdef _MSC_VER + unsigned long rl; + return _BitScanReverse(&rl, n) ? 31 - int(rl) : 32; +#else + return n == 0 ? 32 : __builtin_clz(n); +#endif +} + +inline int countrz(uint32_t n) +{ +#ifdef _MSC_VER + unsigned long rl; + return _BitScanForward(&rl, n) ? int(rl) : 32; +#else + return n == 0 ? 32 : __builtin_ctz(n); +#endif +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index 72842be7b..ccd15facb 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -54,70 +54,42 @@ namespace CodeGen void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) { -#if defined(_WIN32) && defined(_M_X64) UnwindBuilder* unwind = (UnwindBuilder*)context; // All unwinding related data is placed together at the start of the block - size_t unwindSize = sizeof(RUNTIME_FUNCTION) + unwind->getSize(); + size_t unwindSize = unwind->getSize(); unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment LUAU_ASSERT(blockSize >= unwindSize); - RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)block; - runtimeFunc->BeginAddress = DWORD(unwindSize); // Code will start after the unwind info - runtimeFunc->EndAddress = DWORD(blockSize); // Whole block is a part of a 'single function' - runtimeFunc->UnwindInfoAddress = DWORD(sizeof(RUNTIME_FUNCTION)); // Unwind info is placed at the start of the block - - char* unwindData = (char*)block + runtimeFunc->UnwindInfoAddress; - unwind->finalize(unwindData, block + unwindSize, blockSize - unwindSize); + char* unwindData = (char*)block; + unwind->finalize(unwindData, unwindSize, block, blockSize); - if (!RtlAddFunctionTable(runtimeFunc, 1, uintptr_t(block))) +#if defined(_WIN32) && defined(_M_X64) + if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(unwind->getFunctionCount()), uintptr_t(block))) { LUAU_ASSERT(!"failed to allocate function table"); return nullptr; } - - beginOffset = unwindSize + unwind->getBeginOffset(); - return block; -#elif !defined(_WIN32) - UnwindBuilder* unwind = (UnwindBuilder*)context; - - // All unwinding related data is placed together at the start of the block - size_t unwindSize = unwind->getSize(); - unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment - LUAU_ASSERT(blockSize >= unwindSize); - - char* unwindData = (char*)block; - unwind->finalize(unwindData, block, blockSize); - -#if defined(__APPLE__) +#elif defined(__APPLE__) visitFdeEntries(unwindData, __register_frame); -#else +#elif !defined(_WIN32) __register_frame(unwindData); #endif beginOffset = unwindSize + unwind->getBeginOffset(); return block; -#endif - - return nullptr; } void destroyBlockUnwindInfo(void* context, void* unwindData) { #if defined(_WIN32) && defined(_M_X64) - RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)unwindData; - - if (!RtlDeleteFunctionTable(runtimeFunc)) + if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData)) LUAU_ASSERT(!"failed to deallocate function table"); -#elif !defined(_WIN32) - -#if defined(__APPLE__) +#elif defined(__APPLE__) visitFdeEntries((char*)unwindData, __deregister_frame); -#else +#elif !defined(_WIN32) __deregister_frame(unwindData); #endif - -#endif } } // namespace CodeGen diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 8e6e94933..6cd9ea055 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -176,6 +176,10 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& IrInst& inst = function.instructions[index]; + // Substitutions might have meta information about operand restore location from memory + if (inst.cmd == IrCmd::SUBSTITUTE && inst.b.kind != IrOpKind::None) + function.recordRestoreOp(inst.a.index, inst.b); + // Skip pseudo instructions, but make sure they are not used at this stage // This also prevents them from getting into text output when that's enabled if (isPseudo(inst.cmd)) @@ -195,7 +199,18 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& lowering.lowerInst(inst, index, next); if (lowering.hasError()) + { + // Place labels for all blocks that we're skipping + // This is needed to avoid AssemblyBuilder assertions about jumps in earlier blocks with unplaced labels + for (size_t j = i + 1; j < sortedBlocks.size(); ++j) + { + IrBlock& abandoned = function.blocks[sortedBlocks[j]]; + + build.setLabel(abandoned.label); + } + return false; + } } if (options.includeIr) @@ -223,12 +238,8 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& [[maybe_unused]] static bool lowerIr( X64::AssemblyBuilderX64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { - constexpr uint32_t kFunctionAlignment = 32; - optimizeMemoryOperandsX64(ir.function); - build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); - X64::IrLoweringX64 lowering(build, helpers, data, ir.function); return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); @@ -237,9 +248,6 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& [[maybe_unused]] static bool lowerIr( A64::AssemblyBuilderA64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { - if (!A64::IrLoweringA64::canLower(ir.function)) - return false; - A64::IrLoweringA64 lowering(build, helpers, data, proto, ir.function); return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); @@ -432,13 +440,13 @@ void create(lua_State* L) initHelperFunctions(data); #if defined(__x86_64__) || defined(_M_X64) - if (!X64::initEntryFunction(data)) + if (!X64::initHeaderFunctions(data)) { destroyNativeState(L); return; } #elif defined(__aarch64__) - if (!A64::initEntryFunction(data)) + if (!A64::initHeaderFunctions(data)) { destroyNativeState(L); return; diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index e7a1e2e21..7f29beb2b 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -17,14 +17,107 @@ namespace CodeGen namespace A64 { -bool initEntryFunction(NativeState& data) +struct EntryLocations { - AssemblyBuilderA64 build(/* logText= */ false); - UnwindBuilder& unwind = *data.unwindBuilder.get(); + Label start; + Label prologueEnd; + Label epilogueStart; +}; + +static void emitExit(AssemblyBuilderA64& build, bool continueInVm) +{ + build.mov(x0, continueInVm); + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, gateExit))); + build.br(x1); +} + +static void emitInterrupt(AssemblyBuilderA64& build) +{ + // x0 = pc offset + // x1 = return address in native code + // x2 = interrupt + + // Stash return address in rBase; we need to reload rBase anyway + build.mov(rBase, x1); + + // Update savedpc; required in case interrupt errors + build.add(x0, rCode, x0); + build.ldr(x1, mem(rState, offsetof(lua_State, ci))); + build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); + + // Call interrupt + build.mov(x0, rState); + build.mov(w1, -1); + build.blr(x2); + + // Check if we need to exit + Label skip; + build.ldrb(w0, mem(rState, offsetof(lua_State, status))); + build.cbz(w0, skip); + + // L->ci->savedpc-- + // note: recomputing this avoids having to stash x0 + build.ldr(x1, mem(rState, offsetof(lua_State, ci))); + build.ldr(x0, mem(x1, offsetof(CallInfo, savedpc))); + build.sub(x0, x0, sizeof(Instruction)); + build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); + + emitExit(build, /* continueInVm */ false); + + build.setLabel(skip); + + // Return back to caller; rBase has stashed return address + build.mov(x0, rBase); + + emitUpdateBase(build); // interrupt may have reallocated stack + + build.br(x0); +} + +static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) +{ + // x0 = closure object to reentry (equal to clvalue(L->ci->func)) + + // If the fallback requested an exit, we need to do this right away + build.cbz(x0, helpers.exitNoContinueVm); + + emitUpdateBase(build); + + // Need to update state of the current function before we jump away + build.ldr(x1, mem(x0, offsetof(Closure, l.p))); // cl->l.p aka proto + + build.mov(rClosure, x0); + build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k + build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code + + // Get instruction index from instruction pointer + // To get instruction index from instruction pointer, we need to divide byte offset by 4 + // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out + build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci + build.ldr(x2, mem(x2, offsetof(CallInfo, savedpc))); // L->ci->savedpc + build.sub(x2, x2, rCode); + build.add(x2, x2, x2); // TODO: this would not be necessary if we supported shifted register offsets in loads + + // We need to check if the new function can be executed natively + // TODO: This can be done earlier in the function flow, to reduce the JIT->VM transition penalty + build.ldr(x1, mem(x1, offsetofProtoExecData)); + build.cbz(x1, helpers.exitContinueVm); + + // Get new instruction location and jump to it + build.ldr(x1, mem(x1, offsetof(NativeProto, instTargets))); + build.ldr(x1, mem(x1, x2)); + build.br(x1); +} + +static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilder& unwind) +{ + EntryLocations locations; // Arguments: x0 = lua_State*, x1 = Proto*, x2 = native code pointer to jump to, x3 = NativeContext* - unwind.start(); + locations.start = build.setLabel(); + unwind.startFunction(); + unwind.allocStack(8); // TODO: this is just a hack to make UnwindBuilder assertions cooperate // prologue @@ -38,9 +131,7 @@ bool initEntryFunction(NativeState& data) build.mov(x29, sp); // this is only necessary if we maintain frame pointers, which we do in the JIT for now - unwind.finish(); - - size_t prologueSize = build.setLabel().location; + locations.prologueEnd = build.setLabel(); // Setup native execution environment build.mov(rState, x0); @@ -58,7 +149,7 @@ bool initEntryFunction(NativeState& data) build.br(x2); // Even though we jumped away, we will return here in the end - Label returnOff = build.setLabel(); + locations.epilogueStart = build.setLabel(); // Cleanup and exit build.ldp(x23, x24, mem(sp, 48)); @@ -69,12 +160,30 @@ bool initEntryFunction(NativeState& data) build.ret(); + // Our entry function is special, it spans the whole remaining code area + unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); + + return locations; +} + +bool initHeaderFunctions(NativeState& data) +{ + AssemblyBuilderA64 build(/* logText= */ false); + UnwindBuilder& unwind = *data.unwindBuilder.get(); + + unwind.startInfo(); + + EntryLocations entryLocations = buildEntryFunction(build, unwind); + build.finalize(); + unwind.finishInfo(); + LUAU_ASSERT(build.data.empty()); + uint8_t* codeStart = nullptr; if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), - int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, data.context.gateEntry)) + int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, codeStart)) { LUAU_ASSERT(!"failed to create entry function"); return false; @@ -82,9 +191,10 @@ bool initEntryFunction(NativeState& data) // Set the offset at the begining so that functions in new blocks will not overlay the locations // specified by the unwind information of the entry function - unwind.setBeginOffset(prologueSize); + unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd)); - data.context.gateExit = data.context.gateEntry + build.getLabelOffset(returnOff); + data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start); + data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart); return true; } diff --git a/CodeGen/src/CodeGenA64.h b/CodeGen/src/CodeGenA64.h index 7b792cc1b..f6fda7262 100644 --- a/CodeGen/src/CodeGenA64.h +++ b/CodeGen/src/CodeGenA64.h @@ -14,7 +14,7 @@ namespace A64 class AssemblyBuilderA64; -bool initEntryFunction(NativeState& data); +bool initHeaderFunctions(NativeState& data); void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers); } // namespace A64 diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index ae3dbd452..7a9192ab0 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -13,12 +13,58 @@ namespace Luau namespace CodeGen { +bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra) +{ + int sizearray = h->sizearray; + + // first we advance index through the array portion + while (unsigned(index) < unsigned(sizearray)) + { + TValue* e = &h->array[index]; + + if (!ttisnil(e)) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setnvalue(ra + 3, double(index + 1)); + setobj2s(L, ra + 4, e); + + return true; + } + + index++; + } + + int sizenode = 1 << h->lsizenode; + + // then we advance index through the hash portion + while (unsigned(index - h->sizearray) < unsigned(sizenode)) + { + LuaNode* n = &h->node[index - sizearray]; + + if (!ttisnil(gval(n))) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + getnodekey(L, ra + 3, n); + setobj(L, ra + 4, gval(n)); + + return true; + } + + index++; + } + + return false; +} + bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra) { + int sizearray = h->sizearray; + int sizenode = 1 << h->lsizenode; + // then we advance index through the hash portion - while (unsigned(index - h->sizearray) < unsigned(1 << h->lsizenode)) + while (unsigned(index - sizearray) < unsigned(sizenode)) { - LuaNode* n = &h->node[index - h->sizearray]; + LuaNode* n = &h->node[index - sizearray]; if (!ttisnil(gval(n))) { diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 6066a691c..10e88c130 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -8,6 +8,7 @@ namespace Luau namespace CodeGen { +bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra); bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra); bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux); diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 7df1a909d..2acb69f96 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -41,12 +41,21 @@ namespace CodeGen namespace X64 { -bool initEntryFunction(NativeState& data) +struct EntryLocations { - AssemblyBuilderX64 build(/* logText= */ false); - UnwindBuilder& unwind = *data.unwindBuilder.get(); + Label start; + Label prologueEnd; + Label epilogueStart; +}; + +static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilder& unwind) +{ + EntryLocations locations; - unwind.start(); + build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); + + locations.start = build.setLabel(); + unwind.startFunction(); // Save common non-volatile registers build.push(rbp); @@ -84,9 +93,7 @@ bool initEntryFunction(NativeState& data) build.sub(rsp, kStackSize + kLocalsSize); unwind.allocStack(kStackSize + kLocalsSize); - unwind.finish(); - - size_t prologueSize = build.setLabel().location; + locations.prologueEnd = build.setLabel(); // Setup native execution environment build.mov(rState, rArg1); @@ -104,7 +111,7 @@ bool initEntryFunction(NativeState& data) build.jmp(rArg3); // Even though we jumped away, we will return here in the end - Label returnOff = build.setLabel(); + locations.epilogueStart = build.setLabel(); // Cleanup and exit build.add(rsp, kStackSize + kLocalsSize); @@ -123,12 +130,30 @@ bool initEntryFunction(NativeState& data) build.pop(rbp); build.ret(); + // Our entry function is special, it spans the whole remaining code area + unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); + + return locations; +} + +bool initHeaderFunctions(NativeState& data) +{ + AssemblyBuilderX64 build(/* logText= */ false); + UnwindBuilder& unwind = *data.unwindBuilder.get(); + + unwind.startInfo(); + + EntryLocations entryLocations = buildEntryFunction(build, unwind); + build.finalize(); + unwind.finishInfo(); + LUAU_ASSERT(build.data.empty()); - if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, - data.gateDataSize, data.context.gateEntry)) + uint8_t* codeStart = nullptr; + if (!data.codeAllocator.allocate( + build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, data.gateDataSize, codeStart)) { LUAU_ASSERT(!"failed to create entry function"); return false; @@ -136,9 +161,10 @@ bool initEntryFunction(NativeState& data) // Set the offset at the begining so that functions in new blocks will not overlay the locations // specified by the unwind information of the entry function - unwind.setBeginOffset(prologueSize); + unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd)); - data.context.gateExit = data.context.gateEntry + returnOff.location; + data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start); + data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart); return true; } diff --git a/CodeGen/src/CodeGenX64.h b/CodeGen/src/CodeGenX64.h index 1f4831138..1f0f27d91 100644 --- a/CodeGen/src/CodeGenX64.h +++ b/CodeGen/src/CodeGenX64.h @@ -14,7 +14,7 @@ namespace X64 class AssemblyBuilderX64; -bool initEntryFunction(NativeState& data); +bool initHeaderFunctions(NativeState& data); void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers); } // namespace X64 diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index b010ce627..4026b955f 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -107,48 +107,12 @@ void emitBuiltinMathLog(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npar regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); - if (nparams == 1) - { - build.call(qword[rNativeContext + offsetof(NativeContext, libm_log)]); - } - else - { - Label log10check, logdivlog, exit; - - // Using 'rbx' for non-volatile temporary storage of log(arg1) result - RegisterX64 tmp = rbx; - OperandX64 arg2value = qword[args + offsetof(TValue, value)]; - - build.vmovsd(xmm1, arg2value); - - jumpOnNumberCmp(build, noreg, build.f64(2.0), xmm1, IrCondition::NotEqual, log10check); - + // TODO: IR builtin lowering assumes that the only valid 2-argument call is log2; ideally, we use a less hacky way to indicate that + if (nparams == 2) build.call(qword[rNativeContext + offsetof(NativeContext, libm_log2)]); - build.jmp(exit); - - build.setLabel(log10check); - jumpOnNumberCmp(build, noreg, build.f64(10.0), xmm1, IrCondition::NotEqual, logdivlog); - - build.call(qword[rNativeContext + offsetof(NativeContext, libm_log10)]); - build.jmp(exit); - - build.setLabel(logdivlog); - - // log(arg1) - build.call(qword[rNativeContext + offsetof(NativeContext, libm_log)]); - build.vmovq(tmp, xmm0); - - // log(arg2) - build.vmovsd(xmm0, arg2value); + else build.call(qword[rNativeContext + offsetof(NativeContext, libm_log)]); - // log(arg1) / log(arg2) - build.vmovq(xmm1, tmp); - build.vdivsd(xmm0, xmm1, xmm0); - - build.setLabel(exit); - } - build.vmovsd(luauRegValue(ra), xmm0); } @@ -256,62 +220,68 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r switch (bfid) { - case LBF_ASSERT: - case LBF_MATH_DEG: - case LBF_MATH_RAD: - case LBF_MATH_MIN: - case LBF_MATH_MAX: - case LBF_MATH_CLAMP: - case LBF_MATH_FLOOR: - case LBF_MATH_CEIL: - case LBF_MATH_SQRT: - case LBF_MATH_POW: - case LBF_MATH_ABS: - case LBF_MATH_ROUND: - // These instructions are fully translated to IR - break; case LBF_MATH_EXP: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathExp(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_FMOD: + LUAU_ASSERT(nparams == 2 && nresults == 1); return emitBuiltinMathFmod(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ASIN: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathAsin(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SIN: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathSin(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SINH: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathSinh(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ACOS: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathAcos(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_COS: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathCos(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_COSH: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathCosh(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ATAN: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathAtan(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_TAN: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathTan(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_TANH: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathTanh(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ATAN2: + LUAU_ASSERT(nparams == 2 && nresults == 1); return emitBuiltinMathAtan2(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_LOG10: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathLog10(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_LOG: + LUAU_ASSERT((nparams == 1 || nparams == 2) && nresults == 1); return emitBuiltinMathLog(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_LDEXP: + LUAU_ASSERT(nparams == 2 && nresults == 1); return emitBuiltinMathLdexp(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_FREXP: + LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); return emitBuiltinMathFrexp(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_MODF: + LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); return emitBuiltinMathModf(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SIGN: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathSign(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_TYPE: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinType(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_TYPEOF: + LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinTypeof(regs, build, nparams, ra, arg, argsOp, nresults); default: - LUAU_ASSERT(!"missing x64 lowering"); + LUAU_ASSERT(!"Missing x64 lowering"); break; } } diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index a71eafd4c..6a7496694 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -13,8 +13,8 @@ constexpr unsigned kLuaNodeSizeLog2 = 5; constexpr unsigned kLuaNodeTagMask = 0xf; constexpr unsigned kNextBitOffset = 4; -constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field -constexpr unsigned kOffsetOfLuaNodeNext = 12; // offsetof cannot be used on a bit field +constexpr unsigned kOffsetOfTKeyTag = 12; // offsetof cannot be used on a bit field +constexpr unsigned kOffsetOfTKeyNext = 12; // offsetof cannot be used on a bit field constexpr unsigned kOffsetOfInstructionC = 3; // Leaf functions that are placed in every module to perform common instruction sequences diff --git a/CodeGen/src/EmitCommonA64.cpp b/CodeGen/src/EmitCommonA64.cpp deleted file mode 100644 index 1758e4fb1..000000000 --- a/CodeGen/src/EmitCommonA64.cpp +++ /dev/null @@ -1,130 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "EmitCommonA64.h" - -#include "NativeState.h" -#include "CustomExecUtils.h" - -namespace Luau -{ -namespace CodeGen -{ -namespace A64 -{ - -void emitUpdateBase(AssemblyBuilderA64& build) -{ - build.ldr(rBase, mem(rState, offsetof(lua_State, base))); -} - -void emitExit(AssemblyBuilderA64& build, bool continueInVm) -{ - build.mov(x0, continueInVm); - build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, gateExit))); - build.br(x1); -} - -void emitInterrupt(AssemblyBuilderA64& build) -{ - // x0 = pc offset - // x1 = return address in native code - // x2 = interrupt - - // Stash return address in rBase; we need to reload rBase anyway - build.mov(rBase, x1); - - // Update savedpc; required in case interrupt errors - build.add(x0, rCode, x0); - build.ldr(x1, mem(rState, offsetof(lua_State, ci))); - build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); - - // Call interrupt - build.mov(x0, rState); - build.mov(w1, -1); - build.blr(x2); - - // Check if we need to exit - Label skip; - build.ldrb(w0, mem(rState, offsetof(lua_State, status))); - build.cbz(w0, skip); - - // L->ci->savedpc-- - // note: recomputing this avoids having to stash x0 - build.ldr(x1, mem(rState, offsetof(lua_State, ci))); - build.ldr(x0, mem(x1, offsetof(CallInfo, savedpc))); - build.sub(x0, x0, sizeof(Instruction)); - build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); - - emitExit(build, /* continueInVm */ false); - - build.setLabel(skip); - - // Return back to caller; rBase has stashed return address - build.mov(x0, rBase); - - emitUpdateBase(build); // interrupt may have reallocated stack - - build.br(x0); -} - -void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) -{ - // x0 = closure object to reentry (equal to clvalue(L->ci->func)) - - // If the fallback requested an exit, we need to do this right away - build.cbz(x0, helpers.exitNoContinueVm); - - emitUpdateBase(build); - - // Need to update state of the current function before we jump away - build.ldr(x1, mem(x0, offsetof(Closure, l.p))); // cl->l.p aka proto - - build.mov(rClosure, x0); - build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k - build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code - - // Get instruction index from instruction pointer - // To get instruction index from instruction pointer, we need to divide byte offset by 4 - // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out - build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci - build.ldr(x2, mem(x2, offsetof(CallInfo, savedpc))); // L->ci->savedpc - build.sub(x2, x2, rCode); - build.add(x2, x2, x2); // TODO: this would not be necessary if we supported shifted register offsets in loads - - // We need to check if the new function can be executed natively - // TODO: This can be done earlier in the function flow, to reduce the JIT->VM transition penalty - build.ldr(x1, mem(x1, offsetofProtoExecData)); - build.cbz(x1, helpers.exitContinueVm); - - // Get new instruction location and jump to it - build.ldr(x1, mem(x1, offsetof(NativeProto, instTargets))); - build.ldr(x1, mem(x1, x2)); - build.br(x1); -} - -void emitFallback(AssemblyBuilderA64& build, int op, int pcpos) -{ - // fallback(L, instruction, base, k) - build.mov(x0, rState); - - // TODO: refactor into a common helper - if (pcpos * sizeof(Instruction) <= AssemblyBuilderA64::kMaxImmediate) - { - build.add(x1, rCode, uint16_t(pcpos * sizeof(Instruction))); - } - else - { - build.mov(x1, pcpos * sizeof(Instruction)); - build.add(x1, rCode, x1); - } - - build.mov(x2, rBase); - build.mov(x3, rConstants); - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback))); - build.blr(x4); - - emitUpdateBase(build); -} - -} // namespace A64 -} // namespace CodeGen -} // namespace Luau diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index 2a65afa8f..8cb54c1d3 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -7,6 +7,7 @@ #include "lobject.h" #include "ltm.h" +#include "lstate.h" // AArch64 ABI reminder: // Arguments: x0-x7, v0-v7 @@ -38,15 +39,19 @@ constexpr RegisterA64 rBase = x24; // StkId base // Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point // See CodeGenA64.cpp for layout -constexpr unsigned kStackSize = 64; // 8 stashed registers +constexpr unsigned kStashSlots = 8; // stashed non-volatile registers +constexpr unsigned kSpillSlots = 0; // slots for spilling temporary registers (unused) +constexpr unsigned kTempSlots = 2; // 16 bytes of temporary space, such luxury! -void emitUpdateBase(AssemblyBuilderA64& build); +constexpr unsigned kStackSize = (kStashSlots + kSpillSlots + kTempSlots) * 8; -// TODO: Move these to CodeGenA64 so that they can't be accidentally called during lowering -void emitExit(AssemblyBuilderA64& build, bool continueInVm); -void emitInterrupt(AssemblyBuilderA64& build); -void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers); -void emitFallback(AssemblyBuilderA64& build, int op, int pcpos); +constexpr AddressA64 sSpillArea = mem(sp, kStashSlots * 8); +constexpr AddressA64 sTemporary = mem(sp, (kStashSlots + kSpillSlots) * 8); + +inline void emitUpdateBase(AssemblyBuilderA64& build) +{ + build.ldr(rBase, mem(rState, offsetof(lua_State, base))); +} } // namespace A64 } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 9136add85..b6d8b85ec 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -279,32 +279,37 @@ void emitUpdateBase(AssemblyBuilderX64& build) build.mov(rBase, qword[rState + offsetof(lua_State, base)]); } -// Note: only uses rax/rdx, the caller may use other registers -static void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos) +static void emitSetSavedPc(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos) { - build.mov(rdx, sCode); - build.add(rdx, pcpos * sizeof(Instruction)); - build.mov(rax, qword[rState + offsetof(lua_State, ci)]); - build.mov(qword[rax + offsetof(CallInfo, savedpc)], rdx); + ScopedRegX64 tmp1{regs, SizeX64::qword}; + ScopedRegX64 tmp2{regs, SizeX64::qword}; + + build.mov(tmp1.reg, sCode); + build.add(tmp1.reg, pcpos * sizeof(Instruction)); + build.mov(tmp2.reg, qword[rState + offsetof(lua_State, ci)]); + build.mov(qword[tmp2.reg + offsetof(CallInfo, savedpc)], tmp1.reg); } -void emitInterrupt(AssemblyBuilderX64& build, int pcpos) +void emitInterrupt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos) { Label skip; + ScopedRegX64 tmp{regs, SizeX64::qword}; + // Skip if there is no interrupt set - build.mov(r8, qword[rState + offsetof(lua_State, global)]); - build.mov(r8, qword[r8 + offsetof(global_State, cb.interrupt)]); - build.test(r8, r8); + build.mov(tmp.reg, qword[rState + offsetof(lua_State, global)]); + build.mov(tmp.reg, qword[tmp.reg + offsetof(global_State, cb.interrupt)]); + build.test(tmp.reg, tmp.reg); build.jcc(ConditionX64::Zero, skip); - emitSetSavedPc(build, pcpos + 1); // uses rax/rdx + emitSetSavedPc(regs, build, pcpos + 1); // Call interrupt // TODO: This code should move to the end of the function, or even be outlined so that it can be shared by multiple interruptible instructions - build.mov(rArg1, rState); - build.mov(dwordReg(rArg2), -1); // function accepts 'int' here and using qword reg would've forced 8 byte constant here - build.call(r8); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, -1); + callWrap.call(tmp.release()); emitUpdateBase(build); // interrupt may have reallocated stack @@ -320,41 +325,23 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos) build.setLabel(skip); } -void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos) +void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, NativeState& data, int op, int pcpos) { - NativeFallback& opinfo = data.context.fallback[op]; - LUAU_ASSERT(opinfo.fallback); - - if (build.logText) - build.logAppend("; fallback\n"); + LUAU_ASSERT(data.context.fallback[op]); // fallback(L, instruction, base, k) - build.mov(rArg1, rState); - build.mov(rArg2, sCode); - build.add(rArg2, pcpos * sizeof(Instruction)); - build.mov(rArg3, rBase); - build.mov(rArg4, rConstants); - build.call(qword[rNativeContext + offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback)]); - - emitUpdateBase(build); - - // Some instructions may jump to a different instruction or a completely different function - if (opinfo.flags & kFallbackUpdatePc) - { - build.mov(rcx, sClosure); - build.mov(rcx, qword[rcx + offsetof(Closure, l.p)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); - // Get instruction index from returned instruction pointer - // To get instruction index from instruction pointer, we need to divide byte offset by 4 - // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out - build.sub(rax, sCode); + RegisterX64 reg = callWrap.suggestNextArgumentRegister(SizeX64::qword); + build.mov(reg, sCode); + callWrap.addArgument(SizeX64::qword, addr[reg + pcpos * sizeof(Instruction)]); - build.mov(rdx, qword[rcx + offsetofProtoExecData]); + callWrap.addArgument(SizeX64::qword, rBase); + callWrap.addArgument(SizeX64::qword, rConstants); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, fallback) + op * sizeof(FallbackFn)]); - // Get new instruction location and jump to it - build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]); - build.jmp(qword[rax * 2 + rcx]); - } + emitUpdateBase(build); } void emitContinueCallInVm(AssemblyBuilderX64& build) diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 6aac5a1ec..d4684fe85 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -34,6 +34,8 @@ namespace X64 struct IrRegAllocX64; +constexpr uint32_t kFunctionAlignment = 32; + // Data that is very common to access is placed in non-volatile registers constexpr RegisterX64 rState = r15; // lua_State* L constexpr RegisterX64 rBase = r14; // StkId base @@ -134,7 +136,7 @@ inline OperandX64 luauNodeKeyValue(RegisterX64 node) // Note: tag has dirty upper bits inline OperandX64 luauNodeKeyTag(RegisterX64 node) { - return dword[node + offsetof(LuaNode, key) + kOffsetOfLuaNodeTag]; + return dword[node + offsetof(LuaNode, key) + kOffsetOfTKeyTag]; } inline OperandX64 luauNodeValue(RegisterX64 node) @@ -162,12 +164,6 @@ inline void jumpIfTagIsNot(AssemblyBuilderX64& build, int ri, lua_Type tag, Labe build.jcc(ConditionX64::NotEqual, label); } -inline void jumpIfTagIsNot(AssemblyBuilderX64& build, RegisterX64 reg, lua_Type tag, Label& label) -{ - build.cmp(dword[reg + offsetof(TValue, tt)], tag); - build.jcc(ConditionX64::NotEqual, label); -} - // Note: fallthrough label should be placed after this condition inline void jumpIfFalsy(AssemblyBuilderX64& build, int ri, Label& target, Label& fallthrough) { @@ -188,26 +184,6 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label build.jcc(ConditionX64::NotEqual, target); // true if boolean value is 'true' } -inline void jumpIfMetatablePresent(AssemblyBuilderX64& build, RegisterX64 table, Label& target) -{ - build.cmp(qword[table + offsetof(Table, metatable)], 0); - build.jcc(ConditionX64::NotEqual, target); -} - -inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& label) -{ - build.mov(tmp, sClosure); - build.mov(tmp, qword[tmp + offsetof(Closure, env)]); - build.test(byte[tmp + offsetof(Table, safeenv)], 1); - build.jcc(ConditionX64::Zero, label); // Not a safe environment -} - -inline void jumpIfTableIsReadOnly(AssemblyBuilderX64& build, RegisterX64 table, Label& label) -{ - build.cmp(byte[table + offsetof(Table, readonly)], 0); - build.jcc(ConditionX64::NotEqual, label); -} - inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, lua_Type tag, Label& label) { tmp.size = SizeX64::dword; @@ -224,13 +200,6 @@ inline void jumpIfNodeValueTagIs(AssemblyBuilderX64& build, RegisterX64 node, lu build.jcc(ConditionX64::Equal, label); } -inline void jumpIfNodeHasNext(AssemblyBuilderX64& build, RegisterX64 node, Label& label) -{ - build.mov(ecx, dword[node + offsetof(LuaNode, key) + kOffsetOfLuaNodeNext]); - build.shr(ecx, kNextBitOffset); - build.jcc(ConditionX64::NotZero, label); -} - inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, OperandX64 expectedKey, Label& label) { jumpIfNodeKeyTagIsNot(build, tmp, node, LUA_TSTRING, label); @@ -260,8 +229,8 @@ void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitUpdateBase(AssemblyBuilderX64& build); -void emitInterrupt(AssemblyBuilderX64& build, int pcpos); -void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos); +void emitInterrupt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos); +void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, NativeState& data, int op, int pcpos); void emitContinueCallInVm(AssemblyBuilderX64& build); diff --git a/CodeGen/src/EmitInstructionA64.cpp b/CodeGen/src/EmitInstructionA64.cpp deleted file mode 100644 index 400ba77e0..000000000 --- a/CodeGen/src/EmitInstructionA64.cpp +++ /dev/null @@ -1,74 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "EmitInstructionA64.h" - -#include "Luau/AssemblyBuilderA64.h" - -#include "EmitCommonA64.h" -#include "NativeState.h" -#include "CustomExecUtils.h" - -namespace Luau -{ -namespace CodeGen -{ -namespace A64 -{ - -void emitInstReturn(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int n) -{ - // callFallback(L, ra, n) - build.mov(x0, rState); - build.add(x1, rBase, uint16_t(ra * sizeof(TValue))); - build.mov(w2, n); - build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, returnFallback))); - build.blr(x3); - - // reentry with x0=closure (NULL will trigger exit) - build.b(helpers.reentry); -} - -void emitInstCall(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults) -{ - // argtop = (nparams == LUA_MULTRET) ? L->top : ra + 1 + nparams; - if (nparams == LUA_MULTRET) - build.ldr(x2, mem(rState, offsetof(lua_State, top))); - else - build.add(x2, rBase, uint16_t((ra + 1 + nparams) * sizeof(TValue))); - - // callFallback(L, ra, argtop, nresults) - build.mov(x0, rState); - build.add(x1, rBase, uint16_t(ra * sizeof(TValue))); - build.mov(w3, nresults); - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, callFallback))); - build.blr(x4); - - // reentry with x0=closure (NULL will trigger exit) - build.b(helpers.reentry); -} - -void emitInstGetImport(AssemblyBuilderA64& build, int ra, uint32_t aux) -{ - // luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false) - build.mov(x0, rState); - build.ldr(x1, mem(rClosure, offsetof(Closure, env))); - build.mov(x2, rConstants); - build.mov(w3, aux); - build.mov(w4, 0); - build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_getimport))); - build.blr(x5); - - emitUpdateBase(build); - - // setobj2s(L, ra, L->top - 1) - build.ldr(x0, mem(rState, offsetof(lua_State, top))); - build.sub(x0, x0, sizeof(TValue)); - build.ldr(q0, x0); - build.str(q0, mem(rBase, ra * sizeof(TValue))); - - // L->top-- - build.str(x0, mem(rState, offsetof(lua_State, top))); -} - -} // namespace A64 -} // namespace CodeGen -} // namespace Luau diff --git a/CodeGen/src/EmitInstructionA64.h b/CodeGen/src/EmitInstructionA64.h deleted file mode 100644 index 278d8e8e3..000000000 --- a/CodeGen/src/EmitInstructionA64.h +++ /dev/null @@ -1,24 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include - -namespace Luau -{ -namespace CodeGen -{ - -struct ModuleHelpers; - -namespace A64 -{ - -class AssemblyBuilderA64; - -void emitInstReturn(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int n); -void emitInstCall(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); -void emitInstGetImport(AssemblyBuilderA64& build, int ra, uint32_t aux); - -} // namespace A64 -} // namespace CodeGen -} // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index c0a64274a..9a10bfdc1 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -415,7 +415,7 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int callBarrierTableFast(regs, build, table, {}); } -void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit) +void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat) { // ipairs-style traversal is handled in IR LUAU_ASSERT(aux >= 0); @@ -484,78 +484,6 @@ void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRep build.jcc(ConditionX64::NotZero, loopRepeat); } -void emitinstForGLoopFallback(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat) -{ - build.mov(rArg1, rState); - build.mov(dwordReg(rArg2), ra); - build.mov(dwordReg(rArg3), aux); - build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNonTableFallback)]); - emitUpdateBase(build); - build.test(al, al); - build.jcc(ConditionX64::NotZero, loopRepeat); -} - -void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target) -{ - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.mov(dwordReg(rArg3), pcpos + 1); - build.call(qword[rNativeContext + offsetof(NativeContext, forgPrepXnextFallback)]); - build.jmp(target); -} - -void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux) -{ - build.mov(rax, sClosure); - - // luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false) - build.mov(rArg1, rState); - build.mov(rArg2, qword[rax + offsetof(Closure, env)]); - build.mov(rArg3, rConstants); - build.mov(dwordReg(rArg4), aux); - - if (build.abi == ABIX64::Windows) - build.mov(sArg5, 0); - else - build.xor_(rArg5, rArg5); - - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_getimport)]); - - emitUpdateBase(build); - - // setobj2s(L, ra, L->top - 1) - build.mov(rax, qword[rState + offsetof(lua_State, top)]); - build.sub(rax, sizeof(TValue)); - build.vmovups(xmm0, xmmword[rax]); - build.vmovups(luauReg(ra), xmm0); - - // L->top-- - build.mov(qword[rState + offsetof(lua_State, top)], rax); -} - -void emitInstCoverage(AssemblyBuilderX64& build, int pcpos) -{ - build.mov(rcx, sCode); - build.add(rcx, pcpos * sizeof(Instruction)); - - // hits = LUAU_INSN_E(*pc) - build.mov(edx, dword[rcx]); - build.sar(edx, 8); - - // hits = (hits < (1 << 23) - 1) ? hits + 1 : hits; - build.xor_(eax, eax); - build.cmp(edx, (1 << 23) - 1); - build.setcc(ConditionX64::NotEqual, al); - build.add(edx, eax); - - - // VM_PATCH_E(pc, hits); - build.sal(edx, 8); - build.movzx(eax, byte[rcx]); - build.or_(eax, edx); - build.mov(dword[rcx], eax); -} - } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index d58e13310..84fe11309 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -20,11 +20,7 @@ struct IrRegAllocX64; void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults); void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index); -void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit); -void emitinstForGLoopFallback(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat); -void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target); -void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux); -void emitInstCoverage(AssemblyBuilderX64& build, int pcpos); +void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat); } // namespace X64 } // namespace CodeGen diff --git a/CodeGen/src/Fallbacks.cpp b/CodeGen/src/Fallbacks.cpp index e84ee2136..1c0dce57a 100644 --- a/CodeGen/src/Fallbacks.cpp +++ b/CodeGen/src/Fallbacks.cpp @@ -416,6 +416,44 @@ const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, Stk return pc; } +const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = &base[LUAU_INSN_B(insn)]; // note: this can point to L->top if c == LUA_MULTRET making VM_REG unsafe to use + int c = LUAU_INSN_C(insn) - 1; + uint32_t index = *pc++; + + if (c == LUA_MULTRET) + { + c = int(L->top - rb); + L->top = L->ci->top; + } + + Table* h = hvalue(ra); + + // TODO: we really don't need this anymore + if (!ttistable(ra)) + return NULL; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode + + int last = index + c - 1; + if (last > h->sizearray) + { + VM_PROTECT_PC(); // luaH_resizearray may fail due to OOM + + luaH_resizearray(L, h, last); + } + + TValue* array = h->array; + + for (int i = 0; i < c; ++i) + setobj2t(L, &array[index + i - 1], rb + i); + + luaC_barrierfast(L, h); + return pc; +} + const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k) { [[maybe_unused]] Closure* cl = clvalue(L->ci->func); diff --git a/CodeGen/src/Fallbacks.h b/CodeGen/src/Fallbacks.h index bfc0e2b7c..0d2d218a0 100644 --- a/CodeGen/src/Fallbacks.h +++ b/CodeGen/src/Fallbacks.h @@ -16,6 +16,7 @@ const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, S const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 2246e5c5e..f3870e96b 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -354,6 +354,8 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::RETURN: useRange(vmRegOp(inst.a), function.intOp(inst.b)); break; + + // TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it case IrCmd::FASTCALL: case IrCmd::INVOKE_FASTCALL: if (int count = function.intOp(inst.e); count != -1) diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 48c0e25c0..d86dfe058 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -468,7 +468,8 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) IrInst clone = function.instructions[index]; // Skip pseudo instructions to make clone more compact, but validate that they have no users - if (isPseudo(clone.cmd)) + // But if substitution tracks a location, that tracking has to be preserved + if (isPseudo(clone.cmd) && !(clone.cmd == IrCmd::SUBSTITUTE && clone.b.kind != IrOpKind::None)) { LUAU_ASSERT(clone.useCount == 0); continue; diff --git a/CodeGen/src/IrCallWrapperX64.cpp b/CodeGen/src/IrCallWrapperX64.cpp index 8ac5f8bcf..f466df4a8 100644 --- a/CodeGen/src/IrCallWrapperX64.cpp +++ b/CodeGen/src/IrCallWrapperX64.cpp @@ -13,6 +13,10 @@ namespace CodeGen namespace X64 { +static const std::array kWindowsGprOrder = {rcx, rdx, r8, r9, addr[rsp + 32], addr[rsp + 40]}; +static const std::array kSystemvGprOrder = {rdi, rsi, rdx, rcx, r8, r9}; +static const std::array kXmmOrder = {xmm0, xmm1, xmm2, xmm3}; // Common order for first 4 fp arguments on Windows/SystemV + static bool sameUnderlyingRegister(RegisterX64 a, RegisterX64 b) { SizeX64 underlyingSizeA = a.size == SizeX64::xmmword ? SizeX64::xmmword : SizeX64::qword; @@ -37,21 +41,35 @@ void IrCallWrapperX64::addArgument(SizeX64 targetSize, OperandX64 source, IrOp s LUAU_ASSERT(instIdx != kInvalidInstIdx || sourceOp.kind == IrOpKind::None); LUAU_ASSERT(argCount < kMaxCallArguments); - args[argCount++] = {targetSize, source, sourceOp}; + CallArgument& arg = args[argCount++]; + arg = {targetSize, source, sourceOp}; + + arg.target = getNextArgumentTarget(targetSize); + + if (build.abi == ABIX64::Windows) + { + // On Windows, gpr/xmm register positions move in sync + gprPos++; + xmmPos++; + } + else + { + if (targetSize == SizeX64::xmmword) + xmmPos++; + else + gprPos++; + } } void IrCallWrapperX64::addArgument(SizeX64 targetSize, ScopedRegX64& scopedReg) { - LUAU_ASSERT(argCount < kMaxCallArguments); - args[argCount++] = {targetSize, scopedReg.release(), {}}; + addArgument(targetSize, scopedReg.release(), {}); } void IrCallWrapperX64::call(const OperandX64& func) { funcOp = func; - assignTargetRegisters(); - countRegisterUses(); for (int i = 0; i < argCount; ++i) @@ -190,44 +208,33 @@ void IrCallWrapperX64::call(const OperandX64& func) build.call(funcOp); } -void IrCallWrapperX64::assignTargetRegisters() +RegisterX64 IrCallWrapperX64::suggestNextArgumentRegister(SizeX64 size) const { - static const std::array kWindowsGprOrder = {rcx, rdx, r8, r9, addr[rsp + 32], addr[rsp + 40]}; - static const std::array kSystemvGprOrder = {rdi, rsi, rdx, rcx, r8, r9}; - - const std::array& gprOrder = build.abi == ABIX64::Windows ? kWindowsGprOrder : kSystemvGprOrder; - static const std::array kXmmOrder = {xmm0, xmm1, xmm2, xmm3}; // Common order for first 4 fp arguments on Windows/SystemV + OperandX64 target = getNextArgumentTarget(size); - int gprPos = 0; - int xmmPos = 0; + return target.cat == CategoryX64::reg ? regs.takeReg(target.base, kInvalidInstIdx) : regs.allocReg(size, kInvalidInstIdx); +} - for (int i = 0; i < argCount; i++) +OperandX64 IrCallWrapperX64::getNextArgumentTarget(SizeX64 size) const +{ + if (size == SizeX64::xmmword) { - CallArgument& arg = args[i]; + LUAU_ASSERT(size_t(xmmPos) < kXmmOrder.size()); + return kXmmOrder[xmmPos]; + } - if (arg.targetSize == SizeX64::xmmword) - { - LUAU_ASSERT(size_t(xmmPos) < kXmmOrder.size()); - arg.target = kXmmOrder[xmmPos++]; + const std::array& gprOrder = build.abi == ABIX64::Windows ? kWindowsGprOrder : kSystemvGprOrder; - if (build.abi == ABIX64::Windows) - gprPos++; // On Windows, gpr/xmm register positions move in sync - } - else - { - LUAU_ASSERT(size_t(gprPos) < gprOrder.size()); - arg.target = gprOrder[gprPos++]; + LUAU_ASSERT(size_t(gprPos) < gprOrder.size()); + OperandX64 target = gprOrder[gprPos]; - if (build.abi == ABIX64::Windows) - xmmPos++; // On Windows, gpr/xmm register positions move in sync + // Keep requested argument size + if (target.cat == CategoryX64::reg) + target.base.size = size; + else if (target.cat == CategoryX64::mem) + target.memSize = size; - // Keep requested argument size - if (arg.target.cat == CategoryX64::reg) - arg.target.base.size = arg.targetSize; - else if (arg.target.cat == CategoryX64::mem) - arg.target.memSize = arg.targetSize; - } - } + return target; } void IrCallWrapperX64::countRegisterUses() @@ -376,7 +383,7 @@ RegisterX64 IrCallWrapperX64::findConflictingTarget() const void IrCallWrapperX64::renameConflictingRegister(RegisterX64 conflict) { // Get a fresh register - RegisterX64 freshReg = conflict.size == SizeX64::xmmword ? regs.allocXmmReg(kInvalidInstIdx) : regs.allocGprReg(conflict.size, kInvalidInstIdx); + RegisterX64 freshReg = regs.allocReg(conflict.size, kInvalidInstIdx); if (conflict.size == SizeX64::xmmword) build.vmovsd(freshReg, conflict, conflict); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 7f0305cc2..3f05d537a 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -8,7 +8,6 @@ #include "Luau/IrUtils.h" #include "EmitCommonA64.h" -#include "EmitInstructionA64.h" #include "NativeState.h" #include "lstate.h" @@ -27,13 +26,14 @@ namespace A64 #ifdef TRACE struct LoweringStatsA64 { - size_t can; + size_t missing; size_t total; ~LoweringStatsA64() { if (total) - printf("A64 lowering succeeded for %.1f%% functions (%d/%d)\n", double(can) / double(total) * 100, int(can), int(total)); + printf("A64 lowering succeeded for %.1f%% functions (%d/%d)\n", double(total - missing) / double(total) * 100, int(total - missing), + int(total)); } } gStatsA64; #endif @@ -78,32 +78,230 @@ inline ConditionA64 getConditionFP(IrCondition cond) } } -// TODO: instead of temp1/temp2 we can take a register that we will use for ra->value; that way callers to this function will be able to use it when -// calling luaC_barrier* -static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp1, RegisterA64 temp2, int ra, Label& skip) +static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp, int ra, Label& skip) { - RegisterA64 temp1w = castReg(KindA64::w, temp1); - RegisterA64 temp2w = castReg(KindA64::w, temp2); + RegisterA64 tempw = castReg(KindA64::w, temp); // iscollectable(ra) - build.ldr(temp1w, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, tt))); - build.cmp(temp1w, LUA_TSTRING); + build.ldr(tempw, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, tt))); + build.cmp(tempw, LUA_TSTRING); build.b(ConditionA64::Less, skip); // isblack(obj2gco(o)) // TODO: conditional bit test with BLACKBIT - build.ldrb(temp1w, mem(object, offsetof(GCheader, marked))); - build.mov(temp2w, bitmask(BLACKBIT)); - build.and_(temp1w, temp1w, temp2w); - build.cbz(temp1w, skip); + build.ldrb(tempw, mem(object, offsetof(GCheader, marked))); + build.tst(tempw, bitmask(BLACKBIT)); + build.b(ConditionA64::Equal, skip); // Equal = Zero after tst // iswhite(gcvalue(ra)) - // TODO: tst with bitmask(WHITE0BIT, WHITE1BIT) - build.ldr(temp1, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, value))); - build.ldrb(temp1w, mem(temp1, offsetof(GCheader, marked))); - build.mov(temp2w, bit2mask(WHITE0BIT, WHITE1BIT)); - build.and_(temp1w, temp1w, temp2w); - build.cbz(temp1w, skip); + build.ldr(temp, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, value))); + build.ldrb(tempw, mem(temp, offsetof(GCheader, marked))); + build.tst(tempw, bit2mask(WHITE0BIT, WHITE1BIT)); + build.b(ConditionA64::Equal, skip); // Equal = Zero after tst +} + +static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA64 src, size_t offset) +{ + LUAU_ASSERT(dst != src); + LUAU_ASSERT(offset <= INT_MAX); + + if (offset <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(dst, src, uint16_t(offset)); + } + else + { + build.mov(dst, int(offset)); + build.add(dst, dst, src); + } +} + +static void emitFallback(AssemblyBuilderA64& build, int op, int pcpos) +{ + // fallback(L, instruction, base, k) + build.mov(x0, rState); + emitAddOffset(build, x1, rCode, pcpos * sizeof(Instruction)); + build.mov(x2, rBase); + build.mov(x3, rConstants); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, fallback) + op * sizeof(FallbackFn))); + build.blr(x4); + + emitUpdateBase(build); +} + +static void emitInvokeLibm1(AssemblyBuilderA64& build, size_t func, int res, int arg) +{ + build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); + build.ldr(x0, mem(rNativeContext, uint32_t(func))); + build.blr(x0); + build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); +} + +static void emitInvokeLibm2(AssemblyBuilderA64& build, size_t func, int res, int arg, IrOp args, bool argsInt = false) +{ + if (args.kind == IrOpKind::VmReg) + build.ldr(d1, mem(rBase, args.index * sizeof(TValue) + offsetof(TValue, value.n))); + else if (args.kind == IrOpKind::VmConst) + { + size_t constantOffset = args.index * sizeof(TValue) + offsetof(TValue, value.n); + + // Note: cumulative offset is guaranteed to be divisible by 8 (since we're loading a double); we can use that to expand the useful range that + // doesn't require temporaries + if (constantOffset / 8 <= AddressA64::kMaxOffset) + { + build.ldr(d1, mem(rConstants, int(constantOffset))); + } + else + { + emitAddOffset(build, x0, rConstants, constantOffset); + build.ldr(d1, x0); + } + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + + if (argsInt) + build.fcvtzs(w0, d1); + + build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); + build.ldr(x1, mem(rNativeContext, uint32_t(func))); + build.blr(x1); + build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); +} + +static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg) +{ + build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); + build.add(x0, sp, sTemporary.data); // sp-relative offset + build.ldr(x1, mem(rNativeContext, uint32_t(func))); + build.blr(x1); +} + +static bool emitBuiltin(AssemblyBuilderA64& build, IrRegAllocA64& regs, int bfid, int res, int arg, IrOp args, int nparams, int nresults) +{ + switch (bfid) + { + case LBF_MATH_EXP: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_exp), res, arg); + return true; + case LBF_MATH_FMOD: + LUAU_ASSERT(nparams == 2 && nresults == 1); + emitInvokeLibm2(build, offsetof(NativeContext, libm_fmod), res, arg, args); + return true; + case LBF_MATH_ASIN: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_asin), res, arg); + return true; + case LBF_MATH_SIN: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_sin), res, arg); + return true; + case LBF_MATH_SINH: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_sinh), res, arg); + return true; + case LBF_MATH_ACOS: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_acos), res, arg); + return true; + case LBF_MATH_COS: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_cos), res, arg); + return true; + case LBF_MATH_COSH: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_cosh), res, arg); + return true; + case LBF_MATH_ATAN: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_atan), res, arg); + return true; + case LBF_MATH_TAN: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_tan), res, arg); + return true; + case LBF_MATH_TANH: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_tanh), res, arg); + return true; + case LBF_MATH_ATAN2: + LUAU_ASSERT(nparams == 2 && nresults == 1); + emitInvokeLibm2(build, offsetof(NativeContext, libm_atan2), res, arg, args); + return true; + case LBF_MATH_LOG10: + LUAU_ASSERT(nparams == 1 && nresults == 1); + emitInvokeLibm1(build, offsetof(NativeContext, libm_log10), res, arg); + return true; + case LBF_MATH_LOG: + LUAU_ASSERT((nparams == 1 || nparams == 2) && nresults == 1); + // TODO: IR builtin lowering assumes that the only valid 2-argument call is log2; ideally, we use a less hacky way to indicate that + if (nparams == 2) + emitInvokeLibm1(build, offsetof(NativeContext, libm_log2), res, arg); + else + emitInvokeLibm1(build, offsetof(NativeContext, libm_log), res, arg); + return true; + case LBF_MATH_LDEXP: + LUAU_ASSERT(nparams == 2 && nresults == 1); + emitInvokeLibm2(build, offsetof(NativeContext, libm_ldexp), res, arg, args, /* argsInt= */ true); + return true; + case LBF_MATH_FREXP: + LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); + emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); + build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); + if (nresults == 2) + { + build.ldr(w0, sTemporary); + build.scvtf(d1, w0); + build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); + } + return true; + case LBF_MATH_MODF: + LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); + emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); + build.ldr(d1, sTemporary); + build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); + if (nresults == 2) + build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); + return true; + case LBF_MATH_SIGN: + LUAU_ASSERT(nparams == 1 && nresults == 1); + // TODO: this can be improved with fmov(constant), for now we just load from memory + build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); + build.fcmpz(d0); + build.adr(x0, 0.0); + build.ldr(d0, x0); + build.adr(x0, 1.0); + build.ldr(d1, x0); + build.fcsel(d0, d1, d0, getConditionFP(IrCondition::Greater)); + build.adr(x0, -1.0); + build.ldr(d1, x0); + build.fcsel(d0, d1, d0, getConditionFP(IrCondition::Less)); + build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); + return true; + + case LBF_TYPE: + build.ldr(w0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, tt))); + build.ldr(x1, mem(rState, offsetof(lua_State, global))); + // TODO: this can use load with shifted/extended offset + LUAU_ASSERT(sizeof(TString*) == 8); + build.add(x1, x1, zextReg(w0), 3); + build.ldr(x0, mem(x1, offsetof(global_State, ttname))); + build.str(x0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.gc))); + return true; + + case LBF_TYPEOF: + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(arg * sizeof(TValue))); + build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaT_objtypenamestr))); + build.blr(x2); + build.str(x0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.gc))); + return true; + + default: + LUAU_ASSERT(!"Missing A64 lowering"); + return false; + } } IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) @@ -116,119 +314,10 @@ IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, { // In order to allocate registers during lowering, we need to know where instruction results are last used updateLastUseLocations(function); -} -// TODO: Eventually this can go away -bool IrLoweringA64::canLower(const IrFunction& function) -{ #ifdef TRACE gStatsA64.total++; #endif - - for (const IrInst& inst : function.instructions) - { - switch (inst.cmd) - { - case IrCmd::NOP: - case IrCmd::LOAD_TAG: - case IrCmd::LOAD_POINTER: - case IrCmd::LOAD_DOUBLE: - case IrCmd::LOAD_INT: - case IrCmd::LOAD_TVALUE: - case IrCmd::LOAD_NODE_VALUE_TV: - case IrCmd::LOAD_ENV: - case IrCmd::GET_ARR_ADDR: - case IrCmd::GET_SLOT_NODE_ADDR: - case IrCmd::GET_HASH_NODE_ADDR: - case IrCmd::STORE_TAG: - case IrCmd::STORE_POINTER: - case IrCmd::STORE_DOUBLE: - case IrCmd::STORE_INT: - case IrCmd::STORE_TVALUE: - case IrCmd::STORE_NODE_VALUE_TV: - case IrCmd::ADD_INT: - case IrCmd::SUB_INT: - case IrCmd::ADD_NUM: - case IrCmd::SUB_NUM: - case IrCmd::MUL_NUM: - case IrCmd::DIV_NUM: - case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: - case IrCmd::MIN_NUM: - case IrCmd::MAX_NUM: - case IrCmd::UNM_NUM: - case IrCmd::FLOOR_NUM: - case IrCmd::CEIL_NUM: - case IrCmd::ROUND_NUM: - case IrCmd::SQRT_NUM: - case IrCmd::ABS_NUM: - case IrCmd::JUMP: - case IrCmd::JUMP_IF_TRUTHY: - case IrCmd::JUMP_IF_FALSY: - case IrCmd::JUMP_EQ_TAG: - case IrCmd::JUMP_EQ_INT: - case IrCmd::JUMP_EQ_POINTER: - case IrCmd::JUMP_CMP_NUM: - case IrCmd::JUMP_CMP_ANY: - case IrCmd::TABLE_LEN: - case IrCmd::NEW_TABLE: - case IrCmd::DUP_TABLE: - case IrCmd::TRY_NUM_TO_INDEX: - case IrCmd::INT_TO_NUM: - case IrCmd::ADJUST_STACK_TO_REG: - case IrCmd::ADJUST_STACK_TO_TOP: - case IrCmd::INVOKE_FASTCALL: - case IrCmd::CHECK_FASTCALL_RES: - case IrCmd::DO_ARITH: - case IrCmd::DO_LEN: - case IrCmd::GET_TABLE: - case IrCmd::SET_TABLE: - case IrCmd::GET_IMPORT: - case IrCmd::CONCAT: - case IrCmd::GET_UPVALUE: - case IrCmd::SET_UPVALUE: - case IrCmd::PREPARE_FORN: - case IrCmd::CHECK_TAG: - case IrCmd::CHECK_READONLY: - case IrCmd::CHECK_NO_METATABLE: - case IrCmd::CHECK_SAFE_ENV: - case IrCmd::CHECK_ARRAY_SIZE: - case IrCmd::CHECK_SLOT_MATCH: - case IrCmd::INTERRUPT: - case IrCmd::CHECK_GC: - case IrCmd::BARRIER_OBJ: - case IrCmd::BARRIER_TABLE_BACK: - case IrCmd::BARRIER_TABLE_FORWARD: - case IrCmd::SET_SAVEDPC: - case IrCmd::CLOSE_UPVALS: - case IrCmd::CAPTURE: - case IrCmd::CALL: - case IrCmd::RETURN: - case IrCmd::FALLBACK_GETGLOBAL: - case IrCmd::FALLBACK_SETGLOBAL: - case IrCmd::FALLBACK_GETTABLEKS: - case IrCmd::FALLBACK_SETTABLEKS: - case IrCmd::FALLBACK_NAMECALL: - case IrCmd::FALLBACK_PREPVARARGS: - case IrCmd::FALLBACK_GETVARARGS: - case IrCmd::FALLBACK_NEWCLOSURE: - case IrCmd::FALLBACK_DUPCLOSURE: - case IrCmd::SUBSTITUTE: - continue; - - default: -#ifdef TRACE - printf("A64 lowering missing %s\n", getCmdName(inst.cmd)); -#endif - return false; - } - } - -#ifdef TRACE - gStatsA64.can++; -#endif - - return true; } void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) @@ -245,14 +334,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::LOAD_POINTER: { inst.regA64 = regs.allocReg(KindA64::x); - AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value.gc)); build.ldr(inst.regA64, addr); break; } case IrCmd::LOAD_DOUBLE: { inst.regA64 = regs.allocReg(KindA64::d); - AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value.n)); build.ldr(inst.regA64, addr); break; } @@ -287,13 +376,21 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (inst.b.kind == IrOpKind::Inst) { - // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. - build.add(inst.regA64, inst.regA64, castReg(KindA64::x, regOp(inst.b)), kTValueSizeLog2); + build.add(inst.regA64, inst.regA64, zextReg(regOp(inst.b)), kTValueSizeLog2); } else if (inst.b.kind == IrOpKind::Constant) { - LUAU_ASSERT(size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate >> kTValueSizeLog2); // TODO: handle out of range values - build.add(inst.regA64, inst.regA64, uint16_t(intOp(inst.b) << kTValueSizeLog2)); + // TODO: refactor into a common helper? can't use emitAddOffset because we need a temp register + if (intOp(inst.b) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(inst.regA64, inst.regA64, uint16_t(intOp(inst.b) * sizeof(TValue))); + } + else + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.mov(temp, intOp(inst.b) * sizeof(TValue)); + build.add(inst.regA64, inst.regA64, temp); + } } else LUAU_ASSERT(!"Unsupported instruction form"); @@ -314,8 +411,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // note: this may clobber inst.a, so it's important that we don't use it after this build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); - // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. - build.add(inst.regA64, inst.regA64, castReg(KindA64::x, temp2), kLuaNodeSizeLog2); + build.add(inst.regA64, inst.regA64, zextReg(temp2), kLuaNodeSizeLog2); break; } case IrCmd::GET_HASH_NODE_ADDR: @@ -324,18 +420,16 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) RegisterA64 temp1 = regs.allocTemp(KindA64::w); RegisterA64 temp2 = regs.allocTemp(KindA64::w); - // TODO: this can use bic (andnot) to do hash & ~(-1 << lsizenode) instead but we don't support it yet - build.mov(temp1, 1); + // hash & ((1 << lsizenode) - 1) == hash & ~(-1 << lsizenode) + build.mov(temp1, -1); build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, lsizenode))); build.lsl(temp1, temp1, temp2); - build.sub(temp1, temp1, 1); build.mov(temp2, uintOp(inst.b)); - build.and_(temp2, temp2, temp1); + build.bic(temp2, temp2, temp1); // note: this may clobber inst.a, so it's important that we don't use it after this build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); - // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. - build.add(inst.regA64, inst.regA64, castReg(KindA64::x, temp2), kLuaNodeSizeLog2); + build.add(inst.regA64, inst.regA64, zextReg(temp2), kLuaNodeSizeLog2); break; } case IrCmd::STORE_TAG: @@ -501,6 +595,37 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.fabs(inst.regA64, temp); break; } + case IrCmd::NOT_ANY: + { + inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); + + if (inst.a.kind == IrOpKind::Constant) + { + // other cases should've been constant folded + LUAU_ASSERT(tagOp(inst.a) == LUA_TBOOLEAN); + build.eor(inst.regA64, regOp(inst.b), 1); + } + else + { + Label notbool, exit; + + // use the fact that NIL is the only value less than BOOLEAN to do two tag comparisons at once + LUAU_ASSERT(LUA_TNIL == 0 && LUA_TBOOLEAN == 1); + build.cmp(regOp(inst.a), LUA_TBOOLEAN); + build.b(ConditionA64::NotEqual, notbool); + + // boolean => invert value + build.eor(inst.regA64, regOp(inst.b), 1); + build.b(exit); + + // not boolean => result is true iff tag was nil + build.setLabel(notbool); + build.cset(inst.regA64, ConditionA64::Less); + + build.setLabel(exit); + } + break; + } case IrCmd::JUMP: jumpOrFallthrough(blockOp(inst.a), next); break; @@ -537,10 +662,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::JUMP_EQ_TAG: - if (inst.b.kind == IrOpKind::Constant) + if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant) build.cmp(regOp(inst.a), tagOp(inst.b)); - else if (inst.b.kind == IrOpKind::Inst) + else if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Inst) build.cmp(regOp(inst.a), regOp(inst.b)); + else if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Inst) + build.cmp(regOp(inst.b), tagOp(inst.a)); else LUAU_ASSERT(!"Unsupported instruction form"); @@ -570,10 +697,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { IrCondition cond = conditionOp(inst.c); - RegisterA64 temp1 = tempDouble(inst.a); - RegisterA64 temp2 = tempDouble(inst.b); + if (inst.b.kind == IrOpKind::Constant && doubleOp(inst.b) == 0.0) + { + RegisterA64 temp = tempDouble(inst.a); + + build.fcmpz(temp); + } + else + { + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + + build.fcmp(temp1, temp2); + } - build.fcmp(temp1, temp2); build.b(getConditionFP(cond), labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; @@ -607,6 +744,30 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.e), next); break; } + case IrCmd::JUMP_SLOT_MATCH: + { + // TODO: share code with CHECK_SLOT_MATCH + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp1w = castReg(KindA64::w, temp1); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyTag)); + build.and_(temp1w, temp1w, kLuaNodeTagMask); + build.cmp(temp1w, LUA_TSTRING); + build.b(ConditionA64::NotEqual, labelOp(inst.d)); + + AddressA64 addr = tempAddr(inst.b, offsetof(TValue, value)); + build.ldr(temp1, mem(regOp(inst.a), offsetof(LuaNode, key.value))); + build.ldr(temp2, addr); + build.cmp(temp1, temp2); + build.b(ConditionA64::NotEqual, labelOp(inst.d)); + + build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, val.tt))); + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp1w, labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.c), next); + break; + } case IrCmd::TABLE_LEN: { regs.assertAllFreeExcept(regOp(inst.a)); @@ -664,6 +825,32 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; } + case IrCmd::TRY_CALL_FASTGETTM: + { + regs.assertAllFreeExcept(regOp(inst.a)); + + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + + build.ldr(temp1, mem(regOp(inst.a), offsetof(Table, metatable))); + build.cbz(temp1, labelOp(inst.c)); // no metatable + + build.ldrb(temp2, mem(temp1, offsetof(Table, tmcache))); + build.tst(temp2, 1 << intOp(inst.b)); // can't use tbz/tbnz because their jump offsets are too short + build.b(ConditionA64::NotEqual, labelOp(inst.c)); // Equal = Zero after tst; tmcache caches *absence* of metamethods + + build.mov(x0, temp1); + build.mov(w1, intOp(inst.b)); + build.ldr(x2, mem(rState, offsetof(lua_State, global))); + build.ldr(x2, mem(x2, offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaT_gettm))); + build.blr(x3); + + // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + inst.regA64 = regs.allocReg(KindA64::x); + build.mov(inst.regA64, x0); + break; + } case IrCmd::INT_TO_NUM: { inst.regA64 = regs.allocReg(KindA64::d); @@ -683,8 +870,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) else if (inst.b.kind == IrOpKind::Inst) { build.add(temp, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); - // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. - build.add(temp, temp, castReg(KindA64::x, regOp(inst.b)), kTValueSizeLog2); + build.add(temp, temp, zextReg(regOp(inst.b)), kTValueSizeLog2); build.str(temp, mem(rState, offsetof(lua_State, top))); } else @@ -699,6 +885,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.str(temp, mem(rState, offsetof(lua_State, top))); break; } + case IrCmd::FASTCALL: + regs.assertAllFree(); + // TODO: emitBuiltin should be exhaustive + if (!emitBuiltin(build, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f))) + error = true; + break; case IrCmd::INVOKE_FASTCALL: { regs.assertAllFree(); @@ -710,18 +902,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (inst.d.kind == IrOpKind::VmReg) build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue))); else if (inst.d.kind == IrOpKind::VmConst) - { - // TODO: refactor into a common helper - if (vmConstOp(inst.d) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) - { - build.add(x4, rConstants, uint16_t(vmConstOp(inst.d) * sizeof(TValue))); - } - else - { - build.mov(x4, vmConstOp(inst.d) * sizeof(TValue)); - build.add(x4, rConstants, x4); - } - } + emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue)); else LUAU_ASSERT(boolOp(inst.d) == false); @@ -742,7 +923,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction))); build.blr(x6); - // TODO: we could takeReg w0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + // since w0 came from a call, we need to move it so that we don't violate zextReg safety contract inst.regA64 = regs.allocReg(KindA64::w); build.mov(inst.regA64, w0); break; @@ -758,18 +939,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); if (inst.c.kind == IrOpKind::VmConst) - { - // TODO: refactor into a common helper - if (vmConstOp(inst.c) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) - { - build.add(x3, rConstants, uint16_t(vmConstOp(inst.c) * sizeof(TValue))); - } - else - { - build.mov(x3, vmConstOp(inst.c) * sizeof(TValue)); - build.add(x3, rConstants, x3); - } - } + emitAddOffset(build, x3, rConstants, vmConstOp(inst.c) * sizeof(TValue)); else build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); @@ -835,7 +1005,25 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::GET_IMPORT: regs.assertAllFree(); - emitInstGetImport(build, vmRegOp(inst.a), uintOp(inst.b)); + // luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false) + build.mov(x0, rState); + build.ldr(x1, mem(rClosure, offsetof(Closure, env))); + build.mov(x2, rConstants); + build.mov(w3, uintOp(inst.b)); + build.mov(w4, 0); + build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_getimport))); + build.blr(x5); + + emitUpdateBase(build); + + // setobj2s(L, ra, L->top - 1) + build.ldr(x0, mem(rState, offsetof(lua_State, top))); + build.sub(x0, x0, sizeof(TValue)); + build.ldr(q0, x0); + build.str(q0, mem(rBase, vmRegOp(inst.a) * sizeof(TValue))); + + // L->top-- + build.str(x0, mem(rState, offsetof(lua_State, top))); break; case IrCmd::CONCAT: regs.assertAllFree(); @@ -877,7 +1065,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp2 = regs.allocTemp(KindA64::x); RegisterA64 temp3 = regs.allocTemp(KindA64::q); - RegisterA64 temp4 = regs.allocTemp(KindA64::x); // UpVal* build.ldr(temp1, mem(rClosure, offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.a) + offsetof(TValue, value.gc))); @@ -887,7 +1074,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.str(temp3, temp2); Label skip; - checkObjectBarrierConditions(build, temp1, temp2, temp4, vmRegOp(inst.b), skip); + checkObjectBarrierConditions(build, temp1, temp2, vmRegOp(inst.b), skip); build.mov(x0, rState); build.mov(x1, temp1); // TODO: aliasing hazard @@ -945,8 +1132,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(temp, regOp(inst.b)); else if (inst.b.kind == IrOpKind::Constant) { - LUAU_ASSERT(size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); // TODO: handle out of range values - build.cmp(temp, uint16_t(intOp(inst.b))); + // TODO: refactor into a common helper? + if (size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate) + { + build.cmp(temp, uint16_t(intOp(inst.b))); + } + else + { + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + build.mov(temp2, intOp(inst.b)); + build.cmp(temp, temp2); + } } else LUAU_ASSERT(!"Unsupported instruction form"); @@ -959,12 +1155,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp1w = castReg(KindA64::w, temp1); RegisterA64 temp2 = regs.allocTemp(KindA64::x); - RegisterA64 temp2w = castReg(KindA64::w, temp2); - build.ldr(temp1w, mem(regOp(inst.a), kOffsetOfLuaNodeTag)); - // TODO: this needs bitfield extraction, or and-immediate - build.mov(temp2w, kLuaNodeTagMask); - build.and_(temp1w, temp1w, temp2w); + build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyTag)); + build.and_(temp1w, temp1w, kLuaNodeTagMask); build.cmp(temp1w, LUA_TSTRING); build.b(ConditionA64::NotEqual, labelOp(inst.c)); @@ -979,6 +1172,15 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cbz(temp1w, labelOp(inst.c)); break; } + case IrCmd::CHECK_NODE_NO_NEXT: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + + build.ldr(temp, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyNext)); + build.and_(temp, temp, ~((1u << kNextBitOffset) - 1)); // TODO: this would be cleaner with a right shift + build.cbnz(temp, labelOp(inst.b)); + break; + } case IrCmd::INTERRUPT: { unsigned int pcpos = uintOp(inst.a); @@ -1023,11 +1225,10 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { regs.assertAllFreeExcept(regOp(inst.a)); - Label skip; - RegisterA64 temp1 = regs.allocTemp(KindA64::x); - RegisterA64 temp2 = regs.allocTemp(KindA64::x); + RegisterA64 temp = regs.allocTemp(KindA64::x); - checkObjectBarrierConditions(build, regOp(inst.a), temp1, temp2, vmRegOp(inst.b), skip); + Label skip; + checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), skip); build.mov(x0, rState); build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard @@ -1044,15 +1245,13 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.assertAllFreeExcept(regOp(inst.a)); Label skip; - RegisterA64 temp1 = regs.allocTemp(KindA64::w); - RegisterA64 temp2 = regs.allocTemp(KindA64::w); + RegisterA64 temp = regs.allocTemp(KindA64::w); // isblack(obj2gco(t)) - build.ldrb(temp1, mem(regOp(inst.a), offsetof(GCheader, marked))); + build.ldrb(temp, mem(regOp(inst.a), offsetof(GCheader, marked))); // TODO: conditional bit test with BLACKBIT - build.mov(temp2, bitmask(BLACKBIT)); - build.and_(temp1, temp1, temp2); - build.cbz(temp1, skip); + build.tst(temp, bitmask(BLACKBIT)); + build.b(ConditionA64::Equal, skip); // Equal = Zero after tst build.mov(x0, rState); build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard here and below @@ -1068,11 +1267,10 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { regs.assertAllFreeExcept(regOp(inst.a)); - Label skip; - RegisterA64 temp1 = regs.allocTemp(KindA64::x); - RegisterA64 temp2 = regs.allocTemp(KindA64::x); + RegisterA64 temp = regs.allocTemp(KindA64::x); - checkObjectBarrierConditions(build, regOp(inst.a), temp1, temp2, vmRegOp(inst.b), skip); + Label skip; + checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), skip); build.mov(x0, rState); build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard @@ -1086,21 +1284,10 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::SET_SAVEDPC: { - unsigned int pcpos = uintOp(inst.a); RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp2 = regs.allocTemp(KindA64::x); - // TODO: refactor into a common helper - if (pcpos * sizeof(Instruction) <= AssemblyBuilderA64::kMaxImmediate) - { - build.add(temp1, rCode, uint16_t(pcpos * sizeof(Instruction))); - } - else - { - build.mov(temp1, pcpos * sizeof(Instruction)); - build.add(temp1, rCode, temp1); - } - + emitAddOffset(build, temp1, rCode, uintOp(inst.a) * sizeof(Instruction)); build.ldr(temp2, mem(rState, offsetof(lua_State, ci))); build.str(temp1, mem(temp2, offsetof(CallInfo, savedpc))); break; @@ -1133,14 +1320,100 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::CAPTURE: // no-op break; + case IrCmd::SETLIST: + regs.assertAllFree(); + emitFallback(build, LOP_SETLIST, uintOp(inst.a)); + break; case IrCmd::CALL: regs.assertAllFree(); - emitInstCall(build, helpers, vmRegOp(inst.a), intOp(inst.b), intOp(inst.c)); + // argtop = (nparams == LUA_MULTRET) ? L->top : ra + 1 + nparams; + if (intOp(inst.b) == LUA_MULTRET) + build.ldr(x2, mem(rState, offsetof(lua_State, top))); + else + build.add(x2, rBase, uint16_t((vmRegOp(inst.a) + 1 + intOp(inst.b)) * sizeof(TValue))); + + // callFallback(L, ra, argtop, nresults) + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.mov(w3, intOp(inst.c)); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, callFallback))); + build.blr(x4); + + // reentry with x0=closure (NULL will trigger exit) + build.b(helpers.reentry); break; case IrCmd::RETURN: regs.assertAllFree(); - emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); + // callFallback(L, ra, n) + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.mov(w2, intOp(inst.b)); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, returnFallback))); + build.blr(x3); + + // reentry with x0=closure (NULL will trigger exit) + build.b(helpers.reentry); + break; + case IrCmd::FORGLOOP: + // register layout: ra + 1 = table, ra + 2 = internal index, ra + 3 .. ra + aux = iteration variables + regs.assertAllFree(); + // clear extra variables since we might have more than two + if (intOp(inst.b) > 2) + { + build.mov(w0, LUA_TNIL); + for (int i = 2; i < intOp(inst.b); ++i) + build.str(w0, mem(rBase, (vmRegOp(inst.a) + 3 + i) * sizeof(TValue) + offsetof(TValue, tt))); + } + // we use full iter fallback for now; in the future it could be worthwhile to accelerate array iteration here + build.mov(x0, rState); + build.ldr(x1, mem(rBase, (vmRegOp(inst.a) + 1) * sizeof(TValue) + offsetof(TValue, value.gc))); + build.ldr(w2, mem(rBase, (vmRegOp(inst.a) + 2) * sizeof(TValue) + offsetof(TValue, value.p))); + build.add(x3, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, forgLoopTableIter))); + build.blr(x4); + // note: no emitUpdateBase necessary because forgLoopTableIter does not reallocate stack + build.cbnz(w0, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + break; + case IrCmd::FORGLOOP_FALLBACK: + regs.assertAllFree(); + build.mov(x0, rState); + build.mov(w1, vmRegOp(inst.a)); + build.mov(w2, intOp(inst.b)); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, forgLoopNonTableFallback))); + build.blr(x3); + emitUpdateBase(build); + build.cbnz(w0, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + break; + case IrCmd::FORGPREP_XNEXT_FALLBACK: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.mov(w2, uintOp(inst.a) + 1); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, forgPrepXnextFallback))); + build.blr(x3); + // note: no emitUpdateBase necessary because forgLoopNonTableFallback does not reallocate stack + jumpOrFallthrough(blockOp(inst.c), next); break; + case IrCmd::COVERAGE: + { + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + RegisterA64 temp3 = regs.allocTemp(KindA64::w); + + build.mov(temp1, uintOp(inst.a) * sizeof(Instruction)); + build.ldr(temp2, mem(rCode, temp1)); + + // increments E (high 24 bits); if the result overflows a 23-bit counter, high bit becomes 1 + // note: cmp can be eliminated with adds but we aren't concerned with code size for coverage + build.add(temp3, temp2, 256); + build.cmp(temp3, 0); + build.csel(temp2, temp2, temp3, ConditionA64::Less); + + build.str(temp2, mem(rCode, temp1)); + break; + } // Full instruction fallbacks case IrCmd::FALLBACK_GETGLOBAL: @@ -1208,9 +1481,25 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.assertAllFree(); emitFallback(build, LOP_DUPCLOSURE, uintOp(inst.a)); break; + case IrCmd::FALLBACK_FORGPREP: + regs.assertAllFree(); + emitFallback(build, LOP_FORGPREP, uintOp(inst.a)); + jumpOrFallthrough(blockOp(inst.c), next); + break; - default: - LUAU_ASSERT(!"Not supported yet"); + // Pseudo instructions + case IrCmd::NOP: + case IrCmd::SUBSTITUTE: + LUAU_ASSERT(!"Pseudo instructions should not be lowered"); + break; + + // Unsupported instructions + // Note: when adding implementations for these, please move the case: label so that implemented instructions match the order in IrData.h + case IrCmd::STORE_VECTOR: +#ifdef TRACE + gStatsA64.missing++; +#endif + error = true; break; } @@ -1220,7 +1509,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) bool IrLoweringA64::hasError() const { - return false; + return error; } bool IrLoweringA64::isFallthroughBlock(IrBlock target, IrBlock next) @@ -1287,17 +1576,7 @@ AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset) RegisterA64 temp = regs.allocTemp(KindA64::x); - // TODO: refactor into a common helper - if (constantOffset <= AssemblyBuilderA64::kMaxImmediate) - { - build.add(temp, rConstants, uint16_t(constantOffset)); - } - else - { - build.mov(temp, int(constantOffset)); - build.add(temp, rConstants, temp); - } - + emitAddOffset(build, temp, rConstants, constantOffset); return temp; } // If we have a register, we assume it's a pointer to TValue diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index b374a26a0..0c9f87444 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -26,8 +26,6 @@ struct IrLoweringA64 { IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function); - static bool canLower(const IrFunction& function); - void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); bool hasError() const; @@ -61,6 +59,8 @@ struct IrLoweringA64 IrFunction& function; IrRegAllocA64 regs; + + bool error = false; }; } // namespace A64 diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index f2dfdb3b1..51325a37b 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -31,6 +31,8 @@ IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, { // In order to allocate registers during lowering, we need to know where instruction results are last used updateLastUseLocations(function); + + build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); } void IrLoweringX64::storeDoubleAsFloat(OperandX64 dst, IrOp src) @@ -59,7 +61,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) switch (inst.cmd) { case IrCmd::LOAD_TAG: - inst.regX64 = regs.allocGprReg(SizeX64::dword, index); + inst.regX64 = regs.allocReg(SizeX64::dword, index); if (inst.a.kind == IrOpKind::VmReg) build.mov(inst.regX64, luauRegTag(vmRegOp(inst.a))); @@ -73,7 +75,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_POINTER: - inst.regX64 = regs.allocGprReg(SizeX64::qword, index); + inst.regX64 = regs.allocReg(SizeX64::qword, index); if (inst.a.kind == IrOpKind::VmReg) build.mov(inst.regX64, luauRegValue(vmRegOp(inst.a))); @@ -87,7 +89,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_DOUBLE: - inst.regX64 = regs.allocXmmReg(index); + inst.regX64 = regs.allocReg(SizeX64::xmmword, index); if (inst.a.kind == IrOpKind::VmReg) build.vmovsd(inst.regX64, luauRegValue(vmRegOp(inst.a))); @@ -97,12 +99,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_INT: - inst.regX64 = regs.allocGprReg(SizeX64::dword, index); + inst.regX64 = regs.allocReg(SizeX64::dword, index); build.mov(inst.regX64, luauRegValueInt(vmRegOp(inst.a))); break; case IrCmd::LOAD_TVALUE: - inst.regX64 = regs.allocXmmReg(index); + inst.regX64 = regs.allocReg(SizeX64::xmmword, index); if (inst.a.kind == IrOpKind::VmReg) build.vmovups(inst.regX64, luauReg(vmRegOp(inst.a))); @@ -114,12 +116,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_NODE_VALUE_TV: - inst.regX64 = regs.allocXmmReg(index); + inst.regX64 = regs.allocReg(SizeX64::xmmword, index); build.vmovups(inst.regX64, luauNodeValue(regOp(inst.a))); break; case IrCmd::LOAD_ENV: - inst.regX64 = regs.allocGprReg(SizeX64::qword, index); + inst.regX64 = regs.allocReg(SizeX64::qword, index); build.mov(inst.regX64, sClosure); build.mov(inst.regX64, qword[inst.regX64 + offsetof(Closure, env)]); @@ -127,7 +129,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::GET_ARR_ADDR: if (inst.b.kind == IrOpKind::Inst) { - inst.regX64 = regs.allocGprRegOrReuse(SizeX64::qword, index, {inst.b}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::qword, index, {inst.b}); if (dwordReg(inst.regX64) != regOp(inst.b)) build.mov(dwordReg(inst.regX64), regOp(inst.b)); @@ -137,7 +139,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } else if (inst.b.kind == IrOpKind::Constant) { - inst.regX64 = regs.allocGprRegOrReuse(SizeX64::qword, index, {inst.a}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::qword, index, {inst.a}); build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, array)]); @@ -151,7 +153,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::GET_SLOT_NODE_ADDR: { - inst.regX64 = regs.allocGprReg(SizeX64::qword, index); + inst.regX64 = regs.allocReg(SizeX64::qword, index); ScopedRegX64 tmp{regs, SizeX64::qword}; @@ -160,11 +162,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::GET_HASH_NODE_ADDR: { - inst.regX64 = regs.allocGprReg(SizeX64::qword, index); - // Custom bit shift value can only be placed in cl ScopedRegX64 shiftTmp{regs, regs.takeReg(rcx, kInvalidInstIdx)}; + inst.regX64 = regs.allocReg(SizeX64::qword, index); + ScopedRegX64 tmp{regs, SizeX64::qword}; build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, node)]); @@ -232,7 +234,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.vmovups(luauNodeValue(regOp(inst.a)), regOp(inst.b)); break; case IrCmd::ADD_INT: - inst.regX64 = regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.a}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); if (inst.regX64 == regOp(inst.a) && intOp(inst.b) == 1) build.inc(inst.regX64); @@ -242,7 +244,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.lea(inst.regX64, addr[regOp(inst.a) + intOp(inst.b)]); break; case IrCmd::SUB_INT: - inst.regX64 = regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.a}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); if (inst.regX64 == regOp(inst.a) && intOp(inst.b) == 1) build.dec(inst.regX64); @@ -252,7 +254,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.lea(inst.regX64, addr[regOp(inst.a) - intOp(inst.b)]); break; case IrCmd::ADD_NUM: - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); if (inst.a.kind == IrOpKind::Constant) { @@ -267,7 +269,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::SUB_NUM: - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); if (inst.a.kind == IrOpKind::Constant) { @@ -282,7 +284,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::MUL_NUM: - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); if (inst.a.kind == IrOpKind::Constant) { @@ -297,7 +299,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::DIV_NUM: - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); if (inst.a.kind == IrOpKind::Constant) { @@ -313,7 +315,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::MOD_NUM: { - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); ScopedRegX64 optLhsTmp{regs}; RegisterX64 lhs; @@ -362,7 +364,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::MIN_NUM: - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); if (inst.a.kind == IrOpKind::Constant) { @@ -377,7 +379,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::MAX_NUM: - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); if (inst.a.kind == IrOpKind::Constant) { @@ -393,7 +395,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::UNM_NUM: { - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); RegisterX64 src = regOp(inst.a); @@ -410,18 +412,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::FLOOR_NUM: - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); build.vroundsd(inst.regX64, inst.regX64, memRegDoubleOp(inst.a), RoundingModeX64::RoundToNegativeInfinity); break; case IrCmd::CEIL_NUM: - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); build.vroundsd(inst.regX64, inst.regX64, memRegDoubleOp(inst.a), RoundingModeX64::RoundToPositiveInfinity); break; case IrCmd::ROUND_NUM: { - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); ScopedRegX64 tmp1{regs, SizeX64::xmmword}; ScopedRegX64 tmp2{regs, SizeX64::xmmword}; @@ -439,12 +441,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::SQRT_NUM: - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); build.vsqrtsd(inst.regX64, inst.regX64, memRegDoubleOp(inst.a)); break; case IrCmd::ABS_NUM: - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); if (inst.a.kind != IrOpKind::Inst) build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); @@ -456,7 +458,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::NOT_ANY: { // TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target - inst.regX64 = regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); + inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); Label saveone, savezero, exit; @@ -558,7 +560,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]); - inst.regX64 = regs.allocXmmReg(index); + inst.regX64 = regs.allocReg(SizeX64::xmmword, index); build.vcvtsi2sd(inst.regX64, inst.regX64, eax); break; } @@ -566,8 +568,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { IrCallWrapperX64 callWrap(regs, build, index); callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.a)), inst.a); - callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.b)), inst.b); + callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.a))); + callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.b))); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); inst.regX64 = regs.takeReg(rax, index); break; @@ -583,7 +585,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::TRY_NUM_TO_INDEX: { - inst.regX64 = regs.allocGprReg(SizeX64::dword, index); + inst.regX64 = regs.allocReg(SizeX64::dword, index); ScopedRegX64 tmp{regs, SizeX64::xmmword}; @@ -620,7 +622,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::INT_TO_NUM: - inst.regX64 = regs.allocXmmReg(index); + inst.regX64 = regs.allocReg(SizeX64::xmmword, index); build.vcvtsi2sd(inst.regX64, inst.regX64, regOp(inst.a)); break; @@ -688,11 +690,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (nparams == LUA_MULTRET) { - // Compute 'L->top - (ra + 1)', on SystemV, take r9 register to compute directly into the argument - // TODO: IrCallWrapperX64 should provide a way to 'guess' target argument register correctly - RegisterX64 reg = build.abi == ABIX64::Windows ? regs.allocGprReg(SizeX64::qword, kInvalidInstIdx) : regs.takeReg(rArg6, kInvalidInstIdx); + RegisterX64 reg = callWrap.suggestNextArgumentRegister(SizeX64::qword); ScopedRegX64 tmp{regs, SizeX64::qword}; + // L->top - (ra + 1) build.mov(reg, qword[rState + offsetof(lua_State, top)]); build.lea(tmp.reg, addr[rBase + (ra + 1) * sizeof(TValue)]); build.sub(reg, tmp.reg); @@ -759,9 +760,35 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::GET_IMPORT: - regs.assertAllFree(); - emitInstGetImportFallback(build, vmRegOp(inst.a), uintOp(inst.b)); + { + ScopedRegX64 tmp1{regs, SizeX64::qword}; + + build.mov(tmp1.reg, sClosure); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, qword[tmp1.release() + offsetof(Closure, env)]); + callWrap.addArgument(SizeX64::qword, rConstants); + callWrap.addArgument(SizeX64::dword, uintOp(inst.b)); + callWrap.addArgument(SizeX64::dword, 0); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_getimport)]); + + emitUpdateBase(build); + + ScopedRegX64 tmp2{regs, SizeX64::qword}; + + // setobj2s(L, ra, L->top - 1) + build.mov(tmp2.reg, qword[rState + offsetof(lua_State, top)]); + build.sub(tmp2.reg, sizeof(TValue)); + + ScopedRegX64 tmp3{regs, SizeX64::xmmword}; + build.vmovups(tmp3.reg, xmmword[tmp2.reg]); + build.vmovups(luauReg(vmRegOp(inst.a)), tmp3.reg); + + // L->top-- + build.mov(qword[rState + offsetof(lua_State, top)], tmp2.reg); break; + } case IrCmd::CONCAT: { IrCallWrapperX64 callWrap(regs, build, index); @@ -783,7 +810,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // uprefs[] is either an actual value, or it points to UpVal object which has a pointer to value Label skip; - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though build.cmp(dword[tmp1.reg + offsetof(TValue, tt)], LUA_TUPVAL); build.jcc(ConditionX64::NotEqual, skip); @@ -822,36 +848,25 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callPrepareForN(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); break; case IrCmd::CHECK_TAG: - if (inst.a.kind == IrOpKind::Inst) - { - build.cmp(regOp(inst.a), tagOp(inst.b)); - build.jcc(ConditionX64::NotEqual, labelOp(inst.c)); - } - else if (inst.a.kind == IrOpKind::VmReg) - { - jumpIfTagIsNot(build, vmRegOp(inst.a), lua_Type(tagOp(inst.b)), labelOp(inst.c)); - } - else if (inst.a.kind == IrOpKind::VmConst) - { - build.cmp(luauConstantTag(vmConstOp(inst.a)), tagOp(inst.b)); - build.jcc(ConditionX64::NotEqual, labelOp(inst.c)); - } - else - { - LUAU_ASSERT(!"Unsupported instruction form"); - } + build.cmp(memRegTagOp(inst.a), tagOp(inst.b)); + build.jcc(ConditionX64::NotEqual, labelOp(inst.c)); break; case IrCmd::CHECK_READONLY: - jumpIfTableIsReadOnly(build, regOp(inst.a), labelOp(inst.b)); + build.cmp(byte[regOp(inst.a) + offsetof(Table, readonly)], 0); + build.jcc(ConditionX64::NotEqual, labelOp(inst.b)); break; case IrCmd::CHECK_NO_METATABLE: - jumpIfMetatablePresent(build, regOp(inst.a), labelOp(inst.b)); + build.cmp(qword[regOp(inst.a) + offsetof(Table, metatable)], 0); + build.jcc(ConditionX64::NotEqual, labelOp(inst.b)); break; case IrCmd::CHECK_SAFE_ENV: { ScopedRegX64 tmp{regs, SizeX64::qword}; - jumpIfUnsafeEnv(build, tmp.reg, labelOp(inst.a)); + build.mov(tmp.reg, sClosure); + build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); + build.cmp(byte[tmp.reg + offsetof(Table, safeenv)], 0); + build.jcc(ConditionX64::Equal, labelOp(inst.a)); break; } case IrCmd::CHECK_ARRAY_SIZE: @@ -872,11 +887,16 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::CHECK_NODE_NO_NEXT: - jumpIfNodeHasNext(build, regOp(inst.a), labelOp(inst.b)); + { + ScopedRegX64 tmp{regs, SizeX64::dword}; + + build.mov(tmp.reg, dword[regOp(inst.a) + offsetof(LuaNode, key) + kOffsetOfTKeyNext]); + build.shr(tmp.reg, kNextBitOffset); + build.jcc(ConditionX64::NotZero, labelOp(inst.b)); break; + } case IrCmd::INTERRUPT: - regs.assertAllFree(); - emitInterrupt(build, uintOp(inst.a)); + emitInterrupt(regs, build, uintOp(inst.a)); break; case IrCmd::CHECK_GC: callStepGc(regs, build); @@ -970,94 +990,127 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::FORGLOOP: regs.assertAllFree(); - emitinstForGLoop(build, vmRegOp(inst.a), intOp(inst.b), labelOp(inst.c), labelOp(inst.d)); + emitInstForGLoop(build, vmRegOp(inst.a), intOp(inst.b), labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); break; case IrCmd::FORGLOOP_FALLBACK: - regs.assertAllFree(); - emitinstForGLoopFallback(build, vmRegOp(inst.a), intOp(inst.b), labelOp(inst.c)); - build.jmp(labelOp(inst.d)); + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, vmRegOp(inst.a)); + callWrap.addArgument(SizeX64::dword, intOp(inst.b)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNonTableFallback)]); + + emitUpdateBase(build); + + build.test(al, al); + build.jcc(ConditionX64::NotZero, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); break; + } case IrCmd::FORGPREP_XNEXT_FALLBACK: - regs.assertAllFree(); - emitInstForGPrepXnextFallback(build, uintOp(inst.a), vmRegOp(inst.b), labelOp(inst.c)); + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(vmRegOp(inst.b))); + callWrap.addArgument(SizeX64::dword, uintOp(inst.a) + 1); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, forgPrepXnextFallback)]); + jumpOrFallthrough(blockOp(inst.c), next); break; + } case IrCmd::COVERAGE: - regs.assertAllFree(); - emitInstCoverage(build, uintOp(inst.a)); + { + ScopedRegX64 tmp1{regs, SizeX64::qword}; + ScopedRegX64 tmp2{regs, SizeX64::dword}; + ScopedRegX64 tmp3{regs, SizeX64::dword}; + + build.mov(tmp1.reg, sCode); + build.add(tmp1.reg, uintOp(inst.a) * sizeof(Instruction)); + + // hits = LUAU_INSN_E(*pc) + build.mov(tmp2.reg, dword[tmp1.reg]); + build.sar(tmp2.reg, 8); + + // hits = (hits < (1 << 23) - 1) ? hits + 1 : hits; + build.xor_(tmp3.reg, tmp3.reg); + build.cmp(tmp2.reg, (1 << 23) - 1); + build.setcc(ConditionX64::NotEqual, byteReg(tmp3.reg)); + build.add(tmp2.reg, tmp3.reg); + + // VM_PATCH_E(pc, hits); + build.sal(tmp2.reg, 8); + build.movzx(tmp3.reg, byte[tmp1.reg]); + build.or_(tmp3.reg, tmp2.reg); + build.mov(dword[tmp1.reg], tmp3.reg); break; + } // Full instruction fallbacks case IrCmd::FALLBACK_GETGLOBAL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); - regs.assertAllFree(); - emitFallback(build, data, LOP_GETGLOBAL, uintOp(inst.a)); + emitFallback(regs, build, data, LOP_GETGLOBAL, uintOp(inst.a)); break; case IrCmd::FALLBACK_SETGLOBAL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); - regs.assertAllFree(); - emitFallback(build, data, LOP_SETGLOBAL, uintOp(inst.a)); + emitFallback(regs, build, data, LOP_SETGLOBAL, uintOp(inst.a)); break; case IrCmd::FALLBACK_GETTABLEKS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - regs.assertAllFree(); - emitFallback(build, data, LOP_GETTABLEKS, uintOp(inst.a)); + emitFallback(regs, build, data, LOP_GETTABLEKS, uintOp(inst.a)); break; case IrCmd::FALLBACK_SETTABLEKS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - regs.assertAllFree(); - emitFallback(build, data, LOP_SETTABLEKS, uintOp(inst.a)); + emitFallback(regs, build, data, LOP_SETTABLEKS, uintOp(inst.a)); break; case IrCmd::FALLBACK_NAMECALL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - regs.assertAllFree(); - emitFallback(build, data, LOP_NAMECALL, uintOp(inst.a)); + emitFallback(regs, build, data, LOP_NAMECALL, uintOp(inst.a)); break; case IrCmd::FALLBACK_PREPVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); - regs.assertAllFree(); - emitFallback(build, data, LOP_PREPVARARGS, uintOp(inst.a)); + emitFallback(regs, build, data, LOP_PREPVARARGS, uintOp(inst.a)); break; case IrCmd::FALLBACK_GETVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - regs.assertAllFree(); - emitFallback(build, data, LOP_GETVARARGS, uintOp(inst.a)); + emitFallback(regs, build, data, LOP_GETVARARGS, uintOp(inst.a)); break; case IrCmd::FALLBACK_NEWCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - regs.assertAllFree(); - emitFallback(build, data, LOP_NEWCLOSURE, uintOp(inst.a)); + emitFallback(regs, build, data, LOP_NEWCLOSURE, uintOp(inst.a)); break; case IrCmd::FALLBACK_DUPCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); - regs.assertAllFree(); - emitFallback(build, data, LOP_DUPCLOSURE, uintOp(inst.a)); + emitFallback(regs, build, data, LOP_DUPCLOSURE, uintOp(inst.a)); break; case IrCmd::FALLBACK_FORGPREP: - regs.assertAllFree(); - emitFallback(build, data, LOP_FORGPREP, uintOp(inst.a)); + emitFallback(regs, build, data, LOP_FORGPREP, uintOp(inst.a)); + jumpOrFallthrough(blockOp(inst.c), next); break; - default: - LUAU_ASSERT(!"Not supported yet"); + + // Pseudo instructions + case IrCmd::NOP: + case IrCmd::SUBSTITUTE: + LUAU_ASSERT(!"Pseudo instructions should not be lowered"); break; } diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index c6db9e9e0..9a06cf69e 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -1,9 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "IrRegAllocA64.h" -#ifdef _MSC_VER -#include -#endif +#include "BitUtils.h" namespace Luau { @@ -12,19 +10,6 @@ namespace CodeGen namespace A64 { -inline int setBit(uint32_t n) -{ - LUAU_ASSERT(n); - -#ifdef _MSC_VER - unsigned long rl; - _BitScanReverse(&rl, n); - return int(rl); -#else - return 31 - __builtin_clz(n); -#endif -} - IrRegAllocA64::IrRegAllocA64(IrFunction& function, std::initializer_list> regs) : function(function) { @@ -52,7 +37,7 @@ RegisterA64 IrRegAllocA64::allocReg(KindA64 kind) return noreg; } - int index = setBit(set.free); + int index = 31 - countlz(set.free); set.free &= ~(1u << index); return RegisterA64{kind, uint8_t(index)}; @@ -68,7 +53,7 @@ RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind) return noreg; } - int index = setBit(set.free); + int index = 31 - countlz(set.free); set.free &= ~(1u << index); set.temp |= 1u << index; diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index dc9e7f908..24d8f51a3 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrRegAllocX64.h" +#include "Luau/IrUtils.h" + #include "EmitCommonX64.h" namespace Luau @@ -12,11 +14,6 @@ namespace X64 static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; -static bool isFullTvalueOperand(IrCmd cmd) -{ - return cmd == IrCmd::LOAD_TVALUE || cmd == IrCmd::LOAD_NODE_VALUE_TV; -} - IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function) : build(build) , function(function) @@ -27,50 +24,43 @@ IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function) xmmInstUsers.fill(kInvalidInstIdx); } -RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize, uint32_t instIdx) +RegisterX64 IrRegAllocX64::allocReg(SizeX64 size, uint32_t instIdx) { - LUAU_ASSERT( - preferredSize == SizeX64::byte || preferredSize == SizeX64::word || preferredSize == SizeX64::dword || preferredSize == SizeX64::qword); - - for (RegisterX64 reg : kGprAllocOrder) + if (size == SizeX64::xmmword) { - if (freeGprMap[reg.index]) + for (size_t i = 0; i < freeXmmMap.size(); ++i) { - freeGprMap[reg.index] = false; - gprInstUsers[reg.index] = instIdx; - return RegisterX64{preferredSize, reg.index}; + if (freeXmmMap[i]) + { + freeXmmMap[i] = false; + xmmInstUsers[i] = instIdx; + return RegisterX64{size, uint8_t(i)}; + } } } - - // If possible, spill the value with the furthest next use - if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(gprInstUsers); furthestUseTarget != kInvalidInstIdx) - return takeReg(function.instructions[furthestUseTarget].regX64, instIdx); - - LUAU_ASSERT(!"Out of GPR registers to allocate"); - return noreg; -} - -RegisterX64 IrRegAllocX64::allocXmmReg(uint32_t instIdx) -{ - for (size_t i = 0; i < freeXmmMap.size(); ++i) + else { - if (freeXmmMap[i]) + for (RegisterX64 reg : kGprAllocOrder) { - freeXmmMap[i] = false; - xmmInstUsers[i] = instIdx; - return RegisterX64{SizeX64::xmmword, uint8_t(i)}; + if (freeGprMap[reg.index]) + { + freeGprMap[reg.index] = false; + gprInstUsers[reg.index] = instIdx; + return RegisterX64{size, reg.index}; + } } } // Out of registers, spill the value with the furthest next use - if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(xmmInstUsers); furthestUseTarget != kInvalidInstIdx) + const std::array& regInstUsers = size == SizeX64::xmmword ? xmmInstUsers : gprInstUsers; + if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(regInstUsers); furthestUseTarget != kInvalidInstIdx) return takeReg(function.instructions[furthestUseTarget].regX64, instIdx); - LUAU_ASSERT(!"Out of XMM registers to allocate"); + LUAU_ASSERT(!"Out of registers to allocate"); return noreg; } -RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t instIdx, std::initializer_list oprefs) +RegisterX64 IrRegAllocX64::allocRegOrReuse(SizeX64 size, uint32_t instIdx, std::initializer_list oprefs) { for (IrOp op : oprefs) { @@ -81,39 +71,24 @@ RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t in if (source.lastUse == instIdx && !source.reusedReg && !source.spilled) { - LUAU_ASSERT(source.regX64.size != SizeX64::xmmword); + // Not comparing size directly because we only need matching register set + if ((size == SizeX64::xmmword) != (source.regX64.size == SizeX64::xmmword)) + continue; + LUAU_ASSERT(source.regX64 != noreg); source.reusedReg = true; - gprInstUsers[source.regX64.index] = instIdx; - return RegisterX64{preferredSize, source.regX64.index}; - } - } - - return allocGprReg(preferredSize, instIdx); -} - -RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t instIdx, std::initializer_list oprefs) -{ - for (IrOp op : oprefs) - { - if (op.kind != IrOpKind::Inst) - continue; - IrInst& source = function.instructions[op.index]; - - if (source.lastUse == instIdx && !source.reusedReg && !source.spilled) - { - LUAU_ASSERT(source.regX64.size == SizeX64::xmmword); - LUAU_ASSERT(source.regX64 != noreg); + if (size == SizeX64::xmmword) + xmmInstUsers[source.regX64.index] = instIdx; + else + gprInstUsers[source.regX64.index] = instIdx; - source.reusedReg = true; - xmmInstUsers[source.regX64.index] = instIdx; - return source.regX64; + return RegisterX64{size, source.regX64.index}; } } - return allocXmmReg(instIdx); + return allocReg(size, instIdx); } RegisterX64 IrRegAllocX64::takeReg(RegisterX64 reg, uint32_t instIdx) @@ -197,41 +172,34 @@ bool IrRegAllocX64::isLastUseReg(const IrInst& target, uint32_t instIdx) const void IrRegAllocX64::preserve(IrInst& inst) { - bool doubleSlot = isFullTvalueOperand(inst.cmd); + IrSpillX64 spill; + spill.instIdx = function.getInstIndex(inst); + spill.valueKind = getCmdValueKind(inst.cmd); + spill.spillId = nextSpillId++; + spill.originalLoc = inst.regX64; - // Find a free stack slot. Two consecutive slots might be required for 16 byte TValues, so '- 1' is used - for (unsigned i = 0; i < unsigned(usedSpillSlots.size() - 1); ++i) + // Loads from VmReg/VmConst don't have to be spilled, they can be restored from a register later + if (!hasRestoreOp(inst)) { - if (usedSpillSlots.test(i)) - continue; + unsigned i = findSpillStackSlot(spill.valueKind); - if (doubleSlot && usedSpillSlots.test(i + 1)) - { - ++i; // No need to retest this double position - continue; - } - - if (inst.regX64.size == SizeX64::xmmword && doubleSlot) - { + if (spill.valueKind == IrValueKind::Tvalue) build.vmovups(xmmword[sSpillArea + i * 8], inst.regX64); - } - else if (inst.regX64.size == SizeX64::xmmword) - { + else if (spill.valueKind == IrValueKind::Double) build.vmovsd(qword[sSpillArea + i * 8], inst.regX64); - } + else if (spill.valueKind == IrValueKind::Pointer) + build.mov(qword[sSpillArea + i * 8], inst.regX64); + else if (spill.valueKind == IrValueKind::Tag || spill.valueKind == IrValueKind::Int) + build.mov(dword[sSpillArea + i * 8], inst.regX64); else - { - OperandX64 location = addr[sSpillArea + i * 8]; - location.memSize = inst.regX64.size; // Override memory access size - build.mov(location, inst.regX64); - } + LUAU_ASSERT(!"unsupported value kind"); usedSpillSlots.set(i); if (i + 1 > maxUsedSlot) maxUsedSlot = i + 1; - if (doubleSlot) + if (spill.valueKind == IrValueKind::Tvalue) { usedSpillSlots.set(i + 1); @@ -239,22 +207,15 @@ void IrRegAllocX64::preserve(IrInst& inst) maxUsedSlot = i + 2; } - IrSpillX64 spill; - spill.instIdx = function.getInstIndex(inst); - spill.useDoubleSlot = doubleSlot; spill.stackSlot = uint8_t(i); - spill.originalLoc = inst.regX64; - - spills.push_back(spill); + } - freeReg(inst.regX64); + spills.push_back(spill); - inst.regX64 = noreg; - inst.spilled = true; - return; - } + freeReg(inst.regX64); - LUAU_ASSERT(!"nowhere to spill"); + inst.regX64 = noreg; + inst.spilled = true; } void IrRegAllocX64::restore(IrInst& inst, bool intoOriginalLocation) @@ -267,35 +228,34 @@ void IrRegAllocX64::restore(IrInst& inst, bool intoOriginalLocation) if (spill.instIdx == instIdx) { - LUAU_ASSERT(spill.stackSlot != kNoStackSlot); - RegisterX64 reg; + RegisterX64 reg = intoOriginalLocation ? takeReg(spill.originalLoc, instIdx) : allocReg(spill.originalLoc.size, instIdx); + OperandX64 restoreLocation = noreg; - if (spill.originalLoc.size == SizeX64::xmmword) + if (spill.stackSlot != kNoStackSlot) { - reg = intoOriginalLocation ? takeReg(spill.originalLoc, instIdx) : allocXmmReg(instIdx); + restoreLocation = addr[sSpillArea + spill.stackSlot * 8]; + restoreLocation.memSize = reg.size; - if (spill.useDoubleSlot) - build.vmovups(reg, xmmword[sSpillArea + spill.stackSlot * 8]); - else - build.vmovsd(reg, qword[sSpillArea + spill.stackSlot * 8]); + usedSpillSlots.set(spill.stackSlot, false); + + if (spill.valueKind == IrValueKind::Tvalue) + usedSpillSlots.set(spill.stackSlot + 1, false); } else { - reg = intoOriginalLocation ? takeReg(spill.originalLoc, instIdx) : allocGprReg(spill.originalLoc.size, instIdx); - - OperandX64 location = addr[sSpillArea + spill.stackSlot * 8]; - location.memSize = reg.size; // Override memory access size - build.mov(reg, location); + restoreLocation = getRestoreAddress(inst, getRestoreOp(inst)); } + if (spill.valueKind == IrValueKind::Tvalue) + build.vmovups(reg, restoreLocation); + else if (spill.valueKind == IrValueKind::Double) + build.vmovsd(reg, restoreLocation); + else + build.mov(reg, restoreLocation); + inst.regX64 = reg; inst.spilled = false; - usedSpillSlots.set(spill.stackSlot, false); - - if (spill.useDoubleSlot) - usedSpillSlots.set(spill.stackSlot + 1, false); - spills[i] = spills.back(); spills.pop_back(); return; @@ -334,6 +294,81 @@ bool IrRegAllocX64::shouldFreeGpr(RegisterX64 reg) const return false; } +unsigned IrRegAllocX64::findSpillStackSlot(IrValueKind valueKind) +{ + // Find a free stack slot. Two consecutive slots might be required for 16 byte TValues, so '- 1' is used + for (unsigned i = 0; i < unsigned(usedSpillSlots.size() - 1); ++i) + { + if (usedSpillSlots.test(i)) + continue; + + if (valueKind == IrValueKind::Tvalue && usedSpillSlots.test(i + 1)) + { + ++i; // No need to retest this double position + continue; + } + + return i; + } + + LUAU_ASSERT(!"nowhere to spill"); + return ~0u; +} + +IrOp IrRegAllocX64::getRestoreOp(const IrInst& inst) const +{ + switch (inst.cmd) + { + case IrCmd::LOAD_TAG: + case IrCmd::LOAD_POINTER: + case IrCmd::LOAD_DOUBLE: + case IrCmd::LOAD_INT: + case IrCmd::LOAD_TVALUE: + { + IrOp location = inst.a; + + // Might have an alternative location + if (IrOp alternative = function.findRestoreOp(inst); alternative.kind != IrOpKind::None) + location = alternative; + + if (location.kind == IrOpKind::VmReg || location.kind == IrOpKind::VmConst) + return location; + + break; + } + default: + break; + } + + return IrOp(); +} + +bool IrRegAllocX64::hasRestoreOp(const IrInst& inst) const +{ + return getRestoreOp(inst).kind != IrOpKind::None; +} + +OperandX64 IrRegAllocX64::getRestoreAddress(const IrInst& inst, IrOp restoreOp) +{ + switch (inst.cmd) + { + case IrCmd::LOAD_TAG: + return restoreOp.kind == IrOpKind::VmReg ? luauRegTag(vmRegOp(restoreOp)) : luauConstantTag(vmConstOp(restoreOp)); + case IrCmd::LOAD_POINTER: + case IrCmd::LOAD_DOUBLE: + return restoreOp.kind == IrOpKind::VmReg ? luauRegValue(vmRegOp(restoreOp)) : luauConstantValue(vmConstOp(restoreOp)); + case IrCmd::LOAD_INT: + LUAU_ASSERT(restoreOp.kind == IrOpKind::VmReg); + return luauRegValueInt(vmRegOp(restoreOp)); + case IrCmd::LOAD_TVALUE: + return restoreOp.kind == IrOpKind::VmReg ? luauReg(vmRegOp(restoreOp)) : luauConstant(vmConstOp(restoreOp)); + default: + break; + } + + return noreg; +} + uint32_t IrRegAllocX64::findInstructionWithFurthestNextUse(const std::array& regInstUsers) const { uint32_t furthestUseTarget = kInvalidInstIdx; @@ -411,11 +446,7 @@ ScopedRegX64::~ScopedRegX64() void ScopedRegX64::alloc(SizeX64 size) { LUAU_ASSERT(reg == noreg); - - if (size == SizeX64::xmmword) - reg = owner.allocXmmReg(kInvalidInstIdx); - else - reg = owner.allocGprReg(size, kInvalidInstIdx); + reg = owner.allocReg(size, kInvalidInstIdx); } void ScopedRegX64::free() @@ -435,36 +466,34 @@ RegisterX64 ScopedRegX64::release() ScopedSpills::ScopedSpills(IrRegAllocX64& owner) : owner(owner) { - snapshot = owner.spills; + startSpillId = owner.nextSpillId; } ScopedSpills::~ScopedSpills() { - // Taking a copy of current spills because we are going to potentially restore them - std::vector current = owner.spills; + unsigned endSpillId = owner.nextSpillId; - // Restore registers that were spilled inside scope protected by this object - for (IrSpillX64& curr : current) + for (size_t i = 0; i < owner.spills.size();) { - // If spill existed before current scope, it can be restored outside of it - if (!wasSpilledBefore(curr)) + IrSpillX64& spill = owner.spills[i]; + + // Restoring spills inside this scope cannot create new spills + LUAU_ASSERT(spill.spillId < endSpillId); + + // If spill was created inside current scope, it has to be restored + if (spill.spillId >= startSpillId) { - IrInst& inst = owner.function.instructions[curr.instIdx]; + IrInst& inst = owner.function.instructions[spill.instIdx]; owner.restore(inst, /*intoOriginalLocation*/ true); - } - } -} -bool ScopedSpills::wasSpilledBefore(const IrSpillX64& spill) const -{ - for (const IrSpillX64& preexisting : snapshot) - { - if (spill.instIdx == preexisting.instIdx) - return true; + // Spill restore removes the spill entry, so loop is repeated at the same 'i' + } + else + { + i++; + } } - - return false; } } // namespace X64 diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index ba4915645..539fcf770 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -8,6 +8,8 @@ // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results +static const int kMinMaxUnrolledParams = 5; + namespace Luau { namespace CodeGen @@ -23,7 +25,7 @@ BuiltinImplResult translateBuiltinNumberToNumber( return {BuiltinImplType::None, -1}; build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); @@ -40,7 +42,7 @@ BuiltinImplResult translateBuiltin2NumberToNumber( build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(2), build.constInt(1)); if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); @@ -56,12 +58,13 @@ BuiltinImplResult translateBuiltinNumberTo2Number( return {BuiltinImplType::None, -1}; build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + build.inst( + IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(nresults == 1 ? 1 : 2)); if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - if (nresults > 1) + if (nresults != 1) build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 2}; @@ -125,12 +128,33 @@ BuiltinImplResult translateBuiltinMathLog( if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + LuauBuiltinFunction fcId = bfid; + int fcParams = 1; if (nparams != 1) - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + { + if (args.kind != IrOpKind::VmConst) + return {BuiltinImplType::None, -1}; + + LUAU_ASSERT(build.function.proto); + TValue protok = build.function.proto->k[vmConstOp(args)]; + + if (protok.tt != LUA_TNUMBER) + return {BuiltinImplType::None, -1}; + + // TODO: IR builtin lowering assumes that the only valid 2-argument call is log2; ideally, we use a less hacky way to indicate that + if (protok.value.n == 2.0) + fcParams = 2; + else if (protok.value.n == 10.0) + fcId = LBF_MATH_LOG10; + else + // TODO: We can precompute log(args) and divide by it, but that requires extra LOAD/STORE so for now just fall back as this is rare + return {BuiltinImplType::None, -1}; + } - build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + + build.inst(IrCmd::FASTCALL, build.constUint(fcId), build.vmReg(ra), build.vmReg(arg), args, build.constInt(fcParams), build.constInt(1)); if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); @@ -140,17 +164,26 @@ BuiltinImplResult translateBuiltinMathLog( BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { - // TODO: this can be extended for other number of arguments - if (nparams != 2 || nresults > 1) + if (nparams < 2 || nparams > kMinMaxUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + for (int i = 3; i <= nparams; ++i) + build.loadAndCheckTag(build.vmReg(vmRegOp(args) + (i - 2)), LUA_TNUMBER, fallback); + IrOp varg1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); IrOp varg2 = build.inst(IrCmd::LOAD_DOUBLE, args); IrOp res = build.inst(IrCmd::MIN_NUM, varg2, varg1); // Swapped arguments are required for consistency with VM builtins + + for (int i = 3; i <= nparams; ++i) + { + IrOp arg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + (i - 2))); + res = build.inst(IrCmd::MIN_NUM, arg, res); + } + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); if (ra != arg) @@ -161,17 +194,26 @@ BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { - // TODO: this can be extended for other number of arguments - if (nparams != 2 || nresults > 1) + if (nparams < 2 || nparams > kMinMaxUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + for (int i = 3; i <= nparams; ++i) + build.loadAndCheckTag(build.vmReg(vmRegOp(args) + (i - 2)), LUA_TNUMBER, fallback); + IrOp varg1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); IrOp varg2 = build.inst(IrCmd::LOAD_DOUBLE, args); IrOp res = build.inst(IrCmd::MAX_NUM, varg2, varg1); // Swapped arguments are required for consistency with VM builtins + + for (int i = 3; i <= nparams; ++i) + { + IrOp arg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + (i - 2))); + res = build.inst(IrCmd::MAX_NUM, arg, res); + } + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); if (ra != arg) @@ -254,8 +296,7 @@ BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, in if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.inst( - IrCmd::FASTCALL, build.constUint(LBF_TYPE), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + build.inst(IrCmd::FASTCALL, build.constUint(LBF_TYPE), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); @@ -267,8 +308,7 @@ BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.inst( - IrCmd::FASTCALL, build.constUint(LBF_TYPEOF), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + build.inst(IrCmd::FASTCALL, build.constUint(LBF_TYPEOF), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index c5e7c887a..3811ca27b 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -284,7 +284,7 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl block.useCount--; } -void substitute(IrFunction& function, IrInst& inst, IrOp replacement) +void substitute(IrFunction& function, IrInst& inst, IrOp replacement, IrOp location) { LUAU_ASSERT(!isBlockTerminator(inst.cmd)); @@ -298,7 +298,7 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement) removeUse(function, inst.f); inst.a = replacement; - inst.b = {}; + inst.b = location; inst.c = {}; inst.d = {}; inst.e = {}; diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 524796929..cb128de98 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -16,7 +16,7 @@ #include #include -#define CODEGEN_SET_FALLBACK(op, flags) data.context.fallback[op] = {execute_##op, flags} +#define CODEGEN_SET_FALLBACK(op) data.context.fallback[op] = {execute_##op} namespace Luau { @@ -36,20 +36,21 @@ NativeState::~NativeState() = default; void initFallbackTable(NativeState& data) { // When fallback is completely removed, remove it from includeInsts list in lvmexecute_split.py - CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, 0); - CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0); - CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc); - CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0); - CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0); - CODEGEN_SET_FALLBACK(LOP_PREPVARARGS, 0); - CODEGEN_SET_FALLBACK(LOP_BREAK, 0); + CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE); + CODEGEN_SET_FALLBACK(LOP_NAMECALL); + CODEGEN_SET_FALLBACK(LOP_FORGPREP); + CODEGEN_SET_FALLBACK(LOP_GETVARARGS); + CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE); + CODEGEN_SET_FALLBACK(LOP_PREPVARARGS); + CODEGEN_SET_FALLBACK(LOP_BREAK); + CODEGEN_SET_FALLBACK(LOP_SETLIST); // Fallbacks that are called from partial implementation of an instruction // TODO: these fallbacks should be replaced with special functions that exclude the (redundantly executed) fast path from the fallback - CODEGEN_SET_FALLBACK(LOP_GETGLOBAL, 0); - CODEGEN_SET_FALLBACK(LOP_SETGLOBAL, 0); - CODEGEN_SET_FALLBACK(LOP_GETTABLEKS, 0); - CODEGEN_SET_FALLBACK(LOP_SETTABLEKS, 0); + CODEGEN_SET_FALLBACK(LOP_GETGLOBAL); + CODEGEN_SET_FALLBACK(LOP_SETGLOBAL); + CODEGEN_SET_FALLBACK(LOP_GETTABLEKS); + CODEGEN_SET_FALLBACK(LOP_SETTABLEKS); } void initHelperFunctions(NativeState& data) @@ -105,6 +106,7 @@ void initHelperFunctions(NativeState& data) data.context.libm_tan = tan; data.context.libm_tanh = tanh; + data.context.forgLoopTableIter = forgLoopTableIter; data.context.forgLoopNodeIter = forgLoopNodeIter; data.context.forgLoopNonTableFallback = forgLoopNonTableFallback; data.context.forgPrepXnextFallback = forgPrepXnextFallback; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 2d97e63ca..99d408907 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -23,15 +23,7 @@ namespace CodeGen class UnwindBuilder; -using FallbackFn = const Instruction*(lua_State* L, const Instruction* pc, StkId base, TValue* k); - -constexpr uint8_t kFallbackUpdatePc = 1 << 0; - -struct NativeFallback -{ - FallbackFn* fallback; - uint8_t flags; -}; +using FallbackFn = const Instruction* (*)(lua_State* L, const Instruction* pc, StkId base, TValue* k); struct NativeProto { @@ -96,6 +88,7 @@ struct NativeContext double (*libm_modf)(double, double*) = nullptr; // Helper functions + bool (*forgLoopTableIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr; bool (*forgLoopNodeIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr; bool (*forgLoopNonTableFallback)(lua_State* L, int insnA, int aux) = nullptr; void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; @@ -106,7 +99,7 @@ struct NativeContext Closure* (*returnFallback)(lua_State* L, StkId ra, int n) = nullptr; // Opcode fallbacks, implemented in C - NativeFallback fallback[LOP__COUNT] = {}; + FallbackFn fallback[LOP__COUNT] = {}; // Fast call methods, implemented in C luau_FastFunction luauF_table[256] = {}; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 7157a18c4..c7d3d8e9a 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -502,6 +502,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } break; + + // TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it case IrCmd::FASTCALL: case IrCmd::INVOKE_FASTCALL: handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), vmRegOp(inst.b), function.intOp(inst.f)); diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index 0b3134ba3..b20a6b25a 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -132,7 +132,7 @@ size_t UnwindBuilderDwarf2::getBeginOffset() const return beginOffset; } -void UnwindBuilderDwarf2::start() +void UnwindBuilderDwarf2::startInfo() { uint8_t* cieLength = pos; pos = writeu32(pos, 0); // Length (to be filled later) @@ -149,13 +149,23 @@ void UnwindBuilderDwarf2::start() // Optional CIE augmentation section (not present) // Call frame instructions (common for all FDEs, of which we have 1) - stackOffset = 8; // Return address was pushed by calling the function - - pos = defineCfaExpression(pos, DW_REG_RSP, stackOffset); // Define CFA to be the rsp + 8 + pos = defineCfaExpression(pos, DW_REG_RSP, 8); // Define CFA to be the rsp + 8 pos = defineSavedRegisterLocation(pos, DW_REG_RA, 8); // Define return address register (RA) to be located at CFA - 8 pos = alignPosition(cieLength, pos); writeu32(cieLength, unsigned(pos - cieLength - 4)); // Length field itself is excluded from length +} + +void UnwindBuilderDwarf2::startFunction() +{ + // End offset is filled in later and everything gets adjusted at the end + UnwindFunctionDwarf2 func; + func.beginOffset = 0; + func.endOffset = 0; + func.fdeEntryStartPos = uint32_t(pos - rawData); + unwindFunctions.push_back(func); + + stackOffset = 8; // Return address was pushed by calling the function fdeEntryStart = pos; // Will be written at the end pos = writeu32(pos, 0); // Length (to be filled later) @@ -198,14 +208,20 @@ void UnwindBuilderDwarf2::setupFrameReg(X64::RegisterX64 reg, int espOffset) // Cfa is based on rsp, so no additonal commands are required } -void UnwindBuilderDwarf2::finish() +void UnwindBuilderDwarf2::finishFunction(uint32_t beginOffset, uint32_t endOffset) { + unwindFunctions.back().beginOffset = beginOffset; + unwindFunctions.back().endOffset = endOffset; + LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); LUAU_ASSERT(fdeEntryStart != nullptr); pos = alignPosition(fdeEntryStart, pos); writeu32(fdeEntryStart, unsigned(pos - fdeEntryStart - 4)); // Length field itself is excluded from length +} +void UnwindBuilderDwarf2::finishInfo() +{ // Terminate section pos = writeu32(pos, 0); @@ -217,15 +233,26 @@ size_t UnwindBuilderDwarf2::getSize() const return size_t(pos - rawData); } -void UnwindBuilderDwarf2::finalize(char* target, void* funcAddress, size_t funcSize) const +size_t UnwindBuilderDwarf2::getFunctionCount() const +{ + return unwindFunctions.size(); +} + +void UnwindBuilderDwarf2::finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const { memcpy(target, rawData, getSize()); - LUAU_ASSERT(fdeEntryStart != nullptr); - unsigned fdeEntryStartPos = unsigned(fdeEntryStart - rawData); + for (const UnwindFunctionDwarf2& func : unwindFunctions) + { + uint8_t* fdeEntryStart = (uint8_t*)target + func.fdeEntryStartPos; - writeu64((uint8_t*)target + fdeEntryStartPos + kFdeInitialLocationOffset, uintptr_t(funcAddress)); - writeu64((uint8_t*)target + fdeEntryStartPos + kFdeAddressRangeOffset, funcSize); + writeu64(fdeEntryStart + kFdeInitialLocationOffset, uintptr_t(funcAddress) + offset + func.beginOffset); + + if (func.endOffset == kFullBlockFuncton) + writeu64(fdeEntryStart + kFdeAddressRangeOffset, funcSize - offset); + else + writeu64(fdeEntryStart + kFdeAddressRangeOffset, func.endOffset - func.beginOffset); + } } } // namespace CodeGen diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 217330013..5f4f16a9a 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -21,17 +21,6 @@ namespace Luau namespace CodeGen { -// This struct matches the layout of UNWIND_INFO from ehdata.h -struct UnwindInfoWin -{ - uint8_t version : 3; - uint8_t flags : 5; - uint8_t prologsize; - uint8_t unwindcodecount; - uint8_t framereg : 4; - uint8_t frameregoff : 4; -}; - void UnwindBuilderWin::setBeginOffset(size_t beginOffset) { this->beginOffset = beginOffset; @@ -42,11 +31,28 @@ size_t UnwindBuilderWin::getBeginOffset() const return beginOffset; } -void UnwindBuilderWin::start() -{ - stackOffset = 8; // Return address was pushed by calling the function +void UnwindBuilderWin::startInfo() {} +void UnwindBuilderWin::startFunction() +{ + // End offset is filled in later and everything gets adjusted at the end + UnwindFunctionWin func; + func.beginOffset = 0; + func.endOffset = 0; + func.unwindInfoOffset = uint32_t(rawDataPos - rawData); + unwindFunctions.push_back(func); + + unwindCodes.clear(); unwindCodes.reserve(16); + + prologSize = 0; + + // rax has register index 0, which in Windows unwind info means that frame register is not used + frameReg = X64::rax; + frameRegOffset = 0; + + // Return address was pushed by calling the function + stackOffset = 8; } void UnwindBuilderWin::spill(int espOffset, X64::RegisterX64 reg) @@ -85,49 +91,89 @@ void UnwindBuilderWin::setupFrameReg(X64::RegisterX64 reg, int espOffset) unwindCodes.push_back({prologSize, UWOP_SET_FPREG, frameRegOffset}); } -void UnwindBuilderWin::finish() +void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset) { + unwindFunctions.back().beginOffset = beginOffset; + unwindFunctions.back().endOffset = endOffset; + // Windows unwind code count is stored in uint8_t, so we can't have more LUAU_ASSERT(unwindCodes.size() < 256); LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); - size_t codeArraySize = unwindCodes.size(); - codeArraySize = (codeArraySize + 1) & ~1; // Size has to be even, but unwind code count doesn't have to - - infoSize = sizeof(UnwindInfoWin) + sizeof(UnwindCodeWin) * codeArraySize; -} - -size_t UnwindBuilderWin::getSize() const -{ - return infoSize; -} - -void UnwindBuilderWin::finalize(char* target, void* funcAddress, size_t funcSize) const -{ UnwindInfoWin info; info.version = 1; info.flags = 0; // No EH info.prologsize = prologSize; info.unwindcodecount = uint8_t(unwindCodes.size()); + + LUAU_ASSERT(frameReg.index < 16); info.framereg = frameReg.index; + + LUAU_ASSERT(frameRegOffset < 16); info.frameregoff = frameRegOffset; - memcpy(target, &info, sizeof(info)); - target += sizeof(UnwindInfoWin); + LUAU_ASSERT(rawDataPos + sizeof(info) <= rawData + kRawDataLimit); + memcpy(rawDataPos, &info, sizeof(info)); + rawDataPos += sizeof(info); if (!unwindCodes.empty()) { // Copy unwind codes in reverse order // Some unwind codes take up two array slots, but we don't use those atm - char* pos = target + sizeof(UnwindCodeWin) * (unwindCodes.size() - 1); + uint8_t* unwindCodePos = rawDataPos + sizeof(UnwindCodeWin) * (unwindCodes.size() - 1); + LUAU_ASSERT(unwindCodePos <= rawData + kRawDataLimit); for (size_t i = 0; i < unwindCodes.size(); i++) { - memcpy(pos, &unwindCodes[i], sizeof(UnwindCodeWin)); - pos -= sizeof(UnwindCodeWin); + memcpy(unwindCodePos, &unwindCodes[i], sizeof(UnwindCodeWin)); + unwindCodePos -= sizeof(UnwindCodeWin); } } + + rawDataPos += sizeof(UnwindCodeWin) * unwindCodes.size(); + + // Size has to be even, but unwind code count doesn't have to + if (unwindCodes.size() % 2 != 0) + rawDataPos += sizeof(UnwindCodeWin); + + LUAU_ASSERT(rawDataPos <= rawData + kRawDataLimit); +} + +void UnwindBuilderWin::finishInfo() {} + +size_t UnwindBuilderWin::getSize() const +{ + return sizeof(UnwindFunctionWin) * unwindFunctions.size() + size_t(rawDataPos - rawData); +} + +size_t UnwindBuilderWin::getFunctionCount() const +{ + return unwindFunctions.size(); +} + +void UnwindBuilderWin::finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const +{ + // Copy adjusted function information + for (UnwindFunctionWin func : unwindFunctions) + { + // Code will start after the unwind info + func.beginOffset += uint32_t(offset); + + // Whole block is a part of a 'single function' + if (func.endOffset == kFullBlockFuncton) + func.endOffset = uint32_t(funcSize); + else + func.endOffset += uint32_t(offset); + + // Unwind data is placed right after the RUNTIME_FUNCTION data + func.unwindInfoOffset += uint32_t(sizeof(UnwindFunctionWin) * unwindFunctions.size()); + memcpy(target, &func, sizeof(func)); + target += sizeof(func); + } + + // Copy unwind codes + memcpy(target, rawData, size_t(rawDataPos - rawData)); } } // namespace CodeGen diff --git a/Sources.cmake b/Sources.cmake index 3508ec39e..9f54b91e3 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -89,9 +89,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/CodeGenA64.cpp CodeGen/src/CodeGenX64.cpp CodeGen/src/EmitBuiltinsX64.cpp - CodeGen/src/EmitCommonA64.cpp CodeGen/src/EmitCommonX64.cpp - CodeGen/src/EmitInstructionA64.cpp CodeGen/src/EmitInstructionX64.cpp CodeGen/src/Fallbacks.cpp CodeGen/src/IrAnalysis.cpp @@ -111,6 +109,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/UnwindBuilderDwarf2.cpp CodeGen/src/UnwindBuilderWin.cpp + CodeGen/src/BitUtils.h CodeGen/src/ByteUtils.h CodeGen/src/CustomExecUtils.h CodeGen/src/CodeGenUtils.h @@ -120,7 +119,6 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/EmitCommon.h CodeGen/src/EmitCommonA64.h CodeGen/src/EmitCommonX64.h - CodeGen/src/EmitInstructionA64.h CodeGen/src/EmitInstructionX64.h CodeGen/src/Fallbacks.h CodeGen/src/FallbacksProlog.h diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 1528aa39e..25add42a0 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -538,6 +538,8 @@ const void* lua_topointer(lua_State* L, int idx) StkId o = index2addr(L, idx); switch (ttype(o)) { + case LUA_TSTRING: + return tsvalue(o); case LUA_TTABLE: return hvalue(o); case LUA_TFUNCTION: diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 5eceea746..c963ac8d0 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -33,8 +33,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauArrBoundResizeFix, false) - // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -466,30 +464,22 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) int na = computesizes(nums, &nasize); int nh = totaluse - na; - if (FFlag::LuauArrBoundResizeFix) - { - // enforce the boundary invariant; for performance, only do hash lookups if we must - int nadjusted = adjustasize(t, nasize, ek); + // enforce the boundary invariant; for performance, only do hash lookups if we must + int nadjusted = adjustasize(t, nasize, ek); - // count how many extra elements belong to array part instead of hash part - int aextra = nadjusted - nasize; + // count how many extra elements belong to array part instead of hash part + int aextra = nadjusted - nasize; - if (aextra != 0) - { - // we no longer need to store those extra array elements in hash part - nh -= aextra; + if (aextra != 0) + { + // we no longer need to store those extra array elements in hash part + nh -= aextra; - // because hash nodes are twice as large as array nodes, the memory we saved for hash parts can be used by array part - // this follows the general sparse array part optimization where array is allocated when 50% occupation is reached - nasize = nadjusted + aextra; + // because hash nodes are twice as large as array nodes, the memory we saved for hash parts can be used by array part + // this follows the general sparse array part optimization where array is allocated when 50% occupation is reached + nasize = nadjusted + aextra; - // since the size was changed, it's again important to enforce the boundary invariant at the new size - nasize = adjustasize(t, nasize, ek); - } - } - else - { - // enforce the boundary invariant; for performance, only do hash lookups if we must + // since the size was changed, it's again important to enforce the boundary invariant at the new size nasize = adjustasize(t, nasize, ek); } diff --git a/fuzz/linter.cpp b/fuzz/linter.cpp index 854c63277..8efd42469 100644 --- a/fuzz/linter.cpp +++ b/fuzz/linter.cpp @@ -21,7 +21,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) static Luau::NullFileResolver fileResolver; static Luau::NullConfigResolver configResolver; static Luau::Frontend frontend{&fileResolver, &configResolver}; - static int once = (Luau::registerBuiltinGlobals(frontend), 1); + static int once = (Luau::registerBuiltinGlobals(frontend, frontend.globals, false), 1); (void)once; static int once2 = (Luau::freeze(frontend.globals.globalTypes), 1); (void)once2; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index ffeb49195..9366da5e2 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -97,12 +97,12 @@ lua_State* createGlobalState() return L; } -int registerTypes(Luau::TypeChecker& typeChecker, Luau::GlobalTypes& globals) +int registerTypes(Luau::Frontend& frontend, Luau::GlobalTypes& globals, bool forAutocomplete) { using namespace Luau; using std::nullopt; - Luau::registerBuiltinGlobals(typeChecker, globals); + Luau::registerBuiltinGlobals(frontend, globals, forAutocomplete); TypeArena& arena = globals.globalTypes; BuiltinTypes& builtinTypes = *globals.builtinTypes; @@ -147,10 +147,10 @@ int registerTypes(Luau::TypeChecker& typeChecker, Luau::GlobalTypes& globals) static void setupFrontend(Luau::Frontend& frontend) { - registerTypes(frontend.typeChecker, frontend.globals); + registerTypes(frontend, frontend.globals, false); Luau::freeze(frontend.globals.globalTypes); - registerTypes(frontend.typeCheckerForAutocomplete, frontend.globalsForAutocomplete); + registerTypes(frontend, frontend.globalsForAutocomplete, true); Luau::freeze(frontend.globalsForAutocomplete.globalTypes); frontend.iceHandler.onInternalError = [](const char* error) { diff --git a/fuzz/typeck.cpp b/fuzz/typeck.cpp index 4f8f88575..87a882717 100644 --- a/fuzz/typeck.cpp +++ b/fuzz/typeck.cpp @@ -26,7 +26,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) static Luau::NullFileResolver fileResolver; static Luau::NullConfigResolver configResolver; static Luau::Frontend frontend{&fileResolver, &configResolver}; - static int once = (Luau::registerBuiltinGlobals(frontend), 1); + static int once = (Luau::registerBuiltinGlobals(frontend, frontend.globals, false), 1); (void)once; static int once2 = (Luau::freeze(frontend.globals.globalTypes), 1); (void)once2; diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 1690c748c..a0df0f9ba 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -86,6 +86,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") SINGLE_COMPARE(add(x0, x1, x2, 7), 0x8B021C20); SINGLE_COMPARE(sub(x0, x1, x2), 0xCB020020); SINGLE_COMPARE(and_(x0, x1, x2), 0x8A020020); + SINGLE_COMPARE(bic(x0, x1, x2), 0x8A220020); SINGLE_COMPARE(orr(x0, x1, x2), 0xAA020020); SINGLE_COMPARE(eor(x0, x1, x2), 0xCA020020); SINGLE_COMPARE(lsl(x0, x1, x2), 0x9AC22020); @@ -94,6 +95,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") SINGLE_COMPARE(asr(x0, x1, x2), 0x9AC22820); SINGLE_COMPARE(ror(x0, x1, x2), 0x9AC22C20); SINGLE_COMPARE(cmp(x0, x1), 0xEB01001F); + SINGLE_COMPARE(tst(x0, x1), 0xEA01001F); // reg, imm SINGLE_COMPARE(add(x3, x7, 78), 0x910138E3); @@ -102,6 +104,24 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") SINGLE_COMPARE(cmp(w0, 42), 0x7100A81F); } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "BinaryImm") +{ + // instructions + SINGLE_COMPARE(and_(w1, w2, 1), 0x12000041); + SINGLE_COMPARE(orr(w1, w2, 1), 0x32000041); + SINGLE_COMPARE(eor(w1, w2, 1), 0x52000041); + SINGLE_COMPARE(tst(w1, 1), 0x7200003f); + + // various mask forms + SINGLE_COMPARE(and_(w0, w0, 1), 0x12000000); + SINGLE_COMPARE(and_(w0, w0, 3), 0x12000400); + SINGLE_COMPARE(and_(w0, w0, 7), 0x12000800); + SINGLE_COMPARE(and_(w0, w0, 2147483647), 0x12007800); + SINGLE_COMPARE(and_(w0, w0, 6), 0x121F0400); + SINGLE_COMPARE(and_(w0, w0, 12), 0x121E0400); + SINGLE_COMPARE(and_(w0, w0, 2147483648), 0x12010000); +} + TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads") { // address forms @@ -359,11 +379,13 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "AddressOffsetSize") SINGLE_COMPARE(str(q0, mem(x1, 16)), 0x3D800420); } -TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ConditionalSelect") +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Conditionals") { SINGLE_COMPARE(csel(x0, x1, x2, ConditionA64::Equal), 0x9A820020); SINGLE_COMPARE(csel(w0, w1, w2, ConditionA64::Equal), 0x1A820020); SINGLE_COMPARE(fcsel(d0, d1, d2, ConditionA64::Equal), 0x1E620C20); + + SINGLE_COMPARE(cset(x1, ConditionA64::Less), 0x9A9FA7E1); } TEST_CASE("LogTest") @@ -394,6 +416,7 @@ TEST_CASE("LogTest") build.ldr(q1, x2); build.csel(x0, x1, x2, ConditionA64::Equal); + build.cset(x0, ConditionA64::Equal); build.fcmp(d0, d1); build.fcmpz(d0); @@ -423,6 +446,7 @@ TEST_CASE("LogTest") fabs d1,d2 ldr q1,[x2] csel x0,x1,x2,eq + cset x0,eq fcmp d0,d1 fcmp d0,#0 .L1: diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 054eca7bf..bafb68bc3 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -67,6 +67,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseBinaryInstructionForms") SINGLE_COMPARE(add(rax, 0x7f), 0x48, 0x83, 0xc0, 0x7f); SINGLE_COMPARE(add(rax, 0x80), 0x48, 0x81, 0xc0, 0x80, 0x00, 0x00, 0x00); SINGLE_COMPARE(add(r10, 0x7fffffff), 0x49, 0x81, 0xc2, 0xff, 0xff, 0xff, 0x7f); + SINGLE_COMPARE(add(al, 3), 0x80, 0xc0, 0x03); + SINGLE_COMPARE(add(sil, 3), 0x48, 0x80, 0xc6, 0x03); + SINGLE_COMPARE(add(r11b, 3), 0x49, 0x80, 0xc3, 0x03); // reg, [reg] SINGLE_COMPARE(add(rax, qword[rax]), 0x48, 0x03, 0x00); @@ -191,6 +194,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfMov") SINGLE_COMPARE(mov64(rcx, 0x1234567812345678ll), 0x48, 0xb9, 0x78, 0x56, 0x34, 0x12, 0x78, 0x56, 0x34, 0x12); SINGLE_COMPARE(mov(ecx, 2), 0xb9, 0x02, 0x00, 0x00, 0x00); SINGLE_COMPARE(mov(cl, 2), 0xb1, 0x02); + SINGLE_COMPARE(mov(sil, 2), 0x48, 0xb6, 0x02); + SINGLE_COMPARE(mov(r9b, 2), 0x49, 0xb1, 0x02); SINGLE_COMPARE(mov(rcx, qword[rdi]), 0x48, 0x8b, 0x0f); SINGLE_COMPARE(mov(dword[rax], 0xabcd), 0xc7, 0x00, 0xcd, 0xab, 0x00, 0x00); SINGLE_COMPARE(mov(r13, 1), 0x49, 0xbd, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00); @@ -201,6 +206,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfMov") SINGLE_COMPARE(mov(qword[rdx], r9), 0x4c, 0x89, 0x0a); SINGLE_COMPARE(mov(byte[rsi], 0x3), 0xc6, 0x06, 0x03); SINGLE_COMPARE(mov(byte[rsi], al), 0x88, 0x06); + SINGLE_COMPARE(mov(byte[rsi], dil), 0x48, 0x88, 0x3e); + SINGLE_COMPARE(mov(byte[rsi], r10b), 0x4c, 0x88, 0x16); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfMovExtended") @@ -229,6 +236,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfShift") { SINGLE_COMPARE(shl(al, 1), 0xd0, 0xe0); SINGLE_COMPARE(shl(al, cl), 0xd2, 0xe0); + SINGLE_COMPARE(shl(sil, cl), 0x48, 0xd2, 0xe6); + SINGLE_COMPARE(shl(r10b, cl), 0x49, 0xd2, 0xe2); SINGLE_COMPARE(shr(al, 4), 0xc0, 0xe8, 0x04); SINGLE_COMPARE(shr(eax, 1), 0xd1, 0xe8); SINGLE_COMPARE(sal(eax, cl), 0xd3, 0xe0); @@ -247,6 +256,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfSetcc") { SINGLE_COMPARE(setcc(ConditionX64::NotEqual, bl), 0x0f, 0x95, 0xc3); + SINGLE_COMPARE(setcc(ConditionX64::NotEqual, dil), 0x48, 0x0f, 0x95, 0xc7); SINGLE_COMPARE(setcc(ConditionX64::BelowEqual, byte[rcx]), 0x0f, 0x96, 0x01); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index c79bf35ea..3dc75d627 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3473,4 +3473,34 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_response_perf1" * doctest::timeout(0. CHECK(ac.entryMap.count("Instance")); } +TEST_CASE_FIXTURE(ACFixture, "strict_mode_force") +{ + check(R"( +--!nonstrict +local a: {x: number} = {x=1} +local b = a +local c = b.@1 + )"); + + auto ac = autocomplete('1'); + + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("x")); +} + +TEST_CASE_FIXTURE(ACFixture, "suggest_exported_types") +{ + ScopedFastFlag luauCopyExportedTypes{"LuauCopyExportedTypes", true}; + + check(R"( +export type Type = {a: number} +local a: T@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("Type")); + CHECK_EQ(ac.context, AutocompleteContext::Type); +} + TEST_SUITE_END(); diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 359f2ba1c..01deddd3f 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -135,7 +135,8 @@ TEST_CASE("WindowsUnwindCodesX64") UnwindBuilderWin unwind; - unwind.start(); + unwind.startInfo(); + unwind.startFunction(); unwind.spill(16, rdx); unwind.spill(8, rcx); unwind.save(rdi); @@ -148,14 +149,15 @@ TEST_CASE("WindowsUnwindCodesX64") unwind.save(r15); unwind.allocStack(72); unwind.setupFrameReg(rbp, 48); - unwind.finish(); + unwind.finishFunction(0x11223344, 0x55443322); + unwind.finishInfo(); std::vector data; data.resize(unwind.getSize()); - unwind.finalize(data.data(), nullptr, 0); + unwind.finalize(data.data(), 0, nullptr, 0); - std::vector expected{0x01, 0x23, 0x0a, 0x35, 0x23, 0x33, 0x1e, 0x82, 0x1a, 0xf0, 0x18, 0xe0, 0x16, 0xd0, 0x14, 0xc0, 0x12, 0x50, 0x10, - 0x30, 0x0e, 0x60, 0x0c, 0x70}; + std::vector expected{0x44, 0x33, 0x22, 0x11, 0x22, 0x33, 0x44, 0x55, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x23, 0x0a, 0x35, 0x23, 0x33, 0x1e, + 0x82, 0x1a, 0xf0, 0x18, 0xe0, 0x16, 0xd0, 0x14, 0xc0, 0x12, 0x50, 0x10, 0x30, 0x0e, 0x60, 0x0c, 0x70}; REQUIRE(data.size() == expected.size()); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); @@ -168,7 +170,8 @@ TEST_CASE("Dwarf2UnwindCodesX64") UnwindBuilderDwarf2 unwind; - unwind.start(); + unwind.startInfo(); + unwind.startFunction(); unwind.save(rdi); unwind.save(rsi); unwind.save(rbx); @@ -179,11 +182,12 @@ TEST_CASE("Dwarf2UnwindCodesX64") unwind.save(r15); unwind.allocStack(72); unwind.setupFrameReg(rbp, 48); - unwind.finish(); + unwind.finishFunction(0, 0); + unwind.finishInfo(); std::vector data; data.resize(unwind.getSize()); - unwind.finalize(data.data(), nullptr, 0); + unwind.finalize(data.data(), 0, nullptr, 0); std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x05, 0x10, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -211,6 +215,8 @@ constexpr X64::RegisterX64 rArg3 = X64::rdx; constexpr X64::RegisterX64 rNonVol1 = X64::r12; constexpr X64::RegisterX64 rNonVol2 = X64::rbx; +constexpr X64::RegisterX64 rNonVol3 = X64::r13; +constexpr X64::RegisterX64 rNonVol4 = X64::r14; TEST_CASE("GeneratedCodeExecutionX64") { @@ -260,7 +266,10 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") std::unique_ptr unwind = std::make_unique(); #endif - unwind->start(); + unwind->startInfo(); + + Label functionBegin = build.setLabel(); + unwind->startFunction(); // Prologue build.push(rNonVol1); @@ -279,8 +288,6 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") build.lea(rbp, addr[rsp + stackSize]); unwind->setupFrameReg(rbp, stackSize); - unwind->finish(); - // Body build.mov(rNonVol1, rArg1); build.mov(rNonVol2, rArg2); @@ -296,8 +303,12 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") build.pop(rNonVol1); build.ret(); + unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); + build.finalize(); + unwind->finishInfo(); + size_t blockSize = 1024 * 1024; size_t maxTotalSize = 1024 * 1024; CodeAllocator allocator(blockSize, maxTotalSize); @@ -326,6 +337,152 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") } } +TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") +{ + using namespace X64; + + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->startInfo(); + + Label start1; + Label start2; + + // First function + { + build.setLabel(start1); + unwind->startFunction(); + + // Prologue + build.push(rNonVol1); + unwind->save(rNonVol1); + build.push(rNonVol2); + unwind->save(rNonVol2); + build.push(rbp); + unwind->save(rbp); + + int stackSize = 32; + int localsSize = 16; + + build.sub(rsp, stackSize + localsSize); + unwind->allocStack(stackSize + localsSize); + + build.lea(rbp, addr[rsp + stackSize]); + unwind->setupFrameReg(rbp, stackSize); + + // Body + build.mov(rNonVol1, rArg1); + build.mov(rNonVol2, rArg2); + + build.add(rNonVol1, 15); + build.mov(rArg1, rNonVol1); + build.call(rNonVol2); + + // Epilogue + build.lea(rsp, addr[rbp + localsSize]); + build.pop(rbp); + build.pop(rNonVol2); + build.pop(rNonVol1); + build.ret(); + + Label end1 = build.setLabel(); + unwind->finishFunction(build.getLabelOffset(start1), build.getLabelOffset(end1)); + } + + // Second function with different layout + { + build.setLabel(start2); + unwind->startFunction(); + + // Prologue + build.push(rNonVol1); + unwind->save(rNonVol1); + build.push(rNonVol2); + unwind->save(rNonVol2); + build.push(rNonVol3); + unwind->save(rNonVol3); + build.push(rNonVol4); + unwind->save(rNonVol4); + build.push(rbp); + unwind->save(rbp); + + int stackSize = 32; + int localsSize = 32; + + build.sub(rsp, stackSize + localsSize); + unwind->allocStack(stackSize + localsSize); + + build.lea(rbp, addr[rsp + stackSize]); + unwind->setupFrameReg(rbp, stackSize); + + // Body + build.mov(rNonVol3, rArg1); + build.mov(rNonVol4, rArg2); + + build.add(rNonVol3, 15); + build.mov(rArg1, rNonVol3); + build.call(rNonVol4); + + // Epilogue + build.lea(rsp, addr[rbp + localsSize]); + build.pop(rbp); + build.pop(rNonVol4); + build.pop(rNonVol3); + build.pop(rNonVol2); + build.pop(rNonVol1); + build.ret(); + + unwind->finishFunction(build.getLabelOffset(start2), ~0u); + } + + build.finalize(); + + unwind->finishInfo(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f1 = (FunctionType*)(nativeEntry + start1.location); + FunctionType* f2 = (FunctionType*)(nativeEntry + start2.location); + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f1(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } + + try + { + f2(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} + TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") { using namespace X64; @@ -338,7 +495,10 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") std::unique_ptr unwind = std::make_unique(); #endif - unwind->start(); + unwind->startInfo(); + + Label functionBegin = build.setLabel(); + unwind->startFunction(); // Prologue (some of these registers don't have to be saved, but we want to have a big prologue) build.push(r10); @@ -365,8 +525,6 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") build.lea(rbp, addr[rsp + stackSize]); unwind->setupFrameReg(rbp, stackSize); - unwind->finish(); - size_t prologueSize = build.setLabel().location; // Body @@ -387,8 +545,12 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") build.pop(r10); build.ret(); + unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); + build.finalize(); + unwind->finishInfo(); + size_t blockSize = 4096; // Force allocate to create a new block each time size_t maxTotalSize = 1024 * 1024; CodeAllocator allocator(blockSize, maxTotalSize); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 2a32bce2d..ec3d8b847 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -285,8 +285,16 @@ TEST_CASE("Tables") lua_pushcfunction( L, [](lua_State* L) { - unsigned v = luaL_checkunsigned(L, 1); - lua_pushlightuserdata(L, reinterpret_cast(uintptr_t(v))); + if (lua_type(L, 1) == LUA_TNUMBER) + { + unsigned v = luaL_checkunsigned(L, 1); + lua_pushlightuserdata(L, reinterpret_cast(uintptr_t(v))); + } + else + { + const void* p = lua_topointer(L, 1); + lua_pushlightuserdata(L, const_cast(p)); + } return 1; }, "makelud"); @@ -402,21 +410,24 @@ TEST_CASE("PCall") { ScopedFastFlag sff("LuauBetterOOMHandling", true); - runConformance("pcall.lua", [](lua_State* L) { - lua_pushcfunction(L, cxxthrow, "cxxthrow"); - lua_setglobal(L, "cxxthrow"); + runConformance( + "pcall.lua", + [](lua_State* L) { + lua_pushcfunction(L, cxxthrow, "cxxthrow"); + lua_setglobal(L, "cxxthrow"); - lua_pushcfunction( - L, - [](lua_State* L) -> int { - lua_State* co = lua_tothread(L, 1); - lua_xmove(L, co, 1); - lua_resumeerror(co, L); - return 0; - }, - "resumeerror"); - lua_setglobal(L, "resumeerror"); - }, nullptr, lua_newstate(limitedRealloc, nullptr)); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + lua_State* co = lua_tothread(L, 1); + lua_xmove(L, co, 1); + lua_resumeerror(co, L); + return 0; + }, + "resumeerror"); + lua_setglobal(L, "resumeerror"); + }, + nullptr, lua_newstate(limitedRealloc, nullptr)); } TEST_CASE("Pack") diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index aebf177cd..aba2891e2 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -21,6 +21,7 @@ static const char* mainModuleName = "MainModule"; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauOnDemandTypecheckers); extern std::optional randomSeed; // tests/main.cpp @@ -180,9 +181,16 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); } + else if (!FFlag::LuauOnDemandTypecheckers) + { + ModulePtr module = frontend.typeChecker_DEPRECATED.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + + Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); + } else { - ModulePtr module = frontend.typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + TypeChecker typeChecker(frontend.globals.globalScope, &moduleResolver, builtinTypes, &frontend.iceHandler); + ModulePtr module = typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict), std::nullopt); Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); } diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 7e61235a8..3c613a1f6 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -3,6 +3,7 @@ #include "Luau/Module.h" #include "Luau/Scope.h" #include "Luau/RecursionCounter.h" +#include "Luau/Parser.h" #include "Fixture.h" @@ -42,6 +43,38 @@ TEST_CASE_FIXTURE(Fixture, "is_within_comment") CHECK(!isWithinComment(*sm, Position{7, 11})); } +TEST_CASE_FIXTURE(Fixture, "is_within_comment_parse_result") +{ + std::string src = R"( + --!strict + local foo = {} + function foo:bar() end + + --[[ + foo: + ]] foo:bar() + + --[[]]--[[]] -- Two distinct comments that have zero characters of space between them. + )"; + + Luau::Allocator alloc; + Luau::AstNameTable names{alloc}; + Luau::ParseOptions parseOptions; + parseOptions.captureComments = true; + Luau::ParseResult parseResult = Luau::Parser::parse(src.data(), src.size(), names, alloc, parseOptions); + + CHECK_EQ(5, parseResult.commentLocations.size()); + + CHECK(isWithinComment(parseResult, Position{1, 15})); + CHECK(isWithinComment(parseResult, Position{6, 16})); + CHECK(isWithinComment(parseResult, Position{9, 13})); + CHECK(isWithinComment(parseResult, Position{9, 14})); + + CHECK(!isWithinComment(parseResult, Position{2, 15})); + CHECK(!isWithinComment(parseResult, Position{7, 10})); + CHECK(!isWithinComment(parseResult, Position{7, 11})); +} + TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") { TypeArena dest; @@ -319,6 +352,10 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") { + ScopedFastFlag flags[] = { + {"LuauOccursIsntAlwaysFailure", true}, + }; + fileResolver.source["Module/A"] = R"( export type A = B type B = A @@ -332,7 +369,7 @@ type B = A auto mod = frontend.moduleResolver.getModule("Module/A"); auto it = mod->exportedTypeBindings.find("A"); REQUIRE(it != mod->exportedTypeBindings.end()); - CHECK(toString(it->second.type) == "any"); + CHECK(toString(it->second.type) == "*error-type*"); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_reexports") diff --git a/tests/StringUtils.test.cpp b/tests/StringUtils.test.cpp index afef3b06e..786f965ea 100644 --- a/tests/StringUtils.test.cpp +++ b/tests/StringUtils.test.cpp @@ -106,4 +106,22 @@ TEST_CASE("AreWeUsingDistanceWithAdjacentTranspositionsAndNotOptimalStringAlignm CHECK_EQ(distance, 2); } +TEST_CASE("EditDistanceSupportsUnicode") +{ + // ASCII character + CHECK_EQ(Luau::editDistance("A block", "X block"), 1); + + // UTF-8 2 byte character + CHECK_EQ(Luau::editDistance("A block", "Ă€ block"), 2); + + // UTF-8 3 byte character + CHECK_EQ(Luau::editDistance("A block", "⪻ block"), 3); + + // UTF-8 4 byte character + CHECK_EQ(Luau::editDistance("A block", "đ’‹„ block"), 4); + + // UTF-8 extreme characters + CHECK_EQ(Luau::editDistance("A block", "R̴̨̢̟̚ŏ̶̳̳͚ĚÍ…b̶̡̻̞Ě̿ͅl̸̼͝ợ̷̜͓̒̏͜͝ẍ̴̝̦̟̰ĚĚ’ĚĚŚ block"), 85); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2c87cb419..3de529998 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -435,6 +435,10 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr") TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") { + ScopedFastFlag flags[] = { + {"LuauOccursIsntAlwaysFailure", true}, + }; + CheckResult result = check(R"( type A = B type B = A @@ -443,10 +447,10 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") local bb:B )"); - TypeId fType = requireType("aa"); - const AnyType* ftv = get(follow(fType)); - REQUIRE(ftv != nullptr); - REQUIRE(!result.errors.empty()); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + OccursCheckFailed* ocf = get(result.errors[0]); + REQUIRE(ocf); } TEST_CASE_FIXTURE(Fixture, "type_alias_always_resolve_to_a_real_type") @@ -762,6 +766,7 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_union_type") { CheckResult result = check(R"( type T = T | T + local x : T )"); LUAU_REQUIRE_ERROR_COUNT(1, result); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index f1d42c6a4..942ce191f 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1281,6 +1281,39 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "variadic_any_is_compatible_with_a_generic_TypePack") +{ + ScopedFastFlag sff[] = { + {"LuauVariadicAnyCanBeGeneric", true} + }; + + CheckResult result = check(R"( + --!strict + local function f(...) return ... end + local g = function(...) return f(...) end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// https://github.com/Roblox/luau/issues/767 +TEST_CASE_FIXTURE(BuiltinsFixture, "variadic_any_is_compatible_with_a_generic_TypePack_2") +{ + ScopedFastFlag sff{"LuauVariadicAnyCanBeGeneric", true}; + + CheckResult result = check(R"( + local function somethingThatsAny(...: any) + print(...) + end + + local function x(...: T...) + somethingThatsAny(...) -- Failed to unify variadic type packs + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 174bc310e..d224195c9 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -53,10 +53,6 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") TEST_CASE_FIXTURE(Fixture, "and_does_not_always_add_boolean") { - ScopedFastFlag sff[]{ - {"LuauTryhardAnd", true}, - }; - CheckResult result = check(R"( local s = "a" and 10 local x:boolean|number = s @@ -737,6 +733,8 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") { + ScopedFastFlag sff{"LuauOccursIsntAlwaysFailure", true}; + CheckResult result = check(R"( --!strict local _ @@ -744,7 +742,7 @@ TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); + CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to '_'", toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "UnknownGlobalCompoundAssign") @@ -1048,10 +1046,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_and") { - ScopedFastFlag sff[]{ - {"LuauTryhardAnd", true}, - }; - CheckResult result = check(R"( local a: number? = 5 local b: boolean = (a or 1) > 10 @@ -1077,10 +1071,6 @@ local w = c and 1 TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_or") { - ScopedFastFlag sff[]{ - {"LuauTryhardAnd", true}, - }; - CheckResult result = check(R"( local a: number | false = 5 local b: number? = 6 @@ -1115,11 +1105,6 @@ local f1 = f or 'f' TEST_CASE_FIXTURE(BuiltinsFixture, "reducing_and") { - ScopedFastFlag sff[]{ - {"LuauTryhardAnd", true}, - {"LuauReducingAndOr", true}, - }; - CheckResult result = check(R"( type Foo = { name: string?, flag: boolean? } local arr: {Foo} = {} @@ -1137,4 +1122,61 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_polyfill_is_array_simplified") +{ + CheckResult result = check(R"( + --!strict + return function(value: any) : boolean + if typeof(value) ~= "number" then + return false + end + if value % 1 ~= 0 or value < 1 then + return false + end + return true + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_polyfill_is_array") +{ + CheckResult result = check(R"( +--!strict +return function(value: any): boolean + if typeof(value) ~= "table" then + return false + end + if next(value) == nil then + -- an empty table is an empty array + return true + end + + local length = #value + + if length == 0 then + return false + end + + local count = 0 + local sum = 0 + for key in pairs(value) do + if typeof(key) ~= "number" then + return false + end + if key % 1 ~= 0 or key < 1 then + return false + end + count += 1 + sum += key + end + + return sum == (count * (count + 1) / 2) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 87419debb..e074bc871 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -320,23 +320,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } -TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") -{ - ScopedFastFlag sff[] = { - // I'm not sure why this is broken without DCR, but it seems to be fixed - // when DCR is enabled. - {"DebugLuauDeferredConstraintResolution", false}, - }; - - CheckResult result = check(R"( - --!strict - local function f(...) return ... end - local g = function(...) return f(...) end - )"); - - LUAU_REQUIRE_ERRORS(result); // Should not have any errors. -} - // Belongs in TypeInfer.builtins.test.cpp. TEST_CASE_FIXTURE(BuiltinsFixture, "pcall_returns_at_least_two_value_but_function_returns_nothing") { @@ -819,4 +802,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") } } +// We really should be warning on this. We have no guarantee that T has any properties. +TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions_of_tables_that_have_the_prop") +{ + CheckResult result = check(R"( + local function mergeOptions(options: T & ({variable: string} | {variable: number})) + return options.variable + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // LUAU_REQUIRE_ERROR_COUNT(1, result); + + // const UnknownProperty* unknownProp = get(result.errors[0]); + // REQUIRE(unknownProp); + + // CHECK("variable" == unknownProp->key); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 3088235ae..f540be071 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1195,6 +1195,21 @@ local b = typeof(foo) ~= 'nil' CHECK(toString(result.errors[1]) == "Unknown global 'foo'"); } +TEST_CASE_FIXTURE(Fixture, "occurs_isnt_always_failure") +{ + ScopedFastFlag sff{"LuauOccursIsntAlwaysFailure", true}; + + CheckResult result = check(R"( +function f(x, c) -- x : X + local y = if c then x else nil -- y : X? + local z = if c then x else nil -- z : X? + y = z +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "dcr_delays_expansion_of_function_containing_blocked_parameter_type") { ScopedFastFlag sff[] = { diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 19a19e450..19b221482 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -776,4 +776,20 @@ TEST_CASE_FIXTURE(Fixture, "generic_function_with_optional_arg") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions") +{ + CheckResult result = check(R"( + local function mergeOptions(options: T & ({} | {})) + return options.variables + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + const UnknownProperty* unknownProp = get(result.errors[0]); + REQUIRE(unknownProp); + + CHECK("variables" == unknownProp->key); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index 410fd52de..8558670c3 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -301,11 +301,6 @@ TEST_CASE_FIXTURE(Fixture, "length_of_never") TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators") { - ScopedFastFlag sff[]{ - {"LuauTryhardAnd", true}, - {"LuauReducingAndOr", true}, - }; - CheckResult result = check(R"( local function ord(x: nil, y) return x ~= nil and x > y diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 3f0becc54..dbf58cc80 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -273,12 +273,14 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId root = &ttvTweenResult; - frontend.typeChecker.currentModule = std::make_shared(); - frontend.typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(builtinTypes->anyTypePack)); - - TypeId result = frontend.typeChecker.anyify(frontend.globals.globalScope, root, Location{}); - - CHECK_EQ("{| f: t1 |} where t1 = () -> {| f: () -> {| f: ({| f: t1 |}) -> (), signal: {| f: (any) -> () |} |} |}", toString(result)); + ModulePtr currentModule = std::make_shared(); + Anyification anyification(¤tModule->internalTypes, frontend.globals.globalScope, builtinTypes, &frontend.iceHandler, builtinTypes->anyType, + builtinTypes->anyTypePack); + std::optional any = anyification.substitute(root); + + REQUIRE(!anyification.normalizationTooComplex); + REQUIRE(any.has_value()); + CHECK_EQ("{| f: t1 |} where t1 = () -> {| f: () -> {| f: ({| f: t1 |}) -> (), signal: {| f: (any) -> () |} |} |}", toString(*any)); } TEST_CASE("tagging_tables") diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index ea3b5c87a..473427309 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -347,5 +347,15 @@ assert(select('#', math.ceil(1.6)) == 1) assert(select('#', math.sqrt(9)) == 1) assert(select('#', math.deg(9)) == 1) assert(select('#', math.rad(9)) == 1) +assert(select('#', math.sin(1.5)) == 1) +assert(select('#', math.atan2(1.5, 0.5)) == 1) +assert(select('#', math.modf(1.5)) == 2) +assert(select('#', math.frexp(1.5)) == 2) + +-- test that fastcalls that return variadic results return them correctly in variadic position +assert(select(1, math.modf(1.5)) == 1) +assert(select(2, math.modf(1.5)) == 0.5) +assert(select(1, math.frexp(1.5)) == 0.75) +assert(select(2, math.frexp(1.5)) == 1) return('OK') diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 596eed3db..03b463968 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -715,4 +715,11 @@ do end end +-- check that fast path for table lookup can't be tricked into assuming a light user data with string pointer is a string +assert((function () + local t = {} + t[makelud("hi")] = "no" + return t.hi +end)() == nil) + return"OK" diff --git a/tools/lvmexecute_split.py b/tools/lvmexecute_split.py index 16de45dcc..6e64bcd0e 100644 --- a/tools/lvmexecute_split.py +++ b/tools/lvmexecute_split.py @@ -34,7 +34,7 @@ function = "" signature = "" -includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS"] +includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS", "LOP_SETLIST"] state = 0 From 33b95582acbd3c3021f634e8f4d2fe29a363452d Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 14 Apr 2023 16:46:13 +0300 Subject: [PATCH 47/66] Build fix --- CodeGen/src/AssemblyBuilderA64.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index d6274256e..bb7c94398 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -851,8 +851,8 @@ void AssemblyBuilderA64::placeBM(const char* name, RegisterA64 dst, RegisterA64 int lz = countlz(src2); int rz = countrz(src2); - LUAU_ASSERT(lz + rz > 0 && lz + rz < 32); // must have at least one 0 and at least one 1 - LUAU_ASSERT((src2 >> rz) == (1 << (32 - lz - rz)) - 1); // sequence of 1s must be contiguous + LUAU_ASSERT(lz + rz > 0 && lz + rz < 32); // must have at least one 0 and at least one 1 + LUAU_ASSERT((src2 >> rz) == (1u << (32 - lz - rz)) - 1u); // sequence of 1s must be contiguous int imms = 31 - lz - rz; // count of 1s minus 1 int immr = (32 - rz) & 31; // right rotate amount From d5cdb687e061798b3a7dc508c75bb2dd331db913 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 21 Apr 2023 14:41:03 -0700 Subject: [PATCH 48/66] Sync to upstream/release/573 --- .../include/Luau/ConstraintGraphBuilder.h | 6 +- Analysis/include/Luau/ConstraintSolver.h | 2 +- Analysis/include/Luau/Frontend.h | 9 + Analysis/include/Luau/Module.h | 7 +- Analysis/include/Luau/Type.h | 48 +- Analysis/include/Luau/TypeChecker2.h | 2 +- Analysis/include/Luau/TypeInfer.h | 1 - Analysis/include/Luau/VisitType.h | 10 +- Analysis/src/AstJsonEncoder.cpp | 3 + Analysis/src/BuiltinDefinitions.cpp | 2 +- Analysis/src/Clone.cpp | 9 +- Analysis/src/ConstraintGraphBuilder.cpp | 17 +- Analysis/src/ConstraintSolver.cpp | 13 +- Analysis/src/Error.cpp | 6 +- Analysis/src/Frontend.cpp | 96 ++-- Analysis/src/Substitution.cpp | 2 +- Analysis/src/ToString.cpp | 11 +- Analysis/src/Type.cpp | 53 +- Analysis/src/TypeAttach.cpp | 43 +- Analysis/src/TypeChecker2.cpp | 21 +- Analysis/src/TypeInfer.cpp | 40 +- Analysis/src/TypeReduction.cpp | 2 +- Ast/include/Luau/Ast.h | 6 +- Ast/src/Ast.cpp | 6 +- Ast/src/Parser.cpp | 13 +- CLI/Repl.cpp | 10 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 67 ++- CodeGen/include/Luau/AssemblyBuilderX64.h | 5 + CodeGen/include/Luau/ConditionA64.h | 4 +- CodeGen/include/Luau/IrData.h | 64 ++- CodeGen/include/Luau/IrUtils.h | 21 +- CodeGen/include/Luau/OptimizeConstProp.h | 1 + CodeGen/src/AssemblyBuilderA64.cpp | 285 ++++++++-- CodeGen/src/AssemblyBuilderX64.cpp | 38 ++ CodeGen/src/CodeGen.cpp | 13 +- CodeGen/src/CodeGenUtils.cpp | 4 +- CodeGen/src/CodeGenUtils.h | 2 +- CodeGen/src/EmitBuiltinsX64.cpp | 123 ---- CodeGen/src/EmitCommonA64.h | 6 +- CodeGen/src/IrBuilder.cpp | 3 +- CodeGen/src/IrDump.cpp | 34 ++ CodeGen/src/IrLoweringA64.cpp | 538 ++++++++++++------ CodeGen/src/IrLoweringA64.h | 8 +- CodeGen/src/IrLoweringX64.cpp | 234 +++++++- CodeGen/src/IrLoweringX64.h | 6 + CodeGen/src/IrRegAllocA64.cpp | 221 ++++++- CodeGen/src/IrRegAllocA64.h | 38 +- CodeGen/src/IrRegAllocX64.cpp | 67 +-- CodeGen/src/IrTranslateBuiltins.cpp | 428 +++++++++++++- CodeGen/src/IrTranslation.cpp | 6 + CodeGen/src/IrUtils.cpp | 81 ++- CodeGen/src/IrValueLocationTracking.cpp | 223 ++++++++ CodeGen/src/IrValueLocationTracking.h | 38 ++ CodeGen/src/NativeState.h | 2 +- CodeGen/src/OptimizeConstProp.cpp | 55 +- CodeGen/src/UnwindBuilderDwarf2.cpp | 4 +- Sources.cmake | 3 + VM/include/lua.h | 6 +- VM/src/lapi.cpp | 8 +- VM/src/ludata.cpp | 3 +- tests/AssemblyBuilderA64.test.cpp | 55 +- tests/AssemblyBuilderX64.test.cpp | 6 + tests/AstJsonEncoder.test.cpp | 20 +- tests/Conformance.test.cpp | 14 +- tests/ConstraintGraphBuilderFixture.cpp | 6 +- tests/Frontend.test.cpp | 12 +- tests/IrBuilder.test.cpp | 3 + tests/IrRegAllocX64.test.cpp | 58 ++ tests/Module.test.cpp | 6 +- tests/TypeInfer.modules.test.cpp | 10 +- tests/TypeInfer.provisional.test.cpp | 8 +- tests/TypeInfer.test.cpp | 9 +- tests/TypeVar.test.cpp | 19 + tests/conformance/native.lua | 17 + tools/codegenstat.py | 58 ++ tools/faillist.txt | 2 + tools/natvis/CodeGen.natvis | 9 + 77 files changed, 2715 insertions(+), 674 deletions(-) create mode 100644 CodeGen/src/IrValueLocationTracking.cpp create mode 100644 CodeGen/src/IrValueLocationTracking.h create mode 100644 tests/IrRegAllocX64.test.cpp create mode 100644 tests/conformance/native.lua create mode 100644 tools/codegenstat.py diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 204470489..cbf679cc5 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -60,7 +60,6 @@ struct ConstraintGraphBuilder // define the scope hierarchy. std::vector> scopes; - ModuleName moduleName; ModulePtr module; NotNull builtinTypes; const NotNull arena; @@ -94,9 +93,8 @@ struct ConstraintGraphBuilder ScopePtr globalScope; DcrLogger* logger; - ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, - NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, - NotNull dfg); + ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull builtinTypes, + NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, NotNull dfg); /** * Fabricates a new free type belonging to a given scope. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 2feee2368..6888e99c2 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -49,7 +49,7 @@ struct HashInstantiationSignature struct ConstraintSolver { - TypeArena* arena; + NotNull arena; NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 3f41c1456..856c5dafa 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -8,6 +8,8 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" + +#include #include #include #include @@ -67,6 +69,7 @@ struct SourceNode } ModuleName name; + std::string humanReadableName; std::unordered_set requireSet; std::vector> requireLocations; bool dirtySourceModule = true; @@ -114,7 +117,13 @@ struct FrontendModuleResolver : ModuleResolver std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; std::string getHumanReadableModuleName(const ModuleName& moduleName) const override; + void setModule(const ModuleName& moduleName, ModulePtr module); + void clearModules(); + +private: Frontend* frontend; + + mutable std::mutex moduleMutex; std::unordered_map modules; }; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 1bca7636c..b9be8205b 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -28,7 +28,9 @@ class AstTypePack; /// Root of the AST of a parsed source file struct SourceModule { - ModuleName name; // DataModel path if possible. Filename if not. + ModuleName name; // Module identifier or a filename + std::string humanReadableName; + SourceCode::Type type = SourceCode::None; std::optional environmentName; bool cyclic = false; @@ -63,6 +65,9 @@ struct Module { ~Module(); + ModuleName name; + std::string humanReadableName; + TypeArena interfaceTypes; TypeArena internalTypes; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index b9544a11d..24fb7db0f 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -10,6 +10,7 @@ #include "Luau/Unifiable.h" #include "Luau/Variant.h" +#include #include #include #include @@ -550,7 +551,50 @@ struct IntersectionType struct LazyType { - std::function thunk; + LazyType() = default; + LazyType(std::function thunk_DEPRECATED, std::function unwrap) + : thunk_DEPRECATED(thunk_DEPRECATED) + , unwrap(unwrap) + { + } + + // std::atomic is sad and requires a manual copy + LazyType(const LazyType& rhs) + : thunk_DEPRECATED(rhs.thunk_DEPRECATED) + , unwrap(rhs.unwrap) + , unwrapped(rhs.unwrapped.load()) + { + } + + LazyType(LazyType&& rhs) noexcept + : thunk_DEPRECATED(std::move(rhs.thunk_DEPRECATED)) + , unwrap(std::move(rhs.unwrap)) + , unwrapped(rhs.unwrapped.load()) + { + } + + LazyType& operator=(const LazyType& rhs) + { + thunk_DEPRECATED = rhs.thunk_DEPRECATED; + unwrap = rhs.unwrap; + unwrapped = rhs.unwrapped.load(); + + return *this; + } + + LazyType& operator=(LazyType&& rhs) noexcept + { + thunk_DEPRECATED = std::move(rhs.thunk_DEPRECATED); + unwrap = std::move(rhs.unwrap); + unwrapped = rhs.unwrapped.load(); + + return *this; + } + + std::function thunk_DEPRECATED; + + std::function unwrap; + std::atomic unwrapped = nullptr; }; struct UnknownType @@ -798,7 +842,7 @@ struct TypeIterator TypeIterator operator++(int) { TypeIterator copy = *this; - ++copy; + ++*this; return copy; } diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index 6045aecff..def00a440 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -12,6 +12,6 @@ namespace Luau struct DcrLogger; struct BuiltinTypes; -void check(NotNull builtinTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module); +void check(NotNull builtinTypes, NotNull sharedState, DcrLogger* logger, const SourceModule& sourceModule, Module* module); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 7dae79c31..b5db3f58d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -372,7 +372,6 @@ struct TypeChecker ModuleResolver* resolver; ModulePtr currentModule; - ModuleName currentModuleName; std::function prepareModuleScope; NotNull builtinTypes; diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index 95b2b0507..c7dcdcc1e 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -9,6 +9,7 @@ #include "Luau/Type.h" LUAU_FASTINT(LuauVisitRecursionLimit) +LUAU_FASTFLAG(LuauBoundLazyTypes) namespace Luau { @@ -291,9 +292,14 @@ struct GenericTypeVisitor traverse(partTy); } } - else if (get(ty)) + else if (auto ltv = get(ty)) { - // Visiting into LazyType may necessarily cause infinite expansion, so we don't do that on purpose. + if (FFlag::LuauBoundLazyTypes) + { + if (TypeId unwrapped = ltv->unwrapped) + traverse(unwrapped); + } + // Visiting into LazyType that hasn't been unwrapped may necessarily cause infinite expansion, so we don't do that on purpose. // Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassType // that doesn't need to be expanded. } diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index 57c8c90b4..a964c785f 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -776,7 +776,10 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstTypeReference", [&]() { if (node->prefix) PROP(prefix); + if (node->prefixLocation) + write("prefixLocation", *node->prefixLocation); PROP(name); + PROP(nameLocation); PROP(parameters); }); } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 7ed92fb41..8988b332e 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -606,7 +606,7 @@ static std::optional> magicFunctionRequire( if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; - if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr)) + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModule->name, expr)) return WithPredicate{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; return std::nullopt; diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index ac73622d3..f5102654f 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -325,7 +325,14 @@ void TypeCloner::operator()(const IntersectionType& t) void TypeCloner::operator()(const LazyType& t) { - defaultClone(t); + if (TypeId unwrapped = t.unwrapped.load()) + { + seenTypes[typeId] = clone(unwrapped, dest, cloneState); + } + else + { + defaultClone(t); + } } void TypeCloner::operator()(const UnknownType& t) diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 474d39235..ad7cff9f7 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -133,11 +133,10 @@ void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const Con } // namespace -ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, - NotNull moduleResolver, NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, - DcrLogger* logger, NotNull dfg) - : moduleName(moduleName) - , module(module) +ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, + NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, + NotNull dfg) + : module(module) , builtinTypes(builtinTypes) , arena(arena) , rootScope(nullptr) @@ -599,7 +598,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* l { AstExpr* require = *maybeRequire; - if (auto moduleInfo = moduleResolver->resolveModuleInfo(moduleName, *require)) + if (auto moduleInfo = moduleResolver->resolveModuleInfo(module->name, *require)) { const Name name{local->vars.data[i]->name.value}; @@ -1043,7 +1042,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareC Name className(declaredClass->name.value); - TypeId classTy = arena->addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, moduleName)); + TypeId classTy = arena->addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, module->name)); ClassType* ctv = getMutable(classTy); TypeId metaTy = arena->addType(TableType{TableState::Sealed, scope->level, scope.get()}); @@ -2609,7 +2608,7 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) { - errors.push_back(TypeError{location, moduleName, std::move(err)}); + errors.push_back(TypeError{location, module->name, std::move(err)}); if (logger) logger->captureGenerationError(errors.back()); @@ -2617,7 +2616,7 @@ void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) void ConstraintGraphBuilder::reportCodeTooComplex(Location location) { - errors.push_back(TypeError{location, moduleName, CodeTooComplex{}}); + errors.push_back(TypeError{location, module->name, CodeTooComplex{}}); if (logger) logger->captureGenerationError(errors.back()); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 0fc32c33d..558ad2d51 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -18,6 +18,7 @@ #include "Luau/VisitType.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); +LUAU_FASTFLAG(LuauRequirePathTrueModuleName) namespace Luau { @@ -1989,7 +1990,7 @@ static TypePackId getErrorType(NotNull builtinTypes, TypePackId) template bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, TID superTy) { - Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; + Unifier u{normalizer, Mode::Strict, constraint->scope, constraint->location, Covariant}; u.useScopes = true; u.tryUnify(subTy, superTy); @@ -2257,11 +2258,9 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l return errorRecoveryType(); } - std::string humanReadableName = moduleResolver->getHumanReadableModuleName(info.name); - for (const auto& [location, path] : requireCycles) { - if (!path.empty() && path.front() == humanReadableName) + if (!path.empty() && path.front() == (FFlag::LuauRequirePathTrueModuleName ? info.name : moduleResolver->getHumanReadableModuleName(info.name))) return builtinTypes->anyType; } @@ -2269,14 +2268,14 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l if (!module) { if (!moduleResolver->moduleExists(info.name) && !info.optional) - reportError(UnknownRequire{humanReadableName}, location); + reportError(UnknownRequire{moduleResolver->getHumanReadableModuleName(info.name)}, location); return errorRecoveryType(); } if (module->type != SourceCode::Type::Module) { - reportError(IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}, location); + reportError(IllegalRequire{module->humanReadableName, "Module is not a ModuleScript. It cannot be required."}, location); return errorRecoveryType(); } @@ -2287,7 +2286,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l std::optional moduleType = first(modulePack); if (!moduleType) { - reportError(IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}, location); + reportError(IllegalRequire{module->humanReadableName, "Module does not return exactly 1 value. It cannot be required."}, location); return errorRecoveryType(); } diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 84b9cb37d..1e0379729 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauTypeMismatchInvarianceInError, false) +LUAU_FASTFLAGVARIABLE(LuauRequirePathTrueModuleName, false) static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) @@ -349,7 +350,10 @@ struct ErrorConverter else s += " -> "; - s += name; + if (FFlag::LuauRequirePathTrueModuleName && fileResolver != nullptr) + s += fileResolver->getHumanReadableModuleName(name); + else + s += name; } return s; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 5beb6c4e1..916dd1d57 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -33,6 +33,7 @@ LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(LuauOnDemandTypecheckers, false) +LUAU_FASTFLAG(LuauRequirePathTrueModuleName) namespace Luau { @@ -245,7 +246,7 @@ namespace { static ErrorVec accumulateErrors( - const std::unordered_map& sourceNodes, const std::unordered_map& modules, const ModuleName& name) + const std::unordered_map& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name) { std::unordered_set seen; std::vector queue{name}; @@ -271,11 +272,11 @@ static ErrorVec accumulateErrors( // FIXME: If a module has a syntax error, we won't be able to re-report it here. // The solution is probably to move errors from Module to SourceNode - auto it2 = modules.find(next); - if (it2 == modules.end()) + auto modulePtr = moduleResolver.getModule(next); + if (!modulePtr) continue; - Module& module = *it2->second; + Module& module = *modulePtr; std::sort(module.errors.begin(), module.errors.end(), [](const TypeError& e1, const TypeError& e2) -> bool { return e1.location.begin > e2.location.begin; @@ -345,9 +346,9 @@ std::vector getRequireCycles( if (top == start) { for (const SourceNode* node : path) - cycle.push_back(resolver->getHumanReadableModuleName(node->name)); + cycle.push_back(FFlag::LuauRequirePathTrueModuleName ? node->name : node->humanReadableName); - cycle.push_back(resolver->getHumanReadableModuleName(top->name)); + cycle.push_back(FFlag::LuauRequirePathTrueModuleName ? top->name : top->humanReadableName); break; } } @@ -415,11 +416,6 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c { } -FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) - : frontend(frontend) -{ -} - CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) { LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); @@ -428,31 +424,21 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond.hasDirtyModule(frontendOptions.forAutocomplete)) { // No recheck required. - if (frontendOptions.forAutocomplete) - { - auto it2 = moduleResolverForAutocomplete.modules.find(name); - if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) - throw InternalCompilerError("Frontend::modules does not have data for " + name, name); - } - else - { - auto it2 = moduleResolver.modules.find(name); - if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw InternalCompilerError("Frontend::modules does not have data for " + name, name); - } + ModulePtr module = resolver.getModule(name); - std::unordered_map& modules = - frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + if (!module) + throw InternalCompilerError("Frontend::modules does not have data for " + name, name); - checkResult.errors = accumulateErrors(sourceNodes, modules, name); + checkResult.errors = accumulateErrors(sourceNodes, resolver, name); // Get lint result only for top checked module - if (auto it = modules.find(name); it != modules.end()) - checkResult.lintResult = it->second->lintResult; + checkResult.lintResult = module->lintResult; return checkResult; } @@ -556,7 +542,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalerrors.begin(), module->errors.end()); - moduleResolver.modules[moduleName] = std::move(module); + resolver.setModule(moduleName, std::move(module)); sourceNode.dirtyModule = false; } // Get lint result only for top checked module - std::unordered_map& modules = - frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; - - if (auto it = modules.find(name); it != modules.end()) - checkResult.lintResult = it->second->lintResult; + if (ModulePtr module = resolver.getModule(name)) + checkResult.lintResult = module->lintResult; return checkResult; } @@ -817,7 +800,7 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { - if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name)) + if (!moduleResolver.getModule(name) && !moduleResolverForAutocomplete.getModule(name)) return; std::unordered_map> reverseDeps; @@ -884,13 +867,15 @@ ModulePtr check(const SourceModule& sourceModule, const std::vector(); + result->name = sourceModule.name; + result->humanReadableName = sourceModule.humanReadableName; result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, iceHandler); std::unique_ptr logger; if (recordJsonLog) { logger = std::make_unique(); - std::optional source = fileResolver->readSource(sourceModule.name); + std::optional source = fileResolver->readSource(result->name); if (source) { logger->captureSource(source->source); @@ -906,7 +891,6 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorinternalTypes, builtinTypes, NotNull{&unifierState}}; ConstraintGraphBuilder cgb{ - sourceModule.name, result, &result->internalTypes, moduleResolver, @@ -920,8 +904,8 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorerrors = std::move(cgb.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, moduleResolver, - requireCycles, logger.get()}; + ConstraintSolver cs{ + NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), result->name, moduleResolver, requireCycles, logger.get()}; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); @@ -936,7 +920,7 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorclonePublicInterface(builtinTypes, *iceHandler); - Luau::check(builtinTypes, logger.get(), sourceModule, result.get()); + Luau::check(builtinTypes, NotNull{&unifierState}, logger.get(), sourceModule, result.get()); // Ideally we freeze the arenas before the call into Luau::check, but TypeReduction // needs to allocate new types while Luau::check is in progress, so here we are. @@ -1033,7 +1017,8 @@ std::pair Frontend::getSourceNode(const ModuleName& sourceModule = std::move(result); sourceModule.environmentName = environmentName; - sourceNode.name = name; + sourceNode.name = sourceModule.name; + sourceNode.humanReadableName = sourceModule.humanReadableName; sourceNode.requireSet.clear(); sourceNode.requireLocations.clear(); sourceNode.dirtySourceModule = false; @@ -1095,6 +1080,7 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const } sourceModule.name = name; + sourceModule.humanReadableName = fileResolver->getHumanReadableModuleName(name); if (parseOptions.captureComments) { @@ -1105,6 +1091,12 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const return sourceModule; } + +FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) + : frontend(frontend) +{ +} + std::optional FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) { // FIXME I think this can be pushed into the FileResolver. @@ -1129,6 +1121,8 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const { + std::scoped_lock lock(moduleMutex); + auto it = modules.find(moduleName); if (it != modules.end()) return it->second; @@ -1146,6 +1140,20 @@ std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& return frontend->fileResolver->getHumanReadableModuleName(moduleName); } +void FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module) +{ + std::scoped_lock lock(moduleMutex); + + modules[moduleName] = std::move(module); +} + +void FrontendModuleResolver::clearModules() +{ + std::scoped_lock lock(moduleMutex); + + modules.clear(); +} + ScopePtr Frontend::addEnvironment(const std::string& environmentName) { LUAU_ASSERT(environments.count(environmentName) == 0); @@ -1208,8 +1216,8 @@ void Frontend::clear() { sourceNodes.clear(); sourceModules.clear(); - moduleResolver.modules.clear(); - moduleResolverForAutocomplete.modules.clear(); + moduleResolver.clearModules(); + moduleResolverForAutocomplete.clearModules(); requireTrace.clear(); } diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 935d85d71..962172172 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -257,7 +257,7 @@ void Tarjan::visitChildren(TypeId ty, int index) } else if (const ClassType* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) { - for (auto [name, prop] : ctv->props) + for (const auto& [name, prop] : ctv->props) visitChild(prop.type); if (ctv->parent) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index fe09ef11a..46d2e8f8f 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -834,8 +834,15 @@ struct TypeStringifier void operator()(TypeId, const LazyType& ltv) { - state.result.invalid = true; - state.emit("lazy?"); + if (TypeId unwrapped = ltv.unwrapped.load()) + { + stringify(unwrapped); + } + else + { + state.result.invalid = true; + state.emit("lazy?"); + } } void operator()(TypeId, const UnknownType& ttv) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 528541083..e4d9ab33f 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -26,6 +26,7 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) +LUAU_FASTFLAGVARIABLE(LuauBoundLazyTypes, false) namespace Luau { @@ -56,18 +57,51 @@ TypeId follow(TypeId t) TypeId follow(TypeId t, std::function mapper) { auto advance = [&mapper](TypeId ty) -> std::optional { - if (auto btv = get>(mapper(ty))) - return btv->boundTo; - else if (auto ttv = get(mapper(ty))) - return ttv->boundTo; - else + if (FFlag::LuauBoundLazyTypes) + { + TypeId mapped = mapper(ty); + + if (auto btv = get>(mapped)) + return btv->boundTo; + + if (auto ttv = get(mapped)) + return ttv->boundTo; + + if (auto ltv = getMutable(mapped)) + { + TypeId unwrapped = ltv->unwrapped.load(); + + if (unwrapped) + return unwrapped; + + unwrapped = ltv->unwrap(*ltv); + + if (!unwrapped) + throw InternalCompilerError("Lazy Type didn't fill in unwrapped type field"); + + if (get(unwrapped)) + throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); + + return unwrapped; + } + return std::nullopt; + } + else + { + if (auto btv = get>(mapper(ty))) + return btv->boundTo; + else if (auto ttv = get(mapper(ty))) + return ttv->boundTo; + else + return std::nullopt; + } }; auto force = [&mapper](TypeId ty) { if (auto ltv = get_if(&mapper(ty)->ty)) { - TypeId res = ltv->thunk(); + TypeId res = ltv->thunk_DEPRECATED(); if (get(res)) throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); @@ -75,7 +109,8 @@ TypeId follow(TypeId t, std::function mapper) } }; - force(t); + if (!FFlag::LuauBoundLazyTypes) + force(t); TypeId cycleTester = t; // Null once we've determined that there is no cycle if (auto a = advance(cycleTester)) @@ -85,7 +120,9 @@ TypeId follow(TypeId t, std::function mapper) while (true) { - force(t); + if (!FFlag::LuauBoundLazyTypes) + force(t); + auto a1 = advance(t); if (a1) t = *a1; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index d6494edfd..7ed4eb49b 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -94,15 +94,15 @@ class TypeRehydrationVisitor switch (ptv.type) { case PrimitiveType::NilType: - return allocator->alloc(Location(), std::nullopt, AstName("nil")); + return allocator->alloc(Location(), std::nullopt, AstName("nil"), std::nullopt, Location()); case PrimitiveType::Boolean: - return allocator->alloc(Location(), std::nullopt, AstName("boolean")); + return allocator->alloc(Location(), std::nullopt, AstName("boolean"), std::nullopt, Location()); case PrimitiveType::Number: - return allocator->alloc(Location(), std::nullopt, AstName("number")); + return allocator->alloc(Location(), std::nullopt, AstName("number"), std::nullopt, Location()); case PrimitiveType::String: - return allocator->alloc(Location(), std::nullopt, AstName("string")); + return allocator->alloc(Location(), std::nullopt, AstName("string"), std::nullopt, Location()); case PrimitiveType::Thread: - return allocator->alloc(Location(), std::nullopt, AstName("thread")); + return allocator->alloc(Location(), std::nullopt, AstName("thread"), std::nullopt, Location()); default: return nullptr; } @@ -110,12 +110,12 @@ class TypeRehydrationVisitor AstType* operator()(const BlockedType& btv) { - return allocator->alloc(Location(), std::nullopt, AstName("*blocked*")); + return allocator->alloc(Location(), std::nullopt, AstName("*blocked*"), std::nullopt, Location()); } AstType* operator()(const PendingExpansionType& petv) { - return allocator->alloc(Location(), std::nullopt, AstName("*pending-expansion*")); + return allocator->alloc(Location(), std::nullopt, AstName("*pending-expansion*"), std::nullopt, Location()); } AstType* operator()(const SingletonType& stv) @@ -135,7 +135,7 @@ class TypeRehydrationVisitor AstType* operator()(const AnyType&) { - return allocator->alloc(Location(), std::nullopt, AstName("any")); + return allocator->alloc(Location(), std::nullopt, AstName("any"), std::nullopt, Location()); } AstType* operator()(const TableType& ttv) { @@ -157,15 +157,16 @@ class TypeRehydrationVisitor parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; } - return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters); + return allocator->alloc( + Location(), std::nullopt, AstName(ttv.name->c_str()), std::nullopt, Location(), parameters.size != 0, parameters); } if (hasSeen(&ttv)) { if (ttv.name) - return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str())); + return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), std::nullopt, Location()); else - return allocator->alloc(Location(), std::nullopt, AstName("")); + return allocator->alloc(Location(), std::nullopt, AstName(""), std::nullopt, Location()); } AstArray props; @@ -208,7 +209,7 @@ class TypeRehydrationVisitor char* name = allocateString(*allocator, ctv.name); if (!options.expandClassProps || hasSeen(&ctv) || count > 1) - return allocator->alloc(Location(), std::nullopt, AstName{name}); + return allocator->alloc(Location(), std::nullopt, AstName{name}, std::nullopt, Location()); AstArray props; props.size = ctv.props.size(); @@ -233,7 +234,7 @@ class TypeRehydrationVisitor RecursionCounter counter(&count); if (hasSeen(&ftv)) - return allocator->alloc(Location(), std::nullopt, AstName("")); + return allocator->alloc(Location(), std::nullopt, AstName(""), std::nullopt, Location()); AstArray generics; generics.size = ftv.generics.size(); @@ -304,11 +305,12 @@ class TypeRehydrationVisitor } AstType* operator()(const Unifiable::Error&) { - return allocator->alloc(Location(), std::nullopt, AstName("Unifiable")); + return allocator->alloc(Location(), std::nullopt, AstName("Unifiable"), std::nullopt, Location()); } AstType* operator()(const GenericType& gtv) { - return allocator->alloc(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv))); + return allocator->alloc( + Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)), std::nullopt, Location()); } AstType* operator()(const Unifiable::Bound& bound) { @@ -316,7 +318,7 @@ class TypeRehydrationVisitor } AstType* operator()(const FreeType& ftv) { - return allocator->alloc(Location(), std::nullopt, AstName("free")); + return allocator->alloc(Location(), std::nullopt, AstName("free"), std::nullopt, Location()); } AstType* operator()(const UnionType& uv) { @@ -342,15 +344,18 @@ class TypeRehydrationVisitor } AstType* operator()(const LazyType& ltv) { - return allocator->alloc(Location(), std::nullopt, AstName("")); + if (TypeId unwrapped = ltv.unwrapped.load()) + return Luau::visit(*this, unwrapped->ty); + + return allocator->alloc(Location(), std::nullopt, AstName(""), std::nullopt, Location()); } AstType* operator()(const UnknownType& ttv) { - return allocator->alloc(Location(), std::nullopt, AstName{"unknown"}); + return allocator->alloc(Location(), std::nullopt, AstName{"unknown"}, std::nullopt, Location()); } AstType* operator()(const NeverType& ttv) { - return allocator->alloc(Location(), std::nullopt, AstName{"never"}); + return allocator->alloc(Location(), std::nullopt, AstName{"never"}, std::nullopt, Location()); } AstType* operator()(const NegationType& ntv) { diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 6e76af042..893f51d97 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -88,21 +88,22 @@ struct TypeChecker2 { NotNull builtinTypes; DcrLogger* logger; - InternalErrorReporter ice; // FIXME accept a pointer from Frontend + NotNull ice; const SourceModule* sourceModule; Module* module; TypeArena testArena; std::vector> stack; - UnifierSharedState sharedState{&ice}; - Normalizer normalizer{&testArena, builtinTypes, NotNull{&sharedState}}; + Normalizer normalizer; - TypeChecker2(NotNull builtinTypes, DcrLogger* logger, const SourceModule* sourceModule, Module* module) + TypeChecker2(NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule* sourceModule, Module* module) : builtinTypes(builtinTypes) , logger(logger) + , ice(unifierState->iceHandler) , sourceModule(sourceModule) , module(module) + , normalizer{&testArena, builtinTypes, unifierState} { } @@ -996,7 +997,7 @@ struct TypeChecker2 } if (!fst) - ice.ice("UnionType had no elements, so fst is nullopt?"); + ice->ice("UnionType had no elements, so fst is nullopt?"); if (std::optional instantiatedFunctionType = instantiation.substitute(*fst)) { @@ -1018,7 +1019,7 @@ struct TypeChecker2 { AstExprIndexName* indexExpr = call->func->as(); if (!indexExpr) - ice.ice("method call expression has no 'self'"); + ice->ice("method call expression has no 'self'"); args.head.push_back(lookupType(indexExpr->expr)); argLocs.push_back(indexExpr->expr->location); @@ -1646,7 +1647,7 @@ struct TypeChecker2 else if (finite(pack) && size(pack) == 0) return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil` else - ice.ice("flattenPack got a weird pack!"); + ice->ice("flattenPack got a weird pack!"); } void visitGenerics(AstArray generics, AstArray genericPacks) @@ -2012,7 +2013,7 @@ struct TypeChecker2 void reportError(TypeErrorData data, const Location& location) { - module->errors.emplace_back(location, sourceModule->name, std::move(data)); + module->errors.emplace_back(location, module->name, std::move(data)); if (logger) logger->captureTypeCheckError(module->errors.back()); @@ -2160,9 +2161,9 @@ struct TypeChecker2 } }; -void check(NotNull builtinTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module) +void check(NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule& sourceModule, Module* module) { - TypeChecker2 typeChecker{builtinTypes, logger, &sourceModule, module}; + TypeChecker2 typeChecker{builtinTypes, unifierState, logger, &sourceModule, module}; typeChecker.reduceTypes(); typeChecker.visit(sourceModule.root); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 7f366a204..a8c093a43 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,7 +33,6 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNegatedClassTypes) @@ -42,6 +41,7 @@ LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) +LUAU_FASTFLAG(LuauRequirePathTrueModuleName) namespace Luau { @@ -264,8 +264,11 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo { LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + LUAU_TIMETRACE_ARGUMENT("name", module.humanReadableName.c_str()); currentModule.reset(new Module); + currentModule->name = module.name; + currentModule->humanReadableName = module.humanReadableName; currentModule->reduction = std::make_unique(NotNull{¤tModule->internalTypes}, builtinTypes, NotNull{iceHandler}); currentModule->type = module.type; currentModule->allocator = module.allocator; @@ -290,10 +293,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule->scopes.push_back(std::make_pair(module.root->location, moduleScope)); currentModule->mode = mode; - currentModuleName = module.name; - if (prepareModuleScope) - prepareModuleScope(module.name, currentModule->getModuleScope()); + prepareModuleScope(currentModule->name, currentModule->getModuleScope()); try { @@ -1179,7 +1180,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) { AstExpr* require = *maybeRequire; - if (auto moduleInfo = resolver->resolveModuleInfo(currentModuleName, *require)) + if (auto moduleInfo = resolver->resolveModuleInfo(currentModule->name, *require)) { const Name name{local.vars.data[i]->name.value}; @@ -1728,7 +1729,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de Name className(declaredClass.name.value); - TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); + TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModule->name)); ClassType* ctv = getMutable(classTy); TypeId metaTy = addType(TableType{TableState::Sealed, scope->level}); @@ -2000,12 +2001,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) - { - if (FFlag::LuauReturnAnyInsteadOfICE) - return {anyType, std::move(result.predicates)}; - else - ice("Unexpected abstract type pack!", expr.location); - } + return {anyType, std::move(result.predicates)}; else ice("Unknown TypePack type!", expr.location); } @@ -2336,7 +2332,7 @@ TypeId TypeChecker::checkExprTable( TableState state = TableState::Unsealed; TableType table = TableType{std::move(props), indexer, scope->level, state}; - table.definitionModuleName = currentModuleName; + table.definitionModuleName = currentModule->name; table.definitionLocation = expr.location; return addType(table); } @@ -3663,7 +3659,7 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& TypePackId argPack = addTypePack(TypePackVar(TypePack{argTypes, funScope->varargPack})); FunctionDefinition defn; - defn.definitionModuleName = currentModuleName; + defn.definitionModuleName = currentModule->name; defn.definitionLocation = expr.location; defn.varargLocation = expr.vararg ? std::make_optional(expr.varargLocation) : std::nullopt; defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); @@ -4606,11 +4602,9 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module } // Types of requires that transitively refer to current module have to be replaced with 'any' - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - for (const auto& [location, path] : requireCycles) { - if (!path.empty() && path.front() == humanReadableName) + if (!path.empty() && path.front() == (FFlag::LuauRequirePathTrueModuleName ? moduleInfo.name : resolver->getHumanReadableModuleName(moduleInfo.name))) return anyType; } @@ -4621,14 +4615,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // either the file does not exist or there's a cycle. If there's a cycle // we will already have reported the error. if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) - reportError(TypeError{location, UnknownRequire{humanReadableName}}); + reportError(TypeError{location, UnknownRequire{resolver->getHumanReadableModuleName(moduleInfo.name)}}); return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { - reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); + reportError(location, IllegalRequire{module->humanReadableName, "Module is not a ModuleScript. It cannot be required."}); return errorRecoveryType(scope); } @@ -4640,7 +4634,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module std::optional moduleType = first(modulePack); if (!moduleType) { - reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); + reportError(location, IllegalRequire{module->humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); return errorRecoveryType(scope); } @@ -4855,7 +4849,7 @@ void TypeChecker::reportError(const TypeError& error) if (currentModule->mode == Mode::NoCheck) return; currentModule->errors.push_back(error); - currentModule->errors.back().moduleName = currentModuleName; + currentModule->errors.back().moduleName = currentModule->name; } void TypeChecker::reportError(const Location& location, TypeErrorData errorData) @@ -5329,7 +5323,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); TableType ttv{props, tableIndexer, scope->level, TableState::Sealed}; - ttv.definitionModuleName = currentModuleName; + ttv.definitionModuleName = currentModule->name; ttv.definitionLocation = annotation.location; return addType(std::move(ttv)); } @@ -5531,7 +5525,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, { ttv->instantiatedTypeParams = typeParams; ttv->instantiatedTypePackParams = typePackParams; - ttv->definitionModuleName = currentModuleName; + ttv->definitionModuleName = currentModule->name; ttv->definitionLocation = location; } diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 310df766f..031c59936 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -10,7 +10,7 @@ #include LUAU_FASTINTVARIABLE(LuauTypeReductionCartesianProductLimit, 100'000) -LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 400) +LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 300) LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) namespace Luau diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 9c352f9dc..20158e8eb 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -841,14 +841,16 @@ class AstTypeReference : public AstType public: LUAU_RTTI(AstTypeReference) - AstTypeReference(const Location& location, std::optional prefix, AstName name, bool hasParameterList = false, - const AstArray& parameters = {}); + AstTypeReference(const Location& location, std::optional prefix, AstName name, std::optional prefixLocation, + const Location& nameLocation, bool hasParameterList = false, const AstArray& parameters = {}); void visit(AstVisitor* visitor) override; bool hasParameterList; std::optional prefix; + std::optional prefixLocation; AstName name; + Location nameLocation; AstArray parameters; }; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index e01ced049..d2c552a3c 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -753,12 +753,14 @@ void AstStatError::visit(AstVisitor* visitor) } } -AstTypeReference::AstTypeReference( - const Location& location, std::optional prefix, AstName name, bool hasParameterList, const AstArray& parameters) +AstTypeReference::AstTypeReference(const Location& location, std::optional prefix, AstName name, std::optional prefixLocation, + const Location& nameLocation, bool hasParameterList, const AstArray& parameters) : AstType(ClassIndex(), location) , hasParameterList(hasParameterList) , prefix(prefix) + , prefixLocation(prefixLocation) , name(name) + , nameLocation(nameLocation) , parameters(parameters) { } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 40fa754e6..6a76eda22 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1343,7 +1343,7 @@ AstType* Parser::parseTableType() AstType* type = parseType(); // array-like table type: {T} desugars into {[number]: T} - AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber); + AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber, std::nullopt, type->location); indexer = allocator.alloc(AstTableIndexer{index, type, type->location}); break; @@ -1449,7 +1449,7 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray' after '()' when parsing function type; did you mean 'nil'?"); - return allocator.alloc(begin.location, std::nullopt, nameNil); + return allocator.alloc(begin.location, std::nullopt, nameNil, std::nullopt, begin.location); } else { @@ -1493,7 +1493,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) { Location loc = lexer.current().location; nextLexeme(); - parts.push_back(allocator.alloc(loc, std::nullopt, nameNil)); + parts.push_back(allocator.alloc(loc, std::nullopt, nameNil, std::nullopt, loc)); isUnion = true; } else if (c == '&') @@ -1577,7 +1577,7 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack) if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); - return {allocator.alloc(start, std::nullopt, nameNil), {}}; + return {allocator.alloc(start, std::nullopt, nameNil, std::nullopt, start), {}}; } else if (lexer.current().type == Lexeme::ReservedTrue) { @@ -1613,6 +1613,7 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack) else if (lexer.current().type == Lexeme::Name) { std::optional prefix; + std::optional prefixLocation; Name name = parseName("type name"); if (lexer.current().type == '.') @@ -1621,6 +1622,7 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack) nextLexeme(); prefix = name.name; + prefixLocation = name.location; name = parseIndexName("field name", pointPosition); } else if (lexer.current().type == Lexeme::Dot3) @@ -1653,7 +1655,8 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack) Location end = lexer.previousLocation(); - return {allocator.alloc(Location(start, end), prefix, name.name, hasParameters, parameters), {}}; + return { + allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {}}; } else if (lexer.current().type == '{') { diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 63baea8b7..bcf70f250 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -8,6 +8,7 @@ #include "Luau/Compiler.h" #include "Luau/BytecodeBuilder.h" #include "Luau/Parser.h" +#include "Luau/TimeTrace.h" #include "Coverage.h" #include "FileUtils.h" @@ -997,15 +998,18 @@ int replMain(int argc, char** argv) CompileStats stats = {}; int failed = 0; + double startTime = Luau::TimeTrace::getClock(); for (const std::string& path : files) failed += !compileFile(path.c_str(), compileFormat, stats); + double duration = Luau::TimeTrace::getClock() - startTime; + if (compileFormat == CompileFormat::Null) - printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024)); + printf("Compiled %d KLOC into %d KB bytecode in %.2fs\n", int(stats.lines / 1000), int(stats.bytecode / 1024), duration); else if (compileFormat == CompileFormat::CodegenNull) - printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), - int(stats.codegen / 1024)); + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) in %.2fs\n", int(stats.lines / 1000), int(stats.bytecode / 1024), + int(stats.codegen / 1024), stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), duration); return failed ? 1 : 0; } diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 42f5f8a68..1a5f51370 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -37,7 +37,6 @@ class AssemblyBuilderA64 void movk(RegisterA64 dst, uint16_t src, int shift = 0); // Arithmetics - // TODO: support various kinds of shifts void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); void add(RegisterA64 dst, RegisterA64 src1, uint16_t src2); void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); @@ -52,13 +51,11 @@ class AssemblyBuilderA64 void cset(RegisterA64 dst, ConditionA64 cond); // Bitwise - // TODO: support shifts - // TODO: support bitfield ops - void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); - void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); - void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); - void bic(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); - void tst(RegisterA64 src1, RegisterA64 src2); + void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void bic(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void tst(RegisterA64 src1, RegisterA64 src2, int shift = 0); void mvn(RegisterA64 dst, RegisterA64 src); // Bitwise with immediate @@ -76,6 +73,13 @@ class AssemblyBuilderA64 void clz(RegisterA64 dst, RegisterA64 src); void rbit(RegisterA64 dst, RegisterA64 src); + // Shifts with immediates + // Note: immediate value must be in [0, 31] or [0, 63] range based on register type + void lsl(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + void lsr(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + void asr(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + void ror(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + // Load // Note: paired loads are currently omitted for simplicity void ldr(RegisterA64 dst, AddressA64 src); @@ -93,15 +97,19 @@ class AssemblyBuilderA64 void stp(RegisterA64 src1, RegisterA64 src2, AddressA64 dst); // Control flow - // TODO: support tbz/tbnz; they have 15-bit offsets but they can be useful in constrained cases void b(Label& label); - void b(ConditionA64 cond, Label& label); - void cbz(RegisterA64 src, Label& label); - void cbnz(RegisterA64 src, Label& label); + void bl(Label& label); void br(RegisterA64 src); void blr(RegisterA64 src); void ret(); + // Conditional control flow + void b(ConditionA64 cond, Label& label); + void cbz(RegisterA64 src, Label& label); + void cbnz(RegisterA64 src, Label& label); + void tbz(RegisterA64 src, uint8_t bit, Label& label); + void tbnz(RegisterA64 src, uint8_t bit, Label& label); + // Address of embedded data void adr(RegisterA64 dst, const void* ptr, size_t size); void adr(RegisterA64 dst, uint64_t value); @@ -111,7 +119,9 @@ class AssemblyBuilderA64 void adr(RegisterA64 dst, Label& label); // Floating-point scalar moves + // Note: constant must be compatible with immediate floating point moves (see isFmovSupported) void fmov(RegisterA64 dst, RegisterA64 src); + void fmov(RegisterA64 dst, double src); // Floating-point scalar math void fabs(RegisterA64 dst, RegisterA64 src); @@ -173,6 +183,12 @@ class AssemblyBuilderA64 // Maximum immediate argument to functions like add/sub/cmp static constexpr size_t kMaxImmediate = (1 << 12) - 1; + // Check if immediate mode mask is supported for bitwise operations (and/or/xor) + static bool isMaskSupported(uint32_t mask); + + // Check if fmov can be used to synthesize a constant + static bool isFmovSupported(double value); + private: // Instruction archetypes void place0(const char* name, uint32_t word); @@ -183,20 +199,38 @@ class AssemblyBuilderA64 void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size, int sizelog); + void placeB(const char* name, Label& label, uint8_t op); void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); void placeBR(const char* name, RegisterA64 src, uint32_t op); + void placeBTR(const char* name, Label& label, uint8_t op, RegisterA64 cond, uint8_t bit); void placeADR(const char* name, RegisterA64 src, uint8_t op); void placeADR(const char* name, RegisterA64 src, uint8_t op, Label& label); void placeP(const char* name, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src, uint8_t op, uint8_t opc, int sizelog); void placeCS(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc, int invert = 0); void placeFCMP(const char* name, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t opc); + void placeFMOV(const char* name, RegisterA64 dst, double src, uint32_t op); void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op); + void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, uint8_t src2, uint8_t op, int immr, int imms); void place(uint32_t word); - void patchLabel(Label& label); - void patchImm19(uint32_t location, int value); + struct Patch + { + enum Kind + { + Imm26, + Imm19, + Imm14, + }; + + Kind kind : 2; + uint32_t label : 30; + uint32_t location; + }; + + void patchLabel(Label& label, Patch::Kind kind); + void patchOffset(uint32_t location, int value, Patch::Kind kind); void commit(); LUAU_NOINLINE void extend(); @@ -210,9 +244,10 @@ class AssemblyBuilderA64 LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, double src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src); - LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label, int imm = -1); LUAU_NOINLINE void log(const char* opcode, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, Label label); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); @@ -221,7 +256,7 @@ class AssemblyBuilderA64 LUAU_NOINLINE void log(AddressA64 addr); uint32_t nextLabel = 1; - std::vector + + none + R{index&0xff}-v{index >> 8} + R{index&0xff} + K{index} + UP{index} + %{index} + + From 1c2ce0d73196d9bab0a44f72da6fa160e5d2c664 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 28 Apr 2023 14:55:55 +0300 Subject: [PATCH 49/66] Sync to upstream/release/574 --- .../include/Luau/ConstraintGraphBuilder.h | 12 +- Analysis/include/Luau/Frontend.h | 13 +- Analysis/include/Luau/Normalize.h | 3 + Analysis/include/Luau/Type.h | 35 +- Analysis/include/Luau/TypeInfer.h | 7 +- Analysis/include/Luau/TypeUtils.h | 6 + Analysis/include/Luau/VisitType.h | 8 +- Analysis/src/AstQuery.cpp | 4 +- Analysis/src/Autocomplete.cpp | 4 +- Analysis/src/BuiltinDefinitions.cpp | 7 +- Analysis/src/Clone.cpp | 39 +- Analysis/src/ConstraintGraphBuilder.cpp | 75 ++- Analysis/src/ConstraintSolver.cpp | 35 +- Analysis/src/Frontend.cpp | 148 ++--- Analysis/src/Normalize.cpp | 39 +- Analysis/src/Quantify.cpp | 2 +- Analysis/src/Substitution.cpp | 8 +- Analysis/src/ToDot.cpp | 4 +- Analysis/src/ToString.cpp | 2 +- Analysis/src/Type.cpp | 104 +++- Analysis/src/TypeAttach.cpp | 4 +- Analysis/src/TypeChecker2.cpp | 94 +-- Analysis/src/TypeInfer.cpp | 33 +- Analysis/src/TypeReduction.cpp | 10 +- Analysis/src/TypeUtils.cpp | 6 +- Analysis/src/Unifier.cpp | 24 +- CLI/Repl.cpp | 35 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 1 + CodeGen/include/Luau/AssemblyBuilderX64.h | 2 +- CodeGen/include/Luau/IrBuilder.h | 2 + CodeGen/include/Luau/IrData.h | 81 ++- CodeGen/include/Luau/IrUtils.h | 3 +- CodeGen/include/Luau/OptimizeConstProp.h | 4 +- CodeGen/include/Luau/RegisterA64.h | 34 ++ CodeGen/src/AssemblyBuilderA64.cpp | 24 +- CodeGen/src/AssemblyBuilderX64.cpp | 4 +- CodeGen/src/BitUtils.h | 20 + CodeGen/src/CodeGen.cpp | 103 ++-- CodeGen/src/CodeGenA64.cpp | 21 +- CodeGen/src/EmitBuiltinsX64.cpp | 52 +- CodeGen/src/EmitBuiltinsX64.h | 2 +- CodeGen/src/EmitCommonA64.h | 3 +- CodeGen/src/EmitInstructionX64.cpp | 11 +- CodeGen/src/IrAnalysis.cpp | 7 +- CodeGen/src/IrBuilder.cpp | 11 +- CodeGen/src/IrDump.cpp | 10 +- CodeGen/src/IrLoweringA64.cpp | 281 ++++----- CodeGen/src/IrLoweringA64.h | 3 + CodeGen/src/IrLoweringX64.cpp | 106 ++-- CodeGen/src/IrRegAllocA64.cpp | 154 +++-- CodeGen/src/IrRegAllocA64.h | 2 +- CodeGen/src/IrTranslateBuiltins.cpp | 295 +++++---- CodeGen/src/IrTranslation.cpp | 20 +- CodeGen/src/IrUtils.cpp | 176 +++++- CodeGen/src/IrValueLocationTracking.cpp | 1 - CodeGen/src/NativeState.h | 7 +- CodeGen/src/OptimizeConstProp.cpp | 193 +++++- CodeGen/src/OptimizeFinalX64.cpp | 1 - Compiler/include/Luau/BytecodeBuilder.h | 2 + Compiler/src/BytecodeBuilder.cpp | 5 + Compiler/src/Compiler.cpp | 6 + Makefile | 2 +- VM/src/ldo.cpp | 43 +- VM/src/ltablib.cpp | 47 +- tests/AssemblyBuilderA64.test.cpp | 10 + tests/Conformance.test.cpp | 6 +- tests/ConstraintGraphBuilderFixture.cpp | 4 +- tests/Fixture.cpp | 9 +- tests/Frontend.test.cpp | 17 + tests/IrBuilder.test.cpp | 561 ++++++++++++++++-- tests/Module.test.cpp | 16 +- tests/NonstrictMode.test.cpp | 8 +- tests/ToString.test.cpp | 12 +- tests/TypeInfer.annotations.test.cpp | 8 +- tests/TypeInfer.anyerror.test.cpp | 2 +- tests/TypeInfer.builtins.test.cpp | 2 +- tests/TypeInfer.functions.test.cpp | 4 +- tests/TypeInfer.generics.test.cpp | 6 +- tests/TypeInfer.negations.test.cpp | 29 + tests/TypeInfer.refinements.test.cpp | 32 + tests/TypeInfer.tables.test.cpp | 22 +- tests/TypeInfer.tryUnify.test.cpp | 8 +- tools/faillist.txt | 7 +- tools/test_dcr.py | 14 +- 84 files changed, 2237 insertions(+), 1040 deletions(-) diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index cbf679cc5..5800d146d 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -11,6 +11,7 @@ #include "Luau/Refinement.h" #include "Luau/Symbol.h" #include "Luau/Type.h" +#include "Luau/TypeUtils.h" #include "Luau/Variant.h" #include @@ -91,10 +92,14 @@ struct ConstraintGraphBuilder const NotNull ice; ScopePtr globalScope; + + std::function prepareModuleScope; + DcrLogger* logger; ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull builtinTypes, - NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, NotNull dfg); + NotNull ice, const ScopePtr& globalScope, std::function prepareModuleScope, + DcrLogger* logger, NotNull dfg); /** * Fabricates a new free type belonging to a given scope. @@ -174,11 +179,12 @@ struct ConstraintGraphBuilder * surrounding context. Used to implement bidirectional type checking. * @return the type of the expression. */ - Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}, bool forceSingleton = false); + Inference check(const ScopePtr& scope, AstExpr* expr, ValueContext context = ValueContext::RValue, std::optional expectedType = {}, + bool forceSingleton = false); Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); - Inference check(const ScopePtr& scope, AstExprLocal* local); + Inference check(const ScopePtr& scope, AstExprLocal* local, ValueContext context); Inference check(const ScopePtr& scope, AstExprGlobal* global); Inference check(const ScopePtr& scope, AstExprIndexName* indexName); Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 856c5dafa..67e840eec 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -53,9 +53,7 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa * error when we try during typechecking. */ std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& expr); -// TODO: Deprecate this code path when we move away from the old solver -LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition, - const std::string& packageName, bool captureComments); + struct SourceNode { bool hasDirtySourceModule() const @@ -209,10 +207,6 @@ struct Frontend GlobalTypes globals; GlobalTypes globalsForAutocomplete; - // TODO: remove with FFlagLuauOnDemandTypecheckers - TypeChecker typeChecker_DEPRECATED; - TypeChecker typeCheckerForAutocomplete_DEPRECATED; - ConfigResolver* configResolver; FrontendOptions options; InternalErrorReporter iceHandler; @@ -227,10 +221,11 @@ struct Frontend ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, FrontendOptions options); + const ScopePtr& globalScope, std::function prepareModuleScope, FrontendOptions options); ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, FrontendOptions options, bool recordJsonLog); + const ScopePtr& globalScope, std::function prepareModuleScope, FrontendOptions options, + bool recordJsonLog); } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index efcb51085..6c808286c 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -348,10 +348,13 @@ class Normalizer bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there); bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); bool intersectNormalWithTy(NormalizedType& here, TypeId there); + bool normalizeIntersections(const std::vector& intersections, NormalizedType& outType); // Check for inhabitance bool isInhabited(TypeId ty, std::unordered_set seen = {}); bool isInhabited(const NormalizedType* norm, std::unordered_set seen = {}); + // Check for intersections being inhabited + bool isIntersectionInhabited(TypeId left, TypeId right); // -------- Convert back from a normalized type to a type TypeId typeFromNormal(const NormalizedType& norm); diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 24fb7db0f..5d92cbd0b 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -301,7 +301,7 @@ struct MagicFunctionCallContext TypePackId result; }; -using DcrMagicFunction = bool (*)(MagicFunctionCallContext); +using DcrMagicFunction = std::function; struct MagicRefinementContext { @@ -379,12 +379,39 @@ struct TableIndexer struct Property { - TypeId type; + static Property readonly(TypeId ty); + static Property writeonly(TypeId ty); + static Property rw(TypeId ty); // Shared read-write type. + static Property rw(TypeId read, TypeId write); // Separate read-write type. + static std::optional create(std::optional read, std::optional write); + bool deprecated = false; std::string deprecatedSuggestion; std::optional location = std::nullopt; Tags tags; std::optional documentationSymbol; + + // DEPRECATED + // TODO: Kill all constructors in favor of `Property::rw(TypeId read, TypeId write)` and friends. + Property(); + Property(TypeId readTy, bool deprecated = false, const std::string& deprecatedSuggestion = "", std::optional location = std::nullopt, + const Tags& tags = {}, const std::optional& documentationSymbol = std::nullopt); + + // DEPRECATED: Should only be called in non-RWP! We assert that the `readTy` is not nullopt. + // TODO: Kill once we don't have non-RWP. + TypeId type() const; + void setType(TypeId ty); + + // Should only be called in RWP! + // We do not assert that `readTy` nor `writeTy` are nullopt or not. + // The invariant is that at least one of them mustn't be nullopt, which we do assert here. + // TODO: Kill this in favor of exposing `readTy`/`writeTy` directly? If we do, we'll lose the asserts which will be useful while debugging. + std::optional readType() const; + std::optional writeType() const; + +private: + std::optional readTy; + std::optional writeTy; }; struct TableType @@ -552,7 +579,7 @@ struct IntersectionType struct LazyType { LazyType() = default; - LazyType(std::function thunk_DEPRECATED, std::function unwrap) + LazyType(std::function thunk_DEPRECATED, std::function unwrap) : thunk_DEPRECATED(thunk_DEPRECATED) , unwrap(unwrap) { @@ -593,7 +620,7 @@ struct LazyType std::function thunk_DEPRECATED; - std::function unwrap; + std::function unwrap; std::atomic unwrapped = nullptr; }; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index b5db3f58d..cceff0db1 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -11,6 +11,7 @@ #include "Luau/TxnLog.h" #include "Luau/Type.h" #include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/UnifierSharedState.h" @@ -58,12 +59,6 @@ class TimeLimitError : public InternalCompilerError } }; -enum class ValueContext -{ - LValue, - RValue -}; - struct GlobalTypes { GlobalTypes(NotNull builtinTypes); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 42ba40522..86f20f387 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -15,6 +15,12 @@ namespace Luau struct TxnLog; struct TypeArena; +enum class ValueContext +{ + LValue, + RValue +}; + using ScopePtr = std::shared_ptr; std::optional findMetatableEntry( diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index c7dcdcc1e..663627d5e 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -9,7 +9,7 @@ #include "Luau/Type.h" LUAU_FASTINT(LuauVisitRecursionLimit) -LUAU_FASTFLAG(LuauBoundLazyTypes) +LUAU_FASTFLAG(LuauBoundLazyTypes2) namespace Luau { @@ -242,7 +242,7 @@ struct GenericTypeVisitor else { for (auto& [_name, prop] : ttv->props) - traverse(prop.type); + traverse(prop.type()); if (ttv->indexer) { @@ -265,7 +265,7 @@ struct GenericTypeVisitor if (visit(ty, *ctv)) { for (const auto& [name, prop] : ctv->props) - traverse(prop.type); + traverse(prop.type()); if (ctv->parent) traverse(*ctv->parent); @@ -294,7 +294,7 @@ struct GenericTypeVisitor } else if (auto ltv = get(ty)) { - if (FFlag::LuauBoundLazyTypes) + if (FFlag::LuauBoundLazyTypes2) { if (TypeId unwrapped = ltv->unwrapped) traverse(unwrapped); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index cb3efe6a6..38f3bdf5c 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -501,12 +501,12 @@ std::optional getDocumentationSymbolAtPosition(const Source if (const TableType* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) - return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); } else if (const ClassType* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) - return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); } } } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 42fc9a717..4b66568b5 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -260,7 +260,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul // already populated, it takes precedence over the property we found just now. if (result.count(name) == 0 && name != kParseNameError) { - Luau::TypeId type = Luau::follow(prop.type); + Luau::TypeId type = Luau::follow(prop.type()); TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); @@ -287,7 +287,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { - TypeId followed = follow(indexIt->second.type); + TypeId followed = follow(indexIt->second.type()); if (get(followed) || get(followed)) { autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 8988b332e..c55a88ebf 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -52,6 +52,7 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types) TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t) { + LUAU_ASSERT(t); return makeUnion(arena, {builtinTypes->nilType, t}); } @@ -236,7 +237,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC auto it = stringMetatableTable->props.find("__index"); LUAU_ASSERT(it != stringMetatableTable->props.end()); - addGlobalBinding(globals, "string", it->second.type, "@luau"); + addGlobalBinding(globals, "string", it->second.type(), "@luau"); // next(t: Table, i: K?) -> (K?, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); @@ -301,8 +302,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC ttv->props["foreach"].deprecated = true; ttv->props["foreachi"].deprecated = true; - attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); - attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); + attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack); + attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); } attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index f5102654f..450b84af9 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -8,6 +8,7 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) +LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) @@ -17,6 +18,40 @@ namespace Luau namespace { +Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState) +{ + if (FFlag::DebugLuauReadWriteProperties) + { + std::optional cloneReadTy; + if (auto ty = prop.readType()) + cloneReadTy = clone(*ty, dest, cloneState); + + std::optional cloneWriteTy; + if (auto ty = prop.writeType()) + cloneWriteTy = clone(*ty, dest, cloneState); + + std::optional cloned = Property::create(cloneReadTy, cloneWriteTy); + LUAU_ASSERT(cloned); + cloned->deprecated = prop.deprecated; + cloned->deprecatedSuggestion = prop.deprecatedSuggestion; + cloned->location = prop.location; + cloned->tags = prop.tags; + cloned->documentationSymbol = prop.documentationSymbol; + return *cloned; + } + else + { + return Property{ + clone(prop.type(), dest, cloneState), + prop.deprecated, + prop.deprecatedSuggestion, + prop.location, + prop.tags, + prop.documentationSymbol, + }; + } +} + struct TypePackCloner; /* @@ -251,7 +286,7 @@ void TypeCloner::operator()(const TableType& t) ttv->boundTo = clone(*t.boundTo, dest, cloneState); for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + ttv->props[name] = clone(prop, dest, cloneState); if (t.indexer) ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; @@ -285,7 +320,7 @@ void TypeCloner::operator()(const ClassType& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + ctv->props[name] = clone(prop, dest, cloneState); if (t.parent) ctv->parent = clone(*t.parent, dest, cloneState); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index ad7cff9f7..611f420a9 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -134,8 +134,8 @@ void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const Con } // namespace ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, - NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, - NotNull dfg) + NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, + std::function prepareModuleScope, DcrLogger* logger, NotNull dfg) : module(module) , builtinTypes(builtinTypes) , arena(arena) @@ -144,6 +144,7 @@ ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, TypeArena* aren , moduleResolver(moduleResolver) , ice(ice) , globalScope(globalScope) + , prepareModuleScope(std::move(prepareModuleScope)) , logger(logger) { LUAU_ASSERT(module); @@ -510,7 +511,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* l if (hasAnnotation) expectedType = varTypes.at(i); - TypeId exprType = check(scope, value, expectedType).ty; + TypeId exprType = check(scope, value, ValueContext::RValue, expectedType).ty; if (i < varTypes.size()) { if (varTypes[i]) @@ -898,7 +899,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompound ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { - RefinementId refinement = check(scope, ifStatement->condition, std::nullopt).refinement; + RefinementId refinement = check(scope, ifStatement->condition, ValueContext::RValue, std::nullopt).refinement; ScopePtr thenScope = childScope(ifStatement->thenbody, scope); applyRefinements(thenScope, ifStatement->condition->location, refinement); @@ -1081,7 +1082,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareC } else { - TypeId currentTy = assignToMetatable ? metatable->props[propName].type : ctv->props[propName].type; + TypeId currentTy = assignToMetatable ? metatable->props[propName].type() : ctv->props[propName].type(); // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. @@ -1182,7 +1183,7 @@ InferencePack ConstraintGraphBuilder::checkPack( std::optional expectedType; if (i < expectedTypes.size()) expectedType = expectedTypes[i]; - head.push_back(check(scope, expr, expectedType).ty); + head.push_back(check(scope, expr, ValueContext::RValue, expectedType).ty); } else { @@ -1225,7 +1226,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* std::optional expectedType; if (!expectedTypes.empty()) expectedType = expectedTypes[0]; - TypeId t = check(scope, expr, expectedType).ty; + TypeId t = check(scope, expr, ValueContext::RValue, expectedType).ty; result = InferencePack{arena->addTypePack({t})}; } @@ -1332,7 +1333,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) { - auto [ty, refinement] = check(scope, arg, expectedType); + auto [ty, refinement] = check(scope, arg, ValueContext::RValue, expectedType); args.push_back(ty); argumentRefinements.push_back(refinement); } @@ -1434,7 +1435,8 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton) +Inference ConstraintGraphBuilder::check( + const ScopePtr& scope, AstExpr* expr, ValueContext context, std::optional expectedType, bool forceSingleton) { RecursionCounter counter{&recursionCount}; @@ -1447,7 +1449,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st Inference result; if (auto group = expr->as()) - result = check(scope, group->expr, expectedType, forceSingleton); + result = check(scope, group->expr, ValueContext::RValue, expectedType, forceSingleton); else if (auto stringExpr = expr->as()) result = check(scope, stringExpr, expectedType, forceSingleton); else if (expr->is()) @@ -1457,7 +1459,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st else if (expr->is()) result = Inference{builtinTypes->nilType}; else if (auto local = expr->as()) - result = check(scope, local); + result = check(scope, local, context); else if (auto global = expr->as()) result = check(scope, global); else if (expr->is()) @@ -1566,11 +1568,11 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo return Inference{builtinTypes->booleanType}; } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local, ValueContext context) { BreadcrumbId bc = dfg->getBreadcrumb(local); - if (auto ty = scope->lookup(bc->def)) + if (auto ty = scope->lookup(bc->def); ty && context == ValueContext::RValue) return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; else if (auto ty = scope->lookup(local->local)) return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; @@ -1676,18 +1678,18 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* if ScopePtr thenScope = childScope(ifElse->trueExpr, scope); applyRefinements(thenScope, ifElse->trueExpr->location, refinement); - TypeId thenType = check(thenScope, ifElse->trueExpr, expectedType).ty; + TypeId thenType = check(thenScope, ifElse->trueExpr, ValueContext::RValue, expectedType).ty; ScopePtr elseScope = childScope(ifElse->falseExpr, scope); applyRefinements(elseScope, ifElse->falseExpr->location, refinementArena.negation(refinement)); - TypeId elseType = check(elseScope, ifElse->falseExpr, expectedType).ty; + TypeId elseType = check(elseScope, ifElse->falseExpr, ValueContext::RValue, expectedType).ty; return Inference{expectedType ? *expectedType : arena->addType(UnionType{{thenType, elseType}})}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) { - check(scope, typeAssert->expr, std::nullopt); + check(scope, typeAssert->expr, ValueContext::RValue, std::nullopt); return Inference{resolveType(scope, typeAssert->annotation, /* inTypeArguments */ false)}; } @@ -1704,21 +1706,31 @@ std::tuple ConstraintGraphBuilder::checkBinary( { if (binary->op == AstExprBinary::And) { - auto [leftType, leftRefinement] = check(scope, binary->left, expectedType); + std::optional relaxedExpectedLhs; + + if (expectedType) + relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); + + auto [leftType, leftRefinement] = check(scope, binary->left, ValueContext::RValue, relaxedExpectedLhs); ScopePtr rightScope = childScope(binary->right, scope); applyRefinements(rightScope, binary->right->location, leftRefinement); - auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); + auto [rightType, rightRefinement] = check(rightScope, binary->right, ValueContext::RValue, expectedType); return {leftType, rightType, refinementArena.conjunction(leftRefinement, rightRefinement)}; } else if (binary->op == AstExprBinary::Or) { - auto [leftType, leftRefinement] = check(scope, binary->left, expectedType); + std::optional relaxedExpectedLhs; + + if (expectedType) + relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); + + auto [leftType, leftRefinement] = check(scope, binary->left, ValueContext::RValue, relaxedExpectedLhs); ScopePtr rightScope = childScope(binary->right, scope); applyRefinements(rightScope, binary->right->location, refinementArena.negation(leftRefinement)); - auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); + auto [rightType, rightRefinement] = check(rightScope, binary->right, ValueContext::RValue, expectedType); return {leftType, rightType, refinementArena.disjunction(leftRefinement, rightRefinement)}; } @@ -1774,8 +1786,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( } else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) { - TypeId leftType = check(scope, binary->left, expectedType, true).ty; - TypeId rightType = check(scope, binary->right, expectedType, true).ty; + TypeId leftType = check(scope, binary->left, ValueContext::RValue, expectedType, true).ty; + TypeId rightType = check(scope, binary->right, ValueContext::RValue, expectedType, true).ty; RefinementId leftRefinement = nullptr; if (auto bc = dfg->getBreadcrumb(binary->left)) @@ -1795,8 +1807,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( } else { - TypeId leftType = check(scope, binary->left, expectedType).ty; - TypeId rightType = check(scope, binary->right, expectedType).ty; + TypeId leftType = check(scope, binary->left, ValueContext::RValue).ty; + TypeId rightType = check(scope, binary->right, ValueContext::RValue).ty; return {leftType, rightType, nullptr}; } } @@ -1859,7 +1871,7 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) return propType; } else if (!isIndexNameEquivalent(expr)) - return check(scope, expr).ty; + return check(scope, expr, ValueContext::LValue).ty; Symbol sym; std::vector segments; @@ -1894,11 +1906,11 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) } else { - return check(scope, expr).ty; + return check(scope, expr, ValueContext::LValue).ty; } } else - return check(scope, expr).ty; + return check(scope, expr, ValueContext::LValue).ty; } LUAU_ASSERT(!segments.empty()); @@ -1908,7 +1920,7 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) auto lookupResult = scope->lookupEx(sym); if (!lookupResult) - return check(scope, expr).ty; + return check(scope, expr, ValueContext::LValue).ty; const auto [subjectBinding, symbolScope] = std::move(*lookupResult); TypeId subjectType = subjectBinding->typeId; @@ -2029,7 +2041,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp checkExpectedIndexResultType = pinnedIndexResultType; } - TypeId itemTy = check(scope, item.value, checkExpectedIndexResultType).ty; + TypeId itemTy = check(scope, item.value, ValueContext::RValue, checkExpectedIndexResultType).ty; if (isIndexedResultType && !pinnedIndexResultType) pinnedIndexResultType = itemTy; @@ -2039,7 +2051,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp // Even though we don't need to use the type of the item's key if // it's a string constant, we still want to check it to populate // astTypes. - TypeId keyTy = check(scope, item.key, annotatedKeyType).ty; + TypeId keyTy = check(scope, item.key, ValueContext::RValue, annotatedKeyType).ty; if (AstExprConstantString* key = item.key->as()) { @@ -2646,6 +2658,9 @@ void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, { GlobalPrepopulator gp{NotNull{globalScope.get()}, arena}; + if (prepareModuleScope) + prepareModuleScope(module->name, globalScope); + program->visit(&gp); } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 558ad2d51..ec63b25e6 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -472,16 +472,20 @@ bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) { + TypeId generalizedType = follow(c.generalizedType); + if (isBlocked(c.sourceType)) return block(c.sourceType, constraint); + else if (get(generalizedType)) + return block(generalizedType, constraint); std::optional generalized = quantify(arena, c.sourceType, constraint->scope); if (generalized) { - if (isBlocked(c.generalizedType)) - asMutable(c.generalizedType)->ty.emplace(*generalized); + if (get(generalizedType)) + asMutable(generalizedType)->ty.emplace(*generalized); else - unify(c.generalizedType, *generalized, constraint->scope); + unify(generalizedType, *generalized, constraint->scope); } else { @@ -505,10 +509,8 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - if (isBlocked(c.subType)) - asMutable(c.subType)->ty.emplace(*instantiated); - else - unify(c.subType, *instantiated, constraint->scope); + LUAU_ASSERT(get(c.subType)); + asMutable(c.subType)->ty.emplace(*instantiated); unblock(c.subType); @@ -586,6 +588,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(resultType)); + bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or; /* Compound assignments create constraints of the form @@ -979,6 +983,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul } auto bindResult = [this, &c](TypeId result) { + LUAU_ASSERT(get(c.target)); asMutable(c.target)->ty.emplace(result); unblock(c.target); }; @@ -1280,6 +1285,8 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull(expectedType)) return block(expectedType, constraint); + LUAU_ASSERT(get(c.resultType)); + TypeId bindTo = maybeSingleton(expectedType) ? c.singletonType : c.multitonType; asMutable(c.resultType)->ty.emplace(bindTo); unblock(c.resultType); @@ -1291,6 +1298,8 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(c.resultType)); + if (isBlocked(subjectType) || get(subjectType)) return block(subjectType, constraint); @@ -1351,7 +1360,7 @@ static void updateTheTableType( if (it == tbl->props.end()) return; - t = follow(it->second.type); + t = follow(it->second.type()); } // The last path segment should not be a property of the table at all. @@ -1388,7 +1397,7 @@ static void updateTheTableType( if (!tt) return; - tt->props[lastSegment].type = replaceTy; + tt->props[lastSegment].setType(replaceTy); } bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force) @@ -1853,7 +1862,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa else if (auto ttv = getMutable(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) - return {{}, prop->second.type}; + return {{}, prop->second.type()}; else if (ttv->indexer && maybeString(ttv->indexer->indexType)) return {{}, ttv->indexer->indexResultType}; else if (ttv->state == TableState::Free) @@ -1881,7 +1890,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa // TODO: __index can be an overloaded function. - TypeId indexType = follow(indexProp->second.type); + TypeId indexType = follow(indexProp->second.type()); if (auto ft = get(indexType)) { @@ -1902,7 +1911,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa else if (auto ct = get(subjectType)) { if (auto p = lookupClassProp(ct, propName)) - return {{}, p->type}; + return {{}, p->type()}; } else if (auto pt = get(subjectType); pt && pt->metatable) { @@ -1913,7 +1922,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (indexProp == metatable->props.end()) return {{}, std::nullopt}; - return lookupTableProp(indexProp->second.type, propName, seen); + return lookupTableProp(indexProp->second.type(), propName, seen); } else if (auto ft = get(subjectType)) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 916dd1d57..486ef6960 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -32,8 +32,8 @@ LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) -LUAU_FASTFLAGVARIABLE(LuauOnDemandTypecheckers, false) LUAU_FASTFLAG(LuauRequirePathTrueModuleName) +LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) namespace Luau { @@ -133,10 +133,6 @@ static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, S LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete) { - if (!FFlag::DebugLuauDeferredConstraintResolution && !FFlag::LuauOnDemandTypecheckers) - return Luau::loadDefinitionFileNoDCR(typeCheckForAutocomplete ? typeCheckerForAutocomplete_DEPRECATED : typeChecker_DEPRECATED, - typeCheckForAutocomplete ? globalsForAutocomplete : globals, targetScope, source, packageName, captureComments); - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); Luau::SourceModule sourceModule; @@ -154,28 +150,6 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, Scop return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } -LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, - const std::string& packageName, bool captureComments) -{ - LUAU_ASSERT(!FFlag::LuauOnDemandTypecheckers); - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - - Luau::SourceModule sourceModule; - Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments); - - if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - - ModulePtr checkedModule = typeChecker.check(sourceModule, Mode::Definition); - - if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; - - persistCheckedTypes(checkedModule, globals, targetScope, packageName); - - return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; -} - std::vector parsePathExpr(const AstExpr& pathExpr) { const AstExprIndexName* indexName = pathExpr.as(); @@ -409,8 +383,6 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c , moduleResolverForAutocomplete(this) , globals(builtinTypes) , globalsForAutocomplete(builtinTypes) - , typeChecker_DEPRECATED(globals.globalScope, &moduleResolver, builtinTypes, &iceHandler) - , typeCheckerForAutocomplete_DEPRECATED(globalsForAutocomplete.globalScope, &moduleResolverForAutocomplete, builtinTypes, &iceHandler) , configResolver(configResolver) , options(options) { @@ -479,68 +451,32 @@ CheckResult Frontend::check(const ModuleName& name, std::optional 0) - typeCheckerForAutocomplete_DEPRECATED.instantiationChildLimit = - std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete_DEPRECATED.instantiationChildLimit = std::nullopt; - - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckerForAutocomplete_DEPRECATED.unifierIterationLimit = - std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete_DEPRECATED.unifierIterationLimit = std::nullopt; - - moduleForAutocomplete = - FFlag::DebugLuauDeferredConstraintResolution - ? check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, /*recordJsonLog*/ false, {}) - : typeCheckerForAutocomplete_DEPRECATED.check(sourceModule, Mode::Strict, environmentScope); - } + // The autocomplete typecheck is always in strict mode with DM awareness + // to provide better type information for IDE features + TypeCheckLimits typeCheckLimits; + + if (autocompleteTimeLimit != 0.0) + typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; else - { - // The autocomplete typecheck is always in strict mode with DM awareness - // to provide better type information for IDE features - TypeCheckLimits typeCheckLimits; - - if (autocompleteTimeLimit != 0.0) - typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; - else - typeCheckLimits.finishTime = std::nullopt; - - // TODO: This is a dirty ad hoc solution for autocomplete timeouts - // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit - // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle - if (FInt::LuauTarjanChildLimit > 0) - typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.instantiationChildLimit = std::nullopt; - - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.unifierIterationLimit = std::nullopt; - - moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, - /*recordJsonLog*/ false, typeCheckLimits); - } + typeCheckLimits.finishTime = std::nullopt; + + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.unifierIterationLimit = std::nullopt; + + ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, + /*recordJsonLog*/ false, typeCheckLimits); resolver.setModule(moduleName, moduleForAutocomplete); @@ -565,21 +501,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalget(global.c_str()); if (name.value) - result->bindings[name].typeId = FFlag::LuauOnDemandTypecheckers ? builtinTypes->anyType : typeChecker_DEPRECATED.anyType; + result->bindings[name].typeId = builtinTypes->anyType; } } @@ -856,15 +778,17 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& parentScope, FrontendOptions options) + const ScopePtr& parentScope, std::function prepareModuleScope, FrontendOptions options) { const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson; - return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, parentScope, options, recordJsonLog); + return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, parentScope, std::move(prepareModuleScope), + options, recordJsonLog); } ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& parentScope, FrontendOptions options, bool recordJsonLog) + const ScopePtr& parentScope, std::function prepareModuleScope, FrontendOptions options, + bool recordJsonLog) { ModulePtr result = std::make_shared(); result->name = sourceModule.name; @@ -897,6 +821,7 @@ ModulePtr check(const SourceModule& sourceModule, const std::vector seen) { for (const auto& [_, prop] : ttv->props) { - if (!isInhabited(prop.type, seen)) + if (!isInhabited(prop.type(), seen)) return false; } return true; @@ -316,6 +316,20 @@ bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) return isInhabited(norm, seen); } +bool Normalizer::isIntersectionInhabited(TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + std::unordered_set seen = {}; + seen.insert(left); + seen.insert(right); + + NormalizedType norm{builtinTypes}; + if (!normalizeIntersections({left, right}, norm)) + return false; + return isInhabited(&norm, seen); +} + static int tyvarIndex(TypeId ty) { if (const GenericType* gtv = get(ty)) @@ -593,6 +607,23 @@ const NormalizedType* Normalizer::normalize(TypeId ty) return result; } +bool Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType) +{ + if (!arena) + sharedState->iceHandler->ice("Normalizing types outside a module"); + NormalizedType norm{builtinTypes}; + norm.tops = builtinTypes->anyType; + // Now we need to intersect the two types + for (auto ty : intersections) + if (!intersectNormalWithTy(norm, ty)) + return false; + + if (!unionNormals(outType, norm)) + return false; + + return true; +} + void Normalizer::clearNormal(NormalizedType& norm) { norm.tops = builtinTypes->neverType; @@ -2134,9 +2165,9 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there { const auto& [_name, tprop] = *tfound; // TODO: variance issues here, which can't be fixed until we have read/write property types - prop.type = intersectionType(hprop.type, tprop.type); - hereSubThere &= (prop.type == hprop.type); - thereSubHere &= (prop.type == tprop.type); + prop.setType(intersectionType(hprop.type(), tprop.type())); + hereSubThere &= (prop.type() == hprop.type()); + thereSubHere &= (prop.type() == tprop.type()); } // TODO: string indexers result.props[name] = prop; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 0b8f46248..0a7975f4d 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -116,7 +116,7 @@ void quantify(TypeId ty, TypeLevel level) for (const auto& [_, prop] : ttv->props) { - auto ftv = getMutable(follow(prop.type)); + auto ftv = getMutable(follow(prop.type())); if (!ftv || !ftv->hasSelf) continue; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 962172172..6a600b626 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -219,7 +219,7 @@ void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) - visitChild(prop.type); + visitChild(prop.type()); if (ttv->indexer) { visitChild(ttv->indexer->indexType); @@ -258,7 +258,7 @@ void Tarjan::visitChildren(TypeId ty, int index) else if (const ClassType* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) { for (const auto& [name, prop] : ctv->props) - visitChild(prop.type); + visitChild(prop.type()); if (ctv->parent) visitChild(*ctv->parent); @@ -750,7 +750,7 @@ void Substitution::replaceChildren(TypeId ty) { LUAU_ASSERT(!ttv->boundTo); for (auto& [name, prop] : ttv->props) - prop.type = replace(prop.type); + prop.setType(replace(prop.type())); if (ttv->indexer) { ttv->indexer->indexType = replace(ttv->indexer->indexType); @@ -789,7 +789,7 @@ void Substitution::replaceChildren(TypeId ty) else if (ClassType* ctv = getMutable(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) { for (auto& [name, prop] : ctv->props) - prop.type = replace(prop.type); + prop.setType(replace(prop.type())); if (ctv->parent) ctv->parent = replace(*ctv->parent); diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 117d39d20..8d889cb58 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -171,7 +171,7 @@ void StateDot::visitChildren(TypeId ty, int index) return visitChild(*ttv->boundTo, index, "boundTo"); for (const auto& [name, prop] : ttv->props) - visitChild(prop.type, index, name.c_str()); + visitChild(prop.type(), index, name.c_str()); if (ttv->indexer) { visitChild(ttv->indexer->indexType, index, "[index]"); @@ -250,7 +250,7 @@ void StateDot::visitChildren(TypeId ty, int index) finishNode(); for (const auto& [name, prop] : ctv->props) - visitChild(prop.type, index, name.c_str()); + visitChild(prop.type(), index, name.c_str()); if (ctv->parent) visitChild(*ctv->parent, index, "[parent]"); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 46d2e8f8f..ea3ab5775 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -660,7 +660,7 @@ struct TypeStringifier state.emit("\"]"); } state.emit(": "); - stringify(prop.type); + stringify(prop.type()); comma = true; ++index; } diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index e4d9ab33f..2ca39b41a 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -26,7 +26,8 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) -LUAU_FASTFLAGVARIABLE(LuauBoundLazyTypes, false) +LUAU_FASTFLAG(DebugLuauReadWriteProperties) +LUAU_FASTFLAGVARIABLE(LuauBoundLazyTypes2, false) namespace Luau { @@ -57,7 +58,7 @@ TypeId follow(TypeId t) TypeId follow(TypeId t, std::function mapper) { auto advance = [&mapper](TypeId ty) -> std::optional { - if (FFlag::LuauBoundLazyTypes) + if (FFlag::LuauBoundLazyTypes2) { TypeId mapped = mapper(ty); @@ -74,7 +75,8 @@ TypeId follow(TypeId t, std::function mapper) if (unwrapped) return unwrapped; - unwrapped = ltv->unwrap(*ltv); + ltv->unwrap(*ltv); + unwrapped = ltv->unwrapped.load(); if (!unwrapped) throw InternalCompilerError("Lazy Type didn't fill in unwrapped type field"); @@ -109,7 +111,7 @@ TypeId follow(TypeId t, std::function mapper) } }; - if (!FFlag::LuauBoundLazyTypes) + if (!FFlag::LuauBoundLazyTypes2) force(t); TypeId cycleTester = t; // Null once we've determined that there is no cycle @@ -120,7 +122,7 @@ TypeId follow(TypeId t, std::function mapper) while (true) { - if (!FFlag::LuauBoundLazyTypes) + if (!FFlag::LuauBoundLazyTypes2) force(t); auto a1 = advance(t); @@ -622,6 +624,92 @@ FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector ge { } +Property::Property() {} + +Property::Property(TypeId readTy, bool deprecated, const std::string& deprecatedSuggestion, std::optional location, const Tags& tags, + const std::optional& documentationSymbol) + : deprecated(deprecated) + , deprecatedSuggestion(deprecatedSuggestion) + , location(location) + , tags(tags) + , documentationSymbol(documentationSymbol) + , readTy(readTy) + , writeTy(readTy) +{ + LUAU_ASSERT(!FFlag::DebugLuauReadWriteProperties); +} + +Property Property::readonly(TypeId ty) +{ + LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); + + Property p; + p.readTy = ty; + return p; +} + +Property Property::writeonly(TypeId ty) +{ + LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); + + Property p; + p.writeTy = ty; + return p; +} + +Property Property::rw(TypeId ty) +{ + return Property::rw(ty, ty); +} + +Property Property::rw(TypeId read, TypeId write) +{ + LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); + + Property p; + p.readTy = read; + p.writeTy = write; + return p; +} + +std::optional Property::create(std::optional read, std::optional write) +{ + if (read && !write) + return Property::readonly(*read); + else if (!read && write) + return Property::writeonly(*write); + else if (read && write) + return Property::rw(*read, *write); + else + return std::nullopt; +} + +TypeId Property::type() const +{ + LUAU_ASSERT(!FFlag::DebugLuauReadWriteProperties); + LUAU_ASSERT(readTy); + return *readTy; +} + +void Property::setType(TypeId ty) +{ + readTy = ty; +} + +std::optional Property::readType() const +{ + LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); + LUAU_ASSERT(!(bool(readTy) && bool(writeTy))); + return readTy; +} + +std::optional Property::writeType() const +{ + LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); + LUAU_ASSERT(!(bool(readTy) && bool(writeTy))); + return writeTy; +} + TableType::TableType(TableState state, TypeLevel level, Scope* scope) : state(state) , level(level) @@ -709,7 +797,7 @@ bool areEqual(SeenSet& seen, const TableType& lhs, const TableType& rhs) if (l->first != r->first) return false; - if (!areEqual(seen, *l->second.type, *r->second.type)) + if (!areEqual(seen, *l->second.type(), *r->second.type())) return false; ++l; ++r; @@ -1011,7 +1099,7 @@ void persist(TypeId ty) LUAU_ASSERT(ttv->state != TableState::Free && ttv->state != TableState::Unsealed); for (const auto& [_name, prop] : ttv->props) - queue.push_back(prop.type); + queue.push_back(prop.type()); if (ttv->indexer) { @@ -1022,7 +1110,7 @@ void persist(TypeId ty) else if (auto ctv = get(t)) { for (const auto& [_name, prop] : ctv->props) - queue.push_back(prop.type); + queue.push_back(prop.type()); } else if (auto utv = get(t)) { diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 7ed4eb49b..86f781650 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -180,7 +180,7 @@ class TypeRehydrationVisitor char* name = allocateString(*allocator, propName); props.data[idx].name = AstName(name); - props.data[idx].type = Luau::visit(*this, prop.type->ty); + props.data[idx].type = Luau::visit(*this, prop.type()->ty); props.data[idx].location = Location(); idx++; } @@ -221,7 +221,7 @@ class TypeRehydrationVisitor char* name = allocateString(*allocator, propName); props.data[idx].name = AstName{name}; - props.data[idx].type = Luau::visit(*this, prop.type->ty); + props.data[idx].type = Luau::visit(*this, prop.type()->ty); props.data[idx].location = Location(); idx++; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 893f51d97..a103df145 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -171,6 +171,22 @@ struct TypeChecker2 return follow(*tp); } + TypeId lookupExpectedType(AstExpr* expr) + { + if (TypeId* ty = module->astExpectedTypes.find(expr)) + return follow(*ty); + + return builtinTypes->anyType; + } + + TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena) + { + if (TypeId* ty = module->astExpectedTypes.find(expr)) + return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt}); + + return builtinTypes->anyTypePack; + } + TypePackId reconstructPack(AstArray exprs, TypeArena& arena) { if (exprs.size == 0) @@ -208,12 +224,6 @@ struct TypeChecker2 return bestScope; } - enum ValueContext - { - LValue, - RValue - }; - void visit(AstStat* stat) { auto pusher = pushStack(stat); @@ -272,7 +282,7 @@ struct TypeChecker2 void visit(AstStatIf* ifStatement) { - visit(ifStatement->condition, RValue); + visit(ifStatement->condition, ValueContext::RValue); visit(ifStatement->thenbody); if (ifStatement->elsebody) visit(ifStatement->elsebody); @@ -280,14 +290,14 @@ struct TypeChecker2 void visit(AstStatWhile* whileStatement) { - visit(whileStatement->condition, RValue); + visit(whileStatement->condition, ValueContext::RValue); visit(whileStatement->body); } void visit(AstStatRepeat* repeatStatement) { visit(repeatStatement->body); - visit(repeatStatement->condition, RValue); + visit(repeatStatement->condition, ValueContext::RValue); } void visit(AstStatBreak*) {} @@ -314,12 +324,12 @@ struct TypeChecker2 } for (AstExpr* expr : ret->list) - visit(expr, RValue); + visit(expr, ValueContext::RValue); } void visit(AstStatExpr* expr) { - visit(expr->expr, RValue); + visit(expr->expr, ValueContext::RValue); } void visit(AstStatLocal* local) @@ -331,7 +341,7 @@ struct TypeChecker2 const bool isPack = value && (value->is() || value->is()); if (value) - visit(value, RValue); + visit(value, ValueContext::RValue); if (i != local->values.size - 1 || !isPack) { @@ -412,7 +422,7 @@ struct TypeChecker2 if (!expr) return; - visit(expr, RValue); + visit(expr, ValueContext::RValue); reportErrors(tryUnify(scope, expr->location, lookupType(expr), builtinTypes->numberType)); }; @@ -432,7 +442,7 @@ struct TypeChecker2 } for (AstExpr* expr : forInStatement->values) - visit(expr, RValue); + visit(expr, ValueContext::RValue); visit(forInStatement->body); @@ -643,11 +653,11 @@ struct TypeChecker2 for (size_t i = 0; i < count; ++i) { AstExpr* lhs = assign->vars.data[i]; - visit(lhs, LValue); + visit(lhs, ValueContext::LValue); TypeId lhsType = lookupType(lhs); AstExpr* rhs = assign->values.data[i]; - visit(rhs, RValue); + visit(rhs, ValueContext::RValue); TypeId rhsType = lookupType(rhs); if (get(lhsType)) @@ -671,7 +681,7 @@ struct TypeChecker2 void visit(AstStatFunction* stat) { - visit(stat->name, LValue); + visit(stat->name, ValueContext::LValue); visit(stat->func); } @@ -724,7 +734,7 @@ struct TypeChecker2 void visit(AstStatError* stat) { for (AstExpr* expr : stat->expressions) - visit(expr, RValue); + visit(expr, ValueContext::RValue); for (AstStat* s : stat->statements) visit(s); @@ -926,7 +936,7 @@ struct TypeChecker2 TypeArena* arena = &testArena; Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; - TypePackId expectedRetType = lookupPack(call); + TypePackId expectedRetType = lookupExpectedPack(call, *arena); TypeId functionType = lookupType(call->func); TypeId testFunctionType = functionType; TypePack args; @@ -1105,10 +1115,10 @@ struct TypeChecker2 void visit(AstExprCall* call) { - visit(call->func, RValue); + visit(call->func, ValueContext::RValue); for (AstExpr* arg : call->args) - visit(arg, RValue); + visit(arg, ValueContext::RValue); visitCall(call); } @@ -1158,7 +1168,7 @@ struct TypeChecker2 void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) { - visit(expr, RValue); + visit(expr, ValueContext::RValue); TypeId leftType = stripFromNilAndReport(lookupType(expr), location); checkIndexTypeFromType(leftType, propName, location, context); @@ -1179,8 +1189,8 @@ struct TypeChecker2 } // TODO! - visit(indexExpr->expr, LValue); - visit(indexExpr->index, RValue); + visit(indexExpr->expr, ValueContext::LValue); + visit(indexExpr->index, ValueContext::RValue); NotNull scope = stack.back(); @@ -1242,14 +1252,14 @@ struct TypeChecker2 for (const AstExprTable::Item& item : expr->items) { if (item.key) - visit(item.key, LValue); - visit(item.value, RValue); + visit(item.key, ValueContext::LValue); + visit(item.value, ValueContext::RValue); } } void visit(AstExprUnary* expr) { - visit(expr->expr, RValue); + visit(expr->expr, ValueContext::RValue); NotNull scope = stack.back(); TypeId operandType = lookupType(expr->expr); @@ -1330,8 +1340,8 @@ struct TypeChecker2 TypeId visit(AstExprBinary* expr, AstNode* overrideKey = nullptr) { - visit(expr->left, LValue); - visit(expr->right, LValue); + visit(expr->left, ValueContext::LValue); + visit(expr->right, ValueContext::LValue); NotNull scope = stack.back(); @@ -1363,11 +1373,14 @@ struct TypeChecker2 return leftType; } + bool typesHaveIntersection = normalizer.isIntersectionInhabited(leftType, rightType); if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) { std::optional leftMt = getMetatable(leftType, builtinTypes); std::optional rightMt = getMetatable(rightType, builtinTypes); bool matches = leftMt == rightMt; + + if (isEquality && !matches) { auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional otherMt) { @@ -1390,6 +1403,13 @@ struct TypeChecker2 { testUnion(utv, leftMt); } + + // If either left or right has no metatable (or both), we need to consider if + // there are values in common that could possibly inhabit the type (and thus equality could be considered) + if (!leftMt.has_value() || !rightMt.has_value()) + { + matches = matches || typesHaveIntersection; + } } if (!matches && isComparison) @@ -1584,7 +1604,7 @@ struct TypeChecker2 void visit(AstExprTypeAssertion* expr) { - visit(expr->expr, RValue); + visit(expr->expr, ValueContext::RValue); visit(expr->annotation); TypeId annotationType = lookupAnnotation(expr->annotation); @@ -1603,22 +1623,22 @@ struct TypeChecker2 void visit(AstExprIfElse* expr) { // TODO! - visit(expr->condition, RValue); - visit(expr->trueExpr, RValue); - visit(expr->falseExpr, RValue); + visit(expr->condition, ValueContext::RValue); + visit(expr->trueExpr, ValueContext::RValue); + visit(expr->falseExpr, ValueContext::RValue); } void visit(AstExprInterpString* interpString) { for (AstExpr* expr : interpString->expressions) - visit(expr, RValue); + visit(expr, ValueContext::RValue); } void visit(AstExprError* expr) { // TODO! for (AstExpr* e : expr->expressions) - visit(e, RValue); + visit(e, ValueContext::RValue); } /** Extract a TypeId for the first type of the provided pack. @@ -1858,7 +1878,7 @@ struct TypeChecker2 void visit(AstTypeTypeof* ty) { - visit(ty->expr, RValue); + visit(ty->expr, ValueContext::RValue); } void visit(AstTypeUnion* ty) @@ -2109,7 +2129,7 @@ struct TypeChecker2 // because classes come into being with full knowledge of their // shape. We instead want to report the unknown property error of // the `else` branch. - else if (context == LValue && !get(tableTy)) + else if (context == ValueContext::LValue && !get(tableTy)) reportError(CannotExtendTable{tableTy, CannotExtendTable::Property, prop}, location); else reportError(UnknownProperty{tableTy, prop}, location); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index a8c093a43..8f9e1851b 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1786,7 +1786,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& } else { - TypeId currentTy = assignTo[propName].type; + TypeId currentTy = assignTo[propName].type(); // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. @@ -2076,7 +2076,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( if (TableType* tableType = getMutableTableType(type)) { if (auto it = tableType->props.find(name); it != tableType->props.end()) - return it->second.type; + return it->second.type(); else if (auto indexer = tableType->indexer) { // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. @@ -2104,7 +2104,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( { const Property* prop = lookupClassProp(cls, name); if (prop) - return prop->type; + return prop->type(); } else if (const UnionType* utv = get(type)) { @@ -2294,9 +2294,9 @@ TypeId TypeChecker::checkExprTable( if (it != expectedTable->props.end()) { Property expectedProp = it->second; - ErrorVec errors = tryUnify(exprType, expectedProp.type, scope, k->location); + ErrorVec errors = tryUnify(exprType, expectedProp.type(), scope, k->location); if (errors.empty()) - exprType = expectedProp.type; + exprType = expectedProp.type(); } else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType)) { @@ -2390,7 +2390,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (expectedTable) { if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) - expectedResultType = prop->second.type; + expectedResultType = prop->second.type(); else if (expectedIndexType && maybeString(*expectedIndexType)) expectedResultType = expectedIndexResultType; } @@ -2402,7 +2402,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (const TableType* ttv = get(follow(expectedOption))) { if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) - expectedResultTypes.push_back(prop->second.type); + expectedResultTypes.push_back(prop->second.type()); else if (ttv->indexer && maybeString(ttv->indexer->indexType)) expectedResultTypes.push_back(ttv->indexer->indexResultType); } @@ -3257,13 +3257,13 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex const auto& it = lhsTable->props.find(name); if (it != lhsTable->props.end()) { - return it->second.type; + return it->second.type(); } else if ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free) { TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; - property.type = theType; + property.setType(theType); property.location = expr.indexLocation; return theType; } @@ -3303,7 +3303,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex return errorRecoveryType(scope); } - return prop->type; + return prop->type(); } else if (get(lhs)) { @@ -3351,7 +3351,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); return errorRecoveryType(scope); } - return prop->type; + return prop->type(); } } else if (FFlag::LuauAllowIndexClassParameters) @@ -3378,13 +3378,13 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex const auto& it = exprTable->props.find(value->value.data); if (it != exprTable->props.end()) { - return it->second.type; + return it->second.type(); } else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free) { TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; - property.type = resultType; + property.setType(resultType); property.location = expr.index->location; return resultType; } @@ -3467,13 +3467,12 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T Name name = indexName->index.value; if (ttv->props.count(name)) - return ttv->props[name].type; + return ttv->props[name].type(); Property& property = ttv->props[name]; - - property.type = freshTy(); + property.setType(freshTy()); property.location = indexName->indexLocation; - return property.type; + return property.type(); } else if (funName.is()) return errorRecoveryType(scope); diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 031c59936..b81cca7ba 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -70,7 +70,7 @@ TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy) else if (auto tt = get(reducedTy)) { for (auto& [k, p] : tt->props) - irreducible &= isIrreducible(p.type); + irreducible &= isIrreducible(p.type()); if (tt->indexer) { @@ -539,7 +539,7 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) // even if we have the corresponding property in the other one. if (auto other = t2->props.find(name); other != t2->props.end()) { - TypeId propTy = apply(&TypeReducer::intersectionType, prop.type, other->second.type); + TypeId propTy = apply(&TypeReducer::intersectionType, prop.type(), other->second.type()); if (get(propTy)) return builtinTypes->neverType; // { p : string } & { p : number } ~ { p : string & number } ~ { p : never } ~ never else @@ -554,7 +554,7 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) // TODO: And vice versa, t2 properties against t1 indexer if it exists, // even if we have the corresponding property in the other one. if (!t1->props.count(name)) - table->props[name] = {reduce(prop.type)}; // {} & { p : string & string } ~ { p : string } + table->props[name] = {reduce(prop.type())}; // {} & { p : string & string } ~ { p : string } } if (t1->indexer && t2->indexer) @@ -966,11 +966,11 @@ TypeId TypeReducer::tableType(TypeId ty) for (auto& [name, prop] : copied->props) { - TypeId propTy = reduce(prop.type); + TypeId propTy = reduce(prop.type()); if (get(propTy)) return builtinTypes->neverType; else - prop.type = propTy; + prop.setType(propTy); } if (copied->indexer) diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index e5029e587..9124e2fc5 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -34,7 +34,7 @@ std::optional findMetatableEntry( auto it = mtt->props.find(entry); if (it != mtt->props.end()) - return it->second.type; + return it->second.type(); else return std::nullopt; } @@ -49,7 +49,7 @@ std::optional findTablePropertyRespectingMeta( { const auto& it = tableType->props.find(name); if (it != tableType->props.end()) - return it->second.type; + return it->second.type(); } std::optional mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location); @@ -67,7 +67,7 @@ std::optional findTablePropertyRespectingMeta( { const auto& fit = itt->props.find(name); if (fit != itt->props.end()) - return fit->second.type; + return fit->second.type(); } else if (const auto& itf = get(index)) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 3f4e34f6d..3ca93591a 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -351,7 +351,7 @@ static std::optional> getTableMatchT { for (auto&& [name, prop] : ttv->props) { - if (auto sing = get(follow(prop.type))) + if (auto sing = get(follow(prop.type()))) return {{name, sing}}; } } @@ -2003,7 +2003,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - if (subIter == subTable->props.end() && subTable->state == TableState::Unsealed && !isOptional(superProp.type)) + if (subIter == subTable->props.end() && subTable->state == TableState::Unsealed && !isOptional(superProp.type())) missingProperties.push_back(propName); } @@ -2044,7 +2044,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(r->second.type, prop.type); + innerState.tryUnify_(r->second.type(), prop.type()); checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); @@ -2060,7 +2060,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTable->indexer->indexResultType, prop.type); + innerState.tryUnify_(subTable->indexer->indexResultType, prop.type()); checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); @@ -2068,7 +2068,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) log.concat(std::move(innerState.log)); failure |= innerState.failure; } - else if (subTable->state == TableState::Unsealed && isOptional(prop.type)) + else if (subTable->state == TableState::Unsealed && isOptional(prop.type())) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) @@ -2123,7 +2123,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTable->indexer->indexResultType, prop.type); + innerState.tryUnify_(superTable->indexer->indexResultType, prop.type()); checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); @@ -2137,7 +2137,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // TODO: file a JIRA // TODO: hopefully readonly/writeonly properties will fix this. Property clone = prop; - clone.type = deeplyOptional(clone.type); + clone.setType(deeplyOptional(clone.type())); PendingType* pendingSuper = log.queue(superTy); TableType* pendingSuperTtv = getMutable(pendingSuper); @@ -2297,7 +2297,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (auto it = mttv->props.find("__index"); it != mttv->props.end()) { - TypeId ty = it->second.type; + TypeId ty = it->second.type(); Unifier child = makeChildUnifier(); child.tryUnify_(ty, superTy); @@ -2349,7 +2349,7 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see result = types->addType(*ttv); TableType* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) - prop.type = deeplyOptional(prop.type, seen); + prop.setType(deeplyOptional(prop.type(), seen)); return types->addType(UnionType{{builtinTypes->nilType, result}}); } else @@ -2394,7 +2394,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { if (std::optional mtPropTy = findTablePropertyRespectingMeta(superTy, propName)) { - innerState.tryUnify(prop.type, *mtPropTy); + innerState.tryUnify(prop.type(), *mtPropTy); } else { @@ -2505,7 +2505,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) else { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(classProp->type, prop.type); + innerState.tryUnify_(classProp->type(), prop.type()); checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); @@ -2674,7 +2674,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas else if (auto table = state.log.getMutable(ty)) { for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); + queue.push_back(prop.type()); if (table->indexer) { diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index bcf70f250..4303364cd 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -703,10 +703,26 @@ struct CompileStats size_t lines; size_t bytecode; size_t codegen; + + double readTime; + double miscTime; + double parseTime; + double compileTime; + double codegenTime; }; +static double recordDeltaTime(double& timer) +{ + double now = Luau::TimeTrace::getClock(); + double delta = now - timer; + timer = now; + return delta; +} + static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) { + double currts = Luau::TimeTrace::getClock(); + std::optional source = readFile(name); if (!source) { @@ -714,6 +730,8 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st return false; } + stats.readTime += recordDeltaTime(currts); + // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) // This function is much more complicated because it supports many output human-readable formats through internal interfaces @@ -753,6 +771,8 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st bcb.setDumpSource(*source); } + stats.miscTime += recordDeltaTime(currts); + Luau::Allocator allocator; Luau::AstNameTable names(allocator); Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); @@ -761,9 +781,11 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st throw Luau::ParseErrors(result.errors); stats.lines += result.lines; + stats.parseTime += recordDeltaTime(currts); Luau::compileOrThrow(bcb, result, names, copts()); stats.bytecode += bcb.getBytecode().size(); + stats.compileTime += recordDeltaTime(currts); switch (format) { @@ -784,6 +806,7 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st break; case CompileFormat::CodegenNull: stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); + stats.codegenTime += recordDeltaTime(currts); break; case CompileFormat::Null: break; @@ -998,18 +1021,18 @@ int replMain(int argc, char** argv) CompileStats stats = {}; int failed = 0; - double startTime = Luau::TimeTrace::getClock(); for (const std::string& path : files) failed += !compileFile(path.c_str(), compileFormat, stats); - double duration = Luau::TimeTrace::getClock() - startTime; - if (compileFormat == CompileFormat::Null) - printf("Compiled %d KLOC into %d KB bytecode in %.2fs\n", int(stats.lines / 1000), int(stats.bytecode / 1024), duration); + printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), + int(stats.bytecode / 1024), stats.readTime, stats.parseTime, stats.compileTime); else if (compileFormat == CompileFormat::CodegenNull) - printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) in %.2fs\n", int(stats.lines / 1000), int(stats.bytecode / 1024), - int(stats.codegen / 1024), stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), duration); + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", + int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024), + stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime, + stats.codegenTime); return failed ? 1 : 0; } diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 1a5f51370..26be11c54 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -136,6 +136,7 @@ class AssemblyBuilderA64 void frinta(RegisterA64 dst, RegisterA64 src); void frintm(RegisterA64 dst, RegisterA64 src); void frintp(RegisterA64 dst, RegisterA64 src); + void fcvt(RegisterA64 dst, RegisterA64 src); void fcvtzs(RegisterA64 dst, RegisterA64 src); void fcvtzu(RegisterA64 dst, RegisterA64 src); void scvtf(RegisterA64 dst, RegisterA64 src); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index bb3ebb287..e162cd3e4 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -154,7 +154,7 @@ class AssemblyBuilderX64 // Run final checks - void finalize(); + bool finalize(); // Places a label at current location and returns it Label setLabel(); diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index e6202c777..3b09359ec 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -36,6 +36,8 @@ struct IrBuilder // Source block that is cloned cannot use values coming in from a predecessor void clone(const IrBlock& source, bool removeCurrentTerminator); + IrOp undef(); + IrOp constBool(bool value); IrOp constInt(int value); IrOp constUint(unsigned value); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 47a973334..addd18f6b 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -10,6 +10,7 @@ #include #include +#include struct Proto; @@ -18,6 +19,12 @@ namespace Luau namespace CodeGen { +// IR extensions to LuauBuiltinFunction enum (these only exist inside IR, and start from 256 to avoid collisions) +enum +{ + LBF_IR_MATH_LOG2 = 256, +}; + // IR instruction command. // In the command description, following abbreviations are used: // * Rn - VM stack register slot, n in 0..254 @@ -112,7 +119,7 @@ enum class IrCmd : uint8_t ADD_INT, SUB_INT, - // Add/Sub/Mul/Div/Mod/Pow two double numbers + // Add/Sub/Mul/Div/Mod two double numbers // A, B: double // In final x64 lowering, B can also be Rn or Kn ADD_NUM, @@ -120,7 +127,6 @@ enum class IrCmd : uint8_t MUL_NUM, DIV_NUM, MOD_NUM, - POW_NUM, // Get the minimum/maximum of two numbers // If one of the values is NaN, 'B' is returned as the result @@ -192,8 +198,8 @@ enum class IrCmd : uint8_t // D: block (if false) JUMP_LT_INT, - // Jump if A >= B - // A, B: uint + // Jump if unsigned(A) >= unsigned(B) + // A, B: int // C: condition // D: block (if true) // E: block (if false) @@ -543,17 +549,17 @@ enum class IrCmd : uint8_t // A: operand of any type // Performs bitwise and/xor/or on two unsigned integers - // A, B: uint + // A, B: int BITAND_UINT, BITXOR_UINT, BITOR_UINT, // Performs bitwise not on an unsigned integer - // A: uint + // A: int BITNOT_UINT, // Performs bitwise shift/rotate on an unsigned integer - // A: uint (source) + // A: int (source) // B: int (shift amount) BITLSHIFT_UINT, BITRSHIFT_UINT, @@ -562,7 +568,7 @@ enum class IrCmd : uint8_t BITRROTATE_UINT, // Returns the number of consecutive zero bits in A starting from the left-most (most significant) bit. - // A: uint + // A: int BITCOUNTLZ_UINT, BITCOUNTRZ_UINT, @@ -621,6 +627,8 @@ enum class IrOpKind : uint32_t { None, + Undef, + // To reference a constant value Constant, @@ -710,6 +718,63 @@ struct IrInst // When IrInst operands are used, current instruction index is often required to track lifetime constexpr uint32_t kInvalidInstIdx = ~0u; +struct IrInstHash +{ + static const uint32_t m = 0x5bd1e995; + static const int r = 24; + + static uint32_t mix(uint32_t h, uint32_t k) + { + // MurmurHash2 step + k *= m; + k ^= k >> r; + k *= m; + + h *= m; + h ^= k; + + return h; + } + + static uint32_t mix(uint32_t h, IrOp op) + { + static_assert(sizeof(op) == sizeof(uint32_t)); + uint32_t k; + memcpy(&k, &op, sizeof(op)); + + return mix(h, k); + } + + size_t operator()(const IrInst& key) const + { + // MurmurHash2 unrolled + uint32_t h = 25; + + h = mix(h, uint32_t(key.cmd)); + h = mix(h, key.a); + h = mix(h, key.b); + h = mix(h, key.c); + h = mix(h, key.d); + h = mix(h, key.e); + h = mix(h, key.f); + + // MurmurHash2 tail + h ^= h >> 13; + h *= m; + h ^= h >> 15; + + return h; + } +}; + +struct IrInstEq +{ + bool operator()(const IrInst& a, const IrInst& b) const + { + return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f; + } +}; + enum class IrBlockKind : uint8_t { Bytecode, diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index ed9dc91ae..3cf18cd48 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -135,7 +135,6 @@ inline bool hasResult(IrCmd cmd) case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: @@ -231,7 +230,7 @@ bool compare(double a, double b, IrCondition cond); // But it can also be successful on conditional control-flow, replacing it with an unconditional IrCmd::JUMP void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t instIdx); -uint32_t getNativeContextOffset(LuauBuiltinFunction bfid); +uint32_t getNativeContextOffset(int bfid); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/OptimizeConstProp.h b/CodeGen/include/Luau/OptimizeConstProp.h index 619165d07..74ae131aa 100644 --- a/CodeGen/include/Luau/OptimizeConstProp.h +++ b/CodeGen/include/Luau/OptimizeConstProp.h @@ -10,8 +10,8 @@ namespace CodeGen struct IrBuilder; -void constPropInBlockChains(IrBuilder& build); -void createLinearBlocks(IrBuilder& build); +void constPropInBlockChains(IrBuilder& build, bool useValueNumbering); +void createLinearBlocks(IrBuilder& build, bool useValueNumbering); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/RegisterA64.h b/CodeGen/include/Luau/RegisterA64.h index c3a9ae03f..d50369e37 100644 --- a/CodeGen/include/Luau/RegisterA64.h +++ b/CodeGen/include/Luau/RegisterA64.h @@ -17,6 +17,7 @@ enum class KindA64 : uint8_t none, w, // 32-bit GPR x, // 64-bit GPR + s, // 32-bit SIMD&FP scalar d, // 64-bit SIMD&FP scalar q, // 128-bit SIMD&FP vector }; @@ -128,6 +129,39 @@ constexpr RegisterA64 xzr{KindA64::x, 31}; constexpr RegisterA64 sp{KindA64::none, 31}; +constexpr RegisterA64 s0{KindA64::s, 0}; +constexpr RegisterA64 s1{KindA64::s, 1}; +constexpr RegisterA64 s2{KindA64::s, 2}; +constexpr RegisterA64 s3{KindA64::s, 3}; +constexpr RegisterA64 s4{KindA64::s, 4}; +constexpr RegisterA64 s5{KindA64::s, 5}; +constexpr RegisterA64 s6{KindA64::s, 6}; +constexpr RegisterA64 s7{KindA64::s, 7}; +constexpr RegisterA64 s8{KindA64::s, 8}; +constexpr RegisterA64 s9{KindA64::s, 9}; +constexpr RegisterA64 s10{KindA64::s, 10}; +constexpr RegisterA64 s11{KindA64::s, 11}; +constexpr RegisterA64 s12{KindA64::s, 12}; +constexpr RegisterA64 s13{KindA64::s, 13}; +constexpr RegisterA64 s14{KindA64::s, 14}; +constexpr RegisterA64 s15{KindA64::s, 15}; +constexpr RegisterA64 s16{KindA64::s, 16}; +constexpr RegisterA64 s17{KindA64::s, 17}; +constexpr RegisterA64 s18{KindA64::s, 18}; +constexpr RegisterA64 s19{KindA64::s, 19}; +constexpr RegisterA64 s20{KindA64::s, 20}; +constexpr RegisterA64 s21{KindA64::s, 21}; +constexpr RegisterA64 s22{KindA64::s, 22}; +constexpr RegisterA64 s23{KindA64::s, 23}; +constexpr RegisterA64 s24{KindA64::s, 24}; +constexpr RegisterA64 s25{KindA64::s, 25}; +constexpr RegisterA64 s26{KindA64::s, 26}; +constexpr RegisterA64 s27{KindA64::s, 27}; +constexpr RegisterA64 s28{KindA64::s, 28}; +constexpr RegisterA64 s29{KindA64::s, 29}; +constexpr RegisterA64 s30{KindA64::s, 30}; +constexpr RegisterA64 s31{KindA64::s, 31}; + constexpr RegisterA64 d0{KindA64::d, 0}; constexpr RegisterA64 d1{KindA64::d, 1}; constexpr RegisterA64 d2{KindA64::d, 2}; diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index ab84bef8d..33e3c9635 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -282,7 +282,7 @@ void AssemblyBuilderA64::ror(RegisterA64 dst, RegisterA64 src1, uint8_t src2) void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src) { - LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w || dst.kind == KindA64::d || dst.kind == KindA64::q); + LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w || dst.kind == KindA64::s || dst.kind == KindA64::d || dst.kind == KindA64::q); switch (dst.kind) { @@ -292,6 +292,9 @@ void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src) case KindA64::x: placeA("ldr", dst, src, 0b11100001, 0b11, /* sizelog= */ 3); break; + case KindA64::s: + placeA("ldr", dst, src, 0b11110001, 0b10, /* sizelog= */ 2); + break; case KindA64::d: placeA("ldr", dst, src, 0b11110001, 0b11, /* sizelog= */ 3); break; @@ -348,7 +351,7 @@ void AssemblyBuilderA64::ldp(RegisterA64 dst1, RegisterA64 dst2, AddressA64 src) void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst) { - LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w || src.kind == KindA64::d || src.kind == KindA64::q); + LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w || src.kind == KindA64::s || src.kind == KindA64::d || src.kind == KindA64::q); switch (src.kind) { @@ -358,6 +361,9 @@ void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst) case KindA64::x: placeA("str", src, dst, 0b11100000, 0b11, /* sizelog= */ 3); break; + case KindA64::s: + placeA("str", src, dst, 0b11110000, 0b10, /* sizelog= */ 2); + break; case KindA64::d: placeA("str", src, dst, 0b11110000, 0b11, /* sizelog= */ 3); break; @@ -570,6 +576,16 @@ void AssemblyBuilderA64::frintp(RegisterA64 dst, RegisterA64 src) placeR1("frintp", dst, src, 0b000'11110'01'1'001'001'10000); } +void AssemblyBuilderA64::fcvt(RegisterA64 dst, RegisterA64 src) +{ + if (dst.kind == KindA64::s && src.kind == KindA64::d) + placeR1("fcvt", dst, src, 0b11110'01'1'0001'00'10000); + else if (dst.kind == KindA64::d && src.kind == KindA64::s) + placeR1("fcvt", dst, src, 0b11110'00'1'0001'01'10000); + else + LUAU_ASSERT(!"Unexpected register kind"); +} + void AssemblyBuilderA64::fcvtzs(RegisterA64 dst, RegisterA64 src) { LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); @@ -1229,6 +1245,10 @@ void AssemblyBuilderA64::log(RegisterA64 reg) logAppend("x%d", reg.index); break; + case KindA64::s: + logAppend("s%d", reg.index); + break; + case KindA64::d: logAppend("d%d", reg.index); break; diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index b8f3940d0..4c9ad6df4 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -831,7 +831,7 @@ void AssemblyBuilderX64::vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 placeAvx("vblendvpd", dst, src1, mask, src3.index << 4, 0x4b, false, AVX_0F3A, AVX_66); } -void AssemblyBuilderX64::finalize() +bool AssemblyBuilderX64::finalize() { code.resize(codePos - code.data()); @@ -853,6 +853,8 @@ void AssemblyBuilderX64::finalize() data.resize(dataSize); finalized = true; + + return true; } Label AssemblyBuilderX64::setLabel() diff --git a/CodeGen/src/BitUtils.h b/CodeGen/src/BitUtils.h index 93f7cc8db..31fc4bfba 100644 --- a/CodeGen/src/BitUtils.h +++ b/CodeGen/src/BitUtils.h @@ -32,5 +32,25 @@ inline int countrz(uint32_t n) #endif } +inline int lrotate(uint32_t u, int s) +{ + // MSVC doesn't recognize the rotate form that is UB-safe +#ifdef _MSC_VER + return _rotl(u, s); +#else + return (u << (s & 31)) | (u >> ((32 - s) & 31)); +#endif +} + +inline int rrotate(uint32_t u, int s) +{ + // MSVC doesn't recognize the rotate form that is UB-safe +#ifdef _MSC_VER + return _rotr(u, s); +#else + return (u >> (s & 31)) | (u << ((32 - s) & 31)); +#endif +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index a86d5a202..f0be5b3d8 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -51,6 +51,7 @@ LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false) LUAU_FASTFLAGVARIABLE(DebugCodegenOptSize, false) +LUAU_FASTFLAGVARIABLE(DebugCodegenSkipNumbering, false) namespace Luau { @@ -59,21 +60,33 @@ namespace CodeGen static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) { - NativeProto* result = new NativeProto(); + int sizecode = proto->sizecode; + int sizecodeAlloc = (sizecode + 1) & ~1; // align uint32_t array to 8 bytes so that NativeProto is aligned to 8 bytes + void* memory = ::operator new(sizeof(NativeProto) + sizecodeAlloc * sizeof(uint32_t)); + NativeProto* result = new (static_cast(memory) + sizecodeAlloc * sizeof(uint32_t)) NativeProto; result->proto = proto; - result->instTargets = new uintptr_t[proto->sizecode]; - for (int i = 0; i < proto->sizecode; i++) - { - auto [irLocation, asmLocation] = ir.function.bcMapping[i]; + uint32_t* instOffsets = result->instOffsets; - result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation; + for (int i = 0; i < sizecode; i++) + { + // instOffsets uses negative indexing for optimal codegen for RETURN opcode + instOffsets[-i] = ir.function.bcMapping[i].asmLocation; } return result; } +static void destroyNativeProto(NativeProto* nativeProto) +{ + int sizecode = nativeProto->proto->sizecode; + int sizecodeAlloc = (sizecode + 1) & ~1; // align uint32_t array to 8 bytes so that NativeProto is aligned to 8 bytes + void* memory = reinterpret_cast(nativeProto) - sizecodeAlloc * sizeof(uint32_t); + + ::operator delete(memory); +} + template static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) { @@ -95,30 +108,19 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& return a.start < b.start; }); - DenseHashMap bcLocations{~0u}; + // For each IR instruction that begins a bytecode instruction, which bytecode instruction is it? + std::vector bcLocations(function.instructions.size() + 1, ~0u); - // Create keys for IR assembly locations that original bytecode instruction are interested in - for (const auto& [irLocation, asmLocation] : function.bcMapping) + for (size_t i = 0; i < function.bcMapping.size(); ++i) { + uint32_t irLocation = function.bcMapping[i].irLocation; + if (irLocation != ~0u) - bcLocations[irLocation] = 0; + bcLocations[irLocation] = uint32_t(i); } - DenseHashMap indexIrToBc{~0u}; bool outputEnabled = options.includeAssembly || options.includeIr; - if (outputEnabled && options.annotator) - { - // Create reverse mapping from IR location to bytecode location - for (size_t i = 0; i < function.bcMapping.size(); ++i) - { - uint32_t irLocation = function.bcMapping[i].irLocation; - - if (irLocation != ~0u) - indexIrToBc[irLocation] = uint32_t(i); - } - } - IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; // We use this to skip outlined fallback blocks from IR/asm text output @@ -164,18 +166,19 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& { LUAU_ASSERT(index < function.instructions.size()); + uint32_t bcLocation = bcLocations[index]; + // If IR instruction is the first one for the original bytecode, we can annotate it with source code text - if (outputEnabled && options.annotator) + if (outputEnabled && options.annotator && bcLocation != ~0u) { - if (uint32_t* bcIndex = indexIrToBc.find(index)) - options.annotator(options.annotatorContext, build.text, bytecodeid, *bcIndex); + options.annotator(options.annotatorContext, build.text, bytecodeid, bcLocation); } // If bytecode needs the location of this instruction for jumps, record it - if (uint32_t* bcLocation = bcLocations.find(index)) + if (bcLocation != ~0u) { Label label = (index == block.start) ? block.label : build.setLabel(); - *bcLocation = build.getLabelOffset(label); + function.bcMapping[bcLocation].asmLocation = build.getLabelOffset(label); } IrInst& inst = function.instructions[index]; @@ -227,13 +230,6 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& build.logAppend("; skipping %u bytes of outlined code\n", unsigned((build.getCodeSize() - codeSize) * sizeof(build.code[0]))); } - // Copy assembly locations of IR instructions that are mapped to bytecode instructions - for (auto& [irLocation, asmLocation] : function.bcMapping) - { - if (irLocation != ~0u) - asmLocation = bcLocations[irLocation]; - } - return true; } @@ -293,10 +289,12 @@ static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, if (!FFlag::DebugCodegenNoOpt) { - constPropInBlockChains(ir); + bool useValueNumbering = !FFlag::DebugCodegenSkipNumbering; + + constPropInBlockChains(ir, useValueNumbering); if (!FFlag::DebugCodegenOptSize) - createLinearBlocks(ir); + createLinearBlocks(ir, useValueNumbering); } if (!lowerIr(build, ir, data, helpers, proto, options)) @@ -313,12 +311,6 @@ static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, return createNativeProto(proto, ir); } -static void destroyNativeProto(NativeProto* nativeProto) -{ - delete[] nativeProto->instTargets; - delete nativeProto; -} - static void onCloseState(lua_State* L) { destroyNativeState(L); @@ -347,7 +339,9 @@ static int onEnter(lua_State* L, Proto* proto) bool (*gate)(lua_State*, Proto*, uintptr_t, NativeContext*) = (bool (*)(lua_State*, Proto*, uintptr_t, NativeContext*))data->context.gateEntry; NativeProto* nativeProto = getProtoExecData(proto); - uintptr_t target = nativeProto->instTargets[L->ci->savedpc - proto->code]; + + // instOffsets uses negative indexing for optimal codegen for RETURN opcode + uintptr_t target = nativeProto->instBase + nativeProto->instOffsets[-(L->ci->savedpc - proto->code)]; // Returns 1 to finish the function in the VM return gate(L, proto, target, &data->context); @@ -517,7 +511,14 @@ void compile(lua_State* L, int idx) if (NativeProto* np = assembleFunction(build, *data, helpers, p, {})) results.push_back(np); - build.finalize(); + // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module + if (!build.finalize()) + { + for (NativeProto* result : results) + destroyNativeProto(result); + + return; + } // If no functions were assembled, we don't need to allocate/copy executable pages for helpers if (results.empty()) @@ -535,14 +536,11 @@ void compile(lua_State* L, int idx) return; } - // Relocate instruction offsets + // Record instruction base address; at runtime, instOffsets[] will be used as offsets from instBase for (NativeProto* result : results) { - for (int i = 0; i < result->proto->sizecode; i++) - result->instTargets[i] += uintptr_t(codeStart); - - LUAU_ASSERT(result->proto->sizecode); - result->entryTarget = result->instTargets[0]; + result->instBase = uintptr_t(codeStart); + result->entryTarget = uintptr_t(codeStart) + result->instOffsets[0]; } // Link native proto objects to Proto; the memory is now managed by VM and will be freed via onDestroyFunction @@ -579,7 +577,8 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) if (NativeProto* np = assembleFunction(build, data, helpers, p, options)) destroyNativeProto(np); - build.finalize(); + if (!build.finalize()) + return std::string(); if (options.outputBinary) return std::string(reinterpret_cast(build.code.data()), reinterpret_cast(build.code.data() + build.code.size())) + diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 7f29beb2b..415cfc926 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -35,11 +35,17 @@ static void emitInterrupt(AssemblyBuilderA64& build) { // x0 = pc offset // x1 = return address in native code - // x2 = interrupt + + Label skip; // Stash return address in rBase; we need to reload rBase anyway build.mov(rBase, x1); + // Load interrupt handler; it may be nullptr in case the update raced with the check before we got here + build.ldr(x2, mem(rState, offsetof(lua_State, global))); + build.ldr(x2, mem(x2, offsetof(global_State, cb.interrupt))); + build.cbz(x2, skip); + // Update savedpc; required in case interrupt errors build.add(x0, rCode, x0); build.ldr(x1, mem(rState, offsetof(lua_State, ci))); @@ -51,7 +57,6 @@ static void emitInterrupt(AssemblyBuilderA64& build) build.blr(x2); // Check if we need to exit - Label skip; build.ldrb(w0, mem(rState, offsetof(lua_State, status))); build.cbz(w0, skip); @@ -92,11 +97,11 @@ static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) // Get instruction index from instruction pointer // To get instruction index from instruction pointer, we need to divide byte offset by 4 - // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out + // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out + // Note that we're computing negative offset here (code-savedpc) so that we can add it to NativeProto address, as we use reverse indexing build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci build.ldr(x2, mem(x2, offsetof(CallInfo, savedpc))); // L->ci->savedpc - build.sub(x2, x2, rCode); - build.add(x2, x2, x2); // TODO: this would not be necessary if we supported shifted register offsets in loads + build.sub(x2, rCode, x2); // We need to check if the new function can be executed natively // TODO: This can be done earlier in the function flow, to reduce the JIT->VM transition penalty @@ -104,8 +109,10 @@ static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) build.cbz(x1, helpers.exitContinueVm); // Get new instruction location and jump to it - build.ldr(x1, mem(x1, offsetof(NativeProto, instTargets))); - build.ldr(x1, mem(x1, x2)); + LUAU_ASSERT(offsetof(NativeProto, instOffsets) == 0); + build.ldr(w2, mem(x1, x2)); + build.ldr(x1, mem(x1, offsetof(NativeProto, instBase))); + build.add(x1, x1, x2); build.br(x1); } diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 3e6d26b45..af4c529a3 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -18,24 +18,10 @@ namespace CodeGen namespace X64 { -void emitBuiltinMathLog(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - - // TODO: IR builtin lowering assumes that the only valid 2-argument call is log2; ideally, we use a less hacky way to indicate that - if (nparams == 2) - build.call(qword[rNativeContext + offsetof(NativeContext, libm_log2)]); - else - build.call(qword[rNativeContext + offsetof(NativeContext, libm_log)]); - - build.vmovsd(luauRegValue(ra), xmm0); -} - -void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, OperandX64 arg2) { ScopedRegX64 tmp{regs, SizeX64::qword}; - build.vcvttsd2si(tmp.reg, qword[args + offsetof(TValue, value)]); + build.vcvttsd2si(tmp.reg, arg2); IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); @@ -45,7 +31,7 @@ void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int np build.vmovsd(luauRegValue(ra), xmm0); } -void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int nresults) { IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); @@ -61,7 +47,7 @@ void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int np } } -void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int nresults) { IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); @@ -75,7 +61,7 @@ void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npa build.vmovsd(luauRegValue(ra + 1), xmm0); } -void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) { ScopedRegX64 tmp0{regs, SizeX64::xmmword}; ScopedRegX64 tmp1{regs, SizeX64::xmmword}; @@ -102,7 +88,7 @@ void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npa build.vmovsd(luauRegValue(ra), tmp0.reg); } -void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) { ScopedRegX64 tmp0{regs, SizeX64::qword}; ScopedRegX64 tag{regs, SizeX64::dword}; @@ -115,7 +101,7 @@ void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams build.mov(luauRegValue(ra), tmp0.reg); } -void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) { IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::qword, rState); @@ -125,38 +111,28 @@ void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npara build.mov(luauRegValue(ra), rax); } -void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults) +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults) { - OperandX64 argsOp = 0; - - if (args.kind == IrOpKind::VmReg) - argsOp = luauRegAddress(vmRegOp(args)); - else if (args.kind == IrOpKind::VmConst) - argsOp = luauConstantAddress(vmConstOp(args)); - switch (bfid) { - case LBF_MATH_LOG: - LUAU_ASSERT((nparams == 1 || nparams == 2) && nresults == 1); - return emitBuiltinMathLog(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_LDEXP: LUAU_ASSERT(nparams == 2 && nresults == 1); - return emitBuiltinMathLdexp(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinMathLdexp(regs, build, ra, arg, arg2); case LBF_MATH_FREXP: LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - return emitBuiltinMathFrexp(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinMathFrexp(regs, build, ra, arg, nresults); case LBF_MATH_MODF: LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - return emitBuiltinMathModf(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinMathModf(regs, build, ra, arg, nresults); case LBF_MATH_SIGN: LUAU_ASSERT(nparams == 1 && nresults == 1); - return emitBuiltinMathSign(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinMathSign(regs, build, ra, arg); case LBF_TYPE: LUAU_ASSERT(nparams == 1 && nresults == 1); - return emitBuiltinType(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinType(regs, build, ra, arg); case LBF_TYPEOF: LUAU_ASSERT(nparams == 1 && nresults == 1); - return emitBuiltinTypeof(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinTypeof(regs, build, ra, arg); default: LUAU_ASSERT(!"Missing x64 lowering"); break; diff --git a/CodeGen/src/EmitBuiltinsX64.h b/CodeGen/src/EmitBuiltinsX64.h index 5925a2b3d..cd8b52517 100644 --- a/CodeGen/src/EmitBuiltinsX64.h +++ b/CodeGen/src/EmitBuiltinsX64.h @@ -16,7 +16,7 @@ class AssemblyBuilderX64; struct OperandX64; struct IrRegAllocX64; -void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults); +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults); } // namespace X64 } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index f590df91a..9e89b1c09 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -13,8 +13,9 @@ // Arguments: x0-x7, v0-v7 // Return: x0, v0 (or x8 that points to the address of the resulting structure) // Volatile: x9-x15, v16-v31 ("caller-saved", any call may change them) +// Intra-procedure-call temporary: x16-x17 (any call or relocated jump may change them, as linker may point branches to veneers to perform long jumps) // Non-volatile: x19-x28, v8-v15 ("callee-saved", preserved after calls, only bottom half of SIMD registers is preserved!) -// Reserved: x16-x18: reserved for linker/platform use; x29: frame pointer (unless omitted); x30: link register; x31: stack pointer +// Reserved: x18: reserved for platform use; x29: frame pointer (unless omitted); x30: link register; x31: stack pointer namespace Luau { diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 9a10bfdc1..19f0cb86d 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -308,12 +308,15 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i build.mov(rax, qword[cip + offsetof(CallInfo, savedpc)]); // To get instruction index from instruction pointer, we need to divide byte offset by 4 - // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out - build.sub(rax, rdx); + // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out + // Note that we're computing negative offset here (code-savedpc) so that we can add it to NativeProto address, as we use reverse indexing + build.sub(rdx, rax); // Get new instruction location and jump to it - build.mov(rdx, qword[execdata + offsetof(NativeProto, instTargets)]); - build.jmp(qword[rdx + rax * 2]); + LUAU_ASSERT(offsetof(NativeProto, instOffsets) == 0); + build.mov(edx, dword[execdata + rdx]); + build.add(rdx, qword[execdata + offsetof(NativeProto, instBase)]); + build.jmp(rdx); } void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index) diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index f3870e96b..efe9fcc06 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -529,7 +529,8 @@ static void computeCfgLiveInOutRegSets(IrFunction& function) RegisterSet& outRs = info.out[blockIdx]; // Current block has to provide all registers in successor blocks - for (uint32_t succIdx : successors(info, blockIdx)) + BlockIteratorWrapper successorsIt = successors(info, blockIdx); + for (uint32_t succIdx : successorsIt) { IrBlock& succ = function.blocks[succIdx]; @@ -538,7 +539,11 @@ static void computeCfgLiveInOutRegSets(IrFunction& function) // This is because fallback blocks define an alternative implementation of the same operations // This can cause the current block to define more registers that actually were available at fallback entry if (curr.kind != IrBlockKind::Fallback && succ.kind == IrBlockKind::Fallback) + { + // If this is the only successor, this skip will not be valid + LUAU_ASSERT(successorsIt.size() != 1); continue; + } const RegisterSet& succRs = info.in[succIdx]; diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 48c0e25c0..86986fe92 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -30,7 +30,7 @@ void IrBuilder::buildFunctionIr(Proto* proto) // Rebuild original control flow blocks rebuildBytecodeBasicBlocks(proto); - function.bcMapping.resize(proto->sizecode, {~0u, 0}); + function.bcMapping.resize(proto->sizecode, {~0u, ~0u}); // Translate all instructions to IR inside blocks for (int i = 0; i < proto->sizecode;) @@ -41,7 +41,7 @@ void IrBuilder::buildFunctionIr(Proto* proto) int nexti = i + getOpLength(op); LUAU_ASSERT(nexti <= proto->sizecode); - function.bcMapping[i] = {uint32_t(function.instructions.size()), 0}; + function.bcMapping[i] = {uint32_t(function.instructions.size()), ~0u}; // Begin new block at this instruction if it was in the bytecode or requested during translation if (instIndexToBlock[i] != kNoAssociatedBlockIndex) @@ -293,7 +293,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) int skip = LUAU_INSN_C(*pc); IrOp next = blockAtInst(i + skip + 2); - translateFastCallN(*this, pc, i, true, 1, constBool(false), next); + translateFastCallN(*this, pc, i, true, 1, undef(), next); activeFastcallFallback = true; fastcallFallbackReturn = next; @@ -496,6 +496,11 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) } } +IrOp IrBuilder::undef() +{ + return {IrOpKind::Undef, 0}; +} + IrOp IrBuilder::constBool(bool value) { IrConst constant; diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 084a72771..50c1848ea 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -120,8 +120,6 @@ const char* getCmdName(IrCmd cmd) return "DIV_NUM"; case IrCmd::MOD_NUM: return "MOD_NUM"; - case IrCmd::POW_NUM: - return "POW_NUM"; case IrCmd::MIN_NUM: return "MIN_NUM"; case IrCmd::MAX_NUM: @@ -359,6 +357,9 @@ void toString(IrToStringContext& ctx, IrOp op) { case IrOpKind::None: break; + case IrOpKind::Undef: + append(ctx.result, "undef"); + break; case IrOpKind::Constant: toString(ctx.result, ctx.constants[op.index]); break; @@ -398,7 +399,10 @@ void toString(std::string& result, IrConst constant) append(result, "%uu", constant.valueUint); break; case IrConstKind::Double: - append(result, "%.17g", constant.valueDouble); + if (constant.valueDouble != constant.valueDouble) + append(result, "nan"); + else + append(result, "%.17g", constant.valueDouble); break; case IrConstKind::Tag: result.append(getTagName(constant.valueTag)); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 18df26b1f..7fd684b4d 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -109,46 +109,6 @@ static void emitFallback(AssemblyBuilderA64& build, int op, int pcpos) emitUpdateBase(build); } -static void emitInvokeLibm1(AssemblyBuilderA64& build, size_t func, int res, int arg) -{ - build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); - build.ldr(x0, mem(rNativeContext, uint32_t(func))); - build.blr(x0); - build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); -} - -static void emitInvokeLibm2(AssemblyBuilderA64& build, size_t func, int res, int arg, IrOp args, bool argsInt = false) -{ - if (args.kind == IrOpKind::VmReg) - build.ldr(d1, mem(rBase, args.index * sizeof(TValue) + offsetof(TValue, value.n))); - else if (args.kind == IrOpKind::VmConst) - { - size_t constantOffset = args.index * sizeof(TValue) + offsetof(TValue, value.n); - - // Note: cumulative offset is guaranteed to be divisible by 8 (since we're loading a double); we can use that to expand the useful range that - // doesn't require temporaries - if (constantOffset / 8 <= AddressA64::kMaxOffset) - { - build.ldr(d1, mem(rConstants, int(constantOffset))); - } - else - { - emitAddOffset(build, x0, rConstants, constantOffset); - build.ldr(d1, x0); - } - } - else - LUAU_ASSERT(!"Unsupported instruction form"); - - if (argsInt) - build.fcvtzs(w0, d1); - - build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); - build.ldr(x1, mem(rNativeContext, uint32_t(func))); - build.blr(x1); - build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); -} - static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg) { build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); @@ -157,21 +117,46 @@ static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg) build.blr(x1); } -static bool emitBuiltin(AssemblyBuilderA64& build, IrRegAllocA64& regs, int bfid, int res, int arg, IrOp args, int nparams, int nresults) +static bool emitBuiltin( + AssemblyBuilderA64& build, IrFunction& function, IrRegAllocA64& regs, int bfid, int res, int arg, IrOp args, int nparams, int nresults) { switch (bfid) { - case LBF_MATH_LOG: - LUAU_ASSERT((nparams == 1 || nparams == 2) && nresults == 1); - // TODO: IR builtin lowering assumes that the only valid 2-argument call is log2; ideally, we use a less hacky way to indicate that - if (nparams == 2) - emitInvokeLibm1(build, offsetof(NativeContext, libm_log2), res, arg); - else - emitInvokeLibm1(build, offsetof(NativeContext, libm_log), res, arg); - return true; case LBF_MATH_LDEXP: LUAU_ASSERT(nparams == 2 && nresults == 1); - emitInvokeLibm2(build, offsetof(NativeContext, libm_ldexp), res, arg, args, /* argsInt= */ true); + + if (args.kind == IrOpKind::VmReg) + { + build.ldr(d1, mem(rBase, args.index * sizeof(TValue) + offsetof(TValue, value.n))); + build.fcvtzs(w0, d1); + } + else if (args.kind == IrOpKind::VmConst) + { + size_t constantOffset = args.index * sizeof(TValue) + offsetof(TValue, value.n); + + // Note: cumulative offset is guaranteed to be divisible by 8 (since we're loading a double); we can use that to expand the useful range + // that doesn't require temporaries + if (constantOffset / 8 <= AddressA64::kMaxOffset) + { + build.ldr(d1, mem(rConstants, int(constantOffset))); + } + else + { + emitAddOffset(build, x0, rConstants, constantOffset); + build.ldr(d1, x0); + } + + build.fcvtzs(w0, d1); + } + else if (args.kind == IrOpKind::Constant) + build.mov(w0, int(function.doubleOp(args))); + else if (args.kind != IrOpKind::Undef) + LUAU_ASSERT(!"Unsupported instruction form"); + + build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, libm_ldexp))); + build.blr(x1); + build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); return true; case LBF_MATH_FREXP: LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); @@ -233,14 +218,22 @@ IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, , data(data) , proto(proto) , function(function) - , regs(function, {{x0, x15}, {q0, q7}, {q16, q31}}) + , regs(function, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}}) + , valueTracker(function) { // In order to allocate registers during lowering, we need to know where instruction results are last used updateLastUseLocations(function); + + valueTracker.setRestoreCallack(this, [](void* context, IrInst& inst) { + IrLoweringA64* self = static_cast(context); + self->regs.restoreReg(self->build, inst); + }); } void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { + valueTracker.beforeInstLowering(inst); + switch (inst.cmd) { case IrCmd::LOAD_TAG: @@ -299,7 +292,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } else if (inst.b.kind == IrOpKind::Constant) { - // TODO: refactor into a common helper? can't use emitAddOffset because we need a temp register if (intOp(inst.b) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) { build.add(inst.regA64, inst.regA64, uint16_t(intOp(inst.b) * sizeof(TValue))); @@ -387,6 +379,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.str(temp, addr); break; } + case IrCmd::STORE_VECTOR: + { + RegisterA64 temp1 = tempDouble(inst.b); + RegisterA64 temp2 = tempDouble(inst.c); + RegisterA64 temp3 = tempDouble(inst.d); + RegisterA64 temp4 = regs.allocTemp(KindA64::s); + + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + LUAU_ASSERT(addr.kind == AddressKindA64::imm && addr.data % 4 == 0 && unsigned(addr.data + 8) / 4 <= AddressA64::kMaxOffset); + + build.fcvt(temp4, temp1); + build.str(temp4, AddressA64(addr.base, addr.data + 0)); + build.fcvt(temp4, temp2); + build.str(temp4, AddressA64(addr.base, addr.data + 4)); + build.fcvt(temp4, temp3); + build.str(temp4, AddressA64(addr.base, addr.data + 8)); + break; + } case IrCmd::STORE_TVALUE: { AddressA64 addr = tempAddr(inst.a, 0); @@ -400,6 +410,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant && unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate) build.add(inst.regA64, regOp(inst.a), uint16_t(intOp(inst.b))); + else if (inst.a.kind == IrOpKind::Constant && unsigned(intOp(inst.a)) <= AssemblyBuilderA64::kMaxImmediate) + build.add(inst.regA64, regOp(inst.b), uint16_t(intOp(inst.a))); else { RegisterA64 temp = tempInt(inst.b); @@ -459,21 +471,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.fsub(inst.regA64, temp1, inst.regA64); break; } - case IrCmd::POW_NUM: - { - RegisterA64 temp1 = tempDouble(inst.a); - RegisterA64 temp2 = tempDouble(inst.b); - build.fmov(d0, temp1); // TODO: aliasing hazard - build.fmov(d1, temp2); // TODO: aliasing hazard - regs.spill(build, index, {d0, d1}); - build.ldr(x0, mem(rNativeContext, offsetof(NativeContext, libm_pow))); - build.blr(x0); - - // TODO: we could takeReg d0 but it's unclear if we will be able to keep d0 allocatable due to aliasing concerns - inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); - build.fmov(inst.regA64, d0); - break; - } case IrCmd::MIN_NUM: { inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); @@ -635,8 +632,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::JUMP_GE_UINT: { - LUAU_ASSERT(uintOp(inst.b) <= AssemblyBuilderA64::kMaxImmediate); - build.cmp(regOp(inst.a), uint16_t(uintOp(inst.b))); + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); + build.cmp(regOp(inst.a), uint16_t(unsigned(intOp(inst.b)))); build.b(ConditionA64::CarrySet, labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.d), next); break; @@ -723,8 +720,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::TABLE_LEN: { - build.mov(x0, regOp(inst.a)); // TODO: aliasing hazard - regs.spill(build, index, {x0}); + RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + regs.spill(build, index, {reg}); + build.mov(x0, reg); build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaH_getn))); build.blr(x1); inst.regA64 = regs.allocReg(KindA64::d, index); @@ -739,21 +737,18 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(x2, uintOp(inst.b)); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaH_new))); build.blr(x3); - // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns - inst.regA64 = regs.allocReg(KindA64::x, index); - build.mov(inst.regA64, x0); + inst.regA64 = regs.takeReg(x0, index); break; } case IrCmd::DUP_TABLE: { - build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard - regs.spill(build, index, {x1}); + RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + regs.spill(build, index, {reg}); + build.mov(x1, reg); build.mov(x0, rState); build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaH_clone))); build.blr(x2); - // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns - inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); - build.mov(inst.regA64, x0); + inst.regA64 = regs.takeReg(x0, index); break; } case IrCmd::TRY_NUM_TO_INDEX: @@ -789,17 +784,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.tst(temp2, 1 << intOp(inst.b)); // can't use tbz/tbnz because their jump offsets are too short build.b(ConditionA64::NotEqual, labelOp(inst.c)); // Equal = Zero after tst; tmcache caches *absence* of metamethods + regs.spill(build, index, {temp1}); build.mov(x0, temp1); - regs.spill(build, index, {x0}); build.mov(w1, intOp(inst.b)); build.ldr(x2, mem(rState, offsetof(lua_State, global))); build.ldr(x2, mem(x2, offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaT_gettm))); build.blr(x3); - - // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns - inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); - build.mov(inst.regA64, x0); + inst.regA64 = regs.takeReg(x0, index); break; } case IrCmd::INT_TO_NUM: @@ -861,9 +853,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::FASTCALL: regs.spill(build, index); - // TODO: emitBuiltin should be exhaustive - if (!emitBuiltin(build, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f))) - error = true; + error |= emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f)); break; case IrCmd::INVOKE_FASTCALL: { @@ -878,7 +868,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) else if (inst.d.kind == IrOpKind::VmConst) emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue)); else - LUAU_ASSERT(boolOp(inst.d) == false); + LUAU_ASSERT(inst.d.kind == IrOpKind::Undef); // nparams if (intOp(inst.e) == LUA_MULTRET) @@ -1047,10 +1037,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label skip; checkObjectBarrierConditions(build, temp1, temp2, vmRegOp(inst.b), skip); - build.mov(x1, temp1); // TODO: aliasing hazard - - size_t spills = regs.spill(build, index, {x1}); + size_t spills = regs.spill(build, index, {temp1}); + build.mov(x1, temp1); build.mov(x0, rState); build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); @@ -1108,7 +1097,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(temp, regOp(inst.b)); else if (inst.b.kind == IrOpKind::Constant) { - // TODO: refactor into a common helper? if (size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate) { build.cmp(temp, uint16_t(intOp(inst.b))); @@ -1159,17 +1147,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::INTERRUPT: { - unsigned int pcpos = uintOp(inst.a); + RegisterA64 temp = regs.allocTemp(KindA64::x); Label skip, next; - build.ldr(x2, mem(rState, offsetof(lua_State, global))); - build.ldr(x2, mem(x2, offsetof(global_State, cb.interrupt))); - build.cbz(x2, skip); + build.ldr(temp, mem(rState, offsetof(lua_State, global))); + build.ldr(temp, mem(temp, offsetof(global_State, cb.interrupt))); + build.cbz(temp, skip); - size_t spills = regs.spill(build, index, {x2}); + size_t spills = regs.spill(build, index); // Jump to outlined interrupt handler, it will give back control to x1 - build.mov(x0, (pcpos + 1) * sizeof(Instruction)); + build.mov(x0, (uintOp(inst.a) + 1) * sizeof(Instruction)); build.adr(x1, next); build.b(helpers.interrupt); @@ -1182,7 +1170,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::CHECK_GC: { - regs.spill(build, index); RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp2 = regs.allocTemp(KindA64::x); @@ -1193,12 +1180,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(temp1, temp2); build.b(ConditionA64::UnsignedGreater, skip); + size_t spills = regs.spill(build, index); + build.mov(x0, rState); build.mov(w1, 1); build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaC_step))); build.blr(x1); emitUpdateBase(build); + + regs.restore(build, spills); // need to restore before skip so that registers are in a consistent state + build.setLabel(skip); break; } @@ -1209,8 +1201,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label skip; checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), skip); - build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard - size_t spills = regs.spill(build, index, {x1}); + RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + size_t spills = regs.spill(build, index, {reg}); + build.mov(x1, reg); build.mov(x0, rState); build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); @@ -1231,8 +1224,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.ldrb(temp, mem(regOp(inst.a), offsetof(GCheader, marked))); build.tbz(temp, BLACKBIT, skip); - build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard - size_t spills = regs.spill(build, index, {x1}); + RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + size_t spills = regs.spill(build, index, {reg}); + build.mov(x1, reg); build.mov(x0, rState); build.add(x2, x1, uint16_t(offsetof(Table, gclist))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierback))); @@ -1251,8 +1245,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label skip; checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), skip); - build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard - size_t spills = regs.spill(build, index, {x1}); + RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + size_t spills = regs.spill(build, index, {reg}); + build.mov(x1, reg); build.mov(x0, rState); build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barriertable))); @@ -1290,8 +1285,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(temp2, temp1); build.b(ConditionA64::UnsignedGreater, skip); - build.mov(x1, temp2); // TODO: aliasing hazard - size_t spills = regs.spill(build, index, {x1}); + size_t spills = regs.spill(build, index, {temp2}); + build.mov(x1, temp2); build.mov(x0, rState); build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaF_close))); build.blr(x2); @@ -1484,8 +1479,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITAND_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(uintOp(inst.b))) - build.and_(inst.regA64, regOp(inst.a), uintOp(inst.b)); + if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) + build.and_(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b))); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1497,8 +1492,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITXOR_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(uintOp(inst.b))) - build.eor(inst.regA64, regOp(inst.a), uintOp(inst.b)); + if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) + build.eor(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b))); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1510,8 +1505,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITOR_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(uintOp(inst.b))) - build.orr(inst.regA64, regOp(inst.a), uintOp(inst.b)); + if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) + build.orr(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b))); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1531,7 +1526,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant) - build.lsl(inst.regA64, regOp(inst.a), uint8_t(uintOp(inst.b) & 31)); + build.lsl(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1544,7 +1539,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant) - build.lsr(inst.regA64, regOp(inst.a), uint8_t(uintOp(inst.b) & 31)); + build.lsr(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1557,7 +1552,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant) - build.asr(inst.regA64, regOp(inst.a), uint8_t(uintOp(inst.b) & 31)); + build.asr(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1571,7 +1566,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (inst.b.kind == IrOpKind::Constant) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a}); - build.ror(inst.regA64, regOp(inst.a), uint8_t((32 - uintOp(inst.b)) & 31)); + build.ror(inst.regA64, regOp(inst.a), uint8_t((32 - unsigned(intOp(inst.b))) & 31)); } else { @@ -1587,7 +1582,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant) - build.ror(inst.regA64, regOp(inst.a), uint8_t(uintOp(inst.b) & 31)); + build.ror(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1613,39 +1608,51 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::INVOKE_LIBM: { - RegisterA64 temp1 = tempDouble(inst.b); - - build.fmov(d0, temp1); // TODO: aliasing hazard - if (inst.c.kind != IrOpKind::None) { + RegisterA64 temp1 = tempDouble(inst.b); RegisterA64 temp2 = tempDouble(inst.c); - build.fmov(d1, temp2); // TODO: aliasing hazard - regs.spill(build, index, {d0, d1}); + RegisterA64 temp3 = regs.allocTemp(KindA64::d); // note: spill() frees all registers so we need to avoid alloc after spill + regs.spill(build, index, {temp1, temp2}); + + if (d0 != temp2) + { + build.fmov(d0, temp1); + build.fmov(d1, temp2); + } + else + { + build.fmov(temp3, d0); + build.fmov(d0, temp1); + build.fmov(d1, temp3); + } } else - regs.spill(build, index, {d0}); + { + RegisterA64 temp1 = tempDouble(inst.b); + regs.spill(build, index, {temp1}); + build.fmov(d0, temp1); + } - build.ldr(x0, mem(rNativeContext, getNativeContextOffset(LuauBuiltinFunction(uintOp(inst.a))))); + build.ldr(x0, mem(rNativeContext, getNativeContextOffset(uintOp(inst.a)))); build.blr(x0); - // TODO: we could takeReg d0 but it's unclear if we will be able to keep d0 allocatable due to aliasing concerns - inst.regA64 = regs.allocReg(KindA64::d, index); - build.fmov(inst.regA64, d0); + inst.regA64 = regs.takeReg(d0, index); break; } - // Unsupported instructions - // Note: when adding implementations for these, please move the case: label so that implemented instructions match the order in IrData.h - case IrCmd::STORE_VECTOR: - error = true; - break; + // To handle unsupported instructions, add "case IrCmd::OP" and make sure to set error = true! } + valueTracker.afterInstLowering(inst, index); + regs.freeLastUseRegs(inst, index); regs.freeTempRegs(); } -void IrLoweringA64::finishBlock() {} +void IrLoweringA64::finishBlock() +{ + regs.assertNoSpills(); +} bool IrLoweringA64::hasError() const { @@ -1717,7 +1724,7 @@ RegisterA64 IrLoweringA64::tempUint(IrOp op) else if (op.kind == IrOpKind::Constant) { RegisterA64 temp = regs.allocTemp(KindA64::w); - build.mov(temp, uintOp(op)); + build.mov(temp, unsigned(intOp(op))); return temp; } else @@ -1762,7 +1769,7 @@ RegisterA64 IrLoweringA64::regOp(IrOp op) { IrInst& inst = function.instOp(op); - if (inst.spilled) + if (inst.spilled || inst.needsReload) regs.restoreReg(build, inst); LUAU_ASSERT(inst.regA64 != noreg); diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index d5f3a55be..9eda8976c 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -5,6 +5,7 @@ #include "Luau/IrData.h" #include "IrRegAllocA64.h" +#include "IrValueLocationTracking.h" #include @@ -64,6 +65,8 @@ struct IrLoweringA64 IrRegAllocA64 regs; + IrValueLocationTracking valueTracker; + bool error = false; }; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 9af3f73dc..bc617571b 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -237,17 +237,38 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.vmovups(luauNodeValue(regOp(inst.a)), regOp(inst.b)); break; case IrCmd::ADD_INT: + { inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); - if (inst.b.kind == IrOpKind::Inst) - build.lea(inst.regX64, addr[regOp(inst.a) + regOp(inst.b)]); - else if (inst.regX64 == regOp(inst.a) && intOp(inst.b) == 1) - build.inc(inst.regX64); - else if (inst.regX64 == regOp(inst.a)) - build.add(inst.regX64, intOp(inst.b)); + if (inst.a.kind == IrOpKind::Constant) + { + build.lea(inst.regX64, addr[regOp(inst.b) + intOp(inst.a)]); + } + else if (inst.a.kind == IrOpKind::Inst) + { + if (inst.regX64 == regOp(inst.a)) + { + if (inst.b.kind == IrOpKind::Inst) + build.add(inst.regX64, regOp(inst.b)); + else if (intOp(inst.b) == 1) + build.inc(inst.regX64); + else + build.add(inst.regX64, intOp(inst.b)); + } + else + { + if (inst.b.kind == IrOpKind::Inst) + build.lea(inst.regX64, addr[regOp(inst.a) + regOp(inst.b)]); + else + build.lea(inst.regX64, addr[regOp(inst.a) + intOp(inst.b)]); + } + } else - build.lea(inst.regX64, addr[regOp(inst.a) + intOp(inst.b)]); + { + LUAU_ASSERT(!"Unsupported instruction form"); + } break; + } case IrCmd::SUB_INT: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); @@ -359,15 +380,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; } - case IrCmd::POW_NUM: - { - IrCallWrapperX64 callWrap(regs, build, index); - callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.a), inst.a); - callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); - inst.regX64 = regs.takeReg(xmm0, index); - break; - } case IrCmd::MIN_NUM: inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); @@ -537,7 +549,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.d), next); break; case IrCmd::JUMP_GE_UINT: - build.cmp(regOp(inst.a), uintOp(inst.b)); + build.cmp(regOp(inst.a), unsigned(intOp(inst.b))); build.jcc(ConditionX64::AboveEqual, labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.d), next); @@ -690,8 +702,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::FASTCALL: - emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f)); + { + OperandX64 arg2 = inst.d.kind != IrOpKind::Undef ? memRegDoubleOp(inst.d) : OperandX64{0}; + + emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), arg2, intOp(inst.e), intOp(inst.f)); break; + } case IrCmd::INVOKE_FASTCALL: { unsigned bfid = uintOp(inst.a); @@ -703,7 +719,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) else if (inst.d.kind == IrOpKind::VmConst) args = luauConstantAddress(vmConstOp(inst.d)); else - LUAU_ASSERT(boolOp(inst.d) == false); + LUAU_ASSERT(inst.d.kind == IrOpKind::Undef); int ra = vmRegOp(inst.b); int arg = vmRegOp(inst.c); @@ -1141,32 +1157,32 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITAND_UINT: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); - if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.and_(inst.regX64, memRegUintOp(inst.b)); break; case IrCmd::BITXOR_UINT: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); - if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.xor_(inst.regX64, memRegUintOp(inst.b)); break; case IrCmd::BITOR_UINT: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); - if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.or_(inst.regX64, memRegUintOp(inst.b)); break; case IrCmd::BITNOT_UINT: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); - if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.not_(inst.regX64); break; @@ -1179,10 +1195,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(shiftTmp.reg, memRegUintOp(inst.b)); - if (inst.a.kind == IrOpKind::Constant) - build.mov(inst.regX64, uintOp(inst.a)); - else if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.shl(inst.regX64, byteReg(shiftTmp.reg)); break; @@ -1196,10 +1210,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(shiftTmp.reg, memRegUintOp(inst.b)); - if (inst.a.kind == IrOpKind::Constant) - build.mov(inst.regX64, uintOp(inst.a)); - else if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.shr(inst.regX64, byteReg(shiftTmp.reg)); break; @@ -1213,10 +1225,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(shiftTmp.reg, memRegUintOp(inst.b)); - if (inst.a.kind == IrOpKind::Constant) - build.mov(inst.regX64, uintOp(inst.a)); - else if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.sar(inst.regX64, byteReg(shiftTmp.reg)); break; @@ -1230,10 +1240,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(shiftTmp.reg, memRegUintOp(inst.b)); - if (inst.a.kind == IrOpKind::Constant) - build.mov(inst.regX64, uintOp(inst.a)); - else if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.rol(inst.regX64, byteReg(shiftTmp.reg)); break; @@ -1247,10 +1255,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(shiftTmp.reg, memRegUintOp(inst.b)); - if (inst.a.kind == IrOpKind::Constant) - build.mov(inst.regX64, uintOp(inst.a)); - else if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.ror(inst.regX64, byteReg(shiftTmp.reg)); break; @@ -1294,15 +1300,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::INVOKE_LIBM: { - LuauBuiltinFunction bfid = LuauBuiltinFunction(uintOp(inst.a)); - IrCallWrapperX64 callWrap(regs, build, index); callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); if (inst.c.kind != IrOpKind::None) callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.c), inst.c); - callWrap.call(qword[rNativeContext + getNativeContextOffset(bfid)]); + callWrap.call(qword[rNativeContext + getNativeContextOffset(uintOp(inst.a))]); inst.regX64 = regs.takeReg(xmm0, index); break; } @@ -1370,7 +1374,7 @@ OperandX64 IrLoweringX64::memRegUintOp(IrOp op) case IrOpKind::Inst: return regOp(op); case IrOpKind::Constant: - return OperandX64(uintOp(op)); + return OperandX64(unsigned(intOp(op))); default: LUAU_ASSERT(!"Unsupported operand kind"); } diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index bf5538454..a4cfeaed4 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -2,10 +2,15 @@ #include "IrRegAllocA64.h" #include "Luau/AssemblyBuilderA64.h" +#include "Luau/IrUtils.h" #include "BitUtils.h" #include "EmitCommonA64.h" +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauCodegenChaosA64, false) + namespace Luau { namespace CodeGen @@ -39,6 +44,68 @@ static void freeSpill(uint32_t& free, KindA64 kind, uint8_t slot) free |= mask; } +static int getReloadOffset(IrCmd cmd) +{ + switch (getCmdValueKind(cmd)) + { + case IrValueKind::Unknown: + case IrValueKind::None: + LUAU_ASSERT(!"Invalid operand restore value kind"); + break; + case IrValueKind::Tag: + return offsetof(TValue, tt); + case IrValueKind::Int: + return offsetof(TValue, value); + case IrValueKind::Pointer: + return offsetof(TValue, value.gc); + case IrValueKind::Double: + return offsetof(TValue, value.n); + case IrValueKind::Tvalue: + return 0; + } + + LUAU_ASSERT(!"Invalid operand restore value kind"); + LUAU_UNREACHABLE(); +} + +static AddressA64 getReloadAddress(const IrFunction& function, const IrInst& inst) +{ + IrOp location = function.findRestoreOp(inst); + + if (location.kind == IrOpKind::VmReg) + return mem(rBase, vmRegOp(location) * sizeof(TValue) + getReloadOffset(inst.cmd)); + + // loads are 4/8/16 bytes; we conservatively limit the offset to fit assuming a 4b index + if (location.kind == IrOpKind::VmConst && vmConstOp(location) * sizeof(TValue) <= AddressA64::kMaxOffset * 4) + return mem(rConstants, vmConstOp(location) * sizeof(TValue) + getReloadOffset(inst.cmd)); + + return AddressA64(xzr); // dummy +} + +static void restoreInst(AssemblyBuilderA64& build, uint32_t& freeSpillSlots, IrFunction& function, const IrRegAllocA64::Spill& s, RegisterA64 reg) +{ + IrInst& inst = function.instructions[s.inst]; + LUAU_ASSERT(inst.regA64 == noreg); + + if (s.slot >= 0) + { + build.ldr(reg, mem(sp, sSpillArea.data + s.slot * 8)); + + freeSpill(freeSpillSlots, reg.kind, s.slot); + } + else + { + LUAU_ASSERT(!inst.spilled && inst.needsReload); + AddressA64 addr = getReloadAddress(function, function.instructions[s.inst]); + LUAU_ASSERT(addr.base != xzr); + build.ldr(reg, addr); + } + + inst.spilled = false; + inst.needsReload = false; + inst.regA64 = reg; +} + IrRegAllocA64::IrRegAllocA64(IrFunction& function, std::initializer_list> regs) : function(function) { @@ -68,11 +135,16 @@ RegisterA64 IrRegAllocA64::allocReg(KindA64 kind, uint32_t index) if (set.free == 0) { + // TODO: remember the error and fail lowering LUAU_ASSERT(!"Out of registers to allocate"); return noreg; } int reg = 31 - countlz(set.free); + + if (FFlag::DebugLuauCodegenChaosA64) + reg = countrz(set.free); // allocate from low end; this causes extra conflicts for calls + set.free &= ~(1u << reg); set.defs[reg] = index; @@ -85,12 +157,16 @@ RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind) if (set.free == 0) { + // TODO: remember the error and fail lowering LUAU_ASSERT(!"Out of registers to allocate"); return noreg; } int reg = 31 - countlz(set.free); + if (FFlag::DebugLuauCodegenChaosA64) + reg = countrz(set.free); // allocate from low end; this causes extra conflicts for calls + set.free &= ~(1u << reg); set.temp |= 1u << reg; LUAU_ASSERT(set.defs[reg] == kInvalidInstIdx); @@ -107,8 +183,9 @@ RegisterA64 IrRegAllocA64::allocReuse(KindA64 kind, uint32_t index, std::initial IrInst& source = function.instructions[op.index]; - if (source.lastUse == index && !source.reusedReg && !source.spilled && source.regA64 != noreg) + if (source.lastUse == index && !source.reusedReg && source.regA64 != noreg) { + LUAU_ASSERT(!source.spilled && !source.needsReload); LUAU_ASSERT(source.regA64.kind == kind); Set& set = getSet(kind); @@ -152,7 +229,7 @@ void IrRegAllocA64::freeLastUseReg(IrInst& target, uint32_t index) { if (target.lastUse == index && !target.reusedReg) { - LUAU_ASSERT(!target.spilled); + LUAU_ASSERT(!target.spilled && !target.needsReload); // Register might have already been freed if it had multiple uses inside a single instruction if (target.regA64 == noreg) @@ -195,13 +272,19 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init size_t start = spills.size(); - for (RegisterA64 reg : live) + uint32_t poisongpr = 0; + uint32_t poisonsimd = 0; + + if (FFlag::DebugLuauCodegenChaosA64) { - Set& set = getSet(reg.kind); + poisongpr = gpr.base & ~gpr.free; + poisonsimd = simd.base & ~simd.free; - // make sure registers that we expect to survive past spill barrier are not allocated - // TODO: we need to handle this condition somehow in the future; if this fails, this likely means the caller has an aliasing hazard - LUAU_ASSERT(set.free & (1u << reg.index)); + for (RegisterA64 reg : live) + { + Set& set = getSet(reg.kind); + (&set == &simd ? poisonsimd : poisongpr) &= ~(1u << reg.index); + } } for (KindA64 kind : sets) @@ -229,26 +312,38 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init IrInst& def = function.instructions[inst]; LUAU_ASSERT(def.regA64.index == reg); - LUAU_ASSERT(!def.spilled); LUAU_ASSERT(!def.reusedReg); + LUAU_ASSERT(!def.spilled); + LUAU_ASSERT(!def.needsReload); if (def.lastUse == index) { // instead of spilling the register to never reload it, we assume the register is not needed anymore - def.regA64 = noreg; + } + else if (getReloadAddress(function, def).base != xzr) + { + // instead of spilling the register to stack, we can reload it from VM stack/constants + // we still need to record the spill for restore(start) to work + Spill s = {inst, def.regA64, -1}; + spills.push_back(s); + + def.needsReload = true; } else { int slot = allocSpill(freeSpillSlots, def.regA64.kind); LUAU_ASSERT(slot >= 0); // TODO: remember the error and fail lowering - Spill s = {inst, def.regA64, uint8_t(slot)}; + build.str(def.regA64, mem(sp, sSpillArea.data + slot * 8)); + + Spill s = {inst, def.regA64, int8_t(slot)}; spills.push_back(s); def.spilled = true; - def.regA64 = noreg; } + def.regA64 = noreg; + regs &= ~(1u << reg); set.free |= 1u << reg; set.defs[reg] = kInvalidInstIdx; @@ -257,11 +352,15 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init LUAU_ASSERT(set.free == set.base); } - if (start < spills.size()) + if (FFlag::DebugLuauCodegenChaosA64) { - // TODO: use stp for consecutive slots - for (size_t i = start; i < spills.size(); ++i) - build.str(spills[i].origin, mem(sp, sSpillArea.data + spills[i].slot * 8)); + for (int reg = 0; reg < 32; ++reg) + { + if (poisongpr & (1u << reg)) + build.mov(RegisterA64{KindA64::x, uint8_t(reg)}, 0xdead); + if (poisonsimd & (1u << reg)) + build.fmov(RegisterA64{KindA64::d, uint8_t(reg)}, -0.125); + } } return start; @@ -273,22 +372,12 @@ void IrRegAllocA64::restore(AssemblyBuilderA64& build, size_t start) if (start < spills.size()) { - // TODO: use ldp for consecutive slots - for (size_t i = start; i < spills.size(); ++i) - build.ldr(spills[i].origin, mem(sp, sSpillArea.data + spills[i].slot * 8)); - for (size_t i = start; i < spills.size(); ++i) { Spill s = spills[i]; // copy in case takeReg reallocates spills + RegisterA64 reg = takeReg(s.origin, s.inst); - IrInst& def = function.instructions[s.inst]; - LUAU_ASSERT(def.spilled); - LUAU_ASSERT(def.regA64 == noreg); - - def.spilled = false; - def.regA64 = takeReg(s.origin, s.inst); - - freeSpill(freeSpillSlots, s.origin.kind, s.slot); + restoreInst(build, freeSpillSlots, function, s, reg); } spills.resize(start); @@ -299,9 +388,6 @@ void IrRegAllocA64::restoreReg(AssemblyBuilderA64& build, IrInst& inst) { uint32_t index = function.getInstIndex(inst); - LUAU_ASSERT(inst.spilled); - LUAU_ASSERT(inst.regA64 == noreg); - for (size_t i = 0; i < spills.size(); ++i) { if (spills[i].inst == index) @@ -309,12 +395,7 @@ void IrRegAllocA64::restoreReg(AssemblyBuilderA64& build, IrInst& inst) Spill s = spills[i]; // copy in case allocReg reallocates spills RegisterA64 reg = allocReg(s.origin.kind, index); - build.ldr(reg, mem(sp, sSpillArea.data + s.slot * 8)); - - inst.spilled = false; - inst.regA64 = reg; - - freeSpill(freeSpillSlots, reg.kind, s.slot); + restoreInst(build, freeSpillSlots, function, s, reg); spills[i] = spills.back(); spills.pop_back(); @@ -338,6 +419,7 @@ IrRegAllocA64::Set& IrRegAllocA64::getSet(KindA64 kind) case KindA64::w: return gpr; + case KindA64::s: case KindA64::d: case KindA64::q: return simd; diff --git a/CodeGen/src/IrRegAllocA64.h b/CodeGen/src/IrRegAllocA64.h index 940b511c0..689743789 100644 --- a/CodeGen/src/IrRegAllocA64.h +++ b/CodeGen/src/IrRegAllocA64.h @@ -65,7 +65,7 @@ struct IrRegAllocA64 uint32_t inst; RegisterA64 origin; - uint8_t slot; + int8_t slot; }; Set& getSet(KindA64 kind); diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index a6da1e40b..e58d0a126 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -6,6 +6,8 @@ #include "lstate.h" +#include + // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results static const int kMinMaxUnrolledParams = 5; @@ -16,16 +18,32 @@ namespace Luau namespace CodeGen { +static void builtinCheckDouble(IrBuilder& build, IrOp arg, IrOp fallback) +{ + if (arg.kind == IrOpKind::Constant) + LUAU_ASSERT(build.function.constOp(arg).kind == IrConstKind::Double); + else + build.loadAndCheckTag(arg, LUA_TNUMBER, fallback); +} + +static IrOp builtinLoadDouble(IrBuilder& build, IrOp arg) +{ + if (arg.kind == IrOpKind::Constant) + return arg; + + return build.inst(IrCmd::LOAD_DOUBLE, arg); +} + // Wrapper code for all builtins with a fixed signature and manual assembly lowering of the body // (number, ...) -> number -BuiltinImplResult translateBuiltinNumberToNumber( +static BuiltinImplResult translateBuiltinNumberToNumber( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); if (ra != arg) @@ -34,14 +52,14 @@ BuiltinImplResult translateBuiltinNumberToNumber( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinNumberToNumberLibm( +static BuiltinImplResult translateBuiltinNumberToNumberLibm( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + builtinCheckDouble(build, build.vmReg(arg), fallback); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(bfid), va); @@ -54,14 +72,14 @@ BuiltinImplResult translateBuiltinNumberToNumberLibm( } // (number, number, ...) -> number -BuiltinImplResult translateBuiltin2NumberToNumber( +static BuiltinImplResult translateBuiltin2NumberToNumber( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(2), build.constInt(1)); if (ra != arg) @@ -70,17 +88,17 @@ BuiltinImplResult translateBuiltin2NumberToNumber( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltin2NumberToNumberLibm( +static BuiltinImplResult translateBuiltin2NumberToNumberLibm( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(bfid), va, vb); @@ -93,13 +111,13 @@ BuiltinImplResult translateBuiltin2NumberToNumberLibm( } // (number, ...) -> (number, number) -BuiltinImplResult translateBuiltinNumberTo2Number( +static BuiltinImplResult translateBuiltinNumberTo2Number( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 2) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); build.inst( IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(nresults == 1 ? 1 : 2)); @@ -112,7 +130,7 @@ BuiltinImplResult translateBuiltinNumberTo2Number( return {BuiltinImplType::UsesFallback, 2}; } -BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults != 0) return {BuiltinImplType::None, -1}; @@ -126,16 +144,16 @@ BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 0}; } -BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); const double rpd = (3.14159265358979323846 / 180.0); - IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp varg = builtinLoadDouble(build, build.vmReg(arg)); IrOp value = build.inst(IrCmd::DIV_NUM, varg, build.constDouble(rpd)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); @@ -145,16 +163,16 @@ BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); const double rpd = (3.14159265358979323846 / 180.0); - IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp varg = builtinLoadDouble(build, build.vmReg(arg)); IrOp value = build.inst(IrCmd::MUL_NUM, varg, build.constDouble(rpd)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); @@ -164,48 +182,40 @@ BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathLog( +static BuiltinImplResult translateBuiltinMathLog( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - LuauBuiltinFunction fcId = bfid; - int fcParams = 1; + int libmId = bfid; + std::optional denom; if (nparams != 1) { - if (args.kind != IrOpKind::VmConst) - return {BuiltinImplType::None, -1}; - - LUAU_ASSERT(build.function.proto); - TValue protok = build.function.proto->k[vmConstOp(args)]; + std::optional y = build.function.asDoubleOp(args); - if (protok.tt != LUA_TNUMBER) + if (!y) return {BuiltinImplType::None, -1}; - // TODO: IR builtin lowering assumes that the only valid 2-argument call is log2; ideally, we use a less hacky way to indicate that - if (protok.value.n == 2.0) - fcParams = 2; - else if (protok.value.n == 10.0) - fcId = LBF_MATH_LOG10; + if (*y == 2.0) + libmId = LBF_IR_MATH_LOG2; + else if (*y == 10.0) + libmId = LBF_MATH_LOG10; else - // TODO: We can precompute log(args) and divide by it, but that requires extra LOAD/STORE so for now just fall back as this is rare - return {BuiltinImplType::None, -1}; + denom = log(*y); } - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); - if (fcId == LBF_MATH_LOG10) - { - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); - IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(fcId), va); + IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(libmId), va); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); - } - else - build.inst(IrCmd::FASTCALL, build.constUint(fcId), build.vmReg(ra), build.vmReg(arg), args, build.constInt(fcParams), build.constInt(1)); + if (denom) + res = build.inst(IrCmd::DIV_NUM, res, build.constDouble(*denom)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); @@ -213,25 +223,25 @@ BuiltinImplResult translateBuiltinMathLog( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nparams > kMinMaxUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); for (int i = 3; i <= nparams; ++i) - build.loadAndCheckTag(build.vmReg(vmRegOp(args) + (i - 2)), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), fallback); - IrOp varg1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp varg2 = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp varg1 = builtinLoadDouble(build, build.vmReg(arg)); + IrOp varg2 = builtinLoadDouble(build, args); IrOp res = build.inst(IrCmd::MIN_NUM, varg2, varg1); // Swapped arguments are required for consistency with VM builtins for (int i = 3; i <= nparams; ++i) { - IrOp arg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + (i - 2))); + IrOp arg = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + (i - 2))); res = build.inst(IrCmd::MIN_NUM, arg, res); } @@ -243,25 +253,25 @@ BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nparams > kMinMaxUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); for (int i = 3; i <= nparams; ++i) - build.loadAndCheckTag(build.vmReg(vmRegOp(args) + (i - 2)), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), fallback); - IrOp varg1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp varg2 = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp varg1 = builtinLoadDouble(build, build.vmReg(arg)); + IrOp varg2 = builtinLoadDouble(build, args); IrOp res = build.inst(IrCmd::MAX_NUM, varg2, varg1); // Swapped arguments are required for consistency with VM builtins for (int i = 3; i <= nparams; ++i) { - IrOp arg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + (i - 2))); + IrOp arg = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + (i - 2))); res = build.inst(IrCmd::MAX_NUM, arg, res); } @@ -273,7 +283,7 @@ BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -282,17 +292,17 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r LUAU_ASSERT(args.kind == IrOpKind::VmReg); - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - build.loadAndCheckTag(build.vmReg(vmRegOp(args) + 1), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), fallback); - IrOp min = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp max = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + 1)); + IrOp min = builtinLoadDouble(build, args); + IrOp max = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::JUMP_CMP_NUM, min, max, build.cond(IrCondition::NotLessEqual), fallback, block); build.beginBlock(block); - IrOp v = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp v = builtinLoadDouble(build, build.vmReg(arg)); IrOp r = build.inst(IrCmd::MAX_NUM, min, v); IrOp clamped = build.inst(IrCmd::MIN_NUM, max, r); @@ -304,14 +314,14 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); - IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp varg = builtinLoadDouble(build, build.vmReg(arg)); IrOp result = build.inst(cmd, varg); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), result); @@ -322,27 +332,7 @@ BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int npa return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathBinary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) -{ - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - - IrOp lhs = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp rhs = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp result = build.inst(cmd, lhs, rhs); - - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), result); - - if (ra != arg) - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - - return {BuiltinImplType::UsesFallback, 1}; -} - -BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -354,7 +344,7 @@ BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, in return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -366,20 +356,20 @@ BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32BinaryOp( +static BuiltinImplResult translateBuiltinBit32BinaryOp( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nparams > kBit32BinaryOpUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); for (int i = 3; i <= nparams; ++i) - build.loadAndCheckTag(build.vmReg(vmRegOp(args) + (i - 2)), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); IrOp vbui = build.inst(IrCmd::NUM_TO_UINT, vb); @@ -399,7 +389,7 @@ BuiltinImplResult translateBuiltinBit32BinaryOp( for (int i = 3; i <= nparams; ++i) { - IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + (i - 2))); + IrOp vc = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + (i - 2))); IrOp arg = build.inst(IrCmd::NUM_TO_UINT, vc); res = build.inst(cmd, res, arg); @@ -436,14 +426,14 @@ BuiltinImplResult translateBuiltinBit32BinaryOp( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Bnot( +static BuiltinImplResult translateBuiltinBit32Bnot( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + builtinCheckDouble(build, build.vmReg(arg), fallback); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); IrOp not_ = build.inst(IrCmd::BITNOT_UINT, vaui); @@ -457,7 +447,7 @@ BuiltinImplResult translateBuiltinBit32Bnot( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Shift( +static BuiltinImplResult translateBuiltinBit32Shift( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) @@ -465,16 +455,16 @@ BuiltinImplResult translateBuiltinBit32Shift( IrOp block = build.block(IrBlockKind::Internal); - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); IrOp vbi = build.inst(IrCmd::NUM_TO_INT, vb); - build.inst(IrCmd::JUMP_GE_UINT, vbi, build.constUint(32), fallback, block); + build.inst(IrCmd::JUMP_GE_UINT, vbi, build.constInt(32), fallback, block); build.beginBlock(block); IrCmd cmd = IrCmd::NOP; @@ -498,17 +488,17 @@ BuiltinImplResult translateBuiltinBit32Shift( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Rotate( +static BuiltinImplResult translateBuiltinBit32Rotate( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); IrOp vbi = build.inst(IrCmd::NUM_TO_INT, vb); @@ -525,17 +515,17 @@ BuiltinImplResult translateBuiltinBit32Rotate( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Extract( +static BuiltinImplResult translateBuiltinBit32Extract( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); IrOp f = build.inst(IrCmd::NUM_TO_INT, vb); @@ -544,17 +534,17 @@ BuiltinImplResult translateBuiltinBit32Extract( if (nparams == 2) { IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_GE_UINT, f, build.constUint(32), fallback, block); + build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); build.beginBlock(block); // TODO: this can be optimized using a bit-select instruction (bt on x86) IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f); - value = build.inst(IrCmd::BITAND_UINT, shift, build.constUint(1)); + value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1)); } else { - build.loadAndCheckTag(build.vmReg(args.index + 1), LUA_TNUMBER, fallback); - IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 1)); + builtinCheckDouble(build, build.vmReg(args.index + 1), fallback); + IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vc); IrOp block1 = build.block(IrBlockKind::Internal); @@ -570,7 +560,7 @@ BuiltinImplResult translateBuiltinBit32Extract( build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); build.beginBlock(block3); - IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, build.constUint(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); + IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); IrOp m = build.inst(IrCmd::BITNOT_UINT, shift); IrOp nf = build.inst(IrCmd::BITRSHIFT_UINT, n, f); @@ -585,15 +575,15 @@ BuiltinImplResult translateBuiltinBit32Extract( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32ExtractK( +static BuiltinImplResult translateBuiltinBit32ExtractK( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); double a2 = build.function.doubleOp(args); @@ -604,8 +594,8 @@ BuiltinImplResult translateBuiltinBit32ExtractK( uint32_t m = ~(0xfffffffeu << w1); - IrOp nf = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constUint(f)); - IrOp and_ = build.inst(IrCmd::BITAND_UINT, nf, build.constUint(m)); + IrOp nf = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f)); + IrOp and_ = build.inst(IrCmd::BITAND_UINT, nf, build.constInt(m)); IrOp value = build.inst(IrCmd::UINT_TO_NUM, and_); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); @@ -616,14 +606,14 @@ BuiltinImplResult translateBuiltinBit32ExtractK( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Countz( +static BuiltinImplResult translateBuiltinBit32Countz( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + builtinCheckDouble(build, build.vmReg(arg), fallback); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); @@ -640,19 +630,19 @@ BuiltinImplResult translateBuiltinBit32Countz( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Replace( +static BuiltinImplResult translateBuiltinBit32Replace( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - build.loadAndCheckTag(build.vmReg(args.index + 1), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(args.index + 1), fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 1)); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); + IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); IrOp v = build.inst(IrCmd::NUM_TO_UINT, vb); @@ -662,11 +652,11 @@ BuiltinImplResult translateBuiltinBit32Replace( if (nparams == 3) { IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_GE_UINT, f, build.constUint(32), fallback, block); + build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); build.beginBlock(block); // TODO: this can be optimized using a bit-select instruction (btr on x86) - IrOp m = build.constUint(1); + IrOp m = build.constInt(1); IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, m, f); IrOp not_ = build.inst(IrCmd::BITNOT_UINT, shift); IrOp lhs = build.inst(IrCmd::BITAND_UINT, n, not_); @@ -678,8 +668,8 @@ BuiltinImplResult translateBuiltinBit32Replace( } else { - build.loadAndCheckTag(build.vmReg(args.index + 2), LUA_TNUMBER, fallback); - IrOp vd = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 2)); + builtinCheckDouble(build, build.vmReg(args.index + 2), fallback); + IrOp vd = builtinLoadDouble(build, build.vmReg(args.index + 2)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vd); IrOp block1 = build.block(IrBlockKind::Internal); @@ -695,7 +685,7 @@ BuiltinImplResult translateBuiltinBit32Replace( build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); build.beginBlock(block3); - IrOp shift1 = build.inst(IrCmd::BITLSHIFT_UINT, build.constUint(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); + IrOp shift1 = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); IrOp m = build.inst(IrCmd::BITNOT_UINT, shift1); IrOp shift2 = build.inst(IrCmd::BITLSHIFT_UINT, m, f); @@ -716,20 +706,20 @@ BuiltinImplResult translateBuiltinBit32Replace( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; LUAU_ASSERT(LUA_VECTOR_SIZE == 3); - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - build.loadAndCheckTag(build.vmReg(vmRegOp(args) + 1), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), fallback); - IrOp x = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp y = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp z = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + 1)); + IrOp x = builtinLoadDouble(build, build.vmReg(arg)); + IrOp y = builtinLoadDouble(build, args); + IrOp z = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), x, y, z); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); @@ -769,8 +759,6 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, return translateBuiltinMathUnary(build, IrCmd::ABS_NUM, nparams, ra, arg, nresults, fallback); case LBF_MATH_ROUND: return translateBuiltinMathUnary(build, IrCmd::ROUND_NUM, nparams, ra, arg, nresults, fallback); - case LBF_MATH_POW: - return translateBuiltinMathBinary(build, IrCmd::POW_NUM, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_EXP: case LBF_MATH_ASIN: case LBF_MATH_SIN: @@ -785,6 +773,7 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, return translateBuiltinNumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); case LBF_MATH_SIGN: return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_POW: case LBF_MATH_FMOD: case LBF_MATH_ATAN2: return translateBuiltin2NumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index a42a72696..ebbcd875c 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -342,7 +342,7 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, result = build.inst(IrCmd::MOD_NUM, vb, vc); break; case TM_POW: - result = build.inst(IrCmd::POW_NUM, vb, vc); + result = build.inst(IrCmd::INVOKE_LIBM, build.constUint(LBF_MATH_POW), vb, vc); break; default: LUAU_ASSERT(!"unsupported binary op"); @@ -498,8 +498,6 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool int bfid = LUAU_INSN_A(*pc); int skip = LUAU_INSN_C(*pc); - IrOp fallback = build.block(IrBlockKind::Fallback); - Instruction call = pc[skip + 1]; LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); @@ -509,15 +507,21 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1; IrOp args = customParams ? customArgs : build.vmReg(ra + 2); - build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + IrOp builtinArgs = args; - if (bfid == LBF_BIT32_EXTRACTK) + if (customArgs.kind == IrOpKind::VmConst) { - TValue protok = build.function.proto->k[pc[1]]; - args = build.constDouble(protok.value.n); + TValue protok = build.function.proto->k[customArgs.index]; + + if (protok.tt == LUA_TNUMBER) + builtinArgs = build.constDouble(protok.value.n); } - BuiltinImplResult br = translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, args, nparams, nresults, fallback); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + + BuiltinImplResult br = translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, nparams, nresults, fallback); if (br.type == BuiltinImplType::UsesFallback) { diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 8be9e1b7a..a3af43449 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -3,6 +3,7 @@ #include "Luau/IrBuilder.h" +#include "BitUtils.h" #include "NativeState.h" #include "lua.h" @@ -54,7 +55,6 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: @@ -312,6 +312,8 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement) inst.cmd = IrCmd::SUBSTITUTE; + addUse(function, replacement); + removeUse(function, inst.a); removeUse(function, inst.b); removeUse(function, inst.c); @@ -349,6 +351,9 @@ void applySubstitutions(IrFunction& function, IrOp& op) LUAU_ASSERT(src.useCount > 0); src.useCount--; + + if (src.useCount == 0) + removeUse(function, src.a); } } } @@ -444,17 +449,13 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(luai_nummod(function.doubleOp(inst.a), function.doubleOp(inst.b)))); break; - case IrCmd::POW_NUM: - if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) - substitute(function, inst, build.constDouble(pow(function.doubleOp(inst.a), function.doubleOp(inst.b)))); - break; case IrCmd::MIN_NUM: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { double a1 = function.doubleOp(inst.a); double a2 = function.doubleOp(inst.b); - substitute(function, inst, build.constDouble((a2 < a1) ? a2 : a1)); + substitute(function, inst, build.constDouble(a1 < a2 ? a1 : a2)); } break; case IrCmd::MAX_NUM: @@ -463,7 +464,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 double a1 = function.doubleOp(inst.a); double a2 = function.doubleOp(inst.b); - substitute(function, inst, build.constDouble((a2 > a1) ? a2 : a1)); + substitute(function, inst, build.constDouble(a1 > a2 ? a1 : a2)); } break; case IrCmd::UNM_NUM: @@ -533,7 +534,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 case IrCmd::JUMP_GE_UINT: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { - if (function.uintOp(inst.a) >= function.uintOp(inst.b)) + if (unsigned(function.intOp(inst.a)) >= unsigned(function.intOp(inst.b))) replace(function, block, index, {IrCmd::JUMP, inst.c}); else replace(function, block, index, {IrCmd::JUMP, inst.d}); @@ -573,6 +574,30 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(double(function.intOp(inst.a)))); break; + case IrCmd::UINT_TO_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(double(unsigned(function.intOp(inst.a))))); + break; + case IrCmd::NUM_TO_INT: + if (inst.a.kind == IrOpKind::Constant) + { + double value = function.doubleOp(inst.a); + + // To avoid undefined behavior of casting a value not representable in the target type, we check the range + if (value >= INT_MIN && value <= INT_MAX) + substitute(function, inst, build.constInt(int(value))); + } + break; + case IrCmd::NUM_TO_UINT: + if (inst.a.kind == IrOpKind::Constant) + { + double value = function.doubleOp(inst.a); + + // To avoid undefined behavior of casting a value not representable in the target type, we check the range + if (value >= 0 && value <= UINT_MAX) + substitute(function, inst, build.constInt(unsigned(function.doubleOp(inst.a)))); + } + break; case IrCmd::CHECK_TAG: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { @@ -582,12 +607,139 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 replace(function, block, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path } break; + case IrCmd::BITAND_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + unsigned op1 = unsigned(function.intOp(inst.a)); + unsigned op2 = unsigned(function.intOp(inst.b)); + substitute(function, inst, build.constInt(op1 & op2)); + } + else + { + if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == 0) // (0 & b) -> 0 + substitute(function, inst, build.constInt(0)); + else if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == -1) // (-1 & b) -> b + substitute(function, inst, inst.b); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) // (a & 0) -> 0 + substitute(function, inst, build.constInt(0)); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == -1) // (a & -1) -> a + substitute(function, inst, inst.a); + } + break; + case IrCmd::BITXOR_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + unsigned op1 = unsigned(function.intOp(inst.a)); + unsigned op2 = unsigned(function.intOp(inst.b)); + substitute(function, inst, build.constInt(op1 ^ op2)); + } + else + { + if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == 0) // (0 ^ b) -> b + substitute(function, inst, inst.b); + else if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == -1) // (-1 ^ b) -> ~b + replace(function, block, index, {IrCmd::BITNOT_UINT, inst.b}); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) // (a ^ 0) -> a + substitute(function, inst, inst.a); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == -1) // (a ^ -1) -> ~a + replace(function, block, index, {IrCmd::BITNOT_UINT, inst.a}); + } + break; + case IrCmd::BITOR_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + unsigned op1 = unsigned(function.intOp(inst.a)); + unsigned op2 = unsigned(function.intOp(inst.b)); + substitute(function, inst, build.constInt(op1 | op2)); + } + else + { + if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == 0) // (0 | b) -> b + substitute(function, inst, inst.b); + else if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == -1) // (-1 | b) -> -1 + substitute(function, inst, build.constInt(-1)); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) // (a | 0) -> a + substitute(function, inst, inst.a); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == -1) // (a | -1) -> -1 + substitute(function, inst, build.constInt(-1)); + } + break; + case IrCmd::BITNOT_UINT: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(~unsigned(function.intOp(inst.a)))); + break; + case IrCmd::BITLSHIFT_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + unsigned op1 = unsigned(function.intOp(inst.a)); + int op2 = function.intOp(inst.b); + + if (unsigned(op2) < 32) + substitute(function, inst, build.constInt(op1 << op2)); + } + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) + { + substitute(function, inst, inst.a); + } + break; + case IrCmd::BITRSHIFT_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + unsigned op1 = unsigned(function.intOp(inst.a)); + int op2 = function.intOp(inst.b); + + if (unsigned(op2) < 32) + substitute(function, inst, build.constInt(op1 >> op2)); + } + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) + { + substitute(function, inst, inst.a); + } + break; + case IrCmd::BITARSHIFT_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + int op1 = function.intOp(inst.a); + int op2 = function.intOp(inst.b); + + if (unsigned(op2) < 32) + { + // note: technically right shift of negative values is UB, but this behavior is getting defined in C++20 and all compilers do the + // right (shift) thing. + substitute(function, inst, build.constInt(op1 >> op2)); + } + } + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) + { + substitute(function, inst, inst.a); + } + break; + case IrCmd::BITLROTATE_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(lrotate(unsigned(function.intOp(inst.a)), function.intOp(inst.b)))); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) + substitute(function, inst, inst.a); + break; + case IrCmd::BITRROTATE_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(rrotate(unsigned(function.intOp(inst.a)), function.intOp(inst.b)))); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) + substitute(function, inst, inst.a); + break; + case IrCmd::BITCOUNTLZ_UINT: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(countlz(unsigned(function.intOp(inst.a))))); + break; + case IrCmd::BITCOUNTRZ_UINT: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(countrz(unsigned(function.intOp(inst.a))))); + break; default: break; } } -uint32_t getNativeContextOffset(LuauBuiltinFunction bfid) +uint32_t getNativeContextOffset(int bfid) { switch (bfid) { @@ -607,6 +759,8 @@ uint32_t getNativeContextOffset(LuauBuiltinFunction bfid) return offsetof(NativeContext, libm_exp); case LBF_MATH_LOG10: return offsetof(NativeContext, libm_log10); + case LBF_MATH_LOG: + return offsetof(NativeContext, libm_log); case LBF_MATH_SINH: return offsetof(NativeContext, libm_sinh); case LBF_MATH_SIN: @@ -617,6 +771,10 @@ uint32_t getNativeContextOffset(LuauBuiltinFunction bfid) return offsetof(NativeContext, libm_tan); case LBF_MATH_FMOD: return offsetof(NativeContext, libm_fmod); + case LBF_MATH_POW: + return offsetof(NativeContext, libm_pow); + case LBF_IR_MATH_LOG2: + return offsetof(NativeContext, libm_log2); default: LUAU_ASSERT(!"Unsupported bfid"); } diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index b8220f54d..be661a7df 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -117,7 +117,6 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::JUMP_EQ_TAG: diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 9587a228a..459aeaa6d 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -27,9 +27,12 @@ using FallbackFn = const Instruction* (*)(lua_State* L, const Instruction* pc, S struct NativeProto { - uintptr_t entryTarget = 0; - uintptr_t* instTargets = nullptr; // TODO: NativeProto should be variable-size with all target embedded + // This array is stored before NativeProto in reverse order, so to get offset of instruction i you need to index instOffsets[-i] + // This awkward layout is helpful for maximally efficient address computation on X64/A64 + uint32_t instOffsets[1]; + uintptr_t instBase = 0; + uintptr_t entryTarget = 0; // = instOffsets[0] + instBase Proto* proto = nullptr; }; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 37af5a39b..e7663666a 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -8,6 +8,7 @@ #include "lua.h" +#include #include LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) @@ -42,8 +43,9 @@ struct RegisterLink // Data we know about the current VM state struct ConstPropState { - ConstPropState(const IrFunction& function) + ConstPropState(IrFunction& function) : function(function) + , valueMap({}) { } @@ -58,7 +60,13 @@ struct ConstPropState void saveTag(IrOp op, uint8_t tag) { if (RegisterInfo* info = tryGetRegisterInfo(op)) - info->tag = tag; + { + if (info->tag != tag) + { + info->tag = tag; + info->version++; + } + } } IrOp tryGetValue(IrOp op) @@ -74,7 +82,15 @@ struct ConstPropState LUAU_ASSERT(value.kind == IrOpKind::Constant); if (RegisterInfo* info = tryGetRegisterInfo(op)) - info->value = value; + { + if (info->value != value) + { + info->value = value; + info->knownNotReadonly = false; + info->knownNoMetatable = false; + info->version++; + } + } } void invalidate(RegisterInfo& reg, bool invalidateTag, bool invalidateValue) @@ -96,16 +112,22 @@ struct ConstPropState void invalidateTag(IrOp regOp) { + // TODO: use maxstacksize from Proto + maxReg = vmRegOp(regOp) > maxReg ? vmRegOp(regOp) : maxReg; invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ true, /* invalidateValue */ false); } void invalidateValue(IrOp regOp) { + // TODO: use maxstacksize from Proto + maxReg = vmRegOp(regOp) > maxReg ? vmRegOp(regOp) : maxReg; invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ false, /* invalidateValue */ true); } void invalidate(IrOp regOp) { + // TODO: use maxstacksize from Proto + maxReg = vmRegOp(regOp) > maxReg ? vmRegOp(regOp) : maxReg; invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ true, /* invalidateValue */ true); } @@ -113,8 +135,6 @@ struct ConstPropState { for (int i = firstReg; i <= maxReg; ++i) invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true); - - maxReg = int(firstReg) - 1; } void invalidateRegisterRange(int firstReg, int count) @@ -191,9 +211,90 @@ struct ConstPropState return nullptr; } - const IrFunction& function; + // Attach register version number to the register operand in a load instruction + // This is used to allow instructions with register references to be compared for equality + IrInst versionedVmRegLoad(IrCmd loadCmd, IrOp op) + { + LUAU_ASSERT(op.kind == IrOpKind::VmReg); + uint32_t version = regs[vmRegOp(op)].version; + LUAU_ASSERT(version <= 0xffffff); + op.index = vmRegOp(op) | (version << 8); + return IrInst{loadCmd, op}; + } + + // Find existing value of the instruction that is exactly the same, or record current on for future lookups + void substituteOrRecord(IrInst& inst, uint32_t instIdx) + { + if (!useValueNumbering) + return; + + if (uint32_t* prevIdx = valueMap.find(inst)) + substitute(function, inst, IrOp{IrOpKind::Inst, *prevIdx}); + else + valueMap[inst] = instIdx; + } + + // Vm register load can be replaced by a previous load of the same version of the register + // If there is no previous load, we record the current one for future lookups + void substituteOrRecordVmRegLoad(IrInst& loadInst) + { + LUAU_ASSERT(loadInst.a.kind == IrOpKind::VmReg); + + if (!useValueNumbering) + return; + + // To avoid captured register invalidation tracking in lowering later, values from loads from captured registers are not propagated + // This prevents the case where load value location is linked to memory in case of a spill and is then cloberred in a user call + if (function.cfg.captured.regs.test(vmRegOp(loadInst.a))) + return; + + IrInst versionedLoad = versionedVmRegLoad(loadInst.cmd, loadInst.a); + + // Check if there is a value that already has this version of the register + if (uint32_t* prevIdx = valueMap.find(versionedLoad)) + { + // Previous value might not be linked to a register yet + // For example, it could be a NEW_TABLE stored into a register and we might need to track guards made with this value + if (!instLink.contains(*prevIdx)) + createRegLink(*prevIdx, loadInst.a); + + // Substitute load instructon with the previous value + substitute(function, loadInst, IrOp{IrOpKind::Inst, *prevIdx}); + } + else + { + uint32_t instIdx = function.getInstIndex(loadInst); + + // Record load of this register version for future substitution + valueMap[versionedLoad] = instIdx; + + createRegLink(instIdx, loadInst.a); + } + } + + // VM register loads can use the value that was stored in the same Vm register earlier + void forwardVmRegStoreToLoad(const IrInst& storeInst, IrCmd loadCmd) + { + LUAU_ASSERT(storeInst.a.kind == IrOpKind::VmReg); + LUAU_ASSERT(storeInst.b.kind == IrOpKind::Inst); + + if (!useValueNumbering) + return; + + // To avoid captured register invalidation tracking in lowering later, values from stores into captured registers are not propagated + // This prevents the case where store creates an alternative value location in case of a spill and is then cloberred in a user call + if (function.cfg.captured.regs.test(vmRegOp(storeInst.a))) + return; + + // Future loads of this register version can use the value we stored + valueMap[versionedVmRegLoad(loadCmd, storeInst.a)] = storeInst.b.index; + } + + IrFunction& function; - RegisterInfo regs[256]; + bool useValueNumbering = false; + + std::array regs; // For range/full invalidations, we only want to visit a limited number of data that we have recorded int maxReg = 0; @@ -202,6 +303,8 @@ struct ConstPropState bool checkedGc = false; DenseHashMap instLink{~0u}; + + DenseHashMap valueMap; }; static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) @@ -277,6 +380,7 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid } // TODO: classify further using switch above, some fastcalls only modify the value, not the tag + // TODO: fastcalls are different from calls and it might be possible to not invalidate all register starting from return state.invalidateRegistersFrom(firstReturnReg); } @@ -292,45 +396,65 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; case IrCmd::LOAD_POINTER: if (inst.a.kind == IrOpKind::VmReg) - state.createRegLink(index, inst.a); + state.substituteOrRecordVmRegLoad(inst); break; case IrCmd::LOAD_DOUBLE: if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant) substitute(function, inst, value); else if (inst.a.kind == IrOpKind::VmReg) - state.createRegLink(index, inst.a); + state.substituteOrRecordVmRegLoad(inst); break; case IrCmd::LOAD_INT: if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant) substitute(function, inst, value); else if (inst.a.kind == IrOpKind::VmReg) - state.createRegLink(index, inst.a); + state.substituteOrRecordVmRegLoad(inst); break; case IrCmd::LOAD_TVALUE: if (inst.a.kind == IrOpKind::VmReg) - state.createRegLink(index, inst.a); + state.substituteOrRecordVmRegLoad(inst); break; case IrCmd::STORE_TAG: if (inst.a.kind == IrOpKind::VmReg) { + const IrOp source = inst.a; + uint32_t activeLoadDoubleValue = kInvalidInstIdx; + if (inst.b.kind == IrOpKind::Constant) { uint8_t value = function.tagOp(inst.b); - if (state.tryGetTag(inst.a) == value) + // STORE_TAG usually follows a store of the value, but it also bumps the version of the whole register + // To be able to propagate STORE_DOUBLE into LOAD_DOUBLE, we find active LOAD_DOUBLE value and recreate it with updated version + // Register in this optimization cannot be captured to avoid complications in lowering (IrValueLocationTracking doesn't model it) + // If stored tag is not a number, we can skip the lookup as there won't be future loads of this register as a number + if (value == LUA_TNUMBER && !function.cfg.captured.regs.test(vmRegOp(source))) + { + if (uint32_t* prevIdx = state.valueMap.find(state.versionedVmRegLoad(IrCmd::LOAD_DOUBLE, source))) + activeLoadDoubleValue = *prevIdx; + } + + if (state.tryGetTag(source) == value) kill(function, inst); else - state.saveTag(inst.a, value); + state.saveTag(source, value); } else { - state.invalidateTag(inst.a); + state.invalidateTag(source); } + + // Future LOAD_DOUBLE instructions can re-use previous register version load + if (activeLoadDoubleValue != kInvalidInstIdx) + state.valueMap[state.versionedVmRegLoad(IrCmd::LOAD_DOUBLE, source)] = activeLoadDoubleValue; } break; case IrCmd::STORE_POINTER: if (inst.a.kind == IrOpKind::VmReg) + { state.invalidateValue(inst.a); + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_POINTER); + } break; case IrCmd::STORE_DOUBLE: if (inst.a.kind == IrOpKind::VmReg) @@ -345,6 +469,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& else { state.invalidateValue(inst.a); + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_DOUBLE); } } break; @@ -361,6 +486,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& else { state.invalidateValue(inst.a); + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_INT); } } break; @@ -377,6 +503,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (IrOp value = state.tryGetValue(inst.b); value.kind != IrOpKind::None) state.saveValue(inst.a, value); + + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_TVALUE); } break; case IrCmd::JUMP_IF_TRUTHY: @@ -540,11 +668,12 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& // These instructions don't have an effect on register/memory state we are tracking case IrCmd::NOP: case IrCmd::LOAD_NODE_VALUE_TV: + case IrCmd::STORE_NODE_VALUE_TV: case IrCmd::LOAD_ENV: case IrCmd::GET_ARR_ADDR: case IrCmd::GET_SLOT_NODE_ADDR: case IrCmd::GET_HASH_NODE_ADDR: - case IrCmd::STORE_NODE_VALUE_TV: + break; case IrCmd::ADD_INT: case IrCmd::SUB_INT: case IrCmd::ADD_NUM: @@ -552,7 +681,6 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: @@ -562,6 +690,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::NOT_ANY: + state.substituteOrRecord(inst, index); + break; case IrCmd::JUMP: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_SLOT_MATCH: @@ -581,7 +711,6 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::RETURN: case IrCmd::COVERAGE: case IrCmd::SET_UPVALUE: - case IrCmd::SETLIST: // We don't track table state that this can invalidate case IrCmd::SET_SAVEDPC: // TODO: we may be able to remove some updates to PC case IrCmd::CLOSE_UPVALS: // Doesn't change memory that we track case IrCmd::CAPTURE: @@ -642,12 +771,21 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::INTERRUPT: state.invalidateUserCall(); break; + case IrCmd::SETLIST: + state.valueMap.clear(); // TODO: this can be relaxed when x64 emitInstSetList becomes aware of register allocator + break; case IrCmd::CALL: state.invalidateRegistersFrom(vmRegOp(inst.a)); state.invalidateUserCall(); + + // We cannot guarantee right now that all live values can be remeterialized from non-stack memory locations + // To prevent earlier values from being propagated to after the call, we have to clear the map + // TODO: remove only the values that don't have a guaranteed restore location + state.valueMap.clear(); break; case IrCmd::FORGLOOP: state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified + state.valueMap.clear(); // TODO: this can be relaxed when x64 emitInstForGLoop becomes aware of register allocator break; case IrCmd::FORGLOOP_FALLBACK: state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified @@ -656,6 +794,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::FORGPREP_XNEXT_FALLBACK: // This fallback only conditionally throws an exception break; + + // Full fallback instructions case IrCmd::FALLBACK_GETGLOBAL: state.invalidate(inst.b); state.invalidateUserCall(); @@ -678,7 +818,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::FALLBACK_PREPVARARGS: break; case IrCmd::FALLBACK_GETVARARGS: - state.invalidateRegistersFrom(vmRegOp(inst.b)); + state.invalidateRegisterRange(vmRegOp(inst.b), function.intOp(inst.c)); break; case IrCmd::FALLBACK_NEWCLOSURE: state.invalidate(inst.b); @@ -709,13 +849,17 @@ static void constPropInBlock(IrBuilder& build, IrBlock& block, ConstPropState& s constPropInInst(state, build, function, block, inst, index); } + + // Value numbering and load/store propagation is not performed between blocks + state.valueMap.clear(); } -static void constPropInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block) +static void constPropInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block, bool useValueNumbering) { IrFunction& function = build.function; ConstPropState state{function}; + state.useValueNumbering = useValueNumbering; while (block) { @@ -792,7 +936,7 @@ static std::vector collectDirectBlockJumpPath(IrFunction& function, st return path; } -static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited, IrBlock& startingBlock) +static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited, IrBlock& startingBlock, bool useValueNumbering) { IrFunction& function = build.function; @@ -822,6 +966,7 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited // Initialize state with the knowledge of our current block ConstPropState state{function}; + state.useValueNumbering = useValueNumbering; constPropInBlock(build, startingBlock, state); // Veryfy that target hasn't changed @@ -845,7 +990,7 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited constPropInBlock(build, linearBlock, state); } -void constPropInBlockChains(IrBuilder& build) +void constPropInBlockChains(IrBuilder& build, bool useValueNumbering) { IrFunction& function = build.function; @@ -859,11 +1004,11 @@ void constPropInBlockChains(IrBuilder& build) if (visited[function.getBlockIndex(block)]) continue; - constPropInBlockChain(build, visited, &block); + constPropInBlockChain(build, visited, &block, useValueNumbering); } } -void createLinearBlocks(IrBuilder& build) +void createLinearBlocks(IrBuilder& build, bool useValueNumbering) { // Go through internal block chains and outline them into a single new block. // Outlining will be able to linearize the execution, even if there was a jump to a block with multiple users, @@ -884,7 +1029,7 @@ void createLinearBlocks(IrBuilder& build) if (visited[function.getBlockIndex(block)]) continue; - tryCreateLinearBlock(build, visited, block); + tryCreateLinearBlock(build, visited, block, useValueNumbering); } } diff --git a/CodeGen/src/OptimizeFinalX64.cpp b/CodeGen/src/OptimizeFinalX64.cpp index dd31fcc4f..5ee626ae4 100644 --- a/CodeGen/src/OptimizeFinalX64.cpp +++ b/CodeGen/src/OptimizeFinalX64.cpp @@ -40,7 +40,6 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block) case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: { diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index 047c1b67f..ba4232a01 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -79,6 +79,8 @@ class BytecodeBuilder void setDebugLine(int line); void pushDebugLocal(StringRef name, uint8_t reg, uint32_t startpc, uint32_t endpc); void pushDebugUpval(StringRef name); + + size_t getInstructionCount() const; uint32_t getDebugPC() const; void addDebugRemark(const char* format, ...) LUAU_PRINTF_ATTR(2, 3); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 8e450f4f8..b5690acb3 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -556,6 +556,11 @@ void BytecodeBuilder::pushDebugUpval(StringRef name) debugUpvals.push_back(upval); } +size_t BytecodeBuilder::getInstructionCount() const +{ + return insns.size(); +} + uint32_t BytecodeBuilder::getDebugPC() const { return uint32_t(insns.size()); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9478404a0..9eda214c3 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -25,6 +25,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) +LUAU_FASTFLAGVARIABLE(LuauCompileLimitInsns, false) + namespace Luau { @@ -33,6 +35,7 @@ using namespace Luau::Compile; static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; +static const uint32_t kMaxInstructionCount = 1'000'000'000; static const uint8_t kInvalidReg = 255; @@ -247,6 +250,9 @@ struct Compiler popLocals(0); + if (FFlag::LuauCompileLimitInsns && bytecode.getInstructionCount() > kMaxInstructionCount) + CompileError::raise(func->location, "Exceeded function instruction limit; split the function into parts to compile"); + bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size())); Function& f = functions[func]; diff --git a/Makefile b/Makefile index bbc66c2e7..aead3d32d 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ ifneq ($(opt),) TESTS_ARGS+=-O$(opt) endif -OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) +OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(FUZZ_OBJECTS) EXECUTABLE_ALIASES = luau luau-analyze luau-tests # common flags diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 264388bc9..0f4df6719 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,8 +17,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBetterOOMHandling, false) - /* ** {====================================================== ** Error-recovery functions @@ -82,7 +80,7 @@ class lua_exception : public std::exception const char* what() const throw() override { // LUA_ERRRUN passes error object on the stack - if (status == LUA_ERRRUN || (status == LUA_ERRSYNTAX && !FFlag::LuauBetterOOMHandling)) + if (status == LUA_ERRRUN) if (const char* str = lua_tostring(L, -1)) return str; @@ -552,30 +550,21 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e // call user-defined error function (used in xpcall) if (ef) { - if (FFlag::LuauBetterOOMHandling) - { - // push error object to stack top if it's not already there - if (status != LUA_ERRRUN) - seterrorobj(L, status, L->top); - - // if errfunc fails, we fail with "error in error handling" or "not enough memory" - int err = luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)); - - // in general we preserve the status, except for cases when the error handler fails - // out of memory is treated specially because it's common for it to be cascading, in which case we preserve the code - if (err == 0) - errstatus = LUA_ERRRUN; - else if (status == LUA_ERRMEM && err == LUA_ERRMEM) - errstatus = LUA_ERRMEM; - else - errstatus = status = LUA_ERRERR; - } + // push error object to stack top if it's not already there + if (status != LUA_ERRRUN) + seterrorobj(L, status, L->top); + + // if errfunc fails, we fail with "error in error handling" or "not enough memory" + int err = luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)); + + // in general we preserve the status, except for cases when the error handler fails + // out of memory is treated specially because it's common for it to be cascading, in which case we preserve the code + if (err == 0) + errstatus = LUA_ERRRUN; + else if (status == LUA_ERRMEM && err == LUA_ERRMEM) + errstatus = LUA_ERRMEM; else - { - // if errfunc fails, we fail with "error in error handling" - if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) - status = LUA_ERRERR; - } + errstatus = status = LUA_ERRERR; } // since the call failed with an error, we might have to reset the 'active' thread state @@ -597,7 +586,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); // close eventual pending closures - seterrorobj(L, FFlag::LuauBetterOOMHandling ? errstatus : status, oldtop); + seterrorobj(L, errstatus, oldtop); L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 4443be34f..9bc624e93 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,8 +10,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauIntrosort, false) - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -389,7 +387,7 @@ static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredic while (l < u) { // if the limit has been reached, quick sort is going over the permitted nlogn complexity, so we fall back to heap sort - if (FFlag::LuauIntrosort && limit == 0) + if (limit == 0) return sort_heap(L, t, l, u, pred); // sort elements a[l], a[(l+u)/2] and a[u] @@ -435,43 +433,20 @@ static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredic // swap pivot a[p] with a[i], which is the new midpoint sort_swap(L, t, p, i); - if (FFlag::LuauIntrosort) - { - // adjust limit to allow 1.5 log2N recursive steps - limit = (limit >> 1) + (limit >> 2); + // adjust limit to allow 1.5 log2N recursive steps + limit = (limit >> 1) + (limit >> 2); - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // sort smaller half recursively; the larger half is sorted in the next loop iteration - if (i - l < u - i) - { - sort_rec(L, t, l, i - 1, limit, pred); - l = i + 1; - } - else - { - sort_rec(L, t, i + 1, u, limit, pred); - u = i - 1; - } + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // sort smaller half recursively; the larger half is sorted in the next loop iteration + if (i - l < u - i) + { + sort_rec(L, t, l, i - 1, limit, pred); + l = i + 1; } else { - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // adjust so that smaller half is in [j..i] and larger one in [l..u] - if (i - l < u - i) - { - j = l; - i = i - 1; - l = i + 2; - } - else - { - j = i + 1; - i = u; - u = j - 2; - } - - // sort smaller half recursively; the larger half is sorted in the next loop iteration - sort_rec(L, t, j, i, limit, pred); + sort_rec(L, t, i + 1, u, limit, pred); + u = i - 1; } } } diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 5dafb6b90..082fe7a18 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -354,6 +354,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath") SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); + SINGLE_COMPARE(fcvt(s1, d2), 0x1E624041); + SINGLE_COMPARE(fcvt(d1, s2), 0x1E22C041); + SINGLE_COMPARE(fcvtzs(w1, d2), 0x1E780041); SINGLE_COMPARE(fcvtzs(x1, d2), 0x9E780041); SINGLE_COMPARE(fcvtzu(w1, d2), 0x1E790041); @@ -384,16 +387,20 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPLoadStore") SINGLE_COMPARE(str(d0, mem(x1, -7)), 0xFC1F9020); // load/store sizes + SINGLE_COMPARE(ldr(s0, x1), 0xBD400020); SINGLE_COMPARE(ldr(d0, x1), 0xFD400020); SINGLE_COMPARE(ldr(q0, x1), 0x3DC00020); + SINGLE_COMPARE(str(s0, x1), 0xBD000020); SINGLE_COMPARE(str(d0, x1), 0xFD000020); SINGLE_COMPARE(str(q0, x1), 0x3D800020); // load/store sizes x offset scaling SINGLE_COMPARE(ldr(q0, mem(x1, 16)), 0x3DC00420); SINGLE_COMPARE(ldr(d0, mem(x1, 16)), 0xFD400820); + SINGLE_COMPARE(ldr(s0, mem(x1, 16)), 0xBD401020); SINGLE_COMPARE(str(q0, mem(x1, 16)), 0x3D800420); SINGLE_COMPARE(str(d0, mem(x1, 16)), 0xFD000820); + SINGLE_COMPARE(str(s0, mem(x1, 16)), 0xBD001020); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPCompare") @@ -471,6 +478,8 @@ TEST_CASE("LogTest") build.fmov(d0, 0.25); build.tbz(x0, 5, l); + build.fcvt(s1, d2); + build.setLabel(l); build.ret(); @@ -502,6 +511,7 @@ TEST_CASE("LogTest") fcmp d0,#0 fmov d0,#0.25 tbz x0,#5,.L1 + fcvt s1,d2 .L1: ret )"; diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 2464d324b..fc802e12c 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -408,8 +408,6 @@ static int cxxthrow(lua_State* L) TEST_CASE("PCall") { - ScopedFastFlag sff("LuauBetterOOMHandling", true); - runConformance( "pcall.lua", [](lua_State* L) { @@ -504,7 +502,7 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) for (const auto& [name, prop] : t->props) { - populateRTTI(L, prop.type); + populateRTTI(L, prop.type()); lua_setfield(L, -2, name.c_str()); } } @@ -1012,8 +1010,6 @@ TEST_CASE("ApiCalls") lua_pop(L, 1); } - ScopedFastFlag sff("LuauBetterOOMHandling", true); - // lua_pcall on OOM { lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index 0823eabed..7b9339889 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -23,8 +23,8 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) { AstStatBlock* root = parse(code); dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); - cgb = std::make_unique( - mainModule, &arena, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), frontend.globals.globalScope, &logger, NotNull{dfg.get()}); + cgb = std::make_unique(mainModule, &arena, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), + frontend.globals.globalScope, /*prepareModuleScope*/ nullptr, &logger, NotNull{dfg.get()}); cgb->visit(root); rootScope = cgb->rootScope; constraints = Luau::borrowConstraints(cgb->constraints); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index aba2891e2..c6fc475b2 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -21,7 +21,6 @@ static const char* mainModuleName = "MainModule"; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauOnDemandTypecheckers); extern std::optional randomSeed; // tests/main.cpp @@ -177,13 +176,7 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars if (FFlag::DebugLuauDeferredConstraintResolution) { ModulePtr module = Luau::check(*sourceModule, {}, builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, - frontend.globals.globalScope, frontend.options); - - Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); - } - else if (!FFlag::LuauOnDemandTypecheckers) - { - ModulePtr module = frontend.typeChecker_DEPRECATED.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + frontend.globals.globalScope, /*prepareModuleScope*/ nullptr, frontend.options); Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); } diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index f4b9f627e..0b9c872c2 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1129,4 +1129,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_type_alias") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "module_scope_check") +{ + frontend.prepareModuleScope = [this](const ModuleName& name, const ScopePtr& scope, bool forAutocomplete) { + scope->bindings[Luau::AstName{"x"}] = Luau::Binding{frontend.globals.builtinTypes->numberType}; + }; + + fileResolver.source["game/A"] = R"( + local a = x + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + auto ty = requireType("game/A", "a"); + CHECK_EQ(toString(ty), "number"); +} + TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 1419c9512..f09f174a1 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -294,28 +294,75 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") build.beginBlock(block); build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::ADD_INT, build.constInt(10), build.constInt(20))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::ADD_INT, build.constInt(INT_MAX), build.constInt(1))); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::ADD_INT, build.constInt(INT_MAX), build.constInt(1))); + + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::SUB_INT, build.constInt(10), build.constInt(20))); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::SUB_INT, build.constInt(INT_MIN), build.constInt(1))); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(4), build.inst(IrCmd::ADD_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::SUB_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(6), build.inst(IrCmd::MUL_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), build.inst(IrCmd::DIV_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(8), build.inst(IrCmd::MOD_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(10), build.inst(IrCmd::MIN_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(11), build.inst(IrCmd::MAX_NUM, build.constDouble(5), build.constDouble(2))); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(12), build.inst(IrCmd::UNM_NUM, build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(13), build.inst(IrCmd::FLOOR_NUM, build.constDouble(2.5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(14), build.inst(IrCmd::CEIL_NUM, build.constDouble(2.5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(15), build.inst(IrCmd::ROUND_NUM, build.constDouble(2.5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(16), build.inst(IrCmd::SQRT_NUM, build.constDouble(16))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(17), build.inst(IrCmd::ABS_NUM, build.constDouble(-4))); + + build.inst(IrCmd::STORE_INT, build.vmReg(18), build.inst(IrCmd::NOT_ANY, build.constTag(tnil), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); + build.inst( + IrCmd::STORE_INT, build.vmReg(19), build.inst(IrCmd::NOT_ANY, build.constTag(tnumber), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); + build.inst(IrCmd::STORE_INT, build.vmReg(20), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(21), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(1))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::SUB_INT, build.constInt(10), build.constInt(20))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::SUB_INT, build.constInt(INT_MIN), build.constInt(1))); + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + STORE_INT R0, 30i + STORE_INT R1, -2147483648i + STORE_INT R2, -10i + STORE_INT R3, 2147483647i + STORE_DOUBLE R4, 7 + STORE_DOUBLE R5, -3 + STORE_DOUBLE R6, 10 + STORE_DOUBLE R7, 0.40000000000000002 + STORE_DOUBLE R8, 1 + STORE_DOUBLE R10, 2 + STORE_DOUBLE R11, 5 + STORE_DOUBLE R12, -5 + STORE_DOUBLE R13, 2 + STORE_DOUBLE R14, 3 + STORE_DOUBLE R15, 3 + STORE_DOUBLE R16, 4 + STORE_DOUBLE R17, 4 + STORE_INT R18, 1i + STORE_INT R19, 0i + STORE_INT R20, 1i + STORE_INT R21, 0i + RETURN 0u - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::ADD_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::SUB_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MUL_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::DIV_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MOD_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::POW_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, build.constDouble(5), build.constDouble(2))); +)"); +} - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::UNM_NUM, build.constDouble(5))); +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericConversions") +{ + IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tnil), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tnumber), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(0))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(1))); + build.beginBlock(block); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::INT_TO_NUM, build.constInt(8))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.inst(IrCmd::UINT_TO_NUM, build.constInt(0xdeee0000u))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::NUM_TO_INT, build.constDouble(200.0))); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::NUM_TO_UINT, build.constDouble(3740139520.0))); build.inst(IrCmd::RETURN, build.constUint(0)); @@ -324,24 +371,190 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: - STORE_INT R0, 30i - STORE_INT R0, -2147483648i - STORE_INT R0, -10i - STORE_INT R0, 2147483647i - STORE_DOUBLE R0, 7 - STORE_DOUBLE R0, -3 - STORE_DOUBLE R0, 10 - STORE_DOUBLE R0, 0.40000000000000002 - STORE_DOUBLE R0, 1 - STORE_DOUBLE R0, 25 - STORE_DOUBLE R0, 2 - STORE_DOUBLE R0, 5 - STORE_DOUBLE R0, -5 - STORE_INT R0, 1i - STORE_INT R0, 0i - STORE_INT R0, 1i - STORE_INT R0, 0i STORE_DOUBLE R0, 8 + STORE_DOUBLE R1, 3740139520 + STORE_INT R2, 200i + STORE_INT R3, -554827776i + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericConversionsBlocked") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INT, build.constDouble(1e20))); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::NUM_TO_UINT, build.constDouble(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::NUM_TO_INT, nan)); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::NUM_TO_UINT, nan)); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %1 = NUM_TO_INT 1e+20 + STORE_INT R0, %1 + %3 = NUM_TO_UINT -10 + STORE_INT R1, %3 + %5 = NUM_TO_INT nan + STORE_INT R2, %5 + %7 = NUM_TO_UINT nan + STORE_INT R3, %7 + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp unk = build.inst(IrCmd::LOAD_INT, build.vmReg(0)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::BITAND_UINT, build.constInt(0xfe), build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::BITAND_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::BITAND_UINT, build.constInt(0), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::BITAND_UINT, unk, build.constInt(~0u))); + build.inst(IrCmd::STORE_INT, build.vmReg(4), build.inst(IrCmd::BITAND_UINT, build.constInt(~0u), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(5), build.inst(IrCmd::BITXOR_UINT, build.constInt(0xfe), build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(6), build.inst(IrCmd::BITXOR_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(7), build.inst(IrCmd::BITXOR_UINT, build.constInt(0), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(8), build.inst(IrCmd::BITXOR_UINT, unk, build.constInt(~0u))); + build.inst(IrCmd::STORE_INT, build.vmReg(9), build.inst(IrCmd::BITXOR_UINT, build.constInt(~0u), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITOR_UINT, build.constInt(0xf0), build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(11), build.inst(IrCmd::BITOR_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(12), build.inst(IrCmd::BITOR_UINT, build.constInt(0), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(13), build.inst(IrCmd::BITOR_UINT, unk, build.constInt(~0u))); + build.inst(IrCmd::STORE_INT, build.vmReg(14), build.inst(IrCmd::BITOR_UINT, build.constInt(~0u), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(15), build.inst(IrCmd::BITNOT_UINT, build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(16), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf0), build.constInt(4))); + build.inst(IrCmd::STORE_INT, build.vmReg(17), build.inst(IrCmd::BITLSHIFT_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(18), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(19), build.inst(IrCmd::BITRSHIFT_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(20), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(21), build.inst(IrCmd::BITARSHIFT_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(22), build.inst(IrCmd::BITLROTATE_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(23), build.inst(IrCmd::BITLROTATE_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(24), build.inst(IrCmd::BITRROTATE_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(25), build.inst(IrCmd::BITRROTATE_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(26), build.inst(IrCmd::BITCOUNTLZ_UINT, build.constInt(0xff00))); + build.inst(IrCmd::STORE_INT, build.vmReg(27), build.inst(IrCmd::BITCOUNTLZ_UINT, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(28), build.inst(IrCmd::BITCOUNTRZ_UINT, build.constInt(0xff00))); + build.inst(IrCmd::STORE_INT, build.vmReg(29), build.inst(IrCmd::BITCOUNTRZ_UINT, build.constInt(0))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_INT R0 + STORE_INT R0, 14i + STORE_INT R1, 0i + STORE_INT R2, 0i + STORE_INT R3, %0 + STORE_INT R4, %0 + STORE_INT R5, 240i + STORE_INT R6, %0 + STORE_INT R7, %0 + %17 = BITNOT_UINT %0 + STORE_INT R8, %17 + %19 = BITNOT_UINT %0 + STORE_INT R9, %19 + STORE_INT R10, 254i + STORE_INT R11, %0 + STORE_INT R12, %0 + STORE_INT R13, -1i + STORE_INT R14, -1i + STORE_INT R15, -15i + STORE_INT R16, 3840i + STORE_INT R17, %0 + STORE_INT R18, 14609920i + STORE_INT R19, %0 + STORE_INT R20, -2167296i + STORE_INT R21, %0 + STORE_INT R22, -301989666i + STORE_INT R23, %0 + STORE_INT R24, 14609920i + STORE_INT R25, %0 + STORE_INT R26, 16i + STORE_INT R27, 32i + STORE_INT R28, 8i + STORE_INT R29, 32i + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32Blocked") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf), build.constInt(140))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xf), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xf), build.constInt(140))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xf), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xf), build.constInt(140))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = BITLSHIFT_UINT 15i, -10i + STORE_INT R10, %0 + %2 = BITLSHIFT_UINT 15i, 140i + STORE_INT R10, %2 + %4 = BITRSHIFT_UINT 15i, -10i + STORE_INT R10, %4 + %6 = BITRSHIFT_UINT 15i, 140i + STORE_INT R10, %6 + %8 = BITARSHIFT_UINT 15i, -10i + STORE_INT R10, %8 + %10 = BITARSHIFT_UINT 15i, 140i + STORE_INT R10, %10 + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericNan") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, nan, build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, build.constDouble(1), nan)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, nan, build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, build.constDouble(1), nan)); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + STORE_DOUBLE R0, 2 + STORE_DOUBLE R0, nan + STORE_DOUBLE R0, 2 + STORE_DOUBLE R0, nan RETURN 0u )"); @@ -571,7 +784,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -589,10 +802,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues") STORE_DOUBLE R2, %16 %18 = LOAD_TAG R0 STORE_TAG R9, %18 - %20 = LOAD_INT R1 - STORE_INT R10, %20 - %22 = LOAD_DOUBLE R2 - STORE_DOUBLE R11, %22 + STORE_INT R10, %14 + STORE_DOUBLE R11, %16 RETURN 0u )"); @@ -617,7 +828,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -647,7 +858,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag") build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -674,7 +885,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipOncePerBlockChecks") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -713,7 +924,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState") build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -745,7 +956,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -776,7 +987,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -825,7 +1036,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -858,7 +1069,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -888,7 +1099,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation") build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -920,7 +1131,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagationConflicting") build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -956,7 +1167,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval") build.inst(IrCmd::RETURN, build.constUint(3)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -995,7 +1206,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval") build.inst(IrCmd::RETURN, build.constUint(3)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1030,7 +1241,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval") build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1062,7 +1273,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1093,7 +1304,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval") build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1121,7 +1332,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1154,7 +1365,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique build.inst(IrCmd::JUMP, block2); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1190,7 +1401,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") build.inst(IrCmd::JUMP, entry); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1225,7 +1436,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") build.inst(IrCmd::JUMP, block); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1267,7 +1478,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") build.inst(IrCmd::JUMP, block); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1325,8 +1536,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); - constPropInBlockChains(build); - createLinearBlocks(build); + constPropInBlockChains(build, true); + createLinearBlocks(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1401,8 +1612,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); - constPropInBlockChains(build); - createLinearBlocks(build); + constPropInBlockChains(build, true); + createLinearBlocks(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1453,8 +1664,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InfiniteLoopInPathAnalysis") build.inst(IrCmd::JUMP, block2); updateUseCounts(build.function); - constPropInBlockChains(build); - createLinearBlocks(build); + constPropInBlockChains(build, true); + createLinearBlocks(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1468,6 +1679,38 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InfiniteLoopInPathAnalysis") )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "PartialStoreInvalidation") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); // Should be reloaded + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tboolean)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); // Should be reloaded + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_TVALUE R0 + STORE_TVALUE R1, %0 + STORE_DOUBLE R0, 0.5 + %3 = LOAD_TVALUE R0 + STORE_TVALUE R1, %3 + STORE_TAG R0, tboolean + %6 = LOAD_TVALUE R0 + STORE_TVALUE R1, %6 + RETURN 0u + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Analysis"); @@ -1777,7 +2020,6 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinVariadicStart") )"); } - TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") { IrOp entry = build.block(IrBlockKind::Internal); @@ -1799,3 +2041,192 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("ValueNumbering"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "RemoveDuplicateCalculation") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op2 = build.inst(IrCmd::UNM_NUM, op1); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), op2); + IrOp op3 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); // Load propagation is tested here + IrOp op4 = build.inst(IrCmd::UNM_NUM, op3); // And allows value numbering to trigger here + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), op4); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + %1 = UNM_NUM %0 + STORE_DOUBLE R1, %1 + STORE_DOUBLE R2, %1 + RETURN R1, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "LateTableStateLink") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + IrOp tmp = build.inst(IrCmd::DUP_TABLE, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), tmp); // Late tmp -> R0 link is tested here + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); // Store to load propagation test + + build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); + build.inst(IrCmd::CHECK_READONLY, table, fallback); + + build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); + build.inst(IrCmd::CHECK_READONLY, table, fallback); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = DUP_TABLE R0 + STORE_POINTER R0, %0 + CHECK_NO_METATABLE %0, bb_fallback_1 + CHECK_READONLY %0, bb_fallback_1 + RETURN 0u + +bb_fallback_1: + RETURN 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "RegisterVersioning") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op2 = build.inst(IrCmd::UNM_NUM, op1); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), op2); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); // Doesn't prevent previous store propagation + IrOp op3 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); // No longer 'op1' + IrOp op4 = build.inst(IrCmd::UNM_NUM, op3); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), op4); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + %1 = UNM_NUM %0 + STORE_DOUBLE R0, %1 + STORE_TAG R0, tnumber + %5 = UNM_NUM %1 + STORE_DOUBLE R1, %5 + RETURN R0, 2i + +)"); +} + +// This can be relaxed in the future when SETLIST becomes aware of register allocator +TEST_CASE_FIXTURE(IrBuilderFixture, "SetListIsABlocker") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + build.inst(IrCmd::SETLIST); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, op1, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), sum); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + SETLIST + %2 = LOAD_DOUBLE R0 + %3 = ADD_NUM %0, %2 + STORE_DOUBLE R0, %3 + RETURN R0, 1i + +)"); +} + +// Luau call will reuse the same stack and spills will be lost +// However, in the future we might propagate values that can be rematerialized +TEST_CASE_FIXTURE(IrBuilderFixture, "CallIsABlocker") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + build.inst(IrCmd::CALL, build.vmReg(1), build.constInt(1), build.vmReg(2), build.constInt(1)); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, op1, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), sum); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + CALL R1, 1i, R2, 1i + %2 = LOAD_DOUBLE R0 + %3 = ADD_NUM %0, %2 + STORE_DOUBLE R1, %3 + RETURN R1, 2i + +)"); +} + +// While constant propagation correctly versions captured registers, IrValueLocationTracking doesn't (yet) +TEST_CASE_FIXTURE(IrBuilderFixture, "NoPropagationOfCapturedRegs") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::CAPTURE, build.vmReg(0), build.constBool(true)); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, op1, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), sum); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +; captured regs: R0 + +bb_0: +; in regs: R0 + CAPTURE R0, true + %1 = LOAD_DOUBLE R0 + %2 = LOAD_DOUBLE R0 + %3 = ADD_NUM %1, %2 + STORE_DOUBLE R1, %3 + RETURN R1, 1i + +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 00cf5cad1..22530a25e 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -137,7 +137,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") CHECK_EQ(std::optional{"Cyclic"}, ttv->syntheticName); - TypeId methodType = ttv->props["get"].type; + TypeId methodType = ttv->props["get"].type(); REQUIRE(methodType != nullptr); const FunctionType* ftv = get(methodType); @@ -161,7 +161,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table_2") TypeId methodTy = src.addType(FunctionType{src.addTypePack({}), src.addTypePack({tableTy})}); - tt->props["get"].type = methodTy; + tt->props["get"].setType(methodTy); TypeArena dest; @@ -170,7 +170,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table_2") TableType* ctt = getMutable(cloneTy); REQUIRE(ctt); - TypeId clonedMethodType = ctt->props["get"].type; + TypeId clonedMethodType = ctt->props["get"].type(); REQUIRE(clonedMethodType); const FunctionType* cmf = get(clonedMethodType); @@ -199,7 +199,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") TableType* exportsTable = getMutable(*exports); REQUIRE(exportsTable != nullptr); - TypeId signType = exportsTable->props["sign"].type; + TypeId signType = exportsTable->props["sign"].type(); REQUIRE(signType != nullptr); CHECK(!isInArena(signType, module->interfaceTypes)); @@ -340,8 +340,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") { TableType* ttv = getMutable(nested); - ttv->props["a"].type = src.addType(TableType{}); - nested = ttv->props["a"].type; + ttv->props["a"].setType(src.addType(TableType{})); + nested = ttv->props["a"].type(); } TypeArena dest; @@ -411,7 +411,7 @@ return {} TypeId typeB = modBiter->second.type; TableType* tableB = getMutable(typeB); REQUIRE(tableB); - CHECK(typeA == tableB->props["q"].type); + CHECK(typeA == tableB->props["q"].type()); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") @@ -447,7 +447,7 @@ return exports REQUIRE(typeB); TableType* tableA = getMutable(*typeA); TableType* tableB = getMutable(*typeB); - CHECK(tableA->props["a"].type == tableB->props["b"].type); + CHECK(tableA->props["a"].type() == tableB->props["b"].type()); } TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index fddab8002..7130a717f 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -170,7 +170,7 @@ TEST_CASE_FIXTURE(Fixture, "table_props_are_any") REQUIRE(ttv != nullptr); REQUIRE(ttv->props.count("foo")); - TypeId fooProp = ttv->props["foo"].type; + TypeId fooProp = ttv->props["foo"].type(); REQUIRE(fooProp != nullptr); CHECK_EQ(*fooProp, *builtinTypes->anyType); @@ -192,9 +192,9 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") TableType* ttv = getMutable(requireType("T")); REQUIRE_MESSAGE(ttv, "Should be a table: " << toString(requireType("T"))); - CHECK_EQ(*builtinTypes->anyType, *ttv->props["one"].type); - CHECK_EQ(*builtinTypes->anyType, *ttv->props["two"].type); - CHECK_MESSAGE(get(follow(ttv->props["three"].type)), "Should be a function: " << *ttv->props["three"].type); + CHECK_EQ(*builtinTypes->anyType, *ttv->props["one"].type()); + CHECK_EQ(*builtinTypes->anyType, *ttv->props["two"].type()); + CHECK_MESSAGE(get(follow(ttv->props["three"].type())), "Should be a function: " << *ttv->props["three"].type()); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_iterator_variables_are_any") diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index fd245395b..160757e2d 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -514,14 +514,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") REQUIRE(tMeta2); REQUIRE(tMeta2->props.count("__index")); - const MetatableType* tMeta3 = get(tMeta2->props["__index"].type); + const MetatableType* tMeta3 = get(tMeta2->props["__index"].type()); REQUIRE(tMeta3); TableType* tMeta4 = getMutable(tMeta3->metatable); REQUIRE(tMeta4); REQUIRE(tMeta4->props.count("__index")); - TableType* tMeta5 = getMutable(tMeta4->props["__index"].type); + TableType* tMeta5 = getMutable(tMeta4->props["__index"].type()); REQUIRE(tMeta5); REQUIRE(tMeta5->props.count("one") > 0); @@ -529,9 +529,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") REQUIRE(tMeta6); REQUIRE(tMeta6->props.count("two") > 0); - ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts); + ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type(), opts); - std::string twoResult = toString(tMeta6->props["two"].type, opts); + std::string twoResult = toString(tMeta6->props["two"].type(), opts); CHECK_EQ("(a) -> number", oneResult.name); CHECK_EQ("(b) -> number", twoResult); @@ -786,7 +786,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") TypeId parentTy = requireType("foo"); auto ttv = get(follow(parentTy)); - auto ftv = get(follow(ttv->props.at("method").type)); + auto ftv = get(follow(ttv->props.at("method").type())); CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); } @@ -809,7 +809,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") TypeId parentTy = requireType("foo"); auto ttv = get(follow(parentTy)); REQUIRE_MESSAGE(ttv, "Expected a table but got " << toString(parentTy, opts)); - TypeId methodTy = follow(ttv->props.at("method").type); + TypeId methodTy = follow(ttv->props.at("method").type()); auto ftv = get(methodTy); REQUIRE_MESSAGE(ftv, "Expected a function but got " << toString(methodTy, opts)); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 3de529998..84b057d5e 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -330,7 +330,7 @@ TEST_CASE_FIXTURE(Fixture, "self_referential_type_alias") std::optional incr = get(oTable->props, "incr"); REQUIRE(incr); - const FunctionType* incrFunc = get(incr->type); + const FunctionType* incrFunc = get(incr->type()); REQUIRE(incrFunc); std::optional firstArg = first(incrFunc->argTypes); @@ -493,7 +493,7 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") TableType* exportsTable = getMutable(*exportsType); REQUIRE(exportsTable != nullptr); - TypeId n = exportsTable->props["n"].type; + TypeId n = exportsTable->props["n"].type(); REQUIRE(n != nullptr); CHECK(isInArena(n, mod.interfaceTypes)); @@ -548,10 +548,10 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti TableType* exportsTable = getMutable(*exportsType); REQUIRE(exportsTable != nullptr); - TypeId aType = exportsTable->props["a"].type; + TypeId aType = exportsTable->props["a"].type(); REQUIRE(aType); - TypeId bType = exportsTable->props["b"].type; + TypeId bType = exportsTable->props["b"].type(); REQUIRE(bType); CHECK(isInArena(recordType, mod.interfaceTypes)); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index c6766cada..687bc766d 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -195,7 +195,7 @@ TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") REQUIRE(ttv); REQUIRE(ttv->props.count("prop")); - REQUIRE_EQ("any", toString(ttv->props["prop"].type)); + REQUIRE_EQ("any", toString(ttv->props["prop"].type())); } TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 79d9108d3..07cf5393a 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1031,7 +1031,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_persistent_typelevel_change") REQUIRE(mathTy); TableType* ttv = getMutable(mathTy); REQUIRE(ttv); - const FunctionType* ftv = get(ttv->props["frexp"].type); + const FunctionType* ftv = get(ttv->props["frexp"].type()); REQUIRE(ftv); auto original = ftv->level; diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 942ce191f..9086a6049 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -111,7 +111,7 @@ TEST_CASE_FIXTURE(Fixture, "generalize_table_property") const TableType* tt = get(follow(t)); REQUIRE(tt); - TypeId fooTy = tt->props.at("foo").type; + TypeId fooTy = tt->props.at("foo").type(); CHECK("(a) -> a" == toString(fooTy)); } @@ -156,7 +156,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") REQUIRE(ttv); REQUIRE(ttv->props.count("f")); - TypeId k = ttv->props["f"].type; + TypeId k = ttv->props["f"].type(); REQUIRE(k); } diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index b97848176..5ab27f645 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -865,7 +865,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_table_method") REQUIRE(tTable != nullptr); REQUIRE(tTable->props.count("bar")); - TypeId barType = tTable->props["bar"].type; + TypeId barType = tTable->props["bar"].type(); REQUIRE(barType != nullptr); const FunctionType* ftv = get(follow(barType)); @@ -900,7 +900,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") std::optional fooProp = get(t->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionType* foo = get(follow(fooProp->type)); + const FunctionType* foo = get(follow(fooProp->type())); REQUIRE(bool(foo)); std::optional ret_ = first(foo->retTypes); @@ -947,7 +947,7 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") std::optional methodProp = get(argTable->props, "method"); REQUIRE(bool(methodProp)); - const FunctionType* methodFunction = get(methodProp->type); + const FunctionType* methodFunction = get(methodProp->type()); REQUIRE(methodFunction != nullptr); std::optional methodArg = first(methodFunction->argTypes); diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp index adf036532..af73607a9 100644 --- a/tests/TypeInfer.negations.test.cpp +++ b/tests/TypeInfer.negations.test.cpp @@ -2,6 +2,7 @@ #include "Fixture.h" +#include "Luau/ToString.h" #include "doctest.h" #include "Luau/Common.h" #include "ScopedFlags.h" @@ -47,4 +48,32 @@ TEST_CASE_FIXTURE(NegationFixture, "string_is_not_a_subtype_of_negated_string") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "cofinite_strings_can_be_compared_for_equality") +{ + CheckResult result = check(R"( + function f(e) + if e == 'strictEqual' then + e = 'strictEqualObject' + end + if e == 'deepStrictEqual' or e == 'strictEqual' then + elseif e == 'notDeepStrictEqual' or e == 'notStrictEqual' then + end + return e + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(string) -> string" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(NegationFixture, "compare_cofinite_strings") +{ + CheckResult result = check(R"( +local u : Not<"a"> +local v : "b" +if u == v then +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 890e9b693..06cbe0cf3 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1749,4 +1749,36 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_annotations_arent_relevant_when_doing_d CHECK_EQ("nil", toString(requireTypeAtPosition({9, 28}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "function_call_with_colon_after_refining_not_to_be_nil") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + --!strict + export type Observer = { + complete: ((self: Observer) -> ())?, + } + + local function _f(handler: Observer) + assert(handler.complete ~= nil) + handler:complete() -- incorrectly gives Value of type '((Observer) -> ())?' could be nil + handler.complete(handler) -- works fine, both forms should avoid the error + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "refinements_should_not_affect_assignment") +{ + CheckResult result = check(R"( + local a: unknown = true + if a == true then + a = 'not even remotely similar to a boolean' + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 468adc2c6..fcf2c8a4a 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -31,15 +31,15 @@ TEST_CASE_FIXTURE(Fixture, "basic") std::optional fooProp = get(tType->props, "foo"); REQUIRE(bool(fooProp)); - CHECK_EQ(PrimitiveType::String, getPrimitiveType(fooProp->type)); + CHECK_EQ(PrimitiveType::String, getPrimitiveType(fooProp->type())); std::optional bazProp = get(tType->props, "baz"); REQUIRE(bool(bazProp)); - CHECK_EQ(PrimitiveType::Number, getPrimitiveType(bazProp->type)); + CHECK_EQ(PrimitiveType::Number, getPrimitiveType(bazProp->type())); std::optional quuxProp = get(tType->props, "quux"); REQUIRE(bool(quuxProp)); - CHECK_EQ(PrimitiveType::NilType, getPrimitiveType(quuxProp->type)); + CHECK_EQ(PrimitiveType::NilType, getPrimitiveType(quuxProp->type())); } TEST_CASE_FIXTURE(Fixture, "augment_table") @@ -65,7 +65,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table") REQUIRE(tType != nullptr); REQUIRE(tType->props.find("p") != tType->props.end()); - const TableType* pType = get(tType->props["p"].type); + const TableType* pType = get(tType->props["p"].type()); REQUIRE(pType != nullptr); CHECK("{ p: { foo: string } }" == toString(requireType("t"), {true})); @@ -159,7 +159,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function") std::optional fooProp = get(tableType->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionType* methodType = get(follow(fooProp->type)); + const FunctionType* methodType = get(follow(fooProp->type())); REQUIRE(methodType != nullptr); } @@ -173,7 +173,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function_2") std::optional uProp = get(tableType->props, "U"); REQUIRE(bool(uProp)); - TypeId uType = uProp->type; + TypeId uType = uProp->type(); const TableType* uTable = get(uType); REQUIRE(uTable != nullptr); @@ -181,7 +181,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function_2") std::optional fooProp = get(uTable->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionType* methodType = get(follow(fooProp->type)); + const FunctionType* methodType = get(follow(fooProp->type())); REQUIRE(methodType != nullptr); std::vector methodArgs = flatten(methodType->argTypes).first; @@ -935,7 +935,7 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s REQUIRE(tableType->indexer == std::nullopt); REQUIRE(0 != tableType->props.count("a")); - TypeId propertyA = tableType->props["a"].type; + TypeId propertyA = tableType->props["a"].type(); REQUIRE(propertyA != nullptr); CHECK_EQ(*builtinTypes->stringType, *propertyA); } @@ -1925,8 +1925,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") REQUIRE(ttv); REQUIRE(ttv->props.count("new")); Property& prop = ttv->props["new"]; - REQUIRE(prop.type); - const FunctionType* ftv = get(follow(prop.type)); + REQUIRE(prop.type()); + const FunctionType* ftv = get(follow(prop.type())); REQUIRE(ftv); const TypePack* res = get(follow(ftv->retTypes)); REQUIRE(res); @@ -2647,7 +2647,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_sc REQUIRE(counterType); REQUIRE(counterType->props.count("new")); - const FunctionType* newType = get(follow(counterType->props["new"].type)); + const FunctionType* newType = get(follow(counterType->props["new"].type())); REQUIRE(newType); std::optional newRetType = *first(newType->retTypes); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 5a9c77d40..fa52a7466 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -101,7 +101,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); state.tryUnify(&tableTwo, &tableOne); @@ -110,7 +110,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") state.log.commit(); - CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_EQ(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); } TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") @@ -129,14 +129,14 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") TableState::Unsealed}, }}; - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); state.tryUnify(&tableTwo, &tableOne); CHECK(state.failure); CHECK_EQ(1, state.errors.size()); - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); } TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_never") diff --git a/tools/faillist.txt b/tools/faillist.txt index 18bc0c70a..38fa7f5f8 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -12,8 +12,6 @@ BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_can BuiltinTests.bad_select_should_not_crash BuiltinTests.dont_add_definitions_to_persistent_types BuiltinTests.gmatch_definition -BuiltinTests.match_capture_types -BuiltinTests.match_capture_types2 BuiltinTests.math_max_checks_for_numbers BuiltinTests.select_slightly_out_of_range BuiltinTests.select_way_out_of_range @@ -33,6 +31,7 @@ GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_infer_generic_functions +GenericsTests.dont_unify_bound_types GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_functions_should_be_memory_safe @@ -86,7 +85,6 @@ TableTests.give_up_after_one_metatable_index_look_up TableTests.indexer_on_sealed_table_must_unify_with_free_table TableTests.indexing_from_a_table_should_prefer_properties_when_possible TableTests.inequality_operators_imply_exactly_matching_types -TableTests.infer_array_2 TableTests.inferred_return_type_of_free_table TableTests.instantiate_table_cloning_3 TableTests.leaking_bad_metatable_errors @@ -138,7 +136,6 @@ TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.no_stack_overflow_from_isoptional TypeInfer.no_stack_overflow_from_isoptional2 -TypeInfer.should_be_able_to_infer_this_without_stack_overflowing TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer @@ -207,11 +204,9 @@ TypePackTests.unify_variadic_tails_in_arguments TypePackTests.variadic_packs TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch -TypeSingletons.indexing_on_union_of_string_singletons TypeSingletons.no_widening_from_callsites TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.table_properties_type_error_escapes -TypeSingletons.taking_the_length_of_union_of_string_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeSingletons.widening_happens_almost_everywhere UnionTypes.generic_function_with_optional_arg diff --git a/tools/test_dcr.py b/tools/test_dcr.py index 817d08313..208096fa4 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -108,10 +108,10 @@ def main(): help="Write a new faillist.txt after running tests.", ) parser.add_argument( - "--lti", - dest="lti", + "--rwp", + dest="rwp", action="store_true", - help="Run the tests with local type inference enabled.", + help="Run the tests with read-write properties enabled.", ) parser.add_argument("--randomize", action="store_true", help="Pick a random seed") @@ -126,17 +126,17 @@ def main(): args = parser.parse_args() - if args.write and args.lti: + if args.write and args.rwp: print_stderr( - "Cannot run test_dcr.py with --write *and* --lti. You don't want to commit local type inference faillist.txt yet." + "Cannot run test_dcr.py with --write *and* --rwp. You don't want to commit local type inference faillist.txt yet." ) sys.exit(1) failList = loadFailList() flags = ["true", "DebugLuauDeferredConstraintResolution"] - if args.lti: - flags.append("DebugLuauLocalTypeInference") + if args.rwp: + flags.append("DebugLuauReadWriteProperties") commandLine = [args.path, "--reporters=xml", "--fflags=" + ",".join(flags)] From 716f63321abdfebc0594fdae639537c647758e1e Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 5 May 2023 12:57:12 -0700 Subject: [PATCH 50/66] Sync to upstream/release/575 --- Analysis/include/Luau/Constraint.h | 5 + Analysis/include/Luau/Frontend.h | 27 +- Analysis/include/Luau/Normalize.h | 6 - Analysis/include/Luau/Type.h | 2 +- Analysis/include/Luau/TypePack.h | 2 +- Analysis/include/Luau/Unifier.h | 1 + Analysis/src/ConstraintGraphBuilder.cpp | 13 +- Analysis/src/ConstraintSolver.cpp | 58 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 2 +- Analysis/src/Frontend.cpp | 598 ++++++++++++++++++-- Analysis/src/Normalize.cpp | 283 +++------ Analysis/src/TxnLog.cpp | 10 +- Analysis/src/Type.cpp | 62 +- Analysis/src/TypeChecker2.cpp | 342 +++++------ Analysis/src/TypeInfer.cpp | 13 +- Analysis/src/TypePack.cpp | 13 +- Analysis/src/Unifier.cpp | 105 ++-- Ast/include/Luau/ParseOptions.h | 2 - Ast/src/Parser.cpp | 40 +- CLI/Analyze.cpp | 117 +++- CLI/Ast.cpp | 2 - CodeGen/include/Luau/IrAnalysis.h | 14 + CodeGen/include/Luau/IrData.h | 2 +- CodeGen/include/Luau/IrDump.h | 2 +- CodeGen/include/Luau/IrUtils.h | 22 + CodeGen/include/Luau/UnwindBuilder.h | 31 +- CodeGen/include/Luau/UnwindBuilderDwarf2.h | 13 +- CodeGen/include/Luau/UnwindBuilderWin.h | 12 +- CodeGen/src/CodeBlockUnwind.cpp | 30 +- CodeGen/src/CodeGen.cpp | 10 +- CodeGen/src/CodeGenA64.cpp | 9 +- CodeGen/src/CodeGenX64.cpp | 33 +- CodeGen/src/EmitBuiltinsX64.cpp | 16 - CodeGen/src/IrAnalysis.cpp | 2 +- CodeGen/src/IrDump.cpp | 104 +++- CodeGen/src/IrLoweringA64.cpp | 54 +- CodeGen/src/IrLoweringX64.cpp | 10 +- CodeGen/src/IrTranslateBuiltins.cpp | 19 +- CodeGen/src/IrUtils.cpp | 5 + CodeGen/src/OptimizeConstProp.cpp | 71 ++- CodeGen/src/UnwindBuilderDwarf2.cpp | 170 +++--- CodeGen/src/UnwindBuilderWin.cpp | 94 +-- Compiler/src/BytecodeBuilder.cpp | 2 - VM/src/lvmexecute.cpp | 4 +- tests/CodeAllocator.test.cpp | 216 ++++--- tests/IrBuilder.test.cpp | 24 + tests/Normalize.test.cpp | 31 - tests/Parser.test.cpp | 26 - tests/TypeInfer.anyerror.test.cpp | 5 +- tests/TypeInfer.functions.test.cpp | 23 +- tests/TypeInfer.loops.test.cpp | 48 ++ tests/TypeInfer.oop.test.cpp | 28 +- tests/TypeInfer.provisional.test.cpp | 37 ++ tests/TypeInfer.refinements.test.cpp | 3 +- tests/TypeInfer.singletons.test.cpp | 12 +- tests/TypeInfer.tables.test.cpp | 19 + tools/faillist.txt | 8 +- 57 files changed, 1886 insertions(+), 1026 deletions(-) diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 2223c29e0..c7bc58b5a 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -107,6 +107,11 @@ struct FunctionCallConstraint TypePackId result; class AstExprCall* callSite; std::vector> discriminantTypes; + + // When we dispatch this constraint, we update the key at this map to record + // the overload that we selected. + DenseHashMap* astOriginalCallTypes; + DenseHashMap* astOverloadResolvedTypes; }; // result ~ prim ExpectedType SomeSingletonType MultitonType diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 67e840eec..14bf2e2e5 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -28,6 +28,7 @@ struct FileResolver; struct ModuleResolver; struct ParseResult; struct HotComment; +struct BuildQueueItem; struct LoadDefinitionFileResult { @@ -171,7 +172,18 @@ struct Frontend LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete = false); + // Batch module checking. Queue modules and check them together, retrieve results with 'getCheckResult' + // If provided, 'executeTask' function is allowed to call the 'task' function on any thread and return without waiting for 'task' to complete + void queueModuleCheck(const std::vector& names); + void queueModuleCheck(const ModuleName& name); + std::vector checkQueuedModules(std::optional optionOverride = {}, + std::function task)> executeTask = {}, std::function progress = {}); + + std::optional getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); + private: + CheckResult check_DEPRECATED(const ModuleName& name, std::optional optionOverride = {}); + struct TypeCheckLimits { std::optional finishTime; @@ -185,7 +197,14 @@ struct Frontend std::pair getSourceNode(const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); - bool parseGraph(std::vector& buildQueue, const ModuleName& root, bool forAutocomplete); + bool parseGraph( + std::vector& buildQueue, const ModuleName& root, bool forAutocomplete, std::function canSkip = {}); + + void addBuildQueueItems(std::vector& items, std::vector& buildQueue, bool cycleDetected, + std::unordered_set& seen, const FrontendOptions& frontendOptions); + void checkBuildQueueItem(BuildQueueItem& item); + void checkBuildQueueItems(std::vector& items); + void recordItemResult(const BuildQueueItem& item); static LintResult classifyLints(const std::vector& warnings, const Config& config); @@ -212,11 +231,13 @@ struct Frontend InternalErrorReporter iceHandler; std::function prepareModuleScope; - std::unordered_map sourceNodes; - std::unordered_map sourceModules; + std::unordered_map> sourceNodes; + std::unordered_map> sourceModules; std::unordered_map requireTrace; Stats stats = {}; + + std::vector moduleQueue; }; ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 6c808286c..2ec5406fd 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -226,10 +226,6 @@ struct NormalizedType NormalizedClassType classes; - // The class part of the type. - // Each element of this set is a class, and none of the classes are subclasses of each other. - TypeIds DEPRECATED_classes; - // The error part of the type. // This type is either never or the error type. TypeId errors; @@ -333,8 +329,6 @@ class Normalizer // ------- Normalizing intersections TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there); - void DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres); - void DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there); void intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres); void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 5d92cbd0b..c615b8f57 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -694,7 +694,7 @@ bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs); // Follow BoundTypes until we get to something real TypeId follow(TypeId t); -TypeId follow(TypeId t, std::function mapper); +TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)); std::vector flattenIntersection(TypeId ty); diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 2ae56e5f0..e78a66b84 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -169,7 +169,7 @@ using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); -TypePackId follow(TypePackId tp, std::function mapper); +TypePackId follow(TypePackId t, const void* context, TypePackId (*mapper)(const void*, TypePackId)); size_t size(TypePackId tp, TxnLog* log = nullptr); bool finite(TypePackId tp, TxnLog* log = nullptr); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index e3b0a8782..742f029ca 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -163,5 +163,6 @@ struct Unifier void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); std::optional hasUnificationTooComplex(const ErrorVec& errors); +std::optional hasCountMismatch(const ErrorVec& errors); } // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 611f420a9..e07fe701d 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -18,7 +18,6 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauMagicTypes); -LUAU_FASTFLAG(LuauNegatedClassTypes); namespace Luau { @@ -1016,7 +1015,7 @@ static bool isMetamethod(const Name& name) ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) { - std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; + std::optional superTy = std::make_optional(builtinTypes->classType); if (declaredClass->superName) { Name superName = Name(declaredClass->superName->value); @@ -1420,6 +1419,8 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa rets, call, std::move(discriminantTypes), + &module->astOriginalCallTypes, + &module->astOverloadResolvedTypes, }); // We force constraints produced by checking function arguments to wait @@ -1772,7 +1773,7 @@ std::tuple ConstraintGraphBuilder::checkBinary( TypeId ty = follow(typeFun->type); // We're only interested in the root class of any classes. - if (auto ctv = get(ty); !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent == builtinTypes->classType) : !ctv->parent)) + if (auto ctv = get(ty); !ctv || ctv->parent == builtinTypes->classType) discriminantTy = ty; } @@ -1786,8 +1787,10 @@ std::tuple ConstraintGraphBuilder::checkBinary( } else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) { - TypeId leftType = check(scope, binary->left, ValueContext::RValue, expectedType, true).ty; - TypeId rightType = check(scope, binary->right, ValueContext::RValue, expectedType, true).ty; + // We are checking a binary expression of the form a op b + // Just because a op b is epxected to return a bool, doesn't mean a, b are expected to be bools too + TypeId leftType = check(scope, binary->left, ValueContext::RValue, {}, true).ty; + TypeId rightType = check(scope, binary->right, ValueContext::RValue, {}, true).ty; RefinementId leftRefinement = nullptr; if (auto bc = dfg->getBreadcrumb(binary->left)) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index ec63b25e6..f1f868add 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1172,6 +1172,9 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn)) fn = collapse(it).value_or(fn); + if (c.callSite) + (*c.astOriginalCallTypes)[c.callSite] = fn; + // We don't support magic __call metamethods. if (std::optional callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location)) { @@ -1219,10 +1222,22 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); - std::vector overloads = flattenIntersection(fn); + const NormalizedType* normFn = normalizer->normalize(fn); + if (!normFn) + { + reportError(UnificationTooComplex{}, constraint->location); + return true; + } + + // TODO: It would be nice to not need to convert the normalized type back to + // an intersection and flatten it. + TypeId normFnTy = normalizer->typeFromNormal(*normFn); + std::vector overloads = flattenIntersection(normFnTy); Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); + std::vector arityMatchingOverloads; + for (TypeId overload : overloads) { overload = follow(overload); @@ -1247,8 +1262,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(*e)->context != CountMismatch::Context::Arg) && get(*instantiated)) + { + arityMatchingOverloads.push_back(*instantiated); + } + if (u.errors.empty()) { + if (c.callSite) + (*c.astOverloadResolvedTypes)[c.callSite] = *instantiated; + // We found a matching overload. const auto [changedTypes, changedPacks] = u.log.getChanges(); u.log.commit(); @@ -1260,6 +1284,15 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, Location{}, Covariant}; u.useScopes = true; @@ -1267,8 +1300,6 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullanyType); u.tryUnify(fn, builtinTypes->anyType); - LUAU_ASSERT(u.errors.empty()); // unifying with any should never fail - const auto [changedTypes, changedPacks] = u.log.getChanges(); u.log.commit(); @@ -2166,13 +2197,24 @@ void ConstraintSolver::unblock(NotNull progressed) void ConstraintSolver::unblock(TypeId progressed) { - if (logger) - logger->popBlock(progressed); + DenseHashSet seen{nullptr}; - unblock_(progressed); + while (true) + { + if (seen.find(progressed)) + iceReporter.ice("ConstraintSolver::unblock encountered a self-bound type!"); + seen.insert(progressed); - if (auto bt = get(progressed)) - unblock(bt->boundTo); + if (logger) + logger->popBlock(progressed); + + unblock_(progressed); + + if (auto bt = get(progressed)) + progressed = bt->boundTo; + else + break; + } } void ConstraintSolver::unblock(TypePackId progressed) diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 364244ad3..dfc6ff07d 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -92,7 +92,7 @@ type DateTypeResult = { declare os: { time: (time: DateTypeArg?) -> number, - date: (formatString: string?, time: number?) -> DateTypeResult | string, + date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string), difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, clock: () -> number, } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 486ef6960..b6b315cf1 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -21,6 +21,9 @@ #include #include +#include +#include +#include #include #include @@ -34,10 +37,36 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAG(LuauRequirePathTrueModuleName) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) +LUAU_FASTFLAGVARIABLE(LuauSplitFrontendProcessing, false) namespace Luau { +struct BuildQueueItem +{ + ModuleName name; + ModuleName humanReadableName; + + // Parameters + std::shared_ptr sourceNode; + std::shared_ptr sourceModule; + Config config; + ScopePtr environmentScope; + std::vector requireCycles; + FrontendOptions options; + bool recordJsonLog = false; + + // Queue state + std::vector reverseDeps; + int dirtyDependencies = 0; + bool processing = false; + + // Result + std::exception_ptr exception; + ModulePtr module; + Frontend::Stats stats; +}; + std::optional parseMode(const std::vector& hotcomments) { for (const HotComment& hc : hotcomments) @@ -220,7 +249,7 @@ namespace { static ErrorVec accumulateErrors( - const std::unordered_map& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name) + const std::unordered_map>& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name) { std::unordered_set seen; std::vector queue{name}; @@ -240,7 +269,7 @@ static ErrorVec accumulateErrors( if (it == sourceNodes.end()) continue; - const SourceNode& sourceNode = it->second; + const SourceNode& sourceNode = *it->second; queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end()); // FIXME: If a module has a syntax error, we won't be able to re-report it here. @@ -285,8 +314,8 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector getRequireCycles( - const FileResolver* resolver, const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) +std::vector getRequireCycles(const FileResolver* resolver, + const std::unordered_map>& sourceNodes, const SourceNode* start, bool stopAtFirst = false) { std::vector result; @@ -302,7 +331,7 @@ std::vector getRequireCycles( if (dit == sourceNodes.end()) continue; - stack.push_back(&dit->second); + stack.push_back(dit->second.get()); while (!stack.empty()) { @@ -343,7 +372,7 @@ std::vector getRequireCycles( auto rit = sourceNodes.find(reqName); if (rit != sourceNodes.end()) - stack.push_back(&rit->second); + stack.push_back(rit->second.get()); } } } @@ -389,6 +418,52 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c } CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) +{ + if (!FFlag::LuauSplitFrontendProcessing) + return check_DEPRECATED(name, optionOverride); + + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + + FrontendOptions frontendOptions = optionOverride.value_or(options); + + if (std::optional result = getCheckResult(name, true, frontendOptions.forAutocomplete)) + return std::move(*result); + + std::vector buildQueue; + bool cycleDetected = parseGraph(buildQueue, name, frontendOptions.forAutocomplete); + + std::unordered_set seen; + std::vector buildQueueItems; + addBuildQueueItems(buildQueueItems, buildQueue, cycleDetected, seen, frontendOptions); + LUAU_ASSERT(!buildQueueItems.empty()); + + if (FFlag::DebugLuauLogSolverToJson) + { + LUAU_ASSERT(buildQueueItems.back().name == name); + buildQueueItems.back().recordJsonLog = true; + } + + checkBuildQueueItems(buildQueueItems); + + // Collect results only for checked modules, 'getCheckResult' produces a different result + CheckResult checkResult; + + for (const BuildQueueItem& item : buildQueueItems) + { + if (item.module->timeout) + checkResult.timeoutHits.push_back(item.name); + + checkResult.errors.insert(checkResult.errors.end(), item.module->errors.begin(), item.module->errors.end()); + + if (item.name == name) + checkResult.lintResult = item.module->lintResult; + } + + return checkResult; +} + +CheckResult Frontend::check_DEPRECATED(const ModuleName& name, std::optional optionOverride) { LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); @@ -399,7 +474,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond.hasDirtyModule(frontendOptions.forAutocomplete)) + if (it != sourceNodes.end() && !it->second->hasDirtyModule(frontendOptions.forAutocomplete)) { // No recheck required. ModulePtr module = resolver.getModule(name); @@ -421,13 +496,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalgetConfig(moduleName); @@ -583,7 +658,241 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& buildQueue, const ModuleName& root, bool forAutocomplete) +void Frontend::queueModuleCheck(const std::vector& names) +{ + moduleQueue.insert(moduleQueue.end(), names.begin(), names.end()); +} + +void Frontend::queueModuleCheck(const ModuleName& name) +{ + moduleQueue.push_back(name); +} + +std::vector Frontend::checkQueuedModules(std::optional optionOverride, + std::function task)> executeTask, std::function progress) +{ + FrontendOptions frontendOptions = optionOverride.value_or(options); + + // By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown + std::vector currModuleQueue; + std::swap(currModuleQueue, moduleQueue); + + std::unordered_set seen; + std::vector buildQueueItems; + + for (const ModuleName& name : currModuleQueue) + { + if (seen.count(name)) + continue; + + if (!isDirty(name, frontendOptions.forAutocomplete)) + { + seen.insert(name); + continue; + } + + std::vector queue; + bool cycleDetected = parseGraph(queue, name, frontendOptions.forAutocomplete, [&seen](const ModuleName& name) { + return seen.count(name); + }); + + addBuildQueueItems(buildQueueItems, queue, cycleDetected, seen, frontendOptions); + } + + if (buildQueueItems.empty()) + return {}; + + // We need a mapping from modules to build queue slots + std::unordered_map moduleNameToQueue; + + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; + moduleNameToQueue[item.name] = i; + } + + // Default task execution is single-threaded and immediate + if (!executeTask) + { + executeTask = [](std::function task) { + task(); + }; + } + + std::mutex mtx; + std::condition_variable cv; + std::vector readyQueueItems; + + size_t processing = 0; + size_t remaining = buildQueueItems.size(); + + auto itemTask = [&](size_t i) { + BuildQueueItem& item = buildQueueItems[i]; + + try + { + checkBuildQueueItem(item); + } + catch (...) + { + item.exception = std::current_exception(); + } + + { + std::unique_lock guard(mtx); + readyQueueItems.push_back(i); + } + + cv.notify_one(); + }; + + auto sendItemTask = [&](size_t i) { + BuildQueueItem& item = buildQueueItems[i]; + + item.processing = true; + processing++; + + executeTask([&itemTask, i]() { + itemTask(i); + }); + }; + + auto sendCycleItemTask = [&] { + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; + + if (!item.processing) + { + sendItemTask(i); + break; + } + } + }; + + // In a first pass, check modules that have no dependencies and record info of those modules that wait + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; + + for (const ModuleName& dep : item.sourceNode->requireSet) + { + if (auto it = sourceNodes.find(dep); it != sourceNodes.end()) + { + if (it->second->hasDirtyModule(frontendOptions.forAutocomplete)) + { + item.dirtyDependencies++; + + buildQueueItems[moduleNameToQueue[dep]].reverseDeps.push_back(i); + } + } + } + + if (item.dirtyDependencies == 0) + sendItemTask(i); + } + + // Not a single item was found, a cycle in the graph was hit + if (processing == 0) + sendCycleItemTask(); + + std::vector nextItems; + + while (remaining != 0) + { + { + std::unique_lock guard(mtx); + + // If nothing is ready yet, wait + if (readyQueueItems.empty()) + { + cv.wait(guard, [&readyQueueItems] { + return !readyQueueItems.empty(); + }); + } + + // Handle checked items + for (size_t i : readyQueueItems) + { + const BuildQueueItem& item = buildQueueItems[i]; + recordItemResult(item); + + // Notify items that were waiting for this dependency + for (size_t reverseDep : item.reverseDeps) + { + BuildQueueItem& reverseDepItem = buildQueueItems[reverseDep]; + + LUAU_ASSERT(reverseDepItem.dirtyDependencies != 0); + reverseDepItem.dirtyDependencies--; + + // In case of a module cycle earlier, check if unlocked an item that was already processed + if (!reverseDepItem.processing && reverseDepItem.dirtyDependencies == 0) + nextItems.push_back(reverseDep); + } + } + + LUAU_ASSERT(processing >= readyQueueItems.size()); + processing -= readyQueueItems.size(); + + LUAU_ASSERT(remaining >= readyQueueItems.size()); + remaining -= readyQueueItems.size(); + readyQueueItems.clear(); + } + + if (progress) + progress(buildQueueItems.size() - remaining, buildQueueItems.size()); + + // Items cannot be submitted while holding the lock + for (size_t i : nextItems) + sendItemTask(i); + nextItems.clear(); + + // If we aren't done, but don't have anything processing, we hit a cycle + if (remaining != 0 && processing == 0) + sendCycleItemTask(); + } + + std::vector checkedModules; + checkedModules.reserve(buildQueueItems.size()); + + for (size_t i = 0; i < buildQueueItems.size(); i++) + checkedModules.push_back(std::move(buildQueueItems[i].name)); + + return checkedModules; +} + +std::optional Frontend::getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete) +{ + auto it = sourceNodes.find(name); + + if (it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete)) + return std::nullopt; + + auto& resolver = forAutocomplete ? moduleResolverForAutocomplete : moduleResolver; + + ModulePtr module = resolver.getModule(name); + + if (module == nullptr) + throw InternalCompilerError("Frontend does not have module: " + name, name); + + CheckResult checkResult; + + if (module->timeout) + checkResult.timeoutHits.push_back(name); + + if (accumulateNested) + checkResult.errors = accumulateErrors(sourceNodes, resolver, name); + else + checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end()); + + // Get lint result only for top checked module + checkResult.lintResult = module->lintResult; + + return checkResult; +} + +bool Frontend::parseGraph( + std::vector& buildQueue, const ModuleName& root, bool forAutocomplete, std::function canSkip) { LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); @@ -654,14 +963,18 @@ bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& // this relies on the fact that markDirty marks reverse-dependencies dirty as well // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // to be built, *and* can't form a cycle with any nodes we did process. - if (!it->second.hasDirtyModule(forAutocomplete)) + if (!it->second->hasDirtyModule(forAutocomplete)) + continue; + + // This module might already be in the outside build queue + if (canSkip && canSkip(dep)) continue; // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization // calling getSourceNode twice in succession will reparse the file, since getSourceNode leaves dirty flag set - if (seen.contains(&it->second)) + if (seen.contains(it->second.get())) { - stack.push_back(&it->second); + stack.push_back(it->second.get()); continue; } } @@ -681,6 +994,210 @@ bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& return cyclic; } +void Frontend::addBuildQueueItems(std::vector& items, std::vector& buildQueue, bool cycleDetected, + std::unordered_set& seen, const FrontendOptions& frontendOptions) +{ + LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing); + + for (const ModuleName& moduleName : buildQueue) + { + if (seen.count(moduleName)) + continue; + seen.insert(moduleName); + + LUAU_ASSERT(sourceNodes.count(moduleName)); + std::shared_ptr& sourceNode = sourceNodes[moduleName]; + + if (!sourceNode->hasDirtyModule(frontendOptions.forAutocomplete)) + continue; + + LUAU_ASSERT(sourceModules.count(moduleName)); + std::shared_ptr& sourceModule = sourceModules[moduleName]; + + BuildQueueItem data{moduleName, fileResolver->getHumanReadableModuleName(moduleName), sourceNode, sourceModule}; + + data.config = configResolver->getConfig(moduleName); + data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete); + + Mode mode = sourceModule->mode.value_or(data.config.mode); + + // in NoCheck mode we only need to compute the value of .cyclic for typeck + // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself + // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term + // all correct programs must be acyclic so this code triggers rarely + if (cycleDetected) + data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck); + + data.options = frontendOptions; + + // This is used by the type checker to replace the resulting type of cyclic modules with any + sourceModule->cyclic = !data.requireCycles.empty(); + + items.push_back(std::move(data)); + } +} + +void Frontend::checkBuildQueueItem(BuildQueueItem& item) +{ + LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing); + + SourceNode& sourceNode = *item.sourceNode; + const SourceModule& sourceModule = *item.sourceModule; + const Config& config = item.config; + Mode mode = sourceModule.mode.value_or(config.mode); + ScopePtr environmentScope = item.environmentScope; + double timestamp = getTimestamp(); + const std::vector& requireCycles = item.requireCycles; + + if (item.options.forAutocomplete) + { + double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; + + // The autocomplete typecheck is always in strict mode with DM awareness + // to provide better type information for IDE features + TypeCheckLimits typeCheckLimits; + + if (autocompleteTimeLimit != 0.0) + typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; + else + typeCheckLimits.finishTime = std::nullopt; + + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.unifierIterationLimit = std::nullopt; + + ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, + /*recordJsonLog*/ false, typeCheckLimits); + + double duration = getTimestamp() - timestamp; + + if (moduleForAutocomplete->timeout) + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + else if (duration < autocompleteTimeLimit / 2.0) + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); + + item.stats.timeCheck += duration; + item.stats.filesStrict += 1; + + item.module = moduleForAutocomplete; + return; + } + + ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, item.recordJsonLog, {}); + + item.stats.timeCheck += getTimestamp() - timestamp; + item.stats.filesStrict += mode == Mode::Strict; + item.stats.filesNonstrict += mode == Mode::Nonstrict; + + if (module == nullptr) + throw InternalCompilerError("Frontend::check produced a nullptr module for " + item.name, item.name); + + if (FFlag::DebugLuauDeferredConstraintResolution && mode == Mode::NoCheck) + module->errors.clear(); + + if (item.options.runLintChecks) + { + LUAU_TIMETRACE_SCOPE("lint", "Frontend"); + + LintOptions lintOptions = item.options.enabledLintWarnings.value_or(config.enabledLint); + filterLintOptions(lintOptions, sourceModule.hotcomments, mode); + + double timestamp = getTimestamp(); + + std::vector warnings = + Luau::lint(sourceModule.root, *sourceModule.names, environmentScope, module.get(), sourceModule.hotcomments, lintOptions); + + item.stats.timeLint += getTimestamp() - timestamp; + + module->lintResult = classifyLints(warnings, config); + } + + if (!item.options.retainFullTypeGraphs) + { + // copyErrors needs to allocate into interfaceTypes as it copies + // types out of internalTypes, so we unfreeze it here. + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes); + freeze(module->interfaceTypes); + + module->internalTypes.clear(); + + module->astTypes.clear(); + module->astTypePacks.clear(); + module->astExpectedTypes.clear(); + module->astOriginalCallTypes.clear(); + module->astOverloadResolvedTypes.clear(); + module->astResolvedTypes.clear(); + module->astOriginalResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); + module->astScopes.clear(); + + module->scopes.clear(); + } + + if (mode != Mode::NoCheck) + { + for (const RequireCycle& cyc : requireCycles) + { + TypeError te{cyc.location, item.name, ModuleHasCyclicDependency{cyc.path}}; + + module->errors.push_back(te); + } + } + + ErrorVec parseErrors; + + for (const ParseError& pe : sourceModule.parseErrors) + parseErrors.push_back(TypeError{pe.getLocation(), item.name, SyntaxError{pe.what()}}); + + module->errors.insert(module->errors.begin(), parseErrors.begin(), parseErrors.end()); + + item.module = module; +} + +void Frontend::checkBuildQueueItems(std::vector& items) +{ + LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing); + + for (BuildQueueItem& item : items) + { + checkBuildQueueItem(item); + recordItemResult(item); + } +} + +void Frontend::recordItemResult(const BuildQueueItem& item) +{ + if (item.exception) + std::rethrow_exception(item.exception); + + if (item.options.forAutocomplete) + { + moduleResolverForAutocomplete.setModule(item.name, item.module); + item.sourceNode->dirtyModuleForAutocomplete = false; + } + else + { + moduleResolver.setModule(item.name, item.module); + item.sourceNode->dirtyModule = false; + } + + stats.timeCheck += item.stats.timeCheck; + stats.timeLint += item.stats.timeLint; + + stats.filesStrict += item.stats.filesStrict; + stats.filesNonstrict += item.stats.filesNonstrict; +} + ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const { ScopePtr result; @@ -711,7 +1228,7 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); - return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete); + return it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete); } /* @@ -728,7 +1245,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) { - for (const auto& dep : module.second.requireSet) + for (const auto& dep : module.second->requireSet) reverseDeps[dep].push_back(module.first); } @@ -740,7 +1257,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked queue.pop_back(); LUAU_ASSERT(sourceNodes.count(next) > 0); - SourceNode& sourceNode = sourceNodes[next]; + SourceNode& sourceNode = *sourceNodes[next]; if (markedDirty) markedDirty->push_back(next); @@ -766,7 +1283,7 @@ SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) { auto it = sourceModules.find(moduleName); if (it != sourceModules.end()) - return &it->second; + return it->second.get(); else return nullptr; } @@ -901,22 +1418,22 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(const ModuleName& name) { - LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); - auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.hasDirtySourceModule()) + if (it != sourceNodes.end() && !it->second->hasDirtySourceModule()) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) - return {&it->second, &moduleIt->second}; + return {it->second.get(), moduleIt->second.get()}; else { LUAU_ASSERT(!"Everything in sourceNodes should also be in sourceModules"); - return {&it->second, nullptr}; + return {it->second.get(), nullptr}; } } + LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + double timestamp = getTimestamp(); std::optional source = fileResolver->readSource(name); @@ -939,30 +1456,37 @@ std::pair Frontend::getSourceNode(const ModuleName& RequireTraceResult& require = requireTrace[name]; require = traceRequires(fileResolver, result.root, name); - SourceNode& sourceNode = sourceNodes[name]; - SourceModule& sourceModule = sourceModules[name]; + std::shared_ptr& sourceNode = sourceNodes[name]; + + if (!sourceNode) + sourceNode = std::make_shared(); + + std::shared_ptr& sourceModule = sourceModules[name]; - sourceModule = std::move(result); - sourceModule.environmentName = environmentName; + if (!sourceModule) + sourceModule = std::make_shared(); - sourceNode.name = sourceModule.name; - sourceNode.humanReadableName = sourceModule.humanReadableName; - sourceNode.requireSet.clear(); - sourceNode.requireLocations.clear(); - sourceNode.dirtySourceModule = false; + *sourceModule = std::move(result); + sourceModule->environmentName = environmentName; + + sourceNode->name = sourceModule->name; + sourceNode->humanReadableName = sourceModule->humanReadableName; + sourceNode->requireSet.clear(); + sourceNode->requireLocations.clear(); + sourceNode->dirtySourceModule = false; if (it == sourceNodes.end()) { - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; + sourceNode->dirtyModule = true; + sourceNode->dirtyModuleForAutocomplete = true; } for (const auto& [moduleName, location] : require.requireList) - sourceNode.requireSet.insert(moduleName); + sourceNode->requireSet.insert(moduleName); - sourceNode.requireLocations = require.requireList; + sourceNode->requireLocations = require.requireList; - return {&sourceNode, &sourceModule}; + return {sourceNode.get(), sourceModule.get()}; } /** Try to parse a source file into a SourceModule. diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 29f8b2e68..cfc0ae137 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -17,8 +17,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeMetatableFixes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -232,15 +230,8 @@ NormalizedType::NormalizedType(NotNull builtinTypes) static bool isShallowInhabited(const NormalizedType& norm) { - bool inhabitedClasses; - - if (FFlag::LuauNegatedClassTypes) - inhabitedClasses = !norm.classes.isNever(); - else - inhabitedClasses = !norm.DEPRECATED_classes.empty(); - // This test is just a shallow check, for example it returns `true` for `{ p : never }` - return !get(norm.tops) || !get(norm.booleans) || inhabitedClasses || !get(norm.errors) || + return !get(norm.tops) || !get(norm.booleans) || !norm.classes.isNever() || !get(norm.errors) || !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } @@ -257,14 +248,8 @@ bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_setclasses.isNever(); - else - inhabitedClasses = !norm->DEPRECATED_classes.empty(); - if (!get(norm->tops) || !get(norm->booleans) || !get(norm->errors) || !get(norm->nils) || - !get(norm->numbers) || !get(norm->threads) || inhabitedClasses || !norm->strings.isNever() || + !get(norm->numbers) || !get(norm->threads) || !norm->classes.isNever() || !norm->strings.isNever() || !norm->functions.isNever()) return true; @@ -466,7 +451,7 @@ static bool areNormalizedTables(const TypeIds& tys) if (!pt) return false; - if (pt->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) + if (pt->type == PrimitiveType::Table) continue; return false; @@ -475,14 +460,6 @@ static bool areNormalizedTables(const TypeIds& tys) return true; } -static bool areNormalizedClasses(const TypeIds& tys) -{ - for (TypeId ty : tys) - if (!get(ty)) - return false; - return true; -} - static bool areNormalizedClasses(const NormalizedClassType& tys) { for (const auto& [ty, negations] : tys.classes) @@ -567,7 +544,6 @@ static void assertInvariant(const NormalizedType& norm) LUAU_ASSERT(isNormalizedTop(norm.tops)); LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); - LUAU_ASSERT(areNormalizedClasses(norm.DEPRECATED_classes)); LUAU_ASSERT(areNormalizedClasses(norm.classes)); LUAU_ASSERT(isNormalizedError(norm.errors)); LUAU_ASSERT(isNormalizedNil(norm.nils)); @@ -629,7 +605,6 @@ void Normalizer::clearNormal(NormalizedType& norm) norm.tops = builtinTypes->neverType; norm.booleans = builtinTypes->neverType; norm.classes.resetToNever(); - norm.DEPRECATED_classes.clear(); norm.errors = builtinTypes->neverType; norm.nils = builtinTypes->neverType; norm.numbers = builtinTypes->neverType; @@ -1253,18 +1228,11 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) { for (TypeId there : theres) { - if (FFlag::LuauNegatedTableTypes) + if (there == builtinTypes->tableType) { - if (there == builtinTypes->tableType) - { - heres.clear(); - heres.insert(there); - return; - } - else - { - unionTablesWithTable(heres, there); - } + heres.clear(); + heres.insert(there); + return; } else { @@ -1320,10 +1288,7 @@ bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, } here.booleans = unionOfBools(here.booleans, there.booleans); - if (FFlag::LuauNegatedClassTypes) - unionClasses(here.classes, there.classes); - else - unionClasses(here.DEPRECATED_classes, there.DEPRECATED_classes); + unionClasses(here.classes, there.classes); here.errors = (get(there.errors) ? here.errors : there.errors); here.nils = (get(there.nils) ? here.nils : there.nils); @@ -1414,16 +1379,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else if (get(there) || get(there)) unionTablesWithTable(here.tables, there); else if (get(there)) - { - if (FFlag::LuauNegatedClassTypes) - { - unionClassesWithClass(here.classes, there); - } - else - { - unionClassesWithClass(here.DEPRECATED_classes, there); - } - } + unionClassesWithClass(here.classes, there); else if (get(there)) here.errors = there; else if (const PrimitiveType* ptv = get(there)) @@ -1442,7 +1398,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor { here.functions.resetToTop(); } - else if (ptv->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) + else if (ptv->type == PrimitiveType::Table) { here.tables.clear(); here.tables.insert(there); @@ -1527,36 +1483,29 @@ std::optional Normalizer::negateNormal(const NormalizedType& her result.booleans = builtinTypes->trueType; } - if (FFlag::LuauNegatedClassTypes) + if (here.classes.isNever()) { - if (here.classes.isNever()) - { - resetToTop(builtinTypes, result.classes); - } - else if (isTop(builtinTypes, result.classes)) - { - result.classes.resetToNever(); - } - else - { - TypeIds rootNegations{}; - - for (const auto& [hereParent, hereNegations] : here.classes.classes) - { - if (hereParent != builtinTypes->classType) - rootNegations.insert(hereParent); - - for (TypeId hereNegation : hereNegations) - unionClassesWithClass(result.classes, hereNegation); - } - - if (!rootNegations.empty()) - result.classes.pushPair(builtinTypes->classType, rootNegations); - } + resetToTop(builtinTypes, result.classes); + } + else if (isTop(builtinTypes, result.classes)) + { + result.classes.resetToNever(); } else { - result.DEPRECATED_classes = negateAll(here.DEPRECATED_classes); + TypeIds rootNegations{}; + + for (const auto& [hereParent, hereNegations] : here.classes.classes) + { + if (hereParent != builtinTypes->classType) + rootNegations.insert(hereParent); + + for (TypeId hereNegation : hereNegations) + unionClassesWithClass(result.classes, hereNegation); + } + + if (!rootNegations.empty()) + result.classes.pushPair(builtinTypes->classType, rootNegations); } result.nils = get(here.nils) ? builtinTypes->nilType : builtinTypes->neverType; @@ -1584,15 +1533,12 @@ std::optional Normalizer::negateNormal(const NormalizedType& her * types are not runtime-testable. Thus, we prohibit negation of anything * other than `table` and `never`. */ - if (FFlag::LuauNegatedTableTypes) - { - if (here.tables.empty()) - result.tables.insert(builtinTypes->tableType); - else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType) - result.tables.clear(); - else - return std::nullopt; - } + if (here.tables.empty()) + result.tables.insert(builtinTypes->tableType); + else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType) + result.tables.clear(); + else + return std::nullopt; // TODO: negating tables // TODO: negating tyvars? @@ -1662,7 +1608,6 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) here.functions.resetToNever(); break; case PrimitiveType::Table: - LUAU_ASSERT(FFlag::LuauNegatedTableTypes); here.tables.clear(); break; } @@ -1734,64 +1679,6 @@ TypeId Normalizer::intersectionOfBools(TypeId here, TypeId there) return there; } -void Normalizer::DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres) -{ - TypeIds tmp; - for (auto it = heres.begin(); it != heres.end();) - { - const ClassType* hctv = get(*it); - LUAU_ASSERT(hctv); - bool keep = false; - for (TypeId there : theres) - { - const ClassType* tctv = get(there); - LUAU_ASSERT(tctv); - if (isSubclass(hctv, tctv)) - { - keep = true; - break; - } - else if (isSubclass(tctv, hctv)) - { - keep = false; - tmp.insert(there); - break; - } - } - if (keep) - it++; - else - it = heres.erase(it); - } - heres.insert(tmp.begin(), tmp.end()); -} - -void Normalizer::DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there) -{ - bool foundSuper = false; - const ClassType* tctv = get(there); - LUAU_ASSERT(tctv); - for (auto it = heres.begin(); it != heres.end();) - { - const ClassType* hctv = get(*it); - LUAU_ASSERT(hctv); - if (isSubclass(hctv, tctv)) - it++; - else if (isSubclass(tctv, hctv)) - { - foundSuper = true; - break; - } - else - it = heres.erase(it); - } - if (foundSuper) - { - heres.clear(); - heres.insert(there); - } -} - void Normalizer::intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres) { if (theres.isNever()) @@ -2504,15 +2391,7 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th here.booleans = intersectionOfBools(here.booleans, there.booleans); - if (FFlag::LuauNegatedClassTypes) - { - intersectClasses(here.classes, there.classes); - } - else - { - DEPRECATED_intersectClasses(here.DEPRECATED_classes, there.DEPRECATED_classes); - } - + intersectClasses(here.classes, there.classes); here.errors = (get(there.errors) ? there.errors : here.errors); here.nils = (get(there.nils) ? there.nils : here.nils); here.numbers = (get(there.numbers) ? there.numbers : here.numbers); @@ -2619,20 +2498,10 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) } else if (get(there)) { - if (FFlag::LuauNegatedClassTypes) - { - NormalizedClassType nct = std::move(here.classes); - clearNormal(here); - intersectClassesWithClass(nct, there); - here.classes = std::move(nct); - } - else - { - TypeIds classes = std::move(here.DEPRECATED_classes); - clearNormal(here); - DEPRECATED_intersectClassesWithClass(classes, there); - here.DEPRECATED_classes = std::move(classes); - } + NormalizedClassType nct = std::move(here.classes); + clearNormal(here); + intersectClassesWithClass(nct, there); + here.classes = std::move(nct); } else if (get(there)) { @@ -2665,10 +2534,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) else if (ptv->type == PrimitiveType::Function) here.functions = std::move(functions); else if (ptv->type == PrimitiveType::Table) - { - LUAU_ASSERT(FFlag::LuauNegatedTableTypes); here.tables = std::move(tables); - } else LUAU_ASSERT(!"Unreachable"); } @@ -2696,7 +2562,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) subtractPrimitive(here, ntv->ty); else if (const SingletonType* stv = get(t)) subtractSingleton(here, follow(ntv->ty)); - else if (get(t) && FFlag::LuauNegatedClassTypes) + else if (get(t)) { const NormalizedType* normal = normalize(t); std::optional negated = negateNormal(*normal); @@ -2730,7 +2596,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) LUAU_ASSERT(!"Unimplemented"); } } - else if (get(there) && FFlag::LuauNegatedClassTypes) + else if (get(there)) { here.classes.resetToNever(); } @@ -2756,53 +2622,46 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) if (!get(norm.booleans)) result.push_back(norm.booleans); - if (FFlag::LuauNegatedClassTypes) + if (isTop(builtinTypes, norm.classes)) { - if (isTop(builtinTypes, norm.classes)) - { - result.push_back(builtinTypes->classType); - } - else if (!norm.classes.isNever()) + result.push_back(builtinTypes->classType); + } + else if (!norm.classes.isNever()) + { + std::vector parts; + parts.reserve(norm.classes.classes.size()); + + for (const TypeId normTy : norm.classes.ordering) { - std::vector parts; - parts.reserve(norm.classes.classes.size()); + const TypeIds& normNegations = norm.classes.classes.at(normTy); - for (const TypeId normTy : norm.classes.ordering) + if (normNegations.empty()) { - const TypeIds& normNegations = norm.classes.classes.at(normTy); + parts.push_back(normTy); + } + else + { + std::vector intersection; + intersection.reserve(normNegations.size() + 1); - if (normNegations.empty()) + intersection.push_back(normTy); + for (TypeId negation : normNegations) { - parts.push_back(normTy); + intersection.push_back(arena->addType(NegationType{negation})); } - else - { - std::vector intersection; - intersection.reserve(normNegations.size() + 1); - - intersection.push_back(normTy); - for (TypeId negation : normNegations) - { - intersection.push_back(arena->addType(NegationType{negation})); - } - parts.push_back(arena->addType(IntersectionType{std::move(intersection)})); - } + parts.push_back(arena->addType(IntersectionType{std::move(intersection)})); } + } - if (parts.size() == 1) - { - result.push_back(parts.at(0)); - } - else if (parts.size() > 1) - { - result.push_back(arena->addType(UnionType{std::move(parts)})); - } + if (parts.size() == 1) + { + result.push_back(parts.at(0)); + } + else if (parts.size() > 1) + { + result.push_back(arena->addType(UnionType{std::move(parts)})); } - } - else - { - result.insert(result.end(), norm.DEPRECATED_classes.begin(), norm.DEPRECATED_classes.end()); } if (!get(norm.errors)) diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 26618313b..33554ce90 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -382,8 +382,9 @@ std::optional TxnLog::getLevel(TypeId ty) const TypeId TxnLog::follow(TypeId ty) const { - return Luau::follow(ty, [this](TypeId ty) { - PendingType* state = this->pending(ty); + return Luau::follow(ty, this, [](const void* ctx, TypeId ty) -> TypeId { + const TxnLog* self = static_cast(ctx); + PendingType* state = self->pending(ty); if (state == nullptr) return ty; @@ -397,8 +398,9 @@ TypeId TxnLog::follow(TypeId ty) const TypePackId TxnLog::follow(TypePackId tp) const { - return Luau::follow(tp, [this](TypePackId tp) { - PendingTypePack* state = this->pending(tp); + return Luau::follow(tp, this, [](const void* ctx, TypePackId tp) -> TypePackId { + const TxnLog* self = static_cast(ctx); + PendingTypePack* state = self->pending(tp); if (state == nullptr) return tp; diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 2ca39b41a..e8a2bc5d4 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -48,19 +48,39 @@ static std::optional> magicFunctionFind( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); static bool dcrMagicFunctionFind(MagicFunctionCallContext context); +// LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable +static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv) +{ + TypeId unwrapped = ltv->unwrapped.load(); + + if (unwrapped) + return unwrapped; + + ltv->unwrap(*ltv); + unwrapped = ltv->unwrapped.load(); + + if (!unwrapped) + throw InternalCompilerError("Lazy Type didn't fill in unwrapped type field"); + + if (get(unwrapped)) + throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); + + return unwrapped; +} + TypeId follow(TypeId t) { - return follow(t, [](TypeId t) { + return follow(t, nullptr, [](const void*, TypeId t) -> TypeId { return t; }); } -TypeId follow(TypeId t, std::function mapper) +TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)) { - auto advance = [&mapper](TypeId ty) -> std::optional { + auto advance = [context, mapper](TypeId ty) -> std::optional { if (FFlag::LuauBoundLazyTypes2) { - TypeId mapped = mapper(ty); + TypeId mapped = mapper(context, ty); if (auto btv = get>(mapped)) return btv->boundTo; @@ -69,39 +89,25 @@ TypeId follow(TypeId t, std::function mapper) return ttv->boundTo; if (auto ltv = getMutable(mapped)) - { - TypeId unwrapped = ltv->unwrapped.load(); - - if (unwrapped) - return unwrapped; - - ltv->unwrap(*ltv); - unwrapped = ltv->unwrapped.load(); - - if (!unwrapped) - throw InternalCompilerError("Lazy Type didn't fill in unwrapped type field"); - - if (get(unwrapped)) - throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); - - return unwrapped; - } + return unwrapLazy(ltv); return std::nullopt; } else { - if (auto btv = get>(mapper(ty))) + if (auto btv = get>(mapper(context, ty))) return btv->boundTo; - else if (auto ttv = get(mapper(ty))) + else if (auto ttv = get(mapper(context, ty))) return ttv->boundTo; else return std::nullopt; } }; - auto force = [&mapper](TypeId ty) { - if (auto ltv = get_if(&mapper(ty)->ty)) + auto force = [context, mapper](TypeId ty) { + TypeId mapped = mapper(context, ty); + + if (auto ltv = get_if(&mapped->ty)) { TypeId res = ltv->thunk_DEPRECATED(); if (get(res)) @@ -120,6 +126,12 @@ TypeId follow(TypeId t, std::function mapper) else return t; + if (FFlag::LuauBoundLazyTypes2) + { + if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null + return cycleTester; + } + while (true) { if (!FFlag::LuauBoundLazyTypes2) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a103df145..2a2fe69ce 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -22,8 +22,6 @@ LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauDontReduceTypes) -LUAU_FASTFLAG(LuauNegatedClassTypes) - namespace Luau { @@ -519,18 +517,39 @@ struct TypeChecker2 auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), iterFtv->argTypes, /*includeHiddenVariadics*/ true); if (minCount > 2) - reportError(CountMismatch{2, std::nullopt, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values)); + } if (maxCount && *maxCount < 2) - reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values)); + } TypePack flattenedArgTypes = extendTypePack(arena, builtinTypes, iterFtv->argTypes, 2); size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; size_t actualArgCount = expectedVariableTypes.head.size(); - if (firstIterationArgCount < minCount) - reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + } + else if (actualArgCount < minCount) - reportError(CountMismatch{2, std::nullopt, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + } + if (iterTys.size() >= 2 && flattenedArgTypes.head.size() > 0) { @@ -841,125 +860,31 @@ struct TypeChecker2 // TODO! } - ErrorVec visitOverload(AstExprCall* call, NotNull overloadFunctionType, const std::vector& argLocs, - TypePackId expectedArgTypes, TypePackId expectedRetType) - { - ErrorVec overloadErrors = - tryUnify(stack.back(), call->location, overloadFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult); - - size_t argIndex = 0; - auto inferredArgIt = begin(overloadFunctionType->argTypes); - auto expectedArgIt = begin(expectedArgTypes); - while (inferredArgIt != end(overloadFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) - { - Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; - ErrorVec argErrors = tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt); - for (TypeError e : argErrors) - overloadErrors.emplace_back(e); - - ++argIndex; - ++inferredArgIt; - ++expectedArgIt; - } - - // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad - ErrorVec argumentErrors = tryUnify(stack.back(), call->location, expectedArgTypes, overloadFunctionType->argTypes); - for (TypeError e : argumentErrors) - if (get(e) != nullptr) - overloadErrors.emplace_back(std::move(e)); - - return overloadErrors; - } - - void reportOverloadResolutionErrors(AstExprCall* call, std::vector overloads, TypePackId expectedArgTypes, - const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) - { - if (overloads.size() == 1) - { - reportErrors(std::get<0>(overloadsErrors.front())); - return; - } - - std::vector overloadTypes = overloadsThatMatchArgCount; - if (overloadsThatMatchArgCount.size() == 0) - { - reportError(GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location); - // If no overloads match argument count, just list all overloads. - overloadTypes = overloads; - } - else - { - // Report errors of the first argument-count-matching, but failing overload - TypeId overload = overloadsThatMatchArgCount[0]; - - // Remove the overload we are reporting errors about from the list of alternatives - overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end()); - - const FunctionType* ftv = get(overload); - LUAU_ASSERT(ftv); // overload must be a function type here - - auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [overload](const std::pair& e) { - return overload == e.second; - }); - - LUAU_ASSERT(error != overloadsErrors.end()); - reportErrors(std::get<0>(*error)); - - // If only one overload matched, we don't need this error because we provided the previous errors. - if (overloadsThatMatchArgCount.size() == 1) - return; - } - - std::string s; - for (size_t i = 0; i < overloadTypes.size(); ++i) - { - TypeId overload = follow(overloadTypes[i]); - - if (i > 0) - s += "; "; - - if (i > 0 && i == overloadTypes.size() - 1) - s += "and "; - - s += toString(overload); - } - - if (overloadsThatMatchArgCount.size() == 0) - reportError(ExtraInformation{"Available overloads: " + s}, call->func->location); - else - reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location); - } - // Note: this is intentionally separated from `visit(AstExprCall*)` for stack allocation purposes. void visitCall(AstExprCall* call) { - TypeArena* arena = &testArena; - Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; - - TypePackId expectedRetType = lookupExpectedPack(call, *arena); - TypeId functionType = lookupType(call->func); - TypeId testFunctionType = functionType; + TypePackId expectedRetType = lookupExpectedPack(call, testArena); TypePack args; std::vector argLocs; argLocs.reserve(call->args.size + 1); - if (get(functionType) || get(functionType) || get(functionType)) + TypeId* maybeOriginalCallTy = module->astOriginalCallTypes.find(call); + TypeId* maybeSelectedOverload = module->astOverloadResolvedTypes.find(call); + + if (!maybeOriginalCallTy) + return; + + TypeId originalCallTy = follow(*maybeOriginalCallTy); + std::vector overloads = flattenIntersection(originalCallTy); + + if (get(originalCallTy) || get(originalCallTy) || get(originalCallTy)) return; - else if (std::optional callMm = findMetatableEntry(builtinTypes, module->errors, functionType, "__call", call->func->location)) + else if (std::optional callMm = findMetatableEntry(builtinTypes, module->errors, originalCallTy, "__call", call->func->location)) { if (get(follow(*callMm))) { - if (std::optional instantiatedCallMm = instantiation.substitute(*callMm)) - { - args.head.push_back(functionType); - argLocs.push_back(call->func->location); - testFunctionType = follow(*instantiatedCallMm); - } - else - { - reportError(UnificationTooComplex{}, call->func->location); - return; - } + args.head.push_back(originalCallTy); + argLocs.push_back(call->func->location); } else { @@ -969,29 +894,16 @@ struct TypeChecker2 return; } } - else if (get(functionType)) - { - if (std::optional instantiatedFunctionType = instantiation.substitute(functionType)) - { - testFunctionType = *instantiatedFunctionType; - } - else - { - reportError(UnificationTooComplex{}, call->func->location); - return; - } - } - else if (auto itv = get(functionType)) + else if (get(originalCallTy) || get(originalCallTy)) { - // We do nothing here because we'll flatten the intersection later, but we don't want to report it as a non-function. } - else if (auto utv = get(functionType)) + else if (auto utv = get(originalCallTy)) { // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. // Another scenario we might run into it is if the union has a nil member. In this case, we want to throw an error - if (isOptional(functionType)) + if (isOptional(originalCallTy)) { - reportError(OptionalValueAccess{functionType}, call->location); + reportError(OptionalValueAccess{originalCallTy}, call->location); return; } std::optional fst; @@ -1001,7 +913,7 @@ struct TypeChecker2 fst = follow(ty); else if (fst != follow(ty)) { - reportError(CannotCallNonFunction{functionType}, call->func->location); + reportError(CannotCallNonFunction{originalCallTy}, call->func->location); return; } } @@ -1009,19 +921,16 @@ struct TypeChecker2 if (!fst) ice->ice("UnionType had no elements, so fst is nullopt?"); - if (std::optional instantiatedFunctionType = instantiation.substitute(*fst)) - { - testFunctionType = *instantiatedFunctionType; - } - else + originalCallTy = follow(*fst); + if (!get(originalCallTy)) { - reportError(UnificationTooComplex{}, call->func->location); + reportError(CannotCallNonFunction{originalCallTy}, call->func->location); return; } } else { - reportError(CannotCallNonFunction{functionType}, call->func->location); + reportError(CannotCallNonFunction{originalCallTy}, call->func->location); return; } @@ -1054,63 +963,134 @@ struct TypeChecker2 args.head.push_back(builtinTypes->anyType); } - TypePackId expectedArgTypes = arena->addTypePack(args); + TypePackId expectedArgTypes = testArena.addTypePack(args); - std::vector overloads = flattenIntersection(testFunctionType); - std::vector> overloadsErrors; - overloadsErrors.reserve(overloads.size()); - - std::vector overloadsThatMatchArgCount; - - for (TypeId overload : overloads) + if (maybeSelectedOverload) { - overload = follow(overload); + // This overload might not work still: the constraint solver will + // pass the type checker an instantiated function type that matches + // in arity, but not in subtyping, in order to allow the type + // checker to report better error messages. - const FunctionType* overloadFn = get(overload); - if (!overloadFn) + TypeId selectedOverload = follow(*maybeSelectedOverload); + const FunctionType* ftv; + + if (get(selectedOverload) || get(selectedOverload) || get(selectedOverload)) { - reportError(CannotCallNonFunction{overload}, call->func->location); return; } + else if (const FunctionType* overloadFtv = get(selectedOverload)) + { + ftv = overloadFtv; + } else { - // We may have to instantiate the overload in order for it to typecheck. - if (std::optional instantiatedFunctionType = instantiation.substitute(overload)) + reportError(CannotCallNonFunction{selectedOverload}, call->func->location); + return; + } + + LUAU_ASSERT(ftv); + reportErrors(tryUnify(stack.back(), call->location, ftv->retTypes, expectedRetType, CountMismatch::Context::Return)); + + auto it = begin(expectedArgTypes); + size_t i = 0; + std::vector slice; + for (TypeId arg : ftv->argTypes) + { + if (it == end(expectedArgTypes)) { - overloadFn = get(*instantiatedFunctionType); + slice.push_back(arg); + continue; } - else + + TypeId expectedArg = *it; + + Location argLoc = argLocs.at(i >= argLocs.size() ? argLocs.size() - 1 : i); + + reportErrors(tryUnify(stack.back(), argLoc, expectedArg, arg)); + + ++it; + ++i; + } + + if (slice.size() > 0 && it == end(expectedArgTypes)) + { + if (auto tail = it.tail()) { - overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overload); - return; + TypePackId remainingArgs = testArena.addTypePack(TypePack{std::move(slice), std::nullopt}); + reportErrors(tryUnify(stack.back(), argLocs.back(), *tail, remainingArgs)); } } - ErrorVec overloadErrors = visitOverload(call, NotNull{overloadFn}, argLocs, expectedArgTypes, expectedRetType); - if (overloadErrors.empty()) - return; + // We do not need to do an arity test because this overload was + // selected based on its arity already matching. + } + else + { + // No overload worked, even when instantiated. We need to filter the + // set of overloads to those that match the arity of the incoming + // argument set, and then report only those as not matching. - bool argMismatch = false; - for (auto error : overloadErrors) + std::vector arityMatchingOverloads; + ErrorVec empty; + for (TypeId overload : overloads) { - CountMismatch* cm = get(error); - if (!cm) - continue; - - if (cm->context == CountMismatch::Arg) + overload = follow(overload); + if (const FunctionType* ftv = get(overload)) { - argMismatch = true; - break; + if (size(ftv->argTypes) == size(expectedArgTypes)) + { + arityMatchingOverloads.push_back(overload); + } + } + else if (const std::optional callMm = findMetatableEntry(builtinTypes, empty, overload, "__call", call->location)) + { + if (const FunctionType* ftv = get(follow(*callMm))) + { + if (size(ftv->argTypes) == size(expectedArgTypes)) + { + arityMatchingOverloads.push_back(overload); + } + } + else + { + reportError(CannotCallNonFunction{}, call->location); + } } } - if (!argMismatch) - overloadsThatMatchArgCount.push_back(overload); + if (arityMatchingOverloads.size() == 0) + { + reportError( + GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location); + } + else + { + // We have handled the case of a singular arity-matching + // overload above, in the case where an overload was selected. + // LUAU_ASSERT(arityMatchingOverloads.size() > 1); + reportError(GenericError{"None of the overloads for function that accept " + std::to_string(size(expectedArgTypes)) + + " arguments are compatible."}, + call->location); + } - overloadsErrors.emplace_back(std::move(overloadErrors), overload); - } + std::string s; + std::vector& stringifyOverloads = arityMatchingOverloads.size() == 0 ? overloads : arityMatchingOverloads; + for (size_t i = 0; i < stringifyOverloads.size(); ++i) + { + TypeId overload = follow(stringifyOverloads[i]); + + if (i > 0) + s += "; "; + + if (i > 0 && i == stringifyOverloads.size() - 1) + s += "and "; + + s += toString(overload); + } - reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); + reportError(ExtraInformation{"Available overloads: " + s}, call->func->location); + } } void visit(AstExprCall* call) @@ -2077,17 +2057,9 @@ struct TypeChecker2 fetch(norm->tops); fetch(norm->booleans); - if (FFlag::LuauNegatedClassTypes) - { - for (const auto& [ty, _negations] : norm->classes.classes) - { - fetch(ty); - } - } - else + for (const auto& [ty, _negations] : norm->classes.classes) { - for (TypeId ty : norm->DEPRECATED_classes) - fetch(ty); + fetch(ty); } fetch(norm->errors); fetch(norm->nils); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8f9e1851b..1ccba91e7 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,7 +35,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) @@ -1701,7 +1700,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { - std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; + std::optional superTy = std::make_optional(builtinTypes->classType); if (declaredClass.superName) { Name superName = Name(declaredClass.superName->value); @@ -5968,17 +5967,13 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r TypeId type = follow(typeFun->type); // You cannot refine to the top class type. - if (FFlag::LuauNegatedClassTypes) + if (type == builtinTypes->classType) { - if (type == builtinTypes->classType) - { - return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - } + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); } // We're only interested in the root class of any classes. - if (auto ctv = get(type); - !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent != builtinTypes->classType) : (ctv->parent != std::nullopt))) + if (auto ctv = get(type); !ctv || ctv->parent != builtinTypes->classType) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 6873820a7..0db0e5a11 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -255,15 +255,17 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId follow(TypePackId tp) { - return follow(tp, [](TypePackId t) { + return follow(tp, nullptr, [](const void*, TypePackId t) { return t; }); } -TypePackId follow(TypePackId tp, std::function mapper) +TypePackId follow(TypePackId tp, const void* context, TypePackId (*mapper)(const void*, TypePackId)) { - auto advance = [&mapper](TypePackId ty) -> std::optional { - if (const Unifiable::Bound* btv = get>(mapper(ty))) + auto advance = [context, mapper](TypePackId ty) -> std::optional { + TypePackId mapped = mapper(context, ty); + + if (const Unifiable::Bound* btv = get>(mapped)) return btv->boundTo; else return std::nullopt; @@ -275,6 +277,9 @@ TypePackId follow(TypePackId tp, std::function mapper) else return tp; + if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null + return cycleTester; + while (true) { auto a1 = advance(tp); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 3ca93591a..6047a49b1 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -26,8 +26,6 @@ LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) -LUAU_FASTFLAG(LuauNegatedClassTypes) -LUAU_FASTFLAG(LuauNegatedTableTypes) namespace Luau { @@ -344,6 +342,19 @@ std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } +std::optional hasCountMismatch(const ErrorVec& errors) +{ + auto isCountMismatch = [](const TypeError& te) { + return nullptr != get(te); + }; + + auto it = std::find_if(errors.begin(), errors.end(), isCountMismatch); + if (it == errors.end()) + return std::nullopt; + else + return *it; +} + // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { @@ -620,7 +631,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Ok. Do nothing. forall functions F, F <: function } - else if (FFlag::LuauNegatedTableTypes && isPrim(superTy, PrimitiveType::Table) && (get(subTy) || get(subTy))) + else if (isPrim(superTy, PrimitiveType::Table) && (get(subTy) || get(subTy))) { // Ok, do nothing: forall tables T, T <: table } @@ -1183,81 +1194,59 @@ void Unifier::tryUnifyNormalizedTypes( if (!get(superNorm.errors)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - if (FFlag::LuauNegatedClassTypes) + for (const auto& [subClass, _] : subNorm.classes.classes) { - for (const auto& [subClass, _] : subNorm.classes.classes) + bool found = false; + const ClassType* subCtv = get(subClass); + LUAU_ASSERT(subCtv); + + for (const auto& [superClass, superNegations] : superNorm.classes.classes) { - bool found = false; - const ClassType* subCtv = get(subClass); - LUAU_ASSERT(subCtv); + const ClassType* superCtv = get(superClass); + LUAU_ASSERT(superCtv); - for (const auto& [superClass, superNegations] : superNorm.classes.classes) + if (isSubclass(subCtv, superCtv)) { - const ClassType* superCtv = get(superClass); - LUAU_ASSERT(superCtv); - - if (isSubclass(subCtv, superCtv)) - { - found = true; - - for (TypeId negation : superNegations) - { - const ClassType* negationCtv = get(negation); - LUAU_ASSERT(negationCtv); - - if (isSubclass(subCtv, negationCtv)) - { - found = false; - break; - } - } - - if (found) - break; - } - } + found = true; - if (FFlag::DebugLuauDeferredConstraintResolution) - { - for (TypeId superTable : superNorm.tables) + for (TypeId negation : superNegations) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify(subClass, superTable); + const ClassType* negationCtv = get(negation); + LUAU_ASSERT(negationCtv); - if (innerState.errors.empty()) + if (isSubclass(subCtv, negationCtv)) { - found = true; - log.concat(std::move(innerState.log)); + found = false; break; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - return reportError(*e); } - } - if (!found) - { - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + if (found) + break; } } - } - else - { - for (TypeId subClass : subNorm.DEPRECATED_classes) + + if (FFlag::DebugLuauDeferredConstraintResolution) { - bool found = false; - const ClassType* subCtv = get(subClass); - for (TypeId superClass : superNorm.DEPRECATED_classes) + for (TypeId superTable : superNorm.tables) { - const ClassType* superCtv = get(superClass); - if (isSubclass(subCtv, superCtv)) + Unifier innerState = makeChildUnifier(); + innerState.tryUnify(subClass, superTable); + + if (innerState.errors.empty()) { found = true; + log.concat(std::move(innerState.log)); break; } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); } - if (!found) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } + + if (!found) + { + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } } @@ -1266,7 +1255,7 @@ void Unifier::tryUnifyNormalizedTypes( bool found = false; for (TypeId superTable : superNorm.tables) { - if (FFlag::LuauNegatedTableTypes && isPrim(superTable, PrimitiveType::Table)) + if (isPrim(superTable, PrimitiveType::Table)) { found = true; break; diff --git a/Ast/include/Luau/ParseOptions.h b/Ast/include/Luau/ParseOptions.h index 89e79528b..01f2a74fa 100644 --- a/Ast/include/Luau/ParseOptions.h +++ b/Ast/include/Luau/ParseOptions.h @@ -14,8 +14,6 @@ enum class Mode struct ParseOptions { - bool allowTypeAnnotations = true; - bool supportContinueStatement = true; bool allowDeclarationSyntax = false; bool captureComments = false; }; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 6a76eda22..7cae609d1 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -14,8 +14,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false) - #define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" namespace Luau @@ -327,22 +325,19 @@ AstStat* Parser::parseStat() // we know this isn't a call or an assignment; therefore it must be a context-sensitive keyword such as `type` or `continue` AstName ident = getIdentifier(expr); - if (options.allowTypeAnnotations) - { - if (ident == "type") - return parseTypeAlias(expr->location, /* exported= */ false); + if (ident == "type") + return parseTypeAlias(expr->location, /* exported= */ false); - if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") - { - nextLexeme(); - return parseTypeAlias(expr->location, /* exported= */ true); - } + if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") + { + nextLexeme(); + return parseTypeAlias(expr->location, /* exported= */ true); } - if (options.supportContinueStatement && ident == "continue") + if (ident == "continue") return parseContinue(expr->location); - if (options.allowTypeAnnotations && options.allowDeclarationSyntax) + if (options.allowDeclarationSyntax) { if (ident == "declare") return parseDeclaration(expr->location); @@ -1123,7 +1118,7 @@ std::tuple Parser::parseBindingList(TempVector& result, TempVector Parser::parseOptionalReturnType() { - if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) + if (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow) { if (lexer.current().type == Lexeme::SkinnyArrow) report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); @@ -2056,7 +2051,7 @@ AstExpr* Parser::parseAssertionExpr() Location start = lexer.current().location; AstExpr* expr = parseSimpleExpr(); - if (options.allowTypeAnnotations && lexer.current().type == Lexeme::DoubleColon) + if (lexer.current().type == Lexeme::DoubleColon) { nextLexeme(); AstType* annotation = parseType(); @@ -2449,24 +2444,13 @@ std::pair, AstArray> Parser::parseG seenDefault = true; nextLexeme(); - Lexeme packBegin = lexer.current(); - if (shouldParseTypePack(lexer)) { AstTypePack* typePack = parseTypePack(); namePacks.push_back({name, nameLocation, typePack}); } - else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') - { - auto [type, typePack] = parseTypeOrPack(); - - if (type) - report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); - - namePacks.push_back({name, nameLocation, typePack}); - } - else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) + else { auto [type, typePack] = parseTypeOrPack(); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 6d1f54514..50fef7fc6 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -9,6 +9,13 @@ #include "FileUtils.h" #include "Flags.h" +#include +#include +#include +#include +#include +#include + #ifdef CALLGRIND #include #endif @@ -64,26 +71,29 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin report(format, name, warning.location, Luau::LintWarning::getName(warning.code), warning.text.c_str()); } -static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate) +static bool reportModuleResult(Luau::Frontend& frontend, const Luau::ModuleName& name, ReportFormat format, bool annotate) { - Luau::CheckResult cr; + std::optional cr = frontend.getCheckResult(name, false); - if (frontend.isDirty(name)) - cr = frontend.check(name); + if (!cr) + { + fprintf(stderr, "Failed to find result for %s\n", name.c_str()); + return false; + } if (!frontend.getSourceModule(name)) { - fprintf(stderr, "Error opening %s\n", name); + fprintf(stderr, "Error opening %s\n", name.c_str()); return false; } - for (auto& error : cr.errors) + for (auto& error : cr->errors) reportError(frontend, format, error); std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); - for (auto& error : cr.lintResult.errors) + for (auto& error : cr->lintResult.errors) reportWarning(format, humanReadableName.c_str(), error); - for (auto& warning : cr.lintResult.warnings) + for (auto& warning : cr->lintResult.warnings) reportWarning(format, humanReadableName.c_str(), warning); if (annotate) @@ -98,7 +108,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat printf("%s", annotated.c_str()); } - return cr.errors.empty() && cr.lintResult.errors.empty(); + return cr->errors.empty() && cr->lintResult.errors.empty(); } static void displayHelp(const char* argv0) @@ -216,6 +226,70 @@ struct CliConfigResolver : Luau::ConfigResolver } }; +struct TaskScheduler +{ + TaskScheduler(unsigned threadCount) + : threadCount(threadCount) + { + for (unsigned i = 0; i < threadCount; i++) + { + workers.emplace_back([this] { + workerFunction(); + }); + } + } + + ~TaskScheduler() + { + for (unsigned i = 0; i < threadCount; i++) + push({}); + + for (std::thread& worker : workers) + worker.join(); + } + + std::function pop() + { + std::unique_lock guard(mtx); + + cv.wait(guard, [this] { + return !tasks.empty(); + }); + + std::function task = tasks.front(); + tasks.pop(); + return task; + } + + void push(std::function task) + { + { + std::unique_lock guard(mtx); + tasks.push(std::move(task)); + } + + cv.notify_one(); + } + + static unsigned getThreadCount() + { + return std::max(std::thread::hardware_concurrency(), 1u); + } + +private: + void workerFunction() + { + while (std::function task = pop()) + task(); + } + + unsigned threadCount = 1; + std::mutex mtx; + std::condition_variable cv; + std::vector workers; + std::queue> tasks; +}; + int main(int argc, char** argv) { Luau::assertHandler() = assertionHandler; @@ -231,6 +305,7 @@ int main(int argc, char** argv) ReportFormat format = ReportFormat::Default; Luau::Mode mode = Luau::Mode::Nonstrict; bool annotate = false; + int threadCount = 0; for (int i = 1; i < argc; ++i) { @@ -249,6 +324,8 @@ int main(int argc, char** argv) FFlag::DebugLuauTimeTracing.value = true; else if (strncmp(argv[i], "--fflags=", 9) == 0) setLuauFlags(argv[i] + 9); + else if (strncmp(argv[i], "-j", 2) == 0) + threadCount = strtol(argv[i] + 2, nullptr, 10); } #if !defined(LUAU_ENABLE_TIME_TRACE) @@ -276,10 +353,28 @@ int main(int argc, char** argv) std::vector files = getSourceFiles(argc, argv); + for (const std::string& path : files) + frontend.queueModuleCheck(path); + + std::vector checkedModules; + + // If thread count is not set, try to use HW thread count, but with an upper limit + // When we improve scalability of typechecking, upper limit can be adjusted/removed + if (threadCount <= 0) + threadCount = std::min(TaskScheduler::getThreadCount(), 8u); + + { + TaskScheduler scheduler(threadCount); + + checkedModules = frontend.checkQueuedModules(std::nullopt, [&](std::function f) { + scheduler.push(std::move(f)); + }); + } + int failed = 0; - for (const std::string& path : files) - failed += !analyzeFile(frontend, path.c_str(), format, annotate); + for (const Luau::ModuleName& name : checkedModules) + failed += !reportModuleResult(frontend, name, format, annotate); if (!configResolver.configErrors.empty()) { diff --git a/CLI/Ast.cpp b/CLI/Ast.cpp index 99c583936..b5a922aaa 100644 --- a/CLI/Ast.cpp +++ b/CLI/Ast.cpp @@ -64,8 +64,6 @@ int main(int argc, char** argv) Luau::ParseOptions options; options.captureComments = true; - options.supportContinueStatement = true; - options.allowTypeAnnotations = true; options.allowDeclarationSyntax = true; Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 75b4940a6..5418009a8 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -35,6 +35,8 @@ struct RegisterSet uint8_t varargStart = 0; }; +void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart); + struct CfgInfo { std::vector predecessors; @@ -43,10 +45,22 @@ struct CfgInfo std::vector successors; std::vector successorsOffsets; + // VM registers that are live when the block is entered + // Additionally, an active variadic sequence can exist at the entry of the block std::vector in; + + // VM registers that are defined inside the block + // It can also contain a variadic sequence definition if that hasn't been consumed inside the block + // Note that this means that checking 'def' set might not be enough to say that register has not been written to std::vector def; + + // VM registers that are coming out from the block + // These might be registers that are defined inside the block or have been defined at the entry of the block + // Additionally, an active variadic sequence can exist at the exit of the block std::vector out; + // VM registers captured by nested closures + // This set can never have an active variadic sequence RegisterSet captured; }; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index addd18f6b..4bc9c8237 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -575,7 +575,7 @@ enum class IrCmd : uint8_t // Calls native libm function with 1 or 2 arguments // A: builtin function ID // B: double - // C: double (optional, 2nd argument) + // C: double/int (optional, 2nd argument) INVOKE_LIBM, }; diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index 1bc31d9d7..179edd0de 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -30,7 +30,7 @@ void toString(IrToStringContext& ctx, IrOp op); void toString(std::string& result, IrConst constant); -void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo); +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, bool includeUseInfo); void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo); // Block title std::string toString(const IrFunction& function, bool includeUseInfo); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 3cf18cd48..a1211d46a 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -114,6 +114,28 @@ inline bool isBlockTerminator(IrCmd cmd) return false; } +inline bool isNonTerminatingJump(IrCmd cmd) +{ + switch (cmd) + { + case IrCmd::TRY_NUM_TO_INDEX: + case IrCmd::TRY_CALL_FASTGETTM: + case IrCmd::CHECK_FASTCALL_RES: + case IrCmd::CHECK_TAG: + case IrCmd::CHECK_READONLY: + case IrCmd::CHECK_NO_METATABLE: + case IrCmd::CHECK_SAFE_ENV: + case IrCmd::CHECK_ARRAY_SIZE: + case IrCmd::CHECK_SLOT_MATCH: + case IrCmd::CHECK_NODE_NO_NEXT: + return true; + default: + break; + } + + return false; +} + inline bool hasResult(IrCmd cmd) { switch (cmd) diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h index 8fe55ba61..8a44629fe 100644 --- a/CodeGen/include/Luau/UnwindBuilder.h +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -1,8 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/RegisterA64.h" #include "Luau/RegisterX64.h" +#include + #include #include @@ -17,22 +20,36 @@ static uint32_t kFullBlockFuncton = ~0u; class UnwindBuilder { public: + enum Arch + { + X64, + A64 + }; + virtual ~UnwindBuilder() = default; virtual void setBeginOffset(size_t beginOffset) = 0; virtual size_t getBeginOffset() const = 0; - virtual void startInfo() = 0; - + virtual void startInfo(Arch arch) = 0; virtual void startFunction() = 0; - virtual void spill(int espOffset, X64::RegisterX64 reg) = 0; - virtual void save(X64::RegisterX64 reg) = 0; - virtual void allocStack(int size) = 0; - virtual void setupFrameReg(X64::RegisterX64 reg, int espOffset) = 0; virtual void finishFunction(uint32_t beginOffset, uint32_t endOffset) = 0; - virtual void finishInfo() = 0; + // A64-specific; prologue must look like this: + // sub sp, sp, stackSize + // store sequence that saves regs to [sp..sp+regs.size*8) in the order specified in regs; regs should start with x29, x30 (fp, lr) + // mov x29, sp + virtual void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) = 0; + + // X64-specific; prologue must look like this: + // optional, indicated by setupFrame: + // push rbp + // mov rbp, rsp + // push reg in the order specified in regs + // sub rsp, stackSize + virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) = 0; + virtual size_t getSize() const = 0; virtual size_t getFunctionCount() const = 0; diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index 9f862d23f..66749bfc0 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -24,17 +24,14 @@ class UnwindBuilderDwarf2 : public UnwindBuilder void setBeginOffset(size_t beginOffset) override; size_t getBeginOffset() const override; - void startInfo() override; - + void startInfo(Arch arch) override; void startFunction() override; - void spill(int espOffset, X64::RegisterX64 reg) override; - void save(X64::RegisterX64 reg) override; - void allocStack(int size) override; - void setupFrameReg(X64::RegisterX64 reg, int espOffset) override; void finishFunction(uint32_t beginOffset, uint32_t endOffset) override; - void finishInfo() override; + void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) override; + void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) override; + size_t getSize() const override; size_t getFunctionCount() const override; @@ -49,8 +46,6 @@ class UnwindBuilderDwarf2 : public UnwindBuilder uint8_t rawData[kRawDataLimit]; uint8_t* pos = rawData; - uint32_t stackOffset = 0; - // We will remember the FDE location to write some of the fields like entry length, function start and size later uint8_t* fdeEntryStart = nullptr; }; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index ccd7125d7..5afed6938 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -44,17 +44,14 @@ class UnwindBuilderWin : public UnwindBuilder void setBeginOffset(size_t beginOffset) override; size_t getBeginOffset() const override; - void startInfo() override; - + void startInfo(Arch arch) override; void startFunction() override; - void spill(int espOffset, X64::RegisterX64 reg) override; - void save(X64::RegisterX64 reg) override; - void allocStack(int size) override; - void setupFrameReg(X64::RegisterX64 reg, int espOffset) override; void finishFunction(uint32_t beginOffset, uint32_t endOffset) override; - void finishInfo() override; + void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) override; + void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) override; + size_t getSize() const override; size_t getFunctionCount() const override; @@ -75,7 +72,6 @@ class UnwindBuilderWin : public UnwindBuilder uint8_t prologSize = 0; X64::RegisterX64 frameReg = X64::noreg; uint8_t frameRegOffset = 0; - uint32_t stackOffset = 0; }; } // namespace CodeGen diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index ccd15facb..9e338071b 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -22,12 +22,25 @@ extern "C" void __register_frame(const void*); extern "C" void __deregister_frame(const void*); +extern "C" void __unw_add_dynamic_fde() __attribute__((weak)); + #endif -#if defined(__APPLE__) -// On Mac, each FDE inside eh_frame section has to be handled separately +namespace Luau +{ +namespace CodeGen +{ + +#if !defined(_WIN32) static void visitFdeEntries(char* pos, void (*cb)(const void*)) { + // When using glibc++ unwinder, we need to call __register_frame/__deregister_frame on the entire .eh_frame data + // When using libc++ unwinder (libunwind), each FDE has to be handled separately + // libc++ unwinder is the macOS unwinder, but on Linux the unwinder depends on the library the executable is linked with + // __unw_add_dynamic_fde is specific to libc++ unwinder, as such we determine the library based on its existence + if (__unw_add_dynamic_fde == nullptr) + return cb(pos); + for (;;) { unsigned partLength; @@ -47,11 +60,6 @@ static void visitFdeEntries(char* pos, void (*cb)(const void*)) } #endif -namespace Luau -{ -namespace CodeGen -{ - void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) { UnwindBuilder* unwind = (UnwindBuilder*)context; @@ -70,10 +78,8 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz LUAU_ASSERT(!"failed to allocate function table"); return nullptr; } -#elif defined(__APPLE__) - visitFdeEntries(unwindData, __register_frame); #elif !defined(_WIN32) - __register_frame(unwindData); + visitFdeEntries(unwindData, __register_frame); #endif beginOffset = unwindSize + unwind->getBeginOffset(); @@ -85,10 +91,8 @@ void destroyBlockUnwindInfo(void* context, void* unwindData) #if defined(_WIN32) && defined(_M_X64) if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData)) LUAU_ASSERT(!"failed to deallocate function table"); -#elif defined(__APPLE__) - visitFdeEntries((char*)unwindData, __deregister_frame); #elif !defined(_WIN32) - __deregister_frame(unwindData); + visitFdeEntries((char*)unwindData, __deregister_frame); #endif } diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index f0be5b3d8..ab092faac 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -134,7 +134,6 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& for (size_t i = 0; i < sortedBlocks.size(); ++i) { uint32_t blockIndex = sortedBlocks[i]; - IrBlock& block = function.blocks[blockIndex]; if (block.kind == IrBlockKind::Dead) @@ -191,10 +190,13 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& continue; } + // Either instruction result value is not referenced or the use count is not zero + LUAU_ASSERT(inst.lastUse == 0 || inst.useCount != 0); + if (options.includeIr) { build.logAppend("# "); - toStringDetailed(ctx, inst, index, /* includeUseInfo */ true); + toStringDetailed(ctx, block, blockIndex, inst, index, /* includeUseInfo */ true); } IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; @@ -409,9 +411,11 @@ bool isSupported() if (sizeof(LuaNode) != 32) return false; - // TODO: A64 codegen does not generate correct unwind info at the moment so it requires longjmp instead of C++ exceptions +#ifdef _WIN32 + // Unwind info is not supported for Windows-on-ARM yet if (!LUA_USE_LONGJMP) return false; +#endif return true; #else diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 415cfc926..fbe44e23e 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -123,9 +123,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde // Arguments: x0 = lua_State*, x1 = Proto*, x2 = native code pointer to jump to, x3 = NativeContext* locations.start = build.setLabel(); - unwind.startFunction(); - - unwind.allocStack(8); // TODO: this is just a hack to make UnwindBuilder assertions cooperate // prologue build.sub(sp, sp, kStackSize); @@ -140,6 +137,8 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde locations.prologueEnd = build.setLabel(); + uint32_t prologueSize = build.getLabelOffset(locations.prologueEnd) - build.getLabelOffset(locations.start); + // Setup native execution environment build.mov(rState, x0); build.mov(rNativeContext, x3); @@ -168,6 +167,8 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde build.ret(); // Our entry function is special, it spans the whole remaining code area + unwind.startFunction(); + unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24}); unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); return locations; @@ -178,7 +179,7 @@ bool initHeaderFunctions(NativeState& data) AssemblyBuilderA64 build(/* logText= */ false); UnwindBuilder& unwind = *data.unwindBuilder.get(); - unwind.startInfo(); + unwind.startInfo(UnwindBuilder::A64); EntryLocations entryLocations = buildEntryFunction(build, unwind); diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 2acb69f96..5f2cd6147 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -58,43 +58,44 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde unwind.startFunction(); // Save common non-volatile registers - build.push(rbp); - unwind.save(rbp); - if (build.abi == ABIX64::SystemV) { + // We need to use a standard rbp-based frame setup for debuggers to work with JIT code + build.push(rbp); build.mov(rbp, rsp); - unwind.setupFrameReg(rbp, 0); } build.push(rbx); - unwind.save(rbx); build.push(r12); - unwind.save(r12); build.push(r13); - unwind.save(r13); build.push(r14); - unwind.save(r14); build.push(r15); - unwind.save(r15); if (build.abi == ABIX64::Windows) { // Save non-volatile registers that are specific to Windows x64 ABI build.push(rdi); - unwind.save(rdi); build.push(rsi); - unwind.save(rsi); + + // On Windows, rbp is available as a general-purpose non-volatile register; we currently don't use it, but we need to push an even number + // of registers for stack alignment... + build.push(rbp); // TODO: once we start using non-volatile SIMD registers on Windows, we will save those here } // Allocate stack space (reg home area + local data) build.sub(rsp, kStackSize + kLocalsSize); - unwind.allocStack(kStackSize + kLocalsSize); locations.prologueEnd = build.setLabel(); + uint32_t prologueSize = build.getLabelOffset(locations.prologueEnd) - build.getLabelOffset(locations.start); + + if (build.abi == ABIX64::SystemV) + unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ true, {rbx, r12, r13, r14, r15}); + else if (build.abi == ABIX64::Windows) + unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ false, {rbx, r12, r13, r14, r15, rdi, rsi, rbp}); + // Setup native execution environment build.mov(rState, rArg1); build.mov(rNativeContext, rArg4); @@ -118,6 +119,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde if (build.abi == ABIX64::Windows) { + build.pop(rbp); build.pop(rsi); build.pop(rdi); } @@ -127,7 +129,10 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde build.pop(r13); build.pop(r12); build.pop(rbx); - build.pop(rbp); + + if (build.abi == ABIX64::SystemV) + build.pop(rbp); + build.ret(); // Our entry function is special, it spans the whole remaining code area @@ -141,7 +146,7 @@ bool initHeaderFunctions(NativeState& data) AssemblyBuilderX64 build(/* logText= */ false); UnwindBuilder& unwind = *data.unwindBuilder.get(); - unwind.startInfo(); + unwind.startInfo(UnwindBuilder::X64); EntryLocations entryLocations = buildEntryFunction(build, unwind); diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index af4c529a3..474dabf67 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -18,19 +18,6 @@ namespace CodeGen namespace X64 { -static void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, OperandX64 arg2) -{ - ScopedRegX64 tmp{regs, SizeX64::qword}; - build.vcvttsd2si(tmp.reg, arg2); - - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); - callWrap.addArgument(SizeX64::qword, tmp); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_ldexp)]); - - build.vmovsd(luauRegValue(ra), xmm0); -} - static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int nresults) { IrCallWrapperX64 callWrap(regs, build); @@ -115,9 +102,6 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r { switch (bfid) { - case LBF_MATH_LDEXP: - LUAU_ASSERT(nparams == 2 && nresults == 1); - return emitBuiltinMathLdexp(regs, build, ra, arg, arg2); case LBF_MATH_FREXP: LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); return emitBuiltinMathFrexp(regs, build, ra, arg, nresults); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index efe9fcc06..efcacb046 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -162,7 +162,7 @@ uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block) return getLiveInOutValueCount(function, block).second; } -static void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart) +void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart) { if (!defRs.varargSeq) { diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 50c1848ea..062321ba6 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -62,6 +62,7 @@ static const char* getTagName(uint8_t tag) case LUA_TTHREAD: return "tthread"; default: + LUAU_ASSERT(!"Unknown type tag"); LUAU_UNREACHABLE(); } } @@ -410,27 +411,6 @@ void toString(std::string& result, IrConst constant) } } -void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo) -{ - size_t start = ctx.result.size(); - - toString(ctx, inst, index); - - if (includeUseInfo) - { - padToDetailColumn(ctx.result, start); - - if (inst.useCount == 0 && hasSideEffects(inst.cmd)) - append(ctx.result, "; %%%u, has side-effects\n", index); - else - append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); - } - else - { - ctx.result.append("\n"); - } -} - static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks) { bool comma = false; @@ -470,6 +450,86 @@ static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs, con } } +static RegisterSet getJumpTargetExtraLiveIn(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst) +{ + RegisterSet extraRs; + + if (blockIdx >= ctx.cfg.in.size()) + return extraRs; + + const RegisterSet& defRs = ctx.cfg.in[blockIdx]; + + // Find first block argument, for guard instructions (isNonTerminatingJump), that's the first and only one + LUAU_ASSERT(isNonTerminatingJump(inst.cmd)); + IrOp op = inst.a; + + if (inst.b.kind == IrOpKind::Block) + op = inst.b; + else if (inst.c.kind == IrOpKind::Block) + op = inst.c; + else if (inst.d.kind == IrOpKind::Block) + op = inst.d; + else if (inst.e.kind == IrOpKind::Block) + op = inst.e; + else if (inst.f.kind == IrOpKind::Block) + op = inst.f; + + if (op.kind == IrOpKind::Block && op.index < ctx.cfg.in.size()) + { + const RegisterSet& inRs = ctx.cfg.in[op.index]; + + extraRs.regs = inRs.regs & ~defRs.regs; + + if (inRs.varargSeq) + requireVariadicSequence(extraRs, defRs, inRs.varargStart); + } + + return extraRs; +} + +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, bool includeUseInfo) +{ + size_t start = ctx.result.size(); + + toString(ctx, inst, instIdx); + + if (includeUseInfo) + { + padToDetailColumn(ctx.result, start); + + if (inst.useCount == 0 && hasSideEffects(inst.cmd)) + { + if (isNonTerminatingJump(inst.cmd)) + { + RegisterSet extraRs = getJumpTargetExtraLiveIn(ctx, block, blockIdx, inst); + + if (extraRs.regs.any() || extraRs.varargSeq) + { + append(ctx.result, "; %%%u, extra in: ", instIdx); + appendRegisterSet(ctx, extraRs, ", "); + ctx.result.append("\n"); + } + else + { + append(ctx.result, "; %%%u\n", instIdx); + } + } + else + { + append(ctx.result, "; %%%u\n", instIdx); + } + } + else + { + append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); + } + } + else + { + ctx.result.append("\n"); + } +} + void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo) { // Report captured registers for entry block @@ -581,7 +641,7 @@ std::string toString(const IrFunction& function, bool includeUseInfo) continue; append(ctx.result, " "); - toStringDetailed(ctx, inst, index, includeUseInfo); + toStringDetailed(ctx, block, uint32_t(i), inst, index, includeUseInfo); } append(ctx.result, "\n"); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 7fd684b4d..6dec80240 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -122,42 +122,6 @@ static bool emitBuiltin( { switch (bfid) { - case LBF_MATH_LDEXP: - LUAU_ASSERT(nparams == 2 && nresults == 1); - - if (args.kind == IrOpKind::VmReg) - { - build.ldr(d1, mem(rBase, args.index * sizeof(TValue) + offsetof(TValue, value.n))); - build.fcvtzs(w0, d1); - } - else if (args.kind == IrOpKind::VmConst) - { - size_t constantOffset = args.index * sizeof(TValue) + offsetof(TValue, value.n); - - // Note: cumulative offset is guaranteed to be divisible by 8 (since we're loading a double); we can use that to expand the useful range - // that doesn't require temporaries - if (constantOffset / 8 <= AddressA64::kMaxOffset) - { - build.ldr(d1, mem(rConstants, int(constantOffset))); - } - else - { - emitAddOffset(build, x0, rConstants, constantOffset); - build.ldr(d1, x0); - } - - build.fcvtzs(w0, d1); - } - else if (args.kind == IrOpKind::Constant) - build.mov(w0, int(function.doubleOp(args))); - else if (args.kind != IrOpKind::Undef) - LUAU_ASSERT(!"Unsupported instruction form"); - - build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); - build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, libm_ldexp))); - build.blr(x1); - build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - return true; case LBF_MATH_FREXP: LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); @@ -1610,12 +1574,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { if (inst.c.kind != IrOpKind::None) { + bool isInt = (inst.c.kind == IrOpKind::Constant) ? constOp(inst.c).kind == IrConstKind::Int + : getCmdValueKind(function.instOp(inst.c).cmd) == IrValueKind::Int; + RegisterA64 temp1 = tempDouble(inst.b); - RegisterA64 temp2 = tempDouble(inst.c); - RegisterA64 temp3 = regs.allocTemp(KindA64::d); // note: spill() frees all registers so we need to avoid alloc after spill + RegisterA64 temp2 = isInt ? tempInt(inst.c) : tempDouble(inst.c); + RegisterA64 temp3 = isInt ? noreg : regs.allocTemp(KindA64::d); // note: spill() frees all registers so we need to avoid alloc after spill regs.spill(build, index, {temp1, temp2}); - if (d0 != temp2) + if (isInt) + { + build.fmov(d0, temp1); + build.mov(w0, temp2); + } + else if (d0 != temp2) { build.fmov(d0, temp1); build.fmov(d1, temp2); @@ -1634,8 +1606,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.fmov(d0, temp1); } - build.ldr(x0, mem(rNativeContext, getNativeContextOffset(uintOp(inst.a)))); - build.blr(x0); + build.ldr(x1, mem(rNativeContext, getNativeContextOffset(uintOp(inst.a)))); + build.blr(x1); inst.regA64 = regs.takeReg(d0, index); break; } diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index bc617571b..8c1f2b044 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -1304,7 +1304,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); if (inst.c.kind != IrOpKind::None) - callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.c), inst.c); + { + bool isInt = (inst.c.kind == IrOpKind::Constant) ? constOp(inst.c).kind == IrConstKind::Int + : getCmdValueKind(function.instOp(inst.c).cmd) == IrValueKind::Int; + + if (isInt) + callWrap.addArgument(SizeX64::dword, memRegUintOp(inst.c), inst.c); + else + callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.c), inst.c); + } callWrap.call(qword[rNativeContext + getNativeContextOffset(uintOp(inst.a))]); inst.regX64 = regs.takeReg(xmm0, index); diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index e58d0a126..cfa4bc6c1 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -71,8 +71,7 @@ static BuiltinImplResult translateBuiltinNumberToNumberLibm( return {BuiltinImplType::UsesFallback, 1}; } -// (number, number, ...) -> number -static BuiltinImplResult translateBuiltin2NumberToNumber( +static BuiltinImplResult translateBuiltin2NumberToNumberLibm( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) @@ -80,7 +79,13 @@ static BuiltinImplResult translateBuiltin2NumberToNumber( builtinCheckDouble(build, build.vmReg(arg), fallback); builtinCheckDouble(build, args, fallback); - build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(2), build.constInt(1)); + + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); + + IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(bfid), va, vb); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); @@ -88,7 +93,7 @@ static BuiltinImplResult translateBuiltin2NumberToNumber( return {BuiltinImplType::UsesFallback, 1}; } -static BuiltinImplResult translateBuiltin2NumberToNumberLibm( +static BuiltinImplResult translateBuiltinMathLdexp( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) @@ -100,7 +105,9 @@ static BuiltinImplResult translateBuiltin2NumberToNumberLibm( IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vb = builtinLoadDouble(build, args); - IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(bfid), va, vb); + IrOp vbi = build.inst(IrCmd::NUM_TO_INT, vb); + + IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(bfid), va, vbi); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); @@ -778,7 +785,7 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_ATAN2: return translateBuiltin2NumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); case LBF_MATH_LDEXP: - return translateBuiltin2NumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinMathLdexp(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); case LBF_MATH_FREXP: case LBF_MATH_MODF: return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index a3af43449..03a6c9c43 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -299,6 +299,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl removeUse(function, inst.e); removeUse(function, inst.f); + // Inherit existing use count (last use is skipped as it will be defined later) + replacement.useCount = inst.useCount; + inst = replacement; // Removing the earlier extra reference, this might leave the block without users without marking it as dead @@ -775,6 +778,8 @@ uint32_t getNativeContextOffset(int bfid) return offsetof(NativeContext, libm_pow); case LBF_IR_MATH_LOG2: return offsetof(NativeContext, libm_log2); + case LBF_MATH_LDEXP: + return offsetof(NativeContext, libm_ldexp); default: LUAU_ASSERT(!"Unsupported bfid"); } diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index e7663666a..926ead3d7 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -290,6 +290,20 @@ struct ConstPropState valueMap[versionedVmRegLoad(loadCmd, storeInst.a)] = storeInst.b.index; } + void clear() + { + for (int i = 0; i <= maxReg; ++i) + regs[i] = RegisterInfo(); + + maxReg = 0; + + inSafeEnv = false; + checkedGc = false; + + instLink.clear(); + valueMap.clear(); + } + IrFunction& function; bool useValueNumbering = false; @@ -854,12 +868,11 @@ static void constPropInBlock(IrBuilder& build, IrBlock& block, ConstPropState& s state.valueMap.clear(); } -static void constPropInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block, bool useValueNumbering) +static void constPropInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block, ConstPropState& state) { IrFunction& function = build.function; - ConstPropState state{function}; - state.useValueNumbering = useValueNumbering; + state.clear(); while (block) { @@ -936,7 +949,7 @@ static std::vector collectDirectBlockJumpPath(IrFunction& function, st return path; } -static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited, IrBlock& startingBlock, bool useValueNumbering) +static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited, IrBlock& startingBlock, ConstPropState& state) { IrFunction& function = build.function; @@ -965,8 +978,9 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited return; // Initialize state with the knowledge of our current block - ConstPropState state{function}; - state.useValueNumbering = useValueNumbering; + state.clear(); + + // TODO: using values from the first block can cause 'live out' of the linear block predecessor to not have all required registers constPropInBlock(build, startingBlock, state); // Veryfy that target hasn't changed @@ -981,10 +995,43 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited replace(function, termInst.a, newBlock); - // Clone the collected path int our fresh block + // Clone the collected path into our fresh block for (uint32_t pathBlockIdx : path) build.clone(function.blocks[pathBlockIdx], /* removeCurrentTerminator */ true); + // If all live in/out data is defined aside from the new block, generate it + // Note that liveness information is not strictly correct after optimization passes and may need to be recomputed before next passes + // The information generated here is consistent with current state that could be outdated, but still useful in IR inspection + if (function.cfg.in.size() == newBlock.index) + { + LUAU_ASSERT(function.cfg.in.size() == function.cfg.out.size()); + LUAU_ASSERT(function.cfg.in.size() == function.cfg.def.size()); + + // Live in is the same as the input of the original first block + function.cfg.in.push_back(function.cfg.in[path.front()]); + + // Live out is the same as the result of the original last block + function.cfg.out.push_back(function.cfg.out[path.back()]); + + // Defs are tricky, registers are joined together, but variadic sequences can be consumed inside the block + function.cfg.def.push_back({}); + RegisterSet& def = function.cfg.def.back(); + + for (uint32_t pathBlockIdx : path) + { + const RegisterSet& pathDef = function.cfg.def[pathBlockIdx]; + + def.regs |= pathDef.regs; + + // Taking only the last defined variadic sequence if it's not consumed before before the end + if (pathDef.varargSeq && function.cfg.out.back().varargSeq) + { + def.varargSeq = true; + def.varargStart = pathDef.varargStart; + } + } + } + // Optimize our linear block IrBlock& linearBlock = function.blockOp(newBlock); constPropInBlock(build, linearBlock, state); @@ -994,6 +1041,9 @@ void constPropInBlockChains(IrBuilder& build, bool useValueNumbering) { IrFunction& function = build.function; + ConstPropState state{function}; + state.useValueNumbering = useValueNumbering; + std::vector visited(function.blocks.size(), false); for (IrBlock& block : function.blocks) @@ -1004,7 +1054,7 @@ void constPropInBlockChains(IrBuilder& build, bool useValueNumbering) if (visited[function.getBlockIndex(block)]) continue; - constPropInBlockChain(build, visited, &block, useValueNumbering); + constPropInBlockChain(build, visited, &block, state); } } @@ -1015,6 +1065,9 @@ void createLinearBlocks(IrBuilder& build, bool useValueNumbering) // new 'block' will only be reachable from a single one and all gathered information can be preserved. IrFunction& function = build.function; + ConstPropState state{function}; + state.useValueNumbering = useValueNumbering; + std::vector visited(function.blocks.size(), false); // This loop can create new 'linear' blocks, so index-based loop has to be used (and it intentionally won't reach those new blocks) @@ -1029,7 +1082,7 @@ void createLinearBlocks(IrBuilder& build, bool useValueNumbering) if (visited[function.getBlockIndex(block)]) continue; - tryCreateLinearBlock(build, visited, block, useValueNumbering); + tryCreateLinearBlock(build, visited, block, state); } } diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index a4be95ff4..e9df184d0 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -36,27 +36,25 @@ #define DW_CFA_lo_user 0x1c #define DW_CFA_hi_user 0x3f -// Register numbers for x64 (System V ABI, page 57, ch. 3.7, figure 3.36) -#define DW_REG_RAX 0 -#define DW_REG_RDX 1 -#define DW_REG_RCX 2 -#define DW_REG_RBX 3 -#define DW_REG_RSI 4 -#define DW_REG_RDI 5 -#define DW_REG_RBP 6 -#define DW_REG_RSP 7 -#define DW_REG_R8 8 -#define DW_REG_R9 9 -#define DW_REG_R10 10 -#define DW_REG_R11 11 -#define DW_REG_R12 12 -#define DW_REG_R13 13 -#define DW_REG_R14 14 -#define DW_REG_R15 15 -#define DW_REG_RA 16 - -const int regIndexToDwRegX64[16] = {DW_REG_RAX, DW_REG_RCX, DW_REG_RDX, DW_REG_RBX, DW_REG_RSP, DW_REG_RBP, DW_REG_RSI, DW_REG_RDI, DW_REG_R8, - DW_REG_R9, DW_REG_R10, DW_REG_R11, DW_REG_R12, DW_REG_R13, DW_REG_R14, DW_REG_R15}; +// Register numbers for X64 (System V ABI, page 57, ch. 3.7, figure 3.36) +#define DW_REG_X64_RAX 0 +#define DW_REG_X64_RDX 1 +#define DW_REG_X64_RCX 2 +#define DW_REG_X64_RBX 3 +#define DW_REG_X64_RSI 4 +#define DW_REG_X64_RDI 5 +#define DW_REG_X64_RBP 6 +#define DW_REG_X64_RSP 7 +#define DW_REG_X64_RA 16 + +// Register numbers for A64 (DWARF for the Arm 64-bit Architecture, ch. 4.1) +#define DW_REG_A64_FP 29 +#define DW_REG_A64_LR 30 +#define DW_REG_A64_SP 31 + +// X64 register mapping from real register index to DWARF2 (r8..r15 are mapped 1-1, but named registers aren't) +const int regIndexToDwRegX64[16] = {DW_REG_X64_RAX, DW_REG_X64_RCX, DW_REG_X64_RDX, DW_REG_X64_RBX, DW_REG_X64_RSP, DW_REG_X64_RBP, DW_REG_X64_RSI, + DW_REG_X64_RDI, 8, 9, 10, 11, 12, 13, 14, 15}; const int kCodeAlignFactor = 1; const int kDataAlignFactor = 8; @@ -85,7 +83,7 @@ static uint8_t* defineSavedRegisterLocation(uint8_t* pos, int dwReg, uint32_t st { LUAU_ASSERT(stackOffset % kDataAlignFactor == 0 && "stack offsets have to be measured in kDataAlignFactor units"); - if (dwReg <= 15) + if (dwReg <= 0x3f) { pos = writeu8(pos, DW_CFA_offset + dwReg); } @@ -99,8 +97,9 @@ static uint8_t* defineSavedRegisterLocation(uint8_t* pos, int dwReg, uint32_t st return pos; } -static uint8_t* advanceLocation(uint8_t* pos, uint8_t offset) +static uint8_t* advanceLocation(uint8_t* pos, unsigned int offset) { + LUAU_ASSERT(offset < 256); pos = writeu8(pos, DW_CFA_advance_loc1); pos = writeu8(pos, offset); return pos; @@ -132,8 +131,10 @@ size_t UnwindBuilderDwarf2::getBeginOffset() const return beginOffset; } -void UnwindBuilderDwarf2::startInfo() +void UnwindBuilderDwarf2::startInfo(Arch arch) { + LUAU_ASSERT(arch == A64 || arch == X64); + uint8_t* cieLength = pos; pos = writeu32(pos, 0); // Length (to be filled later) @@ -142,15 +143,24 @@ void UnwindBuilderDwarf2::startInfo() pos = writeu8(pos, 0); // CIE augmentation String "" + int ra = arch == A64 ? DW_REG_A64_LR : DW_REG_X64_RA; + pos = writeuleb128(pos, kCodeAlignFactor); // Code align factor pos = writeuleb128(pos, -kDataAlignFactor & 0x7f); // Data align factor of (as signed LEB128) - pos = writeu8(pos, DW_REG_RA); // Return address register + pos = writeu8(pos, ra); // Return address register // Optional CIE augmentation section (not present) - // Call frame instructions (common for all FDEs, of which we have 1) - pos = defineCfaExpression(pos, DW_REG_RSP, 8); // Define CFA to be the rsp + 8 - pos = defineSavedRegisterLocation(pos, DW_REG_RA, 8); // Define return address register (RA) to be located at CFA - 8 + // Call frame instructions (common for all FDEs) + if (arch == A64) + { + pos = defineCfaExpression(pos, DW_REG_A64_SP, 0); // Define CFA to be the sp + } + else + { + pos = defineCfaExpression(pos, DW_REG_X64_RSP, 8); // Define CFA to be the rsp + 8 + pos = defineSavedRegisterLocation(pos, DW_REG_X64_RA, 8); // Define return address register (RA) to be located at CFA - 8 + } pos = alignPosition(cieLength, pos); writeu32(cieLength, unsigned(pos - cieLength - 4)); // Length field itself is excluded from length @@ -165,8 +175,6 @@ void UnwindBuilderDwarf2::startFunction() func.fdeEntryStartPos = uint32_t(pos - rawData); unwindFunctions.push_back(func); - stackOffset = 8; // Return address was pushed by calling the function - fdeEntryStart = pos; // Will be written at the end pos = writeu32(pos, 0); // Length (to be filled later) pos = writeu32(pos, unsigned(pos - rawData)); // CIE pointer @@ -178,42 +186,11 @@ void UnwindBuilderDwarf2::startFunction() // Function call frame instructions to follow } -void UnwindBuilderDwarf2::spill(int espOffset, X64::RegisterX64 reg) -{ - pos = advanceLocation(pos, 5); // REX.W mov [rsp + imm8], reg -} - -void UnwindBuilderDwarf2::save(X64::RegisterX64 reg) -{ - stackOffset += 8; - pos = advanceLocation(pos, 2); // REX.W push reg - pos = defineCfaExpressionOffset(pos, stackOffset); - pos = defineSavedRegisterLocation(pos, regIndexToDwRegX64[reg.index], stackOffset); -} - -void UnwindBuilderDwarf2::allocStack(int size) -{ - stackOffset += size; - pos = advanceLocation(pos, 4); // REX.W sub rsp, imm8 - pos = defineCfaExpressionOffset(pos, stackOffset); -} - -void UnwindBuilderDwarf2::setupFrameReg(X64::RegisterX64 reg, int espOffset) -{ - if (espOffset != 0) - pos = advanceLocation(pos, 5); // REX.W lea rbp, [rsp + imm8] - else - pos = advanceLocation(pos, 3); // REX.W mov rbp, rsp - - // Cfa is based on rsp, so no additonal commands are required -} - void UnwindBuilderDwarf2::finishFunction(uint32_t beginOffset, uint32_t endOffset) { unwindFunctions.back().beginOffset = beginOffset; unwindFunctions.back().endOffset = endOffset; - LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); LUAU_ASSERT(fdeEntryStart != nullptr); pos = alignPosition(fdeEntryStart, pos); @@ -228,6 +205,69 @@ void UnwindBuilderDwarf2::finishInfo() LUAU_ASSERT(getSize() <= kRawDataLimit); } +void UnwindBuilderDwarf2::prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) +{ + LUAU_ASSERT(stackSize % 16 == 0); + LUAU_ASSERT(regs.size() >= 2 && regs.begin()[0] == A64::x29 && regs.begin()[1] == A64::x30); + LUAU_ASSERT(regs.size() * 8 <= stackSize); + + // sub sp, sp, stackSize + pos = advanceLocation(pos, 4); + pos = defineCfaExpressionOffset(pos, stackSize); + + // stp/str to store each register to stack in order + pos = advanceLocation(pos, prologueSize - 4); + + for (size_t i = 0; i < regs.size(); ++i) + { + LUAU_ASSERT(regs.begin()[i].kind == A64::KindA64::x); + pos = defineSavedRegisterLocation(pos, regs.begin()[i].index, stackSize - unsigned(i * 8)); + } +} + +void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) +{ + LUAU_ASSERT(stackSize > 0 && stackSize <= 128 && stackSize % 8 == 0); + + unsigned int stackOffset = 8; // Return address was pushed by calling the function + unsigned int prologueOffset = 0; + + if (setupFrame) + { + // push rbp + stackOffset += 8; + prologueOffset += 2; + pos = advanceLocation(pos, 2); + pos = defineCfaExpressionOffset(pos, stackOffset); + pos = defineSavedRegisterLocation(pos, DW_REG_X64_RBP, stackOffset); + + // mov rbp, rsp + prologueOffset += 3; + pos = advanceLocation(pos, 3); + } + + // push reg + for (X64::RegisterX64 reg : regs) + { + LUAU_ASSERT(reg.size == X64::SizeX64::qword); + + stackOffset += 8; + prologueOffset += 2; + pos = advanceLocation(pos, 2); + pos = defineCfaExpressionOffset(pos, stackOffset); + pos = defineSavedRegisterLocation(pos, regIndexToDwRegX64[reg.index], stackOffset); + } + + // sub rsp, stackSize + stackOffset += stackSize; + prologueOffset += 4; + pos = advanceLocation(pos, 4); + pos = defineCfaExpressionOffset(pos, stackOffset); + + LUAU_ASSERT(stackOffset % 16 == 0); + LUAU_ASSERT(prologueOffset == prologueSize); +} + size_t UnwindBuilderDwarf2::getSize() const { return size_t(pos - rawData); @@ -244,14 +284,14 @@ void UnwindBuilderDwarf2::finalize(char* target, size_t offset, void* funcAddres for (const UnwindFunctionDwarf2& func : unwindFunctions) { - uint8_t* fdeEntryStart = (uint8_t*)target + func.fdeEntryStartPos; + uint8_t* fdeEntry = (uint8_t*)target + func.fdeEntryStartPos; - writeu64(fdeEntryStart + kFdeInitialLocationOffset, uintptr_t(funcAddress) + offset + func.beginOffset); + writeu64(fdeEntry + kFdeInitialLocationOffset, uintptr_t(funcAddress) + offset + func.beginOffset); if (func.endOffset == kFullBlockFuncton) - writeu64(fdeEntryStart + kFdeAddressRangeOffset, funcSize - offset); + writeu64(fdeEntry + kFdeAddressRangeOffset, funcSize - offset); else - writeu64(fdeEntryStart + kFdeAddressRangeOffset, func.endOffset - func.beginOffset); + writeu64(fdeEntry + kFdeAddressRangeOffset, func.endOffset - func.beginOffset); } } diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 5f4f16a9a..f9b927c51 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -31,7 +31,10 @@ size_t UnwindBuilderWin::getBeginOffset() const return beginOffset; } -void UnwindBuilderWin::startInfo() {} +void UnwindBuilderWin::startInfo(Arch arch) +{ + LUAU_ASSERT(arch == X64); +} void UnwindBuilderWin::startFunction() { @@ -50,45 +53,6 @@ void UnwindBuilderWin::startFunction() // rax has register index 0, which in Windows unwind info means that frame register is not used frameReg = X64::rax; frameRegOffset = 0; - - // Return address was pushed by calling the function - stackOffset = 8; -} - -void UnwindBuilderWin::spill(int espOffset, X64::RegisterX64 reg) -{ - prologSize += 5; // REX.W mov [rsp + imm8], reg -} - -void UnwindBuilderWin::save(X64::RegisterX64 reg) -{ - prologSize += 2; // REX.W push reg - stackOffset += 8; - unwindCodes.push_back({prologSize, UWOP_PUSH_NONVOL, reg.index}); -} - -void UnwindBuilderWin::allocStack(int size) -{ - LUAU_ASSERT(size >= 8 && size <= 128 && size % 8 == 0); - - prologSize += 4; // REX.W sub rsp, imm8 - stackOffset += size; - unwindCodes.push_back({prologSize, UWOP_ALLOC_SMALL, uint8_t((size - 8) / 8)}); -} - -void UnwindBuilderWin::setupFrameReg(X64::RegisterX64 reg, int espOffset) -{ - LUAU_ASSERT(espOffset < 256 && espOffset % 16 == 0); - - frameReg = reg; - frameRegOffset = uint8_t(espOffset / 16); - - if (espOffset != 0) - prologSize += 5; // REX.W lea rbp, [rsp + imm8] - else - prologSize += 3; // REX.W mov rbp, rsp - - unwindCodes.push_back({prologSize, UWOP_SET_FPREG, frameRegOffset}); } void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset) @@ -99,8 +63,6 @@ void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset) // Windows unwind code count is stored in uint8_t, so we can't have more LUAU_ASSERT(unwindCodes.size() < 256); - LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); - UnwindInfoWin info; info.version = 1; info.flags = 0; // No EH @@ -142,6 +104,54 @@ void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset) void UnwindBuilderWin::finishInfo() {} +void UnwindBuilderWin::prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) +{ + LUAU_ASSERT(!"Not implemented"); +} + +void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) +{ + LUAU_ASSERT(stackSize > 0 && stackSize <= 128 && stackSize % 8 == 0); + LUAU_ASSERT(prologueSize < 256); + + unsigned int stackOffset = 8; // Return address was pushed by calling the function + unsigned int prologueOffset = 0; + + if (setupFrame) + { + // push rbp + stackOffset += 8; + prologueOffset += 2; + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_PUSH_NONVOL, X64::rbp.index}); + + // mov rbp, rsp + prologueOffset += 3; + frameReg = X64::rbp; + frameRegOffset = 0; + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_SET_FPREG, frameRegOffset}); + } + + // push reg + for (X64::RegisterX64 reg : regs) + { + LUAU_ASSERT(reg.size == X64::SizeX64::qword); + + stackOffset += 8; + prologueOffset += 2; + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_PUSH_NONVOL, reg.index}); + } + + // sub rsp, stackSize + stackOffset += stackSize; + prologueOffset += 4; + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_ALLOC_SMALL, uint8_t((stackSize - 8) / 8)}); + + LUAU_ASSERT(stackOffset % 16 == 0); + LUAU_ASSERT(prologueOffset == prologueSize); + + this->prologSize = prologueSize; +} + size_t UnwindBuilderWin::getSize() const { return sizeof(UnwindFunctionWin) * unwindFunctions.size() + size_t(rawDataPos - rawData); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index b5690acb3..e2b769ec6 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1701,8 +1701,6 @@ void BytecodeBuilder::dumpConstant(std::string& result, int k) const formatAppend(result, "'%s'", func.dumpname.c_str()); break; } - default: - LUAU_UNREACHABLE(); } } diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index c8a184a17..b3edf2ba0 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -913,7 +913,9 @@ static void luau_execute(lua_State* L) // slow-path: not a function call if (LUAU_UNLIKELY(!ttisfunction(ra))) { - VM_PROTECT(luaV_tryfuncTM(L, ra)); + VM_PROTECT_PC(); // luaV_tryfuncTM may fail + + luaV_tryfuncTM(L, ra); argtop++; // __call adds an extra self } diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 01deddd3f..51f216da6 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -135,20 +135,9 @@ TEST_CASE("WindowsUnwindCodesX64") UnwindBuilderWin unwind; - unwind.startInfo(); + unwind.startInfo(UnwindBuilder::X64); unwind.startFunction(); - unwind.spill(16, rdx); - unwind.spill(8, rcx); - unwind.save(rdi); - unwind.save(rsi); - unwind.save(rbx); - unwind.save(rbp); - unwind.save(r12); - unwind.save(r13); - unwind.save(r14); - unwind.save(r15); - unwind.allocStack(72); - unwind.setupFrameReg(rbp, 48); + unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}); unwind.finishFunction(0x11223344, 0x55443322); unwind.finishInfo(); @@ -156,8 +145,8 @@ TEST_CASE("WindowsUnwindCodesX64") data.resize(unwind.getSize()); unwind.finalize(data.data(), 0, nullptr, 0); - std::vector expected{0x44, 0x33, 0x22, 0x11, 0x22, 0x33, 0x44, 0x55, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x23, 0x0a, 0x35, 0x23, 0x33, 0x1e, - 0x82, 0x1a, 0xf0, 0x18, 0xe0, 0x16, 0xd0, 0x14, 0xc0, 0x12, 0x50, 0x10, 0x30, 0x0e, 0x60, 0x0c, 0x70}; + std::vector expected{0x44, 0x33, 0x22, 0x11, 0x22, 0x33, 0x44, 0x55, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x17, 0x0a, 0x05, 0x17, 0x82, 0x13, + 0xf0, 0x11, 0xe0, 0x0f, 0xd0, 0x0d, 0xc0, 0x0b, 0x30, 0x09, 0x60, 0x07, 0x70, 0x05, 0x03, 0x02, 0x50}; REQUIRE(data.size() == expected.size()); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); @@ -170,18 +159,9 @@ TEST_CASE("Dwarf2UnwindCodesX64") UnwindBuilderDwarf2 unwind; - unwind.startInfo(); + unwind.startInfo(UnwindBuilder::X64); unwind.startFunction(); - unwind.save(rdi); - unwind.save(rsi); - unwind.save(rbx); - unwind.save(rbp); - unwind.save(r12); - unwind.save(r13); - unwind.save(r14); - unwind.save(r15); - unwind.allocStack(72); - unwind.setupFrameReg(rbp, 48); + unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}); unwind.finishFunction(0, 0); unwind.finishInfo(); @@ -189,11 +169,36 @@ TEST_CASE("Dwarf2UnwindCodesX64") data.resize(unwind.getSize()); unwind.finalize(data.data(), 0, nullptr, 0); - std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x05, 0x10, 0x01, + std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x90, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x85, 0x02, 0x02, 0x02, 0x0e, 0x18, 0x84, 0x03, 0x02, 0x02, 0x0e, 0x20, 0x83, - 0x04, 0x02, 0x02, 0x0e, 0x28, 0x86, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, 0x0e, 0x40, - 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x02, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00}; + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x86, 0x02, 0x02, 0x03, 0x02, 0x02, 0x0e, 0x18, 0x85, 0x03, 0x02, 0x02, 0x0e, + 0x20, 0x84, 0x04, 0x02, 0x02, 0x0e, 0x28, 0x83, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, + 0x0e, 0x40, 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}; + + REQUIRE(data.size() == expected.size()); + CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); +} + +TEST_CASE("Dwarf2UnwindCodesA64") +{ + using namespace A64; + + UnwindBuilderDwarf2 unwind; + + unwind.startInfo(UnwindBuilder::A64); + unwind.startFunction(); + unwind.prologueA64(/* prologueSize= */ 28, /* stackSize= */ 64, {x29, x30, x19, x20, x21, x22, x23, x24}); + unwind.finishFunction(0, 32); + unwind.finishInfo(); + + std::vector data; + data.resize(unwind.getSize()); + unwind.finalize(data.data(), 0, nullptr, 0); + + std::vector expected{0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x1e, 0x0c, 0x1f, 0x00, 0x2c, 0x00, 0x00, + 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x04, + 0x0e, 0x40, 0x02, 0x18, 0x9d, 0x08, 0x9e, 0x07, 0x93, 0x06, 0x94, 0x05, 0x95, 0x04, 0x96, 0x03, 0x97, 0x02, 0x98, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00}; REQUIRE(data.size() == expected.size()); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); @@ -247,7 +252,7 @@ TEST_CASE("GeneratedCodeExecutionX64") CHECK(result == 210); } -void throwing(int64_t arg) +static void throwing(int64_t arg) { CHECK(arg == 25); @@ -266,27 +271,25 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") std::unique_ptr unwind = std::make_unique(); #endif - unwind->startInfo(); + unwind->startInfo(UnwindBuilder::X64); Label functionBegin = build.setLabel(); unwind->startFunction(); // Prologue + build.push(rbp); + build.mov(rbp, rsp); build.push(rNonVol1); - unwind->save(rNonVol1); build.push(rNonVol2); - unwind->save(rNonVol2); - build.push(rbp); - unwind->save(rbp); int stackSize = 32; int localsSize = 16; build.sub(rsp, stackSize + localsSize); - unwind->allocStack(stackSize + localsSize); - build.lea(rbp, addr[rsp + stackSize]); - unwind->setupFrameReg(rbp, stackSize); + uint32_t prologueSize = build.setLabel().location; + + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}); // Body build.mov(rNonVol1, rArg1); @@ -297,10 +300,10 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") build.call(rNonVol2); // Epilogue - build.lea(rsp, addr[rbp + localsSize]); - build.pop(rbp); + build.add(rsp, stackSize + localsSize); build.pop(rNonVol2); build.pop(rNonVol1); + build.pop(rbp); build.ret(); unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); @@ -349,7 +352,7 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") std::unique_ptr unwind = std::make_unique(); #endif - unwind->startInfo(); + unwind->startInfo(UnwindBuilder::X64); Label start1; Label start2; @@ -360,21 +363,19 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") unwind->startFunction(); // Prologue + build.push(rbp); + build.mov(rbp, rsp); build.push(rNonVol1); - unwind->save(rNonVol1); build.push(rNonVol2); - unwind->save(rNonVol2); - build.push(rbp); - unwind->save(rbp); int stackSize = 32; int localsSize = 16; build.sub(rsp, stackSize + localsSize); - unwind->allocStack(stackSize + localsSize); - build.lea(rbp, addr[rsp + stackSize]); - unwind->setupFrameReg(rbp, stackSize); + uint32_t prologueSize = build.setLabel().location - start1.location; + + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}); // Body build.mov(rNonVol1, rArg1); @@ -385,41 +386,35 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") build.call(rNonVol2); // Epilogue - build.lea(rsp, addr[rbp + localsSize]); - build.pop(rbp); + build.add(rsp, stackSize + localsSize); build.pop(rNonVol2); build.pop(rNonVol1); + build.pop(rbp); build.ret(); Label end1 = build.setLabel(); unwind->finishFunction(build.getLabelOffset(start1), build.getLabelOffset(end1)); } - // Second function with different layout + // Second function with different layout and no frame { build.setLabel(start2); unwind->startFunction(); // Prologue build.push(rNonVol1); - unwind->save(rNonVol1); build.push(rNonVol2); - unwind->save(rNonVol2); build.push(rNonVol3); - unwind->save(rNonVol3); build.push(rNonVol4); - unwind->save(rNonVol4); - build.push(rbp); - unwind->save(rbp); int stackSize = 32; - int localsSize = 32; + int localsSize = 24; build.sub(rsp, stackSize + localsSize); - unwind->allocStack(stackSize + localsSize); - build.lea(rbp, addr[rsp + stackSize]); - unwind->setupFrameReg(rbp, stackSize); + uint32_t prologueSize = build.setLabel().location - start2.location; + + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rNonVol3, rNonVol4}); // Body build.mov(rNonVol3, rArg1); @@ -430,8 +425,7 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") build.call(rNonVol4); // Epilogue - build.lea(rsp, addr[rbp + localsSize]); - build.pop(rbp); + build.add(rsp, stackSize + localsSize); build.pop(rNonVol4); build.pop(rNonVol3); build.pop(rNonVol2); @@ -495,37 +489,29 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") std::unique_ptr unwind = std::make_unique(); #endif - unwind->startInfo(); + unwind->startInfo(UnwindBuilder::X64); Label functionBegin = build.setLabel(); unwind->startFunction(); // Prologue (some of these registers don't have to be saved, but we want to have a big prologue) + build.push(rbp); + build.mov(rbp, rsp); build.push(r10); - unwind->save(r10); build.push(r11); - unwind->save(r11); build.push(r12); - unwind->save(r12); build.push(r13); - unwind->save(r13); build.push(r14); - unwind->save(r14); build.push(r15); - unwind->save(r15); - build.push(rbp); - unwind->save(rbp); int stackSize = 64; int localsSize = 16; build.sub(rsp, stackSize + localsSize); - unwind->allocStack(stackSize + localsSize); - build.lea(rbp, addr[rsp + stackSize]); - unwind->setupFrameReg(rbp, stackSize); + uint32_t prologueSize = build.setLabel().location; - size_t prologueSize = build.setLabel().location; + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {r10, r11, r12, r13, r14, r15}); // Body build.mov(rax, rArg1); @@ -535,14 +521,14 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") Label returnOffset = build.setLabel(); // Epilogue - build.lea(rsp, addr[rbp + localsSize]); - build.pop(rbp); + build.add(rsp, stackSize + localsSize); build.pop(r15); build.pop(r14); build.pop(r13); build.pop(r12); build.pop(r11); build.pop(r10); + build.pop(rbp); build.ret(); unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); @@ -650,6 +636,78 @@ TEST_CASE("GeneratedCodeExecutionA64") CHECK(result == 42); } +static void throwing(int64_t arg) +{ + CHECK(arg == 25); + + throw std::runtime_error("testing"); +} + +TEST_CASE("GeneratedCodeExecutionWithThrowA64") +{ + using namespace A64; + + AssemblyBuilderA64 build(/* logText= */ false); + + std::unique_ptr unwind = std::make_unique(); + + unwind->startInfo(UnwindBuilder::A64); + + build.sub(sp, sp, 32); + build.stp(x29, x30, mem(sp)); + build.str(x28, mem(sp, 16)); + build.mov(x29, sp); + + Label prologueEnd = build.setLabel(); + + build.add(x0, x0, 15); + build.blr(x1); + + build.ldr(x28, mem(sp, 16)); + build.ldp(x29, x30, mem(sp)); + build.add(sp, sp, 32); + + build.ret(); + + Label functionEnd = build.setLabel(); + + unwind->startFunction(); + unwind->prologueA64(build.getLabelOffset(prologueEnd), 32, {x29, x30, x28}); + unwind->finishFunction(0, build.getLabelOffset(functionEnd)); + + build.finalize(); + + unwind->finishInfo(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), reinterpret_cast(build.code.data()), build.code.size() * 4, nativeData, + sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f = (FunctionType*)nativeEntry; + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} + #endif TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index f09f174a1..e1213b931 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -532,6 +532,30 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32Blocked") )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "ReplacementPreservesUses") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp unk = build.inst(IrCmd::LOAD_INT, build.vmReg(0)); + build.inst(IrCmd::STORE_INT, build.vmReg(8), build.inst(IrCmd::BITXOR_UINT, unk, build.constInt(~0u))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ true) == R"( +bb_0: ; useCount: 0 + %0 = LOAD_INT R0 ; useCount: 1, lastUse: %0 + %1 = BITNOT_UINT %0 ; useCount: 1, lastUse: %0 + STORE_INT R8, %1 ; %2 + RETURN 0u ; %3 + +)"); +} + TEST_CASE_FIXTURE(IrBuilderFixture, "NumericNan") { IrOp block = build.block(IrBlockKind::Internal); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 6552a24da..26b3b00d4 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -470,8 +470,6 @@ TEST_SUITE_END(); struct NormalizeFixture : Fixture { - ScopedFastFlag sff2{"LuauNegatedClassTypes", true}; - TypeArena arena; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; @@ -632,11 +630,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_function_and_top_function") TEST_CASE_FIXTURE(NormalizeFixture, "negated_function_is_anything_except_a_function") { - ScopedFastFlag sffs[] = { - {"LuauNegatedTableTypes", true}, - {"LuauNegatedClassTypes", true}, - }; - CHECK("(boolean | class | number | string | table | thread)?" == toString(normal(R"( Not )"))); @@ -649,11 +642,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "specific_functions_cannot_be_negated") TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean") { - ScopedFastFlag sffs[] = { - {"LuauNegatedTableTypes", true}, - {"LuauNegatedClassTypes", true}, - }; - // TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function CHECK("(class | function | number | string | table | thread)?" == toString(normal(R"( Not @@ -723,8 +711,6 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Parent | Unrelated" == toString(normal("Parent | Unrelated"))); CHECK("Parent" == toString(normal("Parent | Child"))); @@ -733,8 +719,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Child" == toString(normal("Parent & Child"))); CHECK("never" == toString(normal("Child & Unrelated"))); @@ -742,8 +726,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Child" == toString(normal("(Child | Unrelated) & Child"))); } @@ -764,11 +746,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") { - ScopedFastFlag sffs[] = { - {"LuauNegatedTableTypes", true}, - {"LuauNegatedClassTypes", true}, - }; - createSomeClasses(&frontend); CHECK("(Parent & ~Child) | Unrelated" == toString(normal("(Parent & Not) | Unrelated"))); CHECK("((class & ~Child) | boolean | function | number | string | table | thread)?" == toString(normal("Not"))); @@ -781,24 +758,18 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_unknown") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Parent" == toString(normal("Parent & unknown"))); } TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_never") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("never" == toString(normal("Parent & never"))); } TEST_CASE_FIXTURE(NormalizeFixture, "top_table_type") { - ScopedFastFlag sff{"LuauNegatedTableTypes", true}; - CHECK("table" == toString(normal("{} | tbl"))); CHECK("{| |}" == toString(normal("{} & tbl"))); CHECK("never" == toString(normal("number & tbl"))); @@ -806,8 +777,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "top_table_type") TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_tables") { - ScopedFastFlag sff{"LuauNegatedTableTypes", true}; - CHECK(nullptr == toNormalizedType("Not<{}>")); CHECK("(boolean | class | function | number | string | thread)?" == toString(normal("Not"))); CHECK("table" == toString(normal("Not>"))); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index ef5aabbe3..1335b6f4e 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -112,14 +112,6 @@ TEST_CASE_FIXTURE(Fixture, "can_haz_annotations") REQUIRE(block != nullptr); } -TEST_CASE_FIXTURE(Fixture, "local_cannot_have_annotation_with_extensions_disabled") -{ - Luau::ParseOptions options; - options.allowTypeAnnotations = false; - - CHECK_THROWS_AS(parse("local foo: string = \"Hello Types!\"", options), std::exception); -} - TEST_CASE_FIXTURE(Fixture, "local_with_annotation") { AstStatBlock* block = parse(R"( @@ -150,14 +142,6 @@ TEST_CASE_FIXTURE(Fixture, "type_names_can_contain_dots") REQUIRE(block != nullptr); } -TEST_CASE_FIXTURE(Fixture, "functions_cannot_have_return_annotations_if_extensions_are_disabled") -{ - Luau::ParseOptions options; - options.allowTypeAnnotations = false; - - CHECK_THROWS_AS(parse("function foo(): number return 55 end", options), std::exception); -} - TEST_CASE_FIXTURE(Fixture, "functions_can_have_return_annotations") { AstStatBlock* block = parse(R"( @@ -395,14 +379,6 @@ TEST_CASE_FIXTURE(Fixture, "return_type_is_an_intersection_type_if_led_with_one_ CHECK(returnAnnotation->types.data[1]->as()); } -TEST_CASE_FIXTURE(Fixture, "illegal_type_alias_if_extensions_are_disabled") -{ - Luau::ParseOptions options; - options.allowTypeAnnotations = false; - - CHECK_THROWS_AS(parse("type A = number", options), std::exception); -} - TEST_CASE_FIXTURE(Fixture, "type_alias_to_a_typeof") { AstStatBlock* block = parse(R"( @@ -2837,8 +2813,6 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_after_last_t TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_type_parameter") { - ScopedFastFlag sff{"LuauParserErrorsOnMissingDefaultTypePackArgument", true}; - ParseResult result = tryParse(R"( type Foo = nil )"); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 687bc766d..0f255f08c 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -108,7 +108,10 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_ERROR_COUNT(2, result); + else + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("*error-type*", toString(requireType("a"))); } diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 9086a6049..94cf4b326 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -169,14 +169,27 @@ TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_ LUAU_REQUIRE_ERROR_COUNT(2, result); - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(builtinTypes->numberType, tm->wantedType); - CHECK_EQ(builtinTypes->stringType, tm->givenType); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + GenericError* g = get(result.errors[0]); + REQUIRE(g); + CHECK(g->message == "None of the overloads for function that accept 1 arguments are compatible."); + } + else + { + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); + } ExtraInformation* ei = get(result.errors[1]); REQUIRE(ei); - CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK("Available overloads: (number) -> number; and (number) -> string" == ei->message); + else + CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); } TEST_CASE_FIXTURE(Fixture, "list_all_overloads_if_no_overload_takes_given_argument_count") diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 7a1343584..c3dbbc7dc 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -2,6 +2,7 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Frontend.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Type.h" @@ -31,6 +32,53 @@ TEST_CASE_FIXTURE(Fixture, "for_loop") CHECK_EQ(*builtinTypes->numberType, *requireType("q")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_no_table_passed") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + CheckResult result = check(R"( + +type Iterable = typeof(setmetatable( + {}, + {}::{ + __iter: (self: Iterable) -> (any, number) -> (number, string) + } +)) + +local t: Iterable + +for a, b in t do end +)"); + + + LUAU_REQUIRE_ERROR_COUNT(1, result); + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("__iter metamethod must return (next[, table[, state]])", ge->message); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_regression_issue_69967") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + CheckResult result = check(R"( + +type Iterable = typeof(setmetatable( + {}, + {}::{ + __iter: (self: Iterable) -> () -> (number, string) + } +)) + +local t: Iterable + +for a, b in t do end +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("__iter metamethod must return (next[, table[, state]])", ge->message); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index f2b3d0559..ee7472520 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -26,9 +26,17 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defi someTable.Function1() -- Argument count mismatch )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "No overload for function accepts 0 arguments."); + CHECK(toString(result.errors[1]) == "Available overloads: (a) -> ()"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") @@ -42,9 +50,17 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_ someTable.Function2() -- Argument count mismatch )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "No overload for function accepts 0 arguments."); + CHECK(toString(result.errors[1]) == "Available overloads: (a, b) -> ()"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 7f21641d6..58acef222 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -52,6 +52,43 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") CHECK_EQ(expected, decorateWithTypes(code)); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.filter") +{ + // This test exercises the fact that we should reduce sealed/unsealed/free tables + // res is a unsealed table with type {((T & ~nil)?) & any} + // Because we do not reduce it fully, we cannot unify it with `Array = { [number] : T} + // TLDR; reduction needs to reduce the indexer on res so it unifies with Array + CheckResult result = check(R"( +--!strict +-- Implements Javascript's `Array.prototype.filter` as defined below +-- https://developer.cmozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/filter +type Array = { [number]: T } +type callbackFn = (element: T, index: number, array: Array) -> boolean +type callbackFnWithThisArg = (thisArg: U, element: T, index: number, array: Array) -> boolean +type Object = { [string]: any } +return function(t: Array, callback: callbackFn | callbackFnWithThisArg, thisArg: U?): Array + + local len = #t + local res = {} + if thisArg == nil then + for i = 1, len do + local kValue = t[i] + if kValue ~= nil then + if (callback :: callbackFn)(kValue, i, t) then + res[i] = kValue + end + end + end + else + end + + return res +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") { const std::string code = R"( diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 06cbe0cf3..c55497ae4 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,7 +8,6 @@ #include "doctest.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauNegatedClassTypes) using namespace Luau; @@ -64,7 +63,7 @@ struct RefinementClassFixture : BuiltinsFixture TypeArena& arena = frontend.globals.globalTypes; NotNull scope{frontend.globals.globalScope.get()}; - std::optional rootSuper = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; + std::optional rootSuper = std::make_optional(builtinTypes->classType); unfreeze(arena); TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 23e49f581..f028e8e0d 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -131,8 +131,16 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); - CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("None of the overloads for function that accept 2 arguments are compatible.", toString(result.errors[0])); + CHECK_EQ("Available overloads: (true, string) -> (); and (false, number) -> ()", toString(result.errors[1])); + } + else + { + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); + } } TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index fcf2c8a4a..4b24fb225 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -3625,4 +3625,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "top_table_type_is_isomorphic_to_empty_sealed )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.includes") +{ + + CheckResult result = check(R"( +type Array = { [number]: T } + +function indexOf(array: Array, searchElement: any, fromIndex: number?): number + return -1 +end + +return function(array: Array, searchElement: any, fromIndex: number?): boolean + return -1 ~= indexOf(array, searchElement, fromIndex) +end + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index 38fa7f5f8..655d094f6 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,9 +1,5 @@ AnnotationTests.too_many_type_params AstQuery.last_argument_function_call_type -AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method -AstQuery::getDocumentationSymbolAtPosition.overloaded_fn -AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop -AutocompleteTest.autocomplete_response_perf1 BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 @@ -54,6 +50,7 @@ ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean ProvisionalTests.free_options_cannot_be_unified_together ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns +ProvisionalTests.luau-polyfill.Array.filter ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument @@ -146,7 +143,6 @@ TypeInferClasses.index_instance_property TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferClasses.warn_when_prop_almost_matches TypeInferFunctions.cannot_hoist_interior_defns_into_signature -TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite @@ -158,7 +154,6 @@ TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.no_lossy_function_type TypeInferFunctions.occurs_check_failure_in_function_return_type -TypeInferFunctions.record_matching_overload TypeInferFunctions.report_exiting_without_return_strict TypeInferFunctions.return_type_by_overload TypeInferFunctions.too_few_arguments_variadic @@ -205,6 +200,7 @@ TypePackTests.variadic_packs TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch TypeSingletons.no_widening_from_callsites +TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.table_properties_type_error_escapes TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton From 12c1edf6c64be6b05564c01300684f48558a3bc4 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 5 May 2023 13:25:00 -0700 Subject: [PATCH 51/66] This test fails on a64 so disable it for now. --- tests/CodeAllocator.test.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 51f216da6..df2fa36b7 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -636,6 +636,7 @@ TEST_CASE("GeneratedCodeExecutionA64") CHECK(result == 42); } +#if 0 static void throwing(int64_t arg) { CHECK(arg == 25); @@ -707,6 +708,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowA64") CHECK(strcmp(error.what(), "testing") == 0); } } +#endif #endif From f7c780164d8683c0d7104b5e07806204bdac3259 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 5 May 2023 15:33:05 -0700 Subject: [PATCH 52/66] Add pthread as a link dependency to Luau.Analyze.CLI for Linux. --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e15e5f88..b3b1573ac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -192,6 +192,7 @@ if(LUAU_BUILD_CLI) find_library(LIBPTHREAD pthread) if (LIBPTHREAD) target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + target_link_libraries(Luau.Analyze.CLI PRIVATE pthread) endif() endif() From 3247aabf75e89588cfd4d1055930520c56a0cac1 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 12 May 2023 15:15:01 +0300 Subject: [PATCH 53/66] Sync to upstream/release/576 --- Analysis/include/Luau/Constraint.h | 23 +- Analysis/include/Luau/ConstraintSolver.h | 2 + Analysis/include/Luau/Error.h | 27 +- Analysis/include/Luau/Quantify.h | 28 +- Analysis/include/Luau/TxnLog.h | 13 +- Analysis/include/Luau/Type.h | 38 ++- Analysis/include/Luau/TypeFamily.h | 115 ++++++++ Analysis/include/Luau/TypePack.h | 17 +- Analysis/include/Luau/Unifier.h | 11 +- Analysis/include/Luau/VisitType.h | 39 +++ Analysis/src/Clone.cpp | 63 ++++- Analysis/src/ConstraintGraphBuilder.cpp | 3 +- Analysis/src/ConstraintSolver.cpp | 147 +++++++--- Analysis/src/Error.cpp | 24 ++ Analysis/src/Frontend.cpp | 5 +- Analysis/src/IostreamHelpers.cpp | 4 + Analysis/src/Quantify.cpp | 30 +- Analysis/src/Substitution.cpp | 72 ++++- Analysis/src/ToDot.cpp | 9 + Analysis/src/ToString.cpp | 119 ++++++-- Analysis/src/TxnLog.cpp | 118 +++++++- Analysis/src/TypeAttach.cpp | 10 + Analysis/src/TypeChecker2.cpp | 47 +++- Analysis/src/TypeFamily.cpp | 310 +++++++++++++++++++++ Analysis/src/TypeInfer.cpp | 100 ++++++- Analysis/src/Unifier.cpp | 214 +++++++++++++- Ast/include/Luau/Ast.h | 1 + Ast/src/StringUtils.cpp | 1 + CLI/Repl.cpp | 28 ++ CMakeLists.txt | 1 + CodeGen/include/Luau/AssemblyBuilderA64.h | 4 +- CodeGen/include/Luau/AssemblyBuilderX64.h | 3 + CodeGen/include/Luau/CodeBlockUnwind.h | 2 + CodeGen/include/Luau/CodeGen.h | 10 +- CodeGen/src/AssemblyBuilderA64.cpp | 53 ++-- CodeGen/src/AssemblyBuilderX64.cpp | 33 ++- CodeGen/src/CodeBlockUnwind.cpp | 28 +- CodeGen/src/CodeGen.cpp | 88 ++++-- CodeGen/src/EmitCommonX64.cpp | 2 +- CodeGen/src/IrLoweringA64.cpp | 26 +- CodeGen/src/IrTranslation.cpp | 39 ++- CodeGen/src/NativeState.h | 2 + Common/include/Luau/ExperimentalFlags.h | 2 + Sources.cmake | 4 + VM/src/ldebug.cpp | 38 ++- VM/src/lvmexecute.cpp | 10 +- bench/tests/matrixmult.lua | 39 +++ bench/tests/mesh-normal-scalar.lua | 254 +++++++++++++++++ tests/AssemblyBuilderA64.test.cpp | 3 +- tests/AssemblyBuilderX64.test.cpp | 25 +- tests/Autocomplete.test.cpp | 22 ++ tests/ClassFixture.cpp | 13 + tests/CodeAllocator.test.cpp | 4 + tests/Conformance.test.cpp | 13 +- tests/TxnLog.test.cpp | 113 ++++++++ tests/TypeFamily.test.cpp | 205 ++++++++++++++ tests/TypeInfer.classes.test.cpp | 146 ++++++++++ tests/TypeInfer.functions.test.cpp | 36 +++ tests/TypeInfer.intersectionTypes.test.cpp | 18 +- tests/TypeInfer.operators.test.cpp | 28 +- tests/TypeInfer.provisional.test.cpp | 35 --- tests/TypeInfer.singletons.test.cpp | 6 +- tests/TypeInfer.test.cpp | 65 +++++ tests/TypeInfer.tryUnify.test.cpp | 87 ++++++ tests/TypeInfer.unionTypes.test.cpp | 78 ++++++ tests/conformance/debugger.lua | 13 + tools/faillist.txt | 9 +- 67 files changed, 2898 insertions(+), 277 deletions(-) create mode 100644 Analysis/include/Luau/TypeFamily.h create mode 100644 Analysis/src/TypeFamily.cpp create mode 100644 bench/tests/matrixmult.lua create mode 100644 bench/tests/mesh-normal-scalar.lua create mode 100644 tests/TxnLog.test.cpp create mode 100644 tests/TypeFamily.test.cpp diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index c7bc58b5a..3aa3c865c 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -198,9 +198,26 @@ struct UnpackConstraint TypePackId sourcePack; }; -using ConstraintV = Variant; +// ty ~ reduce ty +// +// Try to reduce ty, if it is a TypeFamilyInstanceType. Otherwise, do nothing. +struct ReduceConstraint +{ + TypeId ty; +}; + +// tp ~ reduce tp +// +// Analogous to ReduceConstraint, but for type packs. +struct ReducePackConstraint +{ + TypePackId tp; +}; + +using ConstraintV = + Variant; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 6888e99c2..f6b1aede8 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -121,6 +121,8 @@ struct ConstraintSolver bool tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force); bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); bool tryDispatch(const UnpackConstraint& c, NotNull constraint); + bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); // for a, ... in some_table do // also handles __iter metamethod diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 8571430bf..6264a0b53 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -329,12 +329,27 @@ struct DynamicPropertyLookupOnClassesUnsafe bool operator==(const DynamicPropertyLookupOnClassesUnsafe& rhs) const; }; -using TypeErrorData = Variant; +struct UninhabitedTypeFamily +{ + TypeId ty; + + bool operator==(const UninhabitedTypeFamily& rhs) const; +}; + +struct UninhabitedTypePackFamily +{ + TypePackId tp; + + bool operator==(const UninhabitedTypePackFamily& rhs) const; +}; + +using TypeErrorData = + Variant; struct TypeErrorSummary { diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index c86512f1f..b562c54c7 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -2,6 +2,9 @@ #pragma once #include "Luau/Type.h" +#include "Luau/DenseHash.h" + +#include namespace Luau { @@ -10,6 +13,29 @@ struct TypeArena; struct Scope; void quantify(TypeId ty, TypeLevel level); -std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope); + +// TODO: This is eerily similar to the pattern that NormalizedClassType +// implements. We could, and perhaps should, merge them together. +template +struct OrderedMap +{ + std::vector keys; + DenseHashMap pairings{nullptr}; + + void push(K k, V v) + { + keys.push_back(k); + pairings[k] = v; + } +}; + +struct QuantifierResult +{ + TypeId result; + OrderedMap insertedGenerics; + OrderedMap insertedGenericPacks; +}; + +std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope); } // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 0ed8a49ad..907908dfe 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -19,6 +19,10 @@ struct PendingType // The pending Type state. Type pending; + // On very rare occasions, we need to delete an entry from the TxnLog. + // DenseHashMap does not afford that so we note its deadness here. + bool dead = false; + explicit PendingType(Type state) : pending(std::move(state)) { @@ -61,10 +65,11 @@ T* getMutable(PendingTypePack* pending) // Log of what TypeIds we are rebinding, to be committed later. struct TxnLog { - TxnLog() + explicit TxnLog(bool useScopes = false) : typeVarChanges(nullptr) , typePackChanges(nullptr) , ownedSeen() + , useScopes(useScopes) , sharedSeen(&ownedSeen) { } @@ -297,6 +302,12 @@ struct TxnLog void popSeen(TypeOrPackId lhs, TypeOrPackId rhs); public: + // There is one spot in the code where TxnLog has to reconcile collisions + // between parallel logs. In that codepath, we have to work out which of two + // FreeTypes subsumes the other. If useScopes is false, the TypeLevel is + // used. Else we use the embedded Scope*. + bool useScopes = false; + // Used to avoid infinite recursion when types are cyclic. // Shared with all the descendent TxnLogs. std::vector>* sharedSeen; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index c615b8f57..80a044cbf 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -23,6 +23,7 @@ LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) +LUAU_FASTFLAG(LuauTypecheckClassTypeIndexers) namespace Luau { @@ -31,6 +32,8 @@ struct TypeArena; struct Scope; using ScopePtr = std::shared_ptr; +struct TypeFamily; + /** * There are three kinds of type variables: * - `Free` variables are metavariables, which stand for unconstrained types. @@ -489,6 +492,7 @@ struct ClassType Tags tags; std::shared_ptr userData; ModuleName definitionModuleName; + std::optional indexer; ClassType(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, std::shared_ptr userData, ModuleName definitionModuleName) @@ -501,6 +505,35 @@ struct ClassType , definitionModuleName(definitionModuleName) { } + + ClassType(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, + std::shared_ptr userData, ModuleName definitionModuleName, std::optional indexer) + : name(name) + , props(props) + , parent(parent) + , metatable(metatable) + , tags(tags) + , userData(userData) + , definitionModuleName(definitionModuleName) + , indexer(indexer) + { + LUAU_ASSERT(FFlag::LuauTypecheckClassTypeIndexers); + } +}; + +/** + * An instance of a type family that has not yet been reduced to a more concrete + * type. The constraint solver receives a constraint to reduce each + * TypeFamilyInstanceType to a concrete type. A design detail is important to + * note here: the parameters for this instantiation of the type family are + * contained within this type, so that they can be substituted. + */ +struct TypeFamilyInstanceType +{ + NotNull family; + + std::vector typeArguments; + std::vector packArguments; }; struct TypeFun @@ -640,8 +673,9 @@ struct NegationType using ErrorType = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = + Unifiable::Variant; struct Type final { diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h new file mode 100644 index 000000000..4c04f52ae --- /dev/null +++ b/Analysis/include/Luau/TypeFamily.h @@ -0,0 +1,115 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Error.h" +#include "Luau/NotNull.h" +#include "Luau/Variant.h" + +#include +#include +#include + +namespace Luau +{ + +struct Type; +using TypeId = const Type*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +struct TypeArena; +struct BuiltinTypes; +struct TxnLog; + +/// Represents a reduction result, which may have successfully reduced the type, +/// may have concretely failed to reduce the type, or may simply be stuck +/// without more information. +template +struct TypeFamilyReductionResult +{ + /// The result of the reduction, if any. If this is nullopt, the family + /// could not be reduced. + std::optional result; + /// Whether the result is uninhabited: whether we know, unambiguously and + /// permanently, whether this type family reduction results in an + /// uninhabitable type. This will trigger an error to be reported. + bool uninhabited; + /// Any types that need to be progressed or mutated before the reduction may + /// proceed. + std::vector blockedTypes; + /// Any type packs that need to be progressed or mutated before the + /// reduction may proceed. + std::vector blockedPacks; +}; + +/// Represents a type function that may be applied to map a series of types and +/// type packs to a single output type. +struct TypeFamily +{ + /// The human-readable name of the type family. Used to stringify instance + /// types. + std::string name; + + /// The reducer function for the type family. + std::function( + std::vector, std::vector, NotNull, NotNull, NotNull log)> + reducer; +}; + +/// Represents a type function that may be applied to map a series of types and +/// type packs to a single output type pack. +struct TypePackFamily +{ + /// The human-readable name of the type pack family. Used to stringify + /// instance packs. + std::string name; + + /// The reducer function for the type pack family. + std::function( + std::vector, std::vector, NotNull, NotNull, NotNull log)> + reducer; +}; + +struct FamilyGraphReductionResult +{ + ErrorVec errors; + DenseHashSet blockedTypes{nullptr}; + DenseHashSet blockedPacks{nullptr}; + DenseHashSet reducedTypes{nullptr}; + DenseHashSet reducedPacks{nullptr}; +}; + +/** + * Attempt to reduce all instances of any type or type pack family in the type + * graph provided. + * + * @param entrypoint the entry point to the type graph. + * @param location the location the reduction is occurring at; used to populate + * type errors. + * @param arena an arena to allocate types into. + * @param builtins the built-in types. + * @param log a TxnLog to use. If one is provided, substitution will take place + * against the TxnLog, otherwise substitutions will directly mutate the type + * graph. Do not provide the empty TxnLog, as a result. + */ +FamilyGraphReductionResult reduceFamilies( + TypeId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log = nullptr, bool force = false); + +/** + * Attempt to reduce all instances of any type or type pack family in the type + * graph provided. + * + * @param entrypoint the entry point to the type graph. + * @param location the location the reduction is occurring at; used to populate + * type errors. + * @param arena an arena to allocate types into. + * @param builtins the built-in types. + * @param log a TxnLog to use. If one is provided, substitution will take place + * against the TxnLog, otherwise substitutions will directly mutate the type + * graph. Do not provide the empty TxnLog, as a result. + */ +FamilyGraphReductionResult reduceFamilies( + TypePackId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log = nullptr, bool force = false); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index e78a66b84..d159aa45d 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -12,11 +12,13 @@ namespace Luau { struct TypeArena; +struct TypePackFamily; struct TxnLog; struct TypePack; struct VariadicTypePack; struct BlockedTypePack; +struct TypeFamilyInstanceTypePack; struct TypePackVar; using TypePackId = const TypePackVar*; @@ -50,10 +52,10 @@ struct GenericTypePack }; using BoundTypePack = Unifiable::Bound; - using ErrorTypePack = Unifiable::Error; -using TypePackVariant = Unifiable::Variant; +using TypePackVariant = + Unifiable::Variant; /* A TypePack is a rope-like string of TypeIds. We use this structure to encode * notions like packs of unknown length and packs of any length, as well as more @@ -83,6 +85,17 @@ struct BlockedTypePack static size_t nextIndex; }; +/** + * Analogous to a TypeFamilyInstanceType. + */ +struct TypeFamilyInstanceTypePack +{ + NotNull family; + + std::vector typeArguments; + std::vector packArguments; +}; + struct TypePackVar { explicit TypePackVar(const TypePackVariant& ty); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 742f029ca..d5db06c8b 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -64,9 +64,11 @@ struct Unifier Variance variance = Covariant; bool normalize = true; // Normalize unions and intersections if necessary bool checkInhabited = true; // Normalize types to check if they are inhabited - bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; + // If true, generics act as free types when unifying. + bool hideousFixMeGenericsAreActuallyFree = false; + UnifierSharedState& sharedState; // When the Unifier is forced to unify two blocked types (or packs), they @@ -78,6 +80,10 @@ struct Unifier Unifier( NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); + // Configure the Unifier to test for scope subsumption via embedded Scope + // pointers rather than TypeLevels. + void enableScopeTests(); + // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); @@ -159,6 +165,9 @@ struct Unifier // Available after regular type pack unification errors std::optional firstPackErrorPos; + + // If true, we use the scope hierarchy rather than TypeLevels + bool useScopes = false; }; void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index 663627d5e..b6dcf1f1b 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -159,6 +159,10 @@ struct GenericTypeVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const TypeFamilyInstanceType& tfit) + { + return visit(ty); + } virtual bool visit(TypePackId tp) { @@ -192,6 +196,10 @@ struct GenericTypeVisitor { return visit(tp); } + virtual bool visit(TypePackId tp, const TypeFamilyInstanceTypePack& tfitp) + { + return visit(tp); + } void traverse(TypeId ty) { @@ -272,6 +280,15 @@ struct GenericTypeVisitor if (ctv->metatable) traverse(*ctv->metatable); + + if (FFlag::LuauTypecheckClassTypeIndexers) + { + if (ctv->indexer) + { + traverse(ctv->indexer->indexType); + traverse(ctv->indexer->indexResultType); + } + } } } else if (auto atv = get(ty)) @@ -327,6 +344,17 @@ struct GenericTypeVisitor if (visit(ty, *ntv)) traverse(ntv->ty); } + else if (auto tfit = get(ty)) + { + if (visit(ty, *tfit)) + { + for (TypeId p : tfit->typeArguments) + traverse(p); + + for (TypePackId p : tfit->packArguments) + traverse(p); + } + } else LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypeId) is not exhaustive!"); @@ -376,6 +404,17 @@ struct GenericTypeVisitor } else if (auto btp = get(tp)) visit(tp, *btp); + else if (auto tfitp = get(tp)) + { + if (visit(tp, *tfitp)) + { + for (TypeId t : tfitp->typeArguments) + traverse(t); + + for (TypePackId t : tfitp->packArguments) + traverse(t); + } + } else LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypePackId) is not exhaustive!"); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 450b84af9..0c1b24a19 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -52,6 +52,12 @@ Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState) } } +static TableIndexer clone(const TableIndexer& indexer, TypeArena& dest, CloneState& cloneState) +{ + LUAU_ASSERT(FFlag::LuauTypecheckClassTypeIndexers); + return TableIndexer{clone(indexer.indexType, dest, cloneState), clone(indexer.indexResultType, dest, cloneState)}; +} + struct TypePackCloner; /* @@ -98,6 +104,7 @@ struct TypeCloner void operator()(const UnknownType& t); void operator()(const NeverType& t); void operator()(const NegationType& t); + void operator()(const TypeFamilyInstanceType& t); }; struct TypePackCloner @@ -171,6 +178,22 @@ struct TypePackCloner if (t.tail) destTp->tail = clone(*t.tail, dest, cloneState); } + + void operator()(const TypeFamilyInstanceTypePack& t) + { + TypePackId cloned = dest.addTypePack(TypeFamilyInstanceTypePack{t.family, {}, {}}); + TypeFamilyInstanceTypePack* destTp = getMutable(cloned); + LUAU_ASSERT(destTp); + seenTypePacks[typePackId] = cloned; + + destTp->typeArguments.reserve(t.typeArguments.size()); + for (TypeId ty : t.typeArguments) + destTp->typeArguments.push_back(clone(ty, dest, cloneState)); + + destTp->packArguments.reserve(t.packArguments.size()); + for (TypePackId tp : t.packArguments) + destTp->packArguments.push_back(clone(tp, dest, cloneState)); + } }; template @@ -288,8 +311,16 @@ void TypeCloner::operator()(const TableType& t) for (const auto& [name, prop] : t.props) ttv->props[name] = clone(prop, dest, cloneState); - if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; + if (FFlag::LuauTypecheckClassTypeIndexers) + { + if (t.indexer) + ttv->indexer = clone(*t.indexer, dest, cloneState); + } + else + { + if (t.indexer) + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; + } for (TypeId& arg : ttv->instantiatedTypeParams) arg = clone(arg, dest, cloneState); @@ -327,6 +358,12 @@ void TypeCloner::operator()(const ClassType& t) if (t.metatable) ctv->metatable = clone(*t.metatable, dest, cloneState); + + if (FFlag::LuauTypecheckClassTypeIndexers) + { + if (t.indexer) + ctv->indexer = clone(*t.indexer, dest, cloneState); + } } void TypeCloner::operator()(const AnyType& t) @@ -389,6 +426,28 @@ void TypeCloner::operator()(const NegationType& t) asMutable(result)->ty = NegationType{ty}; } +void TypeCloner::operator()(const TypeFamilyInstanceType& t) +{ + TypeId result = dest.addType(TypeFamilyInstanceType{ + t.family, + {}, + {}, + }); + + seenTypes[typeId] = result; + + TypeFamilyInstanceType* tfit = getMutable(result); + LUAU_ASSERT(tfit != nullptr); + + tfit->typeArguments.reserve(t.typeArguments.size()); + for (TypeId p : t.typeArguments) + tfit->typeArguments.push_back(clone(p, dest, cloneState)); + + tfit->packArguments.reserve(t.packArguments.size()); + for (TypePackId p : t.packArguments) + tfit->packArguments.push_back(clone(p, dest, cloneState)); +} + } // anonymous namespace TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index e07fe701d..c8d99adf8 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -728,6 +728,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFun }); addConstraint(scope, std::move(c)); + module->astTypes[function->func] = functionType; return ControlFlow::None; } @@ -1475,7 +1476,7 @@ Inference ConstraintGraphBuilder::check( Checkpoint endCheckpoint = checkpoint(this); TypeId generalizedTy = arena->addType(BlockedType{}); - NotNull gc = addConstraint(scope, expr->location, GeneralizationConstraint{generalizedTy, sig.signature}); + NotNull gc = addConstraint(sig.signatureScope, expr->location, GeneralizationConstraint{generalizedTy, sig.signature}); forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { gc->dependencies.emplace_back(constraint.get()); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index f1f868add..488fd4baa 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -16,6 +16,7 @@ #include "Luau/Type.h" #include "Luau/Unifier.h" #include "Luau/VisitType.h" +#include "Luau/TypeFamily.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAG(LuauRequirePathTrueModuleName) @@ -226,6 +227,32 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } } +struct InstantiationQueuer : TypeOnceVisitor +{ + ConstraintSolver* solver; + NotNull scope; + Location location; + + explicit InstantiationQueuer(NotNull scope, const Location& location, ConstraintSolver* solver) + : solver(solver) + , scope(scope) + , location(location) + { + } + + bool visit(TypeId ty, const PendingExpansionType& petv) override + { + solver->pushConstraint(scope, location, TypeAliasExpansionConstraint{ty}); + return false; + } + + bool visit(TypeId ty, const TypeFamilyInstanceType& tfit) override + { + solver->pushConstraint(scope, location, ReduceConstraint{ty}); + return true; + } +}; + ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) : arena(normalizer->arena) @@ -441,6 +468,10 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*sottc, constraint); else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); + else if (auto rc = get(*constraint)) + success = tryDispatch(*rc, constraint, force); + else if (auto rpc = get(*constraint)) + success = tryDispatch(*rpc, constraint, force); else LUAU_ASSERT(false); @@ -479,13 +510,19 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull(generalizedType)) return block(generalizedType, constraint); - std::optional generalized = quantify(arena, c.sourceType, constraint->scope); + std::optional generalized = quantify(arena, c.sourceType, constraint->scope); if (generalized) { if (get(generalizedType)) - asMutable(generalizedType)->ty.emplace(*generalized); + asMutable(generalizedType)->ty.emplace(generalized->result); else - unify(generalizedType, *generalized, constraint->scope); + unify(generalizedType, generalized->result, constraint->scope); + + for (auto [free, gen] : generalized->insertedGenerics.pairings) + unify(free, gen, constraint->scope); + + for (auto [free, gen] : generalized->insertedGenericPacks.pairings) + unify(free, gen, constraint->scope); } else { @@ -504,6 +541,9 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNullscope); std::optional instantiated = inst.substitute(c.superType); @@ -512,6 +552,9 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull(c.subType)); asMutable(c.subType)->ty.emplace(*instantiated); + InstantiationQueuer queuer{constraint->scope, constraint->location, this}; + queuer.traverse(c.subType); + unblock(c.subType); return true; @@ -953,26 +996,6 @@ struct InfiniteTypeFinder : TypeOnceVisitor } }; -struct InstantiationQueuer : TypeOnceVisitor -{ - ConstraintSolver* solver; - NotNull scope; - Location location; - - explicit InstantiationQueuer(NotNull scope, const Location& location, ConstraintSolver* solver) - : solver(solver) - , scope(scope) - , location(location) - { - } - - bool visit(TypeId ty, const PendingExpansionType& petv) override - { - solver->pushConstraint(scope, location, TypeAliasExpansionConstraint{ty}); - return false; - } -}; - bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint) { const PendingExpansionType* petv = get(follow(c.target)); @@ -1246,7 +1269,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, Location{}, Covariant}; - u.useScopes = true; + u.enableScopeTests(); u.tryUnify(*instantiated, inferredTy, /* isFunctionCall */ true); @@ -1278,8 +1301,12 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, constraint->location, this}; + queuer.traverse(fn); + queuer.traverse(inferredTy); + return true; } } @@ -1295,7 +1322,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, Location{}, Covariant}; - u.useScopes = true; + u.enableScopeTests(); u.tryUnify(inferredTy, builtinTypes->anyType); u.tryUnify(fn, builtinTypes->anyType); @@ -1305,8 +1332,12 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, constraint->location, this}; + queuer.traverse(fn); + queuer.traverse(inferredTy); + return true; } @@ -1567,8 +1598,11 @@ bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNullstate == TableState::Free || tt->state == TableState::Unsealed) { + TypeId promotedIndexTy = arena->freshType(tt->scope); + unify(c.indexType, promotedIndexTy, constraint->scope); + auto mtt = getMutable(subjectType); - mtt->indexer = TableIndexer{c.indexType, c.propType}; + mtt->indexer = TableIndexer{promotedIndexTy, c.propType}; asMutable(c.propType)->ty.emplace(tt->scope); asMutable(c.resultType)->ty.emplace(subjectType); unblock(c.propType); @@ -1666,6 +1700,52 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint, bool force) +{ + TypeId ty = follow(c.ty); + FamilyGraphReductionResult result = reduceFamilies(ty, constraint->location, NotNull{arena}, builtinTypes, nullptr, force); + + for (TypeId r : result.reducedTypes) + unblock(r); + + for (TypePackId r : result.reducedPacks) + unblock(r); + + if (force) + return true; + + for (TypeId b : result.blockedTypes) + block(b, constraint); + + for (TypePackId b : result.blockedPacks) + block(b, constraint); + + return result.blockedTypes.empty() && result.blockedPacks.empty(); +} + +bool ConstraintSolver::tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force) +{ + TypePackId tp = follow(c.tp); + FamilyGraphReductionResult result = reduceFamilies(tp, constraint->location, NotNull{arena}, builtinTypes, nullptr, force); + + for (TypeId r : result.reducedTypes) + unblock(r); + + for (TypePackId r : result.reducedPacks) + unblock(r); + + if (force) + return true; + + for (TypeId b : result.blockedTypes) + block(b, constraint); + + for (TypePackId b : result.blockedPacks) + block(b, constraint); + + return result.blockedTypes.empty() && result.blockedPacks.empty(); +} + bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) { auto block_ = [&](auto&& t) { @@ -2031,7 +2111,7 @@ template bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, TID superTy) { Unifier u{normalizer, Mode::Strict, constraint->scope, constraint->location, Covariant}; - u.useScopes = true; + u.enableScopeTests(); u.tryUnify(subTy, superTy); @@ -2195,10 +2275,11 @@ void ConstraintSolver::unblock(NotNull progressed) return unblock_(progressed.get()); } -void ConstraintSolver::unblock(TypeId progressed) +void ConstraintSolver::unblock(TypeId ty) { DenseHashSet seen{nullptr}; + TypeId progressed = ty; while (true) { if (seen.find(progressed)) @@ -2256,7 +2337,7 @@ bool ConstraintSolver::isBlocked(NotNull constraint) void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) { Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; + u.enableScopeTests(); u.tryUnify(subType, superType); @@ -2279,7 +2360,7 @@ void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope, if (unifyFreeTypes && (get(a) || get(b))) { Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; + u.enableScopeTests(); u.tryUnify(b, a); if (u.errors.empty()) diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 1e0379729..4f70be33f 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -484,6 +484,16 @@ struct ErrorConverter { return "Attempting a dynamic property access on type '" + Luau::toString(e.ty) + "' is unsafe and may cause exceptions at runtime"; } + + std::string operator()(const UninhabitedTypeFamily& e) const + { + return "Type family instance " + Luau::toString(e.ty) + " is uninhabited"; + } + + std::string operator()(const UninhabitedTypePackFamily& e) const + { + return "Type pack family instance " + Luau::toString(e.tp) + " is uninhabited"; + } }; struct InvalidNameChecker @@ -786,6 +796,16 @@ bool DynamicPropertyLookupOnClassesUnsafe::operator==(const DynamicPropertyLooku return ty == rhs.ty; } +bool UninhabitedTypeFamily::operator==(const UninhabitedTypeFamily& rhs) const +{ + return ty == rhs.ty; +} + +bool UninhabitedTypePackFamily::operator==(const UninhabitedTypePackFamily& rhs) const +{ + return tp == rhs.tp; +} + std::string toString(const TypeError& error) { return toString(error, TypeErrorToStringOptions{}); @@ -944,6 +964,10 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) } else if constexpr (std::is_same_v) e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + e.tp = clone(e.tp); else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b6b315cf1..b16eda8a9 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -38,6 +38,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAG(LuauRequirePathTrueModuleName) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) LUAU_FASTFLAGVARIABLE(LuauSplitFrontendProcessing, false) +LUAU_FASTFLAGVARIABLE(LuauTypeCheckerUseCorrectScope, false) namespace Luau { @@ -1397,7 +1398,9 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect } else { - TypeChecker typeChecker(globals.globalScope, forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver, builtinTypes, &iceHandler); + TypeChecker typeChecker(FFlag::LuauTypeCheckerUseCorrectScope ? (forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope) + : globals.globalScope, + forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver, builtinTypes, &iceHandler); if (prepareModuleScope) { diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 43580da4d..000bb140a 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -192,6 +192,10 @@ static void errorToString(std::ostream& stream, const T& err) stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }"; else if constexpr (std::is_same_v) stream << "DynamicPropertyLookupOnClassesUnsafe { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "UninhabitedTypeFamily { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "UninhabitedTypePackFamily { " << toString(err.tp) << " }"; else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 0a7975f4d..5a7a05011 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -154,8 +154,8 @@ void quantify(TypeId ty, TypeLevel level) struct PureQuantifier : Substitution { Scope* scope; - std::vector insertedGenerics; - std::vector insertedGenericPacks; + OrderedMap insertedGenerics; + OrderedMap insertedGenericPacks; bool seenMutableType = false; bool seenGenericType = false; @@ -203,7 +203,7 @@ struct PureQuantifier : Substitution if (auto ftv = get(ty)) { TypeId result = arena->addType(GenericType{scope}); - insertedGenerics.push_back(result); + insertedGenerics.push(ty, result); return result; } else if (auto ttv = get(ty)) @@ -217,7 +217,10 @@ struct PureQuantifier : Substitution resultTable->scope = scope; if (ttv->state == TableState::Free) + { resultTable->state = TableState::Generic; + insertedGenerics.push(ty, result); + } else if (ttv->state == TableState::Unsealed) resultTable->state = TableState::Sealed; @@ -231,8 +234,8 @@ struct PureQuantifier : Substitution { if (auto ftp = get(tp)) { - TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}}); - insertedGenericPacks.push_back(result); + TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{scope}}); + insertedGenericPacks.push(tp, result); return result; } @@ -252,7 +255,7 @@ struct PureQuantifier : Substitution } }; -std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope) +std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope) { PureQuantifier quantifier{arena, scope}; std::optional result = quantifier.substitute(ty); @@ -262,11 +265,20 @@ std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope) FunctionType* ftv = getMutable(*result); LUAU_ASSERT(ftv); ftv->scope = scope; - ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); + + for (auto k : quantifier.insertedGenerics.keys) + { + TypeId g = quantifier.insertedGenerics.pairings[k]; + if (get(g)) + ftv->generics.push_back(g); + } + + for (auto k : quantifier.insertedGenericPacks.keys) + ftv->genericPacks.push_back(quantifier.insertedGenericPacks.pairings[k]); + ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType; - return *result; + return std::optional({*result, std::move(quantifier.insertedGenerics), std::move(quantifier.insertedGenericPacks)}); } } // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 6a600b626..40a495935 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -78,6 +78,11 @@ static TypeId DEPRECATED_shallowClone(TypeId ty, TypeArena& dest, const TxnLog* { result = dest.addType(NegationType{ntv->ty}); } + else if (const TypeFamilyInstanceType* tfit = get(ty)) + { + TypeFamilyInstanceType clone{tfit->family, tfit->typeArguments, tfit->packArguments}; + result = dest.addType(std::move(clone)); + } else return result; @@ -168,14 +173,27 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a { if (alwaysClone) { - ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName}; - return dest.addType(std::move(clone)); + if (FFlag::LuauTypecheckClassTypeIndexers) + { + ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName, a.indexer}; + return dest.addType(std::move(clone)); + } + else + { + ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName}; + return dest.addType(std::move(clone)); + } } else return ty; } else if constexpr (std::is_same_v) return dest.addType(NegationType{a.ty}); + else if constexpr (std::is_same_v) + { + TypeFamilyInstanceType clone{a.family, a.typeArguments, a.packArguments}; + return dest.addType(std::move(clone)); + } else static_assert(always_false_v, "Non-exhaustive shallowClone switch"); }; @@ -255,6 +273,14 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId a : petv->packArguments) visitChild(a); } + else if (const TypeFamilyInstanceType* tfit = get(ty)) + { + for (TypeId a : tfit->typeArguments) + visitChild(a); + + for (TypePackId a : tfit->packArguments) + visitChild(a); + } else if (const ClassType* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) { for (const auto& [name, prop] : ctv->props) @@ -265,6 +291,15 @@ void Tarjan::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable); + + if (FFlag::LuauTypecheckClassTypeIndexers) + { + if (ctv->indexer) + { + visitChild(ctv->indexer->indexType); + visitChild(ctv->indexer->indexResultType); + } + } } else if (const NegationType* ntv = get(ty)) { @@ -669,6 +704,14 @@ TypePackId Substitution::clone(TypePackId tp) clone.hidden = vtp->hidden; return addTypePack(std::move(clone)); } + else if (const TypeFamilyInstanceTypePack* tfitp = get(tp)) + { + TypeFamilyInstanceTypePack clone{ + tfitp->family, std::vector(tfitp->typeArguments.size()), std::vector(tfitp->packArguments.size())}; + clone.typeArguments.assign(tfitp->typeArguments.begin(), tfitp->typeArguments.end()); + clone.packArguments.assign(tfitp->packArguments.begin(), tfitp->packArguments.end()); + return addTypePack(std::move(clone)); + } else if (FFlag::LuauClonePublicInterfaceLess2) { return addTypePack(*tp); @@ -786,6 +829,14 @@ void Substitution::replaceChildren(TypeId ty) for (TypePackId& a : petv->packArguments) a = replace(a); } + else if (TypeFamilyInstanceType* tfit = getMutable(ty)) + { + for (TypeId& a : tfit->typeArguments) + a = replace(a); + + for (TypePackId& a : tfit->packArguments) + a = replace(a); + } else if (ClassType* ctv = getMutable(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) { for (auto& [name, prop] : ctv->props) @@ -796,6 +847,15 @@ void Substitution::replaceChildren(TypeId ty) if (ctv->metatable) ctv->metatable = replace(*ctv->metatable); + + if (FFlag::LuauTypecheckClassTypeIndexers) + { + if (ctv->indexer) + { + ctv->indexer->indexType = replace(ctv->indexer->indexType); + ctv->indexer->indexResultType = replace(ctv->indexer->indexResultType); + } + } } else if (NegationType* ntv = getMutable(ty)) { @@ -824,6 +884,14 @@ void Substitution::replaceChildren(TypePackId tp) { vtp->ty = replace(vtp->ty); } + else if (TypeFamilyInstanceTypePack* tfitp = getMutable(tp)) + { + for (TypeId& t : tfitp->typeArguments) + t = replace(t); + + for (TypePackId& t : tfitp->packArguments) + t = replace(t); + } } } // namespace Luau diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 8d889cb58..f2f15e85e 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -257,6 +257,15 @@ void StateDot::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable, index, "[metatable]"); + + if (FFlag::LuauTypecheckClassTypeIndexers) + { + if (ctv->indexer) + { + visitChild(ctv->indexer->indexType, index, "[index]"); + visitChild(ctv->indexer->indexResultType, index, "[value]"); + } + } } else if (const SingletonType* stv = get(ty)) { diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index ea3ab5775..f5b908e36 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -8,6 +8,7 @@ #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypeFamily.h" #include "Luau/VisitType.h" #include @@ -16,11 +17,22 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) /* - * Prefix generic typenames with gen- - * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 - * Fair warning: Setting this will break a lot of Luau unit tests. + * Enables increasing levels of verbosity for Luau type names when stringifying. + * After level 2, test cases will break unpredictably because a pointer to their + * scope will be included in the stringification of generic and free types. + * + * Supported values: + * + * 0: Disabled, no changes. + * + * 1: Prefix free/generic types with free- and gen-, respectively. Also reveal + * hidden variadic tails. + * + * 2: Suffix free/generic types with their scope depth. + * + * 3: Suffix free/generic types with their scope pointer, if present. */ -LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) +LUAU_FASTINTVARIABLE(DebugLuauVerboseTypeNames, 0) LUAU_FASTFLAGVARIABLE(DebugLuauToStringNoLexicalSort, false) namespace Luau @@ -223,11 +235,15 @@ struct StringifierState ++count; emit(count); - emit("-"); - char buffer[16]; - uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); - snprintf(buffer, sizeof(buffer), "0x%x", s); - emit(buffer); + + if (FInt::DebugLuauVerboseTypeNames >= 3) + { + emit("-"); + char buffer[16]; + uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); + snprintf(buffer, sizeof(buffer), "0x%x", s); + emit(buffer); + } } void emit(TypeLevel level) @@ -371,11 +387,13 @@ struct TypeStringifier void operator()(TypeId ty, const FreeType& ftv) { state.result.invalid = true; - if (FFlag::DebugLuauVerboseTypeNames) + + if (FInt::DebugLuauVerboseTypeNames >= 1) state.emit("free-"); + state.emit(state.getName(ty)); - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); if (FFlag::DebugLuauDeferredConstraintResolution) @@ -392,6 +410,9 @@ struct TypeStringifier void operator()(TypeId ty, const GenericType& gtv) { + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit("gen-"); + if (gtv.explicitName) { state.usedNames.insert(gtv.name); @@ -401,7 +422,7 @@ struct TypeStringifier else state.emit(state.getName(ty)); - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); if (FFlag::DebugLuauDeferredConstraintResolution) @@ -871,6 +892,33 @@ struct TypeStringifier if (parens) state.emit(")"); } + + void operator()(TypeId, const TypeFamilyInstanceType& tfitv) + { + state.emit(tfitv.family->name); + state.emit("<"); + + bool comma = false; + for (TypeId ty : tfitv.typeArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(ty); + } + + for (TypePackId tp : tfitv.packArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(tp); + } + + state.emit(">"); + } }; struct TypePackStringifier @@ -958,7 +1006,7 @@ struct TypePackStringifier if (tp.tail && !isEmpty(*tp.tail)) { TypePackId tail = follow(*tp.tail); - if (auto vtp = get(tail); !vtp || (!FFlag::DebugLuauVerboseTypeNames && !vtp->hidden)) + if (auto vtp = get(tail); !vtp || (FInt::DebugLuauVerboseTypeNames < 1 && !vtp->hidden)) { if (first) first = false; @@ -981,7 +1029,7 @@ struct TypePackStringifier void operator()(TypePackId, const VariadicTypePack& pack) { state.emit("..."); - if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) + if (FInt::DebugLuauVerboseTypeNames >= 1 && pack.hidden) { state.emit("*hidden*"); } @@ -990,6 +1038,9 @@ struct TypePackStringifier void operator()(TypePackId tp, const GenericTypePack& pack) { + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit("gen-"); + if (pack.explicitName) { state.usedNames.insert(pack.name); @@ -1001,7 +1052,7 @@ struct TypePackStringifier state.emit(state.getName(tp)); } - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); if (FFlag::DebugLuauDeferredConstraintResolution) @@ -1009,17 +1060,18 @@ struct TypePackStringifier else state.emit(pack.level); } + state.emit("..."); } void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 1) state.emit("free-"); state.emit(state.getName(tp)); - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); if (FFlag::DebugLuauDeferredConstraintResolution) @@ -1042,6 +1094,33 @@ struct TypePackStringifier state.emit(btp.index); state.emit("*"); } + + void operator()(TypePackId, const TypeFamilyInstanceTypePack& tfitp) + { + state.emit(tfitp.family->name); + state.emit("<"); + + bool comma = false; + for (TypeId p : tfitp.typeArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(p); + } + + for (TypePackId p : tfitp.packArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(p); + } + + state.emit(">"); + } }; void TypeStringifier::stringify(TypePackId tp) @@ -1560,6 +1639,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) } else if constexpr (std::is_same_v) return tos(c.resultPack) + " ~ unpack " + tos(c.sourcePack); + else if constexpr (std::is_same_v) + return "reduce " + tos(c.ty); + else if constexpr (std::is_same_v) + { + return "reduce " + tos(c.tp); + } else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 33554ce90..53dd3b445 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TxnLog.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeArena.h" #include "Luau/TypePack.h" @@ -8,6 +9,8 @@ #include #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + namespace Luau { @@ -71,7 +74,11 @@ const TxnLog* TxnLog::empty() void TxnLog::concat(TxnLog rhs) { for (auto& [ty, rep] : rhs.typeVarChanges) + { + if (rep->dead) + continue; typeVarChanges[ty] = std::move(rep); + } for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); @@ -81,7 +88,10 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) { for (auto& [ty, rightRep] : rhs.typeVarChanges) { - if (auto leftRep = typeVarChanges.find(ty)) + if (rightRep->dead) + continue; + + if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) { TypeId leftTy = arena->addType((*leftRep)->pending); TypeId rightTy = arena->addType(rightRep->pending); @@ -97,16 +107,94 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) { - for (auto& [ty, rightRep] : rhs.typeVarChanges) + if (FFlag::DebugLuauDeferredConstraintResolution) { - if (auto leftRep = typeVarChanges.find(ty)) + /* + * Check for cycles. + * + * We must not combine a log entry that binds 'a to 'b with a log that + * binds 'b to 'a. + * + * Of the two, identify the one with the 'bigger' scope and eliminate the + * entry that rebinds it. + */ + for (const auto& [rightTy, rightRep] : rhs.typeVarChanges) { - TypeId leftTy = arena->addType((*leftRep)->pending); - TypeId rightTy = arena->addType(rightRep->pending); - typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; + if (rightRep->dead) + continue; + + // We explicitly use get_if here because we do not wish to do anything + // if the uncommitted type is already bound to something else. + const FreeType* rf = get_if(&rightTy->ty); + if (!rf) + continue; + + const BoundType* rb = Luau::get(&rightRep->pending); + if (!rb) + continue; + + const TypeId leftTy = rb->boundTo; + const FreeType* lf = get_if(&leftTy->ty); + if (!lf) + continue; + + auto leftRep = typeVarChanges.find(leftTy); + if (!leftRep) + continue; + + if ((*leftRep)->dead) + continue; + + const BoundType* lb = Luau::get(&(*leftRep)->pending); + if (!lb) + continue; + + if (lb->boundTo == rightTy) + { + // leftTy has been bound to rightTy, but rightTy has also been bound + // to leftTy. We find the one that belongs to the more deeply nested + // scope and remove it from the log. + const bool discardLeft = useScopes ? subsumes(lf->scope, rf->scope) : lf->level.subsumes(rf->level); + + if (discardLeft) + (*leftRep)->dead = true; + else + rightRep->dead = true; + } + } + + for (auto& [ty, rightRep] : rhs.typeVarChanges) + { + if (rightRep->dead) + continue; + + if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) + { + TypeId leftTy = arena->addType((*leftRep)->pending); + TypeId rightTy = arena->addType(rightRep->pending); + + if (follow(leftTy) == follow(rightTy)) + typeVarChanges[ty] = std::move(rightRep); + else + typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; + } + else + typeVarChanges[ty] = std::move(rightRep); + } + } + else + { + for (auto& [ty, rightRep] : rhs.typeVarChanges) + { + if (auto leftRep = typeVarChanges.find(ty)) + { + TypeId leftTy = arena->addType((*leftRep)->pending); + TypeId rightTy = arena->addType(rightRep->pending); + typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; + } + else + typeVarChanges[ty] = std::move(rightRep); } - else - typeVarChanges[ty] = std::move(rightRep); } for (auto& [tp, rep] : rhs.typePackChanges) @@ -116,7 +204,10 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) void TxnLog::commit() { for (auto& [ty, rep] : typeVarChanges) - asMutable(ty)->reassign(rep.get()->pending); + { + if (!rep->dead) + asMutable(ty)->reassign(rep.get()->pending); + } for (auto& [tp, rep] : typePackChanges) asMutable(tp)->reassign(rep.get()->pending); @@ -135,7 +226,10 @@ TxnLog TxnLog::inverse() TxnLog inversed(sharedSeen); for (auto& [ty, _rep] : typeVarChanges) - inversed.typeVarChanges[ty] = std::make_unique(*ty); + { + if (!_rep->dead) + inversed.typeVarChanges[ty] = std::make_unique(*ty); + } for (auto& [tp, _rep] : typePackChanges) inversed.typePackChanges[tp] = std::make_unique(*tp); @@ -204,7 +298,7 @@ PendingType* TxnLog::queue(TypeId ty) // Explicitly don't look in ancestors. If we have discovered something new // about this type, we don't want to mutate the parent's state. auto& pending = typeVarChanges[ty]; - if (!pending) + if (!pending || (*pending).dead) { pending = std::make_unique(*ty); pending->pending.owningArena = nullptr; @@ -237,7 +331,7 @@ PendingType* TxnLog::pending(TypeId ty) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typeVarChanges.find(ty)) + if (auto it = current->typeVarChanges.find(ty); it && !(*it)->dead) return it->get(); } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 86f781650..dba95479c 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -9,6 +9,7 @@ #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypeFamily.h" #include @@ -362,6 +363,10 @@ class TypeRehydrationVisitor // FIXME: do the same thing we do with ErrorType throw InternalCompilerError("Cannot convert NegationType into AstNode"); } + AstType* operator()(const TypeFamilyInstanceType& tfit) + { + return allocator->alloc(Location(), std::nullopt, AstName{tfit.family->name.c_str()}, std::nullopt, Location()); + } private: Allocator* allocator; @@ -432,6 +437,11 @@ class TypePackRehydrationVisitor return allocator->alloc(Location(), AstName("Unifiable")); } + AstTypePack* operator()(const TypeFamilyInstanceTypePack& tfitp) const + { + return allocator->alloc(Location(), AstName(tfitp.family->name.c_str())); + } + private: Allocator* allocator; SyntheticNames* syntheticNames; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 2a2fe69ce..a1f764a4d 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -16,6 +16,7 @@ #include "Luau/TypeReduction.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier.h" +#include "Luau/TypeFamily.h" #include @@ -113,6 +114,13 @@ struct TypeChecker2 return std::nullopt; } + TypeId checkForFamilyInhabitance(TypeId instance, Location location) + { + TxnLog fake{}; + reportErrors(reduceFamilies(instance, location, NotNull{&testArena}, builtinTypes, &fake, true).errors); + return instance; + } + TypePackId lookupPack(AstExpr* expr) { // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. @@ -132,11 +140,11 @@ struct TypeChecker2 // allows us not to think about this very much in the actual typechecking logic. TypeId* ty = module->astTypes.find(expr); if (ty) - return follow(*ty); + return checkForFamilyInhabitance(follow(*ty), expr->location); TypePackId* tp = module->astTypePacks.find(expr); if (tp) - return flattenPack(*tp); + return checkForFamilyInhabitance(flattenPack(*tp), expr->location); return builtinTypes->anyType; } @@ -159,7 +167,7 @@ struct TypeChecker2 TypeId* ty = module->astResolvedTypes.find(annotation); LUAU_ASSERT(ty); - return follow(*ty); + return checkForFamilyInhabitance(follow(*ty), annotation->location); } TypePackId lookupPackAnnotation(AstTypePack* annotation) @@ -311,6 +319,7 @@ struct TypeChecker2 TypePackId actualRetType = reconstructPack(ret->list, *arena); Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; + u.hideousFixMeGenericsAreActuallyFree = true; u.tryUnify(actualRetType, expectedRetType); const bool ok = u.errors.empty() && u.log.empty(); @@ -989,8 +998,11 @@ struct TypeChecker2 return; } + TxnLog fake{}; + LUAU_ASSERT(ftv); - reportErrors(tryUnify(stack.back(), call->location, ftv->retTypes, expectedRetType, CountMismatch::Context::Return)); + reportErrors(tryUnify(stack.back(), call->location, ftv->retTypes, expectedRetType, CountMismatch::Context::Return, /* genericsOkay */ true)); + reportErrors(reduceFamilies(ftv->retTypes, call->location, NotNull{&testArena}, builtinTypes, &fake, true).errors); auto it = begin(expectedArgTypes); size_t i = 0; @@ -1007,7 +1019,8 @@ struct TypeChecker2 Location argLoc = argLocs.at(i >= argLocs.size() ? argLocs.size() - 1 : i); - reportErrors(tryUnify(stack.back(), argLoc, expectedArg, arg)); + reportErrors(tryUnify(stack.back(), argLoc, expectedArg, arg, CountMismatch::Context::Arg, /* genericsOkay */ true)); + reportErrors(reduceFamilies(arg, argLoc, NotNull{&testArena}, builtinTypes, &fake, true).errors); ++it; ++i; @@ -1018,7 +1031,8 @@ struct TypeChecker2 if (auto tail = it.tail()) { TypePackId remainingArgs = testArena.addTypePack(TypePack{std::move(slice), std::nullopt}); - reportErrors(tryUnify(stack.back(), argLocs.back(), *tail, remainingArgs)); + reportErrors(tryUnify(stack.back(), argLocs.back(), *tail, remainingArgs, CountMismatch::Context::Arg, /* genericsOkay */ true)); + reportErrors(reduceFamilies(remainingArgs, argLocs.back(), NotNull{&testArena}, builtinTypes, &fake, true).errors); } } @@ -1344,7 +1358,7 @@ struct TypeChecker2 else if (get(rightType) || get(rightType)) return rightType; - if ((get(leftType) || get(leftType)) && !isEquality && !isLogical) + if ((get(leftType) || get(leftType) || get(leftType)) && !isEquality && !isLogical) { auto name = getIdentifierOfBaseVar(expr->left); reportError(CannotInferBinaryOperation{expr->op, name, @@ -1591,10 +1605,10 @@ struct TypeChecker2 TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back())) + if (isSubtype(annotationType, computedType, stack.back(), true)) return; - if (isSubtype(computedType, annotationType, stack.back())) + if (isSubtype(computedType, annotationType, stack.back(), true)) return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); @@ -1679,6 +1693,10 @@ struct TypeChecker2 void visit(AstType* ty) { + TypeId* resolvedTy = module->astResolvedTypes.find(ty); + if (resolvedTy) + checkForFamilyInhabitance(follow(*resolvedTy), ty->location); + if (auto t = ty->as()) return visit(t); else if (auto t = ty->as()) @@ -1989,11 +2007,12 @@ struct TypeChecker2 } template - bool isSubtype(TID subTy, TID superTy, NotNull scope) + bool isSubtype(TID subTy, TID superTy, NotNull scope, bool genericsOkay = false) { TypeArena arena; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; + u.hideousFixMeGenericsAreActuallyFree = genericsOkay; + u.enableScopeTests(); u.tryUnify(subTy, superTy); const bool ok = u.errors.empty() && u.log.empty(); @@ -2001,11 +2020,13 @@ struct TypeChecker2 } template - ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy, CountMismatch::Context context = CountMismatch::Arg) + ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy, CountMismatch::Context context = CountMismatch::Arg, + bool genericsOkay = false) { Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; u.ctx = context; - u.useScopes = true; + u.hideousFixMeGenericsAreActuallyFree = genericsOkay; + u.enableScopeTests(); u.tryUnify(subTy, superTy); return std::move(u.errors); diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp new file mode 100644 index 000000000..1941573b6 --- /dev/null +++ b/Analysis/src/TypeFamily.cpp @@ -0,0 +1,310 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFamily.h" + +#include "Luau/DenseHash.h" +#include "Luau/VisitType.h" +#include "Luau/TxnLog.h" +#include "Luau/Substitution.h" +#include "Luau/ToString.h" + +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); + +namespace Luau +{ + +struct InstanceCollector : TypeOnceVisitor +{ + std::deque tys; + std::deque tps; + + bool visit(TypeId ty, const TypeFamilyInstanceType&) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + tys.push_front(ty); + return true; + } + + bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + tps.push_front(tp); + return true; + } +}; + +struct FamilyReducer +{ + std::deque queuedTys; + std::deque queuedTps; + DenseHashSet irreducible{nullptr}; + FamilyGraphReductionResult result; + Location location; + NotNull arena; + NotNull builtins; + TxnLog* log = nullptr; + NotNull reducerLog; + bool force = false; + + FamilyReducer(std::deque queuedTys, std::deque queuedTps, Location location, NotNull arena, + NotNull builtins, TxnLog* log = nullptr, bool force = false) + : queuedTys(std::move(queuedTys)) + , queuedTps(std::move(queuedTps)) + , location(location) + , arena(arena) + , builtins(builtins) + , log(log) + , reducerLog(NotNull{log ? log : TxnLog::empty()}) + , force(force) + { + } + + enum class SkipTestResult + { + Irreducible, + Defer, + Okay, + }; + + SkipTestResult testForSkippability(TypeId ty) + { + ty = reducerLog->follow(ty); + + if (reducerLog->is(ty)) + { + if (!irreducible.contains(ty)) + return SkipTestResult::Defer; + else + return SkipTestResult::Irreducible; + } + else if (reducerLog->is(ty)) + { + return SkipTestResult::Irreducible; + } + + return SkipTestResult::Okay; + } + + SkipTestResult testForSkippability(TypePackId ty) + { + ty = reducerLog->follow(ty); + + if (reducerLog->is(ty)) + { + if (!irreducible.contains(ty)) + return SkipTestResult::Defer; + else + return SkipTestResult::Irreducible; + } + else if (reducerLog->is(ty)) + { + return SkipTestResult::Irreducible; + } + + return SkipTestResult::Okay; + } + + template + void replace(T subject, T replacement) + { + if (log) + log->replace(subject, Unifiable::Bound{replacement}); + else + asMutable(subject)->ty.template emplace>(replacement); + + if constexpr (std::is_same_v) + result.reducedTypes.insert(subject); + else if constexpr (std::is_same_v) + result.reducedPacks.insert(subject); + } + + template + void handleFamilyReduction(T subject, TypeFamilyReductionResult reduction) + { + if (reduction.result) + replace(subject, *reduction.result); + else + { + irreducible.insert(subject); + + if (reduction.uninhabited || force) + { + if constexpr (std::is_same_v) + result.errors.push_back(TypeError{location, UninhabitedTypeFamily{subject}}); + else if constexpr (std::is_same_v) + result.errors.push_back(TypeError{location, UninhabitedTypePackFamily{subject}}); + } + else if (!reduction.uninhabited && !force) + { + for (TypeId b : reduction.blockedTypes) + result.blockedTypes.insert(b); + + for (TypePackId b : reduction.blockedPacks) + result.blockedPacks.insert(b); + } + } + } + + bool done() + { + return queuedTys.empty() && queuedTps.empty(); + } + + template + bool testParameters(T subject, const I* tfit) + { + for (TypeId p : tfit->typeArguments) + { + SkipTestResult skip = testForSkippability(p); + + if (skip == SkipTestResult::Irreducible) + { + irreducible.insert(subject); + return false; + } + else if (skip == SkipTestResult::Defer) + { + if constexpr (std::is_same_v) + queuedTys.push_back(subject); + else if constexpr (std::is_same_v) + queuedTps.push_back(subject); + + return false; + } + } + + for (TypePackId p : tfit->packArguments) + { + SkipTestResult skip = testForSkippability(p); + + if (skip == SkipTestResult::Irreducible) + { + irreducible.insert(subject); + return false; + } + else if (skip == SkipTestResult::Defer) + { + if constexpr (std::is_same_v) + queuedTys.push_back(subject); + else if constexpr (std::is_same_v) + queuedTps.push_back(subject); + + return false; + } + } + + return true; + } + + void stepType() + { + TypeId subject = reducerLog->follow(queuedTys.front()); + queuedTys.pop_front(); + + if (irreducible.contains(subject)) + return; + + if (const TypeFamilyInstanceType* tfit = reducerLog->get(subject)) + { + if (!testParameters(subject, tfit)) + return; + + TypeFamilyReductionResult result = tfit->family->reducer(tfit->typeArguments, tfit->packArguments, arena, builtins, reducerLog); + handleFamilyReduction(subject, result); + } + } + + void stepPack() + { + TypePackId subject = reducerLog->follow(queuedTps.front()); + queuedTps.pop_front(); + + if (irreducible.contains(subject)) + return; + + if (const TypeFamilyInstanceTypePack* tfit = reducerLog->get(subject)) + { + if (!testParameters(subject, tfit)) + return; + + TypeFamilyReductionResult result = + tfit->family->reducer(tfit->typeArguments, tfit->packArguments, arena, builtins, reducerLog); + handleFamilyReduction(subject, result); + } + } + + void step() + { + if (!queuedTys.empty()) + stepType(); + else if (!queuedTps.empty()) + stepPack(); + } +}; + +static FamilyGraphReductionResult reduceFamiliesInternal(std::deque queuedTys, std::deque queuedTps, Location location, + NotNull arena, NotNull builtins, TxnLog* log, bool force) +{ + FamilyReducer reducer{std::move(queuedTys), std::move(queuedTps), location, arena, builtins, log, force}; + int iterationCount = 0; + + while (!reducer.done()) + { + reducer.step(); + + ++iterationCount; + if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) + { + reducer.result.errors.push_back(TypeError{location, CodeTooComplex{}}); + break; + } + } + + return std::move(reducer.result); +} + +FamilyGraphReductionResult reduceFamilies( + TypeId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log, bool force) +{ + InstanceCollector collector; + + try + { + collector.traverse(entrypoint); + } + catch (RecursionLimitException&) + { + return FamilyGraphReductionResult{}; + } + + return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, log, force); +} + +FamilyGraphReductionResult reduceFamilies( + TypePackId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log, bool force) +{ + InstanceCollector collector; + + try + { + collector.traverse(entrypoint); + } + catch (RecursionLimitException&) + { + return FamilyGraphReductionResult{}; + } + + return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, log, force); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 1ccba91e7..94c64ee25 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -41,6 +41,7 @@ LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAG(LuauRequirePathTrueModuleName) +LUAU_FASTFLAGVARIABLE(LuauTypecheckClassTypeIndexers, false) namespace Luau { @@ -2104,6 +2105,23 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( const Property* prop = lookupClassProp(cls, name); if (prop) return prop->type(); + + if (FFlag::LuauTypecheckClassTypeIndexers) + { + if (auto indexer = cls->indexer) + { + // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. + ErrorVec errors = tryUnify(stringType, indexer->indexType, scope, location); + + if (errors.empty()) + return indexer->indexResultType; + + if (addErrors) + reportError(location, UnknownProperty{type, name}); + + return std::nullopt; + } + } } else if (const UnionType* utv = get(type)) { @@ -3295,14 +3313,38 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (const ClassType* lhsClass = get(lhs)) { - const Property* prop = lookupClassProp(lhsClass, name); - if (!prop) + if (FFlag::LuauTypecheckClassTypeIndexers) { + if (const Property* prop = lookupClassProp(lhsClass, name)) + { + return prop->type(); + } + + if (auto indexer = lhsClass->indexer) + { + Unifier state = mkUnifier(scope, expr.location); + state.tryUnify(stringType, indexer->indexType); + if (state.errors.empty()) + { + state.log.commit(); + return indexer->indexResultType; + } + } + reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); return errorRecoveryType(scope); } + else + { + const Property* prop = lookupClassProp(lhsClass, name); + if (!prop) + { + reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); + return errorRecoveryType(scope); + } - return prop->type(); + return prop->type(); + } } else if (get(lhs)) { @@ -3344,23 +3386,57 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { if (const ClassType* exprClass = get(exprType)) { - const Property* prop = lookupClassProp(exprClass, value->value.data); - if (!prop) + if (FFlag::LuauTypecheckClassTypeIndexers) { + if (const Property* prop = lookupClassProp(exprClass, value->value.data)) + { + return prop->type(); + } + + if (auto indexer = exprClass->indexer) + { + unify(stringType, indexer->indexType, scope, expr.index->location); + return indexer->indexResultType; + } + reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); return errorRecoveryType(scope); } - return prop->type(); + else + { + const Property* prop = lookupClassProp(exprClass, value->value.data); + if (!prop) + { + reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); + return errorRecoveryType(scope); + } + return prop->type(); + } } } - else if (FFlag::LuauAllowIndexClassParameters) + else { - if (const ClassType* exprClass = get(exprType)) + if (FFlag::LuauTypecheckClassTypeIndexers) { - if (isNonstrictMode()) - return unknownType; - reportError(TypeError{expr.location, DynamicPropertyLookupOnClassesUnsafe{exprType}}); - return errorRecoveryType(scope); + if (const ClassType* exprClass = get(exprType)) + { + if (auto indexer = exprClass->indexer) + { + unify(indexType, indexer->indexType, scope, expr.index->location); + return indexer->indexResultType; + } + } + } + + if (FFlag::LuauAllowIndexClassParameters) + { + if (const ClassType* exprClass = get(exprType)) + { + if (isNonstrictMode()) + return unknownType; + reportError(TypeError{expr.location, DynamicPropertyLookupOnClassesUnsafe{exprType}}); + return errorRecoveryType(scope); + } } } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 6047a49b1..56be40471 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -12,6 +12,7 @@ #include "Luau/TypeUtils.h" #include "Luau/Type.h" #include "Luau/VisitType.h" +#include "Luau/TypeFamily.h" #include @@ -20,6 +21,7 @@ LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) LUAU_FASTFLAGVARIABLE(LuauVariadicAnyCanBeGeneric, false) +LUAU_FASTFLAGVARIABLE(LuauUnifyTwoOptions, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) @@ -439,6 +441,30 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; + if (log.get(superTy)) + { + // We do not report errors from reducing here. This is because we will + // "double-report" errors in some cases, like when trying to unify + // identical type family instantiations like Add with + // Add. + reduceFamilies(superTy, location, NotNull(types), builtinTypes, &log); + superTy = log.follow(superTy); + } + + if (log.get(subTy)) + { + reduceFamilies(subTy, location, NotNull(types), builtinTypes, &log); + subTy = log.follow(subTy); + } + + // If we can't reduce the families down and we still have type family types + // here, we are stuck. Nothing meaningful can be done here. We don't wish to + // report an error, either. + if (log.get(superTy) || log.get(subTy)) + { + return; + } + auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); @@ -509,6 +535,49 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return; } + if (hideousFixMeGenericsAreActuallyFree) + { + auto superGeneric = log.getMutable(superTy); + auto subGeneric = log.getMutable(subTy); + + if (superGeneric && subGeneric && subsumes(useScopes, superGeneric, subGeneric)) + { + if (!occursCheck(subTy, superTy, /* reversed = */ false)) + log.replace(subTy, BoundType(superTy)); + + return; + } + else if (superGeneric && subGeneric) + { + if (!occursCheck(superTy, subTy, /* reversed = */ true)) + log.replace(superTy, BoundType(subTy)); + + return; + } + else if (superGeneric) + { + if (!occursCheck(superTy, subTy, /* reversed = */ true)) + { + Widen widen{types, builtinTypes}; + log.replace(superTy, BoundType(widen(subTy))); + } + + return; + } + else if (subGeneric) + { + // Normally, if the subtype is free, it should not be bound to any, unknown, or error types. + // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. + if (log.get(superTy)) + return; + + if (!occursCheck(subTy, superTy, /* reversed = */ false)) + log.replace(subTy, BoundType(superTy)); + + return; + } + } + if (log.get(superTy)) return tryUnifyWithAny(subTy, builtinTypes->anyType); @@ -687,8 +756,93 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool log.popSeen(superTy, subTy); } +/* + * If the passed type is an option, strip nil out. + * + * There is an important subtlety to be observed here: + * + * We want to do a peephole fix to unify the subtype relation A? <: B? where we + * instead peel off the options and relate A <: B instead, but only works if we + * are certain that neither A nor B are themselves optional. + * + * For instance, if we want to test that (boolean?)? <: boolean?, we must peel + * off both layers of optionality from the subTy. + * + * We must also handle unions that have more than two choices. + * + * eg (string | nil)? <: boolean? + */ +static std::optional unwrapOption(NotNull builtinTypes, NotNull arena, const TxnLog& log, TypeId ty, DenseHashSet& seen) +{ + if (seen.find(ty)) + return std::nullopt; + seen.insert(ty); + + const UnionType* ut = get(follow(ty)); + if (!ut) + return std::nullopt; + + if (2 == ut->options.size()) + { + if (isNil(follow(ut->options[0]))) + { + std::optional doubleUnwrapped = unwrapOption(builtinTypes, arena, log, ut->options[1], seen); + return doubleUnwrapped.value_or(ut->options[1]); + } + if (isNil(follow(ut->options[1]))) + { + std::optional doubleUnwrapped = unwrapOption(builtinTypes, arena, log, ut->options[0], seen); + return doubleUnwrapped.value_or(ut->options[0]); + } + } + + std::set newOptions; + bool found = false; + for (TypeId t : ut) + { + t = log.follow(t); + if (isNil(t)) + { + found = true; + continue; + } + else + newOptions.insert(t); + } + + if (!found) + return std::nullopt; + else if (newOptions.empty()) + return builtinTypes->neverType; + else if (1 == newOptions.size()) + return *begin(newOptions); + else + return arena->addType(UnionType{std::vector(begin(newOptions), end(newOptions))}); +} + +static std::optional unwrapOption(NotNull builtinTypes, NotNull arena, const TxnLog& log, TypeId ty) +{ + DenseHashSet seen{nullptr}; + + return unwrapOption(builtinTypes, arena, log, ty, seen); +} + + void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, TypeId superTy) { + // Peephole fix: A? <: B? if A <: B + // + // This works around issues that can arise if A or B is free. We do not + // want either of those types to be bound to nil. + if (FFlag::LuauUnifyTwoOptions) + { + if (auto subOption = unwrapOption(builtinTypes, NotNull{types}, log, subTy)) + { + if (auto superOption = unwrapOption(builtinTypes, NotNull{types}, log, superTy)) + return tryUnify_(*subOption, *superOption); + } + } + // A | B <: T if and only if A <: T and B <: T bool failed = false; bool errorsSuppressed = true; @@ -1205,6 +1359,25 @@ void Unifier::tryUnifyNormalizedTypes( const ClassType* superCtv = get(superClass); LUAU_ASSERT(superCtv); + if (FFlag::LuauUnifyTwoOptions) + { + if (variance == Invariant) + { + if (subCtv == superCtv) + { + found = true; + + /* + * The only way we could care about superNegations is if + * one of them was equal to superCtv. However, + * normalization ensures that this is impossible. + */ + } + else + continue; + } + } + if (isSubclass(subCtv, superCtv)) { found = true; @@ -1518,6 +1691,12 @@ struct WeirdIter } }; +void Unifier::enableScopeTests() +{ + useScopes = true; + log.useScopes = true; +} + ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) { Unifier s = makeChildUnifier(); @@ -1597,6 +1776,21 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal log.replace(subTp, Unifiable::Bound(superTp)); } } + else if (hideousFixMeGenericsAreActuallyFree && log.getMutable(superTp)) + { + if (!occursCheck(superTp, subTp, /* reversed = */ true)) + { + Widen widen{types, builtinTypes}; + log.replace(superTp, Unifiable::Bound(widen(subTp))); + } + } + else if (hideousFixMeGenericsAreActuallyFree && log.getMutable(subTp)) + { + if (!occursCheck(subTp, superTp, /* reversed = */ false)) + { + log.replace(subTp, Unifiable::Bound(superTp)); + } + } else if (log.getMutable(superTp)) tryUnifyWithAny(subTp, superTp); else if (log.getMutable(subTp)) @@ -2611,7 +2805,10 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else if (get(tail)) { - reportError(location, GenericError{"Cannot unify variadic and generic packs"}); + if (!hideousFixMeGenericsAreActuallyFree) + reportError(location, GenericError{"Cannot unify variadic and generic packs"}); + else + log.replace(tail, BoundTypePack{superTp}); } else if (get(tail)) { @@ -2732,7 +2929,7 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) { LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); - TxnLog result; + TxnLog result(useScopes); for (TxnLog& log : logs) result.concatAsIntersections(std::move(log), NotNull{types}); return result; @@ -2741,7 +2938,7 @@ TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) TxnLog Unifier::combineLogsIntoUnion(std::vector logs) { LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); - TxnLog result; + TxnLog result(useScopes); for (TxnLog& log : logs) result.concatAsUnion(std::move(log), NotNull{types}); return result; @@ -2807,7 +3004,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle) && !(hideousFixMeGenericsAreActuallyFree && log.is(needle))) ice("Expected needle to be free"); if (needle == haystack) @@ -2821,7 +3018,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays return true; } - if (log.getMutable(haystack)) + if (log.getMutable(haystack) || (hideousFixMeGenericsAreActuallyFree && log.is(haystack))) return false; else if (auto a = log.getMutable(haystack)) { @@ -2865,7 +3062,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle) && !(hideousFixMeGenericsAreActuallyFree && log.is(needle))) ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); @@ -2900,7 +3097,10 @@ Unifier Unifier::makeChildUnifier() Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; u.normalize = normalize; u.checkInhabited = checkInhabited; - u.useScopes = useScopes; + + if (useScopes) + u.enableScopeTests(); + return u; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 20158e8eb..a486ad0f9 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -8,6 +8,7 @@ #include #include +#include namespace Luau { diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 343c553c3..d7099a314 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace Luau { diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 4303364cd..a585a73a2 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -27,6 +27,10 @@ #include #endif +#ifdef __linux__ +#include +#endif + #ifdef CALLGRIND #include #endif @@ -865,6 +869,7 @@ int replMain(int argc, char** argv) int profile = 0; bool coverage = false; bool interactive = false; + bool codegenPerf = false; // Set the mode if the user has explicitly specified one. int argStart = 1; @@ -962,6 +967,11 @@ int replMain(int argc, char** argv) { codegen = true; } + else if (strcmp(argv[i], "--codegen-perf") == 0) + { + codegen = true; + codegenPerf = true; + } else if (strcmp(argv[i], "--coverage") == 0) { coverage = true; @@ -998,6 +1008,24 @@ int replMain(int argc, char** argv) } #endif + if (codegenPerf) + { +#if __linux__ + char path[128]; + snprintf(path, sizeof(path), "/tmp/perf-%d.map", getpid()); + + // note, there's no need to close the log explicitly as it will be closed when the process exits + FILE* codegenPerfLog = fopen(path, "w"); + + Luau::CodeGen::setPerfLog(codegenPerfLog, [](void* context, uintptr_t addr, unsigned size, const char* symbol) { + fprintf(static_cast(context), "%016lx %08x %s\n", long(addr), size, symbol); + }); +#else + fprintf(stderr, "--codegen-perf option is only supported on Linux\n"); + return 1; +#endif + } + const std::vector files = getSourceFiles(argc, argv); if (mode == CliMode::Unknown) { diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e15e5f88..b3b1573ac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -192,6 +192,7 @@ if(LUAU_BUILD_CLI) find_library(LIBPTHREAD pthread) if (LIBPTHREAD) target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + target_link_libraries(Luau.Analyze.CLI PRIVATE pthread) endif() endif() diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 26be11c54..e7733cd2c 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -56,7 +56,7 @@ class AssemblyBuilderA64 void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); void bic(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); void tst(RegisterA64 src1, RegisterA64 src2, int shift = 0); - void mvn(RegisterA64 dst, RegisterA64 src); + void mvn_(RegisterA64 dst, RegisterA64 src); // Bitwise with immediate // Note: immediate must have a single contiguous sequence of 1 bits set of length 1..31 @@ -199,7 +199,7 @@ class AssemblyBuilderA64 void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); - void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size, int sizelog); + void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint16_t opsize, int sizelog); void placeB(const char* name, Label& label, uint8_t op); void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index e162cd3e4..a372bf911 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Common.h" +#include "Luau/DenseHash.h" #include "Luau/Label.h" #include "Luau/ConditionX64.h" #include "Luau/OperandX64.h" @@ -250,6 +251,8 @@ class AssemblyBuilderX64 std::vector is uninhabited" == toString(result.errors[0])); + CHECK("Type family instance Swap is uninhabited" == toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(FamilyFixture, "resolve_deep_families") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local x: Swap>> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number" == toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(FamilyFixture, "unsolvable_family") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local impossible: (Swap) -> Swap> + local a = impossible(123) + local b = impossible(true) + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + for (size_t i = 0; i < 4; ++i) + { + CHECK(toString(result.errors[i]) == "Type family instance Swap is uninhabited"); + } +} + +TEST_CASE_FIXTURE(FamilyFixture, "table_internal_families") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local t: ({T}) -> {Swap} + local a = t({1, 2, 3}) + local b = t({"a", "b", "c"}) + local c = t({true, false, true}) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(requireType("a")) == "{string}"); + CHECK(toString(requireType("b")) == "{number}"); + CHECK(toString(requireType("c")) == "{Swap}"); + CHECK(toString(result.errors[0]) == "Type family instance Swap is uninhabited"); +} + +TEST_CASE_FIXTURE(FamilyFixture, "function_internal_families") +{ + // This test is broken right now, but it's not because of type families. See + // CLI-71143. + + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local f0: (T) -> (() -> T) + local f: (T) -> (() -> Swap) + local a = f(1) + local b = f("a") + local c = f(true) + local d = f0(1) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(requireType("a")) == "() -> string"); + CHECK(toString(requireType("b")) == "() -> number"); + CHECK(toString(requireType("c")) == "() -> Swap"); + CHECK(toString(result.errors[0]) == "Type family instance Swap is uninhabited"); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index becc88aa6..607fc40aa 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -481,4 +481,150 @@ TEST_CASE_FIXTURE(ClassFixture, "callable_classes") CHECK_EQ("number", toString(requireType("y"))); } +TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") +{ + // Test reading from an index + ScopedFastFlag LuauTypecheckClassTypeIndexers("LuauTypecheckClassTypeIndexers", true); + { + CheckResult result = check(R"( + local x : IndexableClass + local y = x.stringKey + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + local y = x["stringKey"] + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + local str : string + local y = x[str] -- Index with a non-const string + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + local y = x[7] -- Index with a numeric key + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + + // Test writing to an index + { + CheckResult result = check(R"( + local x : IndexableClass + x.stringKey = 42 + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + x["stringKey"] = 42 + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + local str : string + x[str] = 42 -- Index with a non-const string + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + x[1] = 42 -- Index with a numeric key + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + + // Try to index the class using an invalid type for the key (key type is 'number | string'.) + { + CheckResult result = check(R"( + local x : IndexableClass + local y = x[true] + )"); + CHECK_EQ( + toString(result.errors[0]), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible"); + } + { + CheckResult result = check(R"( + local x : IndexableClass + x[true] = 42 + )"); + CHECK_EQ( + toString(result.errors[0]), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible"); + } + + // Test type checking for the return type of the indexer (i.e. a number) + { + CheckResult result = check(R"( + local x : IndexableClass + x.key = "string value" + )"); + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + } + { + CheckResult result = check(R"( + local x : IndexableClass + local str : string = x.key + )"); + CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'string'"); + } + + // Check that we string key are rejected if the indexer's key type is not compatible with string + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + x.key = 1 + )"); + CHECK_EQ(toString(result.errors.at(0)), "Key 'key' not found in class 'IndexableNumericKeyClass'"); + } + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + x["key"] = 1 + )"); + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + } + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + local str : string + x[str] = 1 -- Index with a non-const string + )"); + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + } + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + local y = x.key + )"); + CHECK_EQ(toString(result.errors[0]), "Key 'key' not found in class 'IndexableNumericKeyClass'"); + } + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + local y = x["key"] + )"); + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + } + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + local str : string + local y = x[str] -- Index with a non-const string + )"); + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 94cf4b326..9712f0975 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1952,4 +1952,40 @@ TEST_CASE_FIXTURE(Fixture, "instantiated_type_packs_must_have_a_non_null_scope") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "inner_frees_become_generic_in_dcr") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + function f(x) + local z = x + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + std::optional ty = findTypeAtPosition(Position{3, 19}); + REQUIRE(ty); + CHECK(get(*ty)); +} + +TEST_CASE_FIXTURE(Fixture, "function_exprs_are_generalized_at_signature_scope_not_enclosing") +{ + CheckResult result = check(R"( + local foo + local bar + + -- foo being a function expression is deliberate: the bug we're testing + -- only existed for function expressions, not for function statements. + foo = function(a) + return bar + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // note that b is not in the generic list; it is free, the unconstrained type of `bar`. + CHECK(toString(requireType("foo")) == "(a) -> b"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index b682e5f6c..99abf711e 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -550,6 +550,8 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") { + ScopedFastFlag sff{"LuauUnifyTwoOptions", true}; + CheckResult result = check(R"( local x : { p : number?, q : any } & { p : unknown, q : string? } local y : { p : number?, q : string? } = x -- OK @@ -563,27 +565,19 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" "caused by:\n" - " Property 'p' is not compatible. Type 'number?' could not be converted into 'string?'\n" - "caused by:\n" - " Not all union options are compatible. Type 'number' could not be converted into 'string?'\n" - "caused by:\n" - " None of the union options are compatible. For example: Type 'number' could not be converted into 'string' in an invariant context"); + " Property 'p' is not compatible. Type 'number' could not be converted into 'string' in an invariant context"); CHECK_EQ(toString(result.errors[1]), "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" "caused by:\n" - " Property 'q' is not compatible. Type 'string?' could not be converted into 'number?'\n" - "caused by:\n" - " Not all union options are compatible. Type 'string' could not be converted into 'number?'\n" - "caused by:\n" - " None of the union options are compatible. For example: Type 'string' could not be converted into 'number' in an invariant context"); + " Property 'q' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"); } else { LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, " - "q: number? |}'; none of the intersection parts are compatible"); + "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into " + "'{| p: string?, q: number? |}'; none of the intersection parts are compatible"); } } diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index d224195c9..90436ce7d 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -1134,7 +1134,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "luau_polyfill_is_array_simplified") return false end return true - end + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1179,4 +1179,30 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.startswith") +{ + // This test also exercises whether the binary operator == passes the correct expected type + // to it's l,r operands + CheckResult result = check(R"( +--!strict +local function startsWith(value: string, substring: string, position: number?): boolean + -- Luau FIXME: we have to use a tmp variable, as Luau doesn't understand the logic below narrow position to `number` + local position_ + if position == nil or position < 1 then + position_ = 1 + else + position_ = position + end + + return value:find(substring, position_, true) == position_ +end + +return startsWith + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 58acef222..606a4f4af 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -482,41 +482,6 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); } -TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") -{ - ScopedFastFlag sff[] = { - {"LuauTransitiveSubtyping", true}, - }; - - TypeArena arena; - TypeId nilType = builtinTypes->nilType; - - std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - - TypeId free1 = arena.addType(FreeType{scope.get()}); - TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - - TypeId free2 = arena.addType(FreeType{scope.get()}); - TypeId option2 = arena.addType(UnionType{{nilType, free2}}); - - InternalErrorReporter iceHandler; - UnifierSharedState sharedState{&iceHandler}; - Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant}; - - u.tryUnify(option1, option2); - - CHECK(!u.failure); - - u.log.commit(); - - ToStringOptions opts; - CHECK("a?" == toString(option1, opts)); - - // CHECK("a?" == toString(option2, opts)); // This should hold, but does not. - CHECK("b?" == toString(option2, opts)); // This should not hold. -} - TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") { ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", false}; diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index f028e8e0d..d068ae53d 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -390,6 +390,8 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") { + ScopedFastFlag sff{"LuauUnifyTwoOptions", true}; + CheckResult result = check(R"( local function foo(f, x): "hello"? -- anyone there? return if x == "hi" @@ -401,7 +403,9 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); - CHECK_EQ(R"(((string) -> (a, c...), b) -> "hello"?)", toString(requireType("foo"))); + CHECK_EQ(R"(((string) -> ("hello", b...), a) -> "hello"?)", toString(requireType("foo"))); + + // This is more accurate but we're not there yet: // CHECK_EQ(R"(((string) -> ("hello"?, b...), a) -> "hello"?)", toString(requireType("foo"))); } diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 3f685f1c3..cbb04cba9 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1233,4 +1233,69 @@ TEST_CASE_FIXTURE(Fixture, "dcr_delays_expansion_of_function_containing_blocked_ )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter") +{ + CheckResult result = check(R"( + local TRUE: true = true + + local function matches(value, t: true) + if value then + return true + end + end + + local function readValue(breakpoint) + if matches(breakpoint, TRUE) then + readValue(breakpoint) + end + end + )"); + + CHECK("(a) -> ()" == toString(requireType("readValue"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter_2") +{ + CheckResult result = check(R"( + local function readValue(breakpoint) + if type(breakpoint) == 'number' then + readValue(breakpoint) + end + end + )"); + + CHECK("(number) -> ()" == toString(requireType("readValue"))); +} + +/* + * We got into a case where, as we unified two nearly identical unions with one + * another, where we had a concatenated TxnLog that created a cycle between two + * free types. + * + * This code used to crash the type checker. See CLI-71190 + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "convoluted_case_where_two_TypeVars_were_bound_to_each_other") +{ + check(R"( + type React_Ref = { current: ElementType } | ((ElementType) -> ()) + + type React_AbstractComponent = { + render: ((ref: React_Ref) -> nil) + } + + local createElement : (React_AbstractComponent) -> () + + function ScrollView:render() + local one = table.unpack( + if true then a else b + ) + + createElement(one) + createElement(one) + end + )"); + + // If this code does not crash, we are in good shape. +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index fa52a7466..225b4ff1b 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -409,4 +409,91 @@ local l0:(any)&(typeof(_)),l0:(any)|(any) = _,_ LUAU_REQUIRE_ERRORS(result); } +static TypeId createTheType(TypeArena& arena, NotNull builtinTypes, Scope* scope, TypeId freeTy) +{ + /* + ({| + render: ( + (('a) -> ()) | {| current: 'a |} + ) -> nil + |}) -> () + */ + TypePackId emptyPack = arena.addTypePack({}); + + return arena.addType(FunctionType{ + arena.addTypePack({arena.addType(TableType{ + TableType::Props{{{"render", + Property(arena.addType(FunctionType{ + arena.addTypePack({arena.addType(UnionType{{arena.addType(FunctionType{arena.addTypePack({freeTy}), emptyPack}), + arena.addType(TableType{TableType::Props{{"current", {freeTy}}}, std::nullopt, TypeLevel{}, scope, TableState::Sealed})}})}), + arena.addTypePack({builtinTypes->nilType})}))}}}, + std::nullopt, TypeLevel{}, scope, TableState::Sealed})}), + emptyPack}); +}; + +// See CLI-71190 +TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_two_unions_under_dcr_does_not_create_a_BoundType_cycle") +{ + const std::shared_ptr scope = globalScope; + const std::shared_ptr nestedScope = std::make_shared(scope); + + const TypeId outerType = arena.freshType(scope.get()); + const TypeId outerType2 = arena.freshType(scope.get()); + + const TypeId innerType = arena.freshType(nestedScope.get()); + + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + state.enableScopeTests(); + + SUBCASE("equal_scopes") + { + TypeId one = createTheType(arena, builtinTypes, scope.get(), outerType); + TypeId two = createTheType(arena, builtinTypes, scope.get(), outerType2); + + state.tryUnify(one, two); + state.log.commit(); + + ToStringOptions opts; + + CHECK(follow(outerType) == follow(outerType2)); + } + + SUBCASE("outer_scope_is_subtype") + { + TypeId one = createTheType(arena, builtinTypes, scope.get(), outerType); + TypeId two = createTheType(arena, builtinTypes, scope.get(), innerType); + + state.tryUnify(one, two); + state.log.commit(); + + ToStringOptions opts; + + CHECK(follow(outerType) == follow(innerType)); + + // The scope of outerType exceeds that of innerType. The latter should be bound to the former. + const BoundType* bt = get_if(&innerType->ty); + REQUIRE(bt); + CHECK(bt->boundTo == outerType); + } + + SUBCASE("outer_scope_is_supertype") + { + TypeId one = createTheType(arena, builtinTypes, scope.get(), innerType); + TypeId two = createTheType(arena, builtinTypes, scope.get(), outerType); + + state.tryUnify(one, two); + state.log.commit(); + + ToStringOptions opts; + + CHECK(follow(outerType) == follow(innerType)); + + // The scope of outerType exceeds that of innerType. The latter should be bound to the former. + const BoundType* bt = get_if(&innerType->ty); + REQUIRE(bt); + CHECK(bt->boundTo == outerType); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 19b221482..960d6f15b 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -792,4 +792,82 @@ TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions") CHECK("variables" == unknownProp->key); } +TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") +{ + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + {"LuauUnifyTwoOptions", true} + }; + + TypeArena arena; + TypeId nilType = builtinTypes->nilType; + + std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); + + TypeId free1 = arena.addType(FreeType{scope.get()}); + TypeId option1 = arena.addType(UnionType{{nilType, free1}}); + + TypeId free2 = arena.addType(FreeType{scope.get()}); + TypeId option2 = arena.addType(UnionType{{nilType, free2}}); + + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant}; + + u.tryUnify(option1, option2); + + CHECK(!u.failure); + + u.log.commit(); + + ToStringOptions opts; + CHECK("a?" == toString(option1, opts)); + CHECK("a?" == toString(option2, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "unify_more_complex_unions_that_include_nil") +{ + CheckResult result = check(R"( + type Record = {prop: (string | boolean)?} + + function concatPagination(prop: (string | boolean | nil)?): Record + return {prop = prop} + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "optional_class_instances_are_invariant") +{ + ScopedFastFlag sff[] = { + {"LuauUnifyTwoOptions", true}, + {"LuauTypeMismatchInvarianceInError", true} + }; + + createSomeClasses(&frontend); + + CheckResult result = check(R"( + function foo(ref: {current: Parent?}) + end + + function bar(ref: {current: Child?}) + foo(ref) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + // The last line of this error is the most important part. We need to + // communicate that this is an invariant context. + std::string expectedError = + "Type '{| current: Child? |}' could not be converted into '{| current: Parent? |}'\n" + "caused by:\n" + " Property 'current' is not compatible. Type 'Child' could not be converted into 'Parent' in an invariant context" + ; + + CHECK(expectedError == toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index c773013b7..77b02fd1e 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -69,4 +69,17 @@ end breakpointSetFromMetamethod() +-- break inside function with non-monotonic line info +local function cond(a) + if a then + print('a') + else + print('not a') + end +end + +breakpoint(77) + +pcall(cond, nil) -- prevent inlining + return 'OK' diff --git a/tools/faillist.txt b/tools/faillist.txt index 655d094f6..a26e5c9f9 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,5 +1,6 @@ AnnotationTests.too_many_type_params AstQuery.last_argument_function_call_type +AutocompleteTest.autocomplete_response_perf1 BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 @@ -34,7 +35,6 @@ GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_type_pack_parentheses GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument_2 -GenericsTests.infer_generic_function_function_argument_3 GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_lib_function_function_argument GenericsTests.instantiated_function_argument_names @@ -47,7 +47,6 @@ ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illeg ProvisionalTests.bail_early_if_unification_is_too_complicated ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean -ProvisionalTests.free_options_cannot_be_unified_together ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns ProvisionalTests.luau-polyfill.Array.filter @@ -60,7 +59,6 @@ RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table -RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.checked_prop_too_early TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode @@ -124,6 +122,7 @@ TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_locations TypeAliases.type_alias_of_an_imported_recursive_generic_type +TypeFamilyTests.function_internal_families TypeInfer.check_type_infer_recursion_count TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError @@ -133,6 +132,7 @@ TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.no_stack_overflow_from_isoptional TypeInfer.no_stack_overflow_from_isoptional2 +TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter_2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer @@ -165,9 +165,7 @@ TypeInferFunctions.too_many_return_values_no_function TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_generic_next -TypeInferLoops.loop_iter_metamethod_ok_with_inference TypeInferLoops.loop_iter_trailing_nil -TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict @@ -177,7 +175,6 @@ TypeInferOOP.methods_are_topologically_sorted TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.CallOrOfFunctions TypeInferOperators.cli_38355_recursive_union -TypeInferOperators.compound_assign_metatable TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops TypeInferOperators.operator_eq_completely_incompatible From eb7106016e25c078b577427254c910a062415f78 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 19 May 2023 11:59:59 -0700 Subject: [PATCH 54/66] Sync to upstream/release/577 --- Analysis/include/Luau/Constraint.h | 44 +- .../include/Luau/ConstraintGraphBuilder.h | 4 +- Analysis/include/Luau/ConstraintSolver.h | 28 +- Analysis/include/Luau/Module.h | 3 - Analysis/include/Luau/Normalize.h | 10 + Analysis/include/Luau/Simplify.h | 36 + Analysis/include/Luau/ToString.h | 5 +- Analysis/include/Luau/TxnLog.h | 6 + Analysis/include/Luau/Type.h | 6 +- Analysis/include/Luau/TypeFamily.h | 26 +- Analysis/include/Luau/TypeReduction.h | 85 - Analysis/src/Autocomplete.cpp | 1 - Analysis/src/Clone.cpp | 31 +- Analysis/src/ConstraintGraphBuilder.cpp | 212 ++- Analysis/src/ConstraintSolver.cpp | 256 ++- Analysis/src/Frontend.cpp | 18 +- Analysis/src/Instantiation.cpp | 4 +- Analysis/src/Module.cpp | 7 +- Analysis/src/Normalize.cpp | 57 +- Analysis/src/Quantify.cpp | 4 +- Analysis/src/Simplify.cpp | 1270 ++++++++++++++ Analysis/src/ToString.cpp | 10 + Analysis/src/TxnLog.cpp | 16 +- Analysis/src/TypeChecker2.cpp | 245 +-- Analysis/src/TypeFamily.cpp | 174 +- Analysis/src/TypeInfer.cpp | 4 +- Analysis/src/TypeReduction.cpp | 1200 ------------- Analysis/src/Unifier.cpp | 4 +- CLI/Reduce.cpp | 53 +- CMakeLists.txt | 3 + CodeGen/include/Luau/AssemblyBuilderA64.h | 8 +- CodeGen/include/luacodegen.h | 18 + CodeGen/src/AssemblyBuilderA64.cpp | 38 +- CodeGen/src/CodeGen.cpp | 101 +- CodeGen/src/CodeGenA64.cpp | 27 +- CodeGen/src/CodeGenUtils.cpp | 671 +++++++- CodeGen/src/CodeGenUtils.h | 12 + CodeGen/src/CustomExecUtils.h | 24 - CodeGen/src/EmitCommon.h | 9 +- CodeGen/src/EmitCommonX64.cpp | 6 +- CodeGen/src/EmitCommonX64.h | 6 +- CodeGen/src/EmitInstructionX64.cpp | 23 +- CodeGen/src/Fallbacks.cpp | 639 ------- CodeGen/src/Fallbacks.h | 24 - CodeGen/src/FallbacksProlog.h | 56 - CodeGen/src/IrLoweringA64.cpp | 76 +- CodeGen/src/IrLoweringX64.cpp | 24 +- CodeGen/src/NativeState.cpp | 38 +- CodeGen/src/NativeState.h | 28 +- CodeGen/src/OptimizeConstProp.cpp | 13 + CodeGen/src/lcodegen.cpp | 21 + Makefile | 1 + Sources.cmake | 11 +- VM/include/luaconf.h | 2 + VM/src/ldo.cpp | 11 +- VM/src/ldo.h | 2 +- VM/src/lfunc.cpp | 3 +- VM/src/lobject.h | 3 +- VM/src/lstate.h | 1 + VM/src/lvmexecute.cpp | 58 +- tests/AssemblyBuilderA64.test.cpp | 16 + tests/Autocomplete.test.cpp | 34 - tests/ClassFixture.cpp | 11 +- tests/Conformance.test.cpp | 18 +- tests/ConstraintGraphBuilderFixture.cpp | 3 - tests/IrBuilder.test.cpp | 30 + tests/Module.test.cpp | 29 + tests/Normalize.test.cpp | 6 +- tests/Simplify.test.cpp | 508 ++++++ tests/ToString.test.cpp | 24 +- tests/TxnLog.test.cpp | 11 + tests/TypeFamily.test.cpp | 37 +- tests/TypeInfer.annotations.test.cpp | 12 + tests/TypeInfer.cfa.test.cpp | 5 +- tests/TypeInfer.classes.test.cpp | 13 +- tests/TypeInfer.functions.test.cpp | 20 +- tests/TypeInfer.intersectionTypes.test.cpp | 138 +- tests/TypeInfer.operators.test.cpp | 102 +- tests/TypeInfer.provisional.test.cpp | 59 +- tests/TypeInfer.refinements.test.cpp | 37 +- tests/TypeInfer.tables.test.cpp | 4 +- tests/TypeInfer.test.cpp | 5 +- tests/TypeInfer.typePacks.cpp | 10 + tests/TypeInfer.unionTypes.test.cpp | 51 +- tests/TypeReduction.test.cpp | 1509 ----------------- tests/TypeVar.test.cpp | 1 - tools/faillist.txt | 30 +- tools/lvmexecute_split.py | 112 -- 88 files changed, 4110 insertions(+), 4501 deletions(-) create mode 100644 Analysis/include/Luau/Simplify.h delete mode 100644 Analysis/include/Luau/TypeReduction.h create mode 100644 Analysis/src/Simplify.cpp delete mode 100644 Analysis/src/TypeReduction.cpp create mode 100644 CodeGen/include/luacodegen.h delete mode 100644 CodeGen/src/Fallbacks.cpp delete mode 100644 CodeGen/src/Fallbacks.h delete mode 100644 CodeGen/src/FallbacksProlog.h create mode 100644 CodeGen/src/lcodegen.cpp create mode 100644 tests/Simplify.test.cpp delete mode 100644 tests/TypeReduction.test.cpp delete mode 100644 tools/lvmexecute_split.py diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 3aa3c865c..c815bef01 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -144,6 +144,24 @@ struct HasPropConstraint TypeId resultType; TypeId subjectType; std::string prop; + + // HACK: We presently need types like true|false or string|"hello" when + // deciding whether a particular literal expression should have a singleton + // type. This boolean is set to true when extracting the property type of a + // value that may be a union of tables. + // + // For example, in the following code fragment, we want the lookup of the + // success property to yield true|false when extracting an expectedType in + // this expression: + // + // type Result = {success:true, result: T} | {success:false, error: E} + // + // local r: Result = {success=true, result=9} + // + // If we naively simplify the expectedType to boolean, we will erroneously + // compute the type boolean for the success property of the table literal. + // This causes type checking to fail. + bool suppressSimplification = false; }; // result ~ setProp subjectType ["prop", "prop2", ...] propType @@ -198,6 +216,24 @@ struct UnpackConstraint TypePackId sourcePack; }; +// resultType ~ refine type mode discriminant +// +// Compute type & discriminant (or type | discriminant) as soon as possible (but +// no sooner), simplify, and bind resultType to that type. +struct RefineConstraint +{ + enum + { + Intersection, + Union + } mode; + + TypeId resultType; + + TypeId type; + TypeId discriminant; +}; + // ty ~ reduce ty // // Try to reduce ty, if it is a TypeFamilyInstanceType. Otherwise, do nothing. @@ -214,10 +250,10 @@ struct ReducePackConstraint TypePackId tp; }; -using ConstraintV = - Variant; +using ConstraintV = Variant; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 5800d146d..ababe0a36 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -188,6 +188,7 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprGlobal* global); Inference check(const ScopePtr& scope, AstExprIndexName* indexName); Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); + Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprUnary* unary); Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); @@ -213,7 +214,8 @@ struct ConstraintGraphBuilder ScopePtr bodyScope; }; - FunctionSignature checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType = {}); + FunctionSignature checkFunctionSignature( + const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType = {}, std::optional originalName = {}); /** * Checks the body of a function expression. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index f6b1aede8..1a43a252e 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -8,7 +8,6 @@ #include "Luau/Normalize.h" #include "Luau/ToString.h" #include "Luau/Type.h" -#include "Luau/TypeReduction.h" #include "Luau/Variant.h" #include @@ -121,6 +120,7 @@ struct ConstraintSolver bool tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force); bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); bool tryDispatch(const UnpackConstraint& c, NotNull constraint); + bool tryDispatch(const RefineConstraint& c, NotNull constraint, bool force); bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); @@ -132,8 +132,10 @@ struct ConstraintSolver bool tryDispatchIterableFunction( TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); - std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName); - std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); + std::pair, std::optional> lookupTableProp( + TypeId subjectType, const std::string& propName, bool suppressSimplification = false); + std::pair, std::optional> lookupTableProp( + TypeId subjectType, const std::string& propName, bool suppressSimplification, std::unordered_set& seen); void block(NotNull target, NotNull constraint); /** @@ -143,6 +145,16 @@ struct ConstraintSolver bool block(TypeId target, NotNull constraint); bool block(TypePackId target, NotNull constraint); + // Block on every target + template + bool block(const T& targets, NotNull constraint) + { + for (TypeId target : targets) + block(target, constraint); + + return false; + } + /** * For all constraints that are blocked on one constraint, make them block * on a new constraint. @@ -151,15 +163,15 @@ struct ConstraintSolver */ void inheritBlocks(NotNull source, NotNull addition); - // Traverse the type. If any blocked or pending types are found, block - // the constraint on them. + // Traverse the type. If any pending types are found, block the constraint + // on them. // // Returns false if a type blocks the constraint. // // FIXME: This use of a boolean for the return result is an appalling // interface. - bool recursiveBlock(TypeId target, NotNull constraint); - bool recursiveBlock(TypePackId target, NotNull constraint); + bool blockOnPendingTypes(TypeId target, NotNull constraint); + bool blockOnPendingTypes(TypePackId target, NotNull constraint); void unblock(NotNull progressed); void unblock(TypeId progressed); @@ -255,6 +267,8 @@ struct ConstraintSolver TypeId unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes); + TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); + ToStringOptions opts; }; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index b9be8205b..1fa2e03c7 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -85,14 +85,11 @@ struct Module DenseHashMap astOverloadResolvedTypes{nullptr}; DenseHashMap astResolvedTypes{nullptr}; - DenseHashMap astOriginalResolvedTypes{nullptr}; DenseHashMap astResolvedTypePacks{nullptr}; // Map AST nodes to the scope they create. Cannot be NotNull because we need a sentinel value for the map. DenseHashMap astScopes{nullptr}; - std::unique_ptr reduction; - std::unordered_map declaredGlobals; ErrorVec errors; LintResult lintResult; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 2ec5406fd..978ddb480 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -267,8 +267,18 @@ struct NormalizedType NormalizedType(NormalizedType&&) = default; NormalizedType& operator=(NormalizedType&&) = default; + + // IsType functions + + /// Returns true if the type is a subtype of function. This includes any and unknown. + bool isFunction() const; + + /// Returns true if the type is a subtype of number. This includes any and unknown. + bool isNumber() const; }; + + class Normalizer { std::unordered_map> cachedNormals; diff --git a/Analysis/include/Luau/Simplify.h b/Analysis/include/Luau/Simplify.h new file mode 100644 index 000000000..27ed44f8f --- /dev/null +++ b/Analysis/include/Luau/Simplify.h @@ -0,0 +1,36 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Type.h" + +#include + +namespace Luau +{ + +struct TypeArena; +struct BuiltinTypes; + +struct SimplifyResult +{ + TypeId result; + + std::set blockedTypes; +}; + +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); + +enum class Relation +{ + Disjoint, // No A is a B or vice versa + Coincident, // Every A is in B and vice versa + Intersects, // Some As are in B and some Bs are in A. ex (number | string) <-> (string | boolean) + Subset, // Every A is in B + Superset, // Every B is in A +}; + +Relation relate(TypeId left, TypeId right); + +} // namespace Luau diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 7758e8f99..dec2c1fc5 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -99,10 +99,7 @@ inline std::string toString(const Constraint& c, ToStringOptions&& opts) return toString(c, opts); } -inline std::string toString(const Constraint& c) -{ - return toString(c, ToStringOptions{}); -} +std::string toString(const Constraint& c); std::string toString(const Type& tv, ToStringOptions& opts); std::string toString(const TypePackVar& tp, ToStringOptions& opts); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 907908dfe..951f89ee5 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -308,6 +308,12 @@ struct TxnLog // used. Else we use the embedded Scope*. bool useScopes = false; + // It is sometimes the case under DCR that we speculatively rebind + // GenericTypes to other types as though they were free. We mark logs that + // contain these kinds of substitutions as radioactive so that we know that + // we must never commit one. + bool radioactive = false; + // Used to avoid infinite recursion when types are cyclic. // Shared with all the descendent TxnLogs. std::vector>* sharedSeen; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 80a044cbf..d42f58b4b 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -349,7 +349,9 @@ struct FunctionType DcrMagicFunction dcrMagicFunction = nullptr; DcrMagicRefinement dcrMagicRefinement = nullptr; bool hasSelf; - bool hasNoGenerics = false; + // `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it. + // this flag is used as an optimization to exit early from procedures that manipulate free or generic types. + bool hasNoFreeOrGenericTypes = false; }; enum class TableState @@ -530,7 +532,7 @@ struct ClassType */ struct TypeFamilyInstanceType { - NotNull family; + NotNull family; std::vector typeArguments; std::vector packArguments; diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index 4c04f52ae..bf47de30a 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -21,6 +21,7 @@ using TypePackId = const TypePackVar*; struct TypeArena; struct BuiltinTypes; struct TxnLog; +class Normalizer; /// Represents a reduction result, which may have successfully reduced the type, /// may have concretely failed to reduce the type, or may simply be stuck @@ -52,8 +53,8 @@ struct TypeFamily std::string name; /// The reducer function for the type family. - std::function( - std::vector, std::vector, NotNull, NotNull, NotNull log)> + std::function(std::vector, std::vector, NotNull, NotNull, + NotNull, NotNull, NotNull)> reducer; }; @@ -66,8 +67,8 @@ struct TypePackFamily std::string name; /// The reducer function for the type pack family. - std::function( - std::vector, std::vector, NotNull, NotNull, NotNull log)> + std::function(std::vector, std::vector, NotNull, NotNull, + NotNull, NotNull, NotNull)> reducer; }; @@ -93,8 +94,8 @@ struct FamilyGraphReductionResult * against the TxnLog, otherwise substitutions will directly mutate the type * graph. Do not provide the empty TxnLog, as a result. */ -FamilyGraphReductionResult reduceFamilies( - TypeId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log = nullptr, bool force = false); +FamilyGraphReductionResult reduceFamilies(TypeId entrypoint, Location location, NotNull arena, NotNull builtins, + NotNull scope, NotNull normalizer, TxnLog* log = nullptr, bool force = false); /** * Attempt to reduce all instances of any type or type pack family in the type @@ -109,7 +110,16 @@ FamilyGraphReductionResult reduceFamilies( * against the TxnLog, otherwise substitutions will directly mutate the type * graph. Do not provide the empty TxnLog, as a result. */ -FamilyGraphReductionResult reduceFamilies( - TypePackId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log = nullptr, bool force = false); +FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location location, NotNull arena, NotNull builtins, + NotNull scope, NotNull normalizer, TxnLog* log = nullptr, bool force = false); + +struct BuiltinTypeFamilies +{ + BuiltinTypeFamilies(); + + TypeFamily addFamily; +}; + +const BuiltinTypeFamilies kBuiltinTypeFamilies{}; } // namespace Luau diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h deleted file mode 100644 index 3f64870ab..000000000 --- a/Analysis/include/Luau/TypeReduction.h +++ /dev/null @@ -1,85 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "Luau/Type.h" -#include "Luau/TypeArena.h" -#include "Luau/TypePack.h" -#include "Luau/Variant.h" - -namespace Luau -{ - -namespace detail -{ -template -struct ReductionEdge -{ - T type = nullptr; - bool irreducible = false; -}; - -struct TypeReductionMemoization -{ - TypeReductionMemoization() = default; - - TypeReductionMemoization(const TypeReductionMemoization&) = delete; - TypeReductionMemoization& operator=(const TypeReductionMemoization&) = delete; - - TypeReductionMemoization(TypeReductionMemoization&&) = default; - TypeReductionMemoization& operator=(TypeReductionMemoization&&) = default; - - DenseHashMap> types{nullptr}; - DenseHashMap> typePacks{nullptr}; - - bool isIrreducible(TypeId ty); - bool isIrreducible(TypePackId tp); - - TypeId memoize(TypeId ty, TypeId reducedTy); - TypePackId memoize(TypePackId tp, TypePackId reducedTp); - - // Reducing A into B may have a non-irreducible edge A to B for which B is not irreducible, which means B could be reduced into C. - // Because reduction should always be transitive, A should point to C if A points to B and B points to C. - std::optional> memoizedof(TypeId ty) const; - std::optional> memoizedof(TypePackId tp) const; -}; -} // namespace detail - -struct TypeReductionOptions -{ - /// If it's desirable for type reduction to allocate into a different arena than the TypeReduction instance you have, you will need - /// to create a temporary TypeReduction in that case, and set [`TypeReductionOptions::allowTypeReductionsFromOtherArenas`] to true. - /// This is because TypeReduction caches the reduced type. - bool allowTypeReductionsFromOtherArenas = false; -}; - -struct TypeReduction -{ - explicit TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle, - const TypeReductionOptions& opts = {}); - - TypeReduction(const TypeReduction&) = delete; - TypeReduction& operator=(const TypeReduction&) = delete; - - TypeReduction(TypeReduction&&) = default; - TypeReduction& operator=(TypeReduction&&) = default; - - std::optional reduce(TypeId ty); - std::optional reduce(TypePackId tp); - std::optional reduce(const TypeFun& fun); - -private: - NotNull arena; - NotNull builtinTypes; - NotNull handle; - - TypeReductionOptions options; - detail::TypeReductionMemoization memoization; - - // Computes an *estimated length* of the cartesian product of the given type. - size_t cartesianProductSize(TypeId ty) const; - - bool hasExceededCartesianProductLimit(TypeId ty) const; - bool hasExceededCartesianProductLimit(TypePackId tp) const; -}; - -} // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 4b66568b5..8dd747390 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -7,7 +7,6 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include #include diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 0c1b24a19..1eb78540a 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -11,6 +11,7 @@ LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAGVARIABLE(LuauCloneCyclicUnions, false) namespace Luau { @@ -282,7 +283,7 @@ void TypeCloner::operator()(const FunctionType& t) ftv->argTypes = clone(t.argTypes, dest, cloneState); ftv->argNames = t.argNames; ftv->retTypes = clone(t.retTypes, dest, cloneState); - ftv->hasNoGenerics = t.hasNoGenerics; + ftv->hasNoFreeOrGenericTypes = t.hasNoFreeOrGenericTypes; } void TypeCloner::operator()(const TableType& t) @@ -373,14 +374,30 @@ void TypeCloner::operator()(const AnyType& t) void TypeCloner::operator()(const UnionType& t) { - std::vector options; - options.reserve(t.options.size()); + if (FFlag::LuauCloneCyclicUnions) + { + TypeId result = dest.addType(FreeType{nullptr}); + seenTypes[typeId] = result; - for (TypeId ty : t.options) - options.push_back(clone(ty, dest, cloneState)); + std::vector options; + options.reserve(t.options.size()); - TypeId result = dest.addType(UnionType{std::move(options)}); - seenTypes[typeId] = result; + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, cloneState)); + + asMutable(result)->ty.emplace(std::move(options)); + } + else + { + std::vector options; + options.reserve(t.options.size()); + + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, cloneState)); + + TypeId result = dest.addType(UnionType{std::move(options)}); + seenTypes[typeId] = result; + } } void TypeCloner::operator()(const IntersectionType& t) diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index c8d99adf8..b190f4aba 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -13,6 +13,9 @@ #include "Luau/Scope.h" #include "Luau/TypeUtils.h" #include "Luau/Type.h" +#include "Luau/TypeFamily.h" +#include "Luau/Simplify.h" +#include "Luau/VisitType.h" #include @@ -195,8 +198,23 @@ struct RefinementPartition using RefinementContext = std::unordered_map; -static void unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, NotNull arena) +static void unionRefinements(NotNull builtinTypes, NotNull arena, const RefinementContext& lhs, const RefinementContext& rhs, + RefinementContext& dest, std::vector* constraints) { + const auto intersect = [&](const std::vector& types) { + if (1 == types.size()) + return types[0]; + else if (2 == types.size()) + { + // TODO: It may be advantageous to create a RefineConstraint here when there are blockedTypes. + SimplifyResult sr = simplifyIntersection(builtinTypes, arena, types[0], types[1]); + if (sr.blockedTypes.empty()) + return sr.result; + } + + return arena->addType(IntersectionType{types}); + }; + for (auto& [def, partition] : lhs) { auto rhsIt = rhs.find(def); @@ -206,55 +224,54 @@ static void unionRefinements(const RefinementContext& lhs, const RefinementConte LUAU_ASSERT(!partition.discriminantTypes.empty()); LUAU_ASSERT(!rhsIt->second.discriminantTypes.empty()); - TypeId leftDiscriminantTy = - partition.discriminantTypes.size() == 1 ? partition.discriminantTypes[0] : arena->addType(IntersectionType{partition.discriminantTypes}); + TypeId leftDiscriminantTy = partition.discriminantTypes.size() == 1 ? partition.discriminantTypes[0] : intersect(partition.discriminantTypes); - TypeId rightDiscriminantTy = rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] - : arena->addType(IntersectionType{rhsIt->second.discriminantTypes}); + TypeId rightDiscriminantTy = + rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] : intersect(rhsIt->second.discriminantTypes); - dest[def].discriminantTypes.push_back(arena->addType(UnionType{{leftDiscriminantTy, rightDiscriminantTy}})); + dest[def].discriminantTypes.push_back(simplifyUnion(builtinTypes, arena, leftDiscriminantTy, rightDiscriminantTy).result); dest[def].shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; } } -static void computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, NotNull arena, bool eq, - std::vector* constraints) +static void computeRefinement(NotNull builtinTypes, NotNull arena, const ScopePtr& scope, RefinementId refinement, + RefinementContext* refis, bool sense, bool eq, std::vector* constraints) { if (!refinement) return; else if (auto variadic = get(refinement)) { for (RefinementId refi : variadic->refinements) - computeRefinement(scope, refi, refis, sense, arena, eq, constraints); + computeRefinement(builtinTypes, arena, scope, refi, refis, sense, eq, constraints); } else if (auto negation = get(refinement)) - return computeRefinement(scope, negation->refinement, refis, !sense, arena, eq, constraints); + return computeRefinement(builtinTypes, arena, scope, negation->refinement, refis, !sense, eq, constraints); else if (auto conjunction = get(refinement)) { RefinementContext lhsRefis; RefinementContext rhsRefis; - computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); - computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); + computeRefinement(builtinTypes, arena, scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, eq, constraints); + computeRefinement(builtinTypes, arena, scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, eq, constraints); if (!sense) - unionRefinements(lhsRefis, rhsRefis, *refis, arena); + unionRefinements(builtinTypes, arena, lhsRefis, rhsRefis, *refis, constraints); } else if (auto disjunction = get(refinement)) { RefinementContext lhsRefis; RefinementContext rhsRefis; - computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); - computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); + computeRefinement(builtinTypes, arena, scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, eq, constraints); + computeRefinement(builtinTypes, arena, scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, eq, constraints); if (sense) - unionRefinements(lhsRefis, rhsRefis, *refis, arena); + unionRefinements(builtinTypes, arena, lhsRefis, rhsRefis, *refis, constraints); } else if (auto equivalence = get(refinement)) { - computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints); - computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints); + computeRefinement(builtinTypes, arena, scope, equivalence->lhs, refis, sense, true, constraints); + computeRefinement(builtinTypes, arena, scope, equivalence->rhs, refis, sense, true, constraints); } else if (auto proposition = get(refinement)) { @@ -300,6 +317,63 @@ static void computeRefinement(const ScopePtr& scope, RefinementId refinement, Re } } +namespace +{ + +/* + * Constraint generation may be called upon to simplify an intersection or union + * of types that are not sufficiently solved yet. We use + * FindSimplificationBlockers to recognize these types and defer the + * simplification until constraint solution. + */ +struct FindSimplificationBlockers : TypeOnceVisitor +{ + bool found = false; + + bool visit(TypeId) override + { + return !found; + } + + bool visit(TypeId, const BlockedType&) override + { + found = true; + return false; + } + + bool visit(TypeId, const FreeType&) override + { + found = true; + return false; + } + + bool visit(TypeId, const PendingExpansionType&) override + { + found = true; + return false; + } + + // We do not need to know anything at all about a function's argument or + // return types in order to simplify it in an intersection or union. + bool visit(TypeId, const FunctionType&) override + { + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } +}; + +bool mustDeferIntersection(TypeId ty) +{ + FindSimplificationBlockers bts; + bts.traverse(ty); + return bts.found; +} +} // namespace + void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement) { if (!refinement) @@ -307,7 +381,7 @@ void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location lo RefinementContext refinements; std::vector constraints; - computeRefinement(scope, refinement, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); + computeRefinement(builtinTypes, arena, scope, refinement, &refinements, /*sense*/ true, /*eq*/ false, &constraints); for (auto& [def, partition] : refinements) { @@ -317,8 +391,24 @@ void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location lo if (partition.shouldAppendNilType) ty = arena->addType(UnionType{{ty, builtinTypes->nilType}}); - partition.discriminantTypes.push_back(ty); - scope->dcrRefinements[def] = arena->addType(IntersectionType{std::move(partition.discriminantTypes)}); + // Intersect ty with every discriminant type. If either type is not + // sufficiently solved, we queue the intersection up via an + // IntersectConstraint. + + for (TypeId dt : partition.discriminantTypes) + { + if (mustDeferIntersection(ty) || mustDeferIntersection(dt)) + { + TypeId r = arena->addType(BlockedType{}); + addConstraint(scope, location, RefineConstraint{RefineConstraint::Intersection, r, ty, dt}); + + ty = r; + } + else + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + } + + scope->dcrRefinements[def] = ty; } } @@ -708,7 +798,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFun functionType = arena->addType(BlockedType{}); scope->bindings[function->name] = Binding{functionType, function->name->location}; - FunctionSignature sig = checkFunctionSignature(scope, function->func); + FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; BreadcrumbId bc = dfg->getBreadcrumb(function->name); @@ -741,10 +831,12 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction TypeId generalizedType = arena->addType(BlockedType{}); Checkpoint start = checkpoint(this); - FunctionSignature sig = checkFunctionSignature(scope, function->func); + FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); std::unordered_set excludeList; + const NullableBreadcrumbId functionBreadcrumb = dfg->getBreadcrumb(function->name); + if (AstExprLocal* localName = function->name->as()) { std::optional existingFunctionTy = scope->lookup(localName->local); @@ -759,6 +851,9 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction scope->bindings[localName->local] = Binding{generalizedType, localName->location}; sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; + + if (functionBreadcrumb) + sig.bodyScope->dcrRefinements[functionBreadcrumb->def] = sig.signature; } else if (AstExprGlobal* globalName = function->name->as()) { @@ -769,6 +864,9 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction generalizedType = *existingFunctionTy; sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; + + if (functionBreadcrumb) + sig.bodyScope->dcrRefinements[functionBreadcrumb->def] = sig.signature; } else if (AstExprIndexName* indexName = function->name->as()) { @@ -795,8 +893,8 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction if (generalizedType == nullptr) ice->ice("generalizedType == nullptr", function->location); - if (NullableBreadcrumbId bc = dfg->getBreadcrumb(function->name)) - scope->dcrRefinements[bc->def] = generalizedType; + if (functionBreadcrumb) + scope->dcrRefinements[functionBreadcrumb->def] = generalizedType; checkFunctionBody(sig.bodyScope, function->func); Checkpoint end = checkpoint(this); @@ -1469,21 +1567,7 @@ Inference ConstraintGraphBuilder::check( else if (auto call = expr->as()) result = flattenPack(scope, expr->location, checkPack(scope, call)); // TODO: needs predicates too else if (auto a = expr->as()) - { - Checkpoint startCheckpoint = checkpoint(this); - FunctionSignature sig = checkFunctionSignature(scope, a, expectedType); - checkFunctionBody(sig.bodyScope, a); - Checkpoint endCheckpoint = checkpoint(this); - - TypeId generalizedTy = arena->addType(BlockedType{}); - NotNull gc = addConstraint(sig.signatureScope, expr->location, GeneralizationConstraint{generalizedTy, sig.signature}); - - forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { - gc->dependencies.emplace_back(constraint.get()); - }); - - result = Inference{generalizedTy}; - } + result = check(scope, a, expectedType); else if (auto indexName = expr->as()) result = check(scope, indexName); else if (auto indexExpr = expr->as()) @@ -1651,6 +1735,23 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* return Inference{result}; } +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprFunction* func, std::optional expectedType) +{ + Checkpoint startCheckpoint = checkpoint(this); + FunctionSignature sig = checkFunctionSignature(scope, func, expectedType); + checkFunctionBody(sig.bodyScope, func); + Checkpoint endCheckpoint = checkpoint(this); + + TypeId generalizedTy = arena->addType(BlockedType{}); + NotNull gc = addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature}); + + forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { + gc->dependencies.emplace_back(constraint.get()); + }); + + return Inference{generalizedTy}; +} + Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { auto [operandType, refinement] = check(scope, unary->expr); @@ -1667,6 +1768,17 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* bi { auto [leftType, rightType, refinement] = checkBinary(scope, binary, expectedType); + if (binary->op == AstExprBinary::Op::Add) + { + TypeId resultType = arena->addType(TypeFamilyInstanceType{ + NotNull{&kBuiltinTypeFamilies.addFamily}, + {leftType, rightType}, + {}, + }); + addConstraint(scope, binary->location, ReduceConstraint{resultType}); + return Inference{resultType, std::move(refinement)}; + } + TypeId resultType = arena->addType(BlockedType{}); addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType, binary, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); @@ -1686,7 +1798,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* if applyRefinements(elseScope, ifElse->falseExpr->location, refinementArena.negation(refinement)); TypeId elseType = check(elseScope, ifElse->falseExpr, ValueContext::RValue, expectedType).ty; - return Inference{expectedType ? *expectedType : arena->addType(UnionType{{thenType, elseType}})}; + return Inference{expectedType ? *expectedType : simplifyUnion(builtinTypes, arena, thenType, elseType).result}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) @@ -1902,6 +2014,8 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) } else if (auto indexExpr = e->as()) { + // We need to populate the type for the index value + check(scope, indexExpr->index, ValueContext::RValue); if (auto strIndex = indexExpr->index->as()) { segments.push_back(std::string(strIndex->value.data, strIndex->value.size)); @@ -2018,12 +2132,12 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp else { expectedValueType = arena->addType(BlockedType{}); - addConstraint(scope, item.value->location, HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data}); + addConstraint(scope, item.value->location, + HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data, /*suppressSimplification*/ true}); } } } - // We'll resolve the expected index result type here with the following priority: // 1. Record table types - in which key, value pairs must be handled on a k,v pair basis. // In this case, the above if-statement will populate expectedValueType @@ -2079,7 +2193,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp } ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature( - const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType) + const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType, std::optional originalName) { ScopePtr signatureScope = nullptr; ScopePtr bodyScope = nullptr; @@ -2235,12 +2349,18 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // TODO: Preserve argument names in the function's type. FunctionType actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; - actualFunction.hasNoGenerics = !hasGenerics; actualFunction.generics = std::move(genericTypes); actualFunction.genericPacks = std::move(genericTypePacks); actualFunction.argNames = std::move(argNames); actualFunction.hasSelf = fn->self != nullptr; + FunctionDefinition defn; + defn.definitionModuleName = module->name; + defn.definitionLocation = fn->location; + defn.varargLocation = fn->vararg ? std::make_optional(fn->varargLocation) : std::nullopt; + defn.originalNameLocation = originalName.value_or(Location(fn->location.begin, 0)); + actualFunction.definition = defn; + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); module->astTypes[fn] = actualFunctionType; @@ -2283,6 +2403,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b if (ref->parameters.size != 1 || !ref->parameters.data[0].type) { reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); + module->astResolvedTypes[ty] = builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType(); } else @@ -2420,7 +2541,6 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b // This replicates the behavior of the appropriate FunctionType // constructors. - ftv.hasNoGenerics = !hasGenerics; ftv.generics = std::move(genericTypes); ftv.genericPacks = std::move(genericTypePacks); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 488fd4baa..9c688f427 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -11,12 +11,13 @@ #include "Luau/Metamethods.h" #include "Luau/ModuleResolver.h" #include "Luau/Quantify.h" +#include "Luau/Simplify.h" #include "Luau/ToString.h" -#include "Luau/TypeUtils.h" #include "Luau/Type.h" +#include "Luau/TypeFamily.h" +#include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/VisitType.h" -#include "Luau/TypeFamily.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAG(LuauRequirePathTrueModuleName) @@ -73,7 +74,7 @@ static std::pair, std::vector> saturateArguments // mutually exclusive with the type pack -> type conversion we do below: // extraTypes will only have elements in it if we have more types than we // have parameter slots for them to go into. - if (!extraTypes.empty()) + if (!extraTypes.empty() && !fn.typePackParams.empty()) { saturatedPackArguments.push_back(arena->addTypePack(extraTypes)); } @@ -89,7 +90,7 @@ static std::pair, std::vector> saturateArguments { saturatedTypeArguments.push_back(*first(tp)); } - else + else if (saturatedPackArguments.size() < fn.typePackParams.size()) { saturatedPackArguments.push_back(tp); } @@ -426,7 +427,9 @@ void ConstraintSolver::finalizeModule() rootScope->returnType = builtinTypes->errorTypePack; } else - rootScope->returnType = *returnType; + { + rootScope->returnType = anyifyModuleReturnTypePackGenerics(*returnType); + } } bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) @@ -468,6 +471,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*sottc, constraint); else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); + else if (auto rc = get(*constraint)) + success = tryDispatch(*rc, constraint, force); else if (auto rc = get(*constraint)) success = tryDispatch(*rc, constraint, force); else if (auto rpc = get(*constraint)) @@ -541,15 +546,25 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNullscope); std::optional instantiated = inst.substitute(c.superType); - LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS LUAU_ASSERT(get(c.subType)); + + if (!instantiated.has_value()) + { + reportError(UnificationTooComplex{}, constraint->location); + + asMutable(c.subType)->ty.emplace(errorRecoveryType()); + unblock(c.subType); + + return true; + } + asMutable(c.subType)->ty.emplace(*instantiated); InstantiationQueuer queuer{constraint->scope, constraint->location, this}; @@ -759,9 +774,11 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullnormalize(leftType); if (hasTypeInIntersection(leftType) && force) asMutable(leftType)->ty.emplace(anyPresent ? builtinTypes->anyType : builtinTypes->numberType); - if (isNumber(leftType)) + if (normLeftTy && normLeftTy->isNumber()) { unify(leftType, rightType, constraint->scope); asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); @@ -770,6 +787,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNulladdType(IntersectionType{{builtinTypes->falsyType, leftType}}); + TypeId leftFilteredTy = simplifyIntersection(builtinTypes, arena, leftType, builtinTypes->falsyType).result; - asMutable(resultType)->ty.emplace(arena->addType(UnionType{{leftFilteredTy, rightType}})); + asMutable(resultType)->ty.emplace(simplifyUnion(builtinTypes, arena, rightType, leftFilteredTy).result); unblock(resultType); return true; } @@ -819,9 +837,9 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNulladdType(IntersectionType{{builtinTypes->truthyType, leftType}}); + TypeId leftFilteredTy = simplifyIntersection(builtinTypes, arena, leftType, builtinTypes->truthyType).result; - asMutable(resultType)->ty.emplace(arena->addType(UnionType{{leftFilteredTy, rightType}})); + asMutable(resultType)->ty.emplace(simplifyUnion(builtinTypes, arena, rightType, leftFilteredTy).result); unblock(resultType); return true; } @@ -1266,7 +1284,12 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull instantiated = inst.substitute(overload); - LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS + + if (!instantiated.has_value()) + { + reportError(UnificationTooComplex{}, constraint->location); + return true; + } Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; u.enableScopeTests(); @@ -1374,7 +1397,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(followed)) *asMutable(c.resultType) = BoundType{c.discriminantType}; else - *asMutable(c.resultType) = BoundType{builtinTypes->unknownType}; + *asMutable(c.resultType) = BoundType{builtinTypes->anyType}; + + unblock(c.resultType); return true; } @@ -1700,10 +1725,131 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull found; + bool visit(TypeId ty, const BlockedType&) override + { + found.insert(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + found.insert(ty); + return false; + } +}; + +} + +static bool isNegatedAny(TypeId ty) +{ + ty = follow(ty); + const NegationType* nt = get(ty); + if (!nt) + return false; + TypeId negatedTy = follow(nt->ty); + return bool(get(negatedTy)); +} + +bool ConstraintSolver::tryDispatch(const RefineConstraint& c, NotNull constraint, bool force) +{ + if (isBlocked(c.discriminant)) + return block(c.discriminant, constraint); + + FindRefineConstraintBlockers fbt; + fbt.traverse(c.discriminant); + + if (!fbt.found.empty()) + { + bool foundOne = false; + + for (TypeId blocked : fbt.found) + { + if (blocked == c.type) + continue; + + block(blocked, constraint); + foundOne = true; + } + + if (foundOne) + return false; + } + + /* HACK: Refinements sometimes produce a type T & ~any under the assumption + * that ~any is the same as any. This is so so weird, but refinements needs + * some way to say "I may refine this, but I'm not sure." + * + * It does this by refining on a blocked type and deferring the decision + * until it is unblocked. + * + * Refinements also get negated, so we wind up with types like T & ~*blocked* + * + * We need to treat T & ~any as T in this case. + */ + + if (c.mode == RefineConstraint::Intersection && isNegatedAny(c.discriminant)) + { + asMutable(c.resultType)->ty.emplace(c.type); + unblock(c.resultType); + return true; + } + + const TypeId type = follow(c.type); + + LUAU_ASSERT(get(c.resultType)); + + if (type == c.resultType) + { + /* + * Sometimes, we get a constraint of the form + * + * *blocked-N* ~ refine *blocked-N* & U + * + * The constraint essentially states that a particular type is a + * refinement of itself. This is weird and I think vacuous. + * + * I *believe* it is safe to replace the result with a fresh type that + * is constrained by U. We effect this by minting a fresh type for the + * result when U = any, else we bind the result to whatever discriminant + * was offered. + */ + if (get(follow(c.discriminant))) + asMutable(c.resultType)->ty.emplace(constraint->scope); + else + asMutable(c.resultType)->ty.emplace(c.discriminant); + + unblock(c.resultType); + return true; + } + + auto [result, blockedTypes] = c.mode == RefineConstraint::Intersection ? simplifyIntersection(builtinTypes, NotNull{arena}, type, c.discriminant) + : simplifyUnion(builtinTypes, NotNull{arena}, type, c.discriminant); + + if (!force && !blockedTypes.empty()) + return block(blockedTypes, constraint); + + asMutable(c.resultType)->ty.emplace(result); + + unblock(c.resultType); + + return true; +} + bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force) { TypeId ty = follow(c.ty); - FamilyGraphReductionResult result = reduceFamilies(ty, constraint->location, NotNull{arena}, builtinTypes, nullptr, force); + FamilyGraphReductionResult result = + reduceFamilies(ty, constraint->location, NotNull{arena}, builtinTypes, constraint->scope, normalizer, nullptr, force); for (TypeId r : result.reducedTypes) unblock(r); @@ -1726,7 +1872,8 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force) { TypePackId tp = follow(c.tp); - FamilyGraphReductionResult result = reduceFamilies(tp, constraint->location, NotNull{arena}, builtinTypes, nullptr, force); + FamilyGraphReductionResult result = + reduceFamilies(tp, constraint->location, NotNull{arena}, builtinTypes, constraint->scope, normalizer, nullptr, force); for (TypeId r : result.reducedTypes) unblock(r); @@ -1951,13 +2098,15 @@ bool ConstraintSolver::tryDispatchIterableFunction( return true; } -std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) +std::pair, std::optional> ConstraintSolver::lookupTableProp( + TypeId subjectType, const std::string& propName, bool suppressSimplification) { std::unordered_set seen; - return lookupTableProp(subjectType, propName, seen); + return lookupTableProp(subjectType, propName, suppressSimplification, seen); } -std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) +std::pair, std::optional> ConstraintSolver::lookupTableProp( + TypeId subjectType, const std::string& propName, bool suppressSimplification, std::unordered_set& seen) { if (!seen.insert(subjectType).second) return {}; @@ -1985,7 +2134,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa } else if (auto mt = get(subjectType)) { - auto [blocked, result] = lookupTableProp(mt->table, propName, seen); + auto [blocked, result] = lookupTableProp(mt->table, propName, suppressSimplification, seen); if (!blocked.empty() || result) return {blocked, result}; @@ -2016,13 +2165,17 @@ std::pair, std::optional> ConstraintSolver::lookupTa } } else - return lookupTableProp(indexType, propName, seen); + return lookupTableProp(indexType, propName, suppressSimplification, seen); } } else if (auto ct = get(subjectType)) { if (auto p = lookupClassProp(ct, propName)) return {{}, p->type()}; + if (ct->indexer) + { + return {{}, ct->indexer->indexResultType}; + } } else if (auto pt = get(subjectType); pt && pt->metatable) { @@ -2033,7 +2186,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (indexProp == metatable->props.end()) return {{}, std::nullopt}; - return lookupTableProp(indexProp->second.type(), propName, seen); + return lookupTableProp(indexProp->second.type(), propName, suppressSimplification, seen); } else if (auto ft = get(subjectType)) { @@ -2054,7 +2207,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa for (TypeId ty : utv) { - auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, suppressSimplification, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) options.insert(*innerResult); @@ -2067,6 +2220,12 @@ std::pair, std::optional> ConstraintSolver::lookupTa return {{}, std::nullopt}; else if (options.size() == 1) return {{}, *begin(options)}; + else if (options.size() == 2 && !suppressSimplification) + { + TypeId one = *begin(options); + TypeId two = *(++begin(options)); + return {{}, simplifyUnion(builtinTypes, arena, one, two).result}; + } else return {{}, arena->addType(UnionType{std::vector(begin(options), end(options))})}; } @@ -2077,7 +2236,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa for (TypeId ty : itv) { - auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, suppressSimplification, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) options.insert(*innerResult); @@ -2090,6 +2249,12 @@ std::pair, std::optional> ConstraintSolver::lookupTa return {{}, std::nullopt}; else if (options.size() == 1) return {{}, *begin(options)}; + else if (options.size() == 2 && !suppressSimplification) + { + TypeId one = *begin(options); + TypeId two = *(++begin(options)); + return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + } else return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; } @@ -2214,13 +2379,6 @@ struct Blocker : TypeOnceVisitor { } - bool visit(TypeId ty, const BlockedType&) - { - blocked = true; - solver->block(ty, constraint); - return false; - } - bool visit(TypeId ty, const PendingExpansionType&) { blocked = true; @@ -2229,14 +2387,14 @@ struct Blocker : TypeOnceVisitor } }; -bool ConstraintSolver::recursiveBlock(TypeId target, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; blocker.traverse(target); return !blocker.blocked; } -bool ConstraintSolver::recursiveBlock(TypePackId pack, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypePackId pack, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; blocker.traverse(pack); @@ -2482,4 +2640,34 @@ TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, return arena->addType(UnionType{types}); } +TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp) +{ + tp = follow(tp); + + if (const VariadicTypePack* vtp = get(tp)) + { + TypeId ty = follow(vtp->ty); + return get(ty) ? builtinTypes->anyTypePack : tp; + } + + if (!get(follow(tp))) + return tp; + + std::vector resultTypes; + std::optional resultTail; + + TypePackIterator it = begin(tp); + + for (TypePackIterator e = end(tp); it != e; ++it) + { + TypeId ty = follow(*it); + resultTypes.push_back(get(ty) ? builtinTypes->anyType : ty); + } + + if (std::optional tail = it.tail()) + resultTail = anyifyModuleReturnTypePackGenerics(*tail); + + return arena->addTypePack(resultTypes, resultTail); +} + } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b16eda8a9..07393eb12 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -16,7 +16,6 @@ #include "Luau/TimeTrace.h" #include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeReduction.h" #include "Luau/Variant.h" #include @@ -622,7 +621,6 @@ CheckResult Frontend::check_DEPRECATED(const ModuleName& name, std::optionalastOriginalCallTypes.clear(); module->astOverloadResolvedTypes.clear(); module->astResolvedTypes.clear(); - module->astOriginalResolvedTypes.clear(); module->astResolvedTypePacks.clear(); module->astScopes.clear(); @@ -1138,7 +1136,6 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) module->astOriginalCallTypes.clear(); module->astOverloadResolvedTypes.clear(); module->astResolvedTypes.clear(); - module->astOriginalResolvedTypes.clear(); module->astResolvedTypePacks.clear(); module->astScopes.clear(); @@ -1311,7 +1308,6 @@ ModulePtr check(const SourceModule& sourceModule, const std::vector(); result->name = sourceModule.name; result->humanReadableName = sourceModule.humanReadableName; - result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, iceHandler); std::unique_ptr logger; if (recordJsonLog) @@ -1365,11 +1361,17 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorinternalTypes); freeze(result->interfaceTypes); diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 7d0f0f72f..1d6092f87 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -13,7 +13,7 @@ bool Instantiation::isDirty(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) { - if (ftv->hasNoGenerics) + if (ftv->hasNoFreeOrGenericTypes) return false; return true; @@ -74,7 +74,7 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) { - if (ftv->hasNoGenerics) + if (ftv->hasNoFreeOrGenericTypes) return true; // We aren't recursing in the case of a generic function which diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 830aaf754..0addaa360 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -10,7 +10,6 @@ #include "Luau/Type.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include "Luau/VisitType.h" #include @@ -20,7 +19,6 @@ LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess2, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); -LUAU_FASTFLAGVARIABLE(LuauCopyExportedTypes, false); namespace Luau { @@ -238,10 +236,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr // Copy external stuff over to Module itself this->returnType = moduleScope->returnType; - if (FFlag::DebugLuauDeferredConstraintResolution || FFlag::LuauCopyExportedTypes) - this->exportedTypeBindings = moduleScope->exportedTypeBindings; - else - this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); + this->exportedTypeBindings = moduleScope->exportedTypeBindings; } bool Module::hasModuleScope() const diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index cfc0ae137..24c31f7ec 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNormalizeMetatableFixes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauTransitiveSubtyping) @@ -228,6 +227,16 @@ NormalizedType::NormalizedType(NotNull builtinTypes) { } +bool NormalizedType::isFunction() const +{ + return !get(tops) || !functions.parts.empty(); +} + +bool NormalizedType::isNumber() const +{ + return !get(tops) || !get(numbers); +} + static bool isShallowInhabited(const NormalizedType& norm) { // This test is just a shallow check, for example it returns `true` for `{ p : never }` @@ -516,7 +525,8 @@ static bool areNormalizedClasses(const NormalizedClassType& tys) static bool isPlainTyvar(TypeId ty) { - return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty)) || get(ty)); + return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty)) || + get(ty) || get(ty)); } static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) @@ -1366,7 +1376,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else if (FFlag::LuauTransitiveSubtyping && get(here.tops)) return true; else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there)) || - get(there)) + get(there) || get(there)) { if (tyvarIndex(there) <= ignoreSmallerTyvars) return true; @@ -1436,7 +1446,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor } else if (!FFlag::LuauNormalizeBlockedTypes && get(there)) LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); - else if (get(there)) + else if (get(there) || get(there)) { // nothing } @@ -1981,17 +1991,14 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (isPrim(there, PrimitiveType::Table)) return here; - if (FFlag::LuauNormalizeMetatableFixes) - { - if (get(here)) - return there; - else if (get(there)) - return here; - else if (get(here)) - return there; - else if (get(there)) - return here; - } + if (get(here)) + return there; + else if (get(there)) + return here; + else if (get(here)) + return there; + else if (get(there)) + return here; TypeId htable = here; TypeId hmtable = nullptr; @@ -2009,22 +2016,12 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there } const TableType* httv = get(htable); - if (FFlag::LuauNormalizeMetatableFixes) - { - if (!httv) - return std::nullopt; - } - else - LUAU_ASSERT(httv); + if (!httv) + return std::nullopt; const TableType* tttv = get(ttable); - if (FFlag::LuauNormalizeMetatableFixes) - { - if (!tttv) - return std::nullopt; - } - else - LUAU_ASSERT(tttv); + if (!tttv) + return std::nullopt; if (httv->state == TableState::Free || tttv->state == TableState::Free) @@ -2471,7 +2468,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) return true; } else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there)) || - get(there)) + get(there) || get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 5a7a05011..3528d5345 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -136,7 +136,7 @@ void quantify(TypeId ty, TypeLevel level) ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) - ftv->hasNoGenerics = true; + ftv->hasNoFreeOrGenericTypes = true; } } else @@ -276,7 +276,7 @@ std::optional quantify(TypeArena* arena, TypeId ty, Scope* sco for (auto k : quantifier.insertedGenericPacks.keys) ftv->genericPacks.push_back(quantifier.insertedGenericPacks.pairings[k]); - ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType; + ftv->hasNoFreeOrGenericTypes = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType; return std::optional({*result, std::move(quantifier.insertedGenerics), std::move(quantifier.insertedGenericPacks)}); } diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp new file mode 100644 index 000000000..8e9424ae0 --- /dev/null +++ b/Analysis/src/Simplify.cpp @@ -0,0 +1,1270 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Simplify.h" + +#include "Luau/RecursionCounter.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" +#include "Luau/Normalize.h" // TypeIds + +LUAU_FASTINT(LuauTypeReductionRecursionLimit) + +namespace Luau +{ + +struct TypeSimplifier +{ + NotNull builtinTypes; + NotNull arena; + + std::set blockedTypes; + + int recursionDepth = 0; + + TypeId mkNegation(TypeId ty); + + TypeId intersectFromParts(std::set parts); + + TypeId intersectUnionWithType(TypeId unionTy, TypeId right); + TypeId intersectUnions(TypeId left, TypeId right); + TypeId intersectNegatedUnion(TypeId unionTy, TypeId right); + + TypeId intersectTypeWithNegation(TypeId a, TypeId b); + TypeId intersectNegations(TypeId a, TypeId b); + + TypeId intersectIntersectionWithType(TypeId left, TypeId right); + + // Attempt to intersect the two types. Does not recurse. Does not handle + // unions, intersections, or negations. + std::optional basicIntersect(TypeId left, TypeId right); + + TypeId intersect(TypeId ty, TypeId discriminant); + TypeId union_(TypeId ty, TypeId discriminant); + + TypeId simplify(TypeId ty); + TypeId simplify(TypeId ty, DenseHashSet& seen); +}; + +template +static std::pair get2(TID one, TID two) +{ + const A* a = get(one); + const B* b = get(two); + return a && b ? std::make_pair(a, b) : std::make_pair(nullptr, nullptr); +} + +// Match the exact type false|nil +static bool isFalsyType(TypeId ty) +{ + ty = follow(ty); + const UnionType* ut = get(ty); + if (!ut) + return false; + + bool hasFalse = false; + bool hasNil = false; + + auto it = begin(ut); + if (it == end(ut)) + return false; + + TypeId t = follow(*it); + + if (auto pt = get(t); pt && pt->type == PrimitiveType::NilType) + hasNil = true; + else if (auto st = get(t); st && st->variant == BooleanSingleton{false}) + hasFalse = true; + else + return false; + + ++it; + if (it == end(ut)) + return false; + + t = follow(*it); + + if (auto pt = get(t); pt && pt->type == PrimitiveType::NilType) + hasNil = true; + else if (auto st = get(t); st && st->variant == BooleanSingleton{false}) + hasFalse = true; + else + return false; + + ++it; + if (it != end(ut)) + return false; + + return hasFalse && hasNil; +} + +// Match the exact type ~(false|nil) +bool isTruthyType(TypeId ty) +{ + ty = follow(ty); + + const NegationType* nt = get(ty); + if (!nt) + return false; + + return isFalsyType(nt->ty); +} + +Relation flip(Relation rel) +{ + switch (rel) + { + case Relation::Subset: + return Relation::Superset; + case Relation::Superset: + return Relation::Subset; + default: + return rel; + } +} + +// FIXME: I'm not completely certain that this function is theoretically reasonable. +Relation combine(Relation a, Relation b) +{ + switch (a) + { + case Relation::Disjoint: + switch (b) + { + case Relation::Disjoint: + return Relation::Disjoint; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Intersects; + } + case Relation::Coincident: + switch (b) + { + case Relation::Disjoint: + return Relation::Coincident; + case Relation::Coincident: + return Relation::Coincident; + case Relation::Intersects: + return Relation::Superset; + case Relation::Subset: + return Relation::Coincident; + case Relation::Superset: + return Relation::Intersects; + } + case Relation::Superset: + switch (b) + { + case Relation::Disjoint: + return Relation::Superset; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Superset; + } + case Relation::Subset: + switch (b) + { + case Relation::Disjoint: + return Relation::Subset; + case Relation::Coincident: + return Relation::Coincident; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Subset; + case Relation::Superset: + return Relation::Intersects; + } + case Relation::Intersects: + switch (b) + { + case Relation::Disjoint: + return Relation::Intersects; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Intersects; + } + } + + LUAU_UNREACHABLE(); + return Relation::Intersects; +} + +// Given A & B, what is A & ~B? +Relation invert(Relation r) +{ + switch (r) + { + case Relation::Disjoint: + return Relation::Subset; + case Relation::Coincident: + return Relation::Disjoint; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Disjoint; + case Relation::Superset: + return Relation::Intersects; + } + + LUAU_UNREACHABLE(); + return Relation::Intersects; +} + +static bool isTypeVariable(TypeId ty) +{ + return get(ty) || get(ty) || get(ty) || get(ty); +} + +Relation relate(TypeId left, TypeId right); + +Relation relateTables(TypeId left, TypeId right) +{ + NotNull leftTable{get(left)}; + NotNull rightTable{get(right)}; + LUAU_ASSERT(1 == rightTable->props.size()); + + const auto [propName, rightProp] = *begin(rightTable->props); + + auto it = leftTable->props.find(propName); + if (it == leftTable->props.end()) + { + // Every table lacking a property is a supertype of a table having that + // property but the reverse is not true. + return Relation::Superset; + } + + const Property leftProp = it->second; + + Relation r = relate(leftProp.type(), rightProp.type()); + if (r == Relation::Coincident && 1 != leftTable->props.size()) + { + // eg {tag: "cat", prop: string} & {tag: "cat"} + return Relation::Subset; + } + else + return r; +} + +// A cheap and approximate subtype test +Relation relate(TypeId left, TypeId right) +{ + // TODO nice to have: Relate functions of equal argument and return arity + + left = follow(left); + right = follow(right); + + if (left == right) + return Relation::Coincident; + + if (get(left)) + { + if (get(right)) + return Relation::Subset; + else if (get(right)) + return Relation::Coincident; + else if (get(right)) + return Relation::Disjoint; + else + return Relation::Superset; + } + + if (get(right)) + return flip(relate(right, left)); + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else + return Relation::Superset; + } + + if (get(right)) + return flip(relate(right, left)); + + // Type variables + // * FreeType + // * GenericType + // * BlockedType + // * PendingExpansionType + + // Tops and bottoms + // * ErrorType + // * AnyType + // * NeverType + // * UnknownType + + // Concrete + // * PrimitiveType + // * SingletonType + // * FunctionType + // * TableType + // * MetatableType + // * ClassType + // * UnionType + // * IntersectionType + // * NegationType + + if (isTypeVariable(left) || isTypeVariable(right)) + return Relation::Intersects; + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else if (get(right)) + return Relation::Subset; + else + return Relation::Disjoint; + } + if (get(right)) + return flip(relate(right, left)); + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else + return Relation::Subset; + } + if (get(right)) + return flip(relate(right, left)); + + if (auto ut = get(left)) + return Relation::Intersects; + else if (auto ut = get(right)) + return Relation::Intersects; + + if (auto ut = get(left)) + return Relation::Intersects; + else if (auto ut = get(right)) + return Relation::Intersects; + + if (auto rnt = get(right)) + { + Relation a = relate(left, rnt->ty); + switch (a) + { + case Relation::Coincident: + // number & ~number + return Relation::Disjoint; + case Relation::Disjoint: + if (get(left)) + { + // ~number & ~string + return Relation::Intersects; + } + else + { + // number & ~string + return Relation::Subset; + } + case Relation::Intersects: + // ~(false?) & ~boolean + return Relation::Intersects; + case Relation::Subset: + // "hello" & ~string + return Relation::Disjoint; + case Relation::Superset: + // ~function & ~(false?) -> ~function + // boolean & ~(false?) -> true + // string & ~"hello" -> string & ~"hello" + return Relation::Intersects; + } + } + else if (get(left)) + return flip(relate(right, left)); + + if (auto lp = get(left)) + { + if (auto rp = get(right)) + { + if (lp->type == rp->type) + return Relation::Coincident; + else + return Relation::Disjoint; + } + + if (auto rs = get(right)) + { + if (lp->type == PrimitiveType::String && rs->variant.get_if()) + return Relation::Superset; + else if (lp->type == PrimitiveType::Boolean && rs->variant.get_if()) + return Relation::Superset; + else + return Relation::Disjoint; + } + + if (lp->type == PrimitiveType::Function) + { + if (get(right)) + return Relation::Superset; + else + return Relation::Disjoint; + } + if (lp->type == PrimitiveType::Table) + { + if (get(right)) + return Relation::Superset; + else + return Relation::Disjoint; + } + + if (get(right) || get(right) || get(right) || get(right)) + return Relation::Disjoint; + } + + if (auto ls = get(left)) + { + if (get(right) || get(right) || get(right) || get(right)) + return Relation::Disjoint; + + if (get(right)) + return flip(relate(right, left)); + if (auto rs = get(right)) + { + if (ls->variant == rs->variant) + return Relation::Coincident; + else + return Relation::Disjoint; + } + } + + if (get(left)) + { + if (auto rp = get(right)) + { + if (rp->type == PrimitiveType::Function) + return Relation::Subset; + else + return Relation::Disjoint; + } + else + return Relation::Intersects; + } + + if (auto lt = get(left)) + { + if (auto rp = get(right)) + { + if (rp->type == PrimitiveType::Table) + return Relation::Subset; + else + return Relation::Disjoint; + } + else if (auto rt = get(right)) + { + // TODO PROBABLY indexers and metatables. + if (1 == rt->props.size()) + { + Relation r = relateTables(left, right); + /* + * A reduction of these intersections is certainly possible, but + * it would require minting new table types. Also, I don't think + * it's super likely for this to arise from a refinement. + * + * Time will tell! + * + * ex we simplify this + * {tag: string} & {tag: "cat"} + * but not this + * {tag: string, prop: number} & {tag: "cat"} + */ + if (lt->props.size() > 1 && r == Relation::Superset) + return Relation::Intersects; + else + return r; + } + else if (1 == lt->props.size()) + return flip(relate(right, left)); + else + return Relation::Intersects; + } + // TODO metatables + + return Relation::Disjoint; + } + + if (auto ct = get(left)) + { + if (auto rct = get(right)) + { + if (isSubclass(ct, rct)) + return Relation::Subset; + else if (isSubclass(rct, ct)) + return Relation::Superset; + else + return Relation::Disjoint; + } + + return Relation::Disjoint; + } + + return Relation::Intersects; +} + +TypeId TypeSimplifier::mkNegation(TypeId ty) +{ + TypeId result = nullptr; + + if (ty == builtinTypes->truthyType) + result = builtinTypes->falsyType; + else if (ty == builtinTypes->falsyType) + result = builtinTypes->truthyType; + else if (auto ntv = get(ty)) + result = follow(ntv->ty); + else + result = arena->addType(NegationType{ty}); + + return result; +} + +TypeId TypeSimplifier::intersectFromParts(std::set parts) +{ + if (0 == parts.size()) + return builtinTypes->neverType; + else if (1 == parts.size()) + return *begin(parts); + + { + auto it = begin(parts); + while (it != end(parts)) + { + TypeId t = follow(*it); + + auto copy = it; + ++it; + + if (auto ut = get(t)) + { + for (TypeId part : ut) + parts.insert(part); + parts.erase(copy); + } + } + } + + std::set newParts; + + /* + * It is possible that the parts of the passed intersection are themselves + * reducable. + * + * eg false & boolean + * + * We do a comparison between each pair of types and look for things that we + * can elide. + */ + for (TypeId part : parts) + { + if (newParts.empty()) + { + newParts.insert(part); + continue; + } + + auto it = begin(newParts); + while (it != end(newParts)) + { + TypeId p = *it; + + switch (relate(part, p)) + { + case Relation::Disjoint: + // eg boolean & string + return builtinTypes->neverType; + case Relation::Subset: + { + /* part is a subset of p. Remove p from the set and replace it + * with part. + * + * eg boolean & true + */ + auto saveIt = it; + ++it; + newParts.erase(saveIt); + continue; + } + case Relation::Coincident: + case Relation::Superset: + { + /* part is coincident or a superset of p. We do not need to + * include part in the final intersection. + * + * ex true & boolean + */ + ++it; + continue; + } + case Relation::Intersects: + { + /* It's complicated! A simplification may still be possible, + * but we have to pull the types apart to figure it out. + * + * ex boolean & ~false + */ + std::optional simplified = basicIntersect(part, p); + + auto saveIt = it; + ++it; + + if (simplified) + { + newParts.erase(saveIt); + newParts.insert(*simplified); + } + else + newParts.insert(part); + continue; + } + } + } + } + + if (0 == newParts.size()) + return builtinTypes->neverType; + else if (1 == newParts.size()) + return *begin(newParts); + else + return arena->addType(IntersectionType{std::vector{begin(newParts), end(newParts)}}); +} + +TypeId TypeSimplifier::intersectUnionWithType(TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + LUAU_ASSERT(leftUnion); + + bool changed = false; + std::set newParts; + + for (TypeId part : leftUnion) + { + TypeId simplified = intersect(right, part); + changed |= simplified != part; + + if (get(simplified)) + { + changed = true; + continue; + } + + newParts.insert(simplified); + } + + if (!changed) + return left; + else if (newParts.empty()) + return builtinTypes->neverType; + else if (newParts.size() == 1) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector(begin(newParts), end(newParts))}); +} + +TypeId TypeSimplifier::intersectUnions(TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + LUAU_ASSERT(leftUnion); + + const UnionType* rightUnion = get(right); + LUAU_ASSERT(rightUnion); + + std::set newParts; + + for (TypeId leftPart : leftUnion) + { + for (TypeId rightPart : rightUnion) + { + TypeId simplified = intersect(leftPart, rightPart); + if (get(simplified)) + continue; + + newParts.insert(simplified); + } + } + + if (newParts.empty()) + return builtinTypes->neverType; + else if (newParts.size() == 1) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector(begin(newParts), end(newParts))}); +} + +TypeId TypeSimplifier::intersectNegatedUnion(TypeId left, TypeId right) +{ + // ~(A | B) & C + // (~A & C) & (~B & C) + + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + TypeId negatedTy = follow(leftNegation->ty); + + const UnionType* negatedUnion = get(negatedTy); + LUAU_ASSERT(negatedUnion); + + bool changed = false; + std::set newParts; + + for (TypeId part : negatedUnion) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Disjoint: + // If A is disjoint from B, then ~A & B is just B. + // + // ~(false?) & true + // (~false & true) & (~nil & true) + // true & true + newParts.insert(right); + break; + case Relation::Coincident: + // If A is coincident with or a superset of B, then ~A & B is never. + // + // ~(false?) & false + // (~false & false) & (~nil & false) + // never & false + // + // fallthrough + case Relation::Superset: + // If A is a superset of B, then ~A & B is never. + // + // ~(boolean | nil) & true + // (~boolean & true) & (~boolean & nil) + // never & nil + return builtinTypes->neverType; + case Relation::Subset: + case Relation::Intersects: + // If A is a subset of B, then ~A & B is a bit more complicated. We need to think harder. + // + // ~(false?) & boolean + // (~false & boolean) & (~nil & boolean) + // true & boolean + TypeId simplified = intersectTypeWithNegation(mkNegation(part), right); + changed |= simplified != right; + if (get(simplified)) + changed = true; + else + newParts.insert(simplified); + break; + } + } + + if (!changed) + return right; + else + return intersectFromParts(std::move(newParts)); +} + +TypeId TypeSimplifier::intersectTypeWithNegation(TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + TypeId negatedTy = follow(leftNegation->ty); + + if (negatedTy == right) + return builtinTypes->neverType; + + if (auto ut = get(negatedTy)) + { + // ~(A | B) & C + // (~A & C) & (~B & C) + + bool changed = false; + std::set newParts; + + for (TypeId part : ut) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Coincident: + // ~(false?) & nil + // (~false & nil) & (~nil & nil) + // nil & never + // + // fallthrough + case Relation::Superset: + // ~(boolean | string) & true + // (~boolean & true) & (~boolean & string) + // never & string + + return builtinTypes->neverType; + + case Relation::Disjoint: + // ~nil & boolean + newParts.insert(right); + break; + + case Relation::Subset: + // ~false & boolean + // fallthrough + case Relation::Intersects: + // FIXME: The mkNegation here is pretty unfortunate. + // Memoizing this will probably be important. + changed = true; + newParts.insert(right); + newParts.insert(mkNegation(part)); + } + } + + if (!changed) + return right; + else + return intersectFromParts(std::move(newParts)); + } + + if (auto rightUnion = get(right)) + { + // ~A & (B | C) + bool changed = false; + std::set newParts; + + for (TypeId part : rightUnion) + { + Relation r = relate(negatedTy, part); + switch (r) + { + case Relation::Coincident: + changed = true; + continue; + case Relation::Disjoint: + newParts.insert(part); + break; + case Relation::Superset: + changed = true; + continue; + case Relation::Subset: + // fallthrough + case Relation::Intersects: + changed = true; + newParts.insert(arena->addType(IntersectionType{{left, part}})); + } + } + + if (!changed) + return right; + else if (0 == newParts.size()) + return builtinTypes->neverType; + else if (1 == newParts.size()) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector{begin(newParts), end(newParts)}}); + } + + if (auto pt = get(right); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(negatedTy)) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else if (st->variant == BooleanSingleton{false}) + return builtinTypes->trueType; + else + // boolean & ~"hello" + return builtinTypes->booleanType; + } + } + + Relation r = relate(negatedTy, right); + + switch (r) + { + case Relation::Disjoint: + // ~boolean & string + return right; + case Relation::Coincident: + // ~string & string + // fallthrough + case Relation::Superset: + // ~string & "hello" + return builtinTypes->neverType; + case Relation::Subset: + // ~string & unknown + // ~"hello" & string + // fallthrough + case Relation::Intersects: + // ~("hello" | boolean) & string + // fallthrough + default: + return arena->addType(IntersectionType{{left, right}}); + } +} + +TypeId TypeSimplifier::intersectNegations(TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + if (get(follow(leftNegation->ty))) + return intersectNegatedUnion(left, right); + + const NegationType* rightNegation = get(right); + LUAU_ASSERT(rightNegation); + + if (get(follow(rightNegation->ty))) + return intersectNegatedUnion(right, left); + + Relation r = relate(leftNegation->ty, rightNegation->ty); + + switch (r) + { + case Relation::Coincident: + // ~true & ~true + return left; + case Relation::Subset: + // ~true & ~boolean + return right; + case Relation::Superset: + // ~boolean & ~true + return left; + case Relation::Intersects: + case Relation::Disjoint: + default: + // ~boolean & ~string + return arena->addType(IntersectionType{{left, right}}); + } +} + +TypeId TypeSimplifier::intersectIntersectionWithType(TypeId left, TypeId right) +{ + const IntersectionType* leftIntersection = get(left); + LUAU_ASSERT(leftIntersection); + + bool changed = false; + std::set newParts; + + for (TypeId part : leftIntersection) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Disjoint: + return builtinTypes->neverType; + case Relation::Coincident: + newParts.insert(part); + continue; + case Relation::Subset: + newParts.insert(part); + continue; + case Relation::Superset: + newParts.insert(right); + changed = true; + continue; + default: + newParts.insert(part); + newParts.insert(right); + changed = true; + continue; + } + } + + // It is sometimes the case that an intersection operation will result in + // clipping a free type from the result. + // + // eg (number & 'a) & string --> never + // + // We want to only report the free types that are part of the result. + for (TypeId part : newParts) + { + if (isTypeVariable(part)) + blockedTypes.insert(part); + } + + if (!changed) + return left; + return intersectFromParts(std::move(newParts)); +} + +std::optional TypeSimplifier::basicIntersect(TypeId left, TypeId right) +{ + if (get(left)) + return right; + if (get(right)) + return left; + if (get(left)) + return left; + if (get(right)) + return right; + + if (auto pt = get(left); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(right); st && st->variant.get_if()) + return right; + if (auto nt = get(right)) + { + if (auto st = get(follow(nt->ty)); st && st->variant.get_if()) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else + return builtinTypes->trueType; + } + } + } + else if (auto pt = get(right); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(left); st && st->variant.get_if()) + return left; + if (auto nt = get(left)) + { + if (auto st = get(follow(nt->ty)); st && st->variant.get_if()) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else + return builtinTypes->trueType; + } + } + } + + if (const auto [lt, rt] = get2(left, right); lt && rt) + { + if (1 == lt->props.size()) + { + const auto [propName, leftProp] = *begin(lt->props); + + auto it = rt->props.find(propName); + if (it != rt->props.end()) + { + Relation r = relate(leftProp.type(), it->second.type()); + + switch (r) + { + case Relation::Disjoint: + return builtinTypes->neverType; + case Relation::Coincident: + return right; + default: + break; + } + } + } + else if (1 == rt->props.size()) + return basicIntersect(right, left); + } + + Relation relation = relate(left, right); + if (left == right || Relation::Coincident == relation) + return left; + + if (relation == Relation::Disjoint) + return builtinTypes->neverType; + else if (relation == Relation::Subset) + return left; + else if (relation == Relation::Superset) + return right; + + return std::nullopt; +} + +TypeId TypeSimplifier::intersect(TypeId left, TypeId right) +{ + RecursionLimiter rl(&recursionDepth, 15); + + left = simplify(left); + right = simplify(right); + + if (get(left)) + return right; + if (get(right)) + return left; + if (get(left)) + return left; + if (get(right)) + return right; + + if (isTypeVariable(left)) + { + blockedTypes.insert(left); + return arena->addType(IntersectionType{{left, right}}); + } + + if (isTypeVariable(right)) + { + blockedTypes.insert(right); + return arena->addType(IntersectionType{{left, right}}); + } + + if (auto ut = get(left)) + { + if (get(right)) + return intersectUnions(left, right); + else + return intersectUnionWithType(left, right); + } + else if (auto ut = get(right)) + return intersectUnionWithType(right, left); + + if (auto it = get(left)) + return intersectIntersectionWithType(left, right); + else if (auto it = get(right)) + return intersectIntersectionWithType(right, left); + + if (get(left)) + { + if (get(right)) + return intersectNegations(left, right); + else + return intersectTypeWithNegation(left, right); + } + else if (get(right)) + return intersectTypeWithNegation(right, left); + + std::optional res = basicIntersect(left, right); + if (res) + return *res; + else + return arena->addType(IntersectionType{{left, right}}); +} + +TypeId TypeSimplifier::union_(TypeId left, TypeId right) +{ + RecursionLimiter rl(&recursionDepth, 15); + + left = simplify(left); + right = simplify(right); + + if (auto leftUnion = get(left)) + { + bool changed = false; + std::set newParts; + for (TypeId part : leftUnion) + { + if (get(part)) + { + changed = true; + continue; + } + + Relation r = relate(part, right); + switch (r) + { + case Relation::Coincident: + case Relation::Superset: + return left; + default: + newParts.insert(part); + newParts.insert(right); + changed = true; + break; + } + } + + if (!changed) + return left; + if (1 == newParts.size()) + return *begin(newParts); + return arena->addType(UnionType{std::vector{begin(newParts), end(newParts)}}); + } + else if (get(right)) + return union_(right, left); + + Relation r = relate(left, right); + if (left == right || r == Relation::Coincident || r == Relation::Superset) + return left; + + if (r == Relation::Subset) + return right; + + if (auto as = get(left)) + { + if (auto abs = as->variant.get_if()) + { + if (auto bs = get(right)) + { + if (auto bbs = bs->variant.get_if()) + { + if (abs->value != bbs->value) + return builtinTypes->booleanType; + } + } + } + } + + return arena->addType(UnionType{{left, right}}); +} + +TypeId TypeSimplifier::simplify(TypeId ty) +{ + DenseHashSet seen{nullptr}; + return simplify(ty, seen); +} + +TypeId TypeSimplifier::simplify(TypeId ty, DenseHashSet& seen) +{ + RecursionLimiter limiter(&recursionDepth, 60); + + ty = follow(ty); + + if (seen.find(ty)) + return ty; + seen.insert(ty); + + if (auto nt = get(ty)) + { + TypeId negatedTy = follow(nt->ty); + if (get(negatedTy)) + return builtinTypes->neverType; + else if (get(negatedTy)) + return builtinTypes->anyType; + if (auto nnt = get(negatedTy)) + return simplify(nnt->ty, seen); + } + + // Promote {x: never} to never + if (auto tt = get(ty)) + { + if (1 == tt->props.size()) + { + TypeId propTy = simplify(begin(tt->props)->second.type(), seen); + if (get(propTy)) + return builtinTypes->neverType; + } + } + + return ty; +} + +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) +{ + TypeSimplifier s{builtinTypes, arena}; + + // fprintf(stderr, "Intersect %s and %s ...\n", toString(left).c_str(), toString(right).c_str()); + + TypeId res = s.intersect(left, right); + + // fprintf(stderr, "Intersect %s and %s -> %s\n", toString(left).c_str(), toString(right).c_str(), toString(res).c_str()); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) +{ + TypeSimplifier s{builtinTypes, arena}; + + TypeId res = s.union_(left, right); + + // fprintf(stderr, "Union %s and %s -> %s\n", toString(a).c_str(), toString(b).c_str(), toString(res).c_str()); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + +} // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index f5b908e36..347380cd9 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1639,6 +1639,11 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) } else if constexpr (std::is_same_v) return tos(c.resultPack) + " ~ unpack " + tos(c.sourcePack); + else if constexpr (std::is_same_v) + { + const char* op = c.mode == RefineConstraint::Union ? "union" : "intersect"; + return tos(c.resultType) + " ~ refine " + tos(c.type) + " " + op + " " + tos(c.discriminant); + } else if constexpr (std::is_same_v) return "reduce " + tos(c.ty); else if constexpr (std::is_same_v) @@ -1652,6 +1657,11 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) return visit(go, constraint.c); } +std::string toString(const Constraint& constraint) +{ + return toString(constraint, ToStringOptions{}); +} + std::string dump(const Constraint& c) { ToStringOptions opts; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 53dd3b445..5d38f28e7 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -82,6 +82,8 @@ void TxnLog::concat(TxnLog rhs) for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) @@ -103,6 +105,8 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) @@ -199,10 +203,14 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::commit() { + LUAU_ASSERT(!radioactive); + for (auto& [ty, rep] : typeVarChanges) { if (!rep->dead) @@ -234,6 +242,8 @@ TxnLog TxnLog::inverse() for (auto& [tp, _rep] : typePackChanges) inversed.typePackChanges[tp] = std::make_unique(*tp); + inversed.radioactive = radioactive; + return inversed; } @@ -293,7 +303,8 @@ void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) PendingType* TxnLog::queue(TypeId ty) { - LUAU_ASSERT(!ty->persistent); + if (ty->persistent) + radioactive = true; // Explicitly don't look in ancestors. If we have discovered something new // about this type, we don't want to mutate the parent's state. @@ -309,7 +320,8 @@ PendingType* TxnLog::queue(TypeId ty) PendingTypePack* TxnLog::queue(TypePackId tp) { - LUAU_ASSERT(!tp->persistent); + if (tp->persistent) + radioactive = true; // Explicitly don't look in ancestors. If we have discovered something new // about this type, we don't want to mutate the parent's state. diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a1f764a4d..40376e32a 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -13,7 +13,6 @@ #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" -#include "Luau/TypeReduction.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/TypeFamily.h" @@ -21,7 +20,6 @@ #include LUAU_FASTFLAG(DebugLuauMagicTypes) -LUAU_FASTFLAG(DebugLuauDontReduceTypes) namespace Luau { @@ -117,7 +115,7 @@ struct TypeChecker2 TypeId checkForFamilyInhabitance(TypeId instance, Location location) { TxnLog fake{}; - reportErrors(reduceFamilies(instance, location, NotNull{&testArena}, builtinTypes, &fake, true).errors); + reportErrors(reduceFamilies(instance, location, NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true).errors); return instance; } @@ -1002,7 +1000,9 @@ struct TypeChecker2 LUAU_ASSERT(ftv); reportErrors(tryUnify(stack.back(), call->location, ftv->retTypes, expectedRetType, CountMismatch::Context::Return, /* genericsOkay */ true)); - reportErrors(reduceFamilies(ftv->retTypes, call->location, NotNull{&testArena}, builtinTypes, &fake, true).errors); + reportErrors( + reduceFamilies(ftv->retTypes, call->location, NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true) + .errors); auto it = begin(expectedArgTypes); size_t i = 0; @@ -1020,7 +1020,7 @@ struct TypeChecker2 Location argLoc = argLocs.at(i >= argLocs.size() ? argLocs.size() - 1 : i); reportErrors(tryUnify(stack.back(), argLoc, expectedArg, arg, CountMismatch::Context::Arg, /* genericsOkay */ true)); - reportErrors(reduceFamilies(arg, argLoc, NotNull{&testArena}, builtinTypes, &fake, true).errors); + reportErrors(reduceFamilies(arg, argLoc, NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true).errors); ++it; ++i; @@ -1032,12 +1032,11 @@ struct TypeChecker2 { TypePackId remainingArgs = testArena.addTypePack(TypePack{std::move(slice), std::nullopt}); reportErrors(tryUnify(stack.back(), argLocs.back(), *tail, remainingArgs, CountMismatch::Context::Arg, /* genericsOkay */ true)); - reportErrors(reduceFamilies(remainingArgs, argLocs.back(), NotNull{&testArena}, builtinTypes, &fake, true).errors); + reportErrors(reduceFamilies( + remainingArgs, argLocs.back(), NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true) + .errors); } } - - // We do not need to do an arity test because this overload was - // selected based on its arity already matching. } else { @@ -1160,25 +1159,26 @@ struct TypeChecker2 return ty; } - void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy) { visit(expr, ValueContext::RValue); - TypeId leftType = stripFromNilAndReport(lookupType(expr), location); - checkIndexTypeFromType(leftType, propName, location, context); + checkIndexTypeFromType(leftType, propName, location, context, astIndexExprTy); } void visit(AstExprIndexName* indexName, ValueContext context) { - visitExprName(indexName->expr, indexName->location, indexName->index.value, context); + // If we're indexing like _.foo - foo could either be a prop or a string. + visitExprName(indexName->expr, indexName->location, indexName->index.value, context, builtinTypes->stringType); } void visit(AstExprIndexExpr* indexExpr, ValueContext context) { if (auto str = indexExpr->index->as()) { + TypeId astIndexExprType = lookupType(indexExpr->index); const std::string stringValue(str->value.data, str->value.size); - visitExprName(indexExpr->expr, indexExpr->location, stringValue, context); + visitExprName(indexExpr->expr, indexExpr->location, stringValue, context, astIndexExprType); return; } @@ -1198,6 +1198,8 @@ struct TypeChecker2 else reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); } + else if (auto cls = get(exprType); cls && cls->indexer) + reportErrors(tryUnify(scope, indexExpr->index->location, indexType, cls->indexer->indexType)); else if (get(exprType) && isOptional(exprType)) reportError(OptionalValueAccess{exprType}, indexExpr->location); } @@ -1209,32 +1211,52 @@ struct TypeChecker2 visitGenerics(fn->generics, fn->genericPacks); TypeId inferredFnTy = lookupType(fn); - const FunctionType* inferredFtv = get(inferredFnTy); - LUAU_ASSERT(inferredFtv); - - // There is no way to write an annotation for the self argument, so we - // cannot do anything to check it. - auto argIt = begin(inferredFtv->argTypes); - if (fn->self) - ++argIt; - for (const auto& arg : fn->args) + const NormalizedType* normalizedFnTy = normalizer.normalize(inferredFnTy); + if (!normalizedFnTy) + { + reportError(CodeTooComplex{}, fn->location); + } + else if (get(normalizedFnTy->errors)) + { + // Nothing + } + else if (!normalizedFnTy->isFunction()) + { + ice->ice("Internal error: Lambda has non-function type " + toString(inferredFnTy), fn->location); + } + else { - if (argIt == end(inferredFtv->argTypes)) - break; + if (1 != normalizedFnTy->functions.parts.size()) + ice->ice("Unexpected: Lambda has unexpected type " + toString(inferredFnTy), fn->location); - if (arg->annotation) + const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); + LUAU_ASSERT(inferredFtv); + + // There is no way to write an annotation for the self argument, so we + // cannot do anything to check it. + auto argIt = begin(inferredFtv->argTypes); + if (fn->self) + ++argIt; + + for (const auto& arg : fn->args) { - TypeId inferredArgTy = *argIt; - TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + if (argIt == end(inferredFtv->argTypes)) + break; - if (!isSubtype(inferredArgTy, annotatedArgTy, stack.back())) + if (arg->annotation) { - reportError(TypeMismatch{inferredArgTy, annotatedArgTy}, arg->location); + TypeId inferredArgTy = *argIt; + TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + + if (!isSubtype(inferredArgTy, annotatedArgTy, stack.back())) + { + reportError(TypeMismatch{inferredArgTy, annotatedArgTy}, arg->location); + } } - } - ++argIt; + ++argIt; + } } visit(fn->body); @@ -1345,6 +1367,10 @@ struct TypeChecker2 TypeId leftType = lookupType(expr->left); TypeId rightType = lookupType(expr->right); + TypeId expectedResult = lookupType(expr); + + if (get(expectedResult)) + return expectedResult; if (expr->op == AstExprBinary::Op::Or) { @@ -1432,7 +1458,11 @@ struct TypeChecker2 TypeId instantiatedMm = module->astOverloadResolvedTypes[key]; if (!instantiatedMm) - reportError(CodeTooComplex{}, expr->location); + { + // reportError(CodeTooComplex{}, expr->location); + // was handled by a type family + return expectedResult; + } else if (const FunctionType* ftv = get(follow(instantiatedMm))) { @@ -1715,7 +1745,7 @@ struct TypeChecker2 { // No further validation is necessary in this case. The main logic for // _luau_print is contained in lookupAnnotation. - if (FFlag::DebugLuauMagicTypes && ty->name == "_luau_print" && ty->parameters.size > 0) + if (FFlag::DebugLuauMagicTypes && ty->name == "_luau_print") return; for (const AstTypeOrPack& param : ty->parameters) @@ -1764,6 +1794,7 @@ struct TypeChecker2 if (packsProvided != 0) { reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location); + continue; } if (typesProvided < typesRequired) @@ -1792,7 +1823,11 @@ struct TypeChecker2 if (extraTypes != 0 && packsProvided == 0) { - packsProvided += 1; + // Extra types are only collected into a pack if a pack is expected + if (packsRequired != 0) + packsProvided += 1; + else + typesProvided += extraTypes; } for (size_t i = typesProvided; i < typesRequired; ++i) @@ -1943,69 +1978,6 @@ struct TypeChecker2 } } - void reduceTypes() - { - if (FFlag::DebugLuauDontReduceTypes) - return; - - for (auto [_, scope] : module->scopes) - { - for (auto& [_, b] : scope->bindings) - { - if (auto reduced = module->reduction->reduce(b.typeId)) - b.typeId = *reduced; - } - - if (auto reduced = module->reduction->reduce(scope->returnType)) - scope->returnType = *reduced; - - if (scope->varargPack) - { - if (auto reduced = module->reduction->reduce(*scope->varargPack)) - scope->varargPack = *reduced; - } - - auto reduceMap = [this](auto& map) { - for (auto& [_, tf] : map) - { - if (auto reduced = module->reduction->reduce(tf)) - tf = *reduced; - } - }; - - reduceMap(scope->exportedTypeBindings); - reduceMap(scope->privateTypeBindings); - reduceMap(scope->privateTypePackBindings); - for (auto& [_, space] : scope->importedTypeBindings) - reduceMap(space); - } - - auto reduceOrError = [this](auto& map) { - for (auto [ast, t] : map) - { - if (!t) - continue; // Reminder: this implies that the recursion limit was exceeded. - else if (auto reduced = module->reduction->reduce(t)) - map[ast] = *reduced; - else - reportError(NormalizationTooComplex{}, ast->location); - } - }; - - module->astOriginalResolvedTypes = module->astResolvedTypes; - - // Both [`Module::returnType`] and [`Module::exportedTypeBindings`] are empty here, and - // is populated by [`Module::clonePublicInterface`] in the future, so by that point these - // two aforementioned fields will only contain types that are irreducible. - reduceOrError(module->astTypes); - reduceOrError(module->astTypePacks); - reduceOrError(module->astExpectedTypes); - reduceOrError(module->astOriginalCallTypes); - reduceOrError(module->astOverloadResolvedTypes); - reduceOrError(module->astResolvedTypes); - reduceOrError(module->astResolvedTypePacks); - } - template bool isSubtype(TID subTy, TID superTy, NotNull scope, bool genericsOkay = false) { @@ -2034,6 +2006,9 @@ struct TypeChecker2 void reportError(TypeErrorData data, const Location& location) { + if (auto utk = get_if(&data)) + diagnoseMissingTableKey(utk, data); + module->errors.emplace_back(location, module->name, std::move(data)); if (logger) @@ -2052,7 +2027,7 @@ struct TypeChecker2 } // If the provided type does not have the named property, report an error. - void checkIndexTypeFromType(TypeId tableTy, const std::string& prop, const Location& location, ValueContext context) + void checkIndexTypeFromType(TypeId tableTy, const std::string& prop, const Location& location, ValueContext context, TypeId astIndexExprType) { const NormalizedType* norm = normalizer.normalize(tableTy); if (!norm) @@ -2069,7 +2044,7 @@ struct TypeChecker2 return; std::unordered_set seen; - bool found = hasIndexTypeFromType(ty, prop, location, seen); + bool found = hasIndexTypeFromType(ty, prop, location, seen, astIndexExprType); foundOneProp |= found; if (!found) typesMissingTheProp.push_back(ty); @@ -2129,7 +2104,7 @@ struct TypeChecker2 } } - bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location, std::unordered_set& seen) + bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location, std::unordered_set& seen, TypeId astIndexExprType) { // If we have already encountered this type, we must assume that some // other codepath will do the right thing and signal false if the @@ -2153,31 +2128,83 @@ struct TypeChecker2 if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)) return true; - else if (tt->indexer && isPrim(tt->indexer->indexType, PrimitiveType::String)) - return true; + if (tt->indexer) + { + TypeId indexType = follow(tt->indexer->indexType); + if (isPrim(indexType, PrimitiveType::String)) + return true; + // If the indexer looks like { [any] : _} - the prop lookup should be allowed! + else if (get(indexType) || get(indexType)) + return true; + } - else - return false; + return false; } else if (const ClassType* cls = get(ty)) - return bool(lookupClassProp(cls, prop)); + { + // If the property doesn't exist on the class, we consult the indexer + // We need to check if the type of the index expression foo (x[foo]) + // is compatible with the indexer's indexType + // Construct the intersection and test inhabitedness! + if (auto property = lookupClassProp(cls, prop)) + return true; + if (cls->indexer) + { + TypeId inhabitatedTestType = testArena.addType(IntersectionType{{cls->indexer->indexType, astIndexExprType}}); + return normalizer.isInhabited(inhabitatedTestType); + } + return false; + } else if (const UnionType* utv = get(ty)) return std::all_of(begin(utv), end(utv), [&](TypeId part) { - return hasIndexTypeFromType(part, prop, location, seen); + return hasIndexTypeFromType(part, prop, location, seen, astIndexExprType); }); else if (const IntersectionType* itv = get(ty)) return std::any_of(begin(itv), end(itv), [&](TypeId part) { - return hasIndexTypeFromType(part, prop, location, seen); + return hasIndexTypeFromType(part, prop, location, seen, astIndexExprType); }); else return false; } + + void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const + { + std::string_view sv(utk->key); + std::set candidates; + + auto accumulate = [&](const TableType::Props& props) { + for (const auto& [name, ty] : props) + { + if (sv != name && equalsLower(sv, name)) + candidates.insert(name); + } + }; + + if (auto ttv = getTableType(utk->table)) + accumulate(ttv->props); + else if (auto ctv = get(follow(utk->table))) + { + while (ctv) + { + accumulate(ctv->props); + + if (!ctv->parent) + break; + + ctv = get(*ctv->parent); + LUAU_ASSERT(ctv); + } + } + + if (!candidates.empty()) + data = TypeErrorData(UnknownPropButFoundLikeProp{utk->table, utk->key, candidates}); + } }; void check(NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule& sourceModule, Module* module) { TypeChecker2 typeChecker{builtinTypes, unifierState, logger, &sourceModule, module}; - typeChecker.reduceTypes(); + typeChecker.visit(sourceModule.root); unfreeze(module->interfaceTypes); diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 1941573b6..e5a06c0a4 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -7,6 +7,10 @@ #include "Luau/TxnLog.h" #include "Luau/Substitution.h" #include "Luau/ToString.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier.h" +#include "Luau/Instantiation.h" +#include "Luau/Normalize.h" LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); @@ -30,6 +34,11 @@ struct InstanceCollector : TypeOnceVisitor return true; } + bool visit(TypeId ty, const ClassType&) override + { + return false; + } + bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override { // TypeOnceVisitor performs a depth-first traversal in the absence of @@ -52,20 +61,24 @@ struct FamilyReducer Location location; NotNull arena; NotNull builtins; - TxnLog* log = nullptr; - NotNull reducerLog; + TxnLog* parentLog = nullptr; + TxnLog log; bool force = false; + NotNull scope; + NotNull normalizer; FamilyReducer(std::deque queuedTys, std::deque queuedTps, Location location, NotNull arena, - NotNull builtins, TxnLog* log = nullptr, bool force = false) + NotNull builtins, NotNull scope, NotNull normalizer, TxnLog* parentLog = nullptr, bool force = false) : queuedTys(std::move(queuedTys)) , queuedTps(std::move(queuedTps)) , location(location) , arena(arena) , builtins(builtins) - , log(log) - , reducerLog(NotNull{log ? log : TxnLog::empty()}) + , parentLog(parentLog) + , log(parentLog) , force(force) + , scope(scope) + , normalizer(normalizer) { } @@ -78,16 +91,16 @@ struct FamilyReducer SkipTestResult testForSkippability(TypeId ty) { - ty = reducerLog->follow(ty); + ty = log.follow(ty); - if (reducerLog->is(ty)) + if (log.is(ty)) { if (!irreducible.contains(ty)) return SkipTestResult::Defer; else return SkipTestResult::Irreducible; } - else if (reducerLog->is(ty)) + else if (log.is(ty)) { return SkipTestResult::Irreducible; } @@ -97,16 +110,16 @@ struct FamilyReducer SkipTestResult testForSkippability(TypePackId ty) { - ty = reducerLog->follow(ty); + ty = log.follow(ty); - if (reducerLog->is(ty)) + if (log.is(ty)) { if (!irreducible.contains(ty)) return SkipTestResult::Defer; else return SkipTestResult::Irreducible; } - else if (reducerLog->is(ty)) + else if (log.is(ty)) { return SkipTestResult::Irreducible; } @@ -117,8 +130,8 @@ struct FamilyReducer template void replace(T subject, T replacement) { - if (log) - log->replace(subject, Unifiable::Bound{replacement}); + if (parentLog) + parentLog->replace(subject, Unifiable::Bound{replacement}); else asMutable(subject)->ty.template emplace>(replacement); @@ -208,37 +221,38 @@ struct FamilyReducer void stepType() { - TypeId subject = reducerLog->follow(queuedTys.front()); + TypeId subject = log.follow(queuedTys.front()); queuedTys.pop_front(); if (irreducible.contains(subject)) return; - if (const TypeFamilyInstanceType* tfit = reducerLog->get(subject)) + if (const TypeFamilyInstanceType* tfit = log.get(subject)) { if (!testParameters(subject, tfit)) return; - TypeFamilyReductionResult result = tfit->family->reducer(tfit->typeArguments, tfit->packArguments, arena, builtins, reducerLog); + TypeFamilyReductionResult result = + tfit->family->reducer(tfit->typeArguments, tfit->packArguments, arena, builtins, NotNull{&log}, scope, normalizer); handleFamilyReduction(subject, result); } } void stepPack() { - TypePackId subject = reducerLog->follow(queuedTps.front()); + TypePackId subject = log.follow(queuedTps.front()); queuedTps.pop_front(); if (irreducible.contains(subject)) return; - if (const TypeFamilyInstanceTypePack* tfit = reducerLog->get(subject)) + if (const TypeFamilyInstanceTypePack* tfit = log.get(subject)) { if (!testParameters(subject, tfit)) return; TypeFamilyReductionResult result = - tfit->family->reducer(tfit->typeArguments, tfit->packArguments, arena, builtins, reducerLog); + tfit->family->reducer(tfit->typeArguments, tfit->packArguments, arena, builtins, NotNull{&log}, scope, normalizer); handleFamilyReduction(subject, result); } } @@ -253,9 +267,9 @@ struct FamilyReducer }; static FamilyGraphReductionResult reduceFamiliesInternal(std::deque queuedTys, std::deque queuedTps, Location location, - NotNull arena, NotNull builtins, TxnLog* log, bool force) + NotNull arena, NotNull builtins, NotNull scope, NotNull normalizer, TxnLog* log, bool force) { - FamilyReducer reducer{std::move(queuedTys), std::move(queuedTps), location, arena, builtins, log, force}; + FamilyReducer reducer{std::move(queuedTys), std::move(queuedTps), location, arena, builtins, scope, normalizer, log, force}; int iterationCount = 0; while (!reducer.done()) @@ -273,8 +287,8 @@ static FamilyGraphReductionResult reduceFamiliesInternal(std::deque queu return std::move(reducer.result); } -FamilyGraphReductionResult reduceFamilies( - TypeId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log, bool force) +FamilyGraphReductionResult reduceFamilies(TypeId entrypoint, Location location, NotNull arena, NotNull builtins, + NotNull scope, NotNull normalizer, TxnLog* log, bool force) { InstanceCollector collector; @@ -287,11 +301,11 @@ FamilyGraphReductionResult reduceFamilies( return FamilyGraphReductionResult{}; } - return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, log, force); + return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, scope, normalizer, log, force); } -FamilyGraphReductionResult reduceFamilies( - TypePackId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log, bool force) +FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location location, NotNull arena, NotNull builtins, + NotNull scope, NotNull normalizer, TxnLog* log, bool force) { InstanceCollector collector; @@ -304,7 +318,113 @@ FamilyGraphReductionResult reduceFamilies( return FamilyGraphReductionResult{}; } - return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, log, force); + return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, scope, normalizer, log, force); +} + +bool isPending(TypeId ty, NotNull log) +{ + return log->is(ty) || log->is(ty) || log->is(ty) || log->is(ty); +} + +TypeFamilyReductionResult addFamilyFn(std::vector typeParams, std::vector packParams, NotNull arena, + NotNull builtins, NotNull log, NotNull scope, NotNull normalizer) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + // TODO: ICE? + LUAU_ASSERT(false); + return {std::nullopt, true, {}, {}}; + } + + TypeId lhsTy = log->follow(typeParams.at(0)); + TypeId rhsTy = log->follow(typeParams.at(1)); + + if (isNumber(lhsTy) && isNumber(rhsTy)) + { + return {builtins->numberType, false, {}, {}}; + } + else if (log->is(lhsTy) || log->is(rhsTy)) + { + return {builtins->anyType, false, {}, {}}; + } + else if (log->is(lhsTy) || log->is(rhsTy)) + { + return {builtins->errorRecoveryType(), false, {}, {}}; + } + else if (log->is(lhsTy) || log->is(rhsTy)) + { + return {builtins->neverType, false, {}, {}}; + } + else if (isPending(lhsTy, log)) + { + return {std::nullopt, false, {lhsTy}, {}}; + } + else if (isPending(rhsTy, log)) + { + return {std::nullopt, false, {rhsTy}, {}}; + } + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional addMm = findMetatableEntry(builtins, dummy, lhsTy, "__add", Location{}); + bool reversed = false; + if (!addMm) + { + addMm = findMetatableEntry(builtins, dummy, rhsTy, "__add", Location{}); + reversed = true; + } + + if (!addMm) + return {std::nullopt, true, {}, {}}; + + if (isPending(log->follow(*addMm), log)) + return {std::nullopt, false, {log->follow(*addMm)}, {}}; + + const FunctionType* mmFtv = log->get(log->follow(*addMm)); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + Instantiation instantiation{log.get(), arena.get(), TypeLevel{}, scope.get()}; + if (std::optional instantiatedAddMm = instantiation.substitute(log->follow(*addMm))) + { + if (const FunctionType* instantiatedMmFtv = get(*instantiatedAddMm)) + { + std::vector inferredArgs; + if (!reversed) + inferredArgs = {lhsTy, rhsTy}; + else + inferredArgs = {rhsTy, lhsTy}; + + TypePackId inferredArgPack = arena->addTypePack(std::move(inferredArgs)); + Unifier u{normalizer, Mode::Strict, scope, Location{}, Variance::Covariant, log.get()}; + u.tryUnify(inferredArgPack, instantiatedMmFtv->argTypes); + + if (std::optional ret = first(instantiatedMmFtv->retTypes); ret && u.errors.empty()) + { + return {u.log.follow(*ret), false, {}, {}}; + } + else + { + return {std::nullopt, true, {}, {}}; + } + } + else + { + return {builtins->errorRecoveryType(), false, {}, {}}; + } + } + else + { + // TODO: Not the nicest logic here. + return {std::nullopt, true, {}, {}}; + } +} + +BuiltinTypeFamilies::BuiltinTypeFamilies() + : addFamily{"Add", addFamilyFn} +{ } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 94c64ee25..7e6803990 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -18,7 +18,6 @@ #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include "Luau/TypeUtils.h" #include "Luau/VisitType.h" @@ -269,7 +268,6 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule.reset(new Module); currentModule->name = module.name; currentModule->humanReadableName = module.humanReadableName; - currentModule->reduction = std::make_unique(NotNull{¤tModule->internalTypes}, builtinTypes, NotNull{iceHandler}); currentModule->type = module.type; currentModule->allocator = module.allocator; currentModule->names = module.names; @@ -4842,7 +4840,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat ty = follow(ty); const FunctionType* ftv = get(ty); - if (ftv && ftv->hasNoGenerics) + if (ftv && ftv->hasNoFreeOrGenericTypes) return ty; Instantiation instantiation{log, ¤tModule->internalTypes, scope->level, /*scope*/ nullptr}; diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp deleted file mode 100644 index b81cca7ba..000000000 --- a/Analysis/src/TypeReduction.cpp +++ /dev/null @@ -1,1200 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeReduction.h" - -#include "Luau/Common.h" -#include "Luau/Error.h" -#include "Luau/RecursionCounter.h" -#include "Luau/VisitType.h" - -#include -#include - -LUAU_FASTINTVARIABLE(LuauTypeReductionCartesianProductLimit, 100'000) -LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 300) -LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) - -namespace Luau -{ - -namespace detail -{ -bool TypeReductionMemoization::isIrreducible(TypeId ty) -{ - ty = follow(ty); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto edge = types.find(ty); edge && edge->irreducible) - return true; - else if (get(ty) || get(ty) || get(ty)) - return false; - else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) - return false; - else - return true; -} - -bool TypeReductionMemoization::isIrreducible(TypePackId tp) -{ - tp = follow(tp); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto edge = typePacks.find(tp); edge && edge->irreducible) - return true; - else if (get(tp) || get(tp)) - return false; - else if (auto vtp = get(tp)) - return isIrreducible(vtp->ty); - else - return true; -} - -TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy) -{ - ty = follow(ty); - reducedTy = follow(reducedTy); - - // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. - // We don't need to recurse much further than that, because we already record the irreducibility from - // the bottom up. - bool irreducible = isIrreducible(reducedTy); - if (auto it = get(reducedTy)) - { - for (TypeId part : it) - irreducible &= isIrreducible(part); - } - else if (auto ut = get(reducedTy)) - { - for (TypeId option : ut) - irreducible &= isIrreducible(option); - } - else if (auto tt = get(reducedTy)) - { - for (auto& [k, p] : tt->props) - irreducible &= isIrreducible(p.type()); - - if (tt->indexer) - { - irreducible &= isIrreducible(tt->indexer->indexType); - irreducible &= isIrreducible(tt->indexer->indexResultType); - } - - for (auto ta : tt->instantiatedTypeParams) - irreducible &= isIrreducible(ta); - - for (auto tpa : tt->instantiatedTypePackParams) - irreducible &= isIrreducible(tpa); - } - else if (auto mt = get(reducedTy)) - { - irreducible &= isIrreducible(mt->table); - irreducible &= isIrreducible(mt->metatable); - } - else if (auto ft = get(reducedTy)) - { - irreducible &= isIrreducible(ft->argTypes); - irreducible &= isIrreducible(ft->retTypes); - } - else if (auto nt = get(reducedTy)) - irreducible &= isIrreducible(nt->ty); - - types[ty] = {reducedTy, irreducible}; - types[reducedTy] = {reducedTy, irreducible}; - return reducedTy; -} - -TypePackId TypeReductionMemoization::memoize(TypePackId tp, TypePackId reducedTp) -{ - tp = follow(tp); - reducedTp = follow(reducedTp); - - bool irreducible = isIrreducible(reducedTp); - TypePackIterator it = begin(tp); - while (it != end(tp)) - { - irreducible &= isIrreducible(*it); - ++it; - } - - if (it.tail()) - irreducible &= isIrreducible(*it.tail()); - - typePacks[tp] = {reducedTp, irreducible}; - typePacks[reducedTp] = {reducedTp, irreducible}; - return reducedTp; -} - -std::optional> TypeReductionMemoization::memoizedof(TypeId ty) const -{ - auto fetchContext = [this](TypeId ty) -> std::optional> { - if (auto edge = types.find(ty)) - return *edge; - else - return std::nullopt; - }; - - TypeId currentTy = ty; - std::optional> lastEdge; - while (auto edge = fetchContext(currentTy)) - { - lastEdge = edge; - if (edge->irreducible) - return edge; - else if (edge->type == currentTy) - return edge; - else - currentTy = edge->type; - } - - return lastEdge; -} - -std::optional> TypeReductionMemoization::memoizedof(TypePackId tp) const -{ - auto fetchContext = [this](TypePackId tp) -> std::optional> { - if (auto edge = typePacks.find(tp)) - return *edge; - else - return std::nullopt; - }; - - TypePackId currentTp = tp; - std::optional> lastEdge; - while (auto edge = fetchContext(currentTp)) - { - lastEdge = edge; - if (edge->irreducible) - return edge; - else if (edge->type == currentTp) - return edge; - else - currentTp = edge->type; - } - - return lastEdge; -} -} // namespace detail - -namespace -{ - -template -std::pair get2(const Thing& one, const Thing& two) -{ - const A* a = get(one); - const B* b = get(two); - return a && b ? std::make_pair(a, b) : std::make_pair(nullptr, nullptr); -} - -struct TypeReducer -{ - NotNull arena; - NotNull builtinTypes; - NotNull handle; - NotNull memoization; - DenseHashSet* cyclics; - - int depth = 0; - - TypeId reduce(TypeId ty); - TypePackId reduce(TypePackId tp); - - std::optional intersectionType(TypeId left, TypeId right); - std::optional unionType(TypeId left, TypeId right); - TypeId tableType(TypeId ty); - TypeId functionType(TypeId ty); - TypeId negationType(TypeId ty); - - using BinaryFold = std::optional (TypeReducer::*)(TypeId, TypeId); - using UnaryFold = TypeId (TypeReducer::*)(TypeId); - - template - LUAU_NOINLINE std::pair copy(TypeId ty, const T* t) - { - ty = follow(ty); - - if (auto edge = memoization->memoizedof(ty)) - return {edge->type, getMutable(edge->type)}; - - // We specifically do not want to use [`detail::TypeReductionMemoization::memoize`] because that will - // potentially consider these copiedTy to be reducible, but we need this to resolve cyclic references - // without attempting to recursively reduce it, causing copies of copies of copies of... - TypeId copiedTy = arena->addType(*t); - memoization->types[ty] = {copiedTy, true}; - memoization->types[copiedTy] = {copiedTy, true}; - return {copiedTy, getMutable(copiedTy)}; - } - - template - void foldl_impl(Iter it, Iter endIt, BinaryFold f, std::vector* result, bool* didReduce) - { - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - while (it != endIt) - { - TypeId right = reduce(*it); - *didReduce |= right != follow(*it); - - // We're hitting a case where the `currentTy` returned a type that's the same as `T`. - // e.g. `(string?) & ~(false | nil)` became `(string?) & (~false & ~nil)` but the current iterator we're consuming doesn't know this. - // We will need to recurse and traverse that first. - if (auto t = get(right)) - { - foldl_impl(begin(t), end(t), f, result, didReduce); - ++it; - continue; - } - - bool replaced = false; - auto resultIt = result->begin(); - while (resultIt != result->end()) - { - TypeId left = *resultIt; - if (left == right) - { - replaced = true; - ++resultIt; - continue; - } - - std::optional reduced = (this->*f)(left, right); - if (reduced) - { - *resultIt = *reduced; - ++resultIt; - replaced = true; - } - else - { - ++resultIt; - continue; - } - } - - if (!replaced) - result->push_back(right); - - *didReduce |= replaced; - ++it; - } - } - - template - TypeId flatten(std::vector&& types) - { - if (types.size() == 1) - return types[0]; - else - return arena->addType(T{std::move(types)}); - } - - template - TypeId foldl(Iter it, Iter endIt, std::optional ty, BinaryFold f) - { - std::vector result; - bool didReduce = false; - foldl_impl(it, endIt, f, &result, &didReduce); - - // If we've done any reduction, then we'll need to reduce it again, e.g. - // `"a" | "b" | string` is reduced into `string | string`, which is then reduced into `string`. - if (!didReduce) - return ty ? *ty : flatten(std::move(result)); - else - return reduce(flatten(std::move(result))); - } - - template - TypeId apply(BinaryFold f, TypeId left, TypeId right) - { - std::vector types{left, right}; - return foldl(begin(types), end(types), std::nullopt, f); - } - - template - TypeId distribute(TypeIterator it, TypeIterator endIt, BinaryFold f, TypeId ty) - { - std::vector result; - while (it != endIt) - { - result.push_back(apply(f, *it, ty)); - ++it; - } - return flatten(std::move(result)); - } -}; - -TypeId TypeReducer::reduce(TypeId ty) -{ - ty = follow(ty); - - if (auto edge = memoization->memoizedof(ty)) - { - if (edge->irreducible) - return edge->type; - else - ty = follow(edge->type); - } - else if (cyclics->contains(ty)) - return ty; - - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - TypeId result = nullptr; - if (auto i = get(ty)) - result = foldl(begin(i), end(i), ty, &TypeReducer::intersectionType); - else if (auto u = get(ty)) - result = foldl(begin(u), end(u), ty, &TypeReducer::unionType); - else if (get(ty) || get(ty)) - result = tableType(ty); - else if (get(ty)) - result = functionType(ty); - else if (get(ty)) - result = negationType(ty); - else - result = ty; - - return memoization->memoize(ty, result); -} - -TypePackId TypeReducer::reduce(TypePackId tp) -{ - tp = follow(tp); - - if (auto edge = memoization->memoizedof(tp)) - { - if (edge->irreducible) - return edge->type; - else - tp = edge->type; - } - else if (cyclics->contains(tp)) - return tp; - - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - bool didReduce = false; - TypePackIterator it = begin(tp); - - std::vector head; - while (it != end(tp)) - { - TypeId reducedTy = reduce(*it); - head.push_back(reducedTy); - didReduce |= follow(*it) != follow(reducedTy); - ++it; - } - - std::optional tail = it.tail(); - if (tail) - { - if (auto vtp = get(follow(*it.tail()))) - { - TypeId reducedTy = reduce(vtp->ty); - if (follow(vtp->ty) != follow(reducedTy)) - { - tail = arena->addTypePack(VariadicTypePack{reducedTy, vtp->hidden}); - didReduce = true; - } - } - } - - if (!didReduce) - return memoization->memoize(tp, tp); - else if (head.empty() && tail) - return memoization->memoize(tp, *tail); - else - return memoization->memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); -} - -std::optional TypeReducer::intersectionType(TypeId left, TypeId right) -{ - if (get(left)) - return left; // never & T ~ never - else if (get(right)) - return right; // T & never ~ never - else if (get(left)) - return right; // unknown & T ~ T - else if (get(right)) - return left; // T & unknown ~ T - else if (get(left)) - return right; // any & T ~ T - else if (get(right)) - return left; // T & any ~ T - else if (get(left)) - return std::nullopt; // 'a & T ~ 'a & T - else if (get(right)) - return std::nullopt; // T & 'a ~ T & 'a - else if (get(left)) - return std::nullopt; // G & T ~ G & T - else if (get(right)) - return std::nullopt; // T & G ~ T & G - else if (get(left)) - return std::nullopt; // error & T ~ error & T - else if (get(right)) - return std::nullopt; // T & error ~ T & error - else if (get(left)) - return std::nullopt; // *blocked* & T ~ *blocked* & T - else if (get(right)) - return std::nullopt; // T & *blocked* ~ T & *blocked* - else if (get(left)) - return std::nullopt; // *pending* & T ~ *pending* & T - else if (get(right)) - return std::nullopt; // T & *pending* ~ T & *pending* - else if (auto [utl, utr] = get2(left, right); utl && utr) - { - std::vector parts; - for (TypeId optionl : utl) - { - for (TypeId optionr : utr) - parts.push_back(apply(&TypeReducer::intersectionType, optionl, optionr)); - } - - return reduce(flatten(std::move(parts))); // (T | U) & (A | B) ~ (T & A) | (T & B) | (U & A) | (U & B) - } - else if (auto ut = get(left)) - return reduce(distribute(begin(ut), end(ut), &TypeReducer::intersectionType, right)); // (A | B) & T ~ (A & T) | (B & T) - else if (get(right)) - return intersectionType(right, left); // T & (A | B) ~ (A | B) & T - else if (auto [p1, p2] = get2(left, right); p1 && p2) - { - if (p1->type == p2->type) - return left; // P1 & P2 ~ P1 iff P1 == P2 - else - return builtinTypes->neverType; // P1 & P2 ~ never iff P1 != P2 - } - else if (auto [p, s] = get2(left, right); p && s) - { - if (p->type == PrimitiveType::String && get(s)) - return right; // string & "A" ~ "A" - else if (p->type == PrimitiveType::Boolean && get(s)) - return right; // boolean & true ~ true - else - return builtinTypes->neverType; // string & true ~ never - } - else if (auto [s, p] = get2(left, right); s && p) - return intersectionType(right, left); // S & P ~ P & S - else if (auto [p, f] = get2(left, right); p && f) - { - if (p->type == PrimitiveType::Function) - return right; // function & () -> () ~ () -> () - else - return builtinTypes->neverType; // string & () -> () ~ never - } - else if (auto [f, p] = get2(left, right); f && p) - return intersectionType(right, left); // () -> () & P ~ P & () -> () - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return right; // table & {} ~ {} - else - return builtinTypes->neverType; // string & {} ~ never - } - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return right; // table & {} ~ {} - else - return builtinTypes->neverType; // string & {} ~ never - } - else if (auto [t, p] = get2(left, right); t && p) - return intersectionType(right, left); // {} & P ~ P & {} - else if (auto [t, p] = get2(left, right); t && p) - return intersectionType(right, left); // M & P ~ P & M - else if (auto [s1, s2] = get2(left, right); s1 && s2) - { - if (*s1 == *s2) - return left; // "a" & "a" ~ "a" - else - return builtinTypes->neverType; // "a" & "b" ~ never - } - else if (auto [c1, c2] = get2(left, right); c1 && c2) - { - if (isSubclass(c1, c2)) - return left; // Derived & Base ~ Derived - else if (isSubclass(c2, c1)) - return right; // Base & Derived ~ Derived - else - return builtinTypes->neverType; // Base & Unrelated ~ never - } - else if (auto [f1, f2] = get2(left, right); f1 && f2) - return std::nullopt; // TODO - else if (auto [t1, t2] = get2(left, right); t1 && t2) - { - if (t1->state == TableState::Free || t2->state == TableState::Free) - return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } - else if (t1->state == TableState::Generic || t2->state == TableState::Generic) - return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } - - if (cyclics->contains(left)) - return std::nullopt; // (t1 where t1 = { p: t1 }) & {} ~ t1 & {} - else if (cyclics->contains(right)) - return std::nullopt; // {} & (t1 where t1 = { p: t1 }) ~ {} & t1 - - TypeId resultTy = arena->addType(TableType{}); - TableType* table = getMutable(resultTy); - table->state = t1->state == TableState::Sealed || t2->state == TableState::Sealed ? TableState::Sealed : TableState::Unsealed; - - for (const auto& [name, prop] : t1->props) - { - // TODO: when t1 has properties, we should also intersect that with the indexer in t2 if it exists, - // even if we have the corresponding property in the other one. - if (auto other = t2->props.find(name); other != t2->props.end()) - { - TypeId propTy = apply(&TypeReducer::intersectionType, prop.type(), other->second.type()); - if (get(propTy)) - return builtinTypes->neverType; // { p : string } & { p : number } ~ { p : string & number } ~ { p : never } ~ never - else - table->props[name] = {propTy}; // { p : string } & { p : ~"a" } ~ { p : string & ~"a" } - } - else - table->props[name] = prop; // { p : string } & {} ~ { p : string } - } - - for (const auto& [name, prop] : t2->props) - { - // TODO: And vice versa, t2 properties against t1 indexer if it exists, - // even if we have the corresponding property in the other one. - if (!t1->props.count(name)) - table->props[name] = {reduce(prop.type())}; // {} & { p : string & string } ~ { p : string } - } - - if (t1->indexer && t2->indexer) - { - TypeId keyTy = apply(&TypeReducer::intersectionType, t1->indexer->indexType, t2->indexer->indexType); - if (get(keyTy)) - return std::nullopt; // { [string]: _ } & { [number]: _ } ~ { [string]: _ } & { [number]: _ } - - TypeId valueTy = apply(&TypeReducer::intersectionType, t1->indexer->indexResultType, t2->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { [string]: number } & { [string]: string } ~ { [string]: never } - } - else if (t1->indexer) - { - TypeId keyTy = reduce(t1->indexer->indexType); - TypeId valueTy = reduce(t1->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { [number]: boolean } & { p : string } ~ { p : string, [number]: boolean } - } - else if (t2->indexer) - { - TypeId keyTy = reduce(t2->indexer->indexType); - TypeId valueTy = reduce(t2->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { p : string } & { [number]: boolean } ~ { p : string, [number]: boolean } - } - - return resultTy; - } - else if (auto [mt, tt] = get2(left, right); mt && tt) - return std::nullopt; // TODO - else if (auto [tt, mt] = get2(left, right); tt && mt) - return intersectionType(right, left); // T & M ~ M & T - else if (auto [m1, m2] = get2(left, right); m1 && m2) - return std::nullopt; // TODO - else if (auto [nl, nr] = get2(left, right); nl && nr) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - TypeId nrTy = follow(nr->ty); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - - if (auto [npl, npr] = get2(nlTy, nrTy); npl && npr) - { - if (npl->type == npr->type) - return left; // ~P1 & ~P2 ~ ~P1 iff P1 == P2 - else - return std::nullopt; // ~P1 & ~P2 ~ ~P1 & ~P2 iff P1 != P2 - } - else if (auto [nsl, nsr] = get2(nlTy, nrTy); nsl && nsr) - { - if (*nsl == *nsr) - return left; // ~"A" & ~"A" ~ ~"A" - else - return std::nullopt; // ~"A" & ~"B" ~ ~"A" & ~"B" - } - else if (auto [ns, np] = get2(nlTy, nrTy); ns && np) - { - if (get(ns) && np->type == PrimitiveType::String) - return right; // ~"A" & ~string ~ ~string - else if (get(ns) && np->type == PrimitiveType::Boolean) - return right; // ~false & ~boolean ~ ~boolean - else - return std::nullopt; // ~"A" | ~P ~ ~"A" & ~P - } - else if (auto [np, ns] = get2(nlTy, nrTy); np && ns) - return intersectionType(right, left); // ~P & ~S ~ ~S & ~P - else - return std::nullopt; // ~T & ~U ~ ~T & ~U - } - else if (auto nl = get(left)) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - - if (auto [np, p] = get2(nlTy, right); np && p) - { - if (np->type == p->type) - return builtinTypes->neverType; // ~P1 & P2 ~ never iff P1 == P2 - else - return right; // ~P1 & P2 ~ P2 iff P1 != P2 - } - else if (auto [ns, s] = get2(nlTy, right); ns && s) - { - if (*ns == *s) - return builtinTypes->neverType; // ~"A" & "A" ~ never - else - return right; // ~"A" & "B" ~ "B" - } - else if (auto [ns, p] = get2(nlTy, right); ns && p) - { - if (get(ns) && p->type == PrimitiveType::String) - return std::nullopt; // ~"A" & string ~ ~"A" & string - else if (get(ns) && p->type == PrimitiveType::Boolean) - { - // Because booleans contain a fixed amount of values (2), we can do something cooler with this one. - const BooleanSingleton* b = get(ns); - return arena->addType(SingletonType{BooleanSingleton{!b->value}}); // ~false & boolean ~ true - } - else - return right; // ~"A" & number ~ number - } - else if (auto [np, s] = get2(nlTy, right); np && s) - { - if (np->type == PrimitiveType::String && get(s)) - return builtinTypes->neverType; // ~string & "A" ~ never - else if (np->type == PrimitiveType::Boolean && get(s)) - return builtinTypes->neverType; // ~boolean & true ~ never - else - return right; // ~P & "A" ~ "A" - } - else if (auto [np, f] = get2(nlTy, right); np && f) - { - if (np->type == PrimitiveType::Function) - return builtinTypes->neverType; // ~function & () -> () ~ never - else - return right; // ~string & () -> () ~ () -> () - } - else if (auto [nc, c] = get2(nlTy, right); nc && c) - { - if (isSubclass(c, nc)) - return builtinTypes->neverType; // ~Base & Derived ~ never - else if (isSubclass(nc, c)) - return std::nullopt; // ~Derived & Base ~ ~Derived & Base - else - return right; // ~Base & Unrelated ~ Unrelated - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return builtinTypes->neverType; // ~table & {} ~ never - else - return right; // ~string & {} ~ {} - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return builtinTypes->neverType; // ~table & {} ~ never - else - return right; // ~string & {} ~ {} - } - else - return right; // ~T & U ~ U - } - else if (get(right)) - return intersectionType(right, left); // T & ~U ~ ~U & T - else - return builtinTypes->neverType; // for all T and U except the ones handled above, T & U ~ never -} - -std::optional TypeReducer::unionType(TypeId left, TypeId right) -{ - LUAU_ASSERT(!get(left)); - LUAU_ASSERT(!get(right)); - - if (get(left)) - return right; // never | T ~ T - else if (get(right)) - return left; // T | never ~ T - else if (get(left)) - return left; // unknown | T ~ unknown - else if (get(right)) - return right; // T | unknown ~ unknown - else if (get(left)) - return left; // any | T ~ any - else if (get(right)) - return right; // T | any ~ any - else if (get(left)) - return std::nullopt; // error | T ~ error | T - else if (get(right)) - return std::nullopt; // T | error ~ T | error - else if (auto [p1, p2] = get2(left, right); p1 && p2) - { - if (p1->type == p2->type) - return left; // P1 | P2 ~ P1 iff P1 == P2 - else - return std::nullopt; // P1 | P2 ~ P1 | P2 iff P1 != P2 - } - else if (auto [p, s] = get2(left, right); p && s) - { - if (p->type == PrimitiveType::String && get(s)) - return left; // string | "A" ~ string - else if (p->type == PrimitiveType::Boolean && get(s)) - return left; // boolean | true ~ boolean - else - return std::nullopt; // string | true ~ string | true - } - else if (auto [s, p] = get2(left, right); s && p) - return unionType(right, left); // S | P ~ P | S - else if (auto [p, f] = get2(left, right); p && f) - { - if (p->type == PrimitiveType::Function) - return left; // function | () -> () ~ function - else - return std::nullopt; // P | () -> () ~ P | () -> () - } - else if (auto [f, p] = get2(left, right); f && p) - return unionType(right, left); // () -> () | P ~ P | () -> () - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return left; // table | {} ~ table - else - return std::nullopt; // P | {} ~ P | {} - } - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return left; // table | {} ~ table - else - return std::nullopt; // P | {} ~ P | {} - } - else if (auto [t, p] = get2(left, right); t && p) - return unionType(right, left); // {} | P ~ P | {} - else if (auto [t, p] = get2(left, right); t && p) - return unionType(right, left); // M | P ~ P | M - else if (auto [s1, s2] = get2(left, right); s1 && s2) - { - if (*s1 == *s2) - return left; // "a" | "a" ~ "a" - else - return std::nullopt; // "a" | "b" ~ "a" | "b" - } - else if (auto [c1, c2] = get2(left, right); c1 && c2) - { - if (isSubclass(c1, c2)) - return right; // Derived | Base ~ Base - else if (isSubclass(c2, c1)) - return left; // Base | Derived ~ Base - else - return std::nullopt; // Base | Unrelated ~ Base | Unrelated - } - else if (auto [nt, it] = get2(left, right); nt && it) - return reduce(distribute(begin(it), end(it), &TypeReducer::unionType, left)); // ~T | (A & B) ~ (~T | A) & (~T | B) - else if (auto [it, nt] = get2(left, right); it && nt) - return unionType(right, left); // (A & B) | ~T ~ ~T | (A & B) - else if (auto it = get(left)) - { - bool didReduce = false; - std::vector parts; - for (TypeId part : it) - { - auto nt = get(part); - if (!nt) - { - parts.push_back(part); - continue; - } - - auto redex = unionType(part, right); - if (redex && get(*redex)) - { - didReduce = true; - continue; - } - - parts.push_back(part); - } - - if (didReduce) - return flatten(std::move(parts)); // (T & ~nil) | nil ~ T - else - return std::nullopt; // (T & ~nil) | U - } - else if (get(right)) - return unionType(right, left); // A | (T & U) ~ (T & U) | A - else if (auto [nl, nr] = get2(left, right); nl && nr) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - TypeId nrTy = follow(nr->ty); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - - if (auto [npl, npr] = get2(nlTy, nrTy); npl && npr) - { - if (npl->type == npr->type) - return left; // ~P1 | ~P2 ~ ~P1 iff P1 == P2 - else - return builtinTypes->unknownType; // ~P1 | ~P2 ~ ~P1 iff P1 != P2 - } - else if (auto [nsl, nsr] = get2(nlTy, nrTy); nsl && nsr) - { - if (*nsl == *nsr) - return left; // ~"A" | ~"A" ~ ~"A" - else - return builtinTypes->unknownType; // ~"A" | ~"B" ~ unknown - } - else if (auto [ns, np] = get2(nlTy, nrTy); ns && np) - { - if (get(ns) && np->type == PrimitiveType::String) - return left; // ~"A" | ~string ~ ~"A" - else if (get(ns) && np->type == PrimitiveType::Boolean) - return left; // ~false | ~boolean ~ ~false - else - return builtinTypes->unknownType; // ~"A" | ~P ~ unknown - } - else if (auto [np, ns] = get2(nlTy, nrTy); np && ns) - return unionType(right, left); // ~P | ~S ~ ~S | ~P - else - return std::nullopt; // TODO! - } - else if (auto nl = get(left)) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - - if (auto [np, p] = get2(nlTy, right); np && p) - { - if (np->type == p->type) - return builtinTypes->unknownType; // ~P1 | P2 ~ unknown iff P1 == P2 - else - return left; // ~P1 | P2 ~ ~P1 iff P1 != P2 - } - else if (auto [ns, s] = get2(nlTy, right); ns && s) - { - if (*ns == *s) - return builtinTypes->unknownType; // ~"A" | "A" ~ unknown - else - return left; // ~"A" | "B" ~ ~"A" - } - else if (auto [ns, p] = get2(nlTy, right); ns && p) - { - if (get(ns) && p->type == PrimitiveType::String) - return builtinTypes->unknownType; // ~"A" | string ~ unknown - else if (get(ns) && p->type == PrimitiveType::Boolean) - return builtinTypes->unknownType; // ~false | boolean ~ unknown - else - return left; // ~"A" | T ~ ~"A" - } - else if (auto [np, s] = get2(nlTy, right); np && s) - { - if (np->type == PrimitiveType::String && get(s)) - return std::nullopt; // ~string | "A" ~ ~string | "A" - else if (np->type == PrimitiveType::Boolean && get(s)) - { - const BooleanSingleton* b = get(s); - return negationType(arena->addType(SingletonType{BooleanSingleton{!b->value}})); // ~boolean | false ~ ~true - } - else - return left; // ~P | "A" ~ ~P - } - else if (auto [nc, c] = get2(nlTy, right); nc && c) - { - if (isSubclass(c, nc)) - return std::nullopt; // ~Base | Derived ~ ~Base | Derived - else if (isSubclass(nc, c)) - return builtinTypes->unknownType; // ~Derived | Base ~ unknown - else - return left; // ~Base | Unrelated ~ ~Base - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return std::nullopt; // ~table | {} ~ ~table | {} - else - return right; // ~P | {} ~ ~P | {} - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return std::nullopt; // ~table | {} ~ ~table | {} - else - return right; // ~P | M ~ ~P | M - } - else - return std::nullopt; // TODO - } - else if (get(right)) - return unionType(right, left); // T | ~U ~ ~U | T - else - return std::nullopt; // for all T and U except the ones handled above, T | U ~ T | U -} - -TypeId TypeReducer::tableType(TypeId ty) -{ - if (auto mt = get(ty)) - { - auto [copiedTy, copied] = copy(ty, mt); - copied->table = reduce(mt->table); - copied->metatable = reduce(mt->metatable); - return copiedTy; - } - else if (auto tt = get(ty)) - { - // Because of `typeof()`, we need to preserve pointer identity of free/unsealed tables so that - // all mutations that occurs on this will be applied without leaking the implementation details. - // As a result, we'll just use the type instead of cloning it if it's free/unsealed. - // - // We could choose to do in-place reductions here, but to be on the safer side, I propose that we do not. - if (tt->state == TableState::Free || tt->state == TableState::Unsealed) - return ty; - - auto [copiedTy, copied] = copy(ty, tt); - - for (auto& [name, prop] : copied->props) - { - TypeId propTy = reduce(prop.type()); - if (get(propTy)) - return builtinTypes->neverType; - else - prop.setType(propTy); - } - - if (copied->indexer) - { - TypeId keyTy = reduce(copied->indexer->indexType); - TypeId valueTy = reduce(copied->indexer->indexResultType); - copied->indexer = TableIndexer{keyTy, valueTy}; - } - - for (TypeId& ty : copied->instantiatedTypeParams) - ty = reduce(ty); - - for (TypePackId& tp : copied->instantiatedTypePackParams) - tp = reduce(tp); - - return copiedTy; - } - else - handle->ice("TypeReducer::tableType expects a TableType or MetatableType"); -} - -TypeId TypeReducer::functionType(TypeId ty) -{ - const FunctionType* f = get(ty); - if (!f) - handle->ice("TypeReducer::functionType expects a FunctionType"); - - // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. - auto [copiedTy, copied] = copy(ty, f); - copied->argTypes = reduce(f->argTypes); - copied->retTypes = reduce(f->retTypes); - return copiedTy; -} - -TypeId TypeReducer::negationType(TypeId ty) -{ - const NegationType* n = get(ty); - if (!n) - return arena->addType(NegationType{ty}); - - TypeId negatedTy = follow(n->ty); - - if (auto nn = get(negatedTy)) - return nn->ty; // ~~T ~ T - else if (get(negatedTy)) - return builtinTypes->unknownType; // ~never ~ unknown - else if (get(negatedTy)) - return builtinTypes->neverType; // ~unknown ~ never - else if (get(negatedTy)) - return builtinTypes->anyType; // ~any ~ any - else if (auto ni = get(negatedTy)) - { - std::vector options; - for (TypeId part : ni) - options.push_back(negationType(arena->addType(NegationType{part}))); - return reduce(flatten(std::move(options))); // ~(T & U) ~ (~T | ~U) - } - else if (auto nu = get(negatedTy)) - { - std::vector parts; - for (TypeId option : nu) - parts.push_back(negationType(arena->addType(NegationType{option}))); - return reduce(flatten(std::move(parts))); // ~(T | U) ~ (~T & ~U) - } - else - return ty; // for all T except the ones handled above, ~T ~ ~T -} - -struct MarkCycles : TypeVisitor -{ - DenseHashSet cyclics{nullptr}; - - void cycle(TypeId ty) override - { - cyclics.insert(follow(ty)); - } - - void cycle(TypePackId tp) override - { - cyclics.insert(follow(tp)); - } - - bool visit(TypeId ty) override - { - return !cyclics.find(follow(ty)); - } - - bool visit(TypePackId tp) override - { - return !cyclics.find(follow(tp)); - } -}; -} // namespace - -TypeReduction::TypeReduction( - NotNull arena, NotNull builtinTypes, NotNull handle, const TypeReductionOptions& opts) - : arena(arena) - , builtinTypes(builtinTypes) - , handle(handle) - , options(opts) -{ -} - -std::optional TypeReduction::reduce(TypeId ty) -{ - ty = follow(ty); - - if (FFlag::DebugLuauDontReduceTypes) - return ty; - else if (!options.allowTypeReductionsFromOtherArenas && ty->owningArena != arena) - return ty; - else if (auto edge = memoization.memoizedof(ty)) - { - if (edge->irreducible) - return edge->type; - else - ty = edge->type; - } - else if (hasExceededCartesianProductLimit(ty)) - return std::nullopt; - - try - { - MarkCycles finder; - finder.traverse(ty); - - TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; - return reducer.reduce(ty); - } - catch (const RecursionLimitException&) - { - return std::nullopt; - } -} - -std::optional TypeReduction::reduce(TypePackId tp) -{ - tp = follow(tp); - - if (FFlag::DebugLuauDontReduceTypes) - return tp; - else if (!options.allowTypeReductionsFromOtherArenas && tp->owningArena != arena) - return tp; - else if (auto edge = memoization.memoizedof(tp)) - { - if (edge->irreducible) - return edge->type; - else - tp = edge->type; - } - else if (hasExceededCartesianProductLimit(tp)) - return std::nullopt; - - try - { - MarkCycles finder; - finder.traverse(tp); - - TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; - return reducer.reduce(tp); - } - catch (const RecursionLimitException&) - { - return std::nullopt; - } -} - -std::optional TypeReduction::reduce(const TypeFun& fun) -{ - if (FFlag::DebugLuauDontReduceTypes) - return fun; - - // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. - if (auto reducedTy = reduce(fun.type)) - return TypeFun{fun.typeParams, fun.typePackParams, *reducedTy}; - - return std::nullopt; -} - -size_t TypeReduction::cartesianProductSize(TypeId ty) const -{ - ty = follow(ty); - - auto it = get(follow(ty)); - if (!it) - return 1; - - return std::accumulate(begin(it), end(it), size_t(1), [](size_t acc, TypeId ty) { - if (auto ut = get(ty)) - return acc * std::distance(begin(ut), end(ut)); - else if (get(ty)) - return acc * 0; - else - return acc * 1; - }); -} - -bool TypeReduction::hasExceededCartesianProductLimit(TypeId ty) const -{ - return cartesianProductSize(ty) >= size_t(FInt::LuauTypeReductionCartesianProductLimit); -} - -bool TypeReduction::hasExceededCartesianProductLimit(TypePackId tp) const -{ - TypePackIterator it = begin(tp); - - while (it != end(tp)) - { - if (hasExceededCartesianProductLimit(*it)) - return true; - - ++it; - } - - if (auto tail = it.tail()) - { - if (auto vtp = get(follow(*tail))) - { - if (hasExceededCartesianProductLimit(vtp->ty)) - return true; - } - } - - return false; -} - -} // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 56be40471..76428cf97 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -447,13 +447,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // "double-report" errors in some cases, like when trying to unify // identical type family instantiations like Add with // Add. - reduceFamilies(superTy, location, NotNull(types), builtinTypes, &log); + reduceFamilies(superTy, location, NotNull(types), builtinTypes, scope, normalizer, &log); superTy = log.follow(superTy); } if (log.get(subTy)) { - reduceFamilies(subTy, location, NotNull(types), builtinTypes, &log); + reduceFamilies(subTy, location, NotNull(types), builtinTypes, scope, normalizer, &log); subTy = log.follow(subTy); } diff --git a/CLI/Reduce.cpp b/CLI/Reduce.cpp index b7c780128..ffe670b8a 100644 --- a/CLI/Reduce.cpp +++ b/CLI/Reduce.cpp @@ -56,10 +56,9 @@ struct Reducer ParseResult parseResult; AstStatBlock* root; - std::string tempScriptName; + std::string scriptName; - std::string appName; - std::vector appArgs; + std::string command; std::string_view searchText; Reducer() @@ -99,10 +98,10 @@ struct Reducer } while (true); } - FILE* f = fopen(tempScriptName.c_str(), "w"); + FILE* f = fopen(scriptName.c_str(), "w"); if (!f) { - printf("Unable to open temp script to %s\n", tempScriptName.c_str()); + printf("Unable to open temp script to %s\n", scriptName.c_str()); exit(2); } @@ -113,7 +112,7 @@ struct Reducer if (written != source.size()) { printf("??? %zu %zu\n", written, source.size()); - printf("Unable to write to temp script %s\n", tempScriptName.c_str()); + printf("Unable to write to temp script %s\n", scriptName.c_str()); exit(3); } @@ -142,9 +141,15 @@ struct Reducer { writeTempScript(); - std::string command = appName + " " + escape(tempScriptName); - for (const auto& arg : appArgs) - command += " " + escape(arg); + std::string cmd = command; + while (true) + { + auto pos = cmd.find("{}"); + if (std::string::npos == pos) + break; + + cmd = cmd.substr(0, pos) + escape(scriptName) + cmd.substr(pos + 2); + } #if VERBOSE >= 1 printf("running %s\n", command.c_str()); @@ -424,30 +429,20 @@ struct Reducer } } - void run(const std::string scriptName, const std::string appName, const std::vector& appArgs, std::string_view source, + void run(const std::string scriptName, const std::string command, std::string_view source, std::string_view searchText) { - tempScriptName = scriptName; - if (tempScriptName.substr(tempScriptName.size() - 4) == ".lua") - { - tempScriptName.erase(tempScriptName.size() - 4); - tempScriptName += "-reduced.lua"; - } - else - { - this->tempScriptName = scriptName + "-reduced"; - } + this->scriptName = scriptName; #if 0 // Handy debugging trick: VS Code will update its view of the file in realtime as it is edited. - std::string wheee = "code " + tempScriptName; + std::string wheee = "code " + scriptName; system(wheee.c_str()); #endif - printf("Temp script: %s\n", tempScriptName.c_str()); + printf("Script: %s\n", scriptName.c_str()); - this->appName = appName; - this->appArgs = appArgs; + this->command = command; this->searchText = searchText; parseResult = Parser::parse(source.data(), source.size(), nameTable, allocator, parseOptions); @@ -470,13 +465,14 @@ struct Reducer writeTempScript(/* minify */ true); - printf("Done! Check %s\n", tempScriptName.c_str()); + printf("Done! Check %s\n", scriptName.c_str()); } }; [[noreturn]] void help(const std::vector& args) { - printf("Syntax: %s script application \"search text\" [arguments]\n", args[0].data()); + printf("Syntax: %s script command \"search text\"\n", args[0].data()); + printf(" Within command, use {} as a stand-in for the script being reduced\n"); exit(1); } @@ -484,7 +480,7 @@ int main(int argc, char** argv) { const std::vector args(argv, argv + argc); - if (args.size() < 4) + if (args.size() != 4) help(args); for (size_t i = 1; i < args.size(); ++i) @@ -496,7 +492,6 @@ int main(int argc, char** argv) const std::string scriptName = argv[1]; const std::string appName = argv[2]; const std::string searchText = argv[3]; - const std::vector appArgs(begin(args) + 4, end(args)); std::optional source = readFile(scriptName); @@ -507,5 +502,5 @@ int main(int argc, char** argv) } Reducer reducer; - reducer.run(scriptName, appName, appArgs, *source, searchText); + reducer.run(scriptName, appName, *source, searchText); } diff --git a/CMakeLists.txt b/CMakeLists.txt index b3b1573ac..b6e8b5913 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,6 +143,9 @@ endif() if(LUAU_NATIVE) target_compile_definitions(Luau.VM PUBLIC LUA_CUSTOM_EXECUTION=1) + if(LUAU_EXTERN_C) + target_compile_definitions(Luau.CodeGen PUBLIC LUACODEGEN_API=extern\"C\") + endif() endif() if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" AND MSVC_VERSION GREATER_EQUAL 1924) diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index e7733cd2c..09acfb4a9 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -80,6 +80,12 @@ class AssemblyBuilderA64 void asr(RegisterA64 dst, RegisterA64 src1, uint8_t src2); void ror(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + // Bitfields + void ubfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void ubfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void sbfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void sbfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + // Load // Note: paired loads are currently omitted for simplicity void ldr(RegisterA64 dst, AddressA64 src); @@ -212,7 +218,7 @@ class AssemblyBuilderA64 void placeFCMP(const char* name, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t opc); void placeFMOV(const char* name, RegisterA64 dst, double src, uint32_t op); void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op); - void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, uint8_t src2, uint8_t op, int immr, int imms); + void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms); void place(uint32_t word); diff --git a/CodeGen/include/luacodegen.h b/CodeGen/include/luacodegen.h new file mode 100644 index 000000000..654fc2c90 --- /dev/null +++ b/CodeGen/include/luacodegen.h @@ -0,0 +1,18 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +// Can be used to reconfigure visibility/exports for public APIs +#ifndef LUACODEGEN_API +#define LUACODEGEN_API extern +#endif + +struct lua_State; + +// returns 1 if Luau code generator is supported, 0 otherwise +LUACODEGEN_API int luau_codegen_supported(void); + +// create an instance of Luau code generator. you must check that this feature is supported using luau_codegen_supported(). +LUACODEGEN_API void luau_codegen_create(lua_State* L); + +// build target function and all inner functions +LUACODEGEN_API void luau_codegen_compile(lua_State* L, int idx); diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index 23b5b9f36..000dc85fd 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -280,6 +280,42 @@ void AssemblyBuilderA64::ror(RegisterA64 dst, RegisterA64 src1, uint8_t src2) placeBFM("ror", dst, src1, src2, 0b00'100111, src1.index, src2); } +void AssemblyBuilderA64::ubfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w) +{ + int size = dst.kind == KindA64::x ? 64 : 32; + LUAU_ASSERT(w > 0 && f + w <= size); + + // f * 100 + w is only used for disassembly printout; in the future we might replace it with two separate fields for readability + placeBFM("ubfiz", dst, src, f * 100 + w, 0b10'100110, (-f) & (size - 1), w - 1); +} + +void AssemblyBuilderA64::ubfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w) +{ + int size = dst.kind == KindA64::x ? 64 : 32; + LUAU_ASSERT(w > 0 && f + w <= size); + + // f * 100 + w is only used for disassembly printout; in the future we might replace it with two separate fields for readability + placeBFM("ubfx", dst, src, f * 100 + w, 0b10'100110, f, f + w - 1); +} + +void AssemblyBuilderA64::sbfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w) +{ + int size = dst.kind == KindA64::x ? 64 : 32; + LUAU_ASSERT(w > 0 && f + w <= size); + + // f * 100 + w is only used for disassembly printout; in the future we might replace it with two separate fields for readability + placeBFM("sbfiz", dst, src, f * 100 + w, 0b00'100110, (-f) & (size - 1), w - 1); +} + +void AssemblyBuilderA64::sbfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w) +{ + int size = dst.kind == KindA64::x ? 64 : 32; + LUAU_ASSERT(w > 0 && f + w <= size); + + // f * 100 + w is only used for disassembly printout; in the future we might replace it with two separate fields for readability + placeBFM("sbfx", dst, src, f * 100 + w, 0b00'100110, f, f + w - 1); +} + void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w || dst.kind == KindA64::s || dst.kind == KindA64::d || dst.kind == KindA64::q); @@ -1010,7 +1046,7 @@ void AssemblyBuilderA64::placeBM(const char* name, RegisterA64 dst, RegisterA64 commit(); } -void AssemblyBuilderA64::placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, uint8_t src2, uint8_t op, int immr, int imms) +void AssemblyBuilderA64::placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms) { if (logText) log(name, dst, src1, src2); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 714ddadd8..646038347 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -34,6 +34,7 @@ #include #include +#include #if defined(__x86_64__) || defined(_M_X64) #ifdef _MSC_VER @@ -61,33 +62,34 @@ namespace CodeGen static void* gPerfLogContext = nullptr; static PerfLogFn gPerfLogFn = nullptr; -static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) +struct NativeProto { - int sizecode = proto->sizecode; - int sizecodeAlloc = (sizecode + 1) & ~1; // align uint32_t array to 8 bytes so that NativeProto is aligned to 8 bytes + Proto* p; + void* execdata; + uintptr_t exectarget; +}; - void* memory = ::operator new(sizeof(NativeProto) + sizecodeAlloc * sizeof(uint32_t)); - NativeProto* result = new (static_cast(memory) + sizecodeAlloc * sizeof(uint32_t)) NativeProto; - result->proto = proto; +static NativeProto createNativeProto(Proto* proto, const IrBuilder& ir) +{ + int sizecode = proto->sizecode; - uint32_t* instOffsets = result->instOffsets; + uint32_t* instOffsets = new uint32_t[sizecode]; + uint32_t instTarget = ir.function.bcMapping[0].asmLocation; for (int i = 0; i < sizecode; i++) { - // instOffsets uses negative indexing for optimal codegen for RETURN opcode - instOffsets[-i] = ir.function.bcMapping[i].asmLocation; + LUAU_ASSERT(ir.function.bcMapping[i].asmLocation >= instTarget); + + instOffsets[i] = ir.function.bcMapping[i].asmLocation - instTarget; } - return result; + // entry target will be relocated when assembly is finalized + return {proto, instOffsets, instTarget}; } -static void destroyNativeProto(NativeProto* nativeProto) +static void destroyExecData(void* execdata) { - int sizecode = nativeProto->proto->sizecode; - int sizecodeAlloc = (sizecode + 1) & ~1; // align uint32_t array to 8 bytes so that NativeProto is aligned to 8 bytes - void* memory = reinterpret_cast(nativeProto) - sizecodeAlloc * sizeof(uint32_t); - - ::operator delete(memory); + delete[] static_cast(execdata); } static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) @@ -271,7 +273,7 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& } template -static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +static std::optional assembleFunction(AssemblyBuilder& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { if (options.includeAssembly || options.includeIr) { @@ -321,7 +323,7 @@ static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, if (build.logText) build.logAppend("; skipping (can't lower)\n\n"); - return nullptr; + return std::nullopt; } if (build.logText) @@ -337,23 +339,19 @@ static void onCloseState(lua_State* L) static void onDestroyFunction(lua_State* L, Proto* proto) { - NativeProto* nativeProto = getProtoExecData(proto); - LUAU_ASSERT(nativeProto->proto == proto); - - setProtoExecData(proto, nullptr); - destroyNativeProto(nativeProto); + destroyExecData(proto->execdata); + proto->execdata = nullptr; + proto->exectarget = 0; } static int onEnter(lua_State* L, Proto* proto) { NativeState* data = getNativeState(L); - NativeProto* nativeProto = getProtoExecData(proto); - LUAU_ASSERT(nativeProto); - LUAU_ASSERT(L->ci->savedpc); + LUAU_ASSERT(proto->execdata); + LUAU_ASSERT(L->ci->savedpc >= proto->code && L->ci->savedpc < proto->code + proto->sizecode); - // instOffsets uses negative indexing for optimal codegen for RETURN opcode - uintptr_t target = nativeProto->instBase + nativeProto->instOffsets[proto->code - L->ci->savedpc]; + uintptr_t target = proto->exectarget + static_cast(proto->execdata)[L->ci->savedpc - proto->code]; // Returns 1 to finish the function in the VM return GateFn(data->context.gateEntry)(L, proto, target, &data->context); @@ -361,7 +359,7 @@ static int onEnter(lua_State* L, Proto* proto) static void onSetBreakpoint(lua_State* L, Proto* proto, int instruction) { - if (!getProtoExecData(proto)) + if (!proto->execdata) return; LUAU_ASSERT(!"native breakpoints are not implemented"); @@ -444,8 +442,7 @@ void create(lua_State* L) data.codeAllocator.createBlockUnwindInfo = createBlockUnwindInfo; data.codeAllocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; - initFallbackTable(data); - initHelperFunctions(data); + initFunctions(data); #if defined(__x86_64__) || defined(_M_X64) if (!X64::initHeaderFunctions(data)) @@ -514,20 +511,20 @@ void compile(lua_State* L, int idx) X64::assembleHelpers(build, helpers); #endif - std::vector results; + std::vector results; results.reserve(protos.size()); // Skip protos that have been compiled during previous invocations of CodeGen::compile for (Proto* p : protos) - if (p && getProtoExecData(p) == nullptr) - if (NativeProto* np = assembleFunction(build, *data, helpers, p, {})) - results.push_back(np); + if (p && p->execdata == nullptr) + if (std::optional np = assembleFunction(build, *data, helpers, p, {})) + results.push_back(*np); // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module if (!build.finalize()) { - for (NativeProto* result : results) - destroyNativeProto(result); + for (NativeProto result : results) + destroyExecData(result.execdata); return; } @@ -542,36 +539,32 @@ void compile(lua_State* L, int idx) if (!data->codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), int(build.code.size() * sizeof(build.code[0])), nativeData, sizeNativeData, codeStart)) { - for (NativeProto* result : results) - destroyNativeProto(result); + for (NativeProto result : results) + destroyExecData(result.execdata); return; } if (gPerfLogFn && results.size() > 0) { - gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), results[0]->instOffsets[0], ""); + gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), ""); for (size_t i = 0; i < results.size(); ++i) { - uint32_t begin = results[i]->instOffsets[0]; - uint32_t end = i + 1 < results.size() ? results[i + 1]->instOffsets[0] : uint32_t(build.code.size() * sizeof(build.code[0])); + uint32_t begin = uint32_t(results[i].exectarget); + uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0])); LUAU_ASSERT(begin < end); - logPerfFunction(results[i]->proto, uintptr_t(codeStart) + begin, end - begin); + logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin); } } - // Record instruction base address; at runtime, instOffsets[] will be used as offsets from instBase - for (NativeProto* result : results) + for (NativeProto result : results) { - result->instBase = uintptr_t(codeStart); - result->entryTarget = uintptr_t(codeStart) + result->instOffsets[0]; + // the memory is now managed by VM and will be freed via onDestroyFunction + result.p->execdata = result.execdata; + result.p->exectarget = uintptr_t(codeStart) + result.exectarget; } - - // Link native proto objects to Proto; the memory is now managed by VM and will be freed via onDestroyFunction - for (NativeProto* result : results) - setProtoExecData(result->proto, result); } std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) @@ -586,7 +579,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) #endif NativeState data; - initFallbackTable(data); + initFunctions(data); std::vector protos; gatherFunctions(protos, clvalue(func)->l.p); @@ -600,8 +593,8 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) for (Proto* p : protos) if (p) - if (NativeProto* np = assembleFunction(build, data, helpers, p, options)) - destroyNativeProto(np); + if (std::optional np = assembleFunction(build, data, helpers, p, options)) + destroyExecData(np->execdata); if (!build.finalize()) return std::string(); diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index fbe44e23e..f6e9152c3 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -4,6 +4,7 @@ #include "Luau/AssemblyBuilderA64.h" #include "Luau/UnwindBuilder.h" +#include "BitUtils.h" #include "CustomExecUtils.h" #include "NativeState.h" #include "EmitCommonA64.h" @@ -91,6 +92,13 @@ static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) // Need to update state of the current function before we jump away build.ldr(x1, mem(x0, offsetof(Closure, l.p))); // cl->l.p aka proto + build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci + + // We need to check if the new frame can be executed natively + // TOOD: .flags and .savedpc load below can be fused with ldp + build.ldr(w3, mem(x2, offsetof(CallInfo, flags))); + build.tbz(x3, countrz(LUA_CALLINFO_CUSTOM), helpers.exitContinueVm); + build.mov(rClosure, x0); build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code @@ -98,22 +106,15 @@ static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) // Get instruction index from instruction pointer // To get instruction index from instruction pointer, we need to divide byte offset by 4 // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out - // Note that we're computing negative offset here (code-savedpc) so that we can add it to NativeProto address, as we use reverse indexing - build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci build.ldr(x2, mem(x2, offsetof(CallInfo, savedpc))); // L->ci->savedpc - build.sub(x2, rCode, x2); - - // We need to check if the new function can be executed natively - // TODO: This can be done earlier in the function flow, to reduce the JIT->VM transition penalty - build.ldr(x1, mem(x1, offsetofProtoExecData)); - build.cbz(x1, helpers.exitContinueVm); + build.sub(x2, x2, rCode); // Get new instruction location and jump to it - LUAU_ASSERT(offsetof(NativeProto, instOffsets) == 0); - build.ldr(w2, mem(x1, x2)); - build.ldr(x1, mem(x1, offsetof(NativeProto, instBase))); - build.add(x1, x1, x2); - build.br(x1); + LUAU_ASSERT(offsetof(Proto, exectarget) == offsetof(Proto, execdata) + 8); + build.ldp(x3, x4, mem(x1, offsetof(Proto, execdata))); + build.ldr(w2, mem(x3, x2)); + build.add(x4, x4, x2); + build.br(x4); } static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilder& unwind) diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 37dfa116d..4ad67d83d 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -1,13 +1,64 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "CodeGenUtils.h" +#include "CustomExecUtils.h" + +#include "lvm.h" + +#include "lbuiltins.h" +#include "lbytecode.h" +#include "ldebug.h" #include "ldo.h" +#include "lfunc.h" +#include "lgc.h" +#include "lmem.h" +#include "lnumutils.h" +#include "lstate.h" +#include "lstring.h" #include "ltable.h" -#include "FallbacksProlog.h" - #include +LUAU_FASTFLAG(LuauUniformTopHandling) + +// All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT +// This makes sure that we save the pc (in case the Lua call needs to generate a backtrace) before the call, +// and restores the stack pointer after in case stack gets reallocated +// Should only be used on the slow paths. +#define VM_PROTECT(x) \ + { \ + L->ci->savedpc = pc; \ + { \ + x; \ + }; \ + base = L->base; \ + } + +// Some external functions can cause an error, but never reallocate the stack; for these, VM_PROTECT_PC() is +// a cheaper version of VM_PROTECT that can be called before the external call. +#define VM_PROTECT_PC() L->ci->savedpc = pc + +#define VM_REG(i) (LUAU_ASSERT(unsigned(i) < unsigned(L->top - base)), &base[i]) +#define VM_KV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->l.p->sizek)), &k[i]) +#define VM_UV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->nupvalues)), &cl->l.uprefs[i]) + +#define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) +#define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) + +#define VM_INTERRUPT() \ + { \ + void (*interrupt)(lua_State*, int) = L->global->cb.interrupt; \ + if (LUAU_UNLIKELY(!!interrupt)) \ + { /* the interrupt hook is called right before we advance pc */ \ + VM_PROTECT(L->ci->savedpc++; interrupt(L, -1)); \ + if (L->status != 0) \ + { \ + L->ci->savedpc--; \ + return NULL; \ + } \ + } \ + } + namespace Luau { namespace CodeGen @@ -215,6 +266,10 @@ Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) // keep executing new function ci->savedpc = p->code; + + if (LUAU_LIKELY(p->execdata != NULL)) + ci->flags = LUA_CALLINFO_CUSTOM; + return ccl; } else @@ -281,7 +336,8 @@ Closure* returnFallback(lua_State* L, StkId ra, StkId valend) // we're done! if (LUAU_UNLIKELY(ci->flags & LUA_CALLINFO_RETURN)) { - L->top = res; + if (!FFlag::LuauUniformTopHandling) + L->top = res; return NULL; } @@ -290,5 +346,614 @@ Closure* returnFallback(lua_State* L, StkId ra, StkId valend) return clvalue(cip->func); } +const Instruction* executeGETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path should already have been checked, so we skip checking for it here + Table* h = cl->env; + int slot = LUAU_INSN_C(insn) & h->nodemask8; + + // slow-path, may invoke Lua calls via __index metamethod + TValue g; + sethvalue(L, &g, h); + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, &g, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; +} + +const Instruction* executeSETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path should already have been checked, so we skip checking for it here + Table* h = cl->env; + int slot = LUAU_INSN_C(insn) & h->nodemask8; + + // slow-path, may invoke Lua calls via __newindex metamethod + TValue g; + sethvalue(L, &g, h); + L->cachedslot = slot; + VM_PROTECT(luaV_settable(L, &g, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; +} + +const Instruction* executeGETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: built-in table + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: value is in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) + { + setobj2s(L, ra, gval(n)); + return pc; + } + else if (!h->metatable) + { + // fast-path: value is not in expected slot, but the table lookup doesn't involve metatable + const TValue* res = luaH_getstr(h, tsvalue(kv)); + + if (res != luaO_nilobject) + { + int cachedslot = gval2slot(h, res); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, cachedslot); + } + + setobj2s(L, ra, res); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __index metamethod + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + } + else + { + // fast-path: user data with C __index TM + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_INDEX)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + else if (ttisvector(rb)) + { + // fast-path: quick case-insensitive comparison with "X"/"Y"/"Z" + const char* name = getstr(tsvalue(kv)); + int ic = (name[0] | ' ') - 'x'; + +#if LUA_VECTOR_SIZE == 4 + // 'w' is before 'x' in ascii, so ic is -1 when indexing with 'w' + if (ic == -1) + ic = 3; +#endif + + if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0') + { + const float* v = rb->value.v; // silences ubsan when indexing v[] + setnvalue(ra, v[ic]); + return pc; + } + + fn = fasttm(L, L->global->mt[LUA_TVECTOR], TM_INDEX); + + if (fn && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + + // fall through to slow path + } + + // fall through to slow path + } + + // slow-path, may invoke Lua calls via __index metamethod + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + return pc; +} + +const Instruction* executeSETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: built-in table + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: value is in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) + { + setobj2t(L, gval(n), ra); + luaC_barriert(L, h, ra); + return pc; + } + else if (fastnotm(h->metatable, TM_NEWINDEX) && !h->readonly) + { + VM_PROTECT_PC(); // set may fail + + TValue* res = luaH_setstr(L, h, tsvalue(kv)); + int cachedslot = gval2slot(h, res); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, cachedslot); + setobj2t(L, res, ra); + luaC_barriert(L, h, ra); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __newindex metamethod + L->cachedslot = slot; + VM_PROTECT(luaV_settable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + } + else + { + // fast-path: user data with C __newindex TM + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_NEWINDEX)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 4 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + setobj2s(L, top + 3, ra); + L->top = top + 4; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_callTM(L, 3, -1)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __newindex metamethod + VM_PROTECT(luaV_settable(L, rb, kv, ra)); + return pc; + } + } +} + +const Instruction* executeNEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)]; + LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep)); + + VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM + + // note: we save closure to stack early in case the code below wants to capture it by value + Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv); + setclvalue(L, ra, ncl); + + for (int ui = 0; ui < pv->nups; ++ui) + { + Instruction uinsn = *pc++; + LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); + + switch (LUAU_INSN_A(uinsn)) + { + case LCT_VAL: + setobj(L, &ncl->l.uprefs[ui], VM_REG(LUAU_INSN_B(uinsn))); + break; + + case LCT_REF: + setupvalue(L, &ncl->l.uprefs[ui], luaF_findupval(L, VM_REG(LUAU_INSN_B(uinsn)))); + break; + + case LCT_UPVAL: + setobj(L, &ncl->l.uprefs[ui], VM_UV(LUAU_INSN_B(uinsn))); + break; + + default: + LUAU_ASSERT(!"Unknown upvalue capture type"); + LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks + } + } + + VM_PROTECT(luaC_checkGC(L)); + return pc; +} + +const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + if (ttistable(rb)) + { + Table* h = hvalue(rb); + // note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works + // for predictive lookups + LuaNode* n = &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; + + const TValue* mt = 0; + const LuaNode* mtn = 0; + + // fast-path: key is in the table in expected slot + if (ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(n)); + } + // fast-path: key is absent from the base, table has an __index table, and it has the result in the expected slot + else if (gnext(n) == 0 && (mt = fasttm(L, hvalue(rb)->metatable, TM_INDEX)) && ttistable(mt) && + (mtn = &hvalue(mt)->node[LUAU_INSN_C(insn) & hvalue(mt)->nodemask8]) && ttisstring(gkey(mtn)) && tsvalue(gkey(mtn)) == tsvalue(kv) && + !ttisnil(gval(mtn))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(mtn)); + } + else + { + // slow-path: handles full table lookup + setobj2s(L, ra + 1, rb); + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); + } + } + else + { + Table* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; + const TValue* tmi = 0; + + // fast-path: metatable with __namecall + if (const TValue* fn = fasttm(L, mt, TM_NAMECALL)) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, fn); + + L->namecall = tsvalue(kv); + } + else if ((tmi = fasttm(L, mt, TM_INDEX)) && ttistable(tmi)) + { + Table* h = hvalue(tmi); + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: metatable with __index that has method in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(n)); + } + else + { + // slow-path: handles slot mismatch + setobj2s(L, ra + 1, rb); + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); + } + } + else + { + // slow-path: handles non-table __index + setobj2s(L, ra + 1, rb); + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); + } + } + + // intentional fallthrough to CALL + LUAU_ASSERT(LUAU_INSN_OP(*pc) == LOP_CALL); + return pc; +} + +const Instruction* executeSETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = &base[LUAU_INSN_B(insn)]; // note: this can point to L->top if c == LUA_MULTRET making VM_REG unsafe to use + int c = LUAU_INSN_C(insn) - 1; + uint32_t index = *pc++; + + if (c == LUA_MULTRET) + { + c = int(L->top - rb); + L->top = L->ci->top; + } + + Table* h = hvalue(ra); + + // TODO: we really don't need this anymore + if (!ttistable(ra)) + return NULL; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode + + int last = index + c - 1; + if (last > h->sizearray) + { + VM_PROTECT_PC(); // luaH_resizearray may fail due to OOM + + luaH_resizearray(L, h, last); + } + + TValue* array = h->array; + + for (int i = 0; i < c; ++i) + setobj2t(L, &array[index + i - 1], rb + i); + + luaC_barrierfast(L, h); + return pc; +} + +const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + if (ttisfunction(ra)) + { + // will be called during FORGLOOP + } + else + { + Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); + + if (const TValue* fn = fasttm(L, mt, TM_ITER)) + { + setobj2s(L, ra + 1, ra); + setobj2s(L, ra, fn); + + L->top = ra + 2; // func + self arg + LUAU_ASSERT(L->top <= L->stack_last); + + VM_PROTECT(luaD_call(L, ra, 3)); + L->top = L->ci->top; + + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + + // protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP + if (ttisnil(ra)) + { + VM_PROTECT_PC(); // next call always errors + luaG_typeerror(L, ra, "call"); + } + } + else if (fasttm(L, mt, TM_CALL)) + { + // table or userdata with __call, will be called during FORGLOOP + // TODO: we might be able to stop supporting this depending on whether it's used in practice + } + else if (ttistable(ra)) + { + // set up registers for builtin iteration + setobj2s(L, ra + 1, ra); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setnilvalue(ra); + } + else + { + VM_PROTECT_PC(); // next call always errors + luaG_typeerror(L, ra, "iterate over"); + } + } + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* executeGETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + int b = LUAU_INSN_B(insn) - 1; + int n = cast_int(base - L->ci->func) - cl->l.p->numparams - 1; + + if (b == LUA_MULTRET) + { + VM_PROTECT(luaD_checkstack(L, n)); + StkId ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack + + for (int j = 0; j < n; j++) + setobj2s(L, ra + j, base - n + j); + + L->top = ra + n; + return pc; + } + else + { + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + for (int j = 0; j < b && j < n; j++) + setobj2s(L, ra + j, base - n + j); + for (int j = n; j < b; j++) + setnilvalue(ra + j); + return pc; + } +} + +const Instruction* executeDUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_D(insn)); + + Closure* kcl = clvalue(kv); + + VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM + + // clone closure if the environment is not shared + // note: we save closure to stack early in case the code below wants to capture it by value + Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); + setclvalue(L, ra, ncl); + + // this loop does three things: + // - if the closure was created anew, it just fills it with upvalues + // - if the closure from the constant table is used, it fills it with upvalues so that it can be shared in the future + // - if the closure is reused, it checks if the reuse is safe via rawequal, and falls back to duplicating the closure + // normally this would use two separate loops, for reuse check and upvalue setup, but MSVC codegen goes crazy if you do that + for (int ui = 0; ui < kcl->nupvalues; ++ui) + { + Instruction uinsn = pc[ui]; + LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); + LUAU_ASSERT(LUAU_INSN_A(uinsn) == LCT_VAL || LUAU_INSN_A(uinsn) == LCT_UPVAL); + + TValue* uv = (LUAU_INSN_A(uinsn) == LCT_VAL) ? VM_REG(LUAU_INSN_B(uinsn)) : VM_UV(LUAU_INSN_B(uinsn)); + + // check if the existing closure is safe to reuse + if (ncl == kcl && luaO_rawequalObj(&ncl->l.uprefs[ui], uv)) + continue; + + // lazily clone the closure and update the upvalues + if (ncl == kcl && kcl->preload == 0) + { + ncl = luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); + setclvalue(L, ra, ncl); + + ui = -1; // restart the loop to fill all upvalues + continue; + } + + // this updates a newly created closure, or an existing closure created during preload, in which case we need a barrier + setobj(L, &ncl->l.uprefs[ui], uv); + luaC_barrier(L, ncl, uv); + } + + // this is a noop if ncl is newly created or shared successfully, but it has to run after the closure is preloaded for the first time + ncl->preload = 0; + + if (kcl != ncl) + VM_PROTECT(luaC_checkGC(L)); + + pc += kcl->nupvalues; + return pc; +} + +const Instruction* executePREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + int numparams = LUAU_INSN_A(insn); + + // all fixed parameters are copied after the top so we need more stack space + VM_PROTECT(luaD_checkstack(L, cl->stacksize + numparams)); + + // the caller must have filled extra fixed arguments with nil + LUAU_ASSERT(cast_int(L->top - base) >= numparams); + + // move fixed parameters to final position + StkId fixed = base; // first fixed argument + base = L->top; // final position of first argument + + for (int i = 0; i < numparams; ++i) + { + setobj2s(L, base + i, fixed + i); + setnilvalue(fixed + i); + } + + // rewire our stack frame to point to the new base + L->ci->base = base; + L->ci->top = base + cl->stacksize; + + L->base = base; + L->top = L->ci->top; + return pc; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 4ce35663e..87b6ec449 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -20,5 +20,17 @@ void callEpilogC(lua_State* L, int nresults, int n); Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); Closure* returnFallback(lua_State* L, StkId ra, StkId valend); +const Instruction* executeGETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeSETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeGETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeSETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeNEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeSETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeGETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeDUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executePREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CustomExecUtils.h b/CodeGen/src/CustomExecUtils.h index 9526d6d71..9c9996611 100644 --- a/CodeGen/src/CustomExecUtils.h +++ b/CodeGen/src/CustomExecUtils.h @@ -46,21 +46,6 @@ inline void destroyNativeState(lua_State* L) delete state; } -inline NativeProto* getProtoExecData(Proto* proto) -{ - return (NativeProto*)proto->execdata; -} - -inline void setProtoExecData(Proto* proto, NativeProto* nativeProto) -{ - if (nativeProto) - LUAU_ASSERT(proto->execdata == nullptr); - - proto->execdata = nativeProto; -} - -#define offsetofProtoExecData offsetof(Proto, execdata) - #else inline lua_ExecutionCallbacks* getExecutionCallbacks(lua_State* L) @@ -82,15 +67,6 @@ inline NativeState* createNativeState(lua_State* L) inline void destroyNativeState(lua_State* L) {} -inline NativeProto* getProtoExecData(Proto* proto) -{ - return nullptr; -} - -inline void setProtoExecData(Proto* proto, NativeProto* nativeProto) {} - -#define offsetofProtoExecData 0 - #endif inline int getOpLength(LuauOpcode op) diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index 6a7496694..6b19912bf 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -10,11 +10,12 @@ namespace CodeGen constexpr unsigned kTValueSizeLog2 = 4; constexpr unsigned kLuaNodeSizeLog2 = 5; -constexpr unsigned kLuaNodeTagMask = 0xf; -constexpr unsigned kNextBitOffset = 4; -constexpr unsigned kOffsetOfTKeyTag = 12; // offsetof cannot be used on a bit field -constexpr unsigned kOffsetOfTKeyNext = 12; // offsetof cannot be used on a bit field +// TKey.tt and TKey.next are packed together in a bitfield +constexpr unsigned kOffsetOfTKeyTagNext = 12; // offsetof cannot be used on a bit field +constexpr unsigned kTKeyTagBits = 4; +constexpr unsigned kTKeyTagMask = (1 << kTKeyTagBits) - 1; + constexpr unsigned kOffsetOfInstructionC = 3; // Leaf functions that are placed in every module to perform common instruction sequences diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index b6ef957b3..ce95e7410 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -325,10 +325,8 @@ void emitInterrupt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos) build.setLabel(skip); } -void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, NativeState& data, int op, int pcpos) +void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos) { - LUAU_ASSERT(data.context.fallback[op]); - // fallback(L, instruction, base, k) IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::qword, rState); @@ -339,7 +337,7 @@ void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, NativeState& d callWrap.addArgument(SizeX64::qword, rBase); callWrap.addArgument(SizeX64::qword, rConstants); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, fallback) + op * sizeof(FallbackFn)]); + callWrap.call(qword[rNativeContext + offset]); emitUpdateBase(build); } diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index d4684fe85..ddc4048f4 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -136,7 +136,7 @@ inline OperandX64 luauNodeKeyValue(RegisterX64 node) // Note: tag has dirty upper bits inline OperandX64 luauNodeKeyTag(RegisterX64 node) { - return dword[node + offsetof(LuaNode, key) + kOffsetOfTKeyTag]; + return dword[node + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext]; } inline OperandX64 luauNodeValue(RegisterX64 node) @@ -189,7 +189,7 @@ inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, Re tmp.size = SizeX64::dword; build.mov(tmp, luauNodeKeyTag(node)); - build.and_(tmp, kLuaNodeTagMask); + build.and_(tmp, kTKeyTagMask); build.cmp(tmp, tag); build.jcc(ConditionX64::NotEqual, label); } @@ -230,7 +230,7 @@ void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitUpdateBase(AssemblyBuilderX64& build); void emitInterrupt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos); -void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, NativeState& data, int op, int pcpos); +void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos); void emitContinueCallInVm(AssemblyBuilderX64& build); diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 19f0cb86d..b2db7d187 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -73,8 +73,6 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int build.mov(rax, qword[ci + offsetof(CallInfo, top)]); build.mov(qword[rState + offsetof(lua_State, top)], rax); - build.mov(rax, qword[proto + offsetofProtoExecData]); // We'll need this value later - // But if it is vararg, update it to 'argi' Label skipVararg; @@ -84,10 +82,14 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int build.mov(qword[rState + offsetof(lua_State, top)], argi); build.setLabel(skipVararg); - // Check native function data + // Get native function entry + build.mov(rax, qword[proto + offsetof(Proto, exectarget)]); build.test(rax, rax); build.jcc(ConditionX64::Zero, helpers.continueCallInVm); + // Mark call frame as custom + build.mov(dword[ci + offsetof(CallInfo, flags)], LUA_CALLINFO_CUSTOM); + // Switch current constants build.mov(rConstants, qword[proto + offsetof(Proto, k)]); @@ -95,7 +97,7 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int build.mov(rdx, qword[proto + offsetof(Proto, code)]); build.mov(sCode, rdx); - build.jmp(qword[rax + offsetof(NativeProto, entryTarget)]); + build.jmp(rax); } build.setLabel(cFuncCall); @@ -294,8 +296,9 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i build.mov(proto, qword[rax + offsetof(Closure, l.p)]); - build.mov(execdata, qword[proto + offsetofProtoExecData]); - build.test(execdata, execdata); + build.mov(execdata, qword[proto + offsetof(Proto, execdata)]); + + build.test(byte[cip + offsetof(CallInfo, flags)], LUA_CALLINFO_CUSTOM); build.jcc(ConditionX64::Zero, helpers.exitContinueVm); // Continue in interpreter if function has no native data // Change constants @@ -309,13 +312,11 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i // To get instruction index from instruction pointer, we need to divide byte offset by 4 // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out - // Note that we're computing negative offset here (code-savedpc) so that we can add it to NativeProto address, as we use reverse indexing - build.sub(rdx, rax); + build.sub(rax, rdx); // Get new instruction location and jump to it - LUAU_ASSERT(offsetof(NativeProto, instOffsets) == 0); - build.mov(edx, dword[execdata + rdx]); - build.add(rdx, qword[execdata + offsetof(NativeProto, instBase)]); + build.mov(edx, dword[execdata + rax]); + build.add(rdx, qword[proto + offsetof(Proto, exectarget)]); build.jmp(rdx); } diff --git a/CodeGen/src/Fallbacks.cpp b/CodeGen/src/Fallbacks.cpp deleted file mode 100644 index 1c0dce57a..000000000 --- a/CodeGen/src/Fallbacks.cpp +++ /dev/null @@ -1,639 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details -// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand -#include "Fallbacks.h" -#include "FallbacksProlog.h" - -const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - uint32_t aux = *pc++; - TValue* kv = VM_KV(aux); - LUAU_ASSERT(ttisstring(kv)); - - // fast-path: value is in expected slot - Table* h = cl->env; - int slot = LUAU_INSN_C(insn) & h->nodemask8; - LuaNode* n = &h->node[slot]; - - if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv)) && !ttisnil(gval(n))) - { - setobj2s(L, ra, gval(n)); - return pc; - } - else - { - // slow-path, may invoke Lua calls via __index metamethod - TValue g; - sethvalue(L, &g, h); - L->cachedslot = slot; - VM_PROTECT(luaV_gettable(L, &g, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } -} - -const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - uint32_t aux = *pc++; - TValue* kv = VM_KV(aux); - LUAU_ASSERT(ttisstring(kv)); - - // fast-path: value is in expected slot - Table* h = cl->env; - int slot = LUAU_INSN_C(insn) & h->nodemask8; - LuaNode* n = &h->node[slot]; - - if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) - { - setobj2t(L, gval(n), ra); - luaC_barriert(L, h, ra); - return pc; - } - else - { - // slow-path, may invoke Lua calls via __newindex metamethod - TValue g; - sethvalue(L, &g, h); - L->cachedslot = slot; - VM_PROTECT(luaV_settable(L, &g, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } -} - -const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - StkId rb = VM_REG(LUAU_INSN_B(insn)); - uint32_t aux = *pc++; - TValue* kv = VM_KV(aux); - LUAU_ASSERT(ttisstring(kv)); - - // fast-path: built-in table - if (ttistable(rb)) - { - Table* h = hvalue(rb); - - int slot = LUAU_INSN_C(insn) & h->nodemask8; - LuaNode* n = &h->node[slot]; - - // fast-path: value is in expected slot - if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) - { - setobj2s(L, ra, gval(n)); - return pc; - } - else if (!h->metatable) - { - // fast-path: value is not in expected slot, but the table lookup doesn't involve metatable - const TValue* res = luaH_getstr(h, tsvalue(kv)); - - if (res != luaO_nilobject) - { - int cachedslot = gval2slot(h, res); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, cachedslot); - } - - setobj2s(L, ra, res); - return pc; - } - else - { - // slow-path, may invoke Lua calls via __index metamethod - L->cachedslot = slot; - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } - } - else - { - // fast-path: user data with C __index TM - const TValue* fn = 0; - if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_INDEX)) && ttisfunction(fn) && clvalue(fn)->isC) - { - // note: it's safe to push arguments past top for complicated reasons (see top of the file) - LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); - StkId top = L->top; - setobj2s(L, top + 0, fn); - setobj2s(L, top + 1, rb); - setobj2s(L, top + 2, kv); - L->top = top + 3; - - L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } - else if (ttisvector(rb)) - { - // fast-path: quick case-insensitive comparison with "X"/"Y"/"Z" - const char* name = getstr(tsvalue(kv)); - int ic = (name[0] | ' ') - 'x'; - -#if LUA_VECTOR_SIZE == 4 - // 'w' is before 'x' in ascii, so ic is -1 when indexing with 'w' - if (ic == -1) - ic = 3; -#endif - - if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0') - { - const float* v = rb->value.v; // silences ubsan when indexing v[] - setnvalue(ra, v[ic]); - return pc; - } - - fn = fasttm(L, L->global->mt[LUA_TVECTOR], TM_INDEX); - - if (fn && ttisfunction(fn) && clvalue(fn)->isC) - { - // note: it's safe to push arguments past top for complicated reasons (see top of the file) - LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); - StkId top = L->top; - setobj2s(L, top + 0, fn); - setobj2s(L, top + 1, rb); - setobj2s(L, top + 2, kv); - L->top = top + 3; - - L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } - - // fall through to slow path - } - - // fall through to slow path - } - - // slow-path, may invoke Lua calls via __index metamethod - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - return pc; -} - -const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - StkId rb = VM_REG(LUAU_INSN_B(insn)); - uint32_t aux = *pc++; - TValue* kv = VM_KV(aux); - LUAU_ASSERT(ttisstring(kv)); - - // fast-path: built-in table - if (ttistable(rb)) - { - Table* h = hvalue(rb); - - int slot = LUAU_INSN_C(insn) & h->nodemask8; - LuaNode* n = &h->node[slot]; - - // fast-path: value is in expected slot - if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) - { - setobj2t(L, gval(n), ra); - luaC_barriert(L, h, ra); - return pc; - } - else if (fastnotm(h->metatable, TM_NEWINDEX) && !h->readonly) - { - VM_PROTECT_PC(); // set may fail - - TValue* res = luaH_setstr(L, h, tsvalue(kv)); - int cachedslot = gval2slot(h, res); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, cachedslot); - setobj2t(L, res, ra); - luaC_barriert(L, h, ra); - return pc; - } - else - { - // slow-path, may invoke Lua calls via __newindex metamethod - L->cachedslot = slot; - VM_PROTECT(luaV_settable(L, rb, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } - } - else - { - // fast-path: user data with C __newindex TM - const TValue* fn = 0; - if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_NEWINDEX)) && ttisfunction(fn) && clvalue(fn)->isC) - { - // note: it's safe to push arguments past top for complicated reasons (see top of the file) - LUAU_ASSERT(L->top + 4 < L->stack + L->stacksize); - StkId top = L->top; - setobj2s(L, top + 0, fn); - setobj2s(L, top + 1, rb); - setobj2s(L, top + 2, kv); - setobj2s(L, top + 3, ra); - L->top = top + 4; - - L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luaV_callTM(L, 3, -1)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } - else - { - // slow-path, may invoke Lua calls via __newindex metamethod - VM_PROTECT(luaV_settable(L, rb, kv, ra)); - return pc; - } - } -} - -const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - - Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)]; - LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep)); - - VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM - - // note: we save closure to stack early in case the code below wants to capture it by value - Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv); - setclvalue(L, ra, ncl); - - for (int ui = 0; ui < pv->nups; ++ui) - { - Instruction uinsn = *pc++; - LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); - - switch (LUAU_INSN_A(uinsn)) - { - case LCT_VAL: - setobj(L, &ncl->l.uprefs[ui], VM_REG(LUAU_INSN_B(uinsn))); - break; - - case LCT_REF: - setupvalue(L, &ncl->l.uprefs[ui], luaF_findupval(L, VM_REG(LUAU_INSN_B(uinsn)))); - break; - - case LCT_UPVAL: - setobj(L, &ncl->l.uprefs[ui], VM_UV(LUAU_INSN_B(uinsn))); - break; - - default: - LUAU_ASSERT(!"Unknown upvalue capture type"); - LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks - } - } - - VM_PROTECT(luaC_checkGC(L)); - return pc; -} - -const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - StkId rb = VM_REG(LUAU_INSN_B(insn)); - uint32_t aux = *pc++; - TValue* kv = VM_KV(aux); - LUAU_ASSERT(ttisstring(kv)); - - if (ttistable(rb)) - { - Table* h = hvalue(rb); - // note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works - // for predictive lookups - LuaNode* n = &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; - - const TValue* mt = 0; - const LuaNode* mtn = 0; - - // fast-path: key is in the table in expected slot - if (ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n))) - { - // note: order of copies allows rb to alias ra+1 or ra - setobj2s(L, ra + 1, rb); - setobj2s(L, ra, gval(n)); - } - // fast-path: key is absent from the base, table has an __index table, and it has the result in the expected slot - else if (gnext(n) == 0 && (mt = fasttm(L, hvalue(rb)->metatable, TM_INDEX)) && ttistable(mt) && - (mtn = &hvalue(mt)->node[LUAU_INSN_C(insn) & hvalue(mt)->nodemask8]) && ttisstring(gkey(mtn)) && tsvalue(gkey(mtn)) == tsvalue(kv) && - !ttisnil(gval(mtn))) - { - // note: order of copies allows rb to alias ra+1 or ra - setobj2s(L, ra + 1, rb); - setobj2s(L, ra, gval(mtn)); - } - else - { - // slow-path: handles full table lookup - setobj2s(L, ra + 1, rb); - L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - // recompute ra since stack might have been reallocated - ra = VM_REG(LUAU_INSN_A(insn)); - if (ttisnil(ra)) - luaG_methoderror(L, ra + 1, tsvalue(kv)); - } - } - else - { - Table* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; - const TValue* tmi = 0; - - // fast-path: metatable with __namecall - if (const TValue* fn = fasttm(L, mt, TM_NAMECALL)) - { - // note: order of copies allows rb to alias ra+1 or ra - setobj2s(L, ra + 1, rb); - setobj2s(L, ra, fn); - - L->namecall = tsvalue(kv); - } - else if ((tmi = fasttm(L, mt, TM_INDEX)) && ttistable(tmi)) - { - Table* h = hvalue(tmi); - int slot = LUAU_INSN_C(insn) & h->nodemask8; - LuaNode* n = &h->node[slot]; - - // fast-path: metatable with __index that has method in expected slot - if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) - { - // note: order of copies allows rb to alias ra+1 or ra - setobj2s(L, ra + 1, rb); - setobj2s(L, ra, gval(n)); - } - else - { - // slow-path: handles slot mismatch - setobj2s(L, ra + 1, rb); - L->cachedslot = slot; - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - // recompute ra since stack might have been reallocated - ra = VM_REG(LUAU_INSN_A(insn)); - if (ttisnil(ra)) - luaG_methoderror(L, ra + 1, tsvalue(kv)); - } - } - else - { - // slow-path: handles non-table __index - setobj2s(L, ra + 1, rb); - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - // recompute ra since stack might have been reallocated - ra = VM_REG(LUAU_INSN_A(insn)); - if (ttisnil(ra)) - luaG_methoderror(L, ra + 1, tsvalue(kv)); - } - } - - // intentional fallthrough to CALL - LUAU_ASSERT(LUAU_INSN_OP(*pc) == LOP_CALL); - return pc; -} - -const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - StkId rb = &base[LUAU_INSN_B(insn)]; // note: this can point to L->top if c == LUA_MULTRET making VM_REG unsafe to use - int c = LUAU_INSN_C(insn) - 1; - uint32_t index = *pc++; - - if (c == LUA_MULTRET) - { - c = int(L->top - rb); - L->top = L->ci->top; - } - - Table* h = hvalue(ra); - - // TODO: we really don't need this anymore - if (!ttistable(ra)) - return NULL; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode - - int last = index + c - 1; - if (last > h->sizearray) - { - VM_PROTECT_PC(); // luaH_resizearray may fail due to OOM - - luaH_resizearray(L, h, last); - } - - TValue* array = h->array; - - for (int i = 0; i < c; ++i) - setobj2t(L, &array[index + i - 1], rb + i); - - luaC_barrierfast(L, h); - return pc; -} - -const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - - if (ttisfunction(ra)) - { - // will be called during FORGLOOP - } - else - { - Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); - - if (const TValue* fn = fasttm(L, mt, TM_ITER)) - { - setobj2s(L, ra + 1, ra); - setobj2s(L, ra, fn); - - L->top = ra + 2; // func + self arg - LUAU_ASSERT(L->top <= L->stack_last); - - VM_PROTECT(luaD_call(L, ra, 3)); - L->top = L->ci->top; - - // recompute ra since stack might have been reallocated - ra = VM_REG(LUAU_INSN_A(insn)); - - // protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP - if (ttisnil(ra)) - { - VM_PROTECT_PC(); // next call always errors - luaG_typeerror(L, ra, "call"); - } - } - else if (fasttm(L, mt, TM_CALL)) - { - // table or userdata with __call, will be called during FORGLOOP - // TODO: we might be able to stop supporting this depending on whether it's used in practice - } - else if (ttistable(ra)) - { - // set up registers for builtin iteration - setobj2s(L, ra + 1, ra); - setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); - setnilvalue(ra); - } - else - { - VM_PROTECT_PC(); // next call always errors - luaG_typeerror(L, ra, "iterate over"); - } - } - - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - return pc; -} - -const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - int b = LUAU_INSN_B(insn) - 1; - int n = cast_int(base - L->ci->func) - cl->l.p->numparams - 1; - - if (b == LUA_MULTRET) - { - VM_PROTECT(luaD_checkstack(L, n)); - StkId ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack - - for (int j = 0; j < n; j++) - setobj2s(L, ra + j, base - n + j); - - L->top = ra + n; - return pc; - } - else - { - StkId ra = VM_REG(LUAU_INSN_A(insn)); - - for (int j = 0; j < b && j < n; j++) - setobj2s(L, ra + j, base - n + j); - for (int j = n; j < b; j++) - setnilvalue(ra + j); - return pc; - } -} - -const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - TValue* kv = VM_KV(LUAU_INSN_D(insn)); - - Closure* kcl = clvalue(kv); - - VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM - - // clone closure if the environment is not shared - // note: we save closure to stack early in case the code below wants to capture it by value - Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); - setclvalue(L, ra, ncl); - - // this loop does three things: - // - if the closure was created anew, it just fills it with upvalues - // - if the closure from the constant table is used, it fills it with upvalues so that it can be shared in the future - // - if the closure is reused, it checks if the reuse is safe via rawequal, and falls back to duplicating the closure - // normally this would use two separate loops, for reuse check and upvalue setup, but MSVC codegen goes crazy if you do that - for (int ui = 0; ui < kcl->nupvalues; ++ui) - { - Instruction uinsn = pc[ui]; - LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); - LUAU_ASSERT(LUAU_INSN_A(uinsn) == LCT_VAL || LUAU_INSN_A(uinsn) == LCT_UPVAL); - - TValue* uv = (LUAU_INSN_A(uinsn) == LCT_VAL) ? VM_REG(LUAU_INSN_B(uinsn)) : VM_UV(LUAU_INSN_B(uinsn)); - - // check if the existing closure is safe to reuse - if (ncl == kcl && luaO_rawequalObj(&ncl->l.uprefs[ui], uv)) - continue; - - // lazily clone the closure and update the upvalues - if (ncl == kcl && kcl->preload == 0) - { - ncl = luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); - setclvalue(L, ra, ncl); - - ui = -1; // restart the loop to fill all upvalues - continue; - } - - // this updates a newly created closure, or an existing closure created during preload, in which case we need a barrier - setobj(L, &ncl->l.uprefs[ui], uv); - luaC_barrier(L, ncl, uv); - } - - // this is a noop if ncl is newly created or shared successfully, but it has to run after the closure is preloaded for the first time - ncl->preload = 0; - - if (kcl != ncl) - VM_PROTECT(luaC_checkGC(L)); - - pc += kcl->nupvalues; - return pc; -} - -const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - int numparams = LUAU_INSN_A(insn); - - // all fixed parameters are copied after the top so we need more stack space - VM_PROTECT(luaD_checkstack(L, cl->stacksize + numparams)); - - // the caller must have filled extra fixed arguments with nil - LUAU_ASSERT(cast_int(L->top - base) >= numparams); - - // move fixed parameters to final position - StkId fixed = base; // first fixed argument - base = L->top; // final position of first argument - - for (int i = 0; i < numparams; ++i) - { - setobj2s(L, base + i, fixed + i); - setnilvalue(fixed + i); - } - - // rewire our stack frame to point to the new base - L->ci->base = base; - L->ci->top = base + cl->stacksize; - - L->base = base; - L->top = L->ci->top; - return pc; -} - -const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - LUAU_ASSERT(!"Unsupported deprecated opcode"); - LUAU_UNREACHABLE(); -} diff --git a/CodeGen/src/Fallbacks.h b/CodeGen/src/Fallbacks.h deleted file mode 100644 index 0d2d218a0..000000000 --- a/CodeGen/src/Fallbacks.h +++ /dev/null @@ -1,24 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand -#pragma once - -#include - -struct lua_State; -struct Closure; -typedef uint32_t Instruction; -typedef struct lua_TValue TValue; -typedef TValue* StkId; - -const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k); diff --git a/CodeGen/src/FallbacksProlog.h b/CodeGen/src/FallbacksProlog.h deleted file mode 100644 index bbb06b84b..000000000 --- a/CodeGen/src/FallbacksProlog.h +++ /dev/null @@ -1,56 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "lvm.h" - -#include "lbuiltins.h" -#include "lbytecode.h" -#include "ldebug.h" -#include "ldo.h" -#include "lfunc.h" -#include "lgc.h" -#include "lmem.h" -#include "lnumutils.h" -#include "lstate.h" -#include "lstring.h" -#include "ltable.h" - -#include - -// All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT -// This makes sure that we save the pc (in case the Lua call needs to generate a backtrace) before the call, -// and restores the stack pointer after in case stack gets reallocated -// Should only be used on the slow paths. -#define VM_PROTECT(x) \ - { \ - L->ci->savedpc = pc; \ - { \ - x; \ - }; \ - base = L->base; \ - } - -// Some external functions can cause an error, but never reallocate the stack; for these, VM_PROTECT_PC() is -// a cheaper version of VM_PROTECT that can be called before the external call. -#define VM_PROTECT_PC() L->ci->savedpc = pc - -#define VM_REG(i) (LUAU_ASSERT(unsigned(i) < unsigned(L->top - base)), &base[i]) -#define VM_KV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->l.p->sizek)), &k[i]) -#define VM_UV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->nupvalues)), &cl->l.uprefs[i]) - -#define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) -#define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) - -#define VM_INTERRUPT() \ - { \ - void (*interrupt)(lua_State*, int) = L->global->cb.interrupt; \ - if (LUAU_UNLIKELY(!!interrupt)) \ - { /* the interrupt hook is called right before we advance pc */ \ - VM_PROTECT(L->ci->savedpc++; interrupt(L, -1)); \ - if (L->status != 0) \ - { \ - L->ci->savedpc--; \ - return NULL; \ - } \ - } \ - } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 3ac37efcf..711baba68 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -96,14 +96,14 @@ static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA6 } } -static void emitFallback(AssemblyBuilderA64& build, int op, int pcpos) +static void emitFallback(AssemblyBuilderA64& build, int offset, int pcpos) { // fallback(L, instruction, base, k) build.mov(x0, rState); emitAddOffset(build, x1, rCode, pcpos * sizeof(Instruction)); build.mov(x2, rBase); build.mov(x3, rConstants); - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, fallback) + op * sizeof(FallbackFn))); + build.ldr(x4, mem(rNativeContext, offset)); build.blr(x4); emitUpdateBase(build); @@ -658,30 +658,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.e), next); break; } - case IrCmd::JUMP_SLOT_MATCH: - { - // TODO: share code with CHECK_SLOT_MATCH - RegisterA64 temp1 = regs.allocTemp(KindA64::x); - RegisterA64 temp1w = castReg(KindA64::w, temp1); - RegisterA64 temp2 = regs.allocTemp(KindA64::x); - - build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyTag)); - build.and_(temp1w, temp1w, kLuaNodeTagMask); - build.cmp(temp1w, LUA_TSTRING); - build.b(ConditionA64::NotEqual, labelOp(inst.d)); - - AddressA64 addr = tempAddr(inst.b, offsetof(TValue, value)); - build.ldr(temp1, mem(regOp(inst.a), offsetof(LuaNode, key.value))); - build.ldr(temp2, addr); - build.cmp(temp1, temp2); - build.b(ConditionA64::NotEqual, labelOp(inst.d)); - - build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, val.tt))); - LUAU_ASSERT(LUA_TNIL == 0); - build.cbz(temp1w, labelOp(inst.d)); - jumpOrFallthrough(blockOp(inst.c), next); - break; - } + // IrCmd::JUMP_SLOT_MATCH implemented below case IrCmd::TABLE_LEN: { RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads @@ -1078,34 +1055,40 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.b(ConditionA64::UnsignedLessEqual, labelOp(inst.c)); break; } + case IrCmd::JUMP_SLOT_MATCH: case IrCmd::CHECK_SLOT_MATCH: { + Label& mismatch = inst.cmd == IrCmd::JUMP_SLOT_MATCH ? labelOp(inst.d) : labelOp(inst.c); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp1w = castReg(KindA64::w, temp1); RegisterA64 temp2 = regs.allocTemp(KindA64::x); - build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyTag)); - build.and_(temp1w, temp1w, kLuaNodeTagMask); - build.cmp(temp1w, LUA_TSTRING); - build.b(ConditionA64::NotEqual, labelOp(inst.c)); + LUAU_ASSERT(offsetof(LuaNode, key.value) == offsetof(LuaNode, key) && kOffsetOfTKeyTagNext >= 8 && kOffsetOfTKeyTagNext < 16); + build.ldp(temp1, temp2, mem(regOp(inst.a), offsetof(LuaNode, key))); // load key.value into temp1 and key.tt (alongside other bits) into temp2 + build.ubfx(temp2, temp2, (kOffsetOfTKeyTagNext - 8) * 8, kTKeyTagBits); // .tt is right before .next, and 8 bytes are skipped by ldp + build.cmp(temp2, LUA_TSTRING); + build.b(ConditionA64::NotEqual, mismatch); AddressA64 addr = tempAddr(inst.b, offsetof(TValue, value)); - build.ldr(temp1, mem(regOp(inst.a), offsetof(LuaNode, key.value))); build.ldr(temp2, addr); build.cmp(temp1, temp2); - build.b(ConditionA64::NotEqual, labelOp(inst.c)); + build.b(ConditionA64::NotEqual, mismatch); build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, val.tt))); LUAU_ASSERT(LUA_TNIL == 0); - build.cbz(temp1w, labelOp(inst.c)); + build.cbz(temp1w, mismatch); + + if (inst.cmd == IrCmd::JUMP_SLOT_MATCH) + jumpOrFallthrough(blockOp(inst.c), next); break; } case IrCmd::CHECK_NODE_NO_NEXT: { RegisterA64 temp = regs.allocTemp(KindA64::w); - build.ldr(temp, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyNext)); - build.lsr(temp, temp, kNextBitOffset); + build.ldr(temp, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyTagNext)); + build.lsr(temp, temp, kTKeyTagBits); build.cbnz(temp, labelOp(inst.b)); break; } @@ -1139,6 +1122,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label skip; build.ldr(temp1, mem(rState, offsetof(lua_State, global))); + // TODO: totalbytes and GCthreshold loads can be fused with ldp build.ldr(temp2, mem(temp1, offsetof(global_State, totalbytes))); build.ldr(temp1, mem(temp1, offsetof(global_State, GCthreshold))); build.cmp(temp1, temp2); @@ -1265,7 +1249,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::SETLIST: regs.spill(build, index); - emitFallback(build, LOP_SETLIST, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeSETLIST), uintOp(inst.a)); break; case IrCmd::CALL: regs.spill(build, index); @@ -1368,14 +1352,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_GETGLOBAL, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeGETGLOBAL), uintOp(inst.a)); break; case IrCmd::FALLBACK_SETGLOBAL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_SETGLOBAL, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeSETGLOBAL), uintOp(inst.a)); break; case IrCmd::FALLBACK_GETTABLEKS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); @@ -1383,7 +1367,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_GETTABLEKS, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeGETTABLEKS), uintOp(inst.a)); break; case IrCmd::FALLBACK_SETTABLEKS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); @@ -1391,7 +1375,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_SETTABLEKS, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeSETTABLEKS), uintOp(inst.a)); break; case IrCmd::FALLBACK_NAMECALL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); @@ -1399,38 +1383,38 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_NAMECALL, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeNAMECALL), uintOp(inst.a)); break; case IrCmd::FALLBACK_PREPVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); regs.spill(build, index); - emitFallback(build, LOP_PREPVARARGS, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executePREPVARARGS), uintOp(inst.a)); break; case IrCmd::FALLBACK_GETVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); regs.spill(build, index); - emitFallback(build, LOP_GETVARARGS, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeGETVARARGS), uintOp(inst.a)); break; case IrCmd::FALLBACK_NEWCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); regs.spill(build, index); - emitFallback(build, LOP_NEWCLOSURE, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeNEWCLOSURE), uintOp(inst.a)); break; case IrCmd::FALLBACK_DUPCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_DUPCLOSURE, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeDUPCLOSURE), uintOp(inst.a)); break; case IrCmd::FALLBACK_FORGPREP: regs.spill(build, index); - emitFallback(build, LOP_FORGPREP, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeFORGPREP), uintOp(inst.a)); jumpOrFallthrough(blockOp(inst.c), next); break; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 8c1f2b044..035cc05c6 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -938,8 +938,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { ScopedRegX64 tmp{regs, SizeX64::dword}; - build.mov(tmp.reg, dword[regOp(inst.a) + offsetof(LuaNode, key) + kOffsetOfTKeyNext]); - build.shr(tmp.reg, kNextBitOffset); + build.mov(tmp.reg, dword[regOp(inst.a) + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext]); + build.shr(tmp.reg, kTKeyTagBits); build.jcc(ConditionX64::NotZero, labelOp(inst.b)); break; } @@ -1098,60 +1098,60 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_GETGLOBAL, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeGETGLOBAL), uintOp(inst.a)); break; case IrCmd::FALLBACK_SETGLOBAL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_SETGLOBAL, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeSETGLOBAL), uintOp(inst.a)); break; case IrCmd::FALLBACK_GETTABLEKS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_GETTABLEKS, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeGETTABLEKS), uintOp(inst.a)); break; case IrCmd::FALLBACK_SETTABLEKS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_SETTABLEKS, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeSETTABLEKS), uintOp(inst.a)); break; case IrCmd::FALLBACK_NAMECALL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_NAMECALL, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeNAMECALL), uintOp(inst.a)); break; case IrCmd::FALLBACK_PREPVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); - emitFallback(regs, build, data, LOP_PREPVARARGS, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executePREPVARARGS), uintOp(inst.a)); break; case IrCmd::FALLBACK_GETVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - emitFallback(regs, build, data, LOP_GETVARARGS, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeGETVARARGS), uintOp(inst.a)); break; case IrCmd::FALLBACK_NEWCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - emitFallback(regs, build, data, LOP_NEWCLOSURE, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeNEWCLOSURE), uintOp(inst.a)); break; case IrCmd::FALLBACK_DUPCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_DUPCLOSURE, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeDUPCLOSURE), uintOp(inst.a)); break; case IrCmd::FALLBACK_FORGPREP: - emitFallback(regs, build, data, LOP_FORGPREP, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeFORGPREP), uintOp(inst.a)); jumpOrFallthrough(blockOp(inst.c), next); break; case IrCmd::BITAND_UINT: diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index cb128de98..bda468897 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -5,7 +5,6 @@ #include "CodeGenUtils.h" #include "CustomExecUtils.h" -#include "Fallbacks.h" #include "lbuiltins.h" #include "lgc.h" @@ -16,8 +15,6 @@ #include #include -#define CODEGEN_SET_FALLBACK(op) data.context.fallback[op] = {execute_##op} - namespace Luau { namespace CodeGen @@ -33,27 +30,7 @@ NativeState::NativeState() NativeState::~NativeState() = default; -void initFallbackTable(NativeState& data) -{ - // When fallback is completely removed, remove it from includeInsts list in lvmexecute_split.py - CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE); - CODEGEN_SET_FALLBACK(LOP_NAMECALL); - CODEGEN_SET_FALLBACK(LOP_FORGPREP); - CODEGEN_SET_FALLBACK(LOP_GETVARARGS); - CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE); - CODEGEN_SET_FALLBACK(LOP_PREPVARARGS); - CODEGEN_SET_FALLBACK(LOP_BREAK); - CODEGEN_SET_FALLBACK(LOP_SETLIST); - - // Fallbacks that are called from partial implementation of an instruction - // TODO: these fallbacks should be replaced with special functions that exclude the (redundantly executed) fast path from the fallback - CODEGEN_SET_FALLBACK(LOP_GETGLOBAL); - CODEGEN_SET_FALLBACK(LOP_SETGLOBAL); - CODEGEN_SET_FALLBACK(LOP_GETTABLEKS); - CODEGEN_SET_FALLBACK(LOP_SETTABLEKS); -} - -void initHelperFunctions(NativeState& data) +void initFunctions(NativeState& data) { static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table)); @@ -115,6 +92,19 @@ void initHelperFunctions(NativeState& data) data.context.callFallback = callFallback; data.context.returnFallback = returnFallback; + + data.context.executeGETGLOBAL = executeGETGLOBAL; + data.context.executeSETGLOBAL = executeSETGLOBAL; + data.context.executeGETTABLEKS = executeGETTABLEKS; + data.context.executeSETTABLEKS = executeSETTABLEKS; + + data.context.executeNEWCLOSURE = executeNEWCLOSURE; + data.context.executeNAMECALL = executeNAMECALL; + data.context.executeFORGPREP = executeFORGPREP; + data.context.executeGETVARARGS = executeGETVARARGS; + data.context.executeDUPCLOSURE = executeDUPCLOSURE; + data.context.executePREPVARARGS = executePREPVARARGS; + data.context.executeSETLIST = executeSETLIST; } } // namespace CodeGen diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index eb1d97a50..40017e359 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -23,19 +23,6 @@ namespace CodeGen class UnwindBuilder; -using FallbackFn = const Instruction* (*)(lua_State* L, const Instruction* pc, StkId base, TValue* k); - -struct NativeProto -{ - // This array is stored before NativeProto in reverse order, so to get offset of instruction i you need to index instOffsets[-i] - // This awkward layout is helpful for maximally efficient address computation on X64/A64 - uint32_t instOffsets[1]; - - uintptr_t instBase = 0; - uintptr_t entryTarget = 0; // = instOffsets[0] + instBase - Proto* proto = nullptr; -}; - struct NativeContext { // Gateway (C => native transition) entry & exit, compiled at runtime @@ -102,7 +89,17 @@ struct NativeContext Closure* (*returnFallback)(lua_State* L, StkId ra, StkId valend) = nullptr; // Opcode fallbacks, implemented in C - FallbackFn fallback[LOP__COUNT] = {}; + const Instruction* (*executeGETGLOBAL)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeSETGLOBAL)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeGETTABLEKS)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeSETTABLEKS)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeNEWCLOSURE)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeNAMECALL)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeSETLIST)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeFORGPREP)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeGETVARARGS)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeDUPCLOSURE)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executePREPVARARGS)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; // Fast call methods, implemented in C luau_FastFunction luauF_table[256] = {}; @@ -124,8 +121,7 @@ struct NativeState NativeContext context; }; -void initFallbackTable(NativeState& data); -void initHelperFunctions(NativeState& data); +void initFunctions(NativeState& data); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 926ead3d7..8bb3cd7b7 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -714,10 +714,23 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::DUP_TABLE: case IrCmd::TRY_NUM_TO_INDEX: case IrCmd::TRY_CALL_FASTGETTM: + break; case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: + state.substituteOrRecord(inst, index); + break; case IrCmd::NUM_TO_INT: + if (IrInst* src = function.asInstOp(inst.a); src && src->cmd == IrCmd::INT_TO_NUM) + substitute(function, inst, src->a); + else + state.substituteOrRecord(inst, index); + break; case IrCmd::NUM_TO_UINT: + if (IrInst* src = function.asInstOp(inst.a); src && src->cmd == IrCmd::UINT_TO_NUM) + substitute(function, inst, src->a); + else + state.substituteOrRecord(inst, index); + break; case IrCmd::CHECK_ARRAY_SIZE: case IrCmd::CHECK_SLOT_MATCH: case IrCmd::CHECK_NODE_NO_NEXT: diff --git a/CodeGen/src/lcodegen.cpp b/CodeGen/src/lcodegen.cpp new file mode 100644 index 000000000..0795cd48d --- /dev/null +++ b/CodeGen/src/lcodegen.cpp @@ -0,0 +1,21 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "luacodegen.h" + +#include "Luau/CodeGen.h" + +#include "lapi.h" + +int luau_codegen_supported() +{ + return Luau::CodeGen::isSupported(); +} + +void luau_codegen_create(lua_State* L) +{ + Luau::CodeGen::create(L); +} + +void luau_codegen_compile(lua_State* L, int idx) +{ + Luau::CodeGen::compile(L, idx); +} diff --git a/Makefile b/Makefile index aead3d32d..99eb93e6c 100644 --- a/Makefile +++ b/Makefile @@ -136,6 +136,7 @@ $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/ $(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread +$(ANALYZE_CLI_TARGET): LDFLAGS+=-lpthread fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a $(LPROTOBUF) # pseudo targets diff --git a/Sources.cmake b/Sources.cmake index 6beb02cb1..892b889bb 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -79,6 +79,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/UnwindBuilder.h CodeGen/include/Luau/UnwindBuilderDwarf2.h CodeGen/include/Luau/UnwindBuilderWin.h + CodeGen/include/luacodegen.h CodeGen/src/AssemblyBuilderA64.cpp CodeGen/src/AssemblyBuilderX64.cpp @@ -91,7 +92,6 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/EmitBuiltinsX64.cpp CodeGen/src/EmitCommonX64.cpp CodeGen/src/EmitInstructionX64.cpp - CodeGen/src/Fallbacks.cpp CodeGen/src/IrAnalysis.cpp CodeGen/src/IrBuilder.cpp CodeGen/src/IrCallWrapperX64.cpp @@ -104,6 +104,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/IrTranslation.cpp CodeGen/src/IrUtils.cpp CodeGen/src/IrValueLocationTracking.cpp + CodeGen/src/lcodegen.cpp CodeGen/src/NativeState.cpp CodeGen/src/OptimizeConstProp.cpp CodeGen/src/OptimizeFinalX64.cpp @@ -121,8 +122,6 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/EmitCommonA64.h CodeGen/src/EmitCommonX64.h CodeGen/src/EmitInstructionX64.h - CodeGen/src/Fallbacks.h - CodeGen/src/FallbacksProlog.h CodeGen/src/IrLoweringA64.h CodeGen/src/IrLoweringX64.h CodeGen/src/IrRegAllocA64.h @@ -169,6 +168,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/RecursionCounter.h Analysis/include/Luau/RequireTracer.h Analysis/include/Luau/Scope.h + Analysis/include/Luau/Simplify.h Analysis/include/Luau/Substitution.h Analysis/include/Luau/Symbol.h Analysis/include/Luau/ToDot.h @@ -183,7 +183,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeFamily.h Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypePack.h - Analysis/include/Luau/TypeReduction.h Analysis/include/Luau/TypeUtils.h Analysis/include/Luau/Type.h Analysis/include/Luau/Unifiable.h @@ -220,6 +219,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp + Analysis/src/Simplify.cpp Analysis/src/Substitution.cpp Analysis/src/Symbol.cpp Analysis/src/ToDot.cpp @@ -234,7 +234,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypeFamily.cpp Analysis/src/TypeInfer.cpp Analysis/src/TypePack.cpp - Analysis/src/TypeReduction.cpp Analysis/src/TypeUtils.cpp Analysis/src/Type.cpp Analysis/src/Unifiable.cpp @@ -378,6 +377,7 @@ if(TARGET Luau.UnitTest) tests/Parser.test.cpp tests/RequireTracer.test.cpp tests/RuntimeLimits.test.cpp + tests/Simplify.test.cpp tests/StringUtils.test.cpp tests/Symbol.test.cpp tests/ToDot.test.cpp @@ -412,7 +412,6 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.unionTypes.test.cpp tests/TypeInfer.unknownnever.test.cpp tests/TypePack.test.cpp - tests/TypeReduction.test.cpp tests/TypeVar.test.cpp tests/Variant.test.cpp tests/VisitType.test.cpp diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index 5d6b760eb..2045768a3 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -143,6 +143,8 @@ long l; \ } +#ifndef LUA_VECTOR_SIZE #define LUA_VECTOR_SIZE 3 // must be 3 or 4 +#endif #define LUA_EXTRA_SIZE (LUA_VECTOR_SIZE - 2) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 0f4df6719..7f58d9635 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,6 +17,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauUniformTopHandling, false) + /* ** {====================================================== ** Error-recovery functions @@ -229,12 +231,14 @@ void luaD_checkCstack(lua_State* L) ** When returns, all the results are on the stack, starting at the original ** function position. */ -void luaD_call(lua_State* L, StkId func, int nResults) +void luaD_call(lua_State* L, StkId func, int nresults) { if (++L->nCcalls >= LUAI_MAXCCALLS) luaD_checkCstack(L); - if (luau_precall(L, func, nResults) == PCRLUA) + ptrdiff_t old_func = savestack(L, func); + + if (luau_precall(L, func, nresults) == PCRLUA) { // is a Lua function? L->ci->flags |= LUA_CALLINFO_RETURN; // luau_execute will stop after returning from the stack frame @@ -248,6 +252,9 @@ void luaD_call(lua_State* L, StkId func, int nResults) L->isactive = false; } + if (FFlag::LuauUniformTopHandling && nresults != LUA_MULTRET) + L->top = restorestack(L, old_func) + nresults; + L->nCcalls--; luaC_checkGC(L); } diff --git a/VM/src/ldo.h b/VM/src/ldo.h index eac9927c2..0f7b42ad4 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -44,7 +44,7 @@ typedef void (*Pfunc)(lua_State* L, void* ud); LUAI_FUNC CallInfo* luaD_growCI(lua_State* L); -LUAI_FUNC void luaD_call(lua_State* L, StkId func, int nResults); +LUAI_FUNC void luaD_call(lua_State* L, StkId func, int nresults); LUAI_FUNC int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t oldtop, ptrdiff_t ef); LUAI_FUNC void luaD_reallocCI(lua_State* L, int newsize); LUAI_FUNC void luaD_reallocstack(lua_State* L, int newsize); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 2230a748f..569c1b4e5 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -32,9 +32,8 @@ Proto* luaF_newproto(lua_State* L) f->debugname = NULL; f->debuginsn = NULL; -#if LUA_CUSTOM_EXECUTION f->execdata = NULL; -#endif + f->exectarget = 0; return f; } diff --git a/VM/src/lobject.h b/VM/src/lobject.h index f0471c255..21b8de018 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -275,9 +275,8 @@ typedef struct Proto TString* debugname; uint8_t* debuginsn; // a copy of code[] array with just opcodes -#if LUA_CUSTOM_EXECUTION void* execdata; -#endif + uintptr_t exectarget; GCObject* gclist; diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 32a240bfb..ae1e18664 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -69,6 +69,7 @@ typedef struct CallInfo #define LUA_CALLINFO_RETURN (1 << 0) // should the interpreter return after returning from this callinfo? first frame must have this set #define LUA_CALLINFO_HANDLE (1 << 1) // should the error thrown during execution get handled by continuation from this callinfo? func must be C +#define LUA_CALLINFO_CUSTOM (1 << 2) // should this function be executed using custom execution callback #define curr_func(L) (clvalue(L->ci->func)) #define ci_func(ci) (clvalue((ci)->func)) diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 5565bfefc..454a4e178 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,6 +16,8 @@ #include +LUAU_FASTFLAG(LuauUniformTopHandling) + // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -208,10 +210,11 @@ static void luau_execute(lua_State* L) LUAU_ASSERT(!isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black #if LUA_CUSTOM_EXECUTION - Proto* p = clvalue(L->ci->func)->l.p; - - if (p->execdata && !SingleStep) + if ((L->ci->flags & LUA_CALLINFO_CUSTOM) && !SingleStep) { + Proto* p = clvalue(L->ci->func)->l.p; + LUAU_ASSERT(p->execdata); + if (L->global->ecb.enter(L, p) == 0) return; } @@ -448,7 +451,7 @@ static void luau_execute(lua_State* L) LUAU_ASSERT(ttisstring(kv)); // fast-path: built-in table - if (ttistable(rb)) + if (LUAU_LIKELY(ttistable(rb))) { Table* h = hvalue(rb); @@ -565,7 +568,7 @@ static void luau_execute(lua_State* L) LUAU_ASSERT(ttisstring(kv)); // fast-path: built-in table - if (ttistable(rb)) + if (LUAU_LIKELY(ttistable(rb))) { Table* h = hvalue(rb); @@ -801,7 +804,7 @@ static void luau_execute(lua_State* L) TValue* kv = VM_KV(aux); LUAU_ASSERT(ttisstring(kv)); - if (ttistable(rb)) + if (LUAU_LIKELY(ttistable(rb))) { Table* h = hvalue(rb); // note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works @@ -954,6 +957,7 @@ static void luau_execute(lua_State* L) #if LUA_CUSTOM_EXECUTION if (LUAU_UNLIKELY(p->execdata && !SingleStep)) { + ci->flags = LUA_CALLINFO_CUSTOM; ci->savedpc = p->code; if (L->global->ecb.enter(L, p) == 1) @@ -1040,7 +1044,8 @@ static void luau_execute(lua_State* L) // we're done! if (LUAU_UNLIKELY(ci->flags & LUA_CALLINFO_RETURN)) { - L->top = res; + if (!FFlag::LuauUniformTopHandling) + L->top = res; goto exit; } @@ -1050,7 +1055,7 @@ static void luau_execute(lua_State* L) Proto* nextproto = nextcl->l.p; #if LUA_CUSTOM_EXECUTION - if (LUAU_UNLIKELY(nextproto->execdata && !SingleStep)) + if (LUAU_UNLIKELY((cip->flags & LUA_CALLINFO_CUSTOM) && !SingleStep)) { if (L->global->ecb.enter(L, nextproto) == 1) goto reentry; @@ -1333,7 +1338,7 @@ static void luau_execute(lua_State* L) // fast-path: number // Note that all jumps below jump by 1 in the "false" case to skip over aux - if (ttisnumber(ra) && ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(ra) && ttisnumber(rb))) { pc += nvalue(ra) <= nvalue(rb) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -1366,7 +1371,7 @@ static void luau_execute(lua_State* L) // fast-path: number // Note that all jumps below jump by 1 in the "true" case to skip over aux - if (ttisnumber(ra) && ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(ra) && ttisnumber(rb))) { pc += !(nvalue(ra) <= nvalue(rb)) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -1399,7 +1404,7 @@ static void luau_execute(lua_State* L) // fast-path: number // Note that all jumps below jump by 1 in the "false" case to skip over aux - if (ttisnumber(ra) && ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(ra) && ttisnumber(rb))) { pc += nvalue(ra) < nvalue(rb) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -1432,7 +1437,7 @@ static void luau_execute(lua_State* L) // fast-path: number // Note that all jumps below jump by 1 in the "true" case to skip over aux - if (ttisnumber(ra) && ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(ra) && ttisnumber(rb))) { pc += !(nvalue(ra) < nvalue(rb)) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -1464,7 +1469,7 @@ static void luau_execute(lua_State* L) StkId rc = VM_REG(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb) && ttisnumber(rc)) + if (LUAU_LIKELY(ttisnumber(rb) && ttisnumber(rc))) { setnvalue(ra, nvalue(rb) + nvalue(rc)); VM_NEXT(); @@ -1510,7 +1515,7 @@ static void luau_execute(lua_State* L) StkId rc = VM_REG(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb) && ttisnumber(rc)) + if (LUAU_LIKELY(ttisnumber(rb) && ttisnumber(rc))) { setnvalue(ra, nvalue(rb) - nvalue(rc)); VM_NEXT(); @@ -1556,7 +1561,7 @@ static void luau_execute(lua_State* L) StkId rc = VM_REG(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb) && ttisnumber(rc)) + if (LUAU_LIKELY(ttisnumber(rb) && ttisnumber(rc))) { setnvalue(ra, nvalue(rb) * nvalue(rc)); VM_NEXT(); @@ -1617,7 +1622,7 @@ static void luau_execute(lua_State* L) StkId rc = VM_REG(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb) && ttisnumber(rc)) + if (LUAU_LIKELY(ttisnumber(rb) && ttisnumber(rc))) { setnvalue(ra, nvalue(rb) / nvalue(rc)); VM_NEXT(); @@ -1764,7 +1769,7 @@ static void luau_execute(lua_State* L) TValue* kv = VM_KV(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(rb))) { setnvalue(ra, nvalue(rb) * nvalue(kv)); VM_NEXT(); @@ -1810,7 +1815,7 @@ static void luau_execute(lua_State* L) TValue* kv = VM_KV(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(rb))) { setnvalue(ra, nvalue(rb) / nvalue(kv)); VM_NEXT(); @@ -1976,7 +1981,7 @@ static void luau_execute(lua_State* L) StkId rb = VM_REG(LUAU_INSN_B(insn)); // fast-path - if (ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(rb))) { setnvalue(ra, -nvalue(rb)); VM_NEXT(); @@ -2019,7 +2024,7 @@ static void luau_execute(lua_State* L) StkId rb = VM_REG(LUAU_INSN_B(insn)); // fast-path #1: tables - if (ttistable(rb)) + if (LUAU_LIKELY(ttistable(rb))) { Table* h = hvalue(rb); @@ -2878,14 +2883,21 @@ int luau_precall(lua_State* L, StkId func, int nresults) if (!ccl->isC) { + Proto* p = ccl->l.p; + // fill unused parameters with nil StkId argi = L->top; - StkId argend = L->base + ccl->l.p->numparams; + StkId argend = L->base + p->numparams; while (argi < argend) setnilvalue(argi++); // complete missing arguments - L->top = ccl->l.p->is_vararg ? argi : ci->top; + L->top = p->is_vararg ? argi : ci->top; + + ci->savedpc = p->code; - L->ci->savedpc = ccl->l.p->code; +#if LUA_CUSTOM_EXECUTION + if (p->execdata) + ci->flags = LUA_CALLINFO_CUSTOM; +#endif return PCRLUA; } diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 3827681de..cdadfd76b 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -135,6 +135,19 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "BinaryImm") SINGLE_COMPARE(ror(x1, x2, 1), 0x93C20441); } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Bitfield") +{ + SINGLE_COMPARE(ubfiz(x1, x2, 37, 5), 0xD35B1041); + SINGLE_COMPARE(ubfx(x1, x2, 37, 5), 0xD365A441); + SINGLE_COMPARE(sbfiz(x1, x2, 37, 5), 0x935B1041); + SINGLE_COMPARE(sbfx(x1, x2, 37, 5), 0x9365A441); + + SINGLE_COMPARE(ubfiz(w1, w2, 17, 5), 0x530F1041); + SINGLE_COMPARE(ubfx(w1, w2, 17, 5), 0x53115441); + SINGLE_COMPARE(sbfiz(w1, w2, 17, 5), 0x130F1041); + SINGLE_COMPARE(sbfx(w1, w2, 17, 5), 0x13115441); +} + TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads") { // address forms @@ -481,6 +494,8 @@ TEST_CASE("LogTest") build.fcvt(s1, d2); + build.ubfx(x1, x2, 37, 5); + build.setLabel(l); build.ret(); @@ -513,6 +528,7 @@ TEST_CASE("LogTest") fmov d0,#0.25 tbz x0,#5,.L1 fcvt s1,d2 + ubfx x1,x2,#3705 .L1: ret )"; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index cf92843de..d66eb18e8 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3388,38 +3388,6 @@ TEST_CASE_FIXTURE(ACFixture, "globals_are_order_independent") CHECK(ac.entryMap.count("abc1")); } -TEST_CASE_FIXTURE(ACFixture, "type_reduction_is_hooked_up_to_autocomplete") -{ - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; - - check(R"( - type T = { x: (number & string)? } - - function f(thingamabob: T) - thingamabob.@1 - end - - function g(thingamabob: T) - thingama@2 - end - )"); - - ToStringOptions opts; - opts.exhaustive = true; - - auto ac1 = autocomplete('1'); - REQUIRE(ac1.entryMap.count("x")); - std::optional ty1 = ac1.entryMap.at("x").type; - REQUIRE(ty1); - CHECK("nil" == toString(*ty1, opts)); - - auto ac2 = autocomplete('2'); - REQUIRE(ac2.entryMap.count("thingamabob")); - std::optional ty2 = ac2.entryMap.at("thingamabob").type; - REQUIRE(ty2); - CHECK("{| x: nil |}" == toString(*ty2, opts)); -} - TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") { loadDefinition(R"( @@ -3490,8 +3458,6 @@ local c = b.@1 TEST_CASE_FIXTURE(ACFixture, "suggest_exported_types") { - ScopedFastFlag luauCopyExportedTypes{"LuauCopyExportedTypes", true}; - check(R"( export type Type = {a: number} local a: T@1 diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index 9174051cf..5e28e8d90 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -14,6 +14,7 @@ ClassFixture::ClassFixture() GlobalTypes& globals = frontend.globals; TypeArena& arena = globals.globalTypes; TypeId numberType = builtinTypes->numberType; + TypeId stringType = builtinTypes->stringType; unfreeze(arena); @@ -35,7 +36,7 @@ ClassFixture::ClassFixture() TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { - {"Method", {makeFunction(arena, childClassInstanceType, {}, {builtinTypes->stringType})}}, + {"Method", {makeFunction(arena, childClassInstanceType, {}, {stringType})}}, }; TypeId childClassType = arena.addType(ClassType{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); @@ -48,7 +49,7 @@ ClassFixture::ClassFixture() TypeId grandChildInstanceType = arena.addType(ClassType{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(grandChildInstanceType)->props = { - {"Method", {makeFunction(arena, grandChildInstanceType, {}, {builtinTypes->stringType})}}, + {"Method", {makeFunction(arena, grandChildInstanceType, {}, {stringType})}}, }; TypeId grandChildType = arena.addType(ClassType{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); @@ -61,7 +62,7 @@ ClassFixture::ClassFixture() TypeId anotherChildInstanceType = arena.addType(ClassType{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(anotherChildInstanceType)->props = { - {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {builtinTypes->stringType})}}, + {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {stringType})}}, }; TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); @@ -101,7 +102,7 @@ ClassFixture::ClassFixture() TypeId callableClassMetaType = arena.addType(TableType{}); TypeId callableClassType = arena.addType(ClassType{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); getMutable(callableClassMetaType)->props = { - {"__call", {makeFunction(arena, nullopt, {callableClassType, builtinTypes->stringType}, {builtinTypes->numberType})}}, + {"__call", {makeFunction(arena, nullopt, {callableClassType, stringType}, {numberType})}}, }; globals.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; @@ -114,7 +115,7 @@ ClassFixture::ClassFixture() }; // IndexableClass has a table indexer with a key type of 'number | string' and a return type of 'number' - addIndexableClass("IndexableClass", arena.addType(Luau::UnionType{{builtinTypes->stringType, numberType}}), numberType); + addIndexableClass("IndexableClass", arena.addType(Luau::UnionType{{stringType, numberType}}), numberType); // IndexableNumericKeyClass has a table indexer with a key type of 'number' and a return type of 'number' addIndexableClass("IndexableNumericKeyClass", numberType, numberType); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 6a2c528d6..9e5ae30e9 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -2,13 +2,13 @@ #include "lua.h" #include "lualib.h" #include "luacode.h" +#include "luacodegen.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/ModuleResolver.h" #include "Luau/TypeInfer.h" #include "Luau/StringUtils.h" #include "Luau/BytecodeBuilder.h" -#include "Luau/CodeGen.h" #include "Luau/Frontend.h" #include "doctest.h" @@ -159,8 +159,8 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n StateRef globalState(initialLuaState, lua_close); lua_State* L = globalState.get(); - if (codegen && !skipCodegen && Luau::CodeGen::isSupported()) - Luau::CodeGen::create(L); + if (codegen && !skipCodegen && luau_codegen_supported()) + luau_codegen_create(L); luaL_openlibs(L); @@ -213,8 +213,8 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); free(bytecode); - if (result == 0 && codegen && !skipCodegen && Luau::CodeGen::isSupported()) - Luau::CodeGen::compile(L, -1); + if (result == 0 && codegen && !skipCodegen && luau_codegen_supported()) + luau_codegen_compile(L, -1); int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX; @@ -1679,8 +1679,8 @@ TEST_CASE("HugeFunction") StateRef globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); - if (codegen && Luau::CodeGen::isSupported()) - Luau::CodeGen::create(L); + if (codegen && luau_codegen_supported()) + luau_codegen_create(L); luaL_openlibs(L); luaL_sandbox(L); @@ -1693,8 +1693,8 @@ TEST_CASE("HugeFunction") REQUIRE(result == 0); - if (codegen && Luau::CodeGen::isSupported()) - Luau::CodeGen::compile(L, -1); + if (codegen && luau_codegen_supported()) + luau_codegen_compile(L, -1); int status = lua_resume(L, nullptr, 0); REQUIRE(status == 0); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index 7b9339889..6bfb15901 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "ConstraintGraphBuilderFixture.h" -#include "Luau/TypeReduction.h" - namespace Luau { @@ -13,7 +11,6 @@ ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() { mainModule->name = "MainModule"; mainModule->humanReadableName = "MainModule"; - mainModule->reduction = std::make_unique(NotNull{&mainModule->internalTypes}, builtinTypes, NotNull{&ice}); BlockedType::DEPRECATED_nextIndex = 0; BlockedTypePack::nextIndex = 0; diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index e1213b931..2f5fbf1c9 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -1521,6 +1521,36 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "IntNumIntPeepholes") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp i1 = build.inst(IrCmd::LOAD_INT, build.vmReg(0)); + IrOp u1 = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); + IrOp ni1 = build.inst(IrCmd::INT_TO_NUM, i1); + IrOp nu1 = build.inst(IrCmd::UINT_TO_NUM, u1); + IrOp i2 = build.inst(IrCmd::NUM_TO_INT, ni1); + IrOp u2 = build.inst(IrCmd::NUM_TO_UINT, nu1); + build.inst(IrCmd::STORE_INT, build.vmReg(0), i2); + build.inst(IrCmd::STORE_INT, build.vmReg(1), u2); + build.inst(IrCmd::RETURN, build.constUint(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_INT R0 + %1 = LOAD_INT R1 + STORE_INT R0, %0 + STORE_INT R1, %1 + RETURN 2u + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 22530a25e..abdfea77c 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -350,6 +350,35 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); } +// Unions should never be cyclic, but we should clone them correctly even if +// they are. +TEST_CASE_FIXTURE(Fixture, "clone_cyclic_union") +{ + ScopedFastFlag sff{"LuauCloneCyclicUnions", true}; + + TypeArena src; + + TypeId u = src.addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}); + UnionType* uu = getMutable(u); + REQUIRE(uu); + + uu->options.push_back(u); + + TypeArena dest; + CloneState cloneState; + + TypeId cloned = clone(u, dest, cloneState); + REQUIRE(cloned); + + const UnionType* clonedUnion = get(cloned); + REQUIRE(clonedUnion); + REQUIRE(3 == clonedUnion->options.size()); + + CHECK(builtinTypes->numberType == clonedUnion->options[0]); + CHECK(builtinTypes->stringType == clonedUnion->options[1]); + CHECK(cloned == clonedUnion->options[2]); +} + TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") { ScopedFastFlag flags[] = { diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 26b3b00d4..93ea75103 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -494,7 +494,7 @@ struct NormalizeFixture : Fixture REQUIRE(node); AstStatTypeAlias* alias = node->as(); REQUIRE(alias); - TypeId* originalTy = getMainModule()->astOriginalResolvedTypes.find(alias->type); + TypeId* originalTy = getMainModule()->astResolvedTypes.find(alias->type); REQUIRE(originalTy); return normalizer.normalize(*originalTy); } @@ -732,15 +732,11 @@ TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_metatables_where_the_metatable_is_top_or_bottom") { - ScopedFastFlag sff{"LuauNormalizeMetatableFixes", true}; - CHECK("{ @metatable *error-type*, {| |} }" == toString(normal("Mt<{}, any> & Mt<{}, err>"))); } TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") { - ScopedFastFlag sff{"LuauNormalizeMetatableFixes", true}; - CHECK("never" == toString(normal("Mt<{}, number> & Mt<{}, string>"))); } diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp new file mode 100644 index 000000000..2052019ec --- /dev/null +++ b/tests/Simplify.test.cpp @@ -0,0 +1,508 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +#include "Luau/Simplify.h" + +using namespace Luau; + +namespace +{ + +struct SimplifyFixture : Fixture +{ + TypeArena _arena; + const NotNull arena{&_arena}; + + ToStringOptions opts; + + Scope scope{builtinTypes->anyTypePack}; + + const TypeId anyTy = builtinTypes->anyType; + const TypeId unknownTy = builtinTypes->unknownType; + const TypeId neverTy = builtinTypes->neverType; + const TypeId errorTy = builtinTypes->errorType; + + const TypeId functionTy = builtinTypes->functionType; + const TypeId tableTy = builtinTypes->tableType; + + const TypeId numberTy = builtinTypes->numberType; + const TypeId stringTy = builtinTypes->stringType; + const TypeId booleanTy = builtinTypes->booleanType; + const TypeId nilTy = builtinTypes->nilType; + const TypeId threadTy = builtinTypes->threadType; + + const TypeId classTy = builtinTypes->classType; + + const TypeId trueTy = builtinTypes->trueType; + const TypeId falseTy = builtinTypes->falseType; + + const TypeId truthyTy = builtinTypes->truthyType; + const TypeId falsyTy = builtinTypes->falsyType; + + const TypeId freeTy = arena->addType(FreeType{&scope}); + const TypeId genericTy = arena->addType(GenericType{}); + const TypeId blockedTy = arena->addType(BlockedType{}); + const TypeId pendingTy = arena->addType(PendingExpansionType{{}, {}, {}, {}}); + + const TypeId helloTy = arena->addType(SingletonType{StringSingleton{"hello"}}); + const TypeId worldTy = arena->addType(SingletonType{StringSingleton{"world"}}); + + const TypePackId emptyTypePack = arena->addTypePack({}); + + const TypeId fn1Ty = arena->addType(FunctionType{emptyTypePack, emptyTypePack}); + const TypeId fn2Ty = arena->addType(FunctionType{builtinTypes->anyTypePack, emptyTypePack}); + + TypeId parentClassTy = nullptr; + TypeId childClassTy = nullptr; + TypeId anotherChildClassTy = nullptr; + TypeId unrelatedClassTy = nullptr; + + SimplifyFixture() + { + createSomeClasses(&frontend); + + parentClassTy = frontend.globals.globalScope->linearSearchForBinding("Parent")->typeId; + childClassTy = frontend.globals.globalScope->linearSearchForBinding("Child")->typeId; + anotherChildClassTy = frontend.globals.globalScope->linearSearchForBinding("AnotherChild")->typeId; + unrelatedClassTy = frontend.globals.globalScope->linearSearchForBinding("Unrelated")->typeId; + } + + TypeId intersect(TypeId a, TypeId b) + { + return simplifyIntersection(builtinTypes, arena, a, b).result; + } + + std::string intersectStr(TypeId a, TypeId b) + { + return toString(intersect(a, b), opts); + } + + bool isIntersection(TypeId a) + { + return bool(get(follow(a))); + } + + TypeId mkTable(std::map propTypes) + { + TableType::Props props; + for (const auto& [name, ty] : propTypes) + props[name] = Property{ty}; + + return arena->addType(TableType{props, {}, TypeLevel{}, TableState::Sealed}); + } + + TypeId mkNegation(TypeId ty) + { + return arena->addType(NegationType{ty}); + } + + TypeId mkFunction(TypeId arg, TypeId ret) + { + return arena->addType(FunctionType{arena->addTypePack({arg}), arena->addTypePack({ret})}); + } + + TypeId union_(TypeId a, TypeId b) + { + return simplifyUnion(builtinTypes, arena, a, b).result; + } +}; + +} // namespace + +TEST_SUITE_BEGIN("Simplify"); + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_other_tops_and_bottom_types") +{ + CHECK(unknownTy == intersect(unknownTy, unknownTy)); + + CHECK(unknownTy == intersect(unknownTy, anyTy)); + CHECK(unknownTy == intersect(anyTy, unknownTy)); + + CHECK(neverTy == intersect(unknownTy, neverTy)); + CHECK(neverTy == intersect(neverTy, unknownTy)); + + CHECK(neverTy == intersect(unknownTy, errorTy)); + CHECK(neverTy == intersect(errorTy, unknownTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "nil") +{ + CHECK(nilTy == intersect(nilTy, nilTy)); + CHECK(neverTy == intersect(nilTy, numberTy)); + CHECK(neverTy == intersect(nilTy, trueTy)); + CHECK(neverTy == intersect(nilTy, tableTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "boolean_singletons") +{ + CHECK(trueTy == intersect(trueTy, booleanTy)); + CHECK(trueTy == intersect(booleanTy, trueTy)); + + CHECK(falseTy == intersect(falseTy, booleanTy)); + CHECK(falseTy == intersect(booleanTy, falseTy)); + + CHECK(neverTy == intersect(falseTy, trueTy)); + CHECK(neverTy == intersect(trueTy, falseTy)); + + CHECK(booleanTy == union_(trueTy, booleanTy)); + CHECK(booleanTy == union_(booleanTy, trueTy)); + CHECK(booleanTy == union_(falseTy, booleanTy)); + CHECK(booleanTy == union_(booleanTy, falseTy)); + CHECK(booleanTy == union_(falseTy, trueTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "boolean_and_truthy_and_falsy") +{ + TypeId optionalBooleanTy = arena->addType(UnionType{{booleanTy, nilTy}}); + + CHECK(trueTy == intersect(booleanTy, truthyTy)); + + CHECK(trueTy == intersect(optionalBooleanTy, truthyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "any_and_indeterminate_types") +{ + CHECK("a" == intersectStr(anyTy, freeTy)); + CHECK("a" == intersectStr(freeTy, anyTy)); + + CHECK("b" == intersectStr(anyTy, genericTy)); + CHECK("b" == intersectStr(genericTy, anyTy)); + + CHECK(blockedTy == intersect(anyTy, blockedTy)); + CHECK(blockedTy == intersect(blockedTy, anyTy)); + + CHECK(pendingTy == intersect(anyTy, pendingTy)); + CHECK(pendingTy == intersect(pendingTy, anyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_indeterminate_types") +{ + CHECK(isIntersection(intersect(unknownTy, freeTy))); + CHECK(isIntersection(intersect(freeTy, unknownTy))); + + CHECK(isIntersection(intersect(unknownTy, genericTy))); + CHECK(isIntersection(intersect(genericTy, unknownTy))); + + CHECK(isIntersection(intersect(unknownTy, blockedTy))); + CHECK(isIntersection(intersect(blockedTy, unknownTy))); + + CHECK(isIntersection(intersect(unknownTy, pendingTy))); + CHECK(isIntersection(intersect(pendingTy, unknownTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_concrete") +{ + CHECK(numberTy == intersect(numberTy, unknownTy)); + CHECK(numberTy == intersect(unknownTy, numberTy)); + CHECK(trueTy == intersect(trueTy, unknownTy)); + CHECK(trueTy == intersect(unknownTy, trueTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "error_and_other_tops_and_bottom_types") +{ + CHECK(errorTy == intersect(errorTy, errorTy)); + + CHECK(errorTy == intersect(errorTy, anyTy)); + CHECK(errorTy == intersect(anyTy, errorTy)); + + CHECK(neverTy == intersect(errorTy, neverTy)); + CHECK(neverTy == intersect(neverTy, errorTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "error_and_indeterminate_types") +{ + CHECK("*error-type* & a" == intersectStr(errorTy, freeTy)); + CHECK("*error-type* & a" == intersectStr(freeTy, errorTy)); + + CHECK("*error-type* & b" == intersectStr(errorTy, genericTy)); + CHECK("*error-type* & b" == intersectStr(genericTy, errorTy)); + + CHECK(isIntersection(intersect(errorTy, blockedTy))); + CHECK(isIntersection(intersect(blockedTy, errorTy))); + + CHECK(isIntersection(intersect(errorTy, pendingTy))); + CHECK(isIntersection(intersect(pendingTy, errorTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_concrete") +{ + CHECK(neverTy == intersect(numberTy, errorTy)); + CHECK(neverTy == intersect(errorTy, numberTy)); + CHECK(neverTy == intersect(trueTy, errorTy)); + CHECK(neverTy == intersect(errorTy, trueTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "primitives") +{ + // This shouldn't be possible, but we'll make it work even if it is. + TypeId numberTyDuplicate = arena->addType(PrimitiveType{PrimitiveType::Number}); + + CHECK(numberTy == intersect(numberTy, numberTyDuplicate)); + CHECK(neverTy == intersect(numberTy, stringTy)); + + CHECK(neverTy == intersect(neverTy, numberTy)); + CHECK(neverTy == intersect(numberTy, neverTy)); + + CHECK(neverTy == intersect(neverTy, functionTy)); + CHECK(neverTy == intersect(functionTy, neverTy)); + + CHECK(neverTy == intersect(neverTy, tableTy)); + CHECK(neverTy == intersect(tableTy, neverTy)); + + CHECK(numberTy == intersect(anyTy, numberTy)); + CHECK(numberTy == intersect(numberTy, anyTy)); + + CHECK(neverTy == intersect(stringTy, nilTy)); + CHECK(neverTy == intersect(nilTy, stringTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "primitives_and_falsy") +{ + CHECK(neverTy == intersect(numberTy, falsyTy)); + CHECK(neverTy == intersect(falsyTy, numberTy)); + + CHECK(nilTy == intersect(nilTy, falsyTy)); + CHECK(nilTy == intersect(falsyTy, nilTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "primitives_and_singletons") +{ + CHECK(helloTy == intersect(helloTy, stringTy)); + CHECK(helloTy == intersect(stringTy, helloTy)); + + CHECK(neverTy == intersect(worldTy, helloTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "functions") +{ + CHECK(fn1Ty == intersect(fn1Ty, functionTy)); + CHECK(fn1Ty == intersect(functionTy, fn1Ty)); + + // Intersections of functions are super weird if you think about it. + CHECK("(() -> ()) & ((...any) -> ())" == intersectStr(fn1Ty, fn2Ty)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negated_top_function_type") +{ + TypeId negatedFunctionTy = mkNegation(functionTy); + + CHECK(numberTy == intersect(numberTy, negatedFunctionTy)); + CHECK(numberTy == intersect(negatedFunctionTy, numberTy)); + + CHECK(falsyTy == intersect(falsyTy, negatedFunctionTy)); + CHECK(falsyTy == intersect(negatedFunctionTy, falsyTy)); + + TypeId f = mkFunction(stringTy, numberTy); + + CHECK(neverTy == intersect(f, negatedFunctionTy)); + CHECK(neverTy == intersect(negatedFunctionTy, f)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "optional_overloaded_function_and_top_function") +{ + // (((number) -> string) & ((string) -> number))? & ~function + + TypeId f1 = mkFunction(numberTy, stringTy); + TypeId f2 = mkFunction(stringTy, numberTy); + + TypeId f12 = arena->addType(IntersectionType{{f1, f2}}); + + TypeId t = arena->addType(UnionType{{f12, nilTy}}); + + TypeId notFunctionTy = mkNegation(functionTy); + + CHECK(nilTy == intersect(t, notFunctionTy)); + CHECK(nilTy == intersect(notFunctionTy, t)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negated_function_does_not_intersect_cleanly_with_truthy") +{ + // ~function & ~(false?) + // ~function & ~(false | nil) + // ~function & ~false & ~nil + + TypeId negatedFunctionTy = mkNegation(functionTy); + CHECK(isIntersection(intersect(negatedFunctionTy, truthyTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "tables") +{ + TypeId t1 = mkTable({{"tag", stringTy}}); + + CHECK(t1 == intersect(t1, tableTy)); + CHECK(neverTy == intersect(t1, functionTy)); + + TypeId t2 = mkTable({{"tag", helloTy}}); + + CHECK(t2 == intersect(t1, t2)); + CHECK(t2 == intersect(t2, t1)); + + TypeId t3 = mkTable({}); + + CHECK(t1 == intersect(t1, t3)); + CHECK(t1 == intersect(t3, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "tables_and_top_table") +{ + TypeId notTableType = mkNegation(tableTy); + TypeId t1 = mkTable({{"prop", stringTy}, {"another", numberTy}}); + + CHECK(t1 == intersect(t1, tableTy)); + CHECK(t1 == intersect(tableTy, t1)); + + CHECK(neverTy == intersect(t1, notTableType)); + CHECK(neverTy == intersect(notTableType, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "tables_and_truthy") +{ + TypeId t1 = mkTable({{"prop", stringTy}, {"another", numberTy}}); + + CHECK(t1 == intersect(t1, truthyTy)); + CHECK(t1 == intersect(truthyTy, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "table_with_a_tag") +{ + // {tag: string, prop: number} & {tag: "hello"} + // I think we can decline to simplify this: + TypeId t1 = mkTable({{"tag", stringTy}, {"prop", numberTy}}); + TypeId t2 = mkTable({{"tag", helloTy}}); + + CHECK("{| prop: number, tag: string |} & {| tag: \"hello\" |}" == intersectStr(t1, t2)); + CHECK("{| prop: number, tag: string |} & {| tag: \"hello\" |}" == intersectStr(t2, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "nested_table_tag_test") +{ + TypeId t1 = mkTable({ + {"subtable", mkTable({ + {"tag", helloTy}, + {"subprop", numberTy}, + })}, + {"prop", stringTy}, + }); + TypeId t2 = mkTable({ + {"subtable", mkTable({ + {"tag", helloTy}, + })}, + }); + + CHECK(t1 == intersect(t1, t2)); + CHECK(t1 == intersect(t2, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "union") +{ + TypeId t1 = arena->addType(UnionType{{numberTy, stringTy, nilTy, tableTy}}); + + CHECK(nilTy == intersect(t1, nilTy)); + // CHECK(nilTy == intersect(nilTy, t1)); // TODO? + + CHECK(builtinTypes->stringType == intersect(builtinTypes->optionalStringType, truthyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "two_unions") +{ + TypeId t1 = arena->addType(UnionType{{numberTy, booleanTy, stringTy, nilTy, tableTy}}); + + CHECK("false?" == intersectStr(t1, falsyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "curious_union") +{ + // (a & false) | (a & nil) + TypeId curious = + arena->addType(UnionType{{arena->addType(IntersectionType{{freeTy, falseTy}}), arena->addType(IntersectionType{{freeTy, nilTy}})}}); + + CHECK("(a & false) | (a & nil) | number" == toString(union_(curious, numberTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negations") +{ + TypeId notNumberTy = mkNegation(numberTy); + TypeId notStringTy = mkNegation(stringTy); + + CHECK(neverTy == intersect(numberTy, notNumberTy)); + + CHECK(numberTy == intersect(numberTy, notStringTy)); + CHECK(numberTy == intersect(notStringTy, numberTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "top_class_type") +{ + CHECK(neverTy == intersect(classTy, stringTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "classes") +{ + CHECK(childClassTy == intersect(childClassTy, parentClassTy)); + CHECK(childClassTy == intersect(parentClassTy, childClassTy)); + + CHECK(parentClassTy == union_(childClassTy, parentClassTy)); + CHECK(parentClassTy == union_(parentClassTy, childClassTy)); + + CHECK(neverTy == intersect(childClassTy, unrelatedClassTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negations_of_classes") +{ + TypeId notChildClassTy = mkNegation(childClassTy); + TypeId notParentClassTy = mkNegation(parentClassTy); + + CHECK(neverTy == intersect(childClassTy, notParentClassTy)); + CHECK(neverTy == intersect(notParentClassTy, childClassTy)); + + CHECK("Parent & ~Child" == intersectStr(notChildClassTy, parentClassTy)); + CHECK("Parent & ~Child" == intersectStr(parentClassTy, notChildClassTy)); + + CHECK(notParentClassTy == intersect(notChildClassTy, notParentClassTy)); + CHECK(notParentClassTy == intersect(notParentClassTy, notChildClassTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "intersection_of_intersection_of_a_free_type_can_result_in_removal_of_that_free_type") +{ + // a & string and number + // (a & number) & (string & number) + + TypeId t1 = arena->addType(IntersectionType{{freeTy, stringTy}}); + + CHECK(neverTy == intersect(t1, numberTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "some_tables_are_really_never") +{ + TypeId notAnyTy = mkNegation(anyTy); + + TypeId t1 = mkTable({{"someKey", notAnyTy}}); + + CHECK(neverTy == intersect(t1, numberTy)); + CHECK(neverTy == intersect(numberTy, t1)); + CHECK(neverTy == intersect(t1, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "simplify_stops_at_cycles") +{ + TypeId t = mkTable({}); + TableType* tt = getMutable(t); + REQUIRE(tt); + + TypeId t2 = mkTable({}); + TableType* t2t = getMutable(t2); + REQUIRE(t2t); + + tt->props["cyclic"] = Property{t2}; + t2t->props["cyclic"] = Property{t}; + + CHECK(t == intersect(t, anyTy)); + CHECK(t == intersect(anyTy, t)); + + CHECK(t2 == intersect(t2, anyTy)); + CHECK(t2 == intersect(anyTy, t2)); +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 160757e2d..39759c716 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -291,9 +291,9 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") { o.maxTypeLength = 30; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); } else { @@ -321,9 +321,9 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") { o.maxTypeLength = 30; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); } else { @@ -507,25 +507,25 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); CHECK(0 == opts.nameMap.types.size()); - const MetatableType* tMeta = get(tType); + const MetatableType* tMeta = get(follow(tType)); REQUIRE(tMeta); - TableType* tMeta2 = getMutable(tMeta->metatable); + TableType* tMeta2 = getMutable(follow(tMeta->metatable)); REQUIRE(tMeta2); REQUIRE(tMeta2->props.count("__index")); - const MetatableType* tMeta3 = get(tMeta2->props["__index"].type()); + const MetatableType* tMeta3 = get(follow(tMeta2->props["__index"].type())); REQUIRE(tMeta3); - TableType* tMeta4 = getMutable(tMeta3->metatable); + TableType* tMeta4 = getMutable(follow(tMeta3->metatable)); REQUIRE(tMeta4); REQUIRE(tMeta4->props.count("__index")); - TableType* tMeta5 = getMutable(tMeta4->props["__index"].type()); + TableType* tMeta5 = getMutable(follow(tMeta4->props["__index"].type())); REQUIRE(tMeta5); REQUIRE(tMeta5->props.count("one") > 0); - TableType* tMeta6 = getMutable(tMeta3->table); + TableType* tMeta6 = getMutable(follow(tMeta3->table)); REQUIRE(tMeta6); REQUIRE(tMeta6->props.count("two") > 0); diff --git a/tests/TxnLog.test.cpp b/tests/TxnLog.test.cpp index 78ab06445..bfd297657 100644 --- a/tests/TxnLog.test.cpp +++ b/tests/TxnLog.test.cpp @@ -25,6 +25,8 @@ struct TxnLogFixture TypeId a = arena.freshType(globalScope.get()); TypeId b = arena.freshType(globalScope.get()); TypeId c = arena.freshType(childScope.get()); + + TypeId g = arena.addType(GenericType{"G"}); }; TEST_SUITE_BEGIN("TxnLog"); @@ -110,4 +112,13 @@ TEST_CASE_FIXTURE(TxnLogFixture, "colliding_coincident_logs_do_not_create_degene CHECK("a" == toString(b)); } +TEST_CASE_FIXTURE(TxnLogFixture, "replacing_persistent_types_is_allowed_but_makes_the_log_radioactive") +{ + persist(g); + + log.replace(g, BoundType{a}); + + CHECK(log.radioactive); +} + TEST_SUITE_END(); diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index 9a101d992..b11b05d7f 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -20,7 +20,7 @@ struct FamilyFixture : Fixture swapFamily = TypeFamily{/* name */ "Swap", /* reducer */ [](std::vector tys, std::vector tps, NotNull arena, NotNull builtins, - NotNull log) -> TypeFamilyReductionResult { + NotNull log, NotNull scope, NotNull normalizer) -> TypeFamilyReductionResult { LUAU_ASSERT(tys.size() == 1); TypeId param = log->follow(tys.at(0)); @@ -78,18 +78,6 @@ TEST_CASE_FIXTURE(FamilyFixture, "basic_type_family") CHECK("Type family instance Swap is uninhabited" == toString(result.errors[0])); }; -TEST_CASE_FIXTURE(FamilyFixture, "type_reduction_reduces_families") -{ - if (!FFlag::DebugLuauDeferredConstraintResolution) - return; - - CheckResult result = check(R"( - local x: Swap & nil - )"); - - CHECK("never" == toString(requireType("x"))); -} - TEST_CASE_FIXTURE(FamilyFixture, "family_as_fn_ret") { if (!FFlag::DebugLuauDeferredConstraintResolution) @@ -202,4 +190,27 @@ TEST_CASE_FIXTURE(FamilyFixture, "function_internal_families") CHECK(toString(result.errors[0]) == "Type family instance Swap is uninhabited"); } +TEST_CASE_FIXTURE(Fixture, "add_family_at_work") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function add(a, b) + return a + b + end + + local a = add(1, 2) + local b = add(1, "foo") + local c = add("foo", 1) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(requireType("a")) == "number"); + CHECK(toString(requireType("b")) == "Add"); + CHECK(toString(requireType("c")) == "Add"); + CHECK(toString(result.errors[0]) == "Type family instance Add is uninhabited"); + CHECK(toString(result.errors[1]) == "Type family instance Add is uninhabited"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 84b057d5e..4feb3a6d0 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -736,6 +736,18 @@ TEST_CASE_FIXTURE(Fixture, "luau_print_is_not_special_without_the_flag") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "luau_print_incomplete") +{ + ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + + CheckResult result = check(R"( + local a: _luau_print + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("_luau_print requires one generic parameter", toString(result.errors[0])); +} + TEST_CASE_FIXTURE(Fixture, "instantiate_type_fun_should_not_trip_rbxassert") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.cfa.test.cpp b/tests/TypeInfer.cfa.test.cpp index 737429583..04aeb54b6 100644 --- a/tests/TypeInfer.cfa.test.cpp +++ b/tests/TypeInfer.cfa.test.cpp @@ -352,10 +352,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions") CHECK_EQ("\"err\"", toString(requireTypeAtPosition({13, 31}))); CHECK_EQ("E", toString(requireTypeAtPosition({14, 31}))); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("{| error: E, tag: \"err\" |}", toString(requireTypeAtPosition({16, 19}))); - else - CHECK_EQ("Err", toString(requireTypeAtPosition({16, 19}))); + CHECK_EQ("Err", toString(requireTypeAtPosition({16, 19}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_assert_x") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 607fc40aa..d9e4bbada 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -552,6 +552,8 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") local x : IndexableClass local y = x[true] )"); + + CHECK_EQ( toString(result.errors[0]), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible"); } @@ -560,6 +562,7 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") local x : IndexableClass x[true] = 42 )"); + CHECK_EQ( toString(result.errors[0]), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible"); } @@ -593,7 +596,10 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") local x : IndexableNumericKeyClass x["key"] = 1 )"); - CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(toString(result.errors[0]), "Key 'key' not found in class 'IndexableNumericKeyClass'"); + else + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); } { CheckResult result = check(R"( @@ -615,7 +621,10 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") local x : IndexableNumericKeyClass local y = x["key"] )"); - CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(toString(result.errors[0]), "Key 'key' not found in class 'IndexableNumericKeyClass'"); + else + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); } { CheckResult result = check(R"( diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 9712f0975..78f755874 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -358,6 +358,22 @@ TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") LUAU_REQUIRE_NO_ERRORS(result); } +// We had a bug where we'd look up the type of a recursive call using the DFG, +// not the bindings tables. As a result, we would erroneously use the +// generalized type of foo() in this recursive fragment. This creates a +// constraint cycle that doesn't always work itself out. +// +// The fix is for the DFG node within the scope of foo() to retain the +// ungeneralized type of foo. +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_calls_must_refer_to_the_ungeneralized_type") +{ + CheckResult result = check(R"( + function foo() + string.format('%s: %s', "51", foo()) + end + )"); +} + TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") { CheckResult result = check(R"( @@ -1029,7 +1045,7 @@ TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") LUAU_REQUIRE_NO_ERRORS(result); TypeId type = requireTypeAtPosition(Position(6, 14)); CHECK_EQ("(tbl, number, number) -> number", toString(type)); - auto ftv = get(type); + auto ftv = get(follow(type)); REQUIRE(ftv); CHECK(ftv->hasSelf); } @@ -1967,7 +1983,7 @@ TEST_CASE_FIXTURE(Fixture, "inner_frees_become_generic_in_dcr") LUAU_REQUIRE_NO_ERRORS(result); std::optional ty = findTypeAtPosition(Position{3, 19}); REQUIRE(ty); - CHECK(get(*ty)); + CHECK(get(follow(*ty))); } TEST_CASE_FIXTURE(Fixture, "function_exprs_are_generalized_at_signature_scope_not_enclosing") diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 99abf711e..738d3cd2b 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -132,40 +132,23 @@ TEST_CASE_FIXTURE(Fixture, "should_still_pick_an_overload_whose_arguments_are_un TEST_CASE_FIXTURE(Fixture, "propagates_name") { - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CheckResult result = check(R"( - type A={a:number} - type B={b:string} - - local c:A&B - local b = c - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK("{| a: number, b: string |}" == toString(requireType("b"))); - } - else - { - const std::string code = R"( - type A={a:number} - type B={b:string} + const std::string code = R"( + type A={a:number} + type B={b:string} - local c:A&B - local b = c - )"; + local c:A&B + local b = c + )"; - const std::string expected = R"( - type A={a:number} - type B={b:string} + const std::string expected = R"( + type A={a:number} + type B={b:string} - local c:A&B - local b:A&B=c - )"; + local c:A&B + local b:A&B=c + )"; - CHECK_EQ(expected, decorateWithTypes(code)); - } + CHECK_EQ(expected, decorateWithTypes(code)); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guaranteed_to_exist") @@ -328,11 +311,7 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") LUAU_REQUIRE_ERROR_COUNT(1, result); auto e = toString(result.errors[0]); - // In DCR, because of type normalization, we print a different error message - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Cannot add property 'z' to table '{| x: number, y: number |}'", e); - else - CHECK_EQ("Cannot add property 'z' to table 'X & Y'", e); + CHECK_EQ("Cannot add property 'z' to table 'X & Y'", e); } TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") @@ -406,10 +385,7 @@ local a: XYZ = 3 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into '{| x: number, y: number, z: number |}')"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' + CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' caused by: Not all intersection parts are compatible. Type 'number' could not be converted into 'X')"); } @@ -426,11 +402,7 @@ local b: number = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), R"(Type '{| x: number, y: number, z: number |}' could not be converted into 'number')"); - else - CHECK_EQ( - toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); + CHECK_EQ(toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); } TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") @@ -470,11 +442,7 @@ TEST_CASE_FIXTURE(Fixture, "intersect_bool_and_false") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), "Type 'false' could not be converted into 'true'"); - else - CHECK_EQ( - toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") @@ -486,14 +454,9 @@ TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), "Type 'false' could not be converted into 'true'"); - else - { - // TODO: odd stringification of `false & (boolean & false)`.) - CHECK_EQ(toString(result.errors[0]), - "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); - } + // TODO: odd stringification of `false & (boolean & false)`.) + CHECK_EQ(toString(result.errors[0]), + "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") @@ -531,21 +494,8 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: nil, r: number? |}' could not be converted into '{| p: nil |}'\n" - "caused by:\n" - " Property 'p' is not compatible. Type 'number?' could not be converted into 'nil'\n" - "caused by:\n" - " Not all union options are compatible. Type 'number' could not be converted into 'nil' in an invariant context"); - } - else - { - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into " - "'{| p: nil |}'; none of the intersection parts are compatible"); - } + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into " + "'{| p: nil |}'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") @@ -558,27 +508,9 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") local z : { p : string?, q : number? } = x -- Not OK )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" - "caused by:\n" - " Property 'p' is not compatible. Type 'number' could not be converted into 'string' in an invariant context"); - - CHECK_EQ(toString(result.errors[1]), - "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" - "caused by:\n" - " Property 'q' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into " - "'{| p: string?, q: number? |}'; none of the intersection parts are compatible"); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into " + "'{| p: string?, q: number? |}'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") @@ -605,18 +537,9 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(result.errors[0]), - "Type '((number?) -> {| p: number, q: number |}) & ((string?) -> {| p: number, r: number |})' could not be converted into " - "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); - } - else - { - CHECK_EQ(toString(result.errors[0]), - "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into " - "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); - } + CHECK_EQ(toString(result.errors[0]), + "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into " + "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") @@ -917,7 +840,8 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(never) -> never", toString(requireType("f"))); + // TODO? We do not simplify types from explicit annotations. + CHECK_EQ("({| x: number |} & {| x: string |}) -> {| x: number |} & {| x: string |}", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") @@ -933,7 +857,7 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(never) -> never", toString(requireType("f"))); + CHECK_EQ("({| x: number |} & {| x: string |}) -> never", toString(requireType("f"))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 90436ce7d..dd26cc86b 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -676,9 +676,19 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") src += "end"; CheckResult result = check(src); - LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); - CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // TODO: This will eventually entirely go away, but for now the Add + // family will ensure there's one less error. + LUAU_REQUIRE_ERROR_COUNT(ops.size() - 1, result); + CHECK_EQ("Unknown type used in - operation; consider adding a type annotation to 'a'", toString(result.errors[0])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); + CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") @@ -889,8 +899,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Unknown type used in + operation; consider adding a type annotation to 'x'"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("f")) == "(a, b) -> Add"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in + operation; consider adding a type annotation to 'x'"); + } result = check(Mode::Nonstrict, R"( local function f(x, y) @@ -985,31 +1003,6 @@ TEST_CASE_FIXTURE(Fixture, "unrelated_primitives_cannot_be_compared") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(BuiltinsFixture, "mm_ops_must_return_a_value") -{ - if (!FFlag::DebugLuauDeferredConstraintResolution) - return; - - CheckResult result = check(R"( - local mm = { - __add = function(self, other) - return - end, - } - - local x = setmetatable({}, mm) - local y = x + 123 - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK(requireType("y") == builtinTypes->errorRecoveryType()); - - const GenericError* ge = get(result.errors[1]); - REQUIRE(ge); - CHECK(ge->message == "Metamethod '__add' must return a value"); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") { if (!FFlag::DebugLuauDeferredConstraintResolution) @@ -1179,6 +1172,38 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.String.slice") +{ + + CheckResult result = check(R"( +--!strict +local function slice(str: string, startIndexStr: string | number, lastIndexStr: (string | number)?): string + local strLen, invalidBytePosition = utf8.len(str) + assert(strLen ~= nil, ("string `%s` has an invalid byte at position %s"):format(str, tostring(invalidBytePosition))) + local startIndex = tonumber(startIndexStr) + + + -- if no last index length set, go to str length + 1 + local lastIndex = strLen + 1 + + assert(typeof(lastIndex) == "number", "lastIndexStr should convert to number") + + if lastIndex > strLen then + lastIndex = strLen + 1 + end + + local startIndexByte = utf8.offset(str, startIndex) + + return string.sub(str, startIndexByte, startIndexByte) +end + +return slice + + + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.startswith") { // This test also exercises whether the binary operator == passes the correct expected type @@ -1204,5 +1229,24 @@ return startsWith LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "add_type_family_works") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function add(x, y) + return x + y + end + + local a = add(1, 2) + local b = add("foo", "bar") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(requireType("a")) == "number"); + CHECK(toString(requireType("b")) == "Add"); + CHECK(toString(result.errors[0]) == "Type family instance Add is uninhabited"); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 606a4f4af..885a9781c 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -532,7 +532,7 @@ return wrapStrictTable(Constants, "Constants") std::optional result = first(m->returnType); REQUIRE(result); if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("(any?) & ~table", toString(*result)); + CHECK_EQ("(any & ~table)?", toString(*result)); else CHECK_MESSAGE(get(*result), *result); } @@ -819,4 +819,61 @@ TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions_of_tab // CHECK("variable" == unknownProp->key); } +TEST_CASE_FIXTURE(Fixture, "expected_type_should_be_a_helpful_deduction_guide_for_function_calls") +{ + ScopedFastFlag sffs[]{ + {"LuauUnifyTwoOptions", true}, + {"LuauTypeMismatchInvarianceInError", true}, + }; + + CheckResult result = check(R"( + type Ref = { val: T } + + local function useRef(x: T): Ref + return { val = x } + end + + local x: Ref = useRef(nil) + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // This is actually wrong! Sort of. It's doing the wrong thing, it's actually asking whether + // `{| val: number? |} <: {| val: nil |}` + // instead of the correct way, which is + // `{| val: nil |} <: {| val: number? |}` + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(toString(result.errors[0]), R"(Type 'Ref' could not be converted into 'Ref' +caused by: + Property 'val' is not compatible. Type 'nil' could not be converted into 'number' in an invariant context)"); + } +} + +TEST_CASE_FIXTURE(Fixture, "floating_generics_should_not_be_allowed") +{ + CheckResult result = check(R"( + local assign : (target: T, source0: U?, source1: V?, source2: W?, ...any) -> T & U & V & W = (nil :: any) + + -- We have a big problem here: The generics U, V, and W are not bound to anything! + -- Things get strange because of this. + local benchmark = assign({}) + local options = benchmark.options + do + local resolve2: any = nil + options.fn({ + resolve = function(...) + resolve2(...) + end, + }) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index c55497ae4..3b0654a04 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1020,16 +1020,8 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(R"({| catfood: string, name: string, tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ(R"({| dogfood: string, name: string, tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); - } - else - { - CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); - } + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); } TEST_CASE_FIXTURE(Fixture, "discriminate_tag_with_implicit_else") @@ -1050,16 +1042,8 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag_with_implicit_else") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(R"({| catfood: string, name: string, tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ(R"({| dogfood: string, name: string, tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); - } - else - { - CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); - } + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); } TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") @@ -1403,7 +1387,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") if (FFlag::DebugLuauDeferredConstraintResolution) { CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("~string", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("unknown & ~string", toString(requireTypeAtPosition({5, 28}))); } else { @@ -1508,14 +1492,7 @@ local _ = _ ~= _ or _ or _ end )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Without a realistic motivating case, it's hard to tell if it's important for this to work without errors. - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(get(result.errors[0])); - } - else - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length") @@ -1615,7 +1592,7 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ("~false & ~nil", toString(requireTypeAtPosition({4, 30}))); + CHECK_EQ("~(false?)", toString(requireTypeAtPosition({4, 30}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 4b24fb225..82a20bc1a 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1059,11 +1059,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "unification_of_unions_in_a_self_referential_ const MetatableType* amtv = get(requireType("a")); REQUIRE(amtv); - CHECK_EQ(amtv->metatable, requireType("amt")); + CHECK_EQ(follow(amtv->metatable), follow(requireType("amt"))); const MetatableType* bmtv = get(requireType("b")); REQUIRE(bmtv); - CHECK_EQ(bmtv->metatable, requireType("bmt")); + CHECK_EQ(follow(bmtv->metatable), follow(requireType("bmt"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "oop_polymorphic") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index cbb04cba9..829f993a4 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -267,10 +267,7 @@ TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowi end )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_ERROR_COUNT(1, result); - else - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 441191664..afe0552cc 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1060,4 +1060,14 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_param_overflow") +{ + CheckResult result = check(R"( + type Two = { a: T, b: U } + local x: Two = { a = 1, b = 'c' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 960d6f15b..100abfb7f 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -354,10 +354,7 @@ a.x = 2 LUAU_REQUIRE_ERROR_COUNT(1, result); auto s = toString(result.errors[0]); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Value of type '{| x: number, y: number |}?' could be nil", s); - else - CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", s); + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", s); } TEST_CASE_FIXTURE(Fixture, "optional_length_error") @@ -870,4 +867,50 @@ TEST_CASE_FIXTURE(Fixture, "optional_class_instances_are_invariant") CHECK(expectedError == toString(result.errors[0])); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Map.entries") +{ + + fileResolver.source["Module/Map"] = R"( +--!strict + +type Object = { [any]: any } +type Array = { [number]: T } +type Table = { [T]: V } +type Tuple = Array + +local Map = {} + +export type Map = { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + [K]: V, + _map: { [K]: V }, + _array: { [number]: K }, +} + +function Map:entries() + return {} +end + +local function coerceToTable(mapLike: Map | Table): Array> + local e = mapLike:entries(); + return e +end + + )"; + + CheckResult result = frontend.check("Module/Map"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeReduction.test.cpp b/tests/TypeReduction.test.cpp deleted file mode 100644 index 5f11a71b7..000000000 --- a/tests/TypeReduction.test.cpp +++ /dev/null @@ -1,1509 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeReduction.h" - -#include "Fixture.h" -#include "doctest.h" - -using namespace Luau; - -namespace -{ -struct ReductionFixture : Fixture -{ - TypeReductionOptions typeReductionOpts{/* allowTypeReductionsFromOtherArenas */ true}; - ToStringOptions toStringOpts{true}; - - TypeArena arena; - InternalErrorReporter iceHandler; - UnifierSharedState unifierState{&iceHandler}; - TypeReduction reduction{NotNull{&arena}, builtinTypes, NotNull{&iceHandler}, typeReductionOpts}; - - ReductionFixture() - { - registerHiddenTypes(&frontend); - createSomeClasses(&frontend); - } - - TypeId reductionof(TypeId ty) - { - std::optional reducedTy = reduction.reduce(ty); - REQUIRE(reducedTy); - return *reducedTy; - } - - std::optional tryReduce(const std::string& annotation) - { - check("type _Res = " + annotation); - return reduction.reduce(requireTypeAlias("_Res")); - } - - TypeId reductionof(const std::string& annotation) - { - check("type _Res = " + annotation); - return reductionof(requireTypeAlias("_Res")); - } - - std::string toStringFull(TypeId ty) - { - return toString(ty, toStringOpts); - } -}; -} // namespace - -TEST_SUITE_BEGIN("TypeReductionTests"); - -TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_exceeded") -{ - ScopedFastInt sfi{"LuauTypeReductionCartesianProductLimit", 5}; - - CheckResult result = check(R"( - type T - = string - & (number | string | boolean) - & (number | string | boolean) - )"); - - CHECK(!reduction.reduce(requireTypeAlias("T"))); - // LUAU_REQUIRE_ERROR_COUNT(1, result); - // CHECK("Code is too complex to typecheck! Consider simplifying the code around this area" == toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_exceeded_with_normal_limit") -{ - CheckResult result = check(R"( - type T - = string -- 1 = 1 - & (number | string | boolean) -- 1 * 3 = 3 - & (number | string | boolean) -- 3 * 3 = 9 - & (number | string | boolean) -- 9 * 3 = 27 - & (number | string | boolean) -- 27 * 3 = 81 - & (number | string | boolean) -- 81 * 3 = 243 - & (number | string | boolean) -- 243 * 3 = 729 - & (number | string | boolean) -- 729 * 3 = 2187 - & (number | string | boolean) -- 2187 * 3 = 6561 - & (number | string | boolean) -- 6561 * 3 = 19683 - & (number | string | boolean) -- 19683 * 3 = 59049 - & (number | string) -- 59049 * 2 = 118098 - )"); - - CHECK(!reduction.reduce(requireTypeAlias("T"))); - // LUAU_REQUIRE_ERROR_COUNT(1, result); - // CHECK("Code is too complex to typecheck! Consider simplifying the code around this area" == toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_is_zero") -{ - ScopedFastInt sfi{"LuauTypeReductionCartesianProductLimit", 5}; - - CheckResult result = check(R"( - type T - = string - & (number | string | boolean) - & (number | string | boolean) - & never - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(ReductionFixture, "stress_test_recursion_limits") -{ - TypeId ty = arena.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}); - for (size_t i = 0; i < 20'000; ++i) - { - TableType table; - table.state = TableState::Sealed; - table.props["x"] = {ty}; - ty = arena.addType(IntersectionType{{arena.addType(table), arena.addType(table)}}); - } - - CHECK(!reduction.reduce(ty)); -} - -TEST_CASE_FIXTURE(ReductionFixture, "caching") -{ - SUBCASE("free_tables") - { - TypeId ty1 = arena.addType(TableType{}); - getMutable(ty1)->state = TableState::Free; - getMutable(ty1)->props["x"] = {builtinTypes->stringType}; - - TypeId ty2 = arena.addType(TableType{}); - getMutable(ty2)->state = TableState::Sealed; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("{- x: string -} & {| |}" == toStringFull(reductionof(intersectionTy))); - - getMutable(ty1)->state = TableState::Sealed; - CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); - } - - SUBCASE("unsealed_tables") - { - TypeId ty1 = arena.addType(TableType{}); - getMutable(ty1)->state = TableState::Unsealed; - getMutable(ty1)->props["x"] = {builtinTypes->stringType}; - - TypeId ty2 = arena.addType(TableType{}); - getMutable(ty2)->state = TableState::Sealed; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); - - getMutable(ty1)->state = TableState::Sealed; - CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); - } - - SUBCASE("free_types") - { - TypeId ty1 = arena.freshType(nullptr); - TypeId ty2 = arena.addType(TableType{}); - getMutable(ty2)->state = TableState::Sealed; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("a & {| |}" == toStringFull(reductionof(intersectionTy))); - - *asMutable(ty1) = BoundType{ty2}; - CHECK("{| |}" == toStringFull(reductionof(intersectionTy))); - } - - SUBCASE("we_can_see_that_the_cache_works_if_we_mutate_a_normally_not_mutated_type") - { - TypeId ty1 = arena.addType(BoundType{builtinTypes->stringType}); - TypeId ty2 = builtinTypes->numberType; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("never" == toStringFull(reductionof(intersectionTy))); // Bound & number ~ never - - *asMutable(ty1) = BoundType{ty2}; - CHECK("never" == toStringFull(reductionof(intersectionTy))); // Bound & number ~ number, but the cache is `never`. - } - - SUBCASE("ptr_eq_irreducible_unions") - { - TypeId unionTy = arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->numberType}}); - TypeId reducedTy = reductionof(unionTy); - REQUIRE(unionTy == reducedTy); - } - - SUBCASE("ptr_eq_irreducible_intersections") - { - TypeId intersectionTy = arena.addType(IntersectionType{{builtinTypes->stringType, arena.addType(GenericType{"G"})}}); - TypeId reducedTy = reductionof(intersectionTy); - REQUIRE(intersectionTy == reducedTy); - } - - SUBCASE("ptr_eq_free_table") - { - TypeId tableTy = arena.addType(TableType{}); - getMutable(tableTy)->state = TableState::Free; - - TypeId reducedTy = reductionof(tableTy); - REQUIRE(tableTy == reducedTy); - } - - SUBCASE("ptr_eq_unsealed_table") - { - TypeId tableTy = arena.addType(TableType{}); - getMutable(tableTy)->state = TableState::Unsealed; - - TypeId reducedTy = reductionof(tableTy); - REQUIRE(tableTy == reducedTy); - } -} // caching - -TEST_CASE_FIXTURE(ReductionFixture, "intersections_without_negations") -{ - SUBCASE("string_and_string") - { - TypeId ty = reductionof("string & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("never_and_string") - { - TypeId ty = reductionof("never & string"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_never") - { - TypeId ty = reductionof("string & never"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("unknown_and_string") - { - TypeId ty = reductionof("unknown & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_unknown") - { - TypeId ty = reductionof("string & unknown"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("any_and_string") - { - TypeId ty = reductionof("any & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_any") - { - TypeId ty = reductionof("string & any"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_number_and_string") - { - TypeId ty = reductionof("(string | number) & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_string_or_number") - { - TypeId ty = reductionof("string & (string | number)"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_a") - { - TypeId ty = reductionof(R"(string & "a")"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("boolean_and_true") - { - TypeId ty = reductionof("boolean & true"); - CHECK("true" == toStringFull(ty)); - } - - SUBCASE("boolean_and_a") - { - TypeId ty = reductionof(R"(boolean & "a")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("a_and_a") - { - TypeId ty = reductionof(R"("a" & "a")"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("a_and_b") - { - TypeId ty = reductionof(R"("a" & "b")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("a_and_true") - { - TypeId ty = reductionof(R"("a" & true)"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("a_and_true") - { - TypeId ty = reductionof(R"(true & false)"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_type_and_function") - { - TypeId ty = reductionof("() -> () & fun"); - CHECK("() -> ()" == toStringFull(ty)); - } - - SUBCASE("function_type_and_string") - { - TypeId ty = reductionof("() -> () & string"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("parent_and_child") - { - TypeId ty = reductionof("Parent & Child"); - CHECK("Child" == toStringFull(ty)); - } - - SUBCASE("child_and_parent") - { - TypeId ty = reductionof("Child & Parent"); - CHECK("Child" == toStringFull(ty)); - } - - SUBCASE("child_and_unrelated") - { - TypeId ty = reductionof("Child & Unrelated"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_table") - { - TypeId ty = reductionof("string & {}"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_child") - { - TypeId ty = reductionof("string & Child"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_function") - { - TypeId ty = reductionof("string & () -> ()"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_and_table") - { - TypeId ty = reductionof("() -> () & {}"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_and_class") - { - TypeId ty = reductionof("() -> () & Child"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_and_function") - { - TypeId ty = reductionof("() -> () & () -> ()"); - CHECK("(() -> ()) & (() -> ())" == toStringFull(ty)); - } - - SUBCASE("table_and_table") - { - TypeId ty = reductionof("{} & {}"); - CHECK("{| |}" == toStringFull(ty)); - } - - SUBCASE("table_and_metatable") - { - // No setmetatable in ReductionFixture, so we mix and match. - BuiltinsFixture fixture; - fixture.check(R"( - type Ty = {} & typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("{ @metatable { }, { } } & {| |}" == toStringFull(ty)); - } - - SUBCASE("a_and_string") - { - TypeId ty = reductionof(R"("a" & string)"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("reducible_function_and_function") - { - TypeId ty = reductionof("((string | string) -> (number | number)) & fun"); - CHECK("(string) -> number" == toStringFull(ty)); - } - - SUBCASE("string_and_error") - { - TypeId ty = reductionof("string & err"); - CHECK("*error-type* & string" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_number") - { - TypeId ty = reductionof("{ p: string } & { p: number }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_string") - { - TypeId ty = reductionof("{ p: string } & { p: string }"); - CHECK("{| p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_x_table_p_string_and_table_x_table_p_number") - { - TypeId ty = reductionof("{ x: { p: string } } & { x: { p: number } }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("table_p_and_table_q") - { - TypeId ty = reductionof("{ p: string } & { q: number }"); - CHECK("{| p: string, q: number |}" == toStringFull(ty)); - } - - SUBCASE("table_tag_a_or_table_tag_b_and_table_b") - { - TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { b: string }"); - CHECK("{| a: number, b: string, tag: string |} | {| b: string, tag: number |}" == toStringFull(ty)); - } - - SUBCASE("table_string_number_indexer_and_table_string_number_indexer") - { - TypeId ty = reductionof("{ [string]: number } & { [string]: number }"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("table_string_number_indexer_and_empty_table") - { - TypeId ty = reductionof("{ [string]: number } & {}"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("empty_table_table_string_number_indexer") - { - TypeId ty = reductionof("{} & { [string]: number }"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("string_number_indexer_and_number_number_indexer") - { - TypeId ty = reductionof("{ [string]: number } & { [number]: number }"); - CHECK("{number} & {| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_indexer_number_number") - { - TypeId ty = reductionof("{ p: string } & { [number]: number }"); - CHECK("{| [number]: number, p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_indexer_string_number") - { - TypeId ty = reductionof("{ p: string } & { [string]: number }"); - CHECK("{| [string]: number, p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_string_plus_indexer_string_number") - { - TypeId ty = reductionof("{ p: string } & { p: string, [string]: number }"); - CHECK("{| [string]: number, p: string |}" == toStringFull(ty)); - } - - SUBCASE("array_number_and_array_string") - { - TypeId ty = reductionof("{number} & {string}"); - CHECK("{never}" == toStringFull(ty)); - } - - SUBCASE("array_string_and_array_string") - { - TypeId ty = reductionof("{string} & {string}"); - CHECK("{string}" == toStringFull(ty)); - } - - SUBCASE("array_string_or_number_and_array_string") - { - TypeId ty = reductionof("{string | number} & {string}"); - CHECK("{string}" == toStringFull(ty)); - } - - SUBCASE("fresh_type_and_string") - { - TypeId freshTy = arena.freshType(nullptr); - TypeId ty = reductionof(arena.addType(IntersectionType{{freshTy, builtinTypes->stringType}})); - CHECK("a & string" == toStringFull(ty)); - } - - SUBCASE("string_and_fresh_type") - { - TypeId freshTy = arena.freshType(nullptr); - TypeId ty = reductionof(arena.addType(IntersectionType{{builtinTypes->stringType, freshTy}})); - CHECK("a & string" == toStringFull(ty)); - } - - SUBCASE("generic_and_string") - { - TypeId genericTy = arena.addType(GenericType{"G"}); - TypeId ty = reductionof(arena.addType(IntersectionType{{genericTy, builtinTypes->stringType}})); - CHECK("G & string" == toStringFull(ty)); - } - - SUBCASE("string_and_generic") - { - TypeId genericTy = arena.addType(GenericType{"G"}); - TypeId ty = reductionof(arena.addType(IntersectionType{{builtinTypes->stringType, genericTy}})); - CHECK("G & string" == toStringFull(ty)); - } - - SUBCASE("parent_and_child_or_parent_and_anotherchild_or_parent_and_unrelated") - { - TypeId ty = reductionof("Parent & (Child | AnotherChild | Unrelated)"); - CHECK("AnotherChild | Child" == toString(ty)); - } - - SUBCASE("parent_and_child_or_parent_and_anotherchild_or_parent_and_unrelated_2") - { - TypeId ty = reductionof("(Parent & Child) | (Parent & AnotherChild) | (Parent & Unrelated)"); - CHECK("AnotherChild | Child" == toString(ty)); - } - - SUBCASE("top_table_and_table") - { - TypeId ty = reductionof("tbl & {}"); - CHECK("{| |}" == toString(ty)); - } - - SUBCASE("top_table_and_non_table") - { - TypeId ty = reductionof("tbl & \"foo\""); - CHECK("never" == toString(ty)); - } - - SUBCASE("top_table_and_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = tbl & typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("{ @metatable { }, { } }" == toString(ty)); - } -} // intersections_without_negations - -TEST_CASE_FIXTURE(ReductionFixture, "intersections_with_negations") -{ - SUBCASE("nil_and_not_nil") - { - TypeId ty = reductionof("nil & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("nil_and_not_false") - { - TypeId ty = reductionof("nil & Not"); - CHECK("nil" == toStringFull(ty)); - } - - SUBCASE("string_or_nil_and_not_nil") - { - TypeId ty = reductionof("(string?) & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_nil_and_not_false_or_nil") - { - TypeId ty = reductionof("(string?) & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_nil_and_not_false_and_not_nil") - { - TypeId ty = reductionof("(string?) & Not & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("not_false_and_bool") - { - TypeId ty = reductionof("Not & boolean"); - CHECK("true" == toStringFull(ty)); - } - - SUBCASE("function_type_and_not_function") - { - TypeId ty = reductionof("() -> () & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_type_and_not_string") - { - TypeId ty = reductionof("() -> () & Not"); - CHECK("() -> ()" == toStringFull(ty)); - } - - SUBCASE("not_a_and_string_or_nil") - { - TypeId ty = reductionof(R"(Not<"a"> & (string | nil))"); - CHECK(R"((string & ~"a")?)" == toStringFull(ty)); - } - - SUBCASE("not_a_and_a") - { - TypeId ty = reductionof(R"(Not<"a"> & "a")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_a_and_b") - { - TypeId ty = reductionof(R"(Not<"a"> & "b")"); - CHECK(R"("b")" == toStringFull(ty)); - } - - SUBCASE("not_string_and_a") - { - TypeId ty = reductionof(R"(Not & "a")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_bool_and_true") - { - TypeId ty = reductionof("Not & true"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_string_and_true") - { - TypeId ty = reductionof("Not & true"); - CHECK("true" == toStringFull(ty)); - } - - SUBCASE("parent_and_not_child") - { - TypeId ty = reductionof("Parent & Not"); - CHECK("Parent & ~Child" == toStringFull(ty)); - } - - SUBCASE("not_child_and_parent") - { - TypeId ty = reductionof("Not & Parent"); - CHECK("Parent & ~Child" == toStringFull(ty)); - } - - SUBCASE("child_and_not_parent") - { - TypeId ty = reductionof("Child & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_parent_and_child") - { - TypeId ty = reductionof("Not & Child"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_parent_and_unrelated") - { - TypeId ty = reductionof("Not & Unrelated"); - CHECK("Unrelated" == toStringFull(ty)); - } - - SUBCASE("unrelated_and_not_parent") - { - TypeId ty = reductionof("Unrelated & Not"); - CHECK("Unrelated" == toStringFull(ty)); - } - - SUBCASE("not_unrelated_and_parent") - { - TypeId ty = reductionof("Not & Parent"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("parent_and_not_unrelated") - { - TypeId ty = reductionof("Parent & Not"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("reducible_function_and_not_function") - { - TypeId ty = reductionof("((string | string) -> (number | number)) & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_not_error") - { - TypeId ty = reductionof("string & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_not_number") - { - TypeId ty = reductionof("{ p: string } & { p: Not }"); - CHECK("{| p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_not_string") - { - TypeId ty = reductionof("{ p: string } & { p: Not }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("table_x_table_p_string_and_table_x_table_p_not_number") - { - TypeId ty = reductionof("{ x: { p: string } } & { x: { p: Not } }"); - CHECK("{| x: {| p: string |} |}" == toStringFull(ty)); - } - - SUBCASE("table_or_nil_and_truthy") - { - TypeId ty = reductionof("({ x: number | string }?) & Not"); - CHECK("{| x: number | string |}" == toString(ty)); - } - - SUBCASE("not_top_table_and_table") - { - TypeId ty = reductionof("Not & {}"); - CHECK("never" == toString(ty)); - } - - SUBCASE("not_top_table_and_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = Not & typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("never" == toString(ty)); - } -} // intersections_with_negations - -TEST_CASE_FIXTURE(ReductionFixture, "unions_without_negations") -{ - SUBCASE("never_or_string") - { - TypeId ty = reductionof("never | string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_never") - { - TypeId ty = reductionof("string | never"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("unknown_or_string") - { - TypeId ty = reductionof("unknown | string"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("string_or_unknown") - { - TypeId ty = reductionof("string | unknown"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("any_or_string") - { - TypeId ty = reductionof("any | string"); - CHECK("any" == toStringFull(ty)); - } - - SUBCASE("string_or_any") - { - TypeId ty = reductionof("string | any"); - CHECK("any" == toStringFull(ty)); - } - - SUBCASE("string_or_string_and_number") - { - TypeId ty = reductionof("string | (string & number)"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_string") - { - TypeId ty = reductionof("string | string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_number") - { - TypeId ty = reductionof("string | number"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("number_or_string") - { - TypeId ty = reductionof("number | string"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_number_or_string") - { - TypeId ty = reductionof("(string | number) | string"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_number_or_string_2") - { - TypeId ty = reductionof("string | (number | string)"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_string_or_number") - { - TypeId ty = reductionof("string | (string | number)"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_string_or_number_or_boolean") - { - TypeId ty = reductionof("string | (string | number | boolean)"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_string_or_boolean_or_number") - { - TypeId ty = reductionof("string | (string | boolean | number)"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_boolean_or_string_or_number") - { - TypeId ty = reductionof("string | (boolean | string | number)"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("boolean_or_string_or_number_or_string") - { - TypeId ty = reductionof("(boolean | string | number) | string"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("boolean_or_true") - { - TypeId ty = reductionof("boolean | true"); - CHECK("boolean" == toStringFull(ty)); - } - - SUBCASE("boolean_or_false") - { - TypeId ty = reductionof("boolean | false"); - CHECK("boolean" == toStringFull(ty)); - } - - SUBCASE("boolean_or_true_or_false") - { - TypeId ty = reductionof("boolean | true | false"); - CHECK("boolean" == toStringFull(ty)); - } - - SUBCASE("string_or_a") - { - TypeId ty = reductionof(R"(string | "a")"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("a_or_a") - { - TypeId ty = reductionof(R"("a" | "a")"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("a_or_b") - { - TypeId ty = reductionof(R"("a" | "b")"); - CHECK(R"("a" | "b")" == toStringFull(ty)); - } - - SUBCASE("a_or_b_or_string") - { - TypeId ty = reductionof(R"("a" | "b" | string)"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("unknown_or_any") - { - TypeId ty = reductionof("unknown | any"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("any_or_unknown") - { - TypeId ty = reductionof("any | unknown"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("function_type_or_function") - { - TypeId ty = reductionof("() -> () | fun"); - CHECK("function" == toStringFull(ty)); - } - - SUBCASE("function_or_string") - { - TypeId ty = reductionof("fun | string"); - CHECK("function | string" == toStringFull(ty)); - } - - SUBCASE("parent_or_child") - { - TypeId ty = reductionof("Parent | Child"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("child_or_parent") - { - TypeId ty = reductionof("Child | Parent"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("parent_or_unrelated") - { - TypeId ty = reductionof("Parent | Unrelated"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("parent_or_child_or_unrelated") - { - TypeId ty = reductionof("Parent | Child | Unrelated"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("parent_or_unrelated_or_child") - { - TypeId ty = reductionof("Parent | Unrelated | Child"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("parent_or_child_or_unrelated_or_child") - { - TypeId ty = reductionof("Parent | Child | Unrelated | Child"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("string_or_true") - { - TypeId ty = reductionof("string | true"); - CHECK("string | true" == toStringFull(ty)); - } - - SUBCASE("string_or_function") - { - TypeId ty = reductionof("string | () -> ()"); - CHECK("(() -> ()) | string" == toStringFull(ty)); - } - - SUBCASE("string_or_err") - { - TypeId ty = reductionof("string | err"); - CHECK("*error-type* | string" == toStringFull(ty)); - } - - SUBCASE("top_table_or_table") - { - TypeId ty = reductionof("tbl | {}"); - CHECK("table" == toString(ty)); - } - - SUBCASE("top_table_or_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = tbl | typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("table" == toString(ty)); - } - - SUBCASE("top_table_or_non_table") - { - TypeId ty = reductionof("tbl | number"); - CHECK("number | table" == toString(ty)); - } -} // unions_without_negations - -TEST_CASE_FIXTURE(ReductionFixture, "unions_with_negations") -{ - SUBCASE("string_or_not_string") - { - TypeId ty = reductionof("string | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_string_or_string") - { - TypeId ty = reductionof("Not | string"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_number_or_string") - { - TypeId ty = reductionof("Not | string"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("string_or_not_number") - { - TypeId ty = reductionof("string | Not"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("not_hi_or_string_and_not_hi") - { - TypeId ty = reductionof(R"(Not<"hi"> | (string & Not<"hi">))"); - CHECK(R"(~"hi")" == toStringFull(ty)); - } - - SUBCASE("string_and_not_hi_or_not_hi") - { - TypeId ty = reductionof(R"((string & Not<"hi">) | Not<"hi">)"); - CHECK(R"(~"hi")" == toStringFull(ty)); - } - - SUBCASE("string_or_not_never") - { - TypeId ty = reductionof("string | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_not_a") - { - TypeId ty = reductionof(R"(Not<"a"> | Not<"a">)"); - CHECK(R"(~"a")" == toStringFull(ty)); - } - - SUBCASE("not_a_or_a") - { - TypeId ty = reductionof(R"(Not<"a"> | "a")"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("a_or_not_a") - { - TypeId ty = reductionof(R"("a" | Not<"a">)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_string") - { - TypeId ty = reductionof(R"(Not<"a"> | string)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("string_or_not_a") - { - TypeId ty = reductionof(R"(string | Not<"a">)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_string_or_a") - { - TypeId ty = reductionof(R"(Not | "a")"); - CHECK(R"("a" | ~string)" == toStringFull(ty)); - } - - SUBCASE("a_or_not_string") - { - TypeId ty = reductionof(R"("a" | Not)"); - CHECK(R"("a" | ~string)" == toStringFull(ty)); - } - - SUBCASE("not_number_or_a") - { - TypeId ty = reductionof(R"(Not | "a")"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("a_or_not_number") - { - TypeId ty = reductionof(R"("a" | Not)"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("not_a_or_not_b") - { - TypeId ty = reductionof(R"(Not<"a"> | Not<"b">)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("boolean_or_not_false") - { - TypeId ty = reductionof("boolean | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("boolean_or_not_true") - { - TypeId ty = reductionof("boolean | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("false_or_not_false") - { - TypeId ty = reductionof("false | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("true_or_not_false") - { - TypeId ty = reductionof("true | Not"); - CHECK("~false" == toStringFull(ty)); - } - - SUBCASE("not_boolean_or_true") - { - TypeId ty = reductionof("Not | true"); - CHECK("~false" == toStringFull(ty)); - } - - SUBCASE("not_false_or_not_boolean") - { - TypeId ty = reductionof("Not | Not"); - CHECK("~false" == toStringFull(ty)); - } - - SUBCASE("function_type_or_not_function") - { - TypeId ty = reductionof("() -> () | Not"); - CHECK("(() -> ()) | ~function" == toStringFull(ty)); - } - - SUBCASE("not_parent_or_child") - { - TypeId ty = reductionof("Not | Child"); - CHECK("Child | ~Parent" == toStringFull(ty)); - } - - SUBCASE("child_or_not_parent") - { - TypeId ty = reductionof("Child | Not"); - CHECK("Child | ~Parent" == toStringFull(ty)); - } - - SUBCASE("parent_or_not_child") - { - TypeId ty = reductionof("Parent | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_child_or_parent") - { - TypeId ty = reductionof("Not | Parent"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("parent_or_not_unrelated") - { - TypeId ty = reductionof("Parent | Not"); - CHECK("~Unrelated" == toStringFull(ty)); - } - - SUBCASE("not_string_or_string_and_not_a") - { - TypeId ty = reductionof(R"(Not | (string & Not<"a">))"); - CHECK(R"(~"a")" == toStringFull(ty)); - } - - SUBCASE("not_string_or_not_string") - { - TypeId ty = reductionof("Not | Not"); - CHECK("~string" == toStringFull(ty)); - } - - SUBCASE("not_string_or_not_number") - { - TypeId ty = reductionof("Not | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_not_boolean") - { - TypeId ty = reductionof(R"(Not<"a"> | Not)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_boolean") - { - TypeId ty = reductionof(R"(Not<"a"> | boolean)"); - CHECK(R"(~"a")" == toStringFull(ty)); - } - - SUBCASE("string_or_err") - { - TypeId ty = reductionof("string | Not"); - CHECK("string | ~*error-type*" == toStringFull(ty)); - } - - SUBCASE("not_top_table_or_table") - { - TypeId ty = reductionof("Not | {}"); - CHECK("{| |} | ~table" == toString(ty)); - } - - SUBCASE("not_top_table_or_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = Not | typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("{ @metatable { }, { } } | ~table" == toString(ty)); - } -} // unions_with_negations - -TEST_CASE_FIXTURE(ReductionFixture, "tables") -{ - SUBCASE("reduce_props") - { - TypeId ty = reductionof("{ x: string | string, y: number | number }"); - CHECK("{| x: string, y: number |}" == toStringFull(ty)); - } - - SUBCASE("reduce_indexers") - { - TypeId ty = reductionof("{ [string | string]: number | number }"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("reduce_instantiated_type_parameters") - { - check(R"( - type Foo = { x: T } - local foo: Foo = { x = "hello" } - )"); - - TypeId ty = reductionof(requireType("foo")); - CHECK("Foo" == toString(ty)); - } - - SUBCASE("reduce_instantiated_type_pack_parameters") - { - check(R"( - type Foo = { x: () -> T... } - local foo: Foo = { x = function() return "hi", 5 end } - )"); - - TypeId ty = reductionof(requireType("foo")); - CHECK("Foo" == toString(ty)); - } - - SUBCASE("reduce_tables_within_tables") - { - TypeId ty = reductionof("{ x: { y: string & number } }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("array_of_never") - { - TypeId ty = reductionof("{never}"); - CHECK("{never}" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "metatables") -{ - SUBCASE("reduce_table_part") - { - TableType table; - table.state = TableState::Sealed; - table.props["x"] = {arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->stringType}})}; - TypeId tableTy = arena.addType(std::move(table)); - - TypeId ty = reductionof(arena.addType(MetatableType{tableTy, arena.addType(TableType{})})); - CHECK("{ @metatable { }, {| x: string |} }" == toStringFull(ty)); - } - - SUBCASE("reduce_metatable_part") - { - TableType table; - table.state = TableState::Sealed; - table.props["x"] = {arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->stringType}})}; - TypeId tableTy = arena.addType(std::move(table)); - - TypeId ty = reductionof(arena.addType(MetatableType{arena.addType(TableType{}), tableTy})); - CHECK("{ @metatable {| x: string |}, { } }" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "functions") -{ - SUBCASE("reduce_parameters") - { - TypeId ty = reductionof("(string | string) -> ()"); - CHECK("(string) -> ()" == toStringFull(ty)); - } - - SUBCASE("reduce_returns") - { - TypeId ty = reductionof("() -> (string | string)"); - CHECK("() -> string" == toStringFull(ty)); - } - - SUBCASE("reduce_parameters_and_returns") - { - TypeId ty = reductionof("(string | string) -> (number | number)"); - CHECK("(string) -> number" == toStringFull(ty)); - } - - SUBCASE("reduce_tail") - { - TypeId ty = reductionof("() -> ...(string | string)"); - CHECK("() -> (...string)" == toStringFull(ty)); - } - - SUBCASE("reduce_head_and_tail") - { - TypeId ty = reductionof("() -> (string | string, number | number, ...(boolean | boolean))"); - CHECK("() -> (string, number, ...boolean)" == toStringFull(ty)); - } - - SUBCASE("reduce_overloaded_functions") - { - TypeId ty = reductionof("((number | number) -> ()) & ((string | string) -> ())"); - CHECK("((number) -> ()) & ((string) -> ())" == toStringFull(ty)); - } -} // functions - -TEST_CASE_FIXTURE(ReductionFixture, "negations") -{ - SUBCASE("not_unknown") - { - TypeId ty = reductionof("Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_never") - { - TypeId ty = reductionof("Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_any") - { - TypeId ty = reductionof("Not"); - CHECK("any" == toStringFull(ty)); - } - - SUBCASE("not_not_reduction") - { - TypeId ty = reductionof("Not>"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_string") - { - TypeId ty = reductionof("Not"); - CHECK("~string" == toStringFull(ty)); - } - - SUBCASE("not_string_or_number") - { - TypeId ty = reductionof("Not"); - CHECK("~number & ~string" == toStringFull(ty)); - } - - SUBCASE("not_string_and_number") - { - TypeId ty = reductionof("Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_error") - { - TypeId ty = reductionof("Not"); - CHECK("~*error-type*" == toStringFull(ty)); - } -} // negations - -TEST_CASE_FIXTURE(ReductionFixture, "discriminable_unions") -{ - SUBCASE("cat_or_dog_and_dog") - { - TypeId ty = reductionof(R"(({ tag: "cat", catfood: string } | { tag: "dog", dogfood: string }) & { tag: "dog" })"); - CHECK(R"({| dogfood: string, tag: "dog" |})" == toStringFull(ty)); - } - - SUBCASE("cat_or_dog_and_not_dog") - { - TypeId ty = reductionof(R"(({ tag: "cat", catfood: string } | { tag: "dog", dogfood: string }) & { tag: Not<"dog"> })"); - CHECK(R"({| catfood: string, tag: "cat" |})" == toStringFull(ty)); - } - - SUBCASE("string_or_number_and_number") - { - TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { tag: string }"); - CHECK("{| a: number, tag: string |}" == toStringFull(ty)); - } - - SUBCASE("string_or_number_and_number") - { - TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { tag: number }"); - CHECK("{| b: string, tag: number |}" == toStringFull(ty)); - } - - SUBCASE("child_or_unrelated_and_parent") - { - TypeId ty = reductionof("({ tag: Child, x: number } | { tag: Unrelated, y: string }) & { tag: Parent }"); - CHECK("{| tag: Child, x: number |}" == toStringFull(ty)); - } - - SUBCASE("child_or_unrelated_and_not_parent") - { - TypeId ty = reductionof("({ tag: Child, x: number } | { tag: Unrelated, y: string }) & { tag: Not }"); - CHECK("{| tag: Unrelated, y: string |}" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "cycles") -{ - SUBCASE("recursively_defined_function") - { - check("type F = (f: F) -> ()"); - - TypeId ty = reductionof(requireTypeAlias("F")); - CHECK("t1 where t1 = (t1) -> ()" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_function_and_function") - { - check("type F = (f: F & fun) -> ()"); - - TypeId ty = reductionof(requireTypeAlias("F")); - CHECK("t1 where t1 = (function & t1) -> ()" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table") - { - check("type T = { x: T }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: t1 |}" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table_and_table") - { - check("type T = { x: T & {} }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: t1 & {| |} |}" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table_and_table_2") - { - check("type T = { x: T } & { x: number }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: number |} & {| x: t1 |}" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table_and_table_3") - { - check("type T = { x: T } & { x: T }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: t1 |} & {| x: t1 |}" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "string_singletons") -{ - TypeId ty = reductionof("(string & Not<\"A\">)?"); - CHECK("(string & ~\"A\")?" == toStringFull(ty)); -} - -TEST_CASE_FIXTURE(ReductionFixture, "string_singletons_2") -{ - TypeId ty = reductionof("Not<\"A\"> & Not<\"B\"> & (string?)"); - CHECK("(string & ~\"A\" & ~\"B\")?" == toStringFull(ty)); -} - -TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index b7511267f..8cdd36ea3 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -2,7 +2,6 @@ #include "Luau/Scope.h" #include "Luau/Type.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeReduction.h" #include "Luau/VisitType.h" #include "Fixture.h" diff --git a/tools/faillist.txt b/tools/faillist.txt index a26e5c9f9..fe3353a8b 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,13 +1,10 @@ -AnnotationTests.too_many_type_params AstQuery.last_argument_function_call_type -AutocompleteTest.autocomplete_response_perf1 BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy BuiltinTests.bad_select_should_not_crash -BuiltinTests.dont_add_definitions_to_persistent_types BuiltinTests.gmatch_definition BuiltinTests.math_max_checks_for_numbers BuiltinTests.select_slightly_out_of_range @@ -22,7 +19,6 @@ BuiltinTests.string_format_tostring_specifier_type_constraint BuiltinTests.string_format_use_correct_argument2 DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props -GenericsTests.apply_type_function_nested_generics2 GenericsTests.better_mismatch_error_messages GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.check_mutual_generic_functions @@ -35,6 +31,7 @@ GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_type_pack_parentheses GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument_2 +GenericsTests.infer_generic_function_function_argument_3 GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_lib_function_function_argument GenericsTests.instantiated_function_argument_names @@ -42,23 +39,24 @@ GenericsTests.no_stack_overflow_from_quantifying GenericsTests.self_recursive_instantiated_param IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect -isSubtype.any_is_unknown_union_error ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal -ProvisionalTests.bail_early_if_unification_is_too_complicated ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean -ProvisionalTests.generic_type_leak_to_module_interface_variadic +ProvisionalTests.expected_type_should_be_a_helpful_deduction_guide_for_function_calls ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns ProvisionalTests.luau-polyfill.Array.filter ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete -RefinementTest.type_guard_can_filter_for_intersection_of_tables +RefinementTest.discriminate_from_truthiness_of_x +RefinementTest.not_t_or_some_prop_of_t +RefinementTest.truthy_constraint_on_properties RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table +RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.checked_prop_too_early TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode @@ -71,9 +69,6 @@ TableTests.expected_indexer_value_type_extra TableTests.expected_indexer_value_type_extra_2 TableTests.explicitly_typed_table TableTests.explicitly_typed_table_with_indexer -TableTests.found_like_key_in_table_function_call -TableTests.found_like_key_in_table_property_access -TableTests.found_multiple_like_keys TableTests.fuzz_table_unify_instantiated_table TableTests.generic_table_instantiation_potential_regression TableTests.give_up_after_one_metatable_index_look_up @@ -92,7 +87,6 @@ TableTests.oop_polymorphic TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table -TableTests.result_is_always_any_if_lhs_is_any TableTests.result_is_bool_for_equality_operators_if_lhs_is_any TableTests.right_table_missing_key2 TableTests.shared_selfs @@ -101,7 +95,6 @@ TableTests.shared_selfs_through_metatables TableTests.table_call_metamethod_basic TableTests.table_simple_call TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors -TableTests.table_unification_4 TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon ToString.toStringDetailed2 @@ -122,7 +115,6 @@ TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_locations TypeAliases.type_alias_of_an_imported_recursive_generic_type -TypeFamilyTests.function_internal_families TypeInfer.check_type_infer_recursion_count TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError @@ -131,18 +123,14 @@ TypeInfer.follow_on_new_types_in_substitution TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.no_stack_overflow_from_isoptional -TypeInfer.no_stack_overflow_from_isoptional2 -TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter_2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferClasses.class_type_mismatch_with_name_conflict -TypeInferClasses.classes_without_overloaded_operators_cannot_be_added TypeInferClasses.index_instance_property -TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties -TypeInferClasses.warn_when_prop_almost_matches TypeInferFunctions.cannot_hoist_interior_defns_into_signature +TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite @@ -177,6 +165,8 @@ TypeInferOperators.CallOrOfFunctions TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops +TypeInferOperators.luau-polyfill.String.slice +TypeInferOperators.luau_polyfill_is_array TypeInferOperators.operator_eq_completely_incompatible TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs @@ -191,7 +181,6 @@ TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check TypePackTests.type_alias_backwards_compatible TypePackTests.type_alias_default_type_errors -TypePackTests.type_alias_type_packs_errors TypePackTests.unify_variadic_tails_in_arguments TypePackTests.variadic_packs TypeSingletons.function_call_with_singletons @@ -202,6 +191,7 @@ TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.table_properties_type_error_escapes TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeSingletons.widening_happens_almost_everywhere +UnionTypes.dont_allow_cyclic_unions_to_be_inferred UnionTypes.generic_function_with_optional_arg UnionTypes.index_on_a_union_type_with_missing_property UnionTypes.optional_union_follow diff --git a/tools/lvmexecute_split.py b/tools/lvmexecute_split.py deleted file mode 100644 index 6e64bcd0e..000000000 --- a/tools/lvmexecute_split.py +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/python3 -# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - -# This code can be used to split lvmexecute.cpp VM switch into separate functions for use as native code generation fallbacks -import sys -import re - -input = sys.stdin.readlines() - -inst = "" - -header = """// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand -#pragma once - -#include - -struct lua_State; -struct Closure; -typedef uint32_t Instruction; -typedef struct lua_TValue TValue; -typedef TValue* StkId; - -""" - -source = """// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details -// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand -#include "Fallbacks.h" -#include "FallbacksProlog.h" - -""" - -function = "" -signature = "" - -includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS", "LOP_SETLIST"] - -state = 0 - -# parse with the state machine -for line in input: - # find the start of an instruction - if state == 0: - match = re.match("\s+VM_CASE\((LOP_[A-Z_0-9]+)\)", line) - - if match: - inst = match[1] - signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)" - function = signature + "\n" - function += "{\n" - function += " [[maybe_unused]] Closure* cl = clvalue(L->ci->func);\n" - state = 1 - - # first line of the instruction which is "{" - elif state == 1: - assert(line == " {\n") - state = 2 - - # find the end of an instruction - elif state == 2: - # remove jumps back into the native code - if line == "#if LUA_CUSTOM_EXECUTION\n": - state = 3 - continue - - if line[0] == ' ': - finalline = line[12:-1] + "\n" - else: - finalline = line - - finalline = finalline.replace("VM_NEXT();", "return pc;"); - finalline = finalline.replace("goto exit;", "return NULL;"); - finalline = finalline.replace("return;", "return NULL;"); - - function += finalline - match = re.match(" }", line) - - if match: - # break is not supported - if inst == "LOP_BREAK": - function = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)\n" - function += "{\n LUAU_ASSERT(!\"Unsupported deprecated opcode\");\n LUAU_UNREACHABLE();\n}\n" - # handle fallthrough - elif inst == "LOP_NAMECALL": - function = function[:-len(finalline)] - function += " return pc;\n}\n" - - if inst in includeInsts: - header += signature + ";\n" - source += function + "\n" - - state = 0 - - # skip LUA_CUSTOM_EXECUTION code blocks - elif state == 3: - if line == "#endif\n": - state = 4 - continue - - # skip extra line - elif state == 4: - state = 2 - -# make sure we found the ending -assert(state == 0) - -with open("Fallbacks.h", "w") as fp: - fp.writelines(header) - -with open("Fallbacks.cpp", "w") as fp: - fp.writelines(source) From 123496b29cdba68ef952dc196e6a2517b647749a Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 19 May 2023 12:11:10 -0700 Subject: [PATCH 55/66] gcc fix. --- CodeGen/src/CodeBlockUnwind.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index 59ee6f138..3c2a3f842 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -5,6 +5,7 @@ #include "Luau/UnwindBuilder.h" #include +#include #if defined(_WIN32) && defined(_M_X64) From b8e9d07b20142461212ee9f29759ec6f5c9ba073 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Thu, 25 May 2023 23:46:51 +0300 Subject: [PATCH 56/66] Sync to upstream/release/578 --- Analysis/include/Luau/Constraint.h | 3 + Analysis/include/Luau/ConstraintSolver.h | 4 +- Analysis/include/Luau/Error.h | 16 +- Analysis/include/Luau/Frontend.h | 2 - Analysis/include/Luau/Normalize.h | 11 +- Analysis/include/Luau/Substitution.h | 31 +++ Analysis/include/Luau/Type.h | 25 ++- Analysis/include/Luau/TypeUtils.h | 9 + Analysis/include/Luau/Unifier.h | 4 +- Analysis/include/Luau/VisitType.h | 35 ++- Analysis/src/AstQuery.cpp | 22 +- Analysis/src/Autocomplete.cpp | 18 +- Analysis/src/ConstraintGraphBuilder.cpp | 4 +- Analysis/src/ConstraintSolver.cpp | 85 +++++-- Analysis/src/Error.cpp | 27 ++- Analysis/src/Frontend.cpp | 212 +----------------- Analysis/src/IostreamHelpers.cpp | 4 + Analysis/src/Module.cpp | 17 ++ Analysis/src/Normalize.cpp | 65 +++++- Analysis/src/Substitution.cpp | 25 ++- Analysis/src/ToString.cpp | 50 ++++- Analysis/src/Type.cpp | 70 ++---- Analysis/src/TypeChecker2.cpp | 256 +++++++++++++++++++--- Analysis/src/TypeFamily.cpp | 12 +- Analysis/src/TypeInfer.cpp | 10 +- Analysis/src/Unifier.cpp | 37 +++- CodeGen/include/Luau/AssemblyBuilderA64.h | 2 + CodeGen/include/Luau/AssemblyBuilderX64.h | 1 + CodeGen/include/Luau/IrBuilder.h | 1 - CodeGen/include/Luau/IrData.h | 49 ++--- CodeGen/src/AssemblyBuilderA64.cpp | 5 + CodeGen/src/AssemblyBuilderX64.cpp | 9 + CodeGen/src/CodeAllocator.cpp | 8 +- CodeGen/src/CodeBlockUnwind.cpp | 4 +- CodeGen/src/CodeGen.cpp | 49 +++-- CodeGen/src/CodeGenA64.cpp | 17 +- CodeGen/src/CodeGenUtils.cpp | 4 +- CodeGen/src/CodeGenX64.cpp | 3 +- CodeGen/src/CustomExecUtils.h | 106 --------- CodeGen/src/EmitBuiltinsX64.cpp | 1 - CodeGen/src/EmitCommonX64.cpp | 1 - CodeGen/src/EmitCommonX64.h | 27 --- CodeGen/src/EmitInstructionX64.cpp | 10 +- CodeGen/src/IrAnalysis.cpp | 4 +- CodeGen/src/IrBuilder.cpp | 16 +- CodeGen/src/IrDump.cpp | 3 - CodeGen/src/IrLoweringA64.cpp | 138 ++++++++---- CodeGen/src/IrLoweringA64.h | 1 - CodeGen/src/IrLoweringX64.cpp | 91 +++++--- CodeGen/src/IrLoweringX64.h | 2 +- CodeGen/src/IrRegAllocX64.cpp | 4 +- CodeGen/src/IrTranslation.cpp | 10 +- CodeGen/src/IrTranslation.h | 35 +++ CodeGen/src/IrUtils.cpp | 6 +- CodeGen/src/NativeState.cpp | 1 - CodeGen/src/NativeState.h | 2 +- CodeGen/src/OptimizeConstProp.cpp | 47 +++- Common/include/Luau/Bytecode.h | 5 +- Compiler/src/Builtins.cpp | 3 +- Compiler/src/Compiler.cpp | 77 ++++++- Sources.cmake | 2 +- VM/src/lfunc.cpp | 2 +- VM/src/lobject.h | 1 + VM/src/lstate.cpp | 2 - VM/src/lstate.h | 4 +- VM/src/lvm.h | 3 +- VM/src/lvmexecute.cpp | 63 +++--- VM/src/lvmload.cpp | 55 ++++- tests/AssemblyBuilderA64.test.cpp | 5 + tests/AssemblyBuilderX64.test.cpp | 1 + tests/Compiler.test.cpp | 85 ++++++- tests/IrBuilder.test.cpp | 54 ++++- tests/TypeFamily.test.cpp | 32 +++ tests/TypeInfer.functions.test.cpp | 6 +- tests/TypeInfer.loops.test.cpp | 173 +++++++++++++-- tests/TypeInfer.operators.test.cpp | 7 +- tests/TypeInfer.rwprops.test.cpp | 70 ++++++ tests/TypeInfer.tryUnify.test.cpp | 2 +- tests/TypeInfer.unionTypes.test.cpp | 2 +- tests/TypeInfer.unknownnever.test.cpp | 12 + tests/conformance/native.lua | 23 ++ tools/faillist.txt | 3 - 82 files changed, 1615 insertions(+), 788 deletions(-) delete mode 100644 CodeGen/src/CustomExecUtils.h create mode 100644 tests/TypeInfer.rwprops.test.cpp diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index c815bef01..d73ba46df 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -81,6 +81,9 @@ struct IterableConstraint { TypePackId iterator; TypePackId variables; + + const AstNode* nextAstFragment; + DenseHashMap* astOverloadResolvedTypes; }; // name(namedType) = name diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 1a43a252e..ef87175ef 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -201,7 +201,7 @@ struct ConstraintSolver * @param subType the sub-type to unify. * @param superType the super-type to unify. */ - void unify(TypeId subType, TypeId superType, NotNull scope); + ErrorVec unify(TypeId subType, TypeId superType, NotNull scope); /** * Creates a new Unifier and performs a single unification operation. Commits @@ -209,7 +209,7 @@ struct ConstraintSolver * @param subPack the sub-type pack to unify. * @param superPack the super-type pack to unify. */ - void unify(TypePackId subPack, TypePackId superPack, NotNull scope); + ErrorVec unify(TypePackId subPack, TypePackId superPack, NotNull scope); /** Pushes a new solver constraint to the solver. * @param cv the body of the constraint. diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 6264a0b53..858d1b499 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -343,13 +343,27 @@ struct UninhabitedTypePackFamily bool operator==(const UninhabitedTypePackFamily& rhs) const; }; +struct WhereClauseNeeded +{ + TypeId ty; + + bool operator==(const WhereClauseNeeded& rhs) const; +}; + +struct PackWhereClauseNeeded +{ + TypePackId tp; + + bool operator==(const PackWhereClauseNeeded& rhs) const; +}; + using TypeErrorData = Variant; + NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe, UninhabitedTypeFamily, UninhabitedTypePackFamily, WhereClauseNeeded, PackWhereClauseNeeded>; struct TypeErrorSummary { diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 14bf2e2e5..1306ad2c7 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -182,8 +182,6 @@ struct Frontend std::optional getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); private: - CheckResult check_DEPRECATED(const ModuleName& name, std::optional optionOverride = {}); - struct TypeCheckLimits { std::optional finishTime; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 978ddb480..7d415e92f 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -285,14 +285,19 @@ class Normalizer std::unordered_map cachedIntersections; std::unordered_map cachedUnions; std::unordered_map> cachedTypeIds; + + DenseHashMap cachedIsInhabited{nullptr}; + DenseHashMap, bool, TypeIdPairHash> cachedIsInhabitedIntersection{{nullptr, nullptr}}; + bool withinResourceLimits(); public: TypeArena* arena; NotNull builtinTypes; NotNull sharedState; + bool cacheInhabitance = false; - Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState); + Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState, bool cacheInhabitance = false); Normalizer(const Normalizer&) = delete; Normalizer(Normalizer&&) = delete; Normalizer() = delete; @@ -355,8 +360,10 @@ class Normalizer bool normalizeIntersections(const std::vector& intersections, NormalizedType& outType); // Check for inhabitance - bool isInhabited(TypeId ty, std::unordered_set seen = {}); + bool isInhabited(TypeId ty); + bool isInhabited(TypeId ty, std::unordered_set seen); bool isInhabited(const NormalizedType* norm, std::unordered_set seen = {}); + // Check for intersections being inhabited bool isIntersectionInhabited(TypeId left, TypeId right); diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 2efca2df5..626c93ad1 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -113,6 +113,13 @@ struct Tarjan void visitChild(TypeId ty); void visitChild(TypePackId ty); + template + void visitChild(std::optional ty) + { + if (ty) + visitChild(*ty); + } + // Visit the root vertex. TarjanResult visitRoot(TypeId ty); TarjanResult visitRoot(TypePackId ty); @@ -127,10 +134,22 @@ struct Tarjan { return false; } + virtual bool ignoreChildren(TypePackId ty) { return false; } + + // Some subclasses might ignore children visit, but not other actions like replacing the children + virtual bool ignoreChildrenVisit(TypeId ty) + { + return ignoreChildren(ty); + } + + virtual bool ignoreChildrenVisit(TypePackId ty) + { + return ignoreChildren(ty); + } }; // We use Tarjan to calculate dirty bits. We set `dirty[i]` true @@ -186,8 +205,10 @@ struct Substitution : FindDirty TypeId replace(TypeId ty); TypePackId replace(TypePackId tp); + void replaceChildren(TypeId ty); void replaceChildren(TypePackId tp); + TypeId clone(TypeId ty); TypePackId clone(TypePackId tp); @@ -211,6 +232,16 @@ struct Substitution : FindDirty { return arena->addTypePack(TypePackVar{tp}); } + +private: + template + std::optional replace(std::optional ty) + { + if (ty) + return replace(*ty); + else + return std::nullopt; + } }; } // namespace Luau diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index d42f58b4b..9a35a1d6f 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -388,7 +388,13 @@ struct Property static Property writeonly(TypeId ty); static Property rw(TypeId ty); // Shared read-write type. static Property rw(TypeId read, TypeId write); // Separate read-write type. - static std::optional create(std::optional read, std::optional write); + + // Invariant: at least one of the two optionals are not nullopt! + // If the read type is not nullopt, but the write type is, then the property is readonly. + // If the read type is nullopt, but the write type is not, then the property is writeonly. + // If the read and write types are not nullopt, then the property is read and write. + // Otherwise, an assertion where read and write types are both nullopt will be tripped. + static Property create(std::optional read, std::optional write); bool deprecated = false; std::string deprecatedSuggestion; @@ -414,6 +420,8 @@ struct Property std::optional readType() const; std::optional writeType() const; + bool isShared() const; + private: std::optional readTy; std::optional writeTy; @@ -614,30 +622,26 @@ struct IntersectionType struct LazyType { LazyType() = default; - LazyType(std::function thunk_DEPRECATED, std::function unwrap) - : thunk_DEPRECATED(thunk_DEPRECATED) - , unwrap(unwrap) + LazyType(std::function unwrap) + : unwrap(unwrap) { } // std::atomic is sad and requires a manual copy LazyType(const LazyType& rhs) - : thunk_DEPRECATED(rhs.thunk_DEPRECATED) - , unwrap(rhs.unwrap) + : unwrap(rhs.unwrap) , unwrapped(rhs.unwrapped.load()) { } LazyType(LazyType&& rhs) noexcept - : thunk_DEPRECATED(std::move(rhs.thunk_DEPRECATED)) - , unwrap(std::move(rhs.unwrap)) + : unwrap(std::move(rhs.unwrap)) , unwrapped(rhs.unwrapped.load()) { } LazyType& operator=(const LazyType& rhs) { - thunk_DEPRECATED = rhs.thunk_DEPRECATED; unwrap = rhs.unwrap; unwrapped = rhs.unwrapped.load(); @@ -646,15 +650,12 @@ struct LazyType LazyType& operator=(LazyType&& rhs) noexcept { - thunk_DEPRECATED = std::move(rhs.thunk_DEPRECATED); unwrap = std::move(rhs.unwrap); unwrapped = rhs.unwrapped.load(); return *this; } - std::function thunk_DEPRECATED; - std::function unwrap; std::atomic unwrapped = nullptr; }; diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 86f20f387..5ead2fa4a 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -55,4 +55,13 @@ std::vector reduceUnion(const std::vector& types); */ TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty); +template +const T* get(std::optional ty) +{ + if (ty) + return get(*ty); + else + return nullptr; +} + } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index d5db06c8b..99da33f62 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -54,7 +54,6 @@ struct Unifier TypeArena* const types; NotNull builtinTypes; NotNull normalizer; - Mode mode; NotNull scope; // const Scope maybe TxnLog log; @@ -78,7 +77,7 @@ struct Unifier std::vector blockedTypePacks; Unifier( - NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); + NotNull normalizer, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); // Configure the Unifier to test for scope subsumption via embedded Scope // pointers rather than TypeLevels. @@ -154,7 +153,6 @@ struct Unifier LUAU_NOINLINE void reportError(Location location, TypeErrorData data); private: - bool isNonstrictMode() const; TypeMismatch::Context mismatchContext(); void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType); diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index b6dcf1f1b..1464aa1b5 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -10,6 +10,7 @@ LUAU_FASTINT(LuauVisitRecursionLimit) LUAU_FASTFLAG(LuauBoundLazyTypes2) +LUAU_FASTFLAG(DebugLuauReadWriteProperties) namespace Luau { @@ -250,7 +251,18 @@ struct GenericTypeVisitor else { for (auto& [_name, prop] : ttv->props) - traverse(prop.type()); + { + if (FFlag::DebugLuauReadWriteProperties) + { + if (auto ty = prop.readType()) + traverse(*ty); + + if (auto ty = prop.writeType()) + traverse(*ty); + } + else + traverse(prop.type()); + } if (ttv->indexer) { @@ -273,7 +285,18 @@ struct GenericTypeVisitor if (visit(ty, *ctv)) { for (const auto& [name, prop] : ctv->props) - traverse(prop.type()); + { + if (FFlag::DebugLuauReadWriteProperties) + { + if (auto ty = prop.readType()) + traverse(*ty); + + if (auto ty = prop.writeType()) + traverse(*ty); + } + else + traverse(prop.type()); + } if (ctv->parent) traverse(*ctv->parent); @@ -311,11 +334,9 @@ struct GenericTypeVisitor } else if (auto ltv = get(ty)) { - if (FFlag::LuauBoundLazyTypes2) - { - if (TypeId unwrapped = ltv->unwrapped) - traverse(unwrapped); - } + if (TypeId unwrapped = ltv->unwrapped) + traverse(unwrapped); + // Visiting into LazyType that hasn't been unwrapped may necessarily cause infinite expansion, so we don't do that on purpose. // Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassType // that doesn't need to be expanded. diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 38f3bdf5c..6a6f10e84 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -11,6 +11,8 @@ #include +LUAU_FASTFLAG(DebugLuauReadWriteProperties) + namespace Luau { @@ -501,12 +503,28 @@ std::optional getDocumentationSymbolAtPosition(const Source if (const TableType* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) - return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + { + if (FFlag::DebugLuauReadWriteProperties) + { + if (auto ty = propIt->second.readType()) + return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + } + else + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + } } else if (const ClassType* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) - return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + { + if (FFlag::DebugLuauReadWriteProperties) + { + if (auto ty = propIt->second.readType()) + return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + } + else + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + } } } } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 8dd747390..d67eda8d5 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAG(DebugLuauReadWriteProperties) + static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -138,7 +140,7 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; - Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); + Unifier unifier(NotNull{&normalizer}, scope, Location(), Variance::Covariant); // Cost of normalization can be too high for autocomplete response time requirements unifier.normalize = false; @@ -259,10 +261,22 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul // already populated, it takes precedence over the property we found just now. if (result.count(name) == 0 && name != kParseNameError) { - Luau::TypeId type = Luau::follow(prop.type()); + Luau::TypeId type; + + if (FFlag::DebugLuauReadWriteProperties) + { + if (auto ty = prop.readType()) + type = follow(*ty); + else + continue; + } + else + type = follow(prop.type()); + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); + ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index b190f4aba..c14f10e5a 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -755,8 +755,8 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* f // It is always ok to provide too few variables, so we give this pack a free tail. TypePackId variablePack = arena->addTypePack(std::move(variableTypes), arena->addTypePack(FreeTypePack{loopScope.get()})); - addConstraint(loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack}); - + addConstraint( + loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astOverloadResolvedTypes}); visit(loopScope, forIn->body); return ControlFlow::None; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 9c688f427..14d0df662 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -20,7 +20,6 @@ #include "Luau/VisitType.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); -LUAU_FASTFLAG(LuauRequirePathTrueModuleName) namespace Luau { @@ -252,6 +251,11 @@ struct InstantiationQueuer : TypeOnceVisitor solver->pushConstraint(scope, location, ReduceConstraint{ty}); return true; } + + bool visit(TypeId ty, const ClassType& ctv) override + { + return false; + } }; ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, @@ -749,7 +753,10 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullbooleanType; break; default: - mmResult = first(ftv->retTypes).value_or(errorRecoveryType()); + if (get(leftType) || get(rightType)) + mmResult = builtinTypes->neverType; + else + mmResult = first(ftv->retTypes).value_or(errorRecoveryType()); } asMutable(resultType)->ty.emplace(mmResult); @@ -785,6 +792,13 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) || get(rightType)) + { + unify(leftType, rightType, constraint->scope); + asMutable(resultType)->ty.emplace(builtinTypes->neverType); + unblock(resultType); + return true; + } break; } @@ -800,6 +814,13 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) || get(rightType)) + { + unify(leftType, rightType, constraint->scope); + asMutable(resultType)->ty.emplace(builtinTypes->neverType); + unblock(resultType); + return true; + } break; // Inexact comparisons require that the types be both numbers or both @@ -808,7 +829,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) || + get(rightType)) { asMutable(resultType)->ty.emplace(builtinTypes->booleanType); unblock(resultType); @@ -1291,7 +1313,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, Location{}, Covariant}; + Unifier u{normalizer, constraint->scope, Location{}, Covariant}; u.enableScopeTests(); u.tryUnify(*instantiated, inferredTy, /* isFunctionCall */ true); @@ -1344,7 +1366,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, Location{}, Covariant}; + Unifier u{normalizer, constraint->scope, Location{}, Covariant}; u.enableScopeTests(); u.tryUnify(inferredTy, builtinTypes->anyType); @@ -1746,6 +1768,11 @@ struct FindRefineConstraintBlockers : TypeOnceVisitor found.insert(ty); return false; } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } }; } @@ -1932,6 +1959,15 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl unify(*errorified, ty, constraint->scope); }; + auto neverify = [&](auto ty) { + Anyification anyify{arena, constraint->scope, builtinTypes, &iceReporter, builtinTypes->neverType, builtinTypes->neverTypePack}; + std::optional neverified = anyify.substitute(ty); + if (!neverified) + reportError(CodeTooComplex{}, constraint->location); + else + unify(*neverified, ty, constraint->scope); + }; + if (get(iteratorTy)) { anyify(c.variables); @@ -1944,6 +1980,12 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl return true; } + if (get(iteratorTy)) + { + neverify(c.variables); + return true; + } + // Irksome: I don't think we have any way to guarantee that this table // type never has a metatable. @@ -2072,7 +2114,11 @@ bool ConstraintSolver::tryDispatchIterableFunction( const TypePackId nextRetPack = arena->addTypePack(TypePack{{retIndex}, valueTailTy}); const TypeId expectedNextTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); - unify(nextTy, expectedNextTy, constraint->scope); + ErrorVec errors = unify(nextTy, expectedNextTy, constraint->scope); + + // if there are no errors from unifying the two, we can pass forward the expected type as our selected resolution. + if (errors.empty()) + (*c.astOverloadResolvedTypes)[c.nextAstFragment] = expectedNextTy; auto it = begin(nextRetPack); std::vector modifiedNextRetHead; @@ -2122,7 +2168,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa else if (auto ttv = getMutable(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) - return {{}, prop->second.type()}; + return {{}, FFlag::DebugLuauReadWriteProperties ? prop->second.readType() : prop->second.type()}; else if (ttv->indexer && maybeString(ttv->indexer->indexType)) return {{}, ttv->indexer->indexResultType}; else if (ttv->state == TableState::Free) @@ -2275,7 +2321,7 @@ static TypePackId getErrorType(NotNull builtinTypes, TypePackId) template bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, TID superTy) { - Unifier u{normalizer, Mode::Strict, constraint->scope, constraint->location, Covariant}; + Unifier u{normalizer, constraint->scope, constraint->location, Covariant}; u.enableScopeTests(); u.tryUnify(subTy, superTy); @@ -2379,12 +2425,17 @@ struct Blocker : TypeOnceVisitor { } - bool visit(TypeId ty, const PendingExpansionType&) + bool visit(TypeId ty, const PendingExpansionType&) override { blocked = true; solver->block(ty, constraint); return false; } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } }; bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull constraint) @@ -2492,9 +2543,9 @@ bool ConstraintSolver::isBlocked(NotNull constraint) return blockedIt != blockedConstraints.end() && blockedIt->second > 0; } -void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) +ErrorVec ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) { - Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; + Unifier u{normalizer, scope, Location{}, Covariant}; u.enableScopeTests(); u.tryUnify(subType, superType); @@ -2512,12 +2563,14 @@ void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull sc unblock(changedTypes); unblock(changedPacks); + + return std::move(u.errors); } -void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) +ErrorVec ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; + Unifier u{normalizer, scope, Location{}, Covariant}; u.enableScopeTests(); u.tryUnify(subPack, superPack); @@ -2528,6 +2581,8 @@ void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull ConstraintSolver::pushConstraint(NotNull scope, const Location& location, ConstraintV cv) @@ -2550,7 +2605,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l for (const auto& [location, path] : requireCycles) { - if (!path.empty() && path.front() == (FFlag::LuauRequirePathTrueModuleName ? info.name : moduleResolver->getHumanReadableModuleName(info.name))) + if (!path.empty() && path.front() == info.name) return builtinTypes->anyType; } @@ -2612,7 +2667,7 @@ TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, if (unifyFreeTypes && (get(a) || get(b))) { - Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; + Unifier u{normalizer, scope, Location{}, Covariant}; u.enableScopeTests(); u.tryUnify(b, a); diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 4f70be33f..fdc19d0c4 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauTypeMismatchInvarianceInError, false) -LUAU_FASTFLAGVARIABLE(LuauRequirePathTrueModuleName, false) static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) @@ -350,7 +349,7 @@ struct ErrorConverter else s += " -> "; - if (FFlag::LuauRequirePathTrueModuleName && fileResolver != nullptr) + if (fileResolver != nullptr) s += fileResolver->getHumanReadableModuleName(name); else s += name; @@ -494,6 +493,16 @@ struct ErrorConverter { return "Type pack family instance " + Luau::toString(e.tp) + " is uninhabited"; } + + std::string operator()(const WhereClauseNeeded& e) const + { + return "Type family instance " + Luau::toString(e.ty) + " depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time"; + } + + std::string operator()(const PackWhereClauseNeeded& e) const + { + return "Type pack family instance " + Luau::toString(e.tp) + " depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time"; + } }; struct InvalidNameChecker @@ -806,6 +815,16 @@ bool UninhabitedTypePackFamily::operator==(const UninhabitedTypePackFamily& rhs) return tp == rhs.tp; } +bool WhereClauseNeeded::operator==(const WhereClauseNeeded& rhs) const +{ + return ty == rhs.ty; +} + +bool PackWhereClauseNeeded::operator==(const PackWhereClauseNeeded& rhs) const +{ + return tp == rhs.tp; +} + std::string toString(const TypeError& error) { return toString(error, TypeErrorToStringOptions{}); @@ -968,6 +987,10 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) e.ty = clone(e.ty); else if constexpr (std::is_same_v) e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + e.tp = clone(e.tp); else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 07393eb12..062050fa6 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -34,9 +34,7 @@ LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) -LUAU_FASTFLAG(LuauRequirePathTrueModuleName) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) -LUAU_FASTFLAGVARIABLE(LuauSplitFrontendProcessing, false) LUAU_FASTFLAGVARIABLE(LuauTypeCheckerUseCorrectScope, false) namespace Luau @@ -349,9 +347,9 @@ std::vector getRequireCycles(const FileResolver* resolver, if (top == start) { for (const SourceNode* node : path) - cycle.push_back(FFlag::LuauRequirePathTrueModuleName ? node->name : node->humanReadableName); + cycle.push_back(node->name); - cycle.push_back(FFlag::LuauRequirePathTrueModuleName ? top->name : top->humanReadableName); + cycle.push_back(top->name); break; } } @@ -419,9 +417,6 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) { - if (!FFlag::LuauSplitFrontendProcessing) - return check_DEPRECATED(name, optionOverride); - LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); @@ -463,200 +458,6 @@ CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) -{ - LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); - - FrontendOptions frontendOptions = optionOverride.value_or(options); - CheckResult checkResult; - - FrontendModuleResolver& resolver = frontendOptions.forAutocomplete ? moduleResolverForAutocomplete : moduleResolver; - - auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second->hasDirtyModule(frontendOptions.forAutocomplete)) - { - // No recheck required. - ModulePtr module = resolver.getModule(name); - - if (!module) - throw InternalCompilerError("Frontend::modules does not have data for " + name, name); - - checkResult.errors = accumulateErrors(sourceNodes, resolver, name); - - // Get lint result only for top checked module - checkResult.lintResult = module->lintResult; - - return checkResult; - } - - std::vector buildQueue; - bool cycleDetected = parseGraph(buildQueue, name, frontendOptions.forAutocomplete); - - for (const ModuleName& moduleName : buildQueue) - { - LUAU_ASSERT(sourceNodes.count(moduleName)); - SourceNode& sourceNode = *sourceNodes[moduleName]; - - if (!sourceNode.hasDirtyModule(frontendOptions.forAutocomplete)) - continue; - - LUAU_ASSERT(sourceModules.count(moduleName)); - SourceModule& sourceModule = *sourceModules[moduleName]; - - const Config& config = configResolver->getConfig(moduleName); - - Mode mode = sourceModule.mode.value_or(config.mode); - - ScopePtr environmentScope = getModuleEnvironment(sourceModule, config, frontendOptions.forAutocomplete); - - double timestamp = getTimestamp(); - - std::vector requireCycles; - - // in NoCheck mode we only need to compute the value of .cyclic for typeck - // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself - // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term - // all correct programs must be acyclic so this code triggers rarely - if (cycleDetected) - requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck); - - // This is used by the type checker to replace the resulting type of cyclic modules with any - sourceModule.cyclic = !requireCycles.empty(); - - if (frontendOptions.forAutocomplete) - { - double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; - - // The autocomplete typecheck is always in strict mode with DM awareness - // to provide better type information for IDE features - TypeCheckLimits typeCheckLimits; - - if (autocompleteTimeLimit != 0.0) - typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; - else - typeCheckLimits.finishTime = std::nullopt; - - // TODO: This is a dirty ad hoc solution for autocomplete timeouts - // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit - // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle - if (FInt::LuauTarjanChildLimit > 0) - typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.instantiationChildLimit = std::nullopt; - - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.unifierIterationLimit = std::nullopt; - - ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, - /*recordJsonLog*/ false, typeCheckLimits); - - resolver.setModule(moduleName, moduleForAutocomplete); - - double duration = getTimestamp() - timestamp; - - if (moduleForAutocomplete->timeout) - { - checkResult.timeoutHits.push_back(moduleName); - - sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; - } - else if (duration < autocompleteTimeLimit / 2.0) - { - sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); - } - - stats.timeCheck += duration; - stats.filesStrict += 1; - - sourceNode.dirtyModuleForAutocomplete = false; - continue; - } - - const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson && moduleName == name; - ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, recordJsonLog, {}); - - stats.timeCheck += getTimestamp() - timestamp; - stats.filesStrict += mode == Mode::Strict; - stats.filesNonstrict += mode == Mode::Nonstrict; - - if (module == nullptr) - throw InternalCompilerError("Frontend::check produced a nullptr module for " + moduleName, moduleName); - - if (FFlag::DebugLuauDeferredConstraintResolution && mode == Mode::NoCheck) - module->errors.clear(); - - if (frontendOptions.runLintChecks) - { - LUAU_TIMETRACE_SCOPE("lint", "Frontend"); - - LintOptions lintOptions = frontendOptions.enabledLintWarnings.value_or(config.enabledLint); - filterLintOptions(lintOptions, sourceModule.hotcomments, mode); - - double timestamp = getTimestamp(); - - std::vector warnings = - Luau::lint(sourceModule.root, *sourceModule.names, environmentScope, module.get(), sourceModule.hotcomments, lintOptions); - - stats.timeLint += getTimestamp() - timestamp; - - module->lintResult = classifyLints(warnings, config); - } - - if (!frontendOptions.retainFullTypeGraphs) - { - // copyErrors needs to allocate into interfaceTypes as it copies - // types out of internalTypes, so we unfreeze it here. - unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); - freeze(module->interfaceTypes); - - module->internalTypes.clear(); - - module->astTypes.clear(); - module->astTypePacks.clear(); - module->astExpectedTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astOverloadResolvedTypes.clear(); - module->astResolvedTypes.clear(); - module->astResolvedTypePacks.clear(); - module->astScopes.clear(); - - module->scopes.clear(); - } - - if (mode != Mode::NoCheck) - { - for (const RequireCycle& cyc : requireCycles) - { - TypeError te{cyc.location, moduleName, ModuleHasCyclicDependency{cyc.path}}; - - module->errors.push_back(te); - } - } - - ErrorVec parseErrors; - - for (const ParseError& pe : sourceModule.parseErrors) - parseErrors.push_back(TypeError{pe.getLocation(), moduleName, SyntaxError{pe.what()}}); - - module->errors.insert(module->errors.begin(), parseErrors.begin(), parseErrors.end()); - - checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end()); - - resolver.setModule(moduleName, std::move(module)); - sourceNode.dirtyModule = false; - } - - // Get lint result only for top checked module - if (ModulePtr module = resolver.getModule(name)) - checkResult.lintResult = module->lintResult; - - return checkResult; -} - void Frontend::queueModuleCheck(const std::vector& names) { moduleQueue.insert(moduleQueue.end(), names.begin(), names.end()); @@ -996,8 +797,6 @@ bool Frontend::parseGraph( void Frontend::addBuildQueueItems(std::vector& items, std::vector& buildQueue, bool cycleDetected, std::unordered_set& seen, const FrontendOptions& frontendOptions) { - LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing); - for (const ModuleName& moduleName : buildQueue) { if (seen.count(moduleName)) @@ -1038,8 +837,6 @@ void Frontend::addBuildQueueItems(std::vector& items, std::vecto void Frontend::checkBuildQueueItem(BuildQueueItem& item) { - LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing); - SourceNode& sourceNode = *item.sourceNode; const SourceModule& sourceModule = *item.sourceModule; const Config& config = item.config; @@ -1139,7 +936,8 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) module->astResolvedTypePacks.clear(); module->astScopes.clear(); - module->scopes.clear(); + if (!FFlag::DebugLuauDeferredConstraintResolution) + module->scopes.clear(); } if (mode != Mode::NoCheck) @@ -1164,8 +962,6 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) void Frontend::checkBuildQueueItems(std::vector& items) { - LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing); - for (BuildQueueItem& item : items) { checkBuildQueueItem(item); diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 000bb140a..54a7dbff4 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -196,6 +196,10 @@ static void errorToString(std::ostream& stream, const T& err) stream << "UninhabitedTypeFamily { " << toString(err.ty) << " }"; else if constexpr (std::is_same_v) stream << "UninhabitedTypePackFamily { " << toString(err.tp) << " }"; + else if constexpr (std::is_same_v) + stream << "WhereClauseNeeded { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "PackWhereClauseNeeded { " << toString(err.tp) << " }"; else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 0addaa360..37af00401 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess2, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); +LUAU_FASTFLAGVARIABLE(LuauCloneSkipNonInternalVisit, false); namespace Luau { @@ -98,6 +99,22 @@ struct ClonePublicInterface : Substitution return tp->owningArena == &module->internalTypes; } + bool ignoreChildrenVisit(TypeId ty) override + { + if (FFlag::LuauCloneSkipNonInternalVisit && ty->owningArena != &module->internalTypes) + return true; + + return false; + } + + bool ignoreChildrenVisit(TypePackId tp) override + { + if (FFlag::LuauCloneSkipNonInternalVisit && tp->owningArena != &module->internalTypes) + return true; + + return false; + } + TypeId clean(TypeId ty) override { TypeId result = clone(ty); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 24c31f7ec..6a78bc667 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -21,6 +21,7 @@ LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauTransitiveSubtyping) +LUAU_FASTFLAG(DebugLuauReadWriteProperties) namespace Luau { @@ -277,6 +278,22 @@ bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_set seen) { // TODO: use log.follow(ty), CLI-64291 @@ -297,8 +314,18 @@ bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) { for (const auto& [_, prop] : ttv->props) { - if (!isInhabited(prop.type(), seen)) - return false; + if (FFlag::DebugLuauReadWriteProperties) + { + // A table enclosing a read property whose type is uninhabitable is also itself uninhabitable, + // but not its write property. That just means the write property doesn't exist, and so is readonly. + if (auto ty = prop.readType(); ty && !isInhabited(*ty, seen)) + return false; + } + else + { + if (!isInhabited(prop.type(), seen)) + return false; + } } return true; } @@ -314,14 +341,32 @@ bool Normalizer::isIntersectionInhabited(TypeId left, TypeId right) { left = follow(left); right = follow(right); + + if (cacheInhabitance) + { + if (bool* result = cachedIsInhabitedIntersection.find({left, right})) + return *result; + } + std::unordered_set seen = {}; seen.insert(left); seen.insert(right); NormalizedType norm{builtinTypes}; if (!normalizeIntersections({left, right}, norm)) + { + if (cacheInhabitance) + cachedIsInhabitedIntersection[{left, right}] = false; + return false; - return isInhabited(&norm, seen); + } + + bool result = isInhabited(&norm, seen); + + if (cacheInhabitance) + cachedIsInhabitedIntersection[{left, right}] = result; + + return result; } static int tyvarIndex(TypeId ty) @@ -568,10 +613,11 @@ static void assertInvariant(const NormalizedType& norm) #endif } -Normalizer::Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState) +Normalizer::Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState, bool cacheInhabitance) : arena(arena) , builtinTypes(builtinTypes) , sharedState(sharedState) + , cacheInhabitance(cacheInhabitance) { } @@ -1315,7 +1361,8 @@ bool Normalizer::withinResourceLimits() // If cache is too large, clear it if (FInt::LuauNormalizeCacheLimit > 0) { - size_t cacheUsage = cachedNormals.size() + cachedIntersections.size() + cachedUnions.size() + cachedTypeIds.size(); + size_t cacheUsage = cachedNormals.size() + cachedIntersections.size() + cachedUnions.size() + cachedTypeIds.size() + + cachedIsInhabited.size() + cachedIsInhabitedIntersection.size(); if (cacheUsage > size_t(FInt::LuauNormalizeCacheLimit)) { clearCaches(); @@ -2726,7 +2773,7 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, N UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; u.tryUnify(subPack, superPack); return !u.failure; @@ -2750,7 +2797,7 @@ bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, Not UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; u.tryUnify(subTy, superTy); const bool ok = u.errors.empty() && u.log.empty(); @@ -2763,7 +2810,7 @@ bool isConsistentSubtype( UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; u.tryUnify(subPack, superPack); const bool ok = u.errors.empty() && u.log.empty(); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 40a495935..26cbdc683 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -13,6 +13,8 @@ LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false) LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) +LUAU_FASTFLAG(DebugLuauReadWriteProperties) +LUAU_FASTFLAG(LuauCloneSkipNonInternalVisit) namespace Luau { @@ -214,7 +216,7 @@ void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(ty == log->follow(ty)); - if (ignoreChildren(ty)) + if (FFlag::LuauCloneSkipNonInternalVisit ? ignoreChildrenVisit(ty) : ignoreChildren(ty)) return; if (auto pty = log->pending(ty)) @@ -237,7 +239,16 @@ void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) - visitChild(prop.type()); + { + if (FFlag::DebugLuauReadWriteProperties) + { + visitChild(prop.readType()); + visitChild(prop.writeType()); + } + else + visitChild(prop.type()); + } + if (ttv->indexer) { visitChild(ttv->indexer->indexType); @@ -311,7 +322,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) { LUAU_ASSERT(tp == log->follow(tp)); - if (ignoreChildren(tp)) + if (FFlag::LuauCloneSkipNonInternalVisit ? ignoreChildrenVisit(tp) : ignoreChildren(tp)) return; if (auto ptp = log->pending(tp)) @@ -793,7 +804,13 @@ void Substitution::replaceChildren(TypeId ty) { LUAU_ASSERT(!ttv->boundTo); for (auto& [name, prop] : ttv->props) - prop.setType(replace(prop.type())); + { + if (FFlag::DebugLuauReadWriteProperties) + prop = Property::create(replace(prop.readType()), replace(prop.writeType())); + else + prop.setType(replace(prop.type())); + } + if (ttv->indexer) { ttv->indexer->indexType = replace(ttv->indexer->indexType); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 347380cd9..f4375b5ae 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -335,6 +335,44 @@ struct TypeStringifier tv->ty); } + void stringify(const std::string& name, const Property& prop) + { + if (isIdentifier(name)) + state.emit(name); + else + { + state.emit("[\""); + state.emit(escape(name)); + state.emit("\"]"); + } + state.emit(": "); + + if (FFlag::DebugLuauReadWriteProperties) + { + // We special case the stringification if the property's read and write types are shared. + if (prop.isShared()) + return stringify(*prop.readType()); + + // Otherwise emit them separately. + if (auto ty = prop.readType()) + { + state.emit("read "); + stringify(*ty); + } + + if (prop.readType() && prop.writeType()) + state.emit(" + "); + + if (auto ty = prop.writeType()) + { + state.emit("write "); + stringify(*ty); + } + } + else + stringify(prop.type()); + } + void stringify(TypePackId tp); void stringify(TypePackId tpid, const std::vector>& names); @@ -672,16 +710,8 @@ struct TypeStringifier break; } - if (isIdentifier(name)) - state.emit(name); - else - { - state.emit("[\""); - state.emit(escape(name)); - state.emit("\"]"); - } - state.emit(": "); - stringify(prop.type()); + stringify(name, prop); + comma = true; ++index; } diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index e8a2bc5d4..2aa13bc92 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -27,7 +27,6 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(DebugLuauReadWriteProperties) -LUAU_FASTFLAGVARIABLE(LuauBoundLazyTypes2, false) namespace Luau { @@ -78,65 +77,31 @@ TypeId follow(TypeId t) TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)) { auto advance = [context, mapper](TypeId ty) -> std::optional { - if (FFlag::LuauBoundLazyTypes2) - { - TypeId mapped = mapper(context, ty); + TypeId mapped = mapper(context, ty); - if (auto btv = get>(mapped)) - return btv->boundTo; + if (auto btv = get>(mapped)) + return btv->boundTo; - if (auto ttv = get(mapped)) - return ttv->boundTo; + if (auto ttv = get(mapped)) + return ttv->boundTo; - if (auto ltv = getMutable(mapped)) - return unwrapLazy(ltv); + if (auto ltv = getMutable(mapped)) + return unwrapLazy(ltv); - return std::nullopt; - } - else - { - if (auto btv = get>(mapper(context, ty))) - return btv->boundTo; - else if (auto ttv = get(mapper(context, ty))) - return ttv->boundTo; - else - return std::nullopt; - } - }; - - auto force = [context, mapper](TypeId ty) { - TypeId mapped = mapper(context, ty); - - if (auto ltv = get_if(&mapped->ty)) - { - TypeId res = ltv->thunk_DEPRECATED(); - if (get(res)) - throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); - - *asMutable(ty) = BoundType(res); - } + return std::nullopt; }; - if (!FFlag::LuauBoundLazyTypes2) - force(t); - TypeId cycleTester = t; // Null once we've determined that there is no cycle if (auto a = advance(cycleTester)) cycleTester = *a; else return t; - if (FFlag::LuauBoundLazyTypes2) - { - if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null - return cycleTester; - } + if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null + return cycleTester; while (true) { - if (!FFlag::LuauBoundLazyTypes2) - force(t); - auto a1 = advance(t); if (a1) t = *a1; @@ -684,16 +649,17 @@ Property Property::rw(TypeId read, TypeId write) return p; } -std::optional Property::create(std::optional read, std::optional write) +Property Property::create(std::optional read, std::optional write) { if (read && !write) return Property::readonly(*read); else if (!read && write) return Property::writeonly(*write); - else if (read && write) - return Property::rw(*read, *write); else - return std::nullopt; + { + LUAU_ASSERT(read && write); + return Property::rw(*read, *write); + } } TypeId Property::type() const @@ -705,6 +671,7 @@ TypeId Property::type() const void Property::setType(TypeId ty) { + LUAU_ASSERT(!FFlag::DebugLuauReadWriteProperties); readTy = ty; } @@ -722,6 +689,11 @@ std::optional Property::writeType() const return writeTy; } +bool Property::isShared() const +{ + return readTy && writeTy && readTy == writeTy; +} + TableType::TableType(TableState state, TypeLevel level, Scope* scope) : state(state) , level(level) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 40376e32a..c1146b5c8 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -13,9 +13,11 @@ #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" +#include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/TypeFamily.h" +#include "Luau/VisitType.h" #include @@ -81,6 +83,146 @@ static std::optional getIdentifierOfBaseVar(AstExpr* node) return std::nullopt; } +template +bool areEquivalent(const T& a, const T& b) +{ + if (a.family != b.family) + return false; + + if (a.typeArguments.size() != b.typeArguments.size() || a.packArguments.size() != b.packArguments.size()) + return false; + + for (size_t i = 0; i < a.typeArguments.size(); ++i) + { + if (follow(a.typeArguments[i]) != follow(b.typeArguments[i])) + return false; + } + + for (size_t i = 0; i < a.packArguments.size(); ++i) + { + if (follow(a.packArguments[i]) != follow(b.packArguments[i])) + return false; + } + + return true; +} + +struct FamilyFinder : TypeOnceVisitor +{ + DenseHashSet mentionedFamilies{nullptr}; + DenseHashSet mentionedFamilyPacks{nullptr}; + + bool visit(TypeId ty, const TypeFamilyInstanceType&) override + { + mentionedFamilies.insert(ty); + return true; + } + + bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override + { + mentionedFamilyPacks.insert(tp); + return true; + } +}; + +struct InternalFamilyFinder : TypeOnceVisitor +{ + DenseHashSet internalFamilies{nullptr}; + DenseHashSet internalPackFamilies{nullptr}; + DenseHashSet mentionedFamilies{nullptr}; + DenseHashSet mentionedFamilyPacks{nullptr}; + + InternalFamilyFinder(std::vector& declStack) + { + FamilyFinder f; + for (TypeId fn : declStack) + f.traverse(fn); + + mentionedFamilies = std::move(f.mentionedFamilies); + mentionedFamilyPacks = std::move(f.mentionedFamilyPacks); + } + + bool visit(TypeId ty, const TypeFamilyInstanceType& tfit) override + { + bool hasGeneric = false; + + for (TypeId p : tfit.typeArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } + + for (TypePackId p : tfit.packArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } + + if (hasGeneric) + { + for (TypeId mentioned : mentionedFamilies) + { + const TypeFamilyInstanceType* mentionedTfit = get(mentioned); + LUAU_ASSERT(mentionedTfit); + if (areEquivalent(tfit, *mentionedTfit)) + { + return true; + } + } + + internalFamilies.insert(ty); + } + + return true; + } + + bool visit(TypePackId tp, const TypeFamilyInstanceTypePack& tfitp) override + { + bool hasGeneric = false; + + for (TypeId p : tfitp.typeArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } + + for (TypePackId p : tfitp.packArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } + + if (hasGeneric) + { + for (TypePackId mentioned : mentionedFamilyPacks) + { + const TypeFamilyInstanceTypePack* mentionedTfitp = get(mentioned); + LUAU_ASSERT(mentionedTfitp); + if (areEquivalent(tfitp, *mentionedTfitp)) + { + return true; + } + } + + internalPackFamilies.insert(tp); + } + + return true; + } +}; + struct TypeChecker2 { NotNull builtinTypes; @@ -91,16 +233,20 @@ struct TypeChecker2 TypeArena testArena; std::vector> stack; + std::vector functionDeclStack; + + DenseHashSet noTypeFamilyErrors{nullptr}; Normalizer normalizer; - TypeChecker2(NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule* sourceModule, Module* module) + TypeChecker2(NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule* sourceModule, + Module* module) : builtinTypes(builtinTypes) , logger(logger) , ice(unifierState->iceHandler) , sourceModule(sourceModule) , module(module) - , normalizer{&testArena, builtinTypes, unifierState} + , normalizer{&testArena, builtinTypes, unifierState, /* cacheInhabitance */ true} { } @@ -112,10 +258,31 @@ struct TypeChecker2 return std::nullopt; } + void checkForInternalFamily(TypeId ty, Location location) + { + InternalFamilyFinder finder(functionDeclStack); + finder.traverse(ty); + + for (TypeId internal : finder.internalFamilies) + reportError(WhereClauseNeeded{internal}, location); + + for (TypePackId internal : finder.internalPackFamilies) + reportError(PackWhereClauseNeeded{internal}, location); + } + TypeId checkForFamilyInhabitance(TypeId instance, Location location) { + if (noTypeFamilyErrors.find(instance)) + return instance; + TxnLog fake{}; - reportErrors(reduceFamilies(instance, location, NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true).errors); + ErrorVec errors = + reduceFamilies(instance, location, NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true).errors; + + if (errors.empty()) + noTypeFamilyErrors.insert(instance); + + reportErrors(std::move(errors)); return instance; } @@ -316,7 +483,7 @@ struct TypeChecker2 TypeArena* arena = &testArena; TypePackId actualRetType = reconstructPack(ret->list, *arena); - Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; + Unifier u{NotNull{&normalizer}, stack.back(), ret->location, Covariant}; u.hideousFixMeGenericsAreActuallyFree = true; u.tryUnify(actualRetType, expectedRetType); @@ -466,12 +633,47 @@ struct TypeChecker2 variableTypes.emplace_back(*ty); } - // ugh. There's nothing in the AST to hang a whole type pack on for the - // set of iteratees, so we have to piece it back together by hand. + AstExpr* firstValue = forInStatement->values.data[0]; + + // we need to build up a typepack for the iterators/values portion of the for-in statement. std::vector valueTypes; - for (size_t i = 0; i < forInStatement->values.size - 1; ++i) + std::optional iteratorTail; + + // since the first value may be the only iterator (e.g. if it is a call), we want to + // look to see if it has a resulting typepack as our iterators. + TypePackId* retPack = module->astTypePacks.find(firstValue); + if (retPack) + { + auto [head, tail] = flatten(*retPack); + valueTypes = head; + iteratorTail = tail; + } + else + { + valueTypes.emplace_back(lookupType(firstValue)); + } + + // if the initial and expected types from the iterator unified during constraint solving, + // we'll have a resolved type to use here, but we'll only use it if either the iterator is + // directly present in the for-in statement or if we have an iterator state constraining us + TypeId* resolvedTy = module->astOverloadResolvedTypes.find(firstValue); + if (resolvedTy && (!retPack || valueTypes.size() > 1)) + valueTypes[0] = *resolvedTy; + + for (size_t i = 1; i < forInStatement->values.size - 1; ++i) + { valueTypes.emplace_back(lookupType(forInStatement->values.data[i])); - TypePackId iteratorTail = lookupPack(forInStatement->values.data[forInStatement->values.size - 1]); + } + + // if we had more than one value, the tail from the first value is no longer appropriate to use. + if (forInStatement->values.size > 1) + { + auto [head, tail] = flatten(lookupPack(forInStatement->values.data[forInStatement->values.size - 1])); + valueTypes.insert(valueTypes.end(), head.begin(), head.end()); + iteratorTail = tail; + } + + // and now we can put everything together to get the actual typepack of the iterators. TypePackId iteratorPack = arena.addTypePack(valueTypes, iteratorTail); // ... and then expand it out to 3 values (if possible) @@ -518,26 +720,16 @@ struct TypeChecker2 // This depends on the types in iterateePack and therefore // iteratorTypes. + // If the iteratee is an error type, then we can't really say anything else about iteration over it. + // After all, it _could've_ been a table. + if (get(follow(flattenPack(iterFtv->argTypes)))) + return; + // If iteratorTypes is too short to be a valid call to nextFn, we have to report a count mismatch error. // If 2 is too short to be a valid call to nextFn, we have to report a count mismatch error. // If 2 is too long to be a valid call to nextFn, we have to report a count mismatch error. auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), iterFtv->argTypes, /*includeHiddenVariadics*/ true); - if (minCount > 2) - { - if (isMm) - reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); - else - reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values)); - } - if (maxCount && *maxCount < 2) - { - if (isMm) - reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); - else - reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values)); - } - TypePack flattenedArgTypes = extendTypePack(arena, builtinTypes, iterFtv->argTypes, 2); size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; size_t actualArgCount = expectedVariableTypes.head.size(); @@ -546,7 +738,7 @@ struct TypeChecker2 if (isMm) reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); else - reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->values.data[0]->location); } else if (actualArgCount < minCount) @@ -554,7 +746,7 @@ struct TypeChecker2 if (isMm) reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); else - reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->values.data[0]->location); } @@ -1211,6 +1403,7 @@ struct TypeChecker2 visitGenerics(fn->generics, fn->genericPacks); TypeId inferredFnTy = lookupType(fn); + functionDeclStack.push_back(inferredFnTy); const NormalizedType* normalizedFnTy = normalizer.normalize(inferredFnTy); if (!normalizedFnTy) @@ -1260,6 +1453,8 @@ struct TypeChecker2 } visit(fn->body); + + functionDeclStack.pop_back(); } void visit(AstExprTable* expr) @@ -1370,7 +1565,10 @@ struct TypeChecker2 TypeId expectedResult = lookupType(expr); if (get(expectedResult)) + { + checkForInternalFamily(expectedResult, expr->location); return expectedResult; + } if (expr->op == AstExprBinary::Op::Or) { @@ -1379,9 +1577,9 @@ struct TypeChecker2 bool isStringOperation = isString(leftType) && isString(rightType); - if (get(leftType) || get(leftType)) + if (get(leftType) || get(leftType) || get(leftType)) return leftType; - else if (get(rightType) || get(rightType)) + else if (get(rightType) || get(rightType) || get(rightType)) return rightType; if ((get(leftType) || get(leftType) || get(leftType)) && !isEquality && !isLogical) @@ -1982,7 +2180,7 @@ struct TypeChecker2 bool isSubtype(TID subTy, TID superTy, NotNull scope, bool genericsOkay = false) { TypeArena arena; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; u.hideousFixMeGenericsAreActuallyFree = genericsOkay; u.enableScopeTests(); @@ -1995,7 +2193,7 @@ struct TypeChecker2 ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy, CountMismatch::Context context = CountMismatch::Arg, bool genericsOkay = false) { - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; + Unifier u{NotNull{&normalizer}, scope, location, Covariant}; u.ctx = context; u.hideousFixMeGenericsAreActuallyFree = genericsOkay; u.enableScopeTests(); diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index e5a06c0a4..98a9f97ed 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -301,6 +301,9 @@ FamilyGraphReductionResult reduceFamilies(TypeId entrypoint, Location location, return FamilyGraphReductionResult{}; } + if (collector.tys.empty() && collector.tps.empty()) + return {}; + return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, scope, normalizer, log, force); } @@ -318,6 +321,9 @@ FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location locati return FamilyGraphReductionResult{}; } + if (collector.tys.empty() && collector.tps.empty()) + return {}; + return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, scope, normalizer, log, force); } @@ -338,8 +344,10 @@ TypeFamilyReductionResult addFamilyFn(std::vector typeParams, st TypeId lhsTy = log->follow(typeParams.at(0)); TypeId rhsTy = log->follow(typeParams.at(1)); + const NormalizedType* normLhsTy = normalizer->normalize(lhsTy); + const NormalizedType* normRhsTy = normalizer->normalize(rhsTy); - if (isNumber(lhsTy) && isNumber(rhsTy)) + if (normLhsTy && normRhsTy && normLhsTy->isNumber() && normRhsTy->isNumber()) { return {builtins->numberType, false, {}, {}}; } @@ -398,7 +406,7 @@ TypeFamilyReductionResult addFamilyFn(std::vector typeParams, st inferredArgs = {rhsTy, lhsTy}; TypePackId inferredArgPack = arena->addTypePack(std::move(inferredArgs)); - Unifier u{normalizer, Mode::Strict, scope, Location{}, Variance::Covariant, log.get()}; + Unifier u{normalizer, scope, Location{}, Variance::Covariant, log.get()}; u.tryUnify(inferredArgPack, instantiatedMmFtv->argTypes); if (std::optional ret = first(instantiatedMmFtv->retTypes); ret && u.errors.empty()) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 7e6803990..ecf222a84 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -39,7 +39,6 @@ LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) -LUAU_FASTFLAG(LuauRequirePathTrueModuleName) LUAU_FASTFLAGVARIABLE(LuauTypecheckClassTypeIndexers, false) namespace Luau @@ -2769,8 +2768,9 @@ TypeId TypeChecker::checkRelationalOperation( std::string metamethodName = opToMetaTableEntry(expr.op); - std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType), builtinTypes); - std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType), builtinTypes); + std::optional stringNoMT = std::nullopt; // works around gcc false positive "maybe uninitialized" warnings + std::optional leftMetatable = isString(lhsType) ? stringNoMT : getMetatable(follow(lhsType), builtinTypes); + std::optional rightMetatable = isString(rhsType) ? stringNoMT : getMetatable(follow(rhsType), builtinTypes); if (leftMetatable != rightMetatable) { @@ -4676,7 +4676,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // Types of requires that transitively refer to current module have to be replaced with 'any' for (const auto& [location, path] : requireCycles) { - if (!path.empty() && path.front() == (FFlag::LuauRequirePathTrueModuleName ? moduleInfo.name : resolver->getHumanReadableModuleName(moduleInfo.name))) + if (!path.empty() && path.front() == moduleInfo.name) return anyType; } @@ -5043,7 +5043,7 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location) { - return Unifier{NotNull{&normalizer}, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant}; + return Unifier{NotNull{&normalizer}, NotNull{scope.get()}, location, Variance::Covariant}; } TypeId TypeChecker::freshType(const ScopePtr& scope) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 76428cf97..9a12234bd 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -396,11 +396,10 @@ TypeMismatch::Context Unifier::mismatchContext() } } -Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog) +Unifier::Unifier(NotNull normalizer, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog) : types(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) - , mode(mode) , scope(scope) , log(parentLog) , location(location) @@ -423,6 +422,12 @@ static bool isBlocked(const TxnLog& log, TypeId ty) return get(ty) || get(ty); } +static bool isBlocked(const TxnLog& log, TypePackId tp) +{ + tp = log.follow(tp); + return get(tp); +} + void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); @@ -1761,6 +1766,19 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (log.haveSeen(superTp, subTp)) return; + if (isBlocked(log, subTp) && isBlocked(log, superTp)) + { + blockedTypePacks.push_back(subTp); + blockedTypePacks.push_back(superTp); + } + else if (isBlocked(log, subTp)) + { + blockedTypePacks.push_back(subTp); + } + else if (isBlocked(log, superTp)) + { + blockedTypePacks.push_back(superTp); + } if (log.getMutable(superTp)) { if (!occursCheck(superTp, subTp, /* reversed = */ true)) @@ -2795,7 +2813,12 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever if (std::optional maybeTail = subIter.tail()) { TypePackId tail = follow(*maybeTail); - if (get(tail)) + + if (isBlocked(log, tail)) + { + blockedTypePacks.push_back(tail); + } + else if (get(tail)) { log.replace(tail, BoundTypePack(superTp)); } @@ -3094,7 +3117,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; + Unifier u = Unifier{normalizer, scope, location, variance, &log}; u.normalize = normalize; u.checkInhabited = checkInhabited; @@ -3125,12 +3148,6 @@ void Unifier::reportError(TypeError err) failure = true; } - -bool Unifier::isNonstrictMode() const -{ - return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); -} - void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType) { if (auto e = hasUnificationTooComplex(innerErrors)) diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 09acfb4a9..3fc37d1d7 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -157,6 +157,8 @@ class AssemblyBuilderA64 void fcmpz(RegisterA64 src); void fcsel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); + void udf(); + // Run final checks bool finalize(); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index a372bf911..9e7d50116 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -99,6 +99,7 @@ class AssemblyBuilderX64 void call(OperandX64 op); void int3(); + void ud2(); void bsr(RegisterX64 dst, OperandX64 src); void bsf(RegisterX64 dst, OperandX64 src); diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 3b09359ec..60106d1b5 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -38,7 +38,6 @@ struct IrBuilder IrOp undef(); - IrOp constBool(bool value); IrOp constInt(int value); IrOp constUint(unsigned value); IrOp constDouble(double value); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 4bc9c8237..0e17cba9e 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -283,7 +283,7 @@ enum class IrCmd : uint8_t // A: builtin // B: Rn (result start) // C: Rn (argument start) - // D: Rn or Kn or a boolean that's false (optional second argument) + // D: Rn or Kn or undef (optional second argument) // E: int (argument count) // F: int (result count) FASTCALL, @@ -292,7 +292,7 @@ enum class IrCmd : uint8_t // A: builtin // B: Rn (result start) // C: Rn (argument start) - // D: Rn or Kn or a boolean that's false (optional second argument) + // D: Rn or Kn or undef (optional second argument) // E: int (argument count or -1 to use all arguments up to stack top) // F: int (result count or -1 to preserve all results and adjust stack top) INVOKE_FASTCALL, @@ -360,39 +360,46 @@ enum class IrCmd : uint8_t // Guard against tag mismatch // A, B: tag - // C: block + // C: block/undef // In final x64 lowering, A can also be Rn + // When undef is specified instead of a block, execution is aborted on check failure CHECK_TAG, // Guard against readonly table // A: pointer (Table) - // B: block + // B: block/undef + // When undef is specified instead of a block, execution is aborted on check failure CHECK_READONLY, // Guard against table having a metatable // A: pointer (Table) - // B: block + // B: block/undef + // When undef is specified instead of a block, execution is aborted on check failure CHECK_NO_METATABLE, // Guard against executing in unsafe environment - // A: block + // A: block/undef + // When undef is specified instead of a block, execution is aborted on check failure CHECK_SAFE_ENV, // Guard against index overflowing the table array size // A: pointer (Table) // B: int (index) - // C: block + // C: block/undef + // When undef is specified instead of a block, execution is aborted on check failure CHECK_ARRAY_SIZE, // Guard against cached table node slot not matching the actual table node slot for a key // A: pointer (LuaNode) // B: Kn - // C: block + // C: block/undef + // When undef is specified instead of a block, execution is aborted on check failure CHECK_SLOT_MATCH, // Guard against table node with a linked next node to ensure that our lookup hits the main position of the key // A: pointer (LuaNode) - // B: block + // B: block/undef + // When undef is specified instead of a block, execution is aborted on check failure CHECK_NODE_NO_NEXT, // Special operations @@ -428,7 +435,7 @@ enum class IrCmd : uint8_t // While capture is a no-op right now, it might be useful to track register/upvalue lifetimes // A: Rn or UPn - // B: boolean (true for reference capture, false for value capture) + // B: unsigned int (1 for reference capture, 0 for value capture) CAPTURE, // Operations that don't have an IR representation yet @@ -581,7 +588,6 @@ enum class IrCmd : uint8_t enum class IrConstKind : uint8_t { - Bool, Int, Uint, Double, @@ -867,27 +873,6 @@ struct IrFunction return value.valueTag; } - bool boolOp(IrOp op) - { - IrConst& value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Bool); - return value.valueBool; - } - - std::optional asBoolOp(IrOp op) - { - if (op.kind != IrOpKind::Constant) - return std::nullopt; - - IrConst& value = constOp(op); - - if (value.kind != IrConstKind::Bool) - return std::nullopt; - - return value.valueBool; - } - int intOp(IrOp op) { IrConst& value = constOp(op); diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index 000dc85fd..99a68481e 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -687,6 +687,11 @@ void AssemblyBuilderA64::fcsel(RegisterA64 dst, RegisterA64 src1, RegisterA64 sr placeCS("fcsel", dst, src1, src2, cond, 0b11110'01'1, 0b11); } +void AssemblyBuilderA64::udf() +{ + place0("udf", 0); +} + bool AssemblyBuilderA64::finalize() { code.resize(codePos - code.data()); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index f0ee500cb..c7644a86c 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -472,6 +472,15 @@ void AssemblyBuilderX64::int3() commit(); } +void AssemblyBuilderX64::ud2() +{ + if (logText) + log("ud2"); + + place(0x0f); + place(0x0b); +} + void AssemblyBuilderX64::bsr(RegisterX64 dst, OperandX64 src) { if (logText) diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index 4d04a249f..09e1bb712 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -51,13 +51,13 @@ static void makePagesExecutable(uint8_t* mem, size_t size) DWORD oldProtect; if (VirtualProtect(mem, size, PAGE_EXECUTE_READ, &oldProtect) == 0) - LUAU_ASSERT(!"failed to change page protection"); + LUAU_ASSERT(!"Failed to change page protection"); } static void flushInstructionCache(uint8_t* mem, size_t size) { if (FlushInstructionCache(GetCurrentProcess(), mem, size) == 0) - LUAU_ASSERT(!"failed to flush instruction cache"); + LUAU_ASSERT(!"Failed to flush instruction cache"); } #else static uint8_t* allocatePages(size_t size) @@ -68,7 +68,7 @@ static uint8_t* allocatePages(size_t size) static void freePages(uint8_t* mem, size_t size) { if (munmap(mem, alignToPageSize(size)) != 0) - LUAU_ASSERT(!"failed to deallocate block memory"); + LUAU_ASSERT(!"Failed to deallocate block memory"); } static void makePagesExecutable(uint8_t* mem, size_t size) @@ -77,7 +77,7 @@ static void makePagesExecutable(uint8_t* mem, size_t size) LUAU_ASSERT(size == alignToPageSize(size)); if (mprotect(mem, size, PROT_READ | PROT_EXEC) != 0) - LUAU_ASSERT(!"failed to change page protection"); + LUAU_ASSERT(!"Failed to change page protection"); } static void flushInstructionCache(uint8_t* mem, size_t size) diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index 59ee6f138..e9ce86747 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -79,7 +79,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz #if defined(_WIN32) && defined(_M_X64) if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(unwind->getFunctionCount()), uintptr_t(block))) { - LUAU_ASSERT(!"failed to allocate function table"); + LUAU_ASSERT(!"Failed to allocate function table"); return nullptr; } #elif defined(__linux__) || defined(__APPLE__) @@ -94,7 +94,7 @@ void destroyBlockUnwindInfo(void* context, void* unwindData) { #if defined(_WIN32) && defined(_M_X64) if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData)) - LUAU_ASSERT(!"failed to deallocate function table"); + LUAU_ASSERT(!"Failed to deallocate function table"); #elif defined(__linux__) || defined(__APPLE__) visitFdeEntries((char*)unwindData, __deregister_frame); #endif diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 646038347..4ee8e4440 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -18,7 +18,6 @@ #include "Luau/AssemblyBuilderA64.h" #include "Luau/AssemblyBuilderX64.h" -#include "CustomExecUtils.h" #include "NativeState.h" #include "CodeGenA64.h" @@ -59,6 +58,8 @@ namespace Luau namespace CodeGen { +static const Instruction kCodeEntryInsn = LOP_NATIVECALL; + static void* gPerfLogContext = nullptr; static PerfLogFn gPerfLogFn = nullptr; @@ -332,9 +333,15 @@ static std::optional assembleFunction(AssemblyBuilder& build, Nativ return createNativeProto(proto, ir); } +static NativeState* getNativeState(lua_State* L) +{ + return static_cast(L->global->ecb.context); +} + static void onCloseState(lua_State* L) { - destroyNativeState(L); + delete getNativeState(L); + L->global->ecb = lua_ExecutionCallbacks(); } static void onDestroyFunction(lua_State* L, Proto* proto) @@ -342,6 +349,7 @@ static void onDestroyFunction(lua_State* L, Proto* proto) destroyExecData(proto->execdata); proto->execdata = nullptr; proto->exectarget = 0; + proto->codeentry = proto->code; } static int onEnter(lua_State* L, Proto* proto) @@ -362,7 +370,7 @@ static void onSetBreakpoint(lua_State* L, Proto* proto, int instruction) if (!proto->execdata) return; - LUAU_ASSERT(!"native breakpoints are not implemented"); + LUAU_ASSERT(!"Native breakpoints are not implemented"); } #if defined(__aarch64__) @@ -430,39 +438,34 @@ void create(lua_State* L) { LUAU_ASSERT(isSupported()); - NativeState& data = *createNativeState(L); + std::unique_ptr data = std::make_unique(); #if defined(_WIN32) - data.unwindBuilder = std::make_unique(); + data->unwindBuilder = std::make_unique(); #else - data.unwindBuilder = std::make_unique(); + data->unwindBuilder = std::make_unique(); #endif - data.codeAllocator.context = data.unwindBuilder.get(); - data.codeAllocator.createBlockUnwindInfo = createBlockUnwindInfo; - data.codeAllocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + data->codeAllocator.context = data->unwindBuilder.get(); + data->codeAllocator.createBlockUnwindInfo = createBlockUnwindInfo; + data->codeAllocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; - initFunctions(data); + initFunctions(*data); #if defined(__x86_64__) || defined(_M_X64) - if (!X64::initHeaderFunctions(data)) - { - destroyNativeState(L); + if (!X64::initHeaderFunctions(*data)) return; - } #elif defined(__aarch64__) - if (!A64::initHeaderFunctions(data)) - { - destroyNativeState(L); + if (!A64::initHeaderFunctions(*data)) return; - } #endif if (gPerfLogFn) - gPerfLogFn(gPerfLogContext, uintptr_t(data.context.gateEntry), 4096, ""); + gPerfLogFn(gPerfLogContext, uintptr_t(data->context.gateEntry), 4096, ""); - lua_ExecutionCallbacks* ecb = getExecutionCallbacks(L); + lua_ExecutionCallbacks* ecb = &L->global->ecb; + ecb->context = data.release(); ecb->close = onCloseState; ecb->destroy = onDestroyFunction; ecb->enter = onEnter; @@ -490,7 +493,8 @@ void compile(lua_State* L, int idx) const TValue* func = luaA_toobject(L, idx); // If initialization has failed, do not compile any functions - if (!getNativeState(L)) + NativeState* data = getNativeState(L); + if (!data) return; #if defined(__aarch64__) @@ -499,8 +503,6 @@ void compile(lua_State* L, int idx) X64::AssemblyBuilderX64 build(/* logText= */ false); #endif - NativeState* data = getNativeState(L); - std::vector protos; gatherFunctions(protos, clvalue(func)->l.p); @@ -564,6 +566,7 @@ void compile(lua_State* L, int idx) // the memory is now managed by VM and will be freed via onDestroyFunction result.p->execdata = result.execdata; result.p->exectarget = uintptr_t(codeStart) + result.exectarget; + result.p->codeentry = &kCodeEntryInsn; } } diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index f6e9152c3..c5042fc32 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -5,7 +5,6 @@ #include "Luau/UnwindBuilder.h" #include "BitUtils.h" -#include "CustomExecUtils.h" #include "NativeState.h" #include "EmitCommonA64.h" @@ -95,13 +94,14 @@ static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci // We need to check if the new frame can be executed natively - // TOOD: .flags and .savedpc load below can be fused with ldp + // TODO: .flags and .savedpc load below can be fused with ldp build.ldr(w3, mem(x2, offsetof(CallInfo, flags))); - build.tbz(x3, countrz(LUA_CALLINFO_CUSTOM), helpers.exitContinueVm); + build.tbz(x3, countrz(LUA_CALLINFO_NATIVE), helpers.exitContinueVm); build.mov(rClosure, x0); - build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k - build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code + + LUAU_ASSERT(offsetof(Proto, code) == offsetof(Proto, k) + 8); + build.ldp(rConstants, rCode, mem(x1, offsetof(Proto, k))); // proto->k, proto->code // Get instruction index from instruction pointer // To get instruction index from instruction pointer, we need to divide byte offset by 4 @@ -145,8 +145,9 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde build.mov(rNativeContext, x3); build.ldr(rBase, mem(x0, offsetof(lua_State, base))); // L->base - build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k - build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code + + LUAU_ASSERT(offsetof(Proto, code) == offsetof(Proto, k) + 8); + build.ldp(rConstants, rCode, mem(x1, offsetof(Proto, k))); // proto->k, proto->code build.ldr(x9, mem(x0, offsetof(lua_State, ci))); // L->ci build.ldr(x9, mem(x9, offsetof(CallInfo, func))); // L->ci->func @@ -194,7 +195,7 @@ bool initHeaderFunctions(NativeState& data) if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, codeStart)) { - LUAU_ASSERT(!"failed to create entry function"); + LUAU_ASSERT(!"Failed to create entry function"); return false; } diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 4ad67d83d..a7131e113 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "CodeGenUtils.h" -#include "CustomExecUtils.h" - #include "lvm.h" #include "lbuiltins.h" @@ -268,7 +266,7 @@ Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) ci->savedpc = p->code; if (LUAU_LIKELY(p->execdata != NULL)) - ci->flags = LUA_CALLINFO_CUSTOM; + ci->flags = LUA_CALLINFO_NATIVE; return ccl; } diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 5f2cd6147..ec032c02b 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -4,7 +4,6 @@ #include "Luau/AssemblyBuilderX64.h" #include "Luau/UnwindBuilder.h" -#include "CustomExecUtils.h" #include "NativeState.h" #include "EmitCommonX64.h" @@ -160,7 +159,7 @@ bool initHeaderFunctions(NativeState& data) if (!data.codeAllocator.allocate( build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, data.gateDataSize, codeStart)) { - LUAU_ASSERT(!"failed to create entry function"); + LUAU_ASSERT(!"Failed to create entry function"); return false; } diff --git a/CodeGen/src/CustomExecUtils.h b/CodeGen/src/CustomExecUtils.h deleted file mode 100644 index 9c9996611..000000000 --- a/CodeGen/src/CustomExecUtils.h +++ /dev/null @@ -1,106 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "NativeState.h" - -#include "lobject.h" -#include "lstate.h" - -namespace Luau -{ -namespace CodeGen -{ - -// Here we define helper functions to wrap interaction with Luau custom execution API so that it works with or without LUA_CUSTOM_EXECUTION - -#if LUA_CUSTOM_EXECUTION - -inline lua_ExecutionCallbacks* getExecutionCallbacks(lua_State* L) -{ - return &L->global->ecb; -} - -inline NativeState* getNativeState(lua_State* L) -{ - lua_ExecutionCallbacks* ecb = getExecutionCallbacks(L); - return (NativeState*)ecb->context; -} - -inline void setNativeState(lua_State* L, NativeState* nativeState) -{ - lua_ExecutionCallbacks* ecb = getExecutionCallbacks(L); - ecb->context = nativeState; -} - -inline NativeState* createNativeState(lua_State* L) -{ - NativeState* state = new NativeState(); - setNativeState(L, state); - return state; -} - -inline void destroyNativeState(lua_State* L) -{ - NativeState* state = getNativeState(L); - setNativeState(L, nullptr); - delete state; -} - -#else - -inline lua_ExecutionCallbacks* getExecutionCallbacks(lua_State* L) -{ - return nullptr; -} - -inline NativeState* getNativeState(lua_State* L) -{ - return nullptr; -} - -inline void setNativeState(lua_State* L, NativeState* nativeState) {} - -inline NativeState* createNativeState(lua_State* L) -{ - return nullptr; -} - -inline void destroyNativeState(lua_State* L) {} - -#endif - -inline int getOpLength(LuauOpcode op) -{ - switch (op) - { - case LOP_GETGLOBAL: - case LOP_SETGLOBAL: - case LOP_GETIMPORT: - case LOP_GETTABLEKS: - case LOP_SETTABLEKS: - case LOP_NAMECALL: - case LOP_JUMPIFEQ: - case LOP_JUMPIFLE: - case LOP_JUMPIFLT: - case LOP_JUMPIFNOTEQ: - case LOP_JUMPIFNOTLE: - case LOP_JUMPIFNOTLT: - case LOP_NEWTABLE: - case LOP_SETLIST: - case LOP_FORGLOOP: - case LOP_LOADKX: - case LOP_FASTCALL2: - case LOP_FASTCALL2K: - case LOP_JUMPXEQKNIL: - case LOP_JUMPXEQKB: - case LOP_JUMPXEQKN: - case LOP_JUMPXEQKS: - return 2; - - default: - return 1; - } -} - -} // namespace CodeGen -} // namespace Luau diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 474dabf67..96599c2e5 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -119,7 +119,6 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r return emitBuiltinTypeof(regs, build, ra, arg); default: LUAU_ASSERT(!"Missing x64 lowering"); - break; } } diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index ce95e7410..0095f288a 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -6,7 +6,6 @@ #include "Luau/IrData.h" #include "Luau/IrRegAllocX64.h" -#include "CustomExecUtils.h" #include "NativeState.h" #include "lgc.h" diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index ddc4048f4..3f723f456 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -184,33 +184,6 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label build.jcc(ConditionX64::NotEqual, target); // true if boolean value is 'true' } -inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, lua_Type tag, Label& label) -{ - tmp.size = SizeX64::dword; - - build.mov(tmp, luauNodeKeyTag(node)); - build.and_(tmp, kTKeyTagMask); - build.cmp(tmp, tag); - build.jcc(ConditionX64::NotEqual, label); -} - -inline void jumpIfNodeValueTagIs(AssemblyBuilderX64& build, RegisterX64 node, lua_Type tag, Label& label) -{ - build.cmp(dword[node + offsetof(LuaNode, val) + offsetof(TValue, tt)], tag); - build.jcc(ConditionX64::Equal, label); -} - -inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, OperandX64 expectedKey, Label& label) -{ - jumpIfNodeKeyTagIsNot(build, tmp, node, LUA_TSTRING, label); - - build.mov(tmp, expectedKey); - build.cmp(tmp, luauNodeKeyValue(node)); - build.jcc(ConditionX64::NotEqual, label); - - jumpIfNodeValueTagIs(build, node, LUA_TNIL, label); -} - void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label); void jumpOnAnyCmpFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label); diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index b2db7d187..f2012ca9d 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -4,8 +4,10 @@ #include "Luau/AssemblyBuilderX64.h" #include "Luau/IrRegAllocX64.h" -#include "CustomExecUtils.h" #include "EmitCommonX64.h" +#include "NativeState.h" + +#include "lstate.h" namespace Luau { @@ -87,8 +89,8 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int build.test(rax, rax); build.jcc(ConditionX64::Zero, helpers.continueCallInVm); - // Mark call frame as custom - build.mov(dword[ci + offsetof(CallInfo, flags)], LUA_CALLINFO_CUSTOM); + // Mark call frame as native + build.mov(dword[ci + offsetof(CallInfo, flags)], LUA_CALLINFO_NATIVE); // Switch current constants build.mov(rConstants, qword[proto + offsetof(Proto, k)]); @@ -298,7 +300,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i build.mov(execdata, qword[proto + offsetof(Proto, execdata)]); - build.test(byte[cip + offsetof(CallInfo, flags)], LUA_CALLINFO_CUSTOM); + build.test(byte[cip + offsetof(CallInfo, flags)], LUA_CALLINFO_NATIVE); build.jcc(ConditionX64::Zero, helpers.exitContinueVm); // Continue in interpreter if function has no native data // Change constants diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index efcacb046..85811f057 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -113,7 +113,7 @@ uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t s } // There must be a next use since there is the last use location - LUAU_ASSERT(!"failed to find next use"); + LUAU_ASSERT(!"Failed to find next use"); return targetInst.lastUse; } @@ -338,7 +338,7 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::CAPTURE: maybeUse(inst.a); - if (function.boolOp(inst.b)) + if (function.uintOp(inst.b) == 1) capturedRegs.set(vmRegOp(inst.a), true); break; case IrCmd::SETLIST: diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 86986fe92..a12eca348 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -4,7 +4,6 @@ #include "Luau/IrAnalysis.h" #include "Luau/IrUtils.h" -#include "CustomExecUtils.h" #include "IrTranslation.h" #include "lapi.h" @@ -19,7 +18,7 @@ namespace CodeGen constexpr unsigned kNoAssociatedBlockIndex = ~0u; IrBuilder::IrBuilder() - : constantMap({IrConstKind::Bool, ~0ull}) + : constantMap({IrConstKind::Tag, ~0ull}) { } @@ -410,8 +409,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) break; } default: - LUAU_ASSERT(!"unknown instruction"); - break; + LUAU_ASSERT(!"Unknown instruction"); } } @@ -449,7 +447,7 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) if (const uint32_t* newIndex = instRedir.find(op.index)) op.index = *newIndex; else - LUAU_ASSERT(!"values can only be used if they are defined in the same block"); + LUAU_ASSERT(!"Values can only be used if they are defined in the same block"); } }; @@ -501,14 +499,6 @@ IrOp IrBuilder::undef() return {IrOpKind::Undef, 0}; } -IrOp IrBuilder::constBool(bool value) -{ - IrConst constant; - constant.kind = IrConstKind::Bool; - constant.valueBool = value; - return constAny(constant, uint64_t(value)); -} - IrOp IrBuilder::constInt(int value) { IrConst constant; diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 062321ba6..7ea9b7904 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -390,9 +390,6 @@ void toString(std::string& result, IrConst constant) { switch (constant.kind) { - case IrConstKind::Bool: - append(result, constant.valueBool ? "true" : "false"); - break; case IrConstKind::Int: append(result, "%di", constant.valueInt); break; diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 711baba68..fb5d86878 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -96,6 +96,15 @@ static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA6 } } +static void emitAbort(AssemblyBuilderA64& build, Label& abort) +{ + Label skip; + build.b(skip); + build.setLabel(abort); + build.udf(); + build.setLabel(skip); +} + static void emitFallback(AssemblyBuilderA64& build, int offset, int pcpos) { // fallback(L, instruction, base, k) @@ -256,7 +265,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } else if (inst.b.kind == IrOpKind::Constant) { - if (intOp(inst.b) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) + if (intOp(inst.b) == 0) + { + // no offset required + } + else if (intOp(inst.b) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) { build.add(inst.regA64, inst.regA64, uint16_t(intOp(inst.b) * sizeof(TValue))); } @@ -562,7 +575,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::JUMP_EQ_TAG: - if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant) + { + RegisterA64 zr = noreg; + + if (inst.a.kind == IrOpKind::Constant && tagOp(inst.a) == 0) + zr = regOp(inst.b); + else if (inst.b.kind == IrOpKind::Constant && tagOp(inst.b) == 0) + zr = regOp(inst.a); + else if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant) build.cmp(regOp(inst.a), tagOp(inst.b)); else if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Inst) build.cmp(regOp(inst.a), regOp(inst.b)); @@ -573,19 +593,33 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (isFallthroughBlock(blockOp(inst.d), next)) { - build.b(ConditionA64::Equal, labelOp(inst.c)); + if (zr != noreg) + build.cbz(zr, labelOp(inst.c)); + else + build.b(ConditionA64::Equal, labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.d), next); } else { - build.b(ConditionA64::NotEqual, labelOp(inst.d)); + if (zr != noreg) + build.cbnz(zr, labelOp(inst.d)); + else + build.b(ConditionA64::NotEqual, labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.c), next); } break; + } case IrCmd::JUMP_EQ_INT: - LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); - build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); - build.b(ConditionA64::Equal, labelOp(inst.c)); + if (intOp(inst.b) == 0) + { + build.cbz(regOp(inst.a), labelOp(inst.c)); + } + else + { + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); + build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); + build.b(ConditionA64::Equal, labelOp(inst.c)); + } jumpOrFallthrough(blockOp(inst.d), next); break; case IrCmd::JUMP_LT_INT: @@ -871,7 +905,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); else if (inst.c.kind == IrOpKind::Constant) { - TValue n; + TValue n = {}; setnvalue(&n, uintOp(inst.c)); build.adr(x2, &n, sizeof(n)); } @@ -893,7 +927,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); else if (inst.c.kind == IrOpKind::Constant) { - TValue n; + TValue n = {}; setnvalue(&n, uintOp(inst.c)); build.adr(x2, &n, sizeof(n)); } @@ -908,25 +942,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::GET_IMPORT: regs.spill(build, index); - // luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false) + // luaV_getimport(L, cl->env, k, ra, aux, /* propagatenil= */ false) build.mov(x0, rState); build.ldr(x1, mem(rClosure, offsetof(Closure, env))); build.mov(x2, rConstants); - build.mov(w3, uintOp(inst.b)); - build.mov(w4, 0); - build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_getimport))); - build.blr(x5); + build.add(x3, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.mov(w4, uintOp(inst.b)); + build.mov(w5, 0); + build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luaV_getimport))); + build.blr(x6); emitUpdateBase(build); - - // setobj2s(L, ra, L->top - 1) - build.ldr(x0, mem(rState, offsetof(lua_State, top))); - build.sub(x0, x0, sizeof(TValue)); - build.ldr(q0, x0); - build.str(q0, mem(rBase, vmRegOp(inst.a) * sizeof(TValue))); - - // L->top-- - build.str(x0, mem(rState, offsetof(lua_State, top))); break; case IrCmd::CONCAT: regs.spill(build, index); @@ -1003,62 +1029,99 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // note: no emitUpdateBase necessary because prepareFORN does not reallocate stack break; case IrCmd::CHECK_TAG: - build.cmp(regOp(inst.a), tagOp(inst.b)); - build.b(ConditionA64::NotEqual, labelOp(inst.c)); + { + Label abort; // used when guard aborts execution + Label& fail = inst.c.kind == IrOpKind::Undef ? abort : labelOp(inst.c); + if (tagOp(inst.b) == 0) + { + build.cbnz(regOp(inst.a), fail); + } + else + { + build.cmp(regOp(inst.a), tagOp(inst.b)); + build.b(ConditionA64::NotEqual, fail); + } + if (abort.id) + emitAbort(build, abort); break; + } case IrCmd::CHECK_READONLY: { + Label abort; // used when guard aborts execution RegisterA64 temp = regs.allocTemp(KindA64::w); build.ldrb(temp, mem(regOp(inst.a), offsetof(Table, readonly))); - build.cbnz(temp, labelOp(inst.b)); + build.cbnz(temp, inst.b.kind == IrOpKind::Undef ? abort : labelOp(inst.b)); + if (abort.id) + emitAbort(build, abort); break; } case IrCmd::CHECK_NO_METATABLE: { + Label abort; // used when guard aborts execution RegisterA64 temp = regs.allocTemp(KindA64::x); build.ldr(temp, mem(regOp(inst.a), offsetof(Table, metatable))); - build.cbnz(temp, labelOp(inst.b)); + build.cbnz(temp, inst.b.kind == IrOpKind::Undef ? abort : labelOp(inst.b)); + if (abort.id) + emitAbort(build, abort); break; } case IrCmd::CHECK_SAFE_ENV: { + Label abort; // used when guard aborts execution RegisterA64 temp = regs.allocTemp(KindA64::x); RegisterA64 tempw = castReg(KindA64::w, temp); build.ldr(temp, mem(rClosure, offsetof(Closure, env))); build.ldrb(tempw, mem(temp, offsetof(Table, safeenv))); - build.cbz(tempw, labelOp(inst.a)); + build.cbz(tempw, inst.a.kind == IrOpKind::Undef ? abort : labelOp(inst.a)); + if (abort.id) + emitAbort(build, abort); break; } case IrCmd::CHECK_ARRAY_SIZE: { + Label abort; // used when guard aborts execution + Label& fail = inst.c.kind == IrOpKind::Undef ? abort : labelOp(inst.c); + RegisterA64 temp = regs.allocTemp(KindA64::w); build.ldr(temp, mem(regOp(inst.a), offsetof(Table, sizearray))); if (inst.b.kind == IrOpKind::Inst) + { build.cmp(temp, regOp(inst.b)); + build.b(ConditionA64::UnsignedLessEqual, fail); + } else if (inst.b.kind == IrOpKind::Constant) { - if (size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate) + if (intOp(inst.b) == 0) + { + build.cbz(temp, fail); + } + else if (size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate) { build.cmp(temp, uint16_t(intOp(inst.b))); + build.b(ConditionA64::UnsignedLessEqual, fail); } else { RegisterA64 temp2 = regs.allocTemp(KindA64::w); build.mov(temp2, intOp(inst.b)); build.cmp(temp, temp2); + build.b(ConditionA64::UnsignedLessEqual, fail); } } else LUAU_ASSERT(!"Unsupported instruction form"); - build.b(ConditionA64::UnsignedLessEqual, labelOp(inst.c)); + if (abort.id) + emitAbort(build, abort); break; } case IrCmd::JUMP_SLOT_MATCH: case IrCmd::CHECK_SLOT_MATCH: { - Label& mismatch = inst.cmd == IrCmd::JUMP_SLOT_MATCH ? labelOp(inst.d) : labelOp(inst.c); + Label abort; // used when guard aborts execution + const IrOp& mismatchOp = inst.cmd == IrCmd::JUMP_SLOT_MATCH ? inst.d : inst.c; + Label& mismatch = mismatchOp.kind == IrOpKind::Undef ? abort : labelOp(mismatchOp); RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp1w = castReg(KindA64::w, temp1); @@ -1081,15 +1144,21 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (inst.cmd == IrCmd::JUMP_SLOT_MATCH) jumpOrFallthrough(blockOp(inst.c), next); + else if (abort.id) + emitAbort(build, abort); break; } case IrCmd::CHECK_NODE_NO_NEXT: { + Label abort; // used when guard aborts execution RegisterA64 temp = regs.allocTemp(KindA64::w); build.ldr(temp, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyTagNext)); build.lsr(temp, temp, kTKeyTagBits); - build.cbnz(temp, labelOp(inst.b)); + build.cbnz(temp, inst.b.kind == IrOpKind::Undef ? abort : labelOp(inst.b)); + + if (abort.id) + emitAbort(build, abort); break; } case IrCmd::INTERRUPT: @@ -1762,11 +1831,6 @@ uint8_t IrLoweringA64::tagOp(IrOp op) const return function.tagOp(op); } -bool IrLoweringA64::boolOp(IrOp op) const -{ - return function.boolOp(op); -} - int IrLoweringA64::intOp(IrOp op) const { return function.intOp(op); diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 9eda8976c..264789044 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -48,7 +48,6 @@ struct IrLoweringA64 // Operand data lookup helpers IrConst constOp(IrOp op) const; uint8_t tagOp(IrOp op) const; - bool boolOp(IrOp op) const; int intOp(IrOp op) const; unsigned uintOp(IrOp op) const; double doubleOp(IrOp op) const; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 035cc05c6..2efd73ea7 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -575,14 +575,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOnAnyCmpFallback(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), conditionOp(inst.c), labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; - case IrCmd::JUMP_SLOT_MATCH: - { - ScopedRegX64 tmp{regs, SizeX64::qword}; - - jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(vmConstOp(inst.b)), labelOp(inst.d)); - jumpOrFallthrough(blockOp(inst.c), next); - break; - } case IrCmd::TABLE_LEN: { IrCallWrapperX64 callWrap(regs, build, index); @@ -782,7 +774,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } else if (inst.c.kind == IrOpKind::Constant) { - TValue n; + TValue n = {}; setnvalue(&n, uintOp(inst.c)); callGetTable(regs, build, vmRegOp(inst.b), build.bytes(&n, sizeof(n)), vmRegOp(inst.a)); } @@ -798,7 +790,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } else if (inst.c.kind == IrOpKind::Constant) { - TValue n; + TValue n = {}; setnvalue(&n, uintOp(inst.c)); callSetTable(regs, build, vmRegOp(inst.b), build.bytes(&n, sizeof(n)), vmRegOp(inst.a)); } @@ -817,24 +809,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::qword, rState); callWrap.addArgument(SizeX64::qword, qword[tmp1.release() + offsetof(Closure, env)]); callWrap.addArgument(SizeX64::qword, rConstants); + callWrap.addArgument(SizeX64::qword, luauRegAddress(vmRegOp(inst.a))); callWrap.addArgument(SizeX64::dword, uintOp(inst.b)); callWrap.addArgument(SizeX64::dword, 0); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_getimport)]); emitUpdateBase(build); - - ScopedRegX64 tmp2{regs, SizeX64::qword}; - - // setobj2s(L, ra, L->top - 1) - build.mov(tmp2.reg, qword[rState + offsetof(lua_State, top)]); - build.sub(tmp2.reg, sizeof(TValue)); - - ScopedRegX64 tmp3{regs, SizeX64::xmmword}; - build.vmovups(tmp3.reg, xmmword[tmp2.reg]); - build.vmovups(luauReg(vmRegOp(inst.a)), tmp3.reg); - - // L->top-- - build.mov(qword[rState + offsetof(lua_State, top)], tmp2.reg); break; } case IrCmd::CONCAT: @@ -897,15 +877,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::CHECK_TAG: build.cmp(memRegTagOp(inst.a), tagOp(inst.b)); - build.jcc(ConditionX64::NotEqual, labelOp(inst.c)); + jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.c); break; case IrCmd::CHECK_READONLY: build.cmp(byte[regOp(inst.a) + offsetof(Table, readonly)], 0); - build.jcc(ConditionX64::NotEqual, labelOp(inst.b)); + jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.b); break; case IrCmd::CHECK_NO_METATABLE: build.cmp(qword[regOp(inst.a) + offsetof(Table, metatable)], 0); - build.jcc(ConditionX64::NotEqual, labelOp(inst.b)); + jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.b); break; case IrCmd::CHECK_SAFE_ENV: { @@ -914,7 +894,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(tmp.reg, sClosure); build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); build.cmp(byte[tmp.reg + offsetof(Table, safeenv)], 0); - build.jcc(ConditionX64::Equal, labelOp(inst.a)); + jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.a); break; } case IrCmd::CHECK_ARRAY_SIZE: @@ -925,13 +905,44 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) else LUAU_ASSERT(!"Unsupported instruction form"); - build.jcc(ConditionX64::BelowEqual, labelOp(inst.c)); + jumpOrAbortOnUndef(ConditionX64::BelowEqual, ConditionX64::NotBelowEqual, inst.c); break; + case IrCmd::JUMP_SLOT_MATCH: case IrCmd::CHECK_SLOT_MATCH: { + Label abort; // Used when guard aborts execution + const IrOp& mismatchOp = inst.cmd == IrCmd::JUMP_SLOT_MATCH ? inst.d : inst.c; + Label& mismatch = mismatchOp.kind == IrOpKind::Undef ? abort : labelOp(mismatchOp); + ScopedRegX64 tmp{regs, SizeX64::qword}; - jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(vmConstOp(inst.b)), labelOp(inst.c)); + // Check if node key tag is a string + build.mov(dwordReg(tmp.reg), luauNodeKeyTag(regOp(inst.a))); + build.and_(dwordReg(tmp.reg), kTKeyTagMask); + build.cmp(dwordReg(tmp.reg), LUA_TSTRING); + build.jcc(ConditionX64::NotEqual, mismatch); + + // Check that node key value matches the expected one + build.mov(tmp.reg, luauConstantValue(vmConstOp(inst.b))); + build.cmp(tmp.reg, luauNodeKeyValue(regOp(inst.a))); + build.jcc(ConditionX64::NotEqual, mismatch); + + // Check that node value is not nil + build.cmp(dword[regOp(inst.a) + offsetof(LuaNode, val) + offsetof(TValue, tt)], LUA_TNIL); + build.jcc(ConditionX64::Equal, mismatch); + + if (inst.cmd == IrCmd::JUMP_SLOT_MATCH) + { + jumpOrFallthrough(blockOp(inst.c), next); + } + else if (mismatchOp.kind == IrOpKind::Undef) + { + Label skip; + build.jmp(skip); + build.setLabel(abort); + build.ud2(); + build.setLabel(skip); + } break; } case IrCmd::CHECK_NODE_NO_NEXT: @@ -940,7 +951,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(tmp.reg, dword[regOp(inst.a) + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext]); build.shr(tmp.reg, kTKeyTagBits); - build.jcc(ConditionX64::NotZero, labelOp(inst.b)); + jumpOrAbortOnUndef(ConditionX64::NotZero, ConditionX64::Zero, inst.b); break; } case IrCmd::INTERRUPT: @@ -1356,6 +1367,21 @@ void IrLoweringX64::jumpOrFallthrough(IrBlock& target, IrBlock& next) build.jmp(target.label); } +void IrLoweringX64::jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef) +{ + if (targetOrUndef.kind == IrOpKind::Undef) + { + Label skip; + build.jcc(condInverse, skip); + build.ud2(); + build.setLabel(skip); + } + else + { + build.jcc(cond, labelOp(targetOrUndef)); + } +} + OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) { switch (op.kind) @@ -1428,11 +1454,6 @@ uint8_t IrLoweringX64::tagOp(IrOp op) const return function.tagOp(op); } -bool IrLoweringX64::boolOp(IrOp op) const -{ - return function.boolOp(op); -} - int IrLoweringX64::intOp(IrOp op) const { return function.intOp(op); diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index 083232cf5..cab4a85f5 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -34,6 +34,7 @@ struct IrLoweringX64 bool isFallthroughBlock(IrBlock target, IrBlock next); void jumpOrFallthrough(IrBlock& target, IrBlock& next); + void jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef); void storeDoubleAsFloat(OperandX64 dst, IrOp src); @@ -45,7 +46,6 @@ struct IrLoweringX64 IrConst constOp(IrOp op) const; uint8_t tagOp(IrOp op) const; - bool boolOp(IrOp op) const; int intOp(IrOp op) const; unsigned uintOp(IrOp op) const; double doubleOp(IrOp op) const; diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index 270691b2f..273740778 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -194,7 +194,7 @@ void IrRegAllocX64::preserve(IrInst& inst) else if (spill.valueKind == IrValueKind::Tag || spill.valueKind == IrValueKind::Int) build.mov(dword[sSpillArea + i * 8], inst.regX64); else - LUAU_ASSERT(!"unsupported value kind"); + LUAU_ASSERT(!"Unsupported value kind"); usedSpillSlots.set(i); @@ -318,7 +318,7 @@ unsigned IrRegAllocX64::findSpillStackSlot(IrValueKind valueKind) return i; } - LUAU_ASSERT(!"nowhere to spill"); + LUAU_ASSERT(!"Nowhere to spill"); return ~0u; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index a8bad5289..8e135dfe0 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -5,10 +5,10 @@ #include "Luau/IrBuilder.h" #include "Luau/IrUtils.h" -#include "CustomExecUtils.h" #include "IrTranslateBuiltins.h" #include "lobject.h" +#include "lstate.h" #include "ltm.h" namespace Luau @@ -366,7 +366,7 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, result = build.inst(IrCmd::INVOKE_LIBM, build.constUint(LBF_MATH_POW), vb, vc); break; default: - LUAU_ASSERT(!"unsupported binary op"); + LUAU_ASSERT(!"Unsupported binary op"); } } @@ -1068,13 +1068,13 @@ void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos) switch (type) { case LCT_VAL: - build.inst(IrCmd::CAPTURE, build.vmReg(index), build.constBool(false)); + build.inst(IrCmd::CAPTURE, build.vmReg(index), build.constUint(0)); break; case LCT_REF: - build.inst(IrCmd::CAPTURE, build.vmReg(index), build.constBool(true)); + build.inst(IrCmd::CAPTURE, build.vmReg(index), build.constUint(1)); break; case LCT_UPVAL: - build.inst(IrCmd::CAPTURE, build.vmUpvalue(index), build.constBool(false)); + build.inst(IrCmd::CAPTURE, build.vmUpvalue(index), build.constUint(0)); break; default: LUAU_ASSERT(!"Unknown upvalue capture type"); diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 87a530b50..38dcdd40f 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Bytecode.h" + #include #include "ltm.h" @@ -64,5 +66,38 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); +inline int getOpLength(LuauOpcode op) +{ + switch (op) + { + case LOP_GETGLOBAL: + case LOP_SETGLOBAL: + case LOP_GETIMPORT: + case LOP_GETTABLEKS: + case LOP_SETTABLEKS: + case LOP_NAMECALL: + case LOP_JUMPIFEQ: + case LOP_JUMPIFLE: + case LOP_JUMPIFLT: + case LOP_JUMPIFNOTEQ: + case LOP_JUMPIFNOTLE: + case LOP_JUMPIFNOTLT: + case LOP_NEWTABLE: + case LOP_SETLIST: + case LOP_FORGLOOP: + case LOP_LOADKX: + case LOP_FASTCALL2: + case LOP_FASTCALL2K: + case LOP_JUMPXEQKNIL: + case LOP_JUMPXEQKB: + case LOP_JUMPXEQKN: + case LOP_JUMPXEQKS: + return 2; + + default: + return 1; + } +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 03a6c9c43..70ad1438c 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -356,7 +356,11 @@ void applySubstitutions(IrFunction& function, IrOp& op) src.useCount--; if (src.useCount == 0) + { + src.cmd = IrCmd::NOP; removeUse(function, src.a); + src.a = {}; + } } } } @@ -396,7 +400,7 @@ bool compare(double a, double b, IrCondition cond) case IrCondition::NotGreaterEqual: return !(a >= b); default: - LUAU_ASSERT(!"unsupported conidtion"); + LUAU_ASSERT(!"Unsupported condition"); } return false; diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index bda468897..17977c3c2 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -4,7 +4,6 @@ #include "Luau/UnwindBuilder.h" #include "CodeGenUtils.h" -#include "CustomExecUtils.h" #include "lbuiltins.h" #include "lgc.h" diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 40017e359..0140448fd 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -38,7 +38,7 @@ struct NativeContext void (*luaV_prepareFORN)(lua_State* L, StkId plimit, StkId pstep, StkId pinit) = nullptr; void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; - void (*luaV_getimport)(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) = nullptr; + void (*luaV_getimport)(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil) = nullptr; void (*luaV_concat)(lua_State* L, int total, int last) = nullptr; int (*luaH_getn)(Table* t) = nullptr; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 8bb3cd7b7..0f5eb4ebb 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -12,6 +12,7 @@ #include LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) +LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) namespace Luau { @@ -57,6 +58,12 @@ struct ConstPropState return 0xff; } + void updateTag(IrOp op, uint8_t tag) + { + if (RegisterInfo* info = tryGetRegisterInfo(op)) + info->tag = tag; + } + void saveTag(IrOp op, uint8_t tag) { if (RegisterInfo* info = tryGetRegisterInfo(op)) @@ -202,7 +209,7 @@ struct ConstPropState if (RegisterLink* link = instLink.find(instOp.index)) { // Check that the target register hasn't changed the value - if (link->version > regs[link->reg].version) + if (link->version < regs[link->reg].version) return nullptr; return link; @@ -619,13 +626,20 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff) { if (tag == b) - kill(function, inst); + { + if (FFlag::DebugLuauAbortingChecks) + replace(function, inst.c, build.undef()); + else + kill(function, inst); + } else + { replace(function, block, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path + } } else { - state.saveTag(inst.a, b); // We can assume the tag value going forward + state.updateTag(inst.a, b); // We can assume the tag value going forward } break; } @@ -633,25 +647,46 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (RegisterInfo* info = state.tryGetRegisterInfo(inst.a)) { if (info->knownNotReadonly) - kill(function, inst); + { + if (FFlag::DebugLuauAbortingChecks) + replace(function, inst.b, build.undef()); + else + kill(function, inst); + } else + { info->knownNotReadonly = true; + } } break; case IrCmd::CHECK_NO_METATABLE: if (RegisterInfo* info = state.tryGetRegisterInfo(inst.a)) { if (info->knownNoMetatable) - kill(function, inst); + { + if (FFlag::DebugLuauAbortingChecks) + replace(function, inst.b, build.undef()); + else + kill(function, inst); + } else + { info->knownNoMetatable = true; + } } break; case IrCmd::CHECK_SAFE_ENV: if (state.inSafeEnv) - kill(function, inst); + { + if (FFlag::DebugLuauAbortingChecks) + replace(function, inst.a, build.undef()); + else + kill(function, inst); + } else + { state.inSafeEnv = true; + } break; case IrCmd::CHECK_GC: // It is enough to perform a GC check once in a block diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 82bf6e5a3..54086d53f 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -300,8 +300,9 @@ enum LuauOpcode // A: target register (see FORGLOOP for register layout) LOP_FORGPREP_NEXT, - // removed in v3 - LOP_DEP_FORGLOOP_NEXT, + // NATIVECALL: start executing new function in native code + // this is a pseudo-instruction that is never emitted by bytecode compiler, but can be constructed at runtime to accelerate native code dispatch + LOP_NATIVECALL, // GETVARARGS: copy variables into the target register from vararg storage for current function // A: target register diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 073bb1c79..80fe0b6dd 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -252,8 +252,7 @@ BuiltinInfo getBuiltinInfo(int bfid) return {-1, -1}; case LBF_ASSERT: - return {-1, -1}; - ; // assert() returns all values when first value is truthy + return {-1, -1}; // assert() returns all values when first value is truthy case LBF_MATH_ABS: case LBF_MATH_ACOS: diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9eda214c3..64667221f 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -25,7 +25,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileLimitInsns, false) +LUAU_FASTFLAGVARIABLE(LuauCompileInlineDefer, false) namespace Luau { @@ -250,7 +250,7 @@ struct Compiler popLocals(0); - if (FFlag::LuauCompileLimitInsns && bytecode.getInstructionCount() > kMaxInstructionCount) + if (bytecode.getInstructionCount() > kMaxInstructionCount) CompileError::raise(func->location, "Exceeded function instruction limit; split the function into parts to compile"); bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size())); @@ -559,10 +559,19 @@ struct Compiler size_t oldLocals = localStack.size(); - // note that we push the frame early; this is needed to block recursive inline attempts - inlineFrames.push_back({func, oldLocals, target, targetCount}); + std::vector args; + if (FFlag::LuauCompileInlineDefer) + { + args.reserve(func->args.size); + } + else + { + // note that we push the frame early; this is needed to block recursive inline attempts + inlineFrames.push_back({func, oldLocals, target, targetCount}); + } // evaluate all arguments; note that we don't emit code for constant arguments (relying on constant folding) + // note that compiler state (variable registers/values) does not change here - we defer that to a separate loop below to handle nested calls for (size_t i = 0; i < func->args.size; ++i) { AstLocal* var = func->args.data[i]; @@ -581,8 +590,16 @@ struct Compiler else LUAU_ASSERT(!"Unexpected expression type"); - for (size_t j = i; j < func->args.size; ++j) - pushLocal(func->args.data[j], uint8_t(reg + (j - i))); + if (FFlag::LuauCompileInlineDefer) + { + for (size_t j = i; j < func->args.size; ++j) + args.push_back({func->args.data[j], uint8_t(reg + (j - i))}); + } + else + { + for (size_t j = i; j < func->args.size; ++j) + pushLocal(func->args.data[j], uint8_t(reg + (j - i))); + } // all remaining function arguments have been allocated and assigned to break; @@ -597,17 +614,26 @@ struct Compiler else bytecode.emitABC(LOP_LOADNIL, reg, 0, 0); - pushLocal(var, reg); + if (FFlag::LuauCompileInlineDefer) + args.push_back({var, reg}); + else + pushLocal(var, reg); } else if (arg == nullptr) { // since the argument is not mutated, we can simply fold the value into the expressions that need it - locstants[var] = {Constant::Type_Nil}; + if (FFlag::LuauCompileInlineDefer) + args.push_back({var, kInvalidReg, {Constant::Type_Nil}}); + else + locstants[var] = {Constant::Type_Nil}; } else if (const Constant* cv = constants.find(arg); cv && cv->type != Constant::Type_Unknown) { // since the argument is not mutated, we can simply fold the value into the expressions that need it - locstants[var] = *cv; + if (FFlag::LuauCompileInlineDefer) + args.push_back({var, kInvalidReg, *cv}); + else + locstants[var] = *cv; } else { @@ -617,13 +643,20 @@ struct Compiler // if the argument is a local that isn't mutated, we will simply reuse the existing register if (int reg = le ? getExprLocalReg(le) : -1; reg >= 0 && (!lv || !lv->written)) { - pushLocal(var, uint8_t(reg)); + if (FFlag::LuauCompileInlineDefer) + args.push_back({var, uint8_t(reg)}); + else + pushLocal(var, uint8_t(reg)); } else { uint8_t temp = allocReg(arg, 1); compileExprTemp(arg, temp); - pushLocal(var, temp); + + if (FFlag::LuauCompileInlineDefer) + args.push_back({var, temp}); + else + pushLocal(var, temp); } } } @@ -635,6 +668,20 @@ struct Compiler compileExprAuto(expr->args.data[i], rsi); } + if (FFlag::LuauCompileInlineDefer) + { + // apply all evaluated arguments to the compiler state + // note: locals use current startpc for debug info, although some of them have been computed earlier; this is similar to compileStatLocal + for (InlineArg& arg : args) + if (arg.value.type == Constant::Type_Unknown) + pushLocal(arg.local, arg.reg); + else + locstants[arg.local] = arg.value; + + // the inline frame will be used to compile return statements as well as to reject recursive inlining attempts + inlineFrames.push_back({func, oldLocals, target, targetCount}); + } + // fold constant values updated above into expressions in the function body foldConstants(constants, variables, locstants, builtinsFold, func->body); @@ -3747,6 +3794,14 @@ struct Compiler AstExpr* untilCondition; }; + struct InlineArg + { + AstLocal* local; + + uint8_t reg; + Constant value; + }; + struct InlineFrame { AstExprFunction* func; diff --git a/Sources.cmake b/Sources.cmake index 892b889bb..853b1b866 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -113,7 +113,6 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/BitUtils.h CodeGen/src/ByteUtils.h - CodeGen/src/CustomExecUtils.h CodeGen/src/CodeGenUtils.h CodeGen/src/CodeGenA64.h CodeGen/src/CodeGenX64.h @@ -404,6 +403,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.primitives.test.cpp tests/TypeInfer.provisional.test.cpp tests/TypeInfer.refinements.test.cpp + tests/TypeInfer.rwprops.test.cpp tests/TypeInfer.singletons.test.cpp tests/TypeInfer.tables.test.cpp tests/TypeInfer.test.cpp diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 569c1b4e5..6330b4c36 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -31,7 +31,7 @@ Proto* luaF_newproto(lua_State* L) f->source = NULL; f->debugname = NULL; f->debuginsn = NULL; - + f->codeentry = NULL; f->execdata = NULL; f->exectarget = 0; diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 21b8de018..a42616332 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -275,6 +275,7 @@ typedef struct Proto TString* debugname; uint8_t* debuginsn; // a copy of code[] array with just opcodes + const Instruction* codeentry; void* execdata; uintptr_t exectarget; diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index b320a252b..e0727c163 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -219,9 +219,7 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->cb = lua_Callbacks(); -#if LUA_CUSTOM_EXECUTION g->ecb = lua_ExecutionCallbacks(); -#endif g->gcstats = GCStats(); diff --git a/VM/src/lstate.h b/VM/src/lstate.h index ae1e18664..ca8bc1b31 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -69,7 +69,7 @@ typedef struct CallInfo #define LUA_CALLINFO_RETURN (1 << 0) // should the interpreter return after returning from this callinfo? first frame must have this set #define LUA_CALLINFO_HANDLE (1 << 1) // should the error thrown during execution get handled by continuation from this callinfo? func must be C -#define LUA_CALLINFO_CUSTOM (1 << 2) // should this function be executed using custom execution callback +#define LUA_CALLINFO_NATIVE (1 << 2) // should this function be executed using execution callback for native code #define curr_func(L) (clvalue(L->ci->func)) #define ci_func(ci) (clvalue((ci)->func)) @@ -211,9 +211,7 @@ typedef struct global_State lua_Callbacks cb; -#if LUA_CUSTOM_EXECUTION lua_ExecutionCallbacks ecb; -#endif void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory diff --git a/VM/src/lvm.h b/VM/src/lvm.h index c4b1c18b5..cfb6456b5 100644 --- a/VM/src/lvm.h +++ b/VM/src/lvm.h @@ -23,7 +23,8 @@ LUAI_FUNC int luaV_tostring(lua_State* L, StkId obj); LUAI_FUNC void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId val); LUAI_FUNC void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val); LUAI_FUNC void luaV_concat(lua_State* L, int total, int last); -LUAI_FUNC void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil); +LUAI_FUNC void luaV_getimport(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil); +LUAI_FUNC void luaV_getimport_dep(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil); LUAI_FUNC void luaV_prepareFORN(lua_State* L, StkId plimit, StkId pstep, StkId pinit); LUAI_FUNC void luaV_callTM(lua_State* L, int nparams, int res); LUAI_FUNC void luaV_tryfuncTM(lua_State* L, StkId func); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 454a4e178..280c47927 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -17,6 +17,7 @@ #include LUAU_FASTFLAG(LuauUniformTopHandling) +LUAU_FASTFLAG(LuauGetImportDirect) // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -101,7 +102,7 @@ LUAU_FASTFLAG(LuauUniformTopHandling) VM_DISPATCH_OP(LOP_CONCAT), VM_DISPATCH_OP(LOP_NOT), VM_DISPATCH_OP(LOP_MINUS), VM_DISPATCH_OP(LOP_LENGTH), VM_DISPATCH_OP(LOP_NEWTABLE), \ VM_DISPATCH_OP(LOP_DUPTABLE), VM_DISPATCH_OP(LOP_SETLIST), VM_DISPATCH_OP(LOP_FORNPREP), VM_DISPATCH_OP(LOP_FORNLOOP), \ VM_DISPATCH_OP(LOP_FORGLOOP), VM_DISPATCH_OP(LOP_FORGPREP_INEXT), VM_DISPATCH_OP(LOP_DEP_FORGLOOP_INEXT), VM_DISPATCH_OP(LOP_FORGPREP_NEXT), \ - VM_DISPATCH_OP(LOP_DEP_FORGLOOP_NEXT), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ + VM_DISPATCH_OP(LOP_NATIVECALL), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_DEP_JUMPIFEQK), VM_DISPATCH_OP(LOP_DEP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), VM_DISPATCH_OP(LOP_JUMPXEQKNIL), \ @@ -210,7 +211,7 @@ static void luau_execute(lua_State* L) LUAU_ASSERT(!isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black #if LUA_CUSTOM_EXECUTION - if ((L->ci->flags & LUA_CALLINFO_CUSTOM) && !SingleStep) + if ((L->ci->flags & LUA_CALLINFO_NATIVE) && !SingleStep) { Proto* p = clvalue(L->ci->func)->l.p; LUAU_ASSERT(p->execdata); @@ -432,12 +433,20 @@ static void luau_execute(lua_State* L) { uint32_t aux = *pc++; - VM_PROTECT(luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false)); - ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack + if (FFlag::LuauGetImportDirect) + { + VM_PROTECT(luaV_getimport(L, cl->env, k, ra, aux, /* propagatenil= */ false)); + VM_NEXT(); + } + else + { + VM_PROTECT(luaV_getimport_dep(L, cl->env, k, aux, /* propagatenil= */ false)); + ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack - setobj2s(L, ra, L->top - 1); - L->top--; - VM_NEXT(); + setobj2s(L, ra, L->top - 1); + L->top--; + VM_NEXT(); + } } } @@ -954,21 +963,11 @@ static void luau_execute(lua_State* L) setnilvalue(argi++); // complete missing arguments L->top = p->is_vararg ? argi : ci->top; -#if LUA_CUSTOM_EXECUTION - if (LUAU_UNLIKELY(p->execdata && !SingleStep)) - { - ci->flags = LUA_CALLINFO_CUSTOM; - ci->savedpc = p->code; - - if (L->global->ecb.enter(L, p) == 1) - goto reentry; - else - goto exit; - } -#endif - // reentry - pc = p->code; + // codeentry may point to NATIVECALL instruction when proto is compiled to native code + // this will result in execution continuing in native code, and is equivalent to if (p->execdata) but has no additional overhead + // note that p->codeentry may point *outside* of p->code..p->code+p->sizecode, but that pointer never gets saved to savedpc. + pc = SingleStep ? p->code : p->codeentry; cl = ccl; base = L->base; k = p->k; @@ -1055,7 +1054,7 @@ static void luau_execute(lua_State* L) Proto* nextproto = nextcl->l.p; #if LUA_CUSTOM_EXECUTION - if (LUAU_UNLIKELY((cip->flags & LUA_CALLINFO_CUSTOM) && !SingleStep)) + if (LUAU_UNLIKELY((cip->flags & LUA_CALLINFO_NATIVE) && !SingleStep)) { if (L->global->ecb.enter(L, nextproto) == 1) goto reentry; @@ -2380,10 +2379,24 @@ static void luau_execute(lua_State* L) VM_NEXT(); } - VM_CASE(LOP_DEP_FORGLOOP_NEXT) + VM_CASE(LOP_NATIVECALL) { - LUAU_ASSERT(!"Unsupported deprecated opcode"); + Proto* p = cl->l.p; + LUAU_ASSERT(p->execdata); + + CallInfo* ci = L->ci; + ci->flags = LUA_CALLINFO_NATIVE; + ci->savedpc = p->code; + +#if LUA_CUSTOM_EXECUTION + if (L->global->ecb.enter(L, p) == 1) + goto reentry; + else + goto exit; +#else + LUAU_ASSERT(!"Opcode is only valid when LUA_CUSTOM_EXECUTION is defined"); LUAU_UNREACHABLE(); +#endif } VM_CASE(LOP_GETVARARGS) @@ -2896,7 +2909,7 @@ int luau_precall(lua_State* L, StkId func, int nresults) #if LUA_CUSTOM_EXECUTION if (p->execdata) - ci->flags = LUA_CALLINFO_CUSTOM; + ci->flags = LUA_CALLINFO_NATIVE; #endif return PCRLUA; diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 305e540c6..f26cc05d7 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauGetImportDirect, false) + // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template struct TempBuffer @@ -40,8 +42,45 @@ struct TempBuffer } }; -void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) +void luaV_getimport(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil) { + int count = id >> 30; + LUAU_ASSERT(count > 0); + + int id0 = int(id >> 20) & 1023; + int id1 = int(id >> 10) & 1023; + int id2 = int(id) & 1023; + + // after the first call to luaV_gettable, res may be invalid, and env may (sometimes) be garbage collected + // we take care to not use env again and to restore res before every consecutive use + ptrdiff_t resp = savestack(L, res); + + // global lookup for id0 + TValue g; + sethvalue(L, &g, env); + luaV_gettable(L, &g, &k[id0], res); + + // table lookup for id1 + if (count < 2) + return; + + res = restorestack(L, resp); + if (!propagatenil || !ttisnil(res)) + luaV_gettable(L, res, &k[id1], res); + + // table lookup for id2 + if (count < 3) + return; + + res = restorestack(L, resp); + if (!propagatenil || !ttisnil(res)) + luaV_gettable(L, res, &k[id2], res); +} + +void luaV_getimport_dep(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) +{ + LUAU_ASSERT(!FFlag::LuauGetImportDirect); + int count = id >> 30; int id0 = count > 0 ? int(id >> 20) & 1023 : -1; int id1 = count > 1 ? int(id >> 10) & 1023 : -1; @@ -114,7 +153,17 @@ static void resolveImportSafe(lua_State* L, Table* env, TValue* k, uint32_t id) // note: we call getimport with nil propagation which means that accesses to table chains like A.B.C will resolve in nil // this is technically not necessary but it reduces the number of exceptions when loading scripts that rely on getfenv/setfenv for global // injection - luaV_getimport(L, L->gt, self->k, self->id, /* propagatenil= */ true); + if (FFlag::LuauGetImportDirect) + { + // allocate a stack slot so that we can do table lookups + luaD_checkstack(L, 1); + setnilvalue(L->top); + L->top++; + + luaV_getimport(L, L->gt, self->k, L->top - 1, self->id, /* propagatenil= */ true); + } + else + luaV_getimport_dep(L, L->gt, self->k, self->id, /* propagatenil= */ true); } }; @@ -204,6 +253,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size for (int j = 0; j < p->sizecode; ++j) p->code[j] = read(data, size, offset); + p->codeentry = p->code; + p->sizek = readVarInt(data, size, offset); p->k = luaM_newarray(L, p->sizek, TValue, p->memcat); diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index cdadfd76b..c917a7bb3 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -455,6 +455,11 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Conditionals") SINGLE_COMPARE(cset(x1, ConditionA64::Less), 0x9A9FA7E1); } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Undefined") +{ + SINGLE_COMPARE(udf(), 0x00000000); +} + TEST_CASE("LogTest") { AssemblyBuilderA64 build(/* logText= */ true); diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 28e9dd1f8..63e92f8fa 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -537,6 +537,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions") { SINGLE_COMPARE(int3(), 0xcc); + SINGLE_COMPARE(ud2(), 0x0f, 0x0b); SINGLE_COMPARE(bsr(eax, edx), 0x0f, 0xbd, 0xc2); SINGLE_COMPARE(bsf(eax, edx), 0x0f, 0xbc, 0xc2); } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index cabf1ccea..4885b3174 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -5796,7 +5796,9 @@ RETURN R3 1 TEST_CASE("InlineRecurseArguments") { - // we can't inline a function if it's used to compute its own arguments + ScopedFastFlag sff("LuauCompileInlineDefer", true); + + // the example looks silly but we preserve it verbatim as it was found by fuzzer for a previous version of the compiler CHECK_EQ("\n" + compileFunction(R"( local function foo(a, b) end @@ -5805,15 +5807,82 @@ foo(foo(foo,foo(foo,foo))[foo]) 1, 2), R"( DUPCLOSURE R0 K0 ['foo'] -MOVE R2 R0 -MOVE R3 R0 -MOVE R4 R0 -MOVE R5 R0 -MOVE R6 R0 -CALL R4 2 -1 -CALL R2 -1 1 +LOADNIL R3 +LOADNIL R2 GETTABLE R1 R2 R0 RETURN R0 0 +)"); + + // verify that invocations of the inlined function in any position for computing the arguments to itself compile + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x, y, z = ... + +return foo(foo(x, y), foo(z, 1)) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 ['foo'] +GETVARARGS R1 3 +ADD R5 R1 R2 +ADDK R6 R3 K1 [1] +ADD R4 R5 R6 +RETURN R4 1 +)"); + + // verify that invocations of the inlined function in any position for computing the arguments to itself compile, including constants and locals + // note that foo(k1, k2) doesn't get constant folded, so there's still actual math emitted for some of the calls below + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x, y, z = ... + +return + foo(foo(1, 2), 3), + foo(1, foo(2, 3)), + foo(x, foo(2, 3)), + foo(x, foo(y, 3)), + foo(x, foo(y, z)), + foo(x+0, foo(y, z)), + foo(x+0, foo(y+0, z)), + foo(x+0, foo(y, z+0)), + foo(1, foo(x, y)) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 ['foo'] +GETVARARGS R1 3 +LOADN R5 3 +ADDK R4 R5 K1 [3] +LOADN R6 5 +LOADN R7 1 +ADD R5 R7 R6 +LOADN R7 5 +ADD R6 R1 R7 +ADDK R8 R2 K1 [3] +ADD R7 R1 R8 +ADD R9 R2 R3 +ADD R8 R1 R9 +ADDK R10 R1 K2 [0] +ADD R11 R2 R3 +ADD R9 R10 R11 +ADDK R11 R1 K2 [0] +ADDK R13 R2 K2 [0] +ADD R12 R13 R3 +ADD R10 R11 R12 +ADDK R12 R1 K2 [0] +ADDK R14 R3 K2 [0] +ADD R13 R2 R14 +ADD R11 R12 R13 +ADD R13 R1 R2 +LOADN R14 1 +ADD R12 R14 R13 +RETURN R4 9 )"); } diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 2f5fbf1c9..8b3993081 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -80,6 +80,8 @@ class IrBuilderFixture static const int tnil = 0; static const int tboolean = 1; static const int tnumber = 3; + static const int tstring = 5; + static const int ttable = 6; }; TEST_SUITE_BEGIN("Optimization"); @@ -1286,8 +1288,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") IrOp falseBlock = build.block(IrBlockKind::Internal); build.beginBlock(block); - IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(5)); + IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); build.inst(IrCmd::JUMP_EQ_INT, value, build.constInt(5), trueBlock, falseBlock); build.beginBlock(trueBlock); @@ -1317,8 +1319,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval") IrOp falseBlock = build.block(IrBlockKind::Internal); build.beginBlock(block); - IrOp value = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(4.0)); + IrOp value = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); build.inst(IrCmd::JUMP_CMP_NUM, value, build.constDouble(8.0), build.cond(IrCondition::Greater), trueBlock, falseBlock); build.beginBlock(trueBlock); @@ -1551,6 +1553,50 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntNumIntPeepholes") )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "InvalidateReglinkVersion") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tstring)); + IrOp tv2 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(2)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), tv2); + IrOp ft = build.inst(IrCmd::NEW_TABLE); + build.inst(IrCmd::STORE_POINTER, build.vmReg(2), ft); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(ttable)); + IrOp tv1 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(0), tv1); + IrOp tag = build.inst(IrCmd::LOAD_TAG, build.vmReg(0)); + build.inst(IrCmd::CHECK_TAG, tag, build.constTag(ttable), fallback); + build.inst(IrCmd::RETURN, build.constUint(0)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + STORE_TAG R2, tstring + %1 = LOAD_TVALUE R2 + STORE_TVALUE R1, %1 + %3 = NEW_TABLE + STORE_POINTER R2, %3 + STORE_TAG R2, ttable + STORE_TVALUE R0, %1 + %8 = LOAD_TAG R0 + CHECK_TAG %8, ttable, bb_fallback_1 + RETURN 0u + +bb_fallback_1: + RETURN 1u + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); @@ -2257,7 +2303,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPropagationOfCapturedRegs") IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::CAPTURE, build.vmReg(0), build.constBool(true)); + build.inst(IrCmd::CAPTURE, build.vmReg(0), build.constUint(1)); IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); IrOp sum = build.inst(IrCmd::ADD_NUM, op1, op2); @@ -2273,7 +2319,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPropagationOfCapturedRegs") bb_0: ; in regs: R0 - CAPTURE R0, true + CAPTURE R0, 1u %1 = LOAD_DOUBLE R0 %2 = LOAD_DOUBLE R0 %3 = ADD_NUM %1, %2 diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index b11b05d7f..c10131baa 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -213,4 +213,36 @@ TEST_CASE_FIXTURE(Fixture, "add_family_at_work") CHECK(toString(result.errors[1]) == "Type family instance Add is uninhabited"); } +TEST_CASE_FIXTURE(Fixture, "internal_families_raise_errors") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function innerSum(a, b) + local _ = a + b + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Type family instance Add depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_families_inhabited_with_normalization") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + local useGridConfig : any + local columns = useGridConfig("columns", {}) or 1 + local gutter = useGridConfig('gutter', {}) or 0 + local margin = useGridConfig('margin', {}) or 0 + return function(frameAbsoluteWidth: number) + local cellAbsoluteWidth = (frameAbsoluteWidth - 2 * margin + gutter) / columns - gutter + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 78f755874..f53d6e04d 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1922,6 +1922,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceede {"LuauClonePublicInterfaceLess2", true}, {"LuauSubstitutionReentrant", true}, {"LuauSubstitutionFixMissingFields", true}, + {"LuauCloneSkipNonInternalVisit", true}, }; CheckResult result = check(R"( @@ -1930,13 +1931,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceede end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_MESSAGE(get(result.errors[0]), "Expected CodeTooComplex but got: " << toString(result.errors[0])); CHECK(Location({1, 17}, {1, 18}) == result.errors[0].location); - - CHECK_MESSAGE(get(result.errors[1]), "Expected UnificationTooComplex but got: " << toString(result.errors[1])); - CHECK(Location({0, 0}, {4, 4}) == result.errors[1].location); } /* We had a bug under DCR where instantiated type packs had a nullptr scope. diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index c3dbbc7dc..fef000e40 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -60,23 +60,44 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_regression_issue_69967") { ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; CheckResult result = check(R"( + type Iterable = typeof(setmetatable( + {}, + {}::{ + __iter: (self: Iterable) -> () -> (number, string) + } + )) -type Iterable = typeof(setmetatable( - {}, - {}::{ - __iter: (self: Iterable) -> () -> (number, string) - } -)) + local t: Iterable -local t: Iterable + for a, b in t do end + )"); -for a, b in t do end -)"); + LUAU_REQUIRE_NO_ERRORS(result); +} - LUAU_REQUIRE_ERROR_COUNT(1, result); - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("__iter metamethod must return (next[, table[, state]])", ge->message); +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_regression_issue_69967_alt") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + CheckResult result = check(R"( + type Iterable = typeof(setmetatable( + {}, + {}::{ + __iter: (self: Iterable) -> () -> (number, string) + } + )) + + local t: Iterable + local x, y + + for a, b in t do + x = a + y = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("x"))); + CHECK_EQ("string", toString(requireType("y"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") @@ -777,4 +798,130 @@ TEST_CASE_FIXTURE(Fixture, "iterate_over_free_table") CHECK("Cannot iterate over a table without indexer" == ge->message); } +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_explore_raycast_minimization") +{ + CheckResult result = check(R"( + local testResults = {} + for _, testData in pairs(testResults) do + end + + table.insert(testResults, {}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_minimized_fragmented_keys_1") +{ + CheckResult result = check(R"( + local function rawpairs(t) + return next, t, nil + end + + local function getFragmentedKeys(tbl) + local _ = rawget(tbl, 0) + for _ in rawpairs(tbl) do + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_minimized_fragmented_keys_2") +{ + CheckResult result = check(R"( + local function getFragmentedKeys(tbl) + local _ = rawget(tbl, 0) + for _ in next, tbl, nil do + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_minimized_fragmented_keys_3") +{ + CheckResult result = check(R"( + local function getFragmentedKeys(tbl) + local _ = rawget(tbl, 0) + for _ in pairs(tbl) do + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_fragmented_keys") +{ + CheckResult result = check(R"( + local function isIndexKey(k, contiguousLength) + return true + end + + local function getTableLength(tbl) + local length = 1 + local value = rawget(tbl, length) + while value ~= nil do + length += 1 + value = rawget(tbl, length) + end + return length - 1 + end + + local function rawpairs(t) + return next, t, nil + end + + local function getFragmentedKeys(tbl) + local keys = {} + local keysLength = 0 + local tableLength = getTableLength(tbl) + for key, _ in rawpairs(tbl) do + if not isIndexKey(key, tableLength) then + keysLength = keysLength + 1 + keys[keysLength] = key + end + end + return keys, keysLength, tableLength + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_xpath_candidates") +{ + CheckResult result = check(R"( + type Instance = {} + local function findCandidates(instances: { Instance }, path: { string }) + for _, name in ipairs(path) do + end + return {} + end + + local canditates = findCandidates({}, {}) + for _, canditate in ipairs(canditates) do end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_on_never_gives_never") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + CheckResult result = check(R"( + local iter: never + local ans + for xs in iter do + ans = xs + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("ans")) == "never"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index dd26cc86b..26f0448b9 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -679,10 +679,9 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") if (FFlag::DebugLuauDeferredConstraintResolution) { - // TODO: This will eventually entirely go away, but for now the Add - // family will ensure there's one less error. - LUAU_REQUIRE_ERROR_COUNT(ops.size() - 1, result); - CHECK_EQ("Unknown type used in - operation; consider adding a type annotation to 'a'", toString(result.errors[0])); + LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); + CHECK_EQ("Type family instance Add depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time", toString(result.errors[0])); + CHECK_EQ("Unknown type used in - operation; consider adding a type annotation to 'a'", toString(result.errors[1])); } else { diff --git a/tests/TypeInfer.rwprops.test.cpp b/tests/TypeInfer.rwprops.test.cpp new file mode 100644 index 000000000..4a748715a --- /dev/null +++ b/tests/TypeInfer.rwprops.test.cpp @@ -0,0 +1,70 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Fixture.h" + +#include "doctest.h" + +LUAU_FASTFLAG(DebugLuauReadWriteProperties) + +using namespace Luau; + +namespace +{ + +struct ReadWriteFixture : Fixture +{ + ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; + + ReadWriteFixture() + : Fixture() + { + if (!FFlag::DebugLuauReadWriteProperties) + return; + + TypeArena* arena = &frontend.globals.globalTypes; + NotNull globalScope{frontend.globals.globalScope.get()}; + + unfreeze(*arena); + + TypeId genericT = arena->addType(GenericType{"T"}); + + TypeId readonlyX = arena->addType(TableType{TableState::Sealed, TypeLevel{}, globalScope}); + getMutable(readonlyX)->props["x"] = Property::readonly(genericT); + globalScope->addBuiltinTypeBinding("ReadonlyX", TypeFun{{{genericT}}, readonlyX}); + + freeze(*arena); + } +}; + +} // namespace + +TEST_SUITE_BEGIN("ReadWriteProperties"); + +TEST_CASE_FIXTURE(ReadWriteFixture, "read_from_a_readonly_prop") +{ + if (!FFlag::DebugLuauReadWriteProperties) + return; + + CheckResult result = check(R"( + function f(rx: ReadonlyX) + local x = rx.x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ReadWriteFixture, "write_to_a_readonly_prop") +{ + if (!FFlag::DebugLuauReadWriteProperties) + return; + + CheckResult result = check(R"( + function f(rx: ReadonlyX) + rx.x = "hello!" -- error + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 225b4ff1b..6b451e116 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -20,7 +20,7 @@ struct TryUnifyFixture : Fixture InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; Normalizer normalizer{&arena, builtinTypes, NotNull{&unifierState}}; - Unifier state{NotNull{&normalizer}, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant}; + Unifier state{NotNull{&normalizer}, NotNull{globalScope.get()}, Location{}, Variance::Covariant}; }; TEST_SUITE_BEGIN("TryUnifyTests"); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 100abfb7f..570c72d06 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -810,7 +810,7 @@ TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") InternalErrorReporter iceHandler; UnifierSharedState sharedState{&iceHandler}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant}; + Unifier u{NotNull{&normalizer}, NotNull{scope.get()}, Location{}, Variance::Covariant}; u.tryUnify(option1, option2); diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index 8558670c3..e78c3d06d 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -324,4 +324,16 @@ TEST_CASE_FIXTURE(Fixture, "math_operators_and_never") CHECK_EQ("(nil, a) -> boolean", toString(requireType("mul"))); } +TEST_CASE_FIXTURE(Fixture, "compare_never") +{ + CheckResult result = check(R"( + local function cmp(x: nil, y: number) + return x ~= nil and x > y and x < y -- infers boolean | never, which is normalized into boolean + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(nil, number) -> boolean", toString(requireType("cmp"))); +} + TEST_SUITE_END(); diff --git a/tests/conformance/native.lua b/tests/conformance/native.lua index 498909c9b..085085c1b 100644 --- a/tests/conformance/native.lua +++ b/tests/conformance/native.lua @@ -14,4 +14,27 @@ assert((function(x, y) return c, b, t, t1, t2 end)(5, 10) == 50) +local function fuzzfail1(...) + repeat + _ = nil + until not {} + for _ in ... do + for l0=_,_ do + end + return + end +end + +local function fuzzFail2() + local _ + do + repeat + _ = typeof(_),{_=_,} + _ = _(_._) + until _ + end +end + +assert(pcall(fuzzFail2) == false) + return('OK') diff --git a/tools/faillist.txt b/tools/faillist.txt index fe3353a8b..5c62e3da9 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -151,8 +151,6 @@ TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values -TypeInferLoops.for_in_loop_with_next -TypeInferLoops.for_in_with_generic_next TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.unreachable_code_after_infinite_loop TypeInferModules.do_not_modify_imported_types_5 @@ -175,7 +173,6 @@ TypeInferOperators.unrelated_classes_cannot_be_compared TypeInferOperators.unrelated_primitives_cannot_be_compared TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.string_index -TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators TypeInferUnknownNever.math_operators_and_never TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check From f4357400c51d0fb7833cb9a4b96451c0ca891bc0 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 2 Jun 2023 11:17:31 -0700 Subject: [PATCH 57/66] Sync to upstream/release/579 --- Analysis/include/Luau/TypeInfer.h | 11 +- Analysis/src/ConstraintSolver.cpp | 47 +++--- Analysis/src/TxnLog.cpp | 139 ++++++++--------- Analysis/src/TypeFamily.cpp | 10 +- Analysis/src/TypeInfer.cpp | 31 ++-- Analysis/src/Unifier.cpp | 109 +------------ CodeGen/src/AssemblyBuilderX64.cpp | 2 +- CodeGen/src/CodeAllocator.cpp | 2 + CodeGen/src/IrLoweringA64.cpp | 3 + CodeGen/src/IrLoweringX64.cpp | 3 + CodeGen/src/OptimizeConstProp.cpp | 66 +++++--- Common/include/Luau/ExperimentalFlags.h | 2 - VM/src/ldebug.cpp | 39 +---- tests/Conformance.test.cpp | 2 - tests/IrBuilder.test.cpp | 80 ++++++++++ tests/TypeInfer.builtins.test.cpp | 6 +- tests/TypeInfer.classes.test.cpp | 8 +- tests/TypeInfer.functions.test.cpp | 71 +++++++++ tests/TypeInfer.intersectionTypes.test.cpp | 34 ++++- tests/TypeInfer.operators.test.cpp | 12 +- tests/TypeInfer.provisional.test.cpp | 169 ++++++++++++++++++--- tests/TypeInfer.singletons.test.cpp | 6 +- tests/TypeInfer.tables.test.cpp | 24 ++- tests/TypeInfer.test.cpp | 8 +- tests/TypeInfer.tryUnify.test.cpp | 13 +- tests/TypeInfer.unionTypes.test.cpp | 124 --------------- tests/conformance/native.lua | 28 +++- tools/faillist.txt | 9 +- 28 files changed, 597 insertions(+), 461 deletions(-) diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index cceff0db1..1a721c743 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -30,7 +30,14 @@ struct ModuleResolver; using Name = std::string; using ScopePtr = std::shared_ptr; -using OverloadErrorEntry = std::tuple, std::vector, const FunctionType*>; + +struct OverloadErrorEntry +{ + TxnLog log; + ErrorVec errors; + std::vector arguments; + const FunctionType* fnTy; +}; bool doesCallError(const AstExprCall* call); bool hasBreak(AstStat* node); @@ -166,7 +173,7 @@ struct TypeChecker const std::vector& errors); void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, - const std::vector& errors); + std::vector& errors); WithPredicate checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil = false, const std::vector& lhsAnnotations = {}, diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 14d0df662..f96f54b6d 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1300,6 +1300,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope); std::vector arityMatchingOverloads; + std::optional bestOverloadLog; for (TypeId overload : overloads) { @@ -1330,29 +1331,24 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(*e)->context != CountMismatch::Context::Arg) && get(*instantiated)) - { + const auto& e = hasCountMismatch(u.errors); + bool areArgumentsCompatible = (!e || get(*e)->context != CountMismatch::Context::Arg) && get(*instantiated); + if (areArgumentsCompatible) arityMatchingOverloads.push_back(*instantiated); - } if (u.errors.empty()) { if (c.callSite) (*c.astOverloadResolvedTypes)[c.callSite] = *instantiated; - // We found a matching overload. - const auto [changedTypes, changedPacks] = u.log.getChanges(); - u.log.commit(); - unblock(changedTypes); - unblock(changedPacks); - unblock(c.result); - - InstantiationQueuer queuer{constraint->scope, constraint->location, this}; - queuer.traverse(fn); - queuer.traverse(inferredTy); - - return true; + // This overload has no errors, so override the bestOverloadLog and use this one. + bestOverloadLog = std::move(u.log); + break; + } + else if (areArgumentsCompatible && !bestOverloadLog) + { + // This overload is erroneous. Replace its inferences with `any` iff there isn't already a TxnLog. + bestOverloadLog = std::move(u.log); } } @@ -1365,15 +1361,20 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, Location{}, Covariant}; - u.enableScopeTests(); + // We didn't find any overload that were a viable candidate, so replace the inferences with `any`. + if (!bestOverloadLog) + { + Unifier u{normalizer, constraint->scope, Location{}, Covariant}; + u.enableScopeTests(); - u.tryUnify(inferredTy, builtinTypes->anyType); - u.tryUnify(fn, builtinTypes->anyType); + u.tryUnify(inferredTy, builtinTypes->anyType); + u.tryUnify(fn, builtinTypes->anyType); - const auto [changedTypes, changedPacks] = u.log.getChanges(); - u.log.commit(); + bestOverloadLog = std::move(u.log); + } + + const auto [changedTypes, changedPacks] = bestOverloadLog->getChanges(); + bestOverloadLog->commit(); unblock(changedTypes); unblock(changedPacks); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 5d38f28e7..8a9b35684 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -111,94 +111,77 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) { - if (FFlag::DebugLuauDeferredConstraintResolution) + /* + * Check for cycles. + * + * We must not combine a log entry that binds 'a to 'b with a log that + * binds 'b to 'a. + * + * Of the two, identify the one with the 'bigger' scope and eliminate the + * entry that rebinds it. + */ + for (const auto& [rightTy, rightRep] : rhs.typeVarChanges) { - /* - * Check for cycles. - * - * We must not combine a log entry that binds 'a to 'b with a log that - * binds 'b to 'a. - * - * Of the two, identify the one with the 'bigger' scope and eliminate the - * entry that rebinds it. - */ - for (const auto& [rightTy, rightRep] : rhs.typeVarChanges) - { - if (rightRep->dead) - continue; - - // We explicitly use get_if here because we do not wish to do anything - // if the uncommitted type is already bound to something else. - const FreeType* rf = get_if(&rightTy->ty); - if (!rf) - continue; - - const BoundType* rb = Luau::get(&rightRep->pending); - if (!rb) - continue; - - const TypeId leftTy = rb->boundTo; - const FreeType* lf = get_if(&leftTy->ty); - if (!lf) - continue; - - auto leftRep = typeVarChanges.find(leftTy); - if (!leftRep) - continue; - - if ((*leftRep)->dead) - continue; - - const BoundType* lb = Luau::get(&(*leftRep)->pending); - if (!lb) - continue; - - if (lb->boundTo == rightTy) - { - // leftTy has been bound to rightTy, but rightTy has also been bound - // to leftTy. We find the one that belongs to the more deeply nested - // scope and remove it from the log. - const bool discardLeft = useScopes ? subsumes(lf->scope, rf->scope) : lf->level.subsumes(rf->level); - - if (discardLeft) - (*leftRep)->dead = true; - else - rightRep->dead = true; - } - } + if (rightRep->dead) + continue; - for (auto& [ty, rightRep] : rhs.typeVarChanges) + // We explicitly use get_if here because we do not wish to do anything + // if the uncommitted type is already bound to something else. + const FreeType* rf = get_if(&rightTy->ty); + if (!rf) + continue; + + const BoundType* rb = Luau::get(&rightRep->pending); + if (!rb) + continue; + + const TypeId leftTy = rb->boundTo; + const FreeType* lf = get_if(&leftTy->ty); + if (!lf) + continue; + + auto leftRep = typeVarChanges.find(leftTy); + if (!leftRep) + continue; + + if ((*leftRep)->dead) + continue; + + const BoundType* lb = Luau::get(&(*leftRep)->pending); + if (!lb) + continue; + + if (lb->boundTo == rightTy) { - if (rightRep->dead) - continue; - - if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) - { - TypeId leftTy = arena->addType((*leftRep)->pending); - TypeId rightTy = arena->addType(rightRep->pending); - - if (follow(leftTy) == follow(rightTy)) - typeVarChanges[ty] = std::move(rightRep); - else - typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; - } + // leftTy has been bound to rightTy, but rightTy has also been bound + // to leftTy. We find the one that belongs to the more deeply nested + // scope and remove it from the log. + const bool discardLeft = useScopes ? subsumes(lf->scope, rf->scope) : lf->level.subsumes(rf->level); + + if (discardLeft) + (*leftRep)->dead = true; else - typeVarChanges[ty] = std::move(rightRep); + rightRep->dead = true; } } - else + + for (auto& [ty, rightRep] : rhs.typeVarChanges) { - for (auto& [ty, rightRep] : rhs.typeVarChanges) + if (rightRep->dead) + continue; + + if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) { - if (auto leftRep = typeVarChanges.find(ty)) - { - TypeId leftTy = arena->addType((*leftRep)->pending); - TypeId rightTy = arena->addType(rightRep->pending); - typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; - } - else + TypeId leftTy = arena->addType((*leftRep)->pending); + TypeId rightTy = arena->addType(rightRep->pending); + + if (follow(leftTy) == follow(rightTy)) typeVarChanges[ty] = std::move(rightRep); + else + typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; } + else + typeVarChanges[ty] = std::move(rightRep); } for (auto& [tp, rep] : rhs.typePackChanges) diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 98a9f97ed..e68187fd1 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -347,14 +347,18 @@ TypeFamilyReductionResult addFamilyFn(std::vector typeParams, st const NormalizedType* normLhsTy = normalizer->normalize(lhsTy); const NormalizedType* normRhsTy = normalizer->normalize(rhsTy); - if (normLhsTy && normRhsTy && normLhsTy->isNumber() && normRhsTy->isNumber()) + if (!normLhsTy || !normRhsTy) { - return {builtins->numberType, false, {}, {}}; + return {std::nullopt, false, {}, {}}; } - else if (log->is(lhsTy) || log->is(rhsTy)) + else if (log->is(normLhsTy->tops) || log->is(normRhsTy->tops)) { return {builtins->anyType, false, {}, {}}; } + else if (normLhsTy->isNumber() && normRhsTy->isNumber()) + { + return {builtins->numberType, false, {}, {}}; + } else if (log->is(lhsTy) || log->is(rhsTy)) { return {builtins->errorRecoveryType(), false, {}, {}}; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index ecf222a84..5127febee 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -40,6 +40,7 @@ LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckClassTypeIndexers, false) +LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) namespace Luau { @@ -4387,7 +4388,12 @@ std::unique_ptr> TypeChecker::checkCallOverload(const else overloadsThatDont.push_back(fn); - errors.emplace_back(std::move(state.errors), args->head, ftv); + errors.push_back(OverloadErrorEntry{ + std::move(state.log), + std::move(state.errors), + args->head, + ftv, + }); } else { @@ -4407,7 +4413,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal { // No overloads succeeded: Scan for one that would have worked had the user // used a.b() rather than a:b() or vice versa. - for (const auto& [_, argVec, ftv] : errors) + for (const auto& e : errors) { // Did you write foo:bar() when you should have written foo.bar()? if (expr.self) @@ -4418,7 +4424,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal TypePackId editedArgPack = addTypePack(TypePack{editedParamList}); Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, e.fnTy->argTypes, editedArgLocations); if (editedState.errors.empty()) { @@ -4433,7 +4439,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return true; } } - else if (ftv->hasSelf) + else if (e.fnTy->hasSelf) { // Did you write foo.bar() when you should have written foo:bar()? if (AstExprIndexName* indexName = expr.func->as()) @@ -4449,7 +4455,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, e.fnTy->argTypes, editedArgLocations); if (editedState.errors.empty()) { @@ -4472,11 +4478,14 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, - const std::vector& errors) + std::vector& errors) { if (overloads.size() == 1) { - reportErrors(std::get<0>(errors.front())); + if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls) + errors.front().log.commit(); + + reportErrors(errors.front().errors); return; } @@ -4498,11 +4507,15 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast const FunctionType* ftv = get(overload); auto error = std::find_if(errors.begin(), errors.end(), [ftv](const OverloadErrorEntry& e) { - return ftv == std::get<2>(e); + return ftv == e.fnTy; }); LUAU_ASSERT(error != errors.end()); - reportErrors(std::get<0>(*error)); + + if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls) + error->log.commit(); + + reportErrors(error->errors); // If only one overload matched, we don't need this error because we provided the previous errors. if (overloadsThatMatchArgCount.size() == 1) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 9a12234bd..91b89136a 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -21,13 +21,13 @@ LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) LUAU_FASTFLAGVARIABLE(LuauVariadicAnyCanBeGeneric, false) -LUAU_FASTFLAGVARIABLE(LuauUnifyTwoOptions, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) +LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) namespace Luau { @@ -761,93 +761,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool log.popSeen(superTy, subTy); } -/* - * If the passed type is an option, strip nil out. - * - * There is an important subtlety to be observed here: - * - * We want to do a peephole fix to unify the subtype relation A? <: B? where we - * instead peel off the options and relate A <: B instead, but only works if we - * are certain that neither A nor B are themselves optional. - * - * For instance, if we want to test that (boolean?)? <: boolean?, we must peel - * off both layers of optionality from the subTy. - * - * We must also handle unions that have more than two choices. - * - * eg (string | nil)? <: boolean? - */ -static std::optional unwrapOption(NotNull builtinTypes, NotNull arena, const TxnLog& log, TypeId ty, DenseHashSet& seen) -{ - if (seen.find(ty)) - return std::nullopt; - seen.insert(ty); - - const UnionType* ut = get(follow(ty)); - if (!ut) - return std::nullopt; - - if (2 == ut->options.size()) - { - if (isNil(follow(ut->options[0]))) - { - std::optional doubleUnwrapped = unwrapOption(builtinTypes, arena, log, ut->options[1], seen); - return doubleUnwrapped.value_or(ut->options[1]); - } - if (isNil(follow(ut->options[1]))) - { - std::optional doubleUnwrapped = unwrapOption(builtinTypes, arena, log, ut->options[0], seen); - return doubleUnwrapped.value_or(ut->options[0]); - } - } - - std::set newOptions; - bool found = false; - for (TypeId t : ut) - { - t = log.follow(t); - if (isNil(t)) - { - found = true; - continue; - } - else - newOptions.insert(t); - } - - if (!found) - return std::nullopt; - else if (newOptions.empty()) - return builtinTypes->neverType; - else if (1 == newOptions.size()) - return *begin(newOptions); - else - return arena->addType(UnionType{std::vector(begin(newOptions), end(newOptions))}); -} - -static std::optional unwrapOption(NotNull builtinTypes, NotNull arena, const TxnLog& log, TypeId ty) -{ - DenseHashSet seen{nullptr}; - - return unwrapOption(builtinTypes, arena, log, ty, seen); -} - - void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, TypeId superTy) { - // Peephole fix: A? <: B? if A <: B - // - // This works around issues that can arise if A or B is free. We do not - // want either of those types to be bound to nil. - if (FFlag::LuauUnifyTwoOptions) - { - if (auto subOption = unwrapOption(builtinTypes, NotNull{types}, log, subTy)) - { - if (auto superOption = unwrapOption(builtinTypes, NotNull{types}, log, superTy)) - return tryUnify_(*subOption, *superOption); - } - } - // A | B <: T if and only if A <: T and B <: T bool failed = false; bool errorsSuppressed = true; @@ -880,7 +795,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ } } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls) log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); else { @@ -1364,25 +1279,6 @@ void Unifier::tryUnifyNormalizedTypes( const ClassType* superCtv = get(superClass); LUAU_ASSERT(superCtv); - if (FFlag::LuauUnifyTwoOptions) - { - if (variance == Invariant) - { - if (subCtv == superCtv) - { - found = true; - - /* - * The only way we could care about superNegations is if - * one of them was equal to superCtv. However, - * normalization ensures that this is impossible. - */ - } - else - continue; - } - } - if (isSubclass(subCtv, superCtv)) { found = true; @@ -2960,7 +2856,6 @@ TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) TxnLog Unifier::combineLogsIntoUnion(std::vector logs) { - LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); TxnLog result(useScopes); for (TxnLog& log : logs) result.concatAsUnion(std::move(log), NotNull{types}); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index c7644a86c..426a0259d 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -1415,7 +1415,7 @@ void AssemblyBuilderX64::commit() { LUAU_ASSERT(codePos <= codeEnd); - if (codeEnd - codePos < kMaxInstructionLength) + if (unsigned(codeEnd - codePos) < kMaxInstructionLength) extend(); } diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index 09e1bb712..880a32446 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -56,8 +56,10 @@ static void makePagesExecutable(uint8_t* mem, size_t size) static void flushInstructionCache(uint8_t* mem, size_t size) { +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) if (FlushInstructionCache(GetCurrentProcess(), mem, size) == 0) LUAU_ASSERT(!"Failed to flush instruction cache"); +#endif } #else static uint8_t* allocatePages(size_t size) diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index fb5d86878..5f6249000 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -766,6 +766,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.ldr(x2, mem(x2, offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaT_gettm))); build.blr(x3); + + build.cbz(x0, labelOp(inst.c)); // no tag method + inst.regA64 = regs.takeReg(x0, index); break; } diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 2efd73ea7..b9c35df04 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -639,6 +639,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); } + build.test(rax, rax); + build.jcc(ConditionX64::Zero, labelOp(inst.c)); // No tag method + inst.regX64 = regs.takeReg(rax, index); break; } diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 0f5eb4ebb..338bb49f9 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -146,8 +146,15 @@ struct ConstPropState void invalidateRegisterRange(int firstReg, int count) { - for (int i = firstReg; i < firstReg + count && i <= maxReg; ++i) - invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true); + if (count == -1) + { + invalidateRegistersFrom(firstReg); + } + else + { + for (int i = firstReg; i < firstReg + count && i <= maxReg; ++i) + invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true); + } } void invalidateCapturedRegisters() @@ -236,9 +243,18 @@ struct ConstPropState return; if (uint32_t* prevIdx = valueMap.find(inst)) - substitute(function, inst, IrOp{IrOpKind::Inst, *prevIdx}); - else - valueMap[inst] = instIdx; + { + const IrInst& prev = function.instructions[*prevIdx]; + + // Previous load might have been removed as unused + if (prev.useCount != 0) + { + substitute(function, inst, IrOp{IrOpKind::Inst, *prevIdx}); + return; + } + } + + valueMap[inst] = instIdx; } // Vm register load can be replaced by a previous load of the same version of the register @@ -260,23 +276,28 @@ struct ConstPropState // Check if there is a value that already has this version of the register if (uint32_t* prevIdx = valueMap.find(versionedLoad)) { - // Previous value might not be linked to a register yet - // For example, it could be a NEW_TABLE stored into a register and we might need to track guards made with this value - if (!instLink.contains(*prevIdx)) - createRegLink(*prevIdx, loadInst.a); + const IrInst& prev = function.instructions[*prevIdx]; - // Substitute load instructon with the previous value - substitute(function, loadInst, IrOp{IrOpKind::Inst, *prevIdx}); + // Previous load might have been removed as unused + if (prev.useCount != 0) + { + // Previous value might not be linked to a register yet + // For example, it could be a NEW_TABLE stored into a register and we might need to track guards made with this value + if (!instLink.contains(*prevIdx)) + createRegLink(*prevIdx, loadInst.a); + + // Substitute load instructon with the previous value + substitute(function, loadInst, IrOp{IrOpKind::Inst, *prevIdx}); + return; + } } - else - { - uint32_t instIdx = function.getInstIndex(loadInst); - // Record load of this register version for future substitution - valueMap[versionedLoad] = instIdx; + uint32_t instIdx = function.getInstIndex(loadInst); - createRegLink(instIdx, loadInst.a); - } + // Record load of this register version for future substitution + valueMap[versionedLoad] = instIdx; + + createRegLink(instIdx, loadInst.a); } // VM register loads can use the value that was stored in the same Vm register earlier @@ -456,9 +477,16 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } if (state.tryGetTag(source) == value) - kill(function, inst); + { + if (FFlag::DebugLuauAbortingChecks) + replace(function, block, index, {IrCmd::CHECK_TAG, inst.a, inst.b, build.undef()}); + else + kill(function, inst); + } else + { state.saveTag(source, value); + } } else { diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index df51e7b95..8eca1050a 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -14,8 +14,6 @@ inline bool isFlagExperimental(const char* flag) "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) "LuauTinyControlFlowAnalysis", // waiting for updates to packages depended by internal builtin plugins - "LuauUnifyTwoOptions", // requires some fixes to lua-apps code - // makes sure we always have at least one entry nullptr, }; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 5ea08b53b..d3e21f5de 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauFixBreakpointLineSearch, false) - static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -427,22 +425,11 @@ static int getnextline(Proto* p, int line) int candidate = luaG_getline(p, i); - if (FFlag::LuauFixBreakpointLineSearch) - { - if (candidate == line) - return line; + if (candidate == line) + return line; - if (candidate > line && (closest == -1 || candidate < closest)) - closest = candidate; - } - else - { - if (candidate >= line) - { - closest = candidate; - break; - } - } + if (candidate > line && (closest == -1 || candidate < closest)) + closest = candidate; } } @@ -451,21 +438,11 @@ static int getnextline(Proto* p, int line) // Find the closest line number to the intended one. int candidate = getnextline(p->p[i], line); - if (FFlag::LuauFixBreakpointLineSearch) - { - if (candidate == line) - return line; + if (candidate == line) + return line; - if (candidate > line && (closest == -1 || candidate < closest)) - closest = candidate; - } - else - { - if (closest == -1 || (candidate >= line && candidate < closest)) - { - closest = candidate; - } - } + if (candidate > line && (closest == -1 || candidate < closest)) + closest = candidate; } return closest; diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 9e5ae30e9..9b47b6f5d 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -561,8 +561,6 @@ TEST_CASE("Debug") TEST_CASE("Debugger") { - ScopedFastFlag luauFixBreakpointLineSearch{"LuauFixBreakpointLineSearch", true}; - static int breakhits = 0; static lua_State* interruptedthread = nullptr; static bool singlestep = false; diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 8b3993081..32634225a 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -1811,6 +1811,30 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PartialStoreInvalidation") )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "VaridicRegisterRangeInvalidation") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tnumber)); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tnumber)); + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + STORE_TAG R2, tnumber + FALLBACK_GETVARARGS 0u, R1, -1i + STORE_TAG R2, tnumber + RETURN 0u + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Analysis"); @@ -2329,4 +2353,60 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPropagationOfCapturedRegs") )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "NoDeadLoadReuse") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op1i = build.inst(IrCmd::NUM_TO_INT, op1); + IrOp res = build.inst(IrCmd::BITAND_UINT, op1i, build.constInt(0)); + IrOp resd = build.inst(IrCmd::INT_TO_NUM, res); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, resd, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), sum); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %4 = LOAD_DOUBLE R0 + %5 = ADD_NUM 0, %4 + STORE_DOUBLE R1, %5 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NoDeadValueReuse") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op1i = build.inst(IrCmd::NUM_TO_INT, op1); + IrOp res = build.inst(IrCmd::BITAND_UINT, op1i, build.constInt(0)); + IrOp op2i = build.inst(IrCmd::NUM_TO_INT, op1); + IrOp sum = build.inst(IrCmd::ADD_INT, res, op2i); + IrOp resd = build.inst(IrCmd::INT_TO_NUM, sum); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), resd); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + %3 = NUM_TO_INT %0 + %4 = ADD_INT 0i, %3 + %5 = INT_TO_NUM %4 + STORE_DOUBLE R1, %5 + RETURN R1, 1i + +)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 07cf5393a..4e0b7a7ea 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -132,6 +132,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_predicate") TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") { + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + CheckResult result = check(R"( --!strict local t = {'one', 'two', 'three'} @@ -140,9 +142,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type '(number, number) -> boolean' could not be converted into '((a, a) -> boolean)?' + CHECK_EQ(R"(Type '(number, number) -> boolean' could not be converted into '((string, string) -> boolean)?' caused by: - None of the union options are compatible. For example: Type '(number, number) -> boolean' could not be converted into '(a, a) -> boolean' + None of the union options are compatible. For example: Type '(number, number) -> boolean' could not be converted into '(string, string) -> boolean' caused by: Argument #1 type is not compatible. Type 'string' could not be converted into 'number')", toString(result.errors[0])); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index d9e4bbada..37ecab2ee 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -367,6 +367,8 @@ b.X = 2 -- real Vector2.X is also read-only TEST_CASE_FIXTURE(ClassFixture, "detailed_class_unification_error") { + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + CheckResult result = check(R"( local function foo(v) return v.X :: number + string.len(v.Y) @@ -378,10 +380,10 @@ b(a) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 'Vector2' could not be converted into '{- X: a, Y: string -}' + + CHECK_EQ(toString(result.errors[0]), R"(Type 'Vector2' could not be converted into '{- X: number, Y: string -}' caused by: - Property 'Y' is not compatible. Type 'number' could not be converted into 'string')", - toString(result.errors[0])); + Property 'Y' is not compatible. Type 'number' could not be converted into 'string')"); } TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index f53d6e04d..e5bcfa304 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -2002,4 +2002,75 @@ TEST_CASE_FIXTURE(Fixture, "function_exprs_are_generalized_at_signature_scope_no CHECK(toString(requireType("foo")) == "(a) -> b"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible") +{ + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + + CheckResult result = check(R"( + local function foo(x: a, y: a?) + return x + end + local vec2 = { x = 5, y = 7 } + local ret: number = foo(vec2, { x = 5 }) + )"); + + // In the old solver, this produces a very strange result: + // + // Here, we instantiate `(x: a, y: a?) -> a` with a fresh type `'a` for `a`. + // In argument #1, we unify `vec2` with `'a`. + // This is ok, so we record an equality constraint `'a` with `vec2`. + // In argument #2, we unify `{ x: number }` with `'a?`. + // This fails because `'a` has equality constraint with `vec2`, + // so `{ x: number } <: vec2?`, which is false. + // + // If the unifications were to be committed, then it'd result in the following type error: + // + // Type '{ x: number }' could not be converted into 'vec2?' + // caused by: + // [...] Table type '{ x: number }' not compatible with type 'vec2' because the former is missing field 'y' + // + // However, whenever we check the argument list, if there's an error, we don't commit the unifications, so it actually looks like this: + // + // Type '{ x: number }' could not be converted into 'a?' + // caused by: + // [...] Table type '{ x: number }' not compatible with type 'vec2' because the former is missing field 'y' + // + // Then finally, that generic is left floating free, and since the function returns that generic, + // that free type is then later bound to `number`, which succeeds and mutates the type graph. + // This again changes the type error where `a` becomes bound to `number`. + // + // Type '{ x: number }' could not be converted into 'number?' + // caused by: + // [...] Table type '{ x: number }' not compatible with type 'vec2' because the former is missing field 'y' + // + // Uh oh, that type error is extremely confusing for people who doesn't know how that went down. + // Really, what should happen is we roll each argument incompatibility into a union type, but that needs local type inference. + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(toString(result.errors[0]), R"(Type '{ x: number }' could not be converted into 'vec2?' +caused by: + None of the union options are compatible. For example: Table type '{ x: number }' not compatible with type 'vec2' because the former is missing field 'y')"); + + CHECK_EQ(toString(result.errors[1]), "Type 'vec2' could not be converted into 'number'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible_2") +{ + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + + CheckResult result = check(R"( + local function f(x: a, y: a): a + return if math.random() > 0.5 then x else y + end + + local z: boolean = f(5, "five") + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + CHECK_EQ(toString(result.errors[1]), "Type 'number' could not be converted into 'boolean'"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 738d3cd2b..3e813b7fa 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -500,17 +500,41 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") { - ScopedFastFlag sff{"LuauUnifyTwoOptions", true}; - CheckResult result = check(R"( local x : { p : number?, q : any } & { p : unknown, q : string? } local y : { p : number?, q : string? } = x -- OK local z : { p : string?, q : number? } = x -- Not OK )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into " - "'{| p: string?, q: number? |}'; none of the intersection parts are compatible"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(toString(result.errors[0]), + "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" + "caused by:\n" + " Property 'p' is not compatible. Type 'number?' could not be converted into 'string?'\n" + "caused by:\n" + " Not all union options are compatible. Type 'number' could not be converted into 'string?'\n" + "caused by:\n" + " None of the union options are compatible. For example: Type 'number' could not be converted into 'string' in an invariant context"); + + CHECK_EQ(toString(result.errors[1]), + "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" + "caused by:\n" + " Property 'q' is not compatible. Type 'string?' could not be converted into 'number?'\n" + "caused by:\n" + " Not all union options are compatible. Type 'string' could not be converted into 'number?'\n" + "caused by:\n" + " None of the union options are compatible. For example: Type 'string' could not be converted into 'number' in an invariant context"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, " + "q: number? |}'; none of the intersection parts are compatible"); + } } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 26f0448b9..c905e1cca 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -522,17 +522,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Under DCR, this currently functions as a failed overload resolution, and so we can't say - // anything about the result type of the unary minus. - CHECK_EQ("any", toString(requireType("a"))); - } - else - { - - CHECK_EQ("string", toString(requireType("a"))); - } + CHECK_EQ("string", toString(requireType("a"))); TypeMismatch* tm = get(result.errors[0]); REQUIRE_EQ(*tm->wantedType, *builtinTypes->booleanType); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 885a9781c..b5a06a746 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -482,6 +482,41 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); } +TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") +{ + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + + TypeArena arena; + TypeId nilType = builtinTypes->nilType; + + std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); + + TypeId free1 = arena.addType(FreeType{scope.get()}); + TypeId option1 = arena.addType(UnionType{{nilType, free1}}); + + TypeId free2 = arena.addType(FreeType{scope.get()}); + TypeId option2 = arena.addType(UnionType{{nilType, free2}}); + + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, NotNull{scope.get()}, Location{}, Variance::Covariant}; + + u.tryUnify(option1, option2); + + CHECK(!u.failure); + + u.log.commit(); + + ToStringOptions opts; + CHECK("a?" == toString(option1, opts)); + + // CHECK("a?" == toString(option2, opts)); // This should hold, but does not. + CHECK("b?" == toString(option2, opts)); // This should not hold. +} + TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") { ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", false}; @@ -822,7 +857,6 @@ TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions_of_tab TEST_CASE_FIXTURE(Fixture, "expected_type_should_be_a_helpful_deduction_guide_for_function_calls") { ScopedFastFlag sffs[]{ - {"LuauUnifyTwoOptions", true}, {"LuauTypeMismatchInvarianceInError", true}, }; @@ -836,22 +870,11 @@ TEST_CASE_FIXTURE(Fixture, "expected_type_should_be_a_helpful_deduction_guide_fo local x: Ref = useRef(nil) )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // This is actually wrong! Sort of. It's doing the wrong thing, it's actually asking whether - // `{| val: number? |} <: {| val: nil |}` - // instead of the correct way, which is - // `{| val: nil |} <: {| val: number? |}` - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(toString(result.errors[0]), R"(Type 'Ref' could not be converted into 'Ref' -caused by: - Property 'val' is not compatible. Type 'nil' could not be converted into 'number' in an invariant context)"); - } + // This is actually wrong! Sort of. It's doing the wrong thing, it's actually asking whether + // `{| val: number? |} <: {| val: nil |}` + // instead of the correct way, which is + // `{| val: nil |} <: {| val: number? |}` + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "floating_generics_should_not_be_allowed") @@ -876,4 +899,116 @@ TEST_CASE_FIXTURE(Fixture, "floating_generics_should_not_be_allowed") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") +{ + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + + TypeArena arena; + TypeId nilType = builtinTypes->nilType; + + std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); + + TypeId free1 = arena.addType(FreeType{scope.get()}); + TypeId option1 = arena.addType(UnionType{{nilType, free1}}); + + TypeId free2 = arena.addType(FreeType{scope.get()}); + TypeId option2 = arena.addType(UnionType{{nilType, free2}}); + + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, NotNull{scope.get()}, Location{}, Variance::Covariant}; + + u.tryUnify(option1, option2); + + CHECK(!u.failure); + + u.log.commit(); + + ToStringOptions opts; + CHECK("a?" == toString(option1, opts)); + CHECK("b?" == toString(option2, opts)); // should be `a?`. +} + +TEST_CASE_FIXTURE(Fixture, "unify_more_complex_unions_that_include_nil") +{ + CheckResult result = check(R"( + type Record = {prop: (string | boolean)?} + + function concatPagination(prop: (string | boolean | nil)?): Record + return {prop = prop} + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "optional_class_instances_are_invariant") +{ + ScopedFastFlag sff[] = { + {"LuauTypeMismatchInvarianceInError", true} + }; + + createSomeClasses(&frontend); + + CheckResult result = check(R"( + function foo(ref: {current: Parent?}) + end + + function bar(ref: {current: Child?}) + foo(ref) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Map.entries") +{ + + fileResolver.source["Module/Map"] = R"( +--!strict + +type Object = { [any]: any } +type Array = { [number]: T } +type Table = { [T]: V } +type Tuple = Array + +local Map = {} + +export type Map = { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + [K]: V, + _map: { [K]: V }, + _array: { [number]: K }, +} + +function Map:entries() + return {} +end + +local function coerceToTable(mapLike: Map | Table): Array> + local e = mapLike:entries(); + return e +end + + )"; + + CheckResult result = frontend.check("Module/Map"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index d068ae53d..f028e8e0d 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -390,8 +390,6 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") { - ScopedFastFlag sff{"LuauUnifyTwoOptions", true}; - CheckResult result = check(R"( local function foo(f, x): "hello"? -- anyone there? return if x == "hi" @@ -403,9 +401,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); - CHECK_EQ(R"(((string) -> ("hello", b...), a) -> "hello"?)", toString(requireType("foo"))); - - // This is more accurate but we're not there yet: + CHECK_EQ(R"(((string) -> (a, c...), b) -> "hello"?)", toString(requireType("foo"))); // CHECK_EQ(R"(((string) -> ("hello"?, b...), a) -> "hello"?)", toString(requireType("foo"))); } diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 82a20bc1a..694b62708 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1518,6 +1518,8 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_indexer") { + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + CheckResult result = check(R"( type StringToStringMap = { [string]: string } local rt: StringToStringMap = { ["foo"] = 1 } @@ -1563,6 +1565,8 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end foo({ a = 1 }) @@ -1574,7 +1578,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); - CHECK_EQ("{ a: number }", toString(tm->givenType, o)); + CHECK_EQ("{ [string]: number, a: number }", toString(tm->givenType, o)); } TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") @@ -2383,6 +2387,8 @@ TEST_CASE_FIXTURE(Fixture, "confusing_indexing") TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") { + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + CheckResult result = check(R"( local a: {x: number, y: number, [any]: any} | {y: number} @@ -2396,11 +2402,16 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); + else + REQUIRE_EQ("{- y: number -}", toString(requireType("b"))); } TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") { + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + CheckResult result = check(R"( local a: {y: number} | {x: number, y: number, [any]: any} @@ -2414,7 +2425,10 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); + else + REQUIRE_EQ("{- y: number -}", toString(requireType("b"))); } TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf1") @@ -3292,6 +3306,8 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shap TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type") { + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + CheckResult result = check(R"( local function f(s) return s:absolutely_no_scalar_has_this_method() @@ -3308,10 +3324,12 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_ caused by: The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", toString(result.errors[0])); + CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", toString(result.errors[1])); + CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 829f993a4..efe7fed38 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -989,10 +989,6 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") function Policies:readField(options: ReadFieldOptions) local _ = self:getStoreFieldName(options) - --[[ - Type error: - TypeError { "MainModule", Location { { line = 25, col = 16 }, { line = 25, col = 20 } }, TypeMismatch { Policies, {- getStoreFieldName: (tp1) -> (a, b...) -} } } - ]] foo(self) end )"); @@ -1006,9 +1002,9 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ( - R"(Type 't1 where t1 = {+ getStoreFieldName: (t1, {| fieldName: string |} & {| from: number? |}) -> (a, b...) +}' could not be converted into 'Policies' + R"(Type 'Policies' from 'MainModule' could not be converted into 'Policies' from 'MainModule' caused by: - Property 'getStoreFieldName' is not compatible. Type 't1 where t1 = ({+ getStoreFieldName: t1 +}, {| fieldName: string |} & {| from: number? |}) -> (a, b...)' could not be converted into '(Policies, FieldSpecifier) -> string' + Property 'getStoreFieldName' is not compatible. Type '(Policies, FieldSpecifier & {| from: number? |}) -> (a, b...)' could not be converted into '(Policies, FieldSpecifier) -> string' caused by: Argument #2 type is not compatible. Type 'FieldSpecifier' could not be converted into 'FieldSpecifier & {| from: number? |}' caused by: diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 6b451e116..7475d04bd 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -189,6 +189,8 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_anything") TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_unified_with_errorType") { + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + CheckResult result = check(R"( function f(arg: number) end local a @@ -198,12 +200,14 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("*error-type*", toString(requireType("b"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") { + ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + CheckResult result = check(R"( function f(arg: number) return arg end local a @@ -213,7 +217,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_con LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("*error-type*", toString(requireType("b"))); CHECK_EQ("number", toString(requireType("c"))); } @@ -442,7 +446,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_two_unions_under_dcr_does_not_creat const TypeId innerType = arena.freshType(nestedScope.get()); - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + ScopedFastFlag sffs[]{ + {"DebugLuauDeferredConstraintResolution", true}, + {"LuauAlwaysCommitInferencesOfFunctionCalls", true}, + }; state.enableScopeTests(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 570c72d06..d6ae5acc9 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -789,128 +789,4 @@ TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions") CHECK("variables" == unknownProp->key); } -TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") -{ - ScopedFastFlag sff[] = { - {"LuauTransitiveSubtyping", true}, - {"LuauUnifyTwoOptions", true} - }; - - TypeArena arena; - TypeId nilType = builtinTypes->nilType; - - std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - - TypeId free1 = arena.addType(FreeType{scope.get()}); - TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - - TypeId free2 = arena.addType(FreeType{scope.get()}); - TypeId option2 = arena.addType(UnionType{{nilType, free2}}); - - InternalErrorReporter iceHandler; - UnifierSharedState sharedState{&iceHandler}; - Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, NotNull{scope.get()}, Location{}, Variance::Covariant}; - - u.tryUnify(option1, option2); - - CHECK(!u.failure); - - u.log.commit(); - - ToStringOptions opts; - CHECK("a?" == toString(option1, opts)); - CHECK("a?" == toString(option2, opts)); -} - -TEST_CASE_FIXTURE(Fixture, "unify_more_complex_unions_that_include_nil") -{ - CheckResult result = check(R"( - type Record = {prop: (string | boolean)?} - - function concatPagination(prop: (string | boolean | nil)?): Record - return {prop = prop} - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "optional_class_instances_are_invariant") -{ - ScopedFastFlag sff[] = { - {"LuauUnifyTwoOptions", true}, - {"LuauTypeMismatchInvarianceInError", true} - }; - - createSomeClasses(&frontend); - - CheckResult result = check(R"( - function foo(ref: {current: Parent?}) - end - - function bar(ref: {current: Child?}) - foo(ref) - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - // The last line of this error is the most important part. We need to - // communicate that this is an invariant context. - std::string expectedError = - "Type '{| current: Child? |}' could not be converted into '{| current: Parent? |}'\n" - "caused by:\n" - " Property 'current' is not compatible. Type 'Child' could not be converted into 'Parent' in an invariant context" - ; - - CHECK(expectedError == toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Map.entries") -{ - - fileResolver.source["Module/Map"] = R"( ---!strict - -type Object = { [any]: any } -type Array = { [number]: T } -type Table = { [T]: V } -type Tuple = Array - -local Map = {} - -export type Map = { - size: number, - -- method definitions - set: (self: Map, K, V) -> Map, - get: (self: Map, K) -> V | nil, - clear: (self: Map) -> (), - delete: (self: Map, K) -> boolean, - has: (self: Map, K) -> boolean, - keys: (self: Map) -> Array, - values: (self: Map) -> Array, - entries: (self: Map) -> Array>, - ipairs: (self: Map) -> any, - [K]: V, - _map: { [K]: V }, - _array: { [number]: K }, -} - -function Map:entries() - return {} -end - -local function coerceToTable(mapLike: Map | Table): Array> - local e = mapLike:entries(); - return e -end - - )"; - - CheckResult result = frontend.check("Module/Map"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_SUITE_END(); diff --git a/tests/conformance/native.lua b/tests/conformance/native.lua index 085085c1b..85cc06bf0 100644 --- a/tests/conformance/native.lua +++ b/tests/conformance/native.lua @@ -25,7 +25,7 @@ local function fuzzfail1(...) end end -local function fuzzFail2() +local function fuzzfail2() local _ do repeat @@ -35,6 +35,30 @@ local function fuzzFail2() end end -assert(pcall(fuzzFail2) == false) +assert(pcall(fuzzfail2) == false) + +local function fuzzfail3() + function _(...) + _({_,_,true,},{...,},_,not _) + end + _() +end + +assert(pcall(fuzzfail3) == false) + +local function fuzzfail4() + local _ = setmetatable({},setmetatable({_=_,},_)) + return _(_:_()) +end + +assert(pcall(fuzzfail4) == false) + +local function fuzzfail5() + local _ = bit32.band + _(_(_,0),_) + _(_,_) +end + +assert(pcall(fuzzfail5) == false) return('OK') diff --git a/tools/faillist.txt b/tools/faillist.txt index 5c62e3da9..e7d1f5f40 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -11,7 +11,6 @@ BuiltinTests.select_slightly_out_of_range BuiltinTests.select_way_out_of_range BuiltinTests.set_metatable_needs_arguments BuiltinTests.setmetatable_should_not_mutate_persisted_types -BuiltinTests.sort_with_bad_predicate BuiltinTests.string_format_as_method BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions @@ -37,12 +36,14 @@ GenericsTests.infer_generic_lib_function_function_argument GenericsTests.instantiated_function_argument_names GenericsTests.no_stack_overflow_from_quantifying GenericsTests.self_recursive_instantiated_param +IntersectionTypes.intersection_of_tables_with_top_properties IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean -ProvisionalTests.expected_type_should_be_a_helpful_deduction_guide_for_function_calls +ProvisionalTests.free_options_can_be_unified_together +ProvisionalTests.free_options_cannot_be_unified_together ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns ProvisionalTests.luau-polyfill.Array.filter ProvisionalTests.setmetatable_constrains_free_type_into_free_table @@ -150,8 +151,11 @@ TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function +TypeInferLoops.dcr_iteration_explore_raycast_minimization TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values +TypeInferLoops.loop_iter_metamethod_ok_with_inference TypeInferLoops.loop_iter_trailing_nil +TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict @@ -164,7 +168,6 @@ TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops TypeInferOperators.luau-polyfill.String.slice -TypeInferOperators.luau_polyfill_is_array TypeInferOperators.operator_eq_completely_incompatible TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs From 88cd3dda87f3a5be4b8a3c5c43b9ebc68f8b0e6a Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 9 Jun 2023 15:20:36 +0300 Subject: [PATCH 58/66] Sync to upstream/release/580 --- Analysis/include/Luau/InsertionOrderedMap.h | 134 ++++++++ Analysis/include/Luau/Normalize.h | 25 +- Analysis/src/ConstraintGraphBuilder.cpp | 22 +- Analysis/src/ConstraintSolver.cpp | 36 +- Analysis/src/Frontend.cpp | 4 +- Analysis/src/Normalize.cpp | 74 ++++- Analysis/src/TypeChecker2.cpp | 29 +- Analysis/src/TypeFamily.cpp | 3 +- CLI/Compile.cpp | 346 ++++++++++++++++++++ CMakeLists.txt | 5 + CodeGen/include/Luau/AddressA64.h | 16 +- CodeGen/include/Luau/IrAnalysis.h | 40 +++ CodeGen/include/Luau/IrData.h | 1 + CodeGen/src/AssemblyBuilderA64.cpp | 38 ++- CodeGen/src/AssemblyBuilderX64.cpp | 2 +- CodeGen/src/CodeAllocator.cpp | 2 - CodeGen/src/CodeGen.cpp | 2 +- CodeGen/src/CodeGenA64.cpp | 80 +++++ CodeGen/src/CodeGenUtils.cpp | 40 --- CodeGen/src/CodeGenUtils.h | 1 - CodeGen/src/CodeGenX64.cpp | 5 + CodeGen/src/EmitCommon.h | 1 + CodeGen/src/EmitCommonX64.cpp | 83 +++++ CodeGen/src/EmitCommonX64.h | 2 + CodeGen/src/EmitInstructionX64.cpp | 175 +++------- CodeGen/src/IrAnalysis.cpp | 213 ++++++++++++ CodeGen/src/IrBuilder.cpp | 1 + CodeGen/src/IrLoweringA64.cpp | 77 ++++- CodeGen/src/IrLoweringA64.h | 5 +- CodeGen/src/NativeState.cpp | 1 - CodeGen/src/NativeState.h | 1 - Makefile | 15 +- Sources.cmake | 22 +- VM/src/ldo.cpp | 4 +- VM/src/lvmexecute.cpp | 3 - tests/AssemblyBuilderA64.test.cpp | 26 ++ tests/Autocomplete.test.cpp | 2 - tests/InsertionOrderedMap.test.cpp | 140 ++++++++ tests/IrBuilder.test.cpp | 53 +++ tests/Normalize.test.cpp | 17 +- tests/TypeInfer.functions.test.cpp | 18 +- tools/faillist.txt | 2 - 42 files changed, 1488 insertions(+), 278 deletions(-) create mode 100644 Analysis/include/Luau/InsertionOrderedMap.h create mode 100644 CLI/Compile.cpp create mode 100644 tests/InsertionOrderedMap.test.cpp diff --git a/Analysis/include/Luau/InsertionOrderedMap.h b/Analysis/include/Luau/InsertionOrderedMap.h new file mode 100644 index 000000000..66d6b2ab8 --- /dev/null +++ b/Analysis/include/Luau/InsertionOrderedMap.h @@ -0,0 +1,134 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include +#include +#include +#include + +namespace Luau +{ + +template +struct InsertionOrderedMap +{ + static_assert(std::is_trivially_copyable_v, "key must be trivially copyable"); + + private: + using vec = std::vector>; + + public: + using iterator = typename vec::iterator; + using const_iterator = typename vec::const_iterator; + + void insert(K k, V v) + { + if (indices.count(k) != 0) + return; + + pairs.push_back(std::make_pair(k, std::move(v))); + indices[k] = pairs.size() - 1; + } + + void clear() + { + pairs.clear(); + indices.clear(); + } + + size_t size() const + { + LUAU_ASSERT(pairs.size() == indices.size()); + return pairs.size(); + } + + bool contains(const K& k) const + { + return indices.count(k) > 0; + } + + const V* get(const K& k) const + { + auto it = indices.find(k); + if (it == indices.end()) + return nullptr; + else + return &pairs.at(it->second).second; + } + + V* get(const K& k) + { + auto it = indices.find(k); + if (it == indices.end()) + return nullptr; + else + return &pairs.at(it->second).second; + } + + const_iterator begin() const + { + return pairs.begin(); + } + + const_iterator end() const + { + return pairs.end(); + } + + iterator begin() + { + return pairs.begin(); + } + + iterator end() + { + return pairs.end(); + } + + const_iterator find(K k) const + { + auto indicesIt = indices.find(k); + if (indicesIt == indices.end()) + return end(); + else + return begin() + indicesIt->second; + } + + iterator find(K k) + { + auto indicesIt = indices.find(k); + if (indicesIt == indices.end()) + return end(); + else + return begin() + indicesIt->second; + } + + void erase(iterator it) + { + if (it == pairs.end()) + return; + + K k = it->first; + auto indexIt = indices.find(k); + if (indexIt == indices.end()) + return; + + size_t removed = indexIt->second; + indices.erase(indexIt); + pairs.erase(it); + + for (auto& [_, index] : indices) + { + if (index > removed) + --index; + } + } + +private: + vec pairs; + std::unordered_map indices; +}; + +} diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 7d415e92f..72be0832b 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -64,6 +64,7 @@ class TypeIds bool operator==(const TypeIds& there) const; size_t getHash() const; + bool isNever() const; }; } // namespace Luau @@ -269,12 +270,24 @@ struct NormalizedType NormalizedType& operator=(NormalizedType&&) = default; // IsType functions - - /// Returns true if the type is a subtype of function. This includes any and unknown. - bool isFunction() const; - - /// Returns true if the type is a subtype of number. This includes any and unknown. - bool isNumber() const; + /// Returns true if the type is exactly a number. Behaves like Type::isNumber() + bool isExactlyNumber() const; + + /// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString() + bool isSubtypeOfString() const; + + // Helpers that improve readability of the above (they just say if the component is present) + bool hasTops() const; + bool hasBooleans() const; + bool hasClasses() const; + bool hasErrors() const; + bool hasNils() const; + bool hasNumbers() const; + bool hasStrings() const; + bool hasThreads() const; + bool hasTables() const; + bool hasFunctions() const; + bool hasTyvars() const; }; diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index c14f10e5a..821f6c260 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -16,6 +16,7 @@ #include "Luau/TypeFamily.h" #include "Luau/Simplify.h" #include "Luau/VisitType.h" +#include "Luau/InsertionOrderedMap.h" #include @@ -196,7 +197,7 @@ struct RefinementPartition bool shouldAppendNilType = false; }; -using RefinementContext = std::unordered_map; +using RefinementContext = InsertionOrderedMap; static void unionRefinements(NotNull builtinTypes, NotNull arena, const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector* constraints) @@ -229,8 +230,9 @@ static void unionRefinements(NotNull builtinTypes, NotNullsecond.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] : intersect(rhsIt->second.discriminantTypes); - dest[def].discriminantTypes.push_back(simplifyUnion(builtinTypes, arena, leftDiscriminantTy, rightDiscriminantTy).result); - dest[def].shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; + dest.insert(def, {}); + dest.get(def)->discriminantTypes.push_back(simplifyUnion(builtinTypes, arena, leftDiscriminantTy, rightDiscriminantTy).result); + dest.get(def)->shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; } } @@ -285,11 +287,12 @@ static void computeRefinement(NotNull builtinTypes, NotNullbreadcrumb->def].discriminantTypes.push_back(discriminantTy); + uncommittedRefis.insert(proposition->breadcrumb->def, {}); + uncommittedRefis.get(proposition->breadcrumb->def)->discriminantTypes.push_back(discriminantTy); // When the top-level expression is `t[x]`, we want to refine it into `nil`, not `never`. if ((sense || !eq) && getMetadata(proposition->breadcrumb)) - uncommittedRefis[proposition->breadcrumb->def].shouldAppendNilType = true; + uncommittedRefis.get(proposition->breadcrumb->def)->shouldAppendNilType = true; for (NullableBreadcrumbId current = proposition->breadcrumb; current && current->previous; current = current->previous) { @@ -302,17 +305,20 @@ static void computeRefinement(NotNull builtinTypes, NotNullprop, Property{discriminantTy}}}; discriminantTy = arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); - uncommittedRefis[current->previous->def].discriminantTypes.push_back(discriminantTy); + uncommittedRefis.insert(current->previous->def, {}); + uncommittedRefis.get(current->previous->def)->discriminantTypes.push_back(discriminantTy); } } // And now it's time to commit it. for (auto& [def, partition] : uncommittedRefis) { + (*refis).insert(def, {}); + for (TypeId discriminantTy : partition.discriminantTypes) - (*refis)[def].discriminantTypes.push_back(discriminantTy); + (*refis).get(def)->discriminantTypes.push_back(discriminantTy); - (*refis)[def].shouldAppendNilType |= partition.shouldAppendNilType; + (*refis).get(def)->shouldAppendNilType |= partition.shouldAppendNilType; } } } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index f96f54b6d..c9ac8cc9d 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -785,7 +785,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullnormalize(leftType); if (hasTypeInIntersection(leftType) && force) asMutable(leftType)->ty.emplace(anyPresent ? builtinTypes->anyType : builtinTypes->numberType); - if (normLeftTy && normLeftTy->isNumber()) + // We want to check if the left type has tops because `any` is a valid type for the lhs + if (normLeftTy && (normLeftTy->isExactlyNumber() || get(normLeftTy->tops))) { unify(leftType, rightType, constraint->scope); asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); @@ -805,9 +806,11 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) && force) asMutable(leftType)->ty.emplace(anyPresent ? builtinTypes->anyType : builtinTypes->stringType); - if (isString(leftType)) + const NormalizedType* leftNormTy = normalizer->normalize(leftType); + if (leftNormTy && leftNormTy->isSubtypeOfString()) { unify(leftType, rightType, constraint->scope); asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); @@ -823,14 +826,33 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) || - get(rightType)) + { + const NormalizedType* lt = normalizer->normalize(leftType); + const NormalizedType* rt = normalizer->normalize(rightType); + // If the lhs is any, comparisons should be valid. + if (lt && rt && (lt->isExactlyNumber() || get(lt->tops)) && rt->isExactlyNumber()) + { + asMutable(resultType)->ty.emplace(builtinTypes->booleanType); + unblock(resultType); + return true; + } + + if (lt && rt && (lt->isSubtypeOfString() || get(lt->tops)) && rt->isSubtypeOfString()) + { + asMutable(resultType)->ty.emplace(builtinTypes->booleanType); + unblock(resultType); + return true; + } + + + if (get(leftType) || get(rightType)) { asMutable(resultType)->ty.emplace(builtinTypes->booleanType); unblock(resultType); @@ -838,6 +860,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull builtinTypes, TypePackId) return builtinTypes->errorRecoveryTypePack(); } -template +template bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, TID superTy) { Unifier u{normalizer, constraint->scope, constraint->location, Covariant}; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 062050fa6..409e2eb3d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -35,7 +35,6 @@ LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) -LUAU_FASTFLAGVARIABLE(LuauTypeCheckerUseCorrectScope, false) namespace Luau { @@ -1196,8 +1195,7 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect } else { - TypeChecker typeChecker(FFlag::LuauTypeCheckerUseCorrectScope ? (forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope) - : globals.globalScope, + TypeChecker typeChecker(forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope, forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver, builtinTypes, &iceHandler); if (prepareModuleScope) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 6a78bc667..3af7e8574 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -108,6 +108,14 @@ size_t TypeIds::getHash() const return hash; } +bool TypeIds::isNever() const +{ + return std::all_of(begin(), end(), [&](TypeId i) { + // If each typeid is never, then I guess typeid's is also never? + return get(i) != nullptr; + }); +} + bool TypeIds::operator==(const TypeIds& there) const { return hash == there.hash && types == there.types; @@ -228,14 +236,72 @@ NormalizedType::NormalizedType(NotNull builtinTypes) { } -bool NormalizedType::isFunction() const +bool NormalizedType::isExactlyNumber() const +{ + return hasNumbers() && !hasTops() && !hasBooleans() && !hasClasses() && !hasErrors() && !hasNils() && !hasStrings() && !hasThreads() && + !hasTables() && !hasFunctions() && !hasTyvars(); +} + +bool NormalizedType::isSubtypeOfString() const +{ + return hasStrings() && !hasTops() && !hasBooleans() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasThreads() && + !hasTables() && !hasFunctions() && !hasTyvars(); +} + +bool NormalizedType::hasTops() const +{ + return !get(tops); +} + + +bool NormalizedType::hasBooleans() const +{ + return !get(booleans); +} + +bool NormalizedType::hasClasses() const +{ + return !classes.isNever(); +} + +bool NormalizedType::hasErrors() const +{ + return !get(errors); +} + +bool NormalizedType::hasNils() const +{ + return !get(nils); +} + +bool NormalizedType::hasNumbers() const +{ + return !get(numbers); +} + +bool NormalizedType::hasStrings() const +{ + return !strings.isNever(); +} + +bool NormalizedType::hasThreads() const +{ + return !get(threads); +} + +bool NormalizedType::hasTables() const +{ + return !tables.isNever(); +} + +bool NormalizedType::hasFunctions() const { - return !get(tops) || !functions.parts.empty(); + return !functions.isNever(); } -bool NormalizedType::isNumber() const +bool NormalizedType::hasTyvars() const { - return !get(tops) || !get(numbers); + return !tyvars.empty(); } static bool isShallowInhabited(const NormalizedType& norm) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index c1146b5c8..0a9e9b648 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1067,9 +1067,7 @@ struct TypeChecker2 std::vector argLocs; argLocs.reserve(call->args.size + 1); - TypeId* maybeOriginalCallTy = module->astOriginalCallTypes.find(call); - TypeId* maybeSelectedOverload = module->astOverloadResolvedTypes.find(call); - + auto maybeOriginalCallTy = module->astOriginalCallTypes.find(call); if (!maybeOriginalCallTy) return; @@ -1093,8 +1091,19 @@ struct TypeChecker2 return; } } - else if (get(originalCallTy) || get(originalCallTy)) + else if (get(originalCallTy)) + { + // ok. + } + else if (get(originalCallTy)) { + auto norm = normalizer.normalize(originalCallTy); + if (!norm) + return reportError(CodeTooComplex{}, call->location); + + // NormalizedType::hasFunction returns true if its' tops component is `unknown`, but for soundness we want the reverse. + if (get(norm->tops) || !norm->hasFunctions()) + return reportError(CannotCallNonFunction{originalCallTy}, call->func->location); } else if (auto utv = get(originalCallTy)) { @@ -1164,7 +1173,7 @@ struct TypeChecker2 TypePackId expectedArgTypes = testArena.addTypePack(args); - if (maybeSelectedOverload) + if (auto maybeSelectedOverload = module->astOverloadResolvedTypes.find(call)) { // This overload might not work still: the constraint solver will // pass the type checker an instantiated function type that matches @@ -1414,7 +1423,7 @@ struct TypeChecker2 { // Nothing } - else if (!normalizedFnTy->isFunction()) + else if (!normalizedFnTy->hasFunctions()) { ice->ice("Internal error: Lambda has non-function type " + toString(inferredFnTy), fn->location); } @@ -1793,12 +1802,14 @@ struct TypeChecker2 case AstExprBinary::Op::CompareGt: case AstExprBinary::Op::CompareLe: case AstExprBinary::Op::CompareLt: - if (isNumber(leftType)) + { + const NormalizedType* leftTyNorm = normalizer.normalize(leftType); + if (leftTyNorm && leftTyNorm->isExactlyNumber()) { reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); return builtinTypes->numberType; } - else if (isString(leftType)) + else if (leftTyNorm && leftTyNorm->isSubtypeOfString()) { reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->stringType)); return builtinTypes->stringType; @@ -1810,6 +1821,8 @@ struct TypeChecker2 expr->location); return builtinTypes->errorRecoveryType(); } + } + case AstExprBinary::Op::And: case AstExprBinary::Op::Or: case AstExprBinary::Op::CompareEq: diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index e68187fd1..4adf0f8a7 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -346,7 +346,6 @@ TypeFamilyReductionResult addFamilyFn(std::vector typeParams, st TypeId rhsTy = log->follow(typeParams.at(1)); const NormalizedType* normLhsTy = normalizer->normalize(lhsTy); const NormalizedType* normRhsTy = normalizer->normalize(rhsTy); - if (!normLhsTy || !normRhsTy) { return {std::nullopt, false, {}, {}}; @@ -355,7 +354,7 @@ TypeFamilyReductionResult addFamilyFn(std::vector typeParams, st { return {builtins->anyType, false, {}, {}}; } - else if (normLhsTy->isNumber() && normRhsTy->isNumber()) + else if ((normLhsTy->hasNumbers() || normLhsTy->hasTops()) && (normRhsTy->hasNumbers() || normRhsTy->hasTops())) { return {builtins->numberType, false, {}, {}}; } diff --git a/CLI/Compile.cpp b/CLI/Compile.cpp new file mode 100644 index 000000000..293809d0e --- /dev/null +++ b/CLI/Compile.cpp @@ -0,0 +1,346 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Luau/CodeGen.h" +#include "Luau/Compiler.h" +#include "Luau/BytecodeBuilder.h" +#include "Luau/Parser.h" +#include "Luau/TimeTrace.h" + +#include "FileUtils.h" +#include "Flags.h" + +#include + +#ifdef _WIN32 +#include +#include +#endif + +LUAU_FASTFLAG(DebugLuauTimeTracing) + +enum class CompileFormat +{ + Text, + Binary, + Remarks, + Codegen, // Prints annotated native code including IR and assembly + CodegenAsm, // Prints annotated native code assembly + CodegenIr, // Prints annotated native code IR + CodegenVerbose, // Prints annotated native code including IR, assembly and outlined code + CodegenNull, + Null +}; + +struct GlobalOptions +{ + int optimizationLevel = 1; + int debugLevel = 1; +} globalOptions; + +static Luau::CompileOptions copts() +{ + Luau::CompileOptions result = {}; + result.optimizationLevel = globalOptions.optimizationLevel; + result.debugLevel = globalOptions.debugLevel; + + return result; +} + +static std::optional getCompileFormat(const char* name) +{ + if (strcmp(name, "text") == 0) + return CompileFormat::Text; + else if (strcmp(name, "binary") == 0) + return CompileFormat::Binary; + else if (strcmp(name, "text") == 0) + return CompileFormat::Text; + else if (strcmp(name, "remarks") == 0) + return CompileFormat::Remarks; + else if (strcmp(name, "codegen") == 0) + return CompileFormat::Codegen; + else if (strcmp(name, "codegenasm") == 0) + return CompileFormat::CodegenAsm; + else if (strcmp(name, "codegenir") == 0) + return CompileFormat::CodegenIr; + else if (strcmp(name, "codegenverbose") == 0) + return CompileFormat::CodegenVerbose; + else if (strcmp(name, "codegennull") == 0) + return CompileFormat::CodegenNull; + else if (strcmp(name, "null") == 0) + return CompileFormat::Null; + else + return std::nullopt; +} + +static void report(const char* name, const Luau::Location& location, const char* type, const char* message) +{ + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); +} + +static void reportError(const char* name, const Luau::ParseError& error) +{ + report(name, error.getLocation(), "SyntaxError", error.what()); +} + +static void reportError(const char* name, const Luau::CompileError& error) +{ + report(name, error.getLocation(), "CompileError", error.what()); +} + +static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) +{ + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) + return Luau::CodeGen::getAssembly(L, -1, options); + + fprintf(stderr, "Error loading bytecode %s\n", name); + return ""; +} + +static void annotateInstruction(void* context, std::string& text, int fid, int instpos) +{ + Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; + + bcb.annotateInstruction(text, fid, instpos); +} + +struct CompileStats +{ + size_t lines; + size_t bytecode; + size_t codegen; + + double readTime; + double miscTime; + double parseTime; + double compileTime; + double codegenTime; +}; + +static double recordDeltaTime(double& timer) +{ + double now = Luau::TimeTrace::getClock(); + double delta = now - timer; + timer = now; + return delta; +} + +static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) +{ + double currts = Luau::TimeTrace::getClock(); + + std::optional source = readFile(name); + if (!source) + { + fprintf(stderr, "Error opening %s\n", name); + return false; + } + + stats.readTime += recordDeltaTime(currts); + + // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) + // This function is much more complicated because it supports many output human-readable formats through internal interfaces + + try + { + Luau::BytecodeBuilder bcb; + + Luau::CodeGen::AssemblyOptions options; + options.outputBinary = format == CompileFormat::CodegenNull; + + if (!options.outputBinary) + { + options.includeAssembly = format != CompileFormat::CodegenIr; + options.includeIr = format != CompileFormat::CodegenAsm; + options.includeOutlinedCode = format == CompileFormat::CodegenVerbose; + } + + options.annotator = annotateInstruction; + options.annotatorContext = &bcb; + + if (format == CompileFormat::Text) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } + else if (format == CompileFormat::Remarks) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } + else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || + format == CompileFormat::CodegenVerbose) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } + + stats.miscTime += recordDeltaTime(currts); + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + stats.lines += result.lines; + stats.parseTime += recordDeltaTime(currts); + + Luau::compileOrThrow(bcb, result, names, copts()); + stats.bytecode += bcb.getBytecode().size(); + stats.compileTime += recordDeltaTime(currts); + + switch (format) + { + case CompileFormat::Text: + printf("%s", bcb.dumpEverything().c_str()); + break; + case CompileFormat::Remarks: + printf("%s", bcb.dumpSourceRemarks().c_str()); + break; + case CompileFormat::Binary: + fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); + break; + case CompileFormat::Codegen: + case CompileFormat::CodegenAsm: + case CompileFormat::CodegenIr: + case CompileFormat::CodegenVerbose: + printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); + break; + case CompileFormat::CodegenNull: + stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); + stats.codegenTime += recordDeltaTime(currts); + break; + case CompileFormat::Null: + break; + } + + return true; + } + catch (Luau::ParseErrors& e) + { + for (auto& error : e.getErrors()) + reportError(name, error); + return false; + } + catch (Luau::CompileError& e) + { + reportError(name, e); + return false; + } +} + +static void displayHelp(const char* argv0) +{ + printf("Usage: %s [--mode] [options] [file list]\n", argv0); + printf("\n"); + printf("Available modes:\n"); + printf(" binary, text, remarks, codegen\n"); + printf("\n"); + printf("Available options:\n"); + printf(" -h, --help: Display this usage message.\n"); + printf(" -O: compile with optimization level n (default 1, n should be between 0 and 2).\n"); + printf(" -g: compile with debug level n (default 1, n should be between 0 and 2).\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); +} + +static int assertionHandler(const char* expr, const char* file, int line, const char* function) +{ + printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + return 1; +} + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + setLuauFlagsDefault(); + + CompileFormat compileFormat = CompileFormat::Text; + + for (int i = 1; i < argc; i++) + { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (strncmp(argv[i], "-O", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Optimization level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.optimizationLevel = level; + } + else if (strncmp(argv[i], "-g", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Debug level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.debugLevel = level; + } + else if (strcmp(argv[i], "--timetrace") == 0) + { + FFlag::DebugLuauTimeTracing.value = true; + } + else if (strncmp(argv[i], "--fflags=", 9) == 0) + { + setLuauFlags(argv[i] + 9); + } + else if (argv[i][0] == '-' && argv[i][1] == '-' && getCompileFormat(argv[i] + 2)) + { + compileFormat = *getCompileFormat(argv[i] + 2); + } + else if (argv[i][0] == '-') + { + fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); + displayHelp(argv[0]); + return 1; + } + } + +#if !defined(LUAU_ENABLE_TIME_TRACE) + if (FFlag::DebugLuauTimeTracing) + { + fprintf(stderr, "To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; + } +#endif + + const std::vector files = getSourceFiles(argc, argv); + +#ifdef _WIN32 + if (compileFormat == CompileFormat::Binary) + _setmode(_fileno(stdout), _O_BINARY); +#endif + + CompileStats stats = {}; + int failed = 0; + + for (const std::string& path : files) + failed += !compileFile(path.c_str(), compileFormat, stats); + + if (compileFormat == CompileFormat::Null) + printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), int(stats.bytecode / 1024), + stats.readTime, stats.parseTime, stats.compileTime); + else if (compileFormat == CompileFormat::CodegenNull) + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", + int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024), + stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime, + stats.codegenTime); + + return failed ? 1 : 0; +} diff --git a/CMakeLists.txt b/CMakeLists.txt index b6e8b5913..bc66a83dc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,12 +36,14 @@ if(LUAU_BUILD_CLI) add_executable(Luau.Analyze.CLI) add_executable(Luau.Ast.CLI) add_executable(Luau.Reduce.CLI) + add_executable(Luau.Compile.CLI) # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) set_target_properties(Luau.Ast.CLI PROPERTIES OUTPUT_NAME luau-ast) set_target_properties(Luau.Reduce.CLI PROPERTIES OUTPUT_NAME luau-reduce) + set_target_properties(Luau.Compile.CLI PROPERTIES OUTPUT_NAME luau-compile) endif() if(LUAU_BUILD_TESTS) @@ -186,6 +188,7 @@ if(LUAU_BUILD_CLI) target_compile_options(Luau.Reduce.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Ast.CLI PRIVATE ${LUAU_OPTIONS}) + target_compile_options(Luau.Compile.CLI PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include) @@ -206,6 +209,8 @@ if(LUAU_BUILD_CLI) target_compile_features(Luau.Reduce.CLI PRIVATE cxx_std_17) target_include_directories(Luau.Reduce.CLI PUBLIC Reduce/include) target_link_libraries(Luau.Reduce.CLI PRIVATE Luau.Common Luau.Ast Luau.Analysis) + + target_link_libraries(Luau.Compile.CLI PRIVATE Luau.Compiler Luau.VM Luau.CodeGen) endif() if(LUAU_BUILD_TESTS) diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h index acb64e390..097cc1360 100644 --- a/CodeGen/include/Luau/AddressA64.h +++ b/CodeGen/include/Luau/AddressA64.h @@ -14,13 +14,10 @@ namespace A64 enum class AddressKindA64 : uint8_t { - imm, // reg + imm - reg, // reg + reg - - // TODO: - // reg + reg << shift - // reg + sext(reg) << shift - // reg + uext(reg) << shift + reg, // reg + reg + imm, // reg + imm + pre, // reg + imm, reg += imm + post, // reg, reg += imm }; struct AddressA64 @@ -29,13 +26,14 @@ struct AddressA64 // For example, ldr x0, [reg+imm] is limited to 8 KB offsets assuming imm is divisible by 8, but loading into w0 reduces the range to 4 KB static constexpr size_t kMaxOffset = 1023; - constexpr AddressA64(RegisterA64 base, int off = 0) - : kind(AddressKindA64::imm) + constexpr AddressA64(RegisterA64 base, int off = 0, AddressKindA64 kind = AddressKindA64::imm) + : kind(kind) , base(base) , offset(xzr) , data(off) { LUAU_ASSERT(base.kind == KindA64::x || base == sp); + LUAU_ASSERT(kind != AddressKindA64::reg); } constexpr AddressA64(RegisterA64 base, RegisterA64 offset) diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 5418009a8..ca1eba622 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Common.h" + #include #include #include @@ -37,6 +39,16 @@ struct RegisterSet void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart); +struct BlockOrdering +{ + uint32_t depth = 0; + + uint32_t preOrder = ~0u; + uint32_t postOrder = ~0u; + + bool visited = false; +}; + struct CfgInfo { std::vector predecessors; @@ -45,6 +57,15 @@ struct CfgInfo std::vector successors; std::vector successorsOffsets; + // Immediate dominators (unique parent in the dominator tree) + std::vector idoms; + + // Children in the dominator tree + std::vector domChildren; + std::vector domChildrenOffsets; + + std::vector domOrdering; + // VM registers that are live when the block is entered // Additionally, an active variadic sequence can exist at the entry of the block std::vector in; @@ -64,6 +85,18 @@ struct CfgInfo RegisterSet captured; }; +// A quick refresher on dominance and dominator trees: +// * If A is a dominator of B (A dom B), you can never execute B without executing A first +// * A is a strict dominator of B (A sdom B) is similar to previous one but A != B +// * Immediate dominator node N (idom N) is a unique node T so that T sdom N, +// but T does not strictly dominate any other node that dominates N. +// * Dominance frontier is a set of nodes where dominance of a node X ends. +// In practice this is where values established by node X might no longer hold because of join edges from other nodes coming in. +// This is also where PHI instructions in SSA are placed. +void computeCfgImmediateDominators(IrFunction& function); +void computeCfgDominanceTreeChildren(IrFunction& function); + +// Function used to update all CFG data void computeCfgInfo(IrFunction& function); struct BlockIteratorWrapper @@ -90,10 +123,17 @@ struct BlockIteratorWrapper { return itEnd; } + + uint32_t operator[](size_t pos) const + { + LUAU_ASSERT(pos < size_t(itEnd - itBegin)); + return itBegin[pos]; + } }; BlockIteratorWrapper predecessors(const CfgInfo& cfg, uint32_t blockIdx); BlockIteratorWrapper successors(const CfgInfo& cfg, uint32_t blockIdx); +BlockIteratorWrapper domChildren(const CfgInfo& cfg, uint32_t blockIdx); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 0e17cba9e..4a3fa4242 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -823,6 +823,7 @@ struct IrFunction uint32_t validRestoreOpBlockIdx = 0; Proto* proto = nullptr; + bool variadic = false; CfgInfo cfg; diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index 99a68481e..c62d797a9 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -876,6 +876,9 @@ void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 sr switch (src.kind) { + case AddressKindA64::reg: + place(dst.index | (src.base.index << 5) | (0b011'0'10 << 10) | (src.offset.index << 16) | (1 << 21) | (opsize << 22)); + break; case AddressKindA64::imm: if (unsigned(src.data >> sizelog) < 1024 && (src.data & ((1 << sizelog) - 1)) == 0) place(dst.index | (src.base.index << 5) | ((src.data >> sizelog) << 10) | (opsize << 22) | (1 << 24)); @@ -884,8 +887,13 @@ void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 sr else LUAU_ASSERT(!"Unable to encode large immediate offset"); break; - case AddressKindA64::reg: - place(dst.index | (src.base.index << 5) | (0b011'0'10 << 10) | (src.offset.index << 16) | (1 << 21) | (opsize << 22)); + case AddressKindA64::pre: + LUAU_ASSERT(src.data >= -256 && src.data <= 255); + place(dst.index | (src.base.index << 5) | (0b11 << 10) | ((src.data & ((1 << 9) - 1)) << 12) | (opsize << 22)); + break; + case AddressKindA64::post: + LUAU_ASSERT(src.data >= -256 && src.data <= 255); + place(dst.index | (src.base.index << 5) | (0b01 << 10) | ((src.data & ((1 << 9) - 1)) << 12) | (opsize << 22)); break; } @@ -1312,23 +1320,37 @@ void AssemblyBuilderA64::log(RegisterA64 reg) void AssemblyBuilderA64::log(AddressA64 addr) { - text.append("["); switch (addr.kind) { + case AddressKindA64::reg: + text.append("["); + log(addr.base); + text.append(","); + log(addr.offset); + text.append("]"); + break; case AddressKindA64::imm: + text.append("["); log(addr.base); if (addr.data != 0) logAppend(",#%d", addr.data); + text.append("]"); break; - case AddressKindA64::reg: + case AddressKindA64::pre: + text.append("["); log(addr.base); - text.append(","); - log(addr.offset); if (addr.data != 0) - logAppend(" LSL #%d", addr.data); + logAppend(",#%d", addr.data); + text.append("]!"); + break; + case AddressKindA64::post: + text.append("["); + log(addr.base); + text.append("]!"); + if (addr.data != 0) + logAppend(",#%d", addr.data); break; } - text.append("]"); } } // namespace A64 diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 426a0259d..c7644a86c 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -1415,7 +1415,7 @@ void AssemblyBuilderX64::commit() { LUAU_ASSERT(codePos <= codeEnd); - if (unsigned(codeEnd - codePos) < kMaxInstructionLength) + if (codeEnd - codePos < kMaxInstructionLength) extend(); } diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index 880a32446..09e1bb712 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -56,10 +56,8 @@ static void makePagesExecutable(uint8_t* mem, size_t size) static void flushInstructionCache(uint8_t* mem, size_t size) { -#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) if (FlushInstructionCache(GetCurrentProcess(), mem, size) == 0) LUAU_ASSERT(!"Failed to flush instruction cache"); -#endif } #else static uint8_t* allocatePages(size_t size) diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 4ee8e4440..89399cbc9 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -268,7 +268,7 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& [[maybe_unused]] static bool lowerIr( A64::AssemblyBuilderA64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { - A64::IrLoweringA64 lowering(build, helpers, data, proto, ir.function); + A64::IrLoweringA64 lowering(build, helpers, data, ir.function); return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index c5042fc32..355e29ca6 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -117,6 +117,81 @@ static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) build.br(x4); } +void emitReturn(AssemblyBuilderA64& build, ModuleHelpers& helpers) +{ + // x1 = res + // w2 = number of written values + + // x0 = ci + build.ldr(x0, mem(rState, offsetof(lua_State, ci))); + // w3 = ci->nresults + build.ldr(w3, mem(x0, offsetof(CallInfo, nresults))); + + Label skipResultCopy; + + // Fill the rest of the expected results (nresults - written) with 'nil' + build.cmp(w2, w3); + build.b(ConditionA64::GreaterEqual, skipResultCopy); + + // TODO: cmp above could compute this and flags using subs + build.sub(w2, w3, w2); // counter = nresults - written + build.mov(w4, LUA_TNIL); + + Label repeatNilLoop = build.setLabel(); + build.str(w4, mem(x1, offsetof(TValue, tt))); + build.add(x1, x1, sizeof(TValue)); + build.sub(w2, w2, 1); + build.cbnz(w2, repeatNilLoop); + + build.setLabel(skipResultCopy); + + // x2 = cip = ci - 1 + build.sub(x2, x0, sizeof(CallInfo)); + + // res = cip->top when nresults >= 0 + Label skipFixedRetTop; + build.tbnz(w3, 31, skipFixedRetTop); + build.ldr(x1, mem(x2, offsetof(CallInfo, top))); // res = cip->top + build.setLabel(skipFixedRetTop); + + // Update VM state (ci, base, top) + build.str(x2, mem(rState, offsetof(lua_State, ci))); // L->ci = cip + build.ldr(rBase, mem(x2, offsetof(CallInfo, base))); // sync base = L->base while we have a chance + build.str(rBase, mem(rState, offsetof(lua_State, base))); // L->base = cip->base + + build.str(x1, mem(rState, offsetof(lua_State, top))); // L->top = res + + // Unlikely, but this might be the last return from VM + build.ldr(w4, mem(x0, offsetof(CallInfo, flags))); + build.tbnz(w4, countrz(LUA_CALLINFO_RETURN), helpers.exitNoContinueVm); + + // Continue in interpreter if function has no native data + build.ldr(w4, mem(x2, offsetof(CallInfo, flags))); + build.tbz(w4, countrz(LUA_CALLINFO_NATIVE), helpers.exitContinueVm); + + // Need to update state of the current function before we jump away + build.ldr(rClosure, mem(x2, offsetof(CallInfo, func))); + build.ldr(rClosure, mem(rClosure, offsetof(TValue, value.gc))); + + build.ldr(x1, mem(rClosure, offsetof(Closure, l.p))); // cl->l.p aka proto + + LUAU_ASSERT(offsetof(Proto, code) == offsetof(Proto, k) + 8); + build.ldp(rConstants, rCode, mem(x1, offsetof(Proto, k))); // proto->k, proto->code + + // Get instruction index from instruction pointer + // To get instruction index from instruction pointer, we need to divide byte offset by 4 + // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out + build.ldr(x2, mem(x2, offsetof(CallInfo, savedpc))); // cip->savedpc + build.sub(x2, x2, rCode); + + // Get new instruction location and jump to it + LUAU_ASSERT(offsetof(Proto, exectarget) == offsetof(Proto, execdata) + 8); + build.ldp(x3, x4, mem(x1, offsetof(Proto, execdata))); + build.ldr(w2, mem(x3, x2)); + build.add(x4, x4, x2); + build.br(x4); +} + static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilder& unwind) { EntryLocations locations; @@ -230,6 +305,11 @@ void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers) build.logAppend("; interrupt\n"); helpers.interrupt = build.setLabel(); emitInterrupt(build); + + if (build.logText) + build.logAppend("; return\n"); + helpers.return_ = build.setLabel(); + emitReturn(build, helpers); } } // namespace A64 diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index a7131e113..20269cfd7 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -17,8 +17,6 @@ #include -LUAU_FASTFLAG(LuauUniformTopHandling) - // All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT // This makes sure that we save the pc (in case the Lua call needs to generate a backtrace) before the call, // and restores the stack pointer after in case stack gets reallocated @@ -306,44 +304,6 @@ Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) } } -// Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts -Closure* returnFallback(lua_State* L, StkId ra, StkId valend) -{ - // ci is our callinfo, cip is our parent - CallInfo* ci = L->ci; - CallInfo* cip = ci - 1; - - StkId res = ci->func; // note: we assume CALL always puts func+args and expects results to start at func - StkId vali = ra; - - int nresults = ci->nresults; - - // copy return values into parent stack (but only up to nresults!), fill the rest with nil - // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally - int i; - for (i = nresults; i != 0 && vali < valend; i--) - setobj2s(L, res++, vali++); - while (i-- > 0) - setnilvalue(res++); - - // pop the stack frame - L->ci = cip; - L->base = cip->base; - L->top = (nresults == LUA_MULTRET) ? res : cip->top; - - // we're done! - if (LUAU_UNLIKELY(ci->flags & LUA_CALLINFO_RETURN)) - { - if (!FFlag::LuauUniformTopHandling) - L->top = res; - return NULL; - } - - // keep executing new function - LUAU_ASSERT(isLua(cip)); - return clvalue(cip->func); -} - const Instruction* executeGETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k) { [[maybe_unused]] Closure* cl = clvalue(L->ci->func); diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 87b6ec449..a30d7e98b 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -18,7 +18,6 @@ Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults); void callEpilogC(lua_State* L, int nresults, int n); Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); -Closure* returnFallback(lua_State* L, StkId ra, StkId valend); const Instruction* executeGETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* executeSETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index ec032c02b..4100e667c 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -189,6 +189,11 @@ void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) build.logAppend("; continueCallInVm\n"); helpers.continueCallInVm = build.setLabel(); emitContinueCallInVm(build); + + if (build.logText) + build.logAppend("; return\n"); + helpers.return_ = build.setLabel(); + emitReturn(build, helpers); } } // namespace X64 diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index 6b19912bf..bfdde1690 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -24,6 +24,7 @@ struct ModuleHelpers // A64/X64 Label exitContinueVm; Label exitNoContinueVm; + Label return_; // X64 Label continueCallInVm; diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 0095f288a..4ad4efe7b 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -352,6 +352,89 @@ void emitContinueCallInVm(AssemblyBuilderX64& build) emitExit(build, /* continueInVm */ true); } +void emitReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers) +{ + // input: ci in r8, res in rdi, number of written values in ecx + RegisterX64 ci = r8; + RegisterX64 res = rdi; + RegisterX64 written = ecx; + + RegisterX64 cip = r9; + RegisterX64 nresults = esi; + + build.lea(cip, addr[ci - sizeof(CallInfo)]); + + // nresults = ci->nresults + build.mov(nresults, dword[ci + offsetof(CallInfo, nresults)]); + + Label skipResultCopy; + + // Fill the rest of the expected results (nresults - written) with 'nil' + RegisterX64 counter = written; + build.sub(counter, nresults); // counter = -(nresults - written) + build.jcc(ConditionX64::GreaterEqual, skipResultCopy); + + Label repeatNilLoop = build.setLabel(); + build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL); + build.add(res, sizeof(TValue)); + build.inc(counter); + build.jcc(ConditionX64::NotZero, repeatNilLoop); + + build.setLabel(skipResultCopy); + + build.mov(qword[rState + offsetof(lua_State, ci)], cip); // L->ci = cip + build.mov(rBase, qword[cip + offsetof(CallInfo, base)]); // sync base = L->base while we have a chance + build.mov(qword[rState + offsetof(lua_State, base)], rBase); // L->base = cip->base + + Label skipFixedRetTop; + build.test(nresults, nresults); // test here will set SF=1 for a negative number and it always sets OF to 0 + build.jcc(ConditionX64::Less, skipFixedRetTop); // jl jumps if SF != OF + build.mov(res, qword[cip + offsetof(CallInfo, top)]); // res = cip->top + build.setLabel(skipFixedRetTop); + + build.mov(qword[rState + offsetof(lua_State, top)], res); // L->top = res + + // Unlikely, but this might be the last return from VM + build.test(byte[ci + offsetof(CallInfo, flags)], LUA_CALLINFO_RETURN); + build.jcc(ConditionX64::NotZero, helpers.exitNoContinueVm); + + // Returning back to the previous function is a bit tricky + // Registers alive: r9 (cip) + RegisterX64 proto = rcx; + RegisterX64 execdata = rbx; + + // Change closure + build.mov(rax, qword[cip + offsetof(CallInfo, func)]); + build.mov(rax, qword[rax + offsetof(TValue, value.gc)]); + build.mov(sClosure, rax); + + build.mov(proto, qword[rax + offsetof(Closure, l.p)]); + + build.mov(execdata, qword[proto + offsetof(Proto, execdata)]); + + build.test(byte[cip + offsetof(CallInfo, flags)], LUA_CALLINFO_NATIVE); + build.jcc(ConditionX64::Zero, helpers.exitContinueVm); // Continue in interpreter if function has no native data + + // Change constants + build.mov(rConstants, qword[proto + offsetof(Proto, k)]); + + // Change code + build.mov(rdx, qword[proto + offsetof(Proto, code)]); + build.mov(sCode, rdx); + + build.mov(rax, qword[cip + offsetof(CallInfo, savedpc)]); + + // To get instruction index from instruction pointer, we need to divide byte offset by 4 + // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out + build.sub(rax, rdx); + + // Get new instruction location and jump to it + build.mov(edx, dword[execdata + rax]); + build.add(rdx, qword[proto + offsetof(Proto, exectarget)]); + build.jmp(rdx); +} + + } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 3f723f456..eb4532a0d 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -207,6 +207,8 @@ void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, in void emitContinueCallInVm(AssemblyBuilderX64& build); +void emitReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers); + } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index f2012ca9d..5d1c642fe 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -166,160 +166,61 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults) { RegisterX64 ci = r8; - RegisterX64 cip = r9; RegisterX64 res = rdi; - RegisterX64 nresults = esi; + RegisterX64 written = ecx; build.mov(ci, qword[rState + offsetof(lua_State, ci)]); - build.lea(cip, addr[ci - sizeof(CallInfo)]); - - // res = ci->func; note: we assume CALL always puts func+args and expects results to start at func build.mov(res, qword[ci + offsetof(CallInfo, func)]); - // nresults = ci->nresults - build.mov(nresults, dword[ci + offsetof(CallInfo, nresults)]); + if (actualResults == 0) { - Label skipResultCopy; - - RegisterX64 counter = ecx; - - if (actualResults == 0) + build.xor_(written, written); + build.jmp(helpers.return_); + } + else if (actualResults >= 1 && actualResults <= 3) + { + for (int r = 0; r < actualResults; ++r) { - // Our instruction doesn't have any results, so just fill results expected in parent with 'nil' - build.test(nresults, nresults); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0 - build.jcc(ConditionX64::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1 - - build.mov(counter, nresults); - - Label repeatNilLoop = build.setLabel(); - build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL); - build.add(res, sizeof(TValue)); - build.dec(counter); - build.jcc(ConditionX64::NotZero, repeatNilLoop); + build.vmovups(xmm0, luauReg(ra + r)); + build.vmovups(xmmword[res + r * sizeof(TValue)], xmm0); } - else if (actualResults == 1) - { - // Try setting our 1 result - build.test(nresults, nresults); - build.jcc(ConditionX64::Zero, skipResultCopy); - - build.lea(counter, addr[nresults - 1]); - - build.vmovups(xmm0, luauReg(ra)); - build.vmovups(xmmword[res], xmm0); - build.add(res, sizeof(TValue)); + build.add(res, actualResults * sizeof(TValue)); + build.mov(written, actualResults); + build.jmp(helpers.return_); + } + else + { + RegisterX64 vali = rax; + RegisterX64 valend = rdx; - // Fill the rest of the expected results with 'nil' - build.test(counter, counter); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0 - build.jcc(ConditionX64::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1 + // vali = ra + build.lea(vali, luauRegAddress(ra)); - Label repeatNilLoop = build.setLabel(); - build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL); - build.add(res, sizeof(TValue)); - build.dec(counter); - build.jcc(ConditionX64::NotZero, repeatNilLoop); - } + // Copy as much as possible for MULTRET calls, and only as much as needed otherwise + if (actualResults == LUA_MULTRET) + build.mov(valend, qword[rState + offsetof(lua_State, top)]); // valend = L->top else - { - RegisterX64 vali = rax; - RegisterX64 valend = rdx; - - // Copy return values into parent stack (but only up to nresults!) - build.test(nresults, nresults); - build.jcc(ConditionX64::Zero, skipResultCopy); - - // vali = ra - build.lea(vali, luauRegAddress(ra)); + build.lea(valend, luauRegAddress(ra + actualResults)); // valend = ra + actualResults - // Copy as much as possible for MULTRET calls, and only as much as needed otherwise - if (actualResults == LUA_MULTRET) - build.mov(valend, qword[rState + offsetof(lua_State, top)]); // valend = L->top - else - build.lea(valend, luauRegAddress(ra + actualResults)); // valend = ra + actualResults + build.xor_(written, written); - build.mov(counter, nresults); + Label repeatValueLoop, exitValueLoop; - Label repeatValueLoop, exitValueLoop; + build.cmp(vali, valend); + build.jcc(ConditionX64::NotBelow, exitValueLoop); - build.setLabel(repeatValueLoop); - build.cmp(vali, valend); - build.jcc(ConditionX64::NotBelow, exitValueLoop); + build.setLabel(repeatValueLoop); + build.vmovups(xmm0, xmmword[vali]); + build.vmovups(xmmword[res], xmm0); + build.add(vali, sizeof(TValue)); + build.add(res, sizeof(TValue)); + build.inc(written); + build.cmp(vali, valend); + build.jcc(ConditionX64::Below, repeatValueLoop); - build.vmovups(xmm0, xmmword[vali]); - build.vmovups(xmmword[res], xmm0); - build.add(vali, sizeof(TValue)); - build.add(res, sizeof(TValue)); - build.dec(counter); - build.jcc(ConditionX64::NotZero, repeatValueLoop); - - build.setLabel(exitValueLoop); - - // Fill the rest of the expected results with 'nil' - build.test(counter, counter); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0 - build.jcc(ConditionX64::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1 - - Label repeatNilLoop = build.setLabel(); - build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL); - build.add(res, sizeof(TValue)); - build.dec(counter); - build.jcc(ConditionX64::NotZero, repeatNilLoop); - } - - build.setLabel(skipResultCopy); + build.setLabel(exitValueLoop); + build.jmp(helpers.return_); } - - build.mov(qword[rState + offsetof(lua_State, ci)], cip); // L->ci = cip - build.mov(rBase, qword[cip + offsetof(CallInfo, base)]); // sync base = L->base while we have a chance - build.mov(qword[rState + offsetof(lua_State, base)], rBase); // L->base = cip->base - - // Start with result for LUA_MULTRET/exit value - build.mov(qword[rState + offsetof(lua_State, top)], res); // L->top = res - - // Unlikely, but this might be the last return from VM - build.test(byte[ci + offsetof(CallInfo, flags)], LUA_CALLINFO_RETURN); - build.jcc(ConditionX64::NotZero, helpers.exitNoContinueVm); - - Label skipFixedRetTop; - build.test(nresults, nresults); // test here will set SF=1 for a negative number and it always sets OF to 0 - build.jcc(ConditionX64::Less, skipFixedRetTop); // jl jumps if SF != OF - build.mov(rax, qword[cip + offsetof(CallInfo, top)]); - build.mov(qword[rState + offsetof(lua_State, top)], rax); // L->top = cip->top - build.setLabel(skipFixedRetTop); - - // Returning back to the previous function is a bit tricky - // Registers alive: r9 (cip) - RegisterX64 proto = rcx; - RegisterX64 execdata = rbx; - - // Change closure - build.mov(rax, qword[cip + offsetof(CallInfo, func)]); - build.mov(rax, qword[rax + offsetof(TValue, value.gc)]); - build.mov(sClosure, rax); - - build.mov(proto, qword[rax + offsetof(Closure, l.p)]); - - build.mov(execdata, qword[proto + offsetof(Proto, execdata)]); - - build.test(byte[cip + offsetof(CallInfo, flags)], LUA_CALLINFO_NATIVE); - build.jcc(ConditionX64::Zero, helpers.exitContinueVm); // Continue in interpreter if function has no native data - - // Change constants - build.mov(rConstants, qword[proto + offsetof(Proto, k)]); - - // Change code - build.mov(rdx, qword[proto + offsetof(Proto, code)]); - build.mov(sCode, rdx); - - build.mov(rax, qword[cip + offsetof(CallInfo, savedpc)]); - - // To get instruction index from instruction pointer, we need to divide byte offset by 4 - // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out - build.sub(rax, rdx); - - // Get new instruction location and jump to it - build.mov(edx, dword[execdata + rax]); - build.add(rdx, qword[proto + offsetof(Proto, exectarget)]); - build.jmp(rdx); } void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index) diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 85811f057..14fc9b467 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -661,9 +661,212 @@ static void computeCfgBlockEdges(IrFunction& function) } } +// Assign tree depth and pre- and post- DFS visit order of the tree/graph nodes +// Optionally, collect required node order into a vector +template +void computeBlockOrdering( + IrFunction& function, std::vector& ordering, std::vector* preOrder, std::vector* postOrder) +{ + CfgInfo& info = function.cfg; + + LUAU_ASSERT(info.idoms.size() == function.blocks.size()); + + ordering.clear(); + ordering.resize(function.blocks.size()); + + // Get depth-first post-order using manual stack instead of recursion + struct StackItem + { + uint32_t blockIdx; + uint32_t itPos; + }; + std::vector stack; + + if (preOrder) + preOrder->reserve(function.blocks.size()); + if (postOrder) + postOrder->reserve(function.blocks.size()); + + uint32_t nextPreOrder = 0; + uint32_t nextPostOrder = 0; + + stack.push_back({0, 0}); + ordering[0].visited = true; + ordering[0].preOrder = nextPreOrder++; + + while (!stack.empty()) + { + StackItem& item = stack.back(); + BlockIteratorWrapper children = childIt(info, item.blockIdx); + + if (item.itPos < children.size()) + { + uint32_t childIdx = children[item.itPos++]; + + BlockOrdering& childOrdering = ordering[childIdx]; + + if (!childOrdering.visited) + { + childOrdering.visited = true; + childOrdering.depth = uint32_t(stack.size()); + childOrdering.preOrder = nextPreOrder++; + + if (preOrder) + preOrder->push_back(item.blockIdx); + + stack.push_back({childIdx, 0}); + } + } + else + { + ordering[item.blockIdx].postOrder = nextPostOrder++; + + if (postOrder) + postOrder->push_back(item.blockIdx); + + stack.pop_back(); + } + } +} + +// Dominance tree construction based on 'A Simple, Fast Dominance Algorithm' [Keith D. Cooper, et al] +// This solution has quadratic complexity in the worst case. +// It is possible to switch to SEMI-NCA algorithm (also quadratic) mentioned in 'Linear-Time Algorithms for Dominators and Related Problems' [Loukas +// Georgiadis] + +// Find block that is common between blocks 'a' and 'b' on the path towards the entry +static uint32_t findCommonDominator(const std::vector& idoms, const std::vector& data, uint32_t a, uint32_t b) +{ + while (a != b) + { + while (data[a].postOrder < data[b].postOrder) + { + a = idoms[a]; + LUAU_ASSERT(a != ~0u); + } + + while (data[b].postOrder < data[a].postOrder) + { + b = idoms[b]; + LUAU_ASSERT(b != ~0u); + } + } + + return a; +} + +void computeCfgImmediateDominators(IrFunction& function) +{ + CfgInfo& info = function.cfg; + + // Clear existing data + info.idoms.clear(); + info.idoms.resize(function.blocks.size(), ~0u); + + std::vector ordering; + std::vector blocksInPostOrder; + computeBlockOrdering(function, ordering, /* preOrder */ nullptr, &blocksInPostOrder); + + // Entry node is temporarily marked to be an idom of itself to make algorithm work + info.idoms[0] = 0; + + // Iteratively compute immediate dominators + bool updated = true; + + while (updated) + { + updated = false; + + // Go over blocks in reverse post-order of CFG + // '- 2' skips the root node which is last in post-order traversal + for (int i = int(blocksInPostOrder.size() - 2); i >= 0; i--) + { + uint32_t blockIdx = blocksInPostOrder[i]; + uint32_t newIdom = ~0u; + + for (uint32_t predIdx : predecessors(info, blockIdx)) + { + if (uint32_t predIdom = info.idoms[predIdx]; predIdom != ~0u) + { + if (newIdom == ~0u) + newIdom = predIdx; + else + newIdom = findCommonDominator(info.idoms, ordering, newIdom, predIdx); + } + } + + if (newIdom != info.idoms[blockIdx]) + { + info.idoms[blockIdx] = newIdom; + + // Run until a fixed point is reached + updated = true; + } + } + } + + // Entry node doesn't have an immediate dominator + info.idoms[0] = ~0u; +} + +void computeCfgDominanceTreeChildren(IrFunction& function) +{ + CfgInfo& info = function.cfg; + + // Clear existing data + info.domChildren.clear(); + + info.domChildrenOffsets.clear(); + info.domChildrenOffsets.resize(function.blocks.size()); + + // First we need to know children count of each node in the dominance tree + // We use offset array for to hold this data, counts will be readjusted to offsets later + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + uint32_t domParent = info.idoms[blockIdx]; + + if (domParent != ~0u) + info.domChildrenOffsets[domParent]++; + } + + // Convert counds to offsets using prefix sum + uint32_t total = 0; + + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + uint32_t& offset = info.domChildrenOffsets[blockIdx]; + uint32_t count = offset; + offset = total; + total += count; + } + + info.domChildren.resize(total); + + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + // We use a trick here, where we use the starting offset of the dominance children list as the position where to write next child + // The values will be adjusted back in a separate loop later + uint32_t domParent = info.idoms[blockIdx]; + + if (domParent != ~0u) + info.domChildren[info.domChildrenOffsets[domParent]++] = uint32_t(blockIdx); + } + + // Offsets into the dominance children list were used as iterators in the previous loop + // That process basically moved the values in the array 1 step towards the start + // Here we move them one step towards the end and restore 0 for first offset + for (int blockIdx = int(function.blocks.size() - 1); blockIdx > 0; blockIdx--) + info.domChildrenOffsets[blockIdx] = info.domChildrenOffsets[blockIdx - 1]; + info.domChildrenOffsets[0] = 0; + + computeBlockOrdering(function, info.domOrdering, /* preOrder */ nullptr, /* postOrder */ nullptr); +} + void computeCfgInfo(IrFunction& function) { computeCfgBlockEdges(function); + computeCfgImmediateDominators(function); + computeCfgDominanceTreeChildren(function); computeCfgLiveInOutRegSets(function); } @@ -687,5 +890,15 @@ BlockIteratorWrapper successors(const CfgInfo& cfg, uint32_t blockIdx) return BlockIteratorWrapper{cfg.successors.data() + start, cfg.successors.data() + end}; } +BlockIteratorWrapper domChildren(const CfgInfo& cfg, uint32_t blockIdx) +{ + LUAU_ASSERT(blockIdx < cfg.domChildrenOffsets.size()); + + uint32_t start = cfg.domChildrenOffsets[blockIdx]; + uint32_t end = blockIdx + 1 < cfg.domChildrenOffsets.size() ? cfg.domChildrenOffsets[blockIdx + 1] : uint32_t(cfg.domChildren.size()); + + return BlockIteratorWrapper{cfg.domChildren.data() + start, cfg.domChildren.data() + end}; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index a12eca348..6ab5e249f 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -25,6 +25,7 @@ IrBuilder::IrBuilder() void IrBuilder::buildFunctionIr(Proto* proto) { function.proto = proto; + function.variadic = proto->is_vararg != 0; // Rebuild original control flow blocks rebuildBytecodeBasicBlocks(proto); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 5f6249000..5c29ad413 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -185,11 +185,10 @@ static bool emitBuiltin( } } -IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) +IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, IrFunction& function) : build(build) , helpers(helpers) , data(data) - , proto(proto) , function(function) , regs(function, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}}) , valueTracker(function) @@ -1343,19 +1342,71 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::RETURN: regs.spill(build, index); - // valend = (n == LUA_MULTRET) ? L->top : ra + n - if (intOp(inst.b) == LUA_MULTRET) - build.ldr(x2, mem(rState, offsetof(lua_State, top))); + + if (function.variadic) + { + build.ldr(x1, mem(rState, offsetof(lua_State, ci))); + build.ldr(x1, mem(x1, offsetof(CallInfo, func))); + } + else if (intOp(inst.b) != 1) + build.sub(x1, rBase, sizeof(TValue)); // invariant: ci->func + 1 == ci->base for non-variadic frames + + if (intOp(inst.b) == 0) + { + build.mov(w2, 0); + build.b(helpers.return_); + } + else if (intOp(inst.b) == 1 && !function.variadic) + { + // fast path: minimizes x1 adjustments + // note that we skipped x1 computation for this specific case above + build.ldr(q0, mem(rBase, vmRegOp(inst.a) * sizeof(TValue))); + build.str(q0, mem(rBase, -int(sizeof(TValue)))); + build.mov(x1, rBase); + build.mov(w2, 1); + build.b(helpers.return_); + } + else if (intOp(inst.b) >= 1 && intOp(inst.b) <= 3) + { + for (int r = 0; r < intOp(inst.b); ++r) + { + build.ldr(q0, mem(rBase, (vmRegOp(inst.a) + r) * sizeof(TValue))); + build.str(q0, mem(x1, sizeof(TValue), AddressKindA64::post)); + } + build.mov(w2, intOp(inst.b)); + build.b(helpers.return_); + } else - build.add(x2, rBase, uint16_t((vmRegOp(inst.a) + intOp(inst.b)) * sizeof(TValue))); - // returnFallback(L, ra, valend) - build.mov(x0, rState); - build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); - build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, returnFallback))); - build.blr(x3); + { + build.mov(w2, 0); - // reentry with x0=closure (NULL will trigger exit) - build.b(helpers.reentry); + // vali = ra + build.add(x3, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + + // valend = (n == LUA_MULTRET) ? L->top : ra + n + if (intOp(inst.b) == LUA_MULTRET) + build.ldr(x4, mem(rState, offsetof(lua_State, top))); + else + build.add(x4, rBase, uint16_t((vmRegOp(inst.a) + intOp(inst.b)) * sizeof(TValue))); + + Label repeatValueLoop, exitValueLoop; + + if (intOp(inst.b) == LUA_MULTRET) + { + build.cmp(x3, x4); + build.b(ConditionA64::CarrySet, exitValueLoop); // CarrySet == UnsignedGreaterEqual + } + + build.setLabel(repeatValueLoop); + build.ldr(q0, mem(x3, sizeof(TValue), AddressKindA64::post)); + build.str(q0, mem(x1, sizeof(TValue), AddressKindA64::post)); + build.add(w2, w2, 1); + build.cmp(x3, x4); + build.b(ConditionA64::CarryClear, repeatValueLoop); // CarryClear == UnsignedLess + + build.setLabel(exitValueLoop); + build.b(helpers.return_); + } break; case IrCmd::FORGLOOP: // register layout: ra + 1 = table, ra + 2 = internal index, ra + 3 .. ra + aux = iteration variables diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 264789044..1df09bd37 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -9,8 +9,6 @@ #include -struct Proto; - namespace Luau { namespace CodeGen @@ -25,7 +23,7 @@ namespace A64 struct IrLoweringA64 { - IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function); + IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, IrFunction& function); void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); void finishBlock(); @@ -58,7 +56,6 @@ struct IrLoweringA64 AssemblyBuilderA64& build; ModuleHelpers& helpers; NativeState& data; - Proto* proto = nullptr; // Temporarily required to provide 'Instruction* pc' to old emitInst* methods IrFunction& function; diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 17977c3c2..14c1acd99 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -90,7 +90,6 @@ void initFunctions(NativeState& data) data.context.callEpilogC = callEpilogC; data.context.callFallback = callFallback; - data.context.returnFallback = returnFallback; data.context.executeGETGLOBAL = executeGETGLOBAL; data.context.executeSETGLOBAL = executeSETGLOBAL; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 0140448fd..a2393bbfe 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -86,7 +86,6 @@ struct NativeContext void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr; Closure* (*callFallback)(lua_State* L, StkId ra, StkId argtop, int nresults) = nullptr; - Closure* (*returnFallback)(lua_State* L, StkId ra, StkId valend) = nullptr; // Opcode fallbacks, implemented in C const Instruction* (*executeGETGLOBAL)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; diff --git a/Makefile b/Makefile index 99eb93e6c..d3bf31d2e 100644 --- a/Makefile +++ b/Makefile @@ -44,6 +44,10 @@ ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Analyze.cpp ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o) ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze +COMPILE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Compile.cpp +COMPILE_CLI_OBJECTS=$(COMPILE_CLI_SOURCES:%=$(BUILD)/%.o) +COMPILE_CLI_TARGET=$(BUILD)/luau-compile + FUZZ_SOURCES=$(wildcard fuzz/*.cpp) fuzz/luau.pb.cpp FUZZ_OBJECTS=$(FUZZ_SOURCES:%=$(BUILD)/%.o) @@ -55,8 +59,8 @@ ifneq ($(opt),) TESTS_ARGS+=-O$(opt) endif -OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(FUZZ_OBJECTS) -EXECUTABLE_ALIASES = luau luau-analyze luau-tests +OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(COMPILE_CLI_OBJECTS) $(FUZZ_OBJECTS) +EXECUTABLE_ALIASES = luau luau-analyze luau-compile luau-tests # common flags CXXFLAGS=-g -Wall @@ -132,6 +136,7 @@ $(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include $(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY $(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -Iextern -Iextern/isocline/include $(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -Iextern +$(COMPILE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICodeGen/include $(TESTS_TARGET): LDFLAGS+=-lpthread @@ -189,6 +194,9 @@ luau: $(REPL_CLI_TARGET) luau-analyze: $(ANALYZE_CLI_TARGET) ln -fs $^ $@ +luau-compile: $(COMPILE_CLI_TARGET) + ln -fs $^ $@ + luau-tests: $(TESTS_TARGET) ln -fs $^ $@ @@ -196,8 +204,9 @@ luau-tests: $(TESTS_TARGET) $(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(AST_TARGET) +$(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) -$(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET): +$(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(COMPILE_CLI_TARGET): $(CXX) $^ $(LDFLAGS) -o $@ # executable targets for fuzzing diff --git a/Sources.cmake b/Sources.cmake index 853b1b866..b1693c36d 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -141,7 +141,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Clone.h Analysis/include/Luau/Config.h - Analysis/include/Luau/Refinement.h Analysis/include/Luau/Constraint.h Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintSolver.h @@ -153,6 +152,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h Analysis/include/Luau/Frontend.h + Analysis/include/Luau/InsertionOrderedMap.h Analysis/include/Luau/Instantiation.h Analysis/include/Luau/IostreamHelpers.h Analysis/include/Luau/JsonEmitter.h @@ -165,6 +165,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Predicate.h Analysis/include/Luau/Quantify.h Analysis/include/Luau/RecursionCounter.h + Analysis/include/Luau/Refinement.h Analysis/include/Luau/RequireTracer.h Analysis/include/Luau/Scope.h Analysis/include/Luau/Simplify.h @@ -175,6 +176,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/ToString.h Analysis/include/Luau/Transpiler.h Analysis/include/Luau/TxnLog.h + Analysis/include/Luau/Type.h Analysis/include/Luau/TypeArena.h Analysis/include/Luau/TypeAttach.h Analysis/include/Luau/TypeChecker2.h @@ -183,7 +185,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypePack.h Analysis/include/Luau/TypeUtils.h - Analysis/include/Luau/Type.h Analysis/include/Luau/Unifiable.h Analysis/include/Luau/Unifier.h Analysis/include/Luau/UnifierSharedState.h @@ -198,7 +199,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/BuiltinDefinitions.cpp Analysis/src/Clone.cpp Analysis/src/Config.cpp - Analysis/src/Refinement.cpp Analysis/src/Constraint.cpp Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintSolver.cpp @@ -216,6 +216,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Module.cpp Analysis/src/Normalize.cpp Analysis/src/Quantify.cpp + Analysis/src/Refinement.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp Analysis/src/Simplify.cpp @@ -226,6 +227,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/ToString.cpp Analysis/src/Transpiler.cpp Analysis/src/TxnLog.cpp + Analysis/src/Type.cpp Analysis/src/TypeArena.cpp Analysis/src/TypeAttach.cpp Analysis/src/TypeChecker2.cpp @@ -234,7 +236,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypeInfer.cpp Analysis/src/TypePack.cpp Analysis/src/TypeUtils.cpp - Analysis/src/Type.cpp Analysis/src/Unifiable.cpp Analysis/src/Unifier.cpp ) @@ -326,6 +327,7 @@ if(TARGET Luau.Analyze.CLI) endif() if(TARGET Luau.Ast.CLI) + # Luau.Ast.CLI Sources target_sources(Luau.Ast.CLI PRIVATE CLI/Ast.cpp CLI/FileUtils.h @@ -415,6 +417,7 @@ if(TARGET Luau.UnitTest) tests/TypeVar.test.cpp tests/Variant.test.cpp tests/VisitType.test.cpp + tests/InsertionOrderedMap.test.cpp tests/main.cpp) endif() @@ -449,9 +452,20 @@ if(TARGET Luau.Web) endif() if(TARGET Luau.Reduce.CLI) + # Luau.Reduce.CLI Sources target_sources(Luau.Reduce.CLI PRIVATE CLI/Reduce.cpp CLI/FileUtils.cpp CLI/FileUtils.h ) endif() + +if(TARGET Luau.Compile.CLI) + # Luau.Compile.CLI Sources + target_sources(Luau.Compile.CLI PRIVATE + CLI/FileUtils.h + CLI/FileUtils.cpp + CLI/Flags.h + CLI/Flags.cpp + CLI/Compile.cpp) +endif() diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 7f58d9635..e5fde4d46 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,8 +17,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauUniformTopHandling, false) - /* ** {====================================================== ** Error-recovery functions @@ -252,7 +250,7 @@ void luaD_call(lua_State* L, StkId func, int nresults) L->isactive = false; } - if (FFlag::LuauUniformTopHandling && nresults != LUA_MULTRET) + if (nresults != LUA_MULTRET) L->top = restorestack(L, old_func) + nresults; L->nCcalls--; diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 280c47927..79bf807b6 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,7 +16,6 @@ #include -LUAU_FASTFLAG(LuauUniformTopHandling) LUAU_FASTFLAG(LuauGetImportDirect) // Disable c99-designator to avoid the warning in CGOTO dispatch table @@ -1043,8 +1042,6 @@ static void luau_execute(lua_State* L) // we're done! if (LUAU_UNLIKELY(ci->flags & LUA_CALLINFO_RETURN)) { - if (!FFlag::LuauUniformTopHandling) - L->top = res; goto exit; } diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index c917a7bb3..ba8d40c2b 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -460,6 +460,25 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Undefined") SINGLE_COMPARE(udf(), 0x00000000); } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "PrePostIndexing") +{ + SINGLE_COMPARE(ldr(x0, mem(x1, 1)), 0xF8401020); + SINGLE_COMPARE(ldr(x0, mem(x1, 1, AddressKindA64::pre)), 0xF8401C20); + SINGLE_COMPARE(ldr(x0, mem(x1, 1, AddressKindA64::post)), 0xF8401420); + + SINGLE_COMPARE(ldr(q0, mem(x1, 1)), 0x3CC01020); + SINGLE_COMPARE(ldr(q0, mem(x1, 1, AddressKindA64::pre)), 0x3CC01C20); + SINGLE_COMPARE(ldr(q0, mem(x1, 1, AddressKindA64::post)), 0x3CC01420); + + SINGLE_COMPARE(str(x0, mem(x1, 1)), 0xF8001020); + SINGLE_COMPARE(str(x0, mem(x1, 1, AddressKindA64::pre)), 0xF8001C20); + SINGLE_COMPARE(str(x0, mem(x1, 1, AddressKindA64::post)), 0xF8001420); + + SINGLE_COMPARE(str(q0, mem(x1, 1)), 0x3C801020); + SINGLE_COMPARE(str(q0, mem(x1, 1, AddressKindA64::pre)), 0x3C801C20); + SINGLE_COMPARE(str(q0, mem(x1, 1, AddressKindA64::post)), 0x3C801420); +} + TEST_CASE("LogTest") { AssemblyBuilderA64 build(/* logText= */ true); @@ -501,6 +520,10 @@ TEST_CASE("LogTest") build.ubfx(x1, x2, 37, 5); + build.ldr(x0, mem(x1, 1)); + build.ldr(x0, mem(x1, 1, AddressKindA64::pre)); + build.ldr(x0, mem(x1, 1, AddressKindA64::post)); + build.setLabel(l); build.ret(); @@ -534,6 +557,9 @@ TEST_CASE("LogTest") tbz x0,#5,.L1 fcvt s1,d2 ubfx x1,x2,#3705 + ldr x0,[x1,#1] + ldr x0,[x1,#1]! + ldr x0,[x1]!,#1 .L1: ret )"; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index d66eb18e8..b8dee9976 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3471,8 +3471,6 @@ local a: T@1 TEST_CASE_FIXTURE(ACFixture, "frontend_use_correct_global_scope") { - ScopedFastFlag sff("LuauTypeCheckerUseCorrectScope", true); - loadDefinition(R"( declare class Instance Name: string diff --git a/tests/InsertionOrderedMap.test.cpp b/tests/InsertionOrderedMap.test.cpp new file mode 100644 index 000000000..ca6f14994 --- /dev/null +++ b/tests/InsertionOrderedMap.test.cpp @@ -0,0 +1,140 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/InsertionOrderedMap.h" + +#include + +#include "doctest.h" + +using namespace Luau; + +struct MapFixture +{ + std::vector> ptrs; + + int* makePtr() + { + ptrs.push_back(std::make_unique(int{})); + return ptrs.back().get(); + } +}; + +TEST_SUITE_BEGIN("InsertionOrderedMap"); + +TEST_CASE_FIXTURE(MapFixture, "map_insertion") +{ + InsertionOrderedMap map; + + int* a = makePtr(); + int* b = makePtr(); + + map.insert(a, 1); + map.insert(b, 2); +} + +TEST_CASE_FIXTURE(MapFixture, "map_lookup") +{ + InsertionOrderedMap map; + + int* a = makePtr(); + map.insert(a, 1); + + int* r = map.get(a); + REQUIRE(r != nullptr); + CHECK(*r == 1); + + r = map.get(makePtr()); + CHECK(r == nullptr); +} + +TEST_CASE_FIXTURE(MapFixture, "insert_does_not_update") +{ + InsertionOrderedMap map; + + int* k = makePtr(); + map.insert(k, 1); + map.insert(k, 2); + + int* v = map.get(k); + REQUIRE(v != nullptr); + CHECK(*v == 1); +} + +TEST_CASE_FIXTURE(MapFixture, "insertion_order_is_iteration_order") +{ + // This one is a little hard to prove, in that if the ordering guarantees + // fail this test isn't guaranteed to fail, but it is strictly better than + // nothing. + + InsertionOrderedMap map; + int* a = makePtr(); + int* b = makePtr(); + int* c = makePtr(); + map.insert(a, 1); + map.insert(b, 1); + map.insert(c, 1); + + auto it = map.begin(); + REQUIRE(it != map.end()); + CHECK(it->first == a); + CHECK(it->second == 1); + + ++it; + REQUIRE(it != map.end()); + CHECK(it->first == b); + CHECK(it->second == 1); + + ++it; + REQUIRE(it != map.end()); + CHECK(it->first == c); + CHECK(it->second == 1); + + ++it; + CHECK(it == map.end()); +} + +TEST_CASE_FIXTURE(MapFixture, "destructuring_iterator_compiles") +{ + // This test's only purpose is to successfully compile. + InsertionOrderedMap map; + + for (auto [k, v] : map) + { + // Checks here solely to silence unused variable warnings. + CHECK(k); + CHECK(v > 0); + } +} + +TEST_CASE_FIXTURE(MapFixture, "map_erasure") +{ + InsertionOrderedMap map; + + int* a = makePtr(); + int* b = makePtr(); + + map.insert(a, 1); + map.insert(b, 2); + + map.erase(map.find(a)); + CHECK(map.size() == 1); + CHECK(!map.contains(a)); + CHECK(map.get(a) == nullptr); + + int* v = map.get(b); + REQUIRE(v); +} + +TEST_CASE_FIXTURE(MapFixture, "map_clear") +{ + InsertionOrderedMap map; + int* a = makePtr(); + + map.insert(a, 1); + + map.clear(); + CHECK(map.size() == 0); + CHECK(!map.contains(a)); + CHECK(map.get(a) == nullptr); +} + +TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 32634225a..4bfb63f3a 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -74,6 +74,35 @@ class IrBuilderFixture CHECK(target.f == inst.f); } + void defineCfgTree(const std::vector>& successorSets) + { + for (const std::vector& successorSet : successorSets) + { + build.beginBlock(build.block(IrBlockKind::Internal)); + + build.function.cfg.successorsOffsets.push_back(uint32_t(build.function.cfg.successors.size())); + build.function.cfg.successors.insert(build.function.cfg.successors.end(), successorSet.begin(), successorSet.end()); + } + + // Brute-force the predecessor list + for (int i = 0; i < int(build.function.blocks.size()); i++) + { + build.function.cfg.predecessorsOffsets.push_back(uint32_t(build.function.cfg.predecessors.size())); + + for (int k = 0; k < int(build.function.blocks.size()); k++) + { + for (uint32_t succIdx : successors(build.function.cfg, k)) + { + if (succIdx == i) + build.function.cfg.predecessors.push_back(k); + } + } + } + + computeCfgImmediateDominators(build.function); + computeCfgDominanceTreeChildren(build.function); + } + IrBuilder build; // Luau.VM headers are not accessible @@ -2164,6 +2193,30 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") )"); } +// 'A Simple, Fast Dominance Algorithm' [Keith D. Cooper, et al]. Figure 2. +TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification1") +{ + defineCfgTree({{1, 2}, {3}, {4}, {4}, {3}}); + + CHECK(build.function.cfg.idoms == std::vector{{~0u, 0, 0, 0, 0}}); +} + +// 'A Linear Time Algorithm for Placing Phi-Nodes' [Vugranam C.Sreedhar]. Figure 1. +TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification2") +{ + defineCfgTree({{1, 16}, {2, 3, 4}, {4, 7}, {9}, {5}, {6}, {2, 8}, {8}, {7, 15}, {10, 11}, {12}, {12}, {13}, {3, 14, 15}, {12}, {16}, {}}); + + CHECK(build.function.cfg.idoms == std::vector{~0u, 0, 1, 1, 1, 4, 5, 1, 1, 3, 9, 9, 9, 12, 13, 1, 0}); +} + +// 'A Linear Time Algorithm for Placing Phi-Nodes' [Vugranam C.Sreedhar]. Figure 4. +TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification3") +{ + defineCfgTree({{1, 2}, {3}, {3, 4}, {5}, {5, 6}, {7}, {7}, {}}); + + CHECK(build.function.cfg.idoms == std::vector{~0u, 0, 0, 0, 2, 0, 4, 0}); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ValueNumbering"); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 93ea75103..8fe86655c 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -791,14 +791,21 @@ TEST_CASE_FIXTURE(NormalizeFixture, "normalize_blocked_types") CHECK_EQ(normalizer.typeFromNormal(*norm), &blocked); } -TEST_CASE_FIXTURE(NormalizeFixture, "normalize_pending_expansion_types") +TEST_CASE_FIXTURE(NormalizeFixture, "normalize_is_exactly_number") { - AstName name; - Type pending{PendingExpansionType{std::nullopt, name, {}, {}}}; + const NormalizedType* number = normalizer.normalize(builtinTypes->numberType); + // 1. all types for which Types::number say true for, NormalizedType::isExactlyNumber should say true as well + CHECK(Luau::isNumber(builtinTypes->numberType) == number->isExactlyNumber()); + // 2. isExactlyNumber should handle cases like `number & number` + TypeId intersection = arena.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->numberType}}); + const NormalizedType* normIntersection = normalizer.normalize(intersection); + CHECK(normIntersection->isExactlyNumber()); - const NormalizedType* norm = normalizer.normalize(&pending); + // 3. isExactlyNumber should reject things that are definitely not precisely numbers `number | any` - CHECK_EQ(normalizer.typeFromNormal(*norm), &pending); + TypeId yoonion = arena.addType(UnionType{{builtinTypes->anyType, builtinTypes->numberType}}); + const NormalizedType* unionIntersection = normalizer.normalize(yoonion); + CHECK(!unionIntersection->isExactlyNumber()); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index e5bcfa304..5aabb240b 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -14,7 +14,7 @@ using namespace Luau; -LUAU_FASTFLAG(LuauInstantiateInSubtyping); +LUAU_FASTFLAG(LuauInstantiateInSubtyping) TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -2073,4 +2073,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "param_1_and_2_both_takes_the_same_generic_bu CHECK_EQ(toString(result.errors[1]), "Type 'number' could not be converted into 'boolean'"); } +TEST_CASE_FIXTURE(Fixture, "attempt_to_call_an_intersection_of_tables") +{ + CheckResult result = check(R"( + local function f(t: { x: number } & { y: string }) + t() + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(toString(result.errors[0]), "Cannot call non-function {| x: number |} & {| y: string |}"); + else + CHECK_EQ(toString(result.errors[0]), "Cannot call non-function {| x: number |}"); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index e7d1f5f40..f049a0ee9 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -88,7 +88,6 @@ TableTests.oop_polymorphic TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table -TableTests.result_is_bool_for_equality_operators_if_lhs_is_any TableTests.right_table_missing_key2 TableTests.shared_selfs TableTests.shared_selfs_from_free_param @@ -167,7 +166,6 @@ TypeInferOperators.CallOrOfFunctions TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops -TypeInferOperators.luau-polyfill.String.slice TypeInferOperators.operator_eq_completely_incompatible TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs From 212888c36194da897a757a13a8d6fce26fe6874d Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 9 Jun 2023 15:46:43 +0300 Subject: [PATCH 59/66] Fix build warning --- tests/IrBuilder.test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 4bfb63f3a..5b0c44d0c 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -93,7 +93,7 @@ class IrBuilderFixture { for (uint32_t succIdx : successors(build.function.cfg, k)) { - if (succIdx == i) + if (succIdx == uint32_t(i)) build.function.cfg.predecessors.push_back(k); } } From 6ee4f190abc79a682bb173f5a3fdb664549384b1 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 16 Jun 2023 10:01:18 -0700 Subject: [PATCH 60/66] Sync to upstream/release/581 --- Analysis/include/Luau/ConstraintSolver.h | 8 +- Analysis/src/AstJsonEncoder.cpp | 1 + Analysis/src/ConstraintGraphBuilder.cpp | 18 ++ Analysis/src/ConstraintSolver.cpp | 140 +++++------ Analysis/src/Frontend.cpp | 49 +++- Analysis/src/Normalize.cpp | 8 +- Analysis/src/Simplify.cpp | 12 + Analysis/src/Substitution.cpp | 7 +- Analysis/src/TypeAttach.cpp | 14 +- Analysis/src/TypeInfer.cpp | 4 + Ast/include/Luau/Ast.h | 18 +- Ast/src/Ast.cpp | 5 +- Ast/src/Parser.cpp | 23 +- CLI/Repl.cpp | 293 +--------------------- CodeGen/include/Luau/AssemblyBuilderX64.h | 3 + CodeGen/include/Luau/IrData.h | 2 + CodeGen/include/Luau/IrDump.h | 2 + CodeGen/src/AssemblyBuilderX64.cpp | 24 +- CodeGen/src/CodeAllocator.cpp | 2 + CodeGen/src/CodeGen.cpp | 20 +- CodeGen/src/CodeGenA64.cpp | 10 +- CodeGen/src/CodeGenX64.cpp | 18 +- CodeGen/src/EmitCommon.h | 2 +- CodeGen/src/EmitCommonX64.cpp | 52 ++-- CodeGen/src/EmitCommonX64.h | 27 +- CodeGen/src/EmitInstructionX64.cpp | 46 +++- CodeGen/src/EmitInstructionX64.h | 2 +- CodeGen/src/IrBuilder.cpp | 1 + CodeGen/src/IrDump.cpp | 151 +++++++++-- CodeGen/src/IrLoweringA64.cpp | 36 +-- CodeGen/src/IrLoweringA64.h | 10 + CodeGen/src/IrLoweringX64.cpp | 38 ++- CodeGen/src/IrLoweringX64.h | 10 + CodeGen/src/OptimizeConstProp.cpp | 9 +- Common/include/Luau/Bytecode.h | 23 +- Compiler/include/Luau/BytecodeBuilder.h | 5 + Compiler/src/BytecodeBuilder.cpp | 100 ++++++++ Compiler/src/Compiler.cpp | 79 ++---- Compiler/src/Types.cpp | 106 ++++++++ Compiler/src/Types.h | 9 + Sources.cmake | 2 + VM/src/lvm.h | 1 - VM/src/lvmexecute.cpp | 18 +- VM/src/lvmload.cpp | 73 +++--- bench/bench.py | 59 ++++- bench/bench_support.lua | 42 ++++ tests/AssemblyBuilderX64.test.cpp | 16 ++ tests/AstJsonEncoder.test.cpp | 4 +- tests/Compiler.test.cpp | 63 ++++- tests/IrBuilder.test.cpp | 2 +- tests/Parser.test.cpp | 50 +++- tests/Simplify.test.cpp | 2 +- tests/TypeInfer.definitions.test.cpp | 30 +++ 53 files changed, 1117 insertions(+), 632 deletions(-) create mode 100644 Compiler/src/Types.cpp create mode 100644 Compiler/src/Types.h diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index ef87175ef..b13bb21bd 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -174,10 +174,10 @@ struct ConstraintSolver bool blockOnPendingTypes(TypePackId target, NotNull constraint); void unblock(NotNull progressed); - void unblock(TypeId progressed); - void unblock(TypePackId progressed); - void unblock(const std::vector& types); - void unblock(const std::vector& packs); + void unblock(TypeId progressed, Location location); + void unblock(TypePackId progressed, Location location); + void unblock(const std::vector& types, Location location); + void unblock(const std::vector& packs, Location location); /** * @returns true if the TypeId is in a blocked state. diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index a964c785f..f2943c4df 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -752,6 +752,7 @@ struct AstJsonEncoder : public AstVisitor if (node->superName) write("superName", *node->superName); PROP(props); + PROP(indexer); }); } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 821f6c260..07dba9219 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -22,6 +22,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauParseDeclareClassIndexer); namespace Luau { @@ -1157,6 +1158,23 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareC scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + if (FFlag::LuauParseDeclareClassIndexer && declaredClass->indexer) + { + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(declaredClass->indexer->location); + } + else + { + ctv->indexer = TableIndexer{ + resolveType(scope, declaredClass->indexer->indexType, /* inTypeArguments */ false), + resolveType(scope, declaredClass->indexer->resultType, /* inTypeArguments */ false), + }; + } + } + for (const AstDeclaredClassProp& prop : declaredClass->props) { Name propName(prop.name.value); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index c9ac8cc9d..b85d2c59c 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -539,8 +539,8 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullty.emplace(builtinTypes->errorRecoveryType()); } - unblock(c.generalizedType); - unblock(c.sourceType); + unblock(c.generalizedType, constraint->location); + unblock(c.sourceType, constraint->location); return true; } @@ -564,7 +564,7 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNulllocation); asMutable(c.subType)->ty.emplace(errorRecoveryType()); - unblock(c.subType); + unblock(c.subType, constraint->location); return true; } @@ -574,7 +574,7 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNullscope, constraint->location, this}; queuer.traverse(c.subType); - unblock(c.subType); + unblock(c.subType, constraint->location); return true; } @@ -597,7 +597,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(builtinTypes->booleanType); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } case AstExprUnary::Len: @@ -605,7 +605,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(builtinTypes->numberType); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } case AstExprUnary::Minus: @@ -635,7 +635,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(builtinTypes->errorRecoveryType()); } - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } } @@ -684,7 +684,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) && !isLogical)) { asMutable(resultType)->ty.emplace(errorRecoveryType()); - unblock(resultType); + unblock(resultType, constraint->location); return true; } @@ -697,7 +697,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(builtinTypes->booleanType); - unblock(resultType); + unblock(resultType, constraint->location); return true; } @@ -760,7 +760,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(mmResult); - unblock(resultType); + unblock(resultType, constraint->location); (*c.astOriginalCallTypes)[c.astFragment] = *mm; (*c.astOverloadResolvedTypes)[c.astFragment] = *instantiatedMm; @@ -790,14 +790,14 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope); asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); - unblock(resultType); + unblock(resultType, constraint->location); return true; } else if (get(leftType) || get(rightType)) { unify(leftType, rightType, constraint->scope); asMutable(resultType)->ty.emplace(builtinTypes->neverType); - unblock(resultType); + unblock(resultType, constraint->location); return true; } @@ -814,14 +814,14 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope); asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); - unblock(resultType); + unblock(resultType, constraint->location); return true; } else if (get(leftType) || get(rightType)) { unify(leftType, rightType, constraint->scope); asMutable(resultType)->ty.emplace(builtinTypes->neverType); - unblock(resultType); + unblock(resultType, constraint->location); return true; } @@ -840,14 +840,14 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullisExactlyNumber() || get(lt->tops)) && rt->isExactlyNumber()) { asMutable(resultType)->ty.emplace(builtinTypes->booleanType); - unblock(resultType); + unblock(resultType, constraint->location); return true; } if (lt && rt && (lt->isSubtypeOfString() || get(lt->tops)) && rt->isSubtypeOfString()) { asMutable(resultType)->ty.emplace(builtinTypes->booleanType); - unblock(resultType); + unblock(resultType, constraint->location); return true; } @@ -855,7 +855,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(leftType) || get(rightType)) { asMutable(resultType)->ty.emplace(builtinTypes->booleanType); - unblock(resultType); + unblock(resultType, constraint->location); return true; } @@ -867,7 +867,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(builtinTypes->booleanType); - unblock(resultType); + unblock(resultType, constraint->location); return true; // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is // truthy. @@ -876,7 +876,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullfalsyType).result; asMutable(resultType)->ty.emplace(simplifyUnion(builtinTypes, arena, rightType, leftFilteredTy).result); - unblock(resultType); + unblock(resultType, constraint->location); return true; } // Or evaluates to the LHS type if the LHS is truthy, and the RHS type if @@ -886,7 +886,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNulltruthyType).result; asMutable(resultType)->ty.emplace(simplifyUnion(builtinTypes, arena, rightType, leftFilteredTy).result); - unblock(resultType); + unblock(resultType, constraint->location); return true; } default: @@ -898,7 +898,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope); unify(rightType, errorRecoveryType(), constraint->scope); asMutable(resultType)->ty.emplace(errorRecoveryType()); - unblock(resultType); + unblock(resultType, constraint->location); return true; } @@ -1065,14 +1065,14 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul const PendingExpansionType* petv = get(follow(c.target)); if (!petv) { - unblock(c.target); + unblock(c.target, constraint->location); return true; } - auto bindResult = [this, &c](TypeId result) { + auto bindResult = [this, &c, constraint](TypeId result) { LUAU_ASSERT(get(c.target)); asMutable(c.target)->ty.emplace(result); - unblock(c.target); + unblock(c.target, constraint->location); }; std::optional tf = (petv->prefix) ? constraint->scope->lookupImportedType(petv->prefix->value, petv->name.value) @@ -1400,9 +1400,9 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullgetChanges(); bestOverloadLog->commit(); - unblock(changedTypes); - unblock(changedPacks); - unblock(c.result); + unblock(changedTypes, constraint->location); + unblock(changedPacks, constraint->location); + unblock(c.result, constraint->location); InstantiationQueuer queuer{constraint->scope, constraint->location, this}; queuer.traverse(fn); @@ -1421,7 +1421,7 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNullty.emplace(bindTo); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } @@ -1440,7 +1440,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullty.emplace(TableState::Free, TypeLevel{}, constraint->scope); ttv.props[c.prop] = Property{c.resultType}; asMutable(c.resultType)->ty.emplace(constraint->scope); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } @@ -1454,7 +1454,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullty.emplace(result.value_or(builtinTypes->anyType)); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } @@ -1568,7 +1568,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullscope); bind(c.resultType, c.subjectType); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } @@ -1593,8 +1593,8 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNulllocation); + unblock(c.resultType, constraint->location); return true; } else if (auto ttv = getMutable(subjectType)) @@ -1605,7 +1605,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullprops[c.path[0]] = Property{c.propType}; bind(c.resultType, c.subjectType); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } else if (ttv->state == TableState::Unsealed) @@ -1614,14 +1614,14 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNulllocation); + unblock(c.resultType, constraint->location); return true; } else { bind(c.resultType, subjectType); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } } @@ -1630,7 +1630,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNulllocation); return true; } } @@ -1649,8 +1649,8 @@ bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNullty.emplace(subjectType); asMutable(c.propType)->ty.emplace(scope); - unblock(c.propType); - unblock(c.resultType); + unblock(c.propType, constraint->location); + unblock(c.resultType, constraint->location); return true; } @@ -1662,8 +1662,8 @@ bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNullindexer->indexType, constraint->scope); asMutable(c.propType)->ty.emplace(tt->indexer->indexResultType); asMutable(c.resultType)->ty.emplace(subjectType); - unblock(c.propType); - unblock(c.resultType); + unblock(c.propType, constraint->location); + unblock(c.resultType, constraint->location); return true; } else if (tt->state == TableState::Free || tt->state == TableState::Unsealed) @@ -1675,8 +1675,8 @@ bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNullindexer = TableIndexer{promotedIndexTy, c.propType}; asMutable(c.propType)->ty.emplace(tt->scope); asMutable(c.resultType)->ty.emplace(subjectType); - unblock(c.propType); - unblock(c.resultType); + unblock(c.propType, constraint->location); + unblock(c.resultType, constraint->location); return true; } // Do not augment sealed or generic tables that lack indexers @@ -1684,8 +1684,8 @@ bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNullty.emplace(builtinTypes->errorRecoveryType()); asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); - unblock(c.propType); - unblock(c.resultType); + unblock(c.propType, constraint->location); + unblock(c.resultType, constraint->location); return true; } @@ -1704,7 +1704,7 @@ bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNul else *asMutable(c.resultType) = BoundType{builtinTypes->anyType}; - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } @@ -1720,7 +1720,7 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullty.emplace(sourcePack); - unblock(resultPack); + unblock(resultPack, constraint->location); return true; } @@ -1745,7 +1745,7 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullty.emplace(srcTy); - unblock(*destIter); + unblock(*destIter, constraint->location); } else unify(*destIter, srcTy, constraint->scope); @@ -1763,7 +1763,7 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullty.emplace(builtinTypes->errorRecoveryType()); - unblock(*destIter); + unblock(*destIter, constraint->location); } ++destIter; @@ -1852,7 +1852,7 @@ bool ConstraintSolver::tryDispatch(const RefineConstraint& c, NotNullty.emplace(c.type); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } @@ -1880,7 +1880,7 @@ bool ConstraintSolver::tryDispatch(const RefineConstraint& c, NotNullty.emplace(c.discriminant); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } @@ -1892,7 +1892,7 @@ bool ConstraintSolver::tryDispatch(const RefineConstraint& c, NotNullty.emplace(result); - unblock(c.resultType); + unblock(c.resultType, constraint->location); return true; } @@ -1904,10 +1904,10 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNulllocation, NotNull{arena}, builtinTypes, constraint->scope, normalizer, nullptr, force); for (TypeId r : result.reducedTypes) - unblock(r); + unblock(r, constraint->location); for (TypePackId r : result.reducedPacks) - unblock(r); + unblock(r, constraint->location); if (force) return true; @@ -1928,10 +1928,10 @@ bool ConstraintSolver::tryDispatch(const ReducePackConstraint& c, NotNulllocation, NotNull{arena}, builtinTypes, constraint->scope, normalizer, nullptr, force); for (TypeId r : result.reducedTypes) - unblock(r); + unblock(r, constraint->location); for (TypePackId r : result.reducedPacks) - unblock(r); + unblock(r, constraint->location); if (force) return true; @@ -2374,8 +2374,8 @@ bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, u.log.commit(); - unblock(changedTypes); - unblock(changedPacks); + unblock(changedTypes, constraint->location); + unblock(changedPacks, constraint->location); return true; } @@ -2509,7 +2509,7 @@ void ConstraintSolver::unblock(NotNull progressed) return unblock_(progressed.get()); } -void ConstraintSolver::unblock(TypeId ty) +void ConstraintSolver::unblock(TypeId ty, Location location) { DenseHashSet seen{nullptr}; @@ -2517,7 +2517,7 @@ void ConstraintSolver::unblock(TypeId ty) while (true) { if (seen.find(progressed)) - iceReporter.ice("ConstraintSolver::unblock encountered a self-bound type!"); + iceReporter.ice("ConstraintSolver::unblock encountered a self-bound type!", location); seen.insert(progressed); if (logger) @@ -2532,7 +2532,7 @@ void ConstraintSolver::unblock(TypeId ty) } } -void ConstraintSolver::unblock(TypePackId progressed) +void ConstraintSolver::unblock(TypePackId progressed, Location) { if (logger) logger->popBlock(progressed); @@ -2540,16 +2540,16 @@ void ConstraintSolver::unblock(TypePackId progressed) return unblock_(progressed); } -void ConstraintSolver::unblock(const std::vector& types) +void ConstraintSolver::unblock(const std::vector& types, Location location) { for (TypeId t : types) - unblock(t); + unblock(t, location); } -void ConstraintSolver::unblock(const std::vector& packs) +void ConstraintSolver::unblock(const std::vector& packs, Location location) { for (TypePackId t : packs) - unblock(t); + unblock(t, location); } bool ConstraintSolver::isBlocked(TypeId ty) @@ -2586,8 +2586,8 @@ ErrorVec ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull Frontend::checkQueuedModules(std::optional nextItems; + std::optional itemWithException; while (remaining != 0) { @@ -603,17 +605,25 @@ std::vector Frontend::checkQueuedModules(std::optional Frontend::checkQueuedModules(std::optional checkedModules; @@ -1104,6 +1123,8 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorname = sourceModule.name; result->humanReadableName = sourceModule.humanReadableName; + iceHandler->moduleName = sourceModule.name; + std::unique_ptr logger; if (recordJsonLog) { @@ -1189,9 +1210,19 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect prepareModuleScope(name, scope, forAutocomplete); }; - return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, - NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, - environmentScope ? *environmentScope : globals.globalScope, prepareModuleScopeWrap, options, recordJsonLog); + try + { + return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, + NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, + environmentScope ? *environmentScope : globals.globalScope, prepareModuleScopeWrap, options, recordJsonLog); + } + catch (const InternalCompilerError& err) + { + InternalCompilerError augmented = err.location.has_value() + ? InternalCompilerError{err.message, sourceModule.humanReadableName, *err.location} + : InternalCompilerError{err.message, sourceModule.humanReadableName}; + throw augmented; + } } else { diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 3af7e8574..e4f22f331 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -2117,15 +2117,15 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there TypeId hmtable = nullptr; if (const MetatableType* hmtv = get(here)) { - htable = hmtv->table; - hmtable = hmtv->metatable; + htable = follow(hmtv->table); + hmtable = follow(hmtv->metatable); } TypeId ttable = there; TypeId tmtable = nullptr; if (const MetatableType* tmtv = get(there)) { - ttable = tmtv->table; - tmtable = tmtv->metatable; + ttable = follow(tmtv->table); + tmtable = follow(tmtv->metatable); } const TableType* httv = get(htable); diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index 8e9424ae0..e17df3870 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -6,6 +6,7 @@ #include "Luau/ToString.h" #include "Luau/TypeArena.h" #include "Luau/Normalize.h" // TypeIds +#include LUAU_FASTINT(LuauTypeReductionRecursionLimit) @@ -236,6 +237,17 @@ Relation relateTables(TypeId left, TypeId right) NotNull leftTable{get(left)}; NotNull rightTable{get(right)}; LUAU_ASSERT(1 == rightTable->props.size()); + // Disjoint props have nothing in common + // t1 with props p1's cannot appear in t2 and t2 with props p2's cannot appear in t1 + bool foundPropFromLeftInRight = std::any_of(begin(leftTable->props), end(leftTable->props), [&](auto prop) { + return rightTable->props.find(prop.first) != end(rightTable->props); + }); + bool foundPropFromRightInLeft = std::any_of(begin(rightTable->props), end(rightTable->props), [&](auto prop) { + return leftTable->props.find(prop.first) != end(leftTable->props); + }); + + if (!(foundPropFromLeftInRight || foundPropFromRightInLeft) && leftTable->props.size() >= 1 && rightTable->props.size() >= 1) + return Relation::Disjoint; const auto [propName, rightProp] = *begin(rightTable->props); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 26cbdc683..655881a3f 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -111,11 +111,14 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a else if constexpr (std::is_same_v) return dest.addType(a); else if constexpr (std::is_same_v) - return ty; + return dest.addType(a); else if constexpr (std::is_same_v) return ty; else if constexpr (std::is_same_v) - return ty; + { + PendingExpansionType clone = PendingExpansionType{a.prefix, a.name, a.typeArguments, a.packArguments}; + return dest.addType(std::move(clone)); + } else if constexpr (std::is_same_v) return ty; else if constexpr (std::is_same_v) diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index dba95479c..3a1217bfc 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAG(LuauParseDeclareClassIndexer); + static char* allocateString(Luau::Allocator& allocator, std::string_view contents) { char* result = (char*)allocator.allocate(contents.size() + 1); @@ -227,7 +229,17 @@ class TypeRehydrationVisitor idx++; } - return allocator->alloc(Location(), props); + AstTableIndexer* indexer = nullptr; + if (FFlag::LuauParseDeclareClassIndexer && ctv.indexer) + { + RecursionCounter counter(&count); + + indexer = allocator->alloc(); + indexer->indexType = Luau::visit(*this, ctv.indexer->indexType->ty); + indexer->resultType = Luau::visit(*this, ctv.indexer->indexResultType->ty); + } + + return allocator->alloc(Location(), props, indexer); } AstType* operator()(const FunctionType& ftv) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 5127febee..a3d917045 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -41,6 +41,7 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckClassTypeIndexers, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) +LUAU_FASTFLAG(LuauParseDeclareClassIndexer) namespace Luau { @@ -1757,6 +1758,9 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& if (!ctv->metatable) ice("No metatable for declared class"); + if (const auto& indexer = declaredClass.indexer; FFlag::LuauParseDeclareClassIndexer && indexer) + ctv->indexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + TableType* metatable = getMutable(*ctv->metatable); for (const AstDeclaredClassProp& prop : declaredClass.props) { diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index a486ad0f9..f9f9ab416 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -801,12 +801,20 @@ struct AstDeclaredClassProp bool isMethod = false; }; +struct AstTableIndexer +{ + AstType* indexType; + AstType* resultType; + Location location; +}; + class AstStatDeclareClass : public AstStat { public: LUAU_RTTI(AstStatDeclareClass) - AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props); + AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props, + AstTableIndexer* indexer = nullptr); void visit(AstVisitor* visitor) override; @@ -814,6 +822,7 @@ class AstStatDeclareClass : public AstStat std::optional superName; AstArray props; + AstTableIndexer* indexer; }; class AstType : public AstNode @@ -862,13 +871,6 @@ struct AstTableProp AstType* type; }; -struct AstTableIndexer -{ - AstType* indexType; - AstType* resultType; - Location location; -}; - class AstTypeTable : public AstType { public: diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index d2c552a3c..3c87e36cb 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -714,12 +714,13 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor) } } -AstStatDeclareClass::AstStatDeclareClass( - const Location& location, const AstName& name, std::optional superName, const AstArray& props) +AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, + const AstArray& props, AstTableIndexer* indexer) : AstStat(ClassIndex(), location) , name(name) , superName(superName) , props(props) + , indexer(indexer) { } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 7cae609d1..cc5d7b381 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -13,6 +13,7 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) +LUAU_FASTFLAGVARIABLE(LuauParseDeclareClassIndexer, false) #define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" @@ -877,6 +878,7 @@ AstStat* Parser::parseDeclaration(const Location& start) } TempVector props(scratchDeclaredClassProps); + AstTableIndexer* indexer = nullptr; while (lexer.current().type != Lexeme::ReservedEnd) { @@ -885,7 +887,8 @@ AstStat* Parser::parseDeclaration(const Location& start) { props.push_back(parseDeclaredClassMethod()); } - else if (lexer.current().type == '[') + else if (lexer.current().type == '[' && (!FFlag::LuauParseDeclareClassIndexer || lexer.lookahead().type == Lexeme::RawString || + lexer.lookahead().type == Lexeme::QuotedString)) { const Lexeme begin = lexer.current(); nextLexeme(); // [ @@ -904,6 +907,22 @@ AstStat* Parser::parseDeclaration(const Location& start) else report(begin.location, "String literal contains malformed escape sequence"); } + else if (lexer.current().type == '[' && FFlag::LuauParseDeclareClassIndexer) + { + if (indexer) + { + // maybe we don't need to parse the entire badIndexer... + // however, we either have { or [ to lint, not the entire table type or the bad indexer. + AstTableIndexer* badIndexer = parseTableIndexer(); + + // we lose all additional indexer expressions from the AST after error recovery here + report(badIndexer->location, "Cannot have more than one class indexer"); + } + else + { + indexer = parseTableIndexer(); + } + } else { Name propName = parseName("property name"); @@ -916,7 +935,7 @@ AstStat* Parser::parseDeclaration(const Location& start) Location classEnd = lexer.current().location; nextLexeme(); // skip past `end` - return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props)); + return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props), indexer); } else if (std::optional globalName = parseNameOpt("global variable name")) { diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index a585a73a2..87ce27177 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -6,7 +6,6 @@ #include "Luau/CodeGen.h" #include "Luau/Compiler.h" -#include "Luau/BytecodeBuilder.h" #include "Luau/Parser.h" #include "Luau/TimeTrace.h" @@ -40,27 +39,6 @@ LUAU_FASTFLAG(DebugLuauTimeTracing) -enum class CliMode -{ - Unknown, - Repl, - Compile, - RunSourceFiles -}; - -enum class CompileFormat -{ - Text, - Binary, - Remarks, - Codegen, // Prints annotated native code including IR and assembly - CodegenAsm, // Prints annotated native code assembly - CodegenIr, // Prints annotated native code IR - CodegenVerbose, // Prints annotated native code including IR, assembly and outlined code - CodegenNull, - Null -}; - constexpr int MaxTraversalLimit = 50; static bool codegen = false; @@ -668,178 +646,11 @@ static bool runFile(const char* name, lua_State* GL, bool repl) return status == 0; } -static void report(const char* name, const Luau::Location& location, const char* type, const char* message) -{ - fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); -} - -static void reportError(const char* name, const Luau::ParseError& error) -{ - report(name, error.getLocation(), "SyntaxError", error.what()); -} - -static void reportError(const char* name, const Luau::CompileError& error) -{ - report(name, error.getLocation(), "CompileError", error.what()); -} - -static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) -{ - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) - return Luau::CodeGen::getAssembly(L, -1, options); - - fprintf(stderr, "Error loading bytecode %s\n", name); - return ""; -} - -static void annotateInstruction(void* context, std::string& text, int fid, int instpos) -{ - Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; - - bcb.annotateInstruction(text, fid, instpos); -} - -struct CompileStats -{ - size_t lines; - size_t bytecode; - size_t codegen; - - double readTime; - double miscTime; - double parseTime; - double compileTime; - double codegenTime; -}; - -static double recordDeltaTime(double& timer) -{ - double now = Luau::TimeTrace::getClock(); - double delta = now - timer; - timer = now; - return delta; -} - -static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) -{ - double currts = Luau::TimeTrace::getClock(); - - std::optional source = readFile(name); - if (!source) - { - fprintf(stderr, "Error opening %s\n", name); - return false; - } - - stats.readTime += recordDeltaTime(currts); - - // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) - // This function is much more complicated because it supports many output human-readable formats through internal interfaces - - try - { - Luau::BytecodeBuilder bcb; - - Luau::CodeGen::AssemblyOptions options; - options.outputBinary = format == CompileFormat::CodegenNull; - - if (!options.outputBinary) - { - options.includeAssembly = format != CompileFormat::CodegenIr; - options.includeIr = format != CompileFormat::CodegenAsm; - options.includeOutlinedCode = format == CompileFormat::CodegenVerbose; - } - - options.annotator = annotateInstruction; - options.annotatorContext = &bcb; - - if (format == CompileFormat::Text) - { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | - Luau::BytecodeBuilder::Dump_Remarks); - bcb.setDumpSource(*source); - } - else if (format == CompileFormat::Remarks) - { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); - bcb.setDumpSource(*source); - } - else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || - format == CompileFormat::CodegenVerbose) - { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | - Luau::BytecodeBuilder::Dump_Remarks); - bcb.setDumpSource(*source); - } - - stats.miscTime += recordDeltaTime(currts); - - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); - Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); - - if (!result.errors.empty()) - throw Luau::ParseErrors(result.errors); - - stats.lines += result.lines; - stats.parseTime += recordDeltaTime(currts); - - Luau::compileOrThrow(bcb, result, names, copts()); - stats.bytecode += bcb.getBytecode().size(); - stats.compileTime += recordDeltaTime(currts); - - switch (format) - { - case CompileFormat::Text: - printf("%s", bcb.dumpEverything().c_str()); - break; - case CompileFormat::Remarks: - printf("%s", bcb.dumpSourceRemarks().c_str()); - break; - case CompileFormat::Binary: - fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); - break; - case CompileFormat::Codegen: - case CompileFormat::CodegenAsm: - case CompileFormat::CodegenIr: - case CompileFormat::CodegenVerbose: - printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); - break; - case CompileFormat::CodegenNull: - stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); - stats.codegenTime += recordDeltaTime(currts); - break; - case CompileFormat::Null: - break; - } - - return true; - } - catch (Luau::ParseErrors& e) - { - for (auto& error : e.getErrors()) - reportError(name, error); - return false; - } - catch (Luau::CompileError& e) - { - reportError(name, e); - return false; - } -} - static void displayHelp(const char* argv0) { - printf("Usage: %s [--mode] [options] [file list]\n", argv0); - printf("\n"); - printf("When mode and file list are omitted, an interactive REPL is started instead.\n"); + printf("Usage: %s [options] [file list]\n", argv0); printf("\n"); - printf("Available modes:\n"); - printf(" omitted: compile and run input files one by one\n"); - printf(" --compile[=format]: compile input files and output resulting bytecode/assembly (binary, text, remarks, codegen)\n"); + printf("When file list is omitted, an interactive REPL is started instead.\n"); printf("\n"); printf("Available options:\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); @@ -864,67 +675,12 @@ int replMain(int argc, char** argv) setLuauFlagsDefault(); - CliMode mode = CliMode::Unknown; - CompileFormat compileFormat{}; int profile = 0; bool coverage = false; bool interactive = false; bool codegenPerf = false; - // Set the mode if the user has explicitly specified one. - int argStart = 1; - if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) - { - argStart++; - mode = CliMode::Compile; - if (strcmp(argv[1], "--compile") == 0) - { - compileFormat = CompileFormat::Text; - } - else if (strcmp(argv[1], "--compile=binary") == 0) - { - compileFormat = CompileFormat::Binary; - } - else if (strcmp(argv[1], "--compile=text") == 0) - { - compileFormat = CompileFormat::Text; - } - else if (strcmp(argv[1], "--compile=remarks") == 0) - { - compileFormat = CompileFormat::Remarks; - } - else if (strcmp(argv[1], "--compile=codegen") == 0) - { - compileFormat = CompileFormat::Codegen; - } - else if (strcmp(argv[1], "--compile=codegenasm") == 0) - { - compileFormat = CompileFormat::CodegenAsm; - } - else if (strcmp(argv[1], "--compile=codegenir") == 0) - { - compileFormat = CompileFormat::CodegenIr; - } - else if (strcmp(argv[1], "--compile=codegenverbose") == 0) - { - compileFormat = CompileFormat::CodegenVerbose; - } - else if (strcmp(argv[1], "--compile=codegennull") == 0) - { - compileFormat = CompileFormat::CodegenNull; - } - else if (strcmp(argv[1], "--compile=null") == 0) - { - compileFormat = CompileFormat::Null; - } - else - { - fprintf(stderr, "Error: Unrecognized value for '--compile' specified.\n"); - return 1; - } - } - - for (int i = argStart; i < argc; i++) + for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { @@ -1026,50 +782,20 @@ int replMain(int argc, char** argv) #endif } - const std::vector files = getSourceFiles(argc, argv); - if (mode == CliMode::Unknown) - { - mode = files.empty() ? CliMode::Repl : CliMode::RunSourceFiles; - } - - if (mode != CliMode::Compile && codegen && !Luau::CodeGen::isSupported()) + if (codegen && !Luau::CodeGen::isSupported()) { fprintf(stderr, "Cannot enable --codegen, native code generation is not supported in current configuration\n"); return 1; } - switch (mode) - { - case CliMode::Compile: - { -#ifdef _WIN32 - if (compileFormat == CompileFormat::Binary) - _setmode(_fileno(stdout), _O_BINARY); -#endif - - CompileStats stats = {}; - int failed = 0; - - for (const std::string& path : files) - failed += !compileFile(path.c_str(), compileFormat, stats); - - if (compileFormat == CompileFormat::Null) - printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), - int(stats.bytecode / 1024), stats.readTime, stats.parseTime, stats.compileTime); - else if (compileFormat == CompileFormat::CodegenNull) - printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", - int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024), - stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime, - stats.codegenTime); + const std::vector files = getSourceFiles(argc, argv); - return failed ? 1 : 0; - } - case CliMode::Repl: + if (files.empty()) { runRepl(); return 0; } - case CliMode::RunSourceFiles: + else { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); @@ -1101,9 +827,4 @@ int replMain(int argc, char** argv) return failed ? 1 : 0; } - case CliMode::Unknown: - default: - LUAU_ASSERT(!"Unhandled cli mode."); - return 1; - } } diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 9e7d50116..aea01eec7 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -98,6 +98,8 @@ class AssemblyBuilderX64 void call(Label& label); void call(OperandX64 op); + void lea(RegisterX64 lhs, Label& label); + void int3(); void ud2(); @@ -243,6 +245,7 @@ class AssemblyBuilderX64 LUAU_NOINLINE void log(const char* opcode, OperandX64 op1, OperandX64 op2, OperandX64 op3, OperandX64 op4); LUAU_NOINLINE void log(Label label); LUAU_NOINLINE void log(const char* opcode, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterX64 reg, Label label); void log(OperandX64 op); const char* getSizeName(SizeX64 size) const; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 4a3fa4242..1c79ccb47 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -801,6 +801,8 @@ struct IrBlock uint32_t start = ~0u; uint32_t finish = ~0u; + uint32_t sortkey = ~0u; + Label label; }; diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index 179edd0de..2f86ebf08 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -38,6 +38,8 @@ std::string toString(const IrFunction& function, bool includeUseInfo); std::string dump(const IrFunction& function); std::string toDot(const IrFunction& function, bool includeInst); +std::string toDotCfg(const IrFunction& function); +std::string toDotDjGraph(const IrFunction& function); std::string dumpDot(const IrFunction& function, bool includeInst); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index c7644a86c..2a8bc92ee 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -463,6 +463,20 @@ void AssemblyBuilderX64::call(OperandX64 op) commit(); } +void AssemblyBuilderX64::lea(RegisterX64 lhs, Label& label) +{ + LUAU_ASSERT(lhs.size == SizeX64::qword); + + placeBinaryRegAndRegMem(lhs, OperandX64(SizeX64::qword, noreg, 1, rip, 0), 0x8d, 0x8d); + + codePos -= 4; + placeLabel(label); + commit(); + + if (logText) + log("lea", lhs, label); +} + void AssemblyBuilderX64::int3() { if (logText) @@ -1415,7 +1429,7 @@ void AssemblyBuilderX64::commit() { LUAU_ASSERT(codePos <= codeEnd); - if (codeEnd - codePos < kMaxInstructionLength) + if (unsigned(codeEnd - codePos) < kMaxInstructionLength) extend(); } @@ -1501,6 +1515,14 @@ void AssemblyBuilderX64::log(const char* opcode, Label label) logAppend(" %-12s.L%d\n", opcode, label.id); } +void AssemblyBuilderX64::log(const char* opcode, RegisterX64 reg, Label label) +{ + logAppend(" %-12s", opcode); + log(reg); + text.append(","); + logAppend(".L%d\n", label.id); +} + void AssemblyBuilderX64::log(OperandX64 op) { switch (op.cat) diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index 09e1bb712..880a32446 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -56,8 +56,10 @@ static void makePagesExecutable(uint8_t* mem, size_t size) static void flushInstructionCache(uint8_t* mem, size_t size) { +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) if (FlushInstructionCache(GetCurrentProcess(), mem, size) == 0) LUAU_ASSERT(!"Failed to flush instruction cache"); +#endif } #else static uint8_t* allocatePages(size_t size) diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 89399cbc9..d7283b4a6 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -125,7 +125,7 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); // Try to order by instruction order - return a.start < b.start; + return a.sortkey < b.sortkey; }); // For each IR instruction that begins a bytecode instruction, which bytecode instruction is it? @@ -234,6 +234,8 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& build.setLabel(abandoned.label); } + lowering.finishFunction(); + return false; } } @@ -244,7 +246,15 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& build.logAppend("#\n"); } - if (outputEnabled && !options.includeOutlinedCode && seenFallback) + if (!seenFallback) + { + textSize = build.text.length(); + codeSize = build.getCodeSize(); + } + + lowering.finishFunction(); + + if (outputEnabled && !options.includeOutlinedCode && textSize < build.text.size()) { build.text.resize(textSize); @@ -594,6 +604,12 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) X64::assembleHelpers(build, helpers); #endif + if (!options.includeOutlinedCode && options.includeAssembly) + { + build.text.clear(); + build.logAppend("; skipping %u bytes of outlined helpers\n", unsigned(build.getCodeSize() * sizeof(build.code[0]))); + } + for (Proto* p : protos) if (p) if (std::optional np = assembleFunction(build, data, helpers, p, options)) diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 355e29ca6..cc0131820 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -288,27 +288,27 @@ void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers) { if (build.logText) build.logAppend("; exitContinueVm\n"); - helpers.exitContinueVm = build.setLabel(); + build.setLabel(helpers.exitContinueVm); emitExit(build, /* continueInVm */ true); if (build.logText) build.logAppend("; exitNoContinueVm\n"); - helpers.exitNoContinueVm = build.setLabel(); + build.setLabel(helpers.exitNoContinueVm); emitExit(build, /* continueInVm */ false); if (build.logText) build.logAppend("; reentry\n"); - helpers.reentry = build.setLabel(); + build.setLabel(helpers.reentry); emitReentry(build, helpers); if (build.logText) build.logAppend("; interrupt\n"); - helpers.interrupt = build.setLabel(); + build.setLabel(helpers.interrupt); emitInterrupt(build); if (build.logText) build.logAppend("; return\n"); - helpers.return_ = build.setLabel(); + build.setLabel(helpers.return_); emitReturn(build, helpers); } diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 4100e667c..41c3dbd05 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -56,6 +56,11 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde locations.start = build.setLabel(); unwind.startFunction(); + RegisterX64 rArg1 = (build.abi == ABIX64::Windows) ? rcx : rdi; + RegisterX64 rArg2 = (build.abi == ABIX64::Windows) ? rdx : rsi; + RegisterX64 rArg3 = (build.abi == ABIX64::Windows) ? r8 : rdx; + RegisterX64 rArg4 = (build.abi == ABIX64::Windows) ? r9 : rcx; + // Save common non-volatile registers if (build.abi == ABIX64::SystemV) { @@ -177,22 +182,27 @@ void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) { if (build.logText) build.logAppend("; exitContinueVm\n"); - helpers.exitContinueVm = build.setLabel(); + build.setLabel(helpers.exitContinueVm); emitExit(build, /* continueInVm */ true); if (build.logText) build.logAppend("; exitNoContinueVm\n"); - helpers.exitNoContinueVm = build.setLabel(); + build.setLabel(helpers.exitNoContinueVm); emitExit(build, /* continueInVm */ false); if (build.logText) build.logAppend("; continueCallInVm\n"); - helpers.continueCallInVm = build.setLabel(); + build.setLabel(helpers.continueCallInVm); emitContinueCallInVm(build); + if (build.logText) + build.logAppend("; interrupt\n"); + build.setLabel(helpers.interrupt); + emitInterrupt(build); + if (build.logText) build.logAppend("; return\n"); - helpers.return_ = build.setLabel(); + build.setLabel(helpers.return_); emitReturn(build, helpers); } diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index bfdde1690..f912ffba7 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -25,13 +25,13 @@ struct ModuleHelpers Label exitContinueVm; Label exitNoContinueVm; Label return_; + Label interrupt; // X64 Label continueCallInVm; // A64 Label reentry; // x0: closure - Label interrupt; // x0: pc offset, x1: return address, x2: interrupt }; } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 4ad4efe7b..f240d26f7 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -278,39 +278,34 @@ void emitUpdateBase(AssemblyBuilderX64& build) build.mov(rBase, qword[rState + offsetof(lua_State, base)]); } -static void emitSetSavedPc(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos) +void emitInterrupt(AssemblyBuilderX64& build) { - ScopedRegX64 tmp1{regs, SizeX64::qword}; - ScopedRegX64 tmp2{regs, SizeX64::qword}; + // rax = pcpos + 1 + // rbx = return address in native code - build.mov(tmp1.reg, sCode); - build.add(tmp1.reg, pcpos * sizeof(Instruction)); - build.mov(tmp2.reg, qword[rState + offsetof(lua_State, ci)]); - build.mov(qword[tmp2.reg + offsetof(CallInfo, savedpc)], tmp1.reg); -} + // note: rbx is non-volatile so it will be saved across interrupt call automatically + + RegisterX64 rArg1 = (build.abi == ABIX64::Windows) ? rcx : rdi; + RegisterX64 rArg2 = (build.abi == ABIX64::Windows) ? rdx : rsi; -void emitInterrupt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos) -{ Label skip; - ScopedRegX64 tmp{regs, SizeX64::qword}; + // Update L->ci->savedpc; required in case interrupt errors + build.mov(rcx, sCode); + build.lea(rcx, addr[rcx + rax * sizeof(Instruction)]); + build.mov(rax, qword[rState + offsetof(lua_State, ci)]); + build.mov(qword[rax + offsetof(CallInfo, savedpc)], rcx); - // Skip if there is no interrupt set - build.mov(tmp.reg, qword[rState + offsetof(lua_State, global)]); - build.mov(tmp.reg, qword[tmp.reg + offsetof(global_State, cb.interrupt)]); - build.test(tmp.reg, tmp.reg); + // Load interrupt handler; it may be nullptr in case the update raced with the check before we got here + build.mov(rax, qword[rState + offsetof(lua_State, global)]); + build.mov(rax, qword[rax + offsetof(global_State, cb.interrupt)]); + build.test(rax, rax); build.jcc(ConditionX64::Zero, skip); - emitSetSavedPc(regs, build, pcpos + 1); - // Call interrupt - // TODO: This code should move to the end of the function, or even be outlined so that it can be shared by multiple interruptible instructions - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::dword, -1); - callWrap.call(tmp.release()); - - emitUpdateBase(build); // interrupt may have reallocated stack + build.mov(rArg1, rState); + build.mov(dwordReg(rArg2), -1); + build.call(rax); // Check if we need to exit build.mov(al, byte[rState + offsetof(lua_State, status)]); @@ -322,6 +317,10 @@ void emitInterrupt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos) emitExit(build, /* continueInVm */ false); build.setLabel(skip); + + emitUpdateBase(build); // interrupt may have reallocated stack + + build.jmp(rbx); } void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos) @@ -354,14 +353,15 @@ void emitContinueCallInVm(AssemblyBuilderX64& build) void emitReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers) { - // input: ci in r8, res in rdi, number of written values in ecx - RegisterX64 ci = r8; + // input: res in rdi, number of written values in ecx RegisterX64 res = rdi; RegisterX64 written = ecx; + RegisterX64 ci = r8; RegisterX64 cip = r9; RegisterX64 nresults = esi; + build.mov(ci, qword[rState + offsetof(lua_State, ci)]); build.lea(cip, addr[ci - sizeof(CallInfo)]); // nresults = ci->nresults diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index eb4532a0d..37be73fda 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -53,31 +53,6 @@ constexpr OperandX64 sCode = qword[rsp + kStackSize + 8]; // Instruction* cod constexpr OperandX64 sTemporarySlot = addr[rsp + kStackSize + 16]; constexpr OperandX64 sSpillArea = addr[rsp + kStackSize + 24]; -// TODO: These should be replaced with a portable call function that checks the ABI at runtime and reorders moves accordingly to avoid conflicts -#if defined(_WIN32) - -constexpr RegisterX64 rArg1 = rcx; -constexpr RegisterX64 rArg2 = rdx; -constexpr RegisterX64 rArg3 = r8; -constexpr RegisterX64 rArg4 = r9; -constexpr RegisterX64 rArg5 = noreg; -constexpr RegisterX64 rArg6 = noreg; -constexpr OperandX64 sArg5 = qword[rsp + 32]; -constexpr OperandX64 sArg6 = qword[rsp + 40]; - -#else - -constexpr RegisterX64 rArg1 = rdi; -constexpr RegisterX64 rArg2 = rsi; -constexpr RegisterX64 rArg3 = rdx; -constexpr RegisterX64 rArg4 = rcx; -constexpr RegisterX64 rArg5 = r8; -constexpr RegisterX64 rArg6 = r9; -constexpr OperandX64 sArg5 = noreg; -constexpr OperandX64 sArg6 = noreg; - -#endif - inline OperandX64 luauReg(int ri) { return xmmword[rBase + ri * sizeof(TValue)]; @@ -202,7 +177,7 @@ void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitUpdateBase(AssemblyBuilderX64& build); -void emitInterrupt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos); +void emitInterrupt(AssemblyBuilderX64& build); void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos); void emitContinueCallInVm(AssemblyBuilderX64& build); diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 5d1c642fe..61d5ac63e 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -18,6 +18,12 @@ namespace X64 void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults) { + // TODO: This should use IrCallWrapperX64 + RegisterX64 rArg1 = (build.abi == ABIX64::Windows) ? rcx : rdi; + RegisterX64 rArg2 = (build.abi == ABIX64::Windows) ? rdx : rsi; + RegisterX64 rArg3 = (build.abi == ABIX64::Windows) ? r8 : rdx; + RegisterX64 rArg4 = (build.abi == ABIX64::Windows) ? r9 : rcx; + build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(ra)); @@ -163,20 +169,34 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int } } -void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults) +void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults, bool functionVariadic) { - RegisterX64 ci = r8; RegisterX64 res = rdi; RegisterX64 written = ecx; - build.mov(ci, qword[rState + offsetof(lua_State, ci)]); - build.mov(res, qword[ci + offsetof(CallInfo, func)]); + if (functionVariadic) + { + build.mov(res, qword[rState + offsetof(lua_State, ci)]); + build.mov(res, qword[res + offsetof(CallInfo, func)]); + } + else if (actualResults != 1) + build.lea(res, addr[rBase - sizeof(TValue)]); // invariant: ci->func + 1 == ci->base for non-variadic frames if (actualResults == 0) { build.xor_(written, written); build.jmp(helpers.return_); } + else if (actualResults == 1 && !functionVariadic) + { + // fast path: minimizes res adjustments + // note that we skipped res computation for this specific case above + build.vmovups(xmm0, luauReg(ra)); + build.vmovups(xmmword[rBase - sizeof(TValue)], xmm0); + build.mov(res, rBase); + build.mov(written, 1); + build.jmp(helpers.return_); + } else if (actualResults >= 1 && actualResults <= 3) { for (int r = 0; r < actualResults; ++r) @@ -206,8 +226,11 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i Label repeatValueLoop, exitValueLoop; - build.cmp(vali, valend); - build.jcc(ConditionX64::NotBelow, exitValueLoop); + if (actualResults == LUA_MULTRET) + { + build.cmp(vali, valend); + build.jcc(ConditionX64::NotBelow, exitValueLoop); + } build.setLabel(repeatValueLoop); build.vmovups(xmm0, xmmword[vali]); @@ -225,6 +248,11 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index) { + // TODO: This should use IrCallWrapperX64 + RegisterX64 rArg1 = (build.abi == ABIX64::Windows) ? rcx : rdi; + RegisterX64 rArg2 = (build.abi == ABIX64::Windows) ? rdx : rsi; + RegisterX64 rArg3 = (build.abi == ABIX64::Windows) ? r8 : rdx; + OperandX64 last = index + count - 1; // Using non-volatile 'rbx' for dynamic 'count' value (for LUA_MULTRET) to skip later recomputation @@ -327,6 +355,12 @@ void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRep // ipairs-style traversal is handled in IR LUAU_ASSERT(aux >= 0); + // TODO: This should use IrCallWrapperX64 + RegisterX64 rArg1 = (build.abi == ABIX64::Windows) ? rcx : rdi; + RegisterX64 rArg2 = (build.abi == ABIX64::Windows) ? rdx : rsi; + RegisterX64 rArg3 = (build.abi == ABIX64::Windows) ? r8 : rdx; + RegisterX64 rArg4 = (build.abi == ABIX64::Windows) ? r9 : rcx; + // This is a fast-path for builtin table iteration, tag check for 'ra' has to be performed before emitting this instruction // Registers are chosen in this way to simplify fallback code for the node part diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index 84fe11309..b248b7e81 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -18,7 +18,7 @@ class AssemblyBuilderX64; struct IrRegAllocX64; void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); -void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults); +void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults, bool functionVariadic); void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index); void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 6ab5e249f..98db2977e 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -429,6 +429,7 @@ void IrBuilder::beginBlock(IrOp block) LUAU_ASSERT(target.start == ~0u || target.start == uint32_t(function.instructions.size())); target.start = uint32_t(function.instructions.size()); + target.sortkey = target.start; inTerminatedBlock = false; } diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 7ea9b7904..09cafbaa8 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -656,28 +656,23 @@ std::string dump(const IrFunction& function) return result; } -std::string toDot(const IrFunction& function, bool includeInst) +static void appendLabelRegset(IrToStringContext& ctx, const std::vector& regSets, size_t blockIdx, const char* name) { - std::string result; - IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + if (blockIdx < regSets.size()) + { + const RegisterSet& rs = regSets[blockIdx]; - auto appendLabelRegset = [&ctx](const std::vector& regSets, size_t blockIdx, const char* name) { - if (blockIdx < regSets.size()) + if (rs.regs.any() || rs.varargSeq) { - const RegisterSet& rs = regSets[blockIdx]; - - if (rs.regs.any() || rs.varargSeq) - { - append(ctx.result, "|{%s|", name); - appendRegisterSet(ctx, rs, "|"); - append(ctx.result, "}"); - } + append(ctx.result, "|{%s|", name); + appendRegisterSet(ctx, rs, "|"); + append(ctx.result, "}"); } - }; - - append(ctx.result, "digraph CFG {\n"); - append(ctx.result, "node[shape=record]\n"); + } +} +static void appendBlocks(IrToStringContext& ctx, const IrFunction& function, bool includeInst, bool includeIn, bool includeOut, bool includeDef) +{ for (size_t i = 0; i < function.blocks.size(); i++) { const IrBlock& block = function.blocks[i]; @@ -692,7 +687,8 @@ std::string toDot(const IrFunction& function, bool includeInst) append(ctx.result, "label=\"{"); toString(ctx, block, uint32_t(i)); - appendLabelRegset(ctx.cfg.in, i, "in"); + if (includeIn) + appendLabelRegset(ctx, ctx.cfg.in, i, "in"); if (includeInst && block.start != ~0u) { @@ -709,11 +705,25 @@ std::string toDot(const IrFunction& function, bool includeInst) } } - appendLabelRegset(ctx.cfg.def, i, "def"); - appendLabelRegset(ctx.cfg.out, i, "out"); + if (includeDef) + appendLabelRegset(ctx, ctx.cfg.def, i, "def"); + + if (includeOut) + appendLabelRegset(ctx, ctx.cfg.out, i, "out"); append(ctx.result, "}\"];\n"); } +} + +std::string toDot(const IrFunction& function, bool includeInst) +{ + std::string result; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + + append(ctx.result, "digraph CFG {\n"); + append(ctx.result, "node[shape=record]\n"); + + appendBlocks(ctx, function, includeInst, /* includeIn */ true, /* includeOut */ true, /* includeDef */ true); for (size_t i = 0; i < function.blocks.size(); i++) { @@ -750,6 +760,107 @@ std::string toDot(const IrFunction& function, bool includeInst) return result; } +std::string toDotCfg(const IrFunction& function) +{ + std::string result; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + + append(ctx.result, "digraph CFG {\n"); + append(ctx.result, "node[shape=record]\n"); + + appendBlocks(ctx, function, /* includeInst */ false, /* includeIn */ false, /* includeOut */ false, /* includeDef */ true); + + for (size_t i = 0; i < function.blocks.size() && i < ctx.cfg.successorsOffsets.size(); i++) + { + BlockIteratorWrapper succ = successors(ctx.cfg, unsigned(i)); + + for (uint32_t target : succ) + append(ctx.result, "b%u -> b%u;\n", unsigned(i), target); + } + + append(ctx.result, "}\n"); + + return result; +} + +std::string toDotDjGraph(const IrFunction& function) +{ + std::string result; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + + append(ctx.result, "digraph CFG {\n"); + + for (size_t i = 0; i < ctx.blocks.size(); i++) + { + const IrBlock& block = ctx.blocks[i]; + + append(ctx.result, "b%u [", unsigned(i)); + + if (block.kind == IrBlockKind::Fallback) + append(ctx.result, "style=filled;fillcolor=salmon;"); + else if (block.kind == IrBlockKind::Bytecode) + append(ctx.result, "style=filled;fillcolor=palegreen;"); + + append(ctx.result, "label=\""); + toString(ctx, block, uint32_t(i)); + append(ctx.result, "\"];\n"); + } + + // Layer by depth in tree + uint32_t depth = 0; + bool found = true; + + while (found) + { + found = false; + + append(ctx.result, "{rank = same;"); + for (size_t i = 0; i < ctx.cfg.domOrdering.size(); i++) + { + if (ctx.cfg.domOrdering[i].depth == depth) + { + append(ctx.result, "b%u;", unsigned(i)); + found = true; + } + } + append(ctx.result, "}\n"); + + depth++; + } + + for (size_t i = 0; i < ctx.cfg.domChildrenOffsets.size(); i++) + { + BlockIteratorWrapper dom = domChildren(ctx.cfg, unsigned(i)); + + for (uint32_t target : dom) + append(ctx.result, "b%u -> b%u;\n", unsigned(i), target); + + // Join edges are all successor edges that do not strongly dominate + BlockIteratorWrapper succ = successors(ctx.cfg, unsigned(i)); + + for (uint32_t successor : succ) + { + bool found = false; + + for (uint32_t target : dom) + { + if (target == successor) + { + found = true; + break; + } + } + + if (!found) + append(ctx.result, "b%u -> b%u [style=dotted];\n", unsigned(i), successor); + } + } + + append(ctx.result, "}\n"); + + return result; +} + std::string dumpDot(const IrFunction& function, bool includeInst) { std::string result = toDot(function, includeInst); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 5c29ad413..94c46dbfd 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -1165,25 +1165,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::INTERRUPT: { - RegisterA64 temp = regs.allocTemp(KindA64::x); - - Label skip, next; - build.ldr(temp, mem(rState, offsetof(lua_State, global))); - build.ldr(temp, mem(temp, offsetof(global_State, cb.interrupt))); - build.cbz(temp, skip); - - size_t spills = regs.spill(build, index); + regs.spill(build, index); - // Jump to outlined interrupt handler, it will give back control to x1 - build.mov(x0, (uintOp(inst.a) + 1) * sizeof(Instruction)); - build.adr(x1, next); - build.b(helpers.interrupt); + Label self; - build.setLabel(next); + build.ldr(x0, mem(rState, offsetof(lua_State, global))); + build.ldr(x0, mem(x0, offsetof(global_State, cb.interrupt))); + build.cbnz(x0, self); - regs.restore(build, spills); // need to restore before skip so that registers are in a consistent state + Label next = build.setLabel(); - build.setLabel(skip); + interruptHandlers.push_back({self, uintOp(inst.a), next}); break; } case IrCmd::CHECK_GC: @@ -1733,6 +1725,20 @@ void IrLoweringA64::finishBlock() regs.assertNoSpills(); } +void IrLoweringA64::finishFunction() +{ + if (build.logText) + build.logAppend("; interrupt handlers\n"); + + for (InterruptHandler& handler : interruptHandlers) + { + build.setLabel(handler.self); + build.mov(x0, (handler.pcpos + 1) * sizeof(Instruction)); + build.adr(x1, handler.next); + build.b(helpers.interrupt); + } +} + bool IrLoweringA64::hasError() const { return error; diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 1df09bd37..fc228cf16 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -27,6 +27,7 @@ struct IrLoweringA64 void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); void finishBlock(); + void finishFunction(); bool hasError() const; @@ -53,6 +54,13 @@ struct IrLoweringA64 IrBlock& blockOp(IrOp op) const; Label& labelOp(IrOp op) const; + struct InterruptHandler + { + Label self; + unsigned int pcpos; + Label next; + }; + AssemblyBuilderA64& build; ModuleHelpers& helpers; NativeState& data; @@ -63,6 +71,8 @@ struct IrLoweringA64 IrValueLocationTracking valueTracker; + std::vector interruptHandlers; + bool error = false; }; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index b9c35df04..320cb0791 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -958,8 +958,27 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::INTERRUPT: - emitInterrupt(regs, build, uintOp(inst.a)); + { + unsigned pcpos = uintOp(inst.a); + + // We unconditionally spill values here because that allows us to ignore register state when we synthesize interrupt handler + // This can be changed in the future if we can somehow record interrupt handler code separately + // Since interrupts are loop edges or call/ret, we don't have a significant opportunity for register reuse here anyway + regs.preserveAndFreeInstValues(); + + ScopedRegX64 tmp{regs, SizeX64::qword}; + + Label self; + + build.mov(tmp.reg, qword[rState + offsetof(lua_State, global)]); + build.cmp(qword[tmp.reg + offsetof(global_State, cb.interrupt)], 0); + build.jcc(ConditionX64::NotEqual, self); + + Label next = build.setLabel(); + + interruptHandlers.push_back({self, pcpos, next}); break; + } case IrCmd::CHECK_GC: callStepGc(regs, build); break; @@ -991,7 +1010,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::SET_SAVEDPC: { - // This is like emitSetSavedPc, but using register allocation instead of relying on rax/rdx ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::qword}; @@ -1048,7 +1066,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::RETURN: regs.assertAllFree(); regs.assertNoSpills(); - emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); + emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b), function.variadic); break; case IrCmd::FORGLOOP: regs.assertAllFree(); @@ -1350,6 +1368,20 @@ void IrLoweringX64::finishBlock() regs.assertNoSpills(); } +void IrLoweringX64::finishFunction() +{ + if (build.logText) + build.logAppend("; interrupt handlers\n"); + + for (InterruptHandler& handler : interruptHandlers) + { + build.setLabel(handler.self); + build.mov(rax, handler.pcpos + 1); + build.lea(rbx, handler.next); + build.jmp(helpers.interrupt); + } +} + bool IrLoweringX64::hasError() const { // If register allocator had to use more stack slots than we have available, this function can't run natively diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index cab4a85f5..a375a334c 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -29,6 +29,7 @@ struct IrLoweringX64 void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); void finishBlock(); + void finishFunction(); bool hasError() const; @@ -53,6 +54,13 @@ struct IrLoweringX64 IrBlock& blockOp(IrOp op) const; Label& labelOp(IrOp op) const; + struct InterruptHandler + { + Label self; + unsigned int pcpos; + Label next; + }; + AssemblyBuilderX64& build; ModuleHelpers& helpers; NativeState& data; @@ -62,6 +70,8 @@ struct IrLoweringX64 IrRegAllocX64 regs; IrValueLocationTracking valueTracker; + + std::vector interruptHandlers; }; } // namespace X64 diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 338bb49f9..b779fb4b7 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -1059,16 +1059,21 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited // TODO: using values from the first block can cause 'live out' of the linear block predecessor to not have all required registers constPropInBlock(build, startingBlock, state); - // Veryfy that target hasn't changed + // Verify that target hasn't changed LUAU_ASSERT(function.instructions[startingBlock.finish].a.index == targetBlockIdx); + // Note: using startingBlock after this line is unsafe as the reference may be reallocated by build.block() below + uint32_t startingInsn = startingBlock.start; + // Create new linearized block into which we are going to redirect starting block jump IrOp newBlock = build.block(IrBlockKind::Linearized); visited.push_back(false); - // TODO: placement of linear blocks in final lowering is sub-optimal, it should follow our predecessor build.beginBlock(newBlock); + // By default, blocks are ordered according to start instruction; we alter sort order to make sure linearized block is placed right after the starting block + function.blocks[newBlock.index].sortkey = startingInsn + 1; + replace(function, termInst.a, newBlock); // Clone the collected path into our fresh block diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 54086d53f..eab57b17a 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -413,8 +413,10 @@ enum LuauBytecodeTag { // Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled LBC_VERSION_MIN = 3, - LBC_VERSION_MAX = 3, + LBC_VERSION_MAX = 4, LBC_VERSION_TARGET = 3, + // Type encoding version + LBC_TYPE_VERSION = 1, // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, @@ -425,6 +427,25 @@ enum LuauBytecodeTag LBC_CONSTANT_CLOSURE, }; +// Type table tags +enum LuauBytecodeEncodedType +{ + LBC_TYPE_NIL = 0, + LBC_TYPE_BOOLEAN, + LBC_TYPE_NUMBER, + LBC_TYPE_STRING, + LBC_TYPE_TABLE, + LBC_TYPE_FUNCTION, + LBC_TYPE_THREAD, + LBC_TYPE_USERDATA, + LBC_TYPE_VECTOR, + + LBC_TYPE_ANY = 15, + LBC_TYPE_OPTIONAL_BIT = 1 << 7, + + LBC_TYPE_INVALID = 256, +}; + // Builtin function ids, used in LOP_FASTCALL enum LuauBuiltinFunction { diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index ba4232a01..f3c2f47d7 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -74,6 +74,8 @@ class BytecodeBuilder void foldJumps(); void expandJumps(); + void setFunctionTypeInfo(std::string value); + void setDebugFunctionName(StringRef name); void setDebugFunctionLineDefined(int line); void setDebugLine(int line); @@ -118,6 +120,7 @@ class BytecodeBuilder std::string dumpFunction(uint32_t id) const; std::string dumpEverything() const; std::string dumpSourceRemarks() const; + std::string dumpTypeInfo() const; void annotateInstruction(std::string& result, uint32_t fid, uint32_t instpos) const; @@ -132,6 +135,7 @@ class BytecodeBuilder static std::string getError(const std::string& message); static uint8_t getVersion(); + static uint8_t getTypeEncodingVersion(); private: struct Constant @@ -186,6 +190,7 @@ class BytecodeBuilder std::string dump; std::string dumpname; std::vector dumpinstoffs; + std::string typeinfo; }; struct DebugLocal diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index e2b769ec6..9296519ba 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(BytecodeVersion4, false) + namespace Luau { @@ -513,6 +515,11 @@ bool BytecodeBuilder::patchSkipC(size_t jumpLabel, size_t targetLabel) return true; } +void BytecodeBuilder::setFunctionTypeInfo(std::string value) +{ + functions[currentFunction].typeinfo = std::move(value); +} + void BytecodeBuilder::setDebugFunctionName(StringRef name) { unsigned int index = addStringTableEntry(name); @@ -606,6 +613,13 @@ void BytecodeBuilder::finalize() bytecode = char(version); + if (FFlag::BytecodeVersion4) + { + uint8_t typesversion = getTypeEncodingVersion(); + LUAU_ASSERT(typesversion == 1); + writeByte(bytecode, typesversion); + } + writeStringTable(bytecode); writeVarInt(bytecode, uint32_t(functions.size())); @@ -628,6 +642,14 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const writeByte(ss, func.numupvalues); writeByte(ss, func.isvararg); + if (FFlag::BytecodeVersion4) + { + writeByte(ss, 0); // Reserved for cgflags + + writeVarInt(ss, uint32_t(func.typeinfo.size())); + ss.append(func.typeinfo); + } + // instructions writeVarInt(ss, uint32_t(insns.size())); @@ -1092,9 +1114,18 @@ std::string BytecodeBuilder::getError(const std::string& message) uint8_t BytecodeBuilder::getVersion() { // This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags + + if (FFlag::BytecodeVersion4) + return 4; + return LBC_VERSION_TARGET; } +uint8_t BytecodeBuilder::getTypeEncodingVersion() +{ + return LBC_TYPE_VERSION; +} + #ifdef LUAU_ASSERTENABLED void BytecodeBuilder::validate() const { @@ -2269,6 +2300,75 @@ std::string BytecodeBuilder::dumpSourceRemarks() const return result; } +static const char* getBaseTypeString(uint8_t type) +{ + uint8_t tag = type & ~LBC_TYPE_OPTIONAL_BIT; + switch (tag) + { + case LBC_TYPE_NIL: + return "nil"; + case LBC_TYPE_BOOLEAN: + return "boolean"; + case LBC_TYPE_NUMBER: + return "number"; + case LBC_TYPE_STRING: + return "string"; + case LBC_TYPE_TABLE: + return "{ }"; + case LBC_TYPE_FUNCTION: + return "function( )"; + case LBC_TYPE_THREAD: + return "thread"; + case LBC_TYPE_USERDATA: + return "userdata"; + case LBC_TYPE_VECTOR: + return "vector"; + case LBC_TYPE_ANY: + return "any"; + } + + LUAU_ASSERT(!"Unhandled type in getBaseTypeString"); + return nullptr; +} + +std::string BytecodeBuilder::dumpTypeInfo() const +{ + std::string result; + + for (size_t i = 0; i < functions.size(); ++i) + { + const std::string& typeinfo = functions[i].typeinfo; + if (typeinfo.empty()) + continue; + + uint8_t encodedType = typeinfo[0]; + + LUAU_ASSERT(encodedType == LBC_TYPE_FUNCTION); + + formatAppend(result, "%zu: function(", i); + + LUAU_ASSERT(typeinfo.size() >= 2); + + uint8_t numparams = typeinfo[1]; + + LUAU_ASSERT(size_t(1 + numparams - 1) < typeinfo.size()); + + for (uint8_t i = 0; i < numparams; ++i) + { + uint8_t et = typeinfo[2 + i]; + const char* optional = (et & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + formatAppend(result, "%s%s", getBaseTypeString(et), optional); + + if (i + 1 != numparams) + formatAppend(result, ", "); + } + + formatAppend(result, ")\n"); + } + + return result; +} + void BytecodeBuilder::annotateInstruction(std::string& result, uint32_t fid, uint32_t instpos) const { if ((dumpFlags & Dump_Code) == 0) diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 64667221f..8dd9876ca 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -10,6 +10,7 @@ #include "ConstantFolding.h" #include "CostModel.h" #include "TableShape.h" +#include "Types.h" #include "ValueTracking.h" #include @@ -25,7 +26,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileInlineDefer, false) +LUAU_FASTFLAGVARIABLE(CompileFunctionType, false) +LUAU_FASTFLAG(BytecodeVersion4) namespace Luau { @@ -202,6 +204,13 @@ struct Compiler setDebugLine(func); + if (FFlag::BytecodeVersion4 && FFlag::CompileFunctionType) + { + std::string funcType = getFunctionType(func); + if (!funcType.empty()) + bytecode.setFunctionTypeInfo(std::move(funcType)); + } + if (func->vararg) bytecode.emitABC(LOP_PREPVARARGS, uint8_t(self + func->args.size), 0, 0); @@ -560,15 +569,7 @@ struct Compiler size_t oldLocals = localStack.size(); std::vector args; - if (FFlag::LuauCompileInlineDefer) - { - args.reserve(func->args.size); - } - else - { - // note that we push the frame early; this is needed to block recursive inline attempts - inlineFrames.push_back({func, oldLocals, target, targetCount}); - } + args.reserve(func->args.size); // evaluate all arguments; note that we don't emit code for constant arguments (relying on constant folding) // note that compiler state (variable registers/values) does not change here - we defer that to a separate loop below to handle nested calls @@ -590,16 +591,8 @@ struct Compiler else LUAU_ASSERT(!"Unexpected expression type"); - if (FFlag::LuauCompileInlineDefer) - { - for (size_t j = i; j < func->args.size; ++j) - args.push_back({func->args.data[j], uint8_t(reg + (j - i))}); - } - else - { - for (size_t j = i; j < func->args.size; ++j) - pushLocal(func->args.data[j], uint8_t(reg + (j - i))); - } + for (size_t j = i; j < func->args.size; ++j) + args.push_back({func->args.data[j], uint8_t(reg + (j - i))}); // all remaining function arguments have been allocated and assigned to break; @@ -614,26 +607,17 @@ struct Compiler else bytecode.emitABC(LOP_LOADNIL, reg, 0, 0); - if (FFlag::LuauCompileInlineDefer) - args.push_back({var, reg}); - else - pushLocal(var, reg); + args.push_back({var, reg}); } else if (arg == nullptr) { // since the argument is not mutated, we can simply fold the value into the expressions that need it - if (FFlag::LuauCompileInlineDefer) - args.push_back({var, kInvalidReg, {Constant::Type_Nil}}); - else - locstants[var] = {Constant::Type_Nil}; + args.push_back({var, kInvalidReg, {Constant::Type_Nil}}); } else if (const Constant* cv = constants.find(arg); cv && cv->type != Constant::Type_Unknown) { // since the argument is not mutated, we can simply fold the value into the expressions that need it - if (FFlag::LuauCompileInlineDefer) - args.push_back({var, kInvalidReg, *cv}); - else - locstants[var] = *cv; + args.push_back({var, kInvalidReg, *cv}); } else { @@ -643,20 +627,14 @@ struct Compiler // if the argument is a local that isn't mutated, we will simply reuse the existing register if (int reg = le ? getExprLocalReg(le) : -1; reg >= 0 && (!lv || !lv->written)) { - if (FFlag::LuauCompileInlineDefer) - args.push_back({var, uint8_t(reg)}); - else - pushLocal(var, uint8_t(reg)); + args.push_back({var, uint8_t(reg)}); } else { uint8_t temp = allocReg(arg, 1); compileExprTemp(arg, temp); - if (FFlag::LuauCompileInlineDefer) - args.push_back({var, temp}); - else - pushLocal(var, temp); + args.push_back({var, temp}); } } } @@ -668,19 +646,16 @@ struct Compiler compileExprAuto(expr->args.data[i], rsi); } - if (FFlag::LuauCompileInlineDefer) - { - // apply all evaluated arguments to the compiler state - // note: locals use current startpc for debug info, although some of them have been computed earlier; this is similar to compileStatLocal - for (InlineArg& arg : args) - if (arg.value.type == Constant::Type_Unknown) - pushLocal(arg.local, arg.reg); - else - locstants[arg.local] = arg.value; + // apply all evaluated arguments to the compiler state + // note: locals use current startpc for debug info, although some of them have been computed earlier; this is similar to compileStatLocal + for (InlineArg& arg : args) + if (arg.value.type == Constant::Type_Unknown) + pushLocal(arg.local, arg.reg); + else + locstants[arg.local] = arg.value; - // the inline frame will be used to compile return statements as well as to reject recursive inlining attempts - inlineFrames.push_back({func, oldLocals, target, targetCount}); - } + // the inline frame will be used to compile return statements as well as to reject recursive inlining attempts + inlineFrames.push_back({func, oldLocals, target, targetCount}); // fold constant values updated above into expressions in the function body foldConstants(constants, variables, locstants, builtinsFold, func->body); diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp new file mode 100644 index 000000000..02041986f --- /dev/null +++ b/Compiler/src/Types.cpp @@ -0,0 +1,106 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/BytecodeBuilder.h" + +#include "Types.h" + +namespace Luau +{ + +static LuauBytecodeEncodedType getType(AstType* ty) +{ + if (AstTypeReference* ref = ty->as()) + { + if (ref->name == "nil") + return LBC_TYPE_NIL; + else if (ref->name == "boolean") + return LBC_TYPE_BOOLEAN; + else if (ref->name == "number") + return LBC_TYPE_NUMBER; + else if (ref->name == "string") + return LBC_TYPE_STRING; + else if (ref->name == "thread") + return LBC_TYPE_THREAD; + else if (ref->name == "any" || ref->name == "unknown") + return LBC_TYPE_ANY; + } + else if (AstTypeTable* table = ty->as()) + { + return LBC_TYPE_TABLE; + } + else if (AstTypeFunction* func = ty->as()) + { + return LBC_TYPE_FUNCTION; + } + else if (AstTypeUnion* un = ty->as()) + { + bool optional = false; + LuauBytecodeEncodedType type = LBC_TYPE_INVALID; + + for (AstType* ty : un->types) + { + LuauBytecodeEncodedType et = getType(ty); + + if (et == LBC_TYPE_NIL) + { + optional = true; + continue; + } + + if (type == LBC_TYPE_INVALID) + { + type = et; + continue; + } + + if (type != et) + return LBC_TYPE_ANY; + } + + if (type == LBC_TYPE_INVALID) + return LBC_TYPE_ANY; + + return LuauBytecodeEncodedType(type | (optional && (type != LBC_TYPE_ANY) ? LBC_TYPE_OPTIONAL_BIT : 0)); + } + else if (AstTypeIntersection* inter = ty->as()) + { + return LBC_TYPE_ANY; + } + + return LBC_TYPE_ANY; +} + +std::string getFunctionType(const AstExprFunction* func) +{ + if (func->vararg || func->generics.size || func->genericPacks.size) + return {}; + + bool self = func->self != 0; + + std::string typeInfo; + typeInfo.reserve(func->args.size + self + 2); + + typeInfo.push_back(LBC_TYPE_FUNCTION); + typeInfo.push_back(uint8_t(self + func->args.size)); + + if (self) + typeInfo.push_back(LBC_TYPE_TABLE); + + bool haveNonAnyParam = false; + for (AstLocal* arg : func->args) + { + LuauBytecodeEncodedType ty = arg->annotation ? getType(arg->annotation) : LBC_TYPE_ANY; + + if (ty != LBC_TYPE_ANY) + haveNonAnyParam = true; + + typeInfo.push_back(ty); + } + + // If all parameters simplify to any, we can just omit type info for this function + if (!haveNonAnyParam) + return {}; + + return typeInfo; +} + +} // namespace Luau \ No newline at end of file diff --git a/Compiler/src/Types.h b/Compiler/src/Types.h new file mode 100644 index 000000000..1be9155f8 --- /dev/null +++ b/Compiler/src/Types.h @@ -0,0 +1,9 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" + +namespace Luau +{ +std::string getFunctionType(const AstExprFunction* func); +} // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index b1693c36d..5b9bd61eb 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -43,6 +43,7 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/ConstantFolding.cpp Compiler/src/CostModel.cpp Compiler/src/TableShape.cpp + Compiler/src/Types.cpp Compiler/src/ValueTracking.cpp Compiler/src/lcode.cpp Compiler/src/Builtins.h @@ -50,6 +51,7 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/ConstantFolding.h Compiler/src/CostModel.h Compiler/src/TableShape.h + Compiler/src/Types.h Compiler/src/ValueTracking.h ) diff --git a/VM/src/lvm.h b/VM/src/lvm.h index cfb6456b5..5ec7bc165 100644 --- a/VM/src/lvm.h +++ b/VM/src/lvm.h @@ -24,7 +24,6 @@ LUAI_FUNC void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId v LUAI_FUNC void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val); LUAI_FUNC void luaV_concat(lua_State* L, int total, int last); LUAI_FUNC void luaV_getimport(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil); -LUAI_FUNC void luaV_getimport_dep(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil); LUAI_FUNC void luaV_prepareFORN(lua_State* L, StkId plimit, StkId pstep, StkId pinit); LUAI_FUNC void luaV_callTM(lua_State* L, int nparams, int res); LUAI_FUNC void luaV_tryfuncTM(lua_State* L, StkId func); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 79bf807b6..90c5a7e86 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,8 +16,6 @@ #include -LUAU_FASTFLAG(LuauGetImportDirect) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -432,20 +430,8 @@ static void luau_execute(lua_State* L) { uint32_t aux = *pc++; - if (FFlag::LuauGetImportDirect) - { - VM_PROTECT(luaV_getimport(L, cl->env, k, ra, aux, /* propagatenil= */ false)); - VM_NEXT(); - } - else - { - VM_PROTECT(luaV_getimport_dep(L, cl->env, k, aux, /* propagatenil= */ false)); - ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack - - setobj2s(L, ra, L->top - 1); - L->top--; - VM_NEXT(); - } + VM_PROTECT(luaV_getimport(L, cl->env, k, ra, aux, /* propagatenil= */ false)); + VM_NEXT(); } } diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index f26cc05d7..edbe5035d 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauGetImportDirect, false) - // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template struct TempBuffer @@ -77,34 +75,6 @@ void luaV_getimport(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, luaV_gettable(L, res, &k[id2], res); } -void luaV_getimport_dep(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) -{ - LUAU_ASSERT(!FFlag::LuauGetImportDirect); - - int count = id >> 30; - int id0 = count > 0 ? int(id >> 20) & 1023 : -1; - int id1 = count > 1 ? int(id >> 10) & 1023 : -1; - int id2 = count > 2 ? int(id) & 1023 : -1; - - // allocate a stack slot so that we can do table lookups - luaD_checkstack(L, 1); - setnilvalue(L->top); - L->top++; - - // global lookup into L->top-1 - TValue g; - sethvalue(L, &g, env); - luaV_gettable(L, &g, &k[id0], L->top - 1); - - // table lookup for id1 - if (id1 >= 0 && (!propagatenil || !ttisnil(L->top - 1))) - luaV_gettable(L, L->top - 1, &k[id1], L->top - 1); - - // table lookup for id2 - if (id2 >= 0 && (!propagatenil || !ttisnil(L->top - 1))) - luaV_gettable(L, L->top - 1, &k[id2], L->top - 1); -} - template static T read(const char* data, size_t size, size_t& offset) { @@ -153,17 +123,12 @@ static void resolveImportSafe(lua_State* L, Table* env, TValue* k, uint32_t id) // note: we call getimport with nil propagation which means that accesses to table chains like A.B.C will resolve in nil // this is technically not necessary but it reduces the number of exceptions when loading scripts that rely on getfenv/setfenv for global // injection - if (FFlag::LuauGetImportDirect) - { - // allocate a stack slot so that we can do table lookups - luaD_checkstack(L, 1); - setnilvalue(L->top); - L->top++; + // allocate a stack slot so that we can do table lookups + luaD_checkstack(L, 1); + setnilvalue(L->top); + L->top++; - luaV_getimport(L, L->gt, self->k, L->top - 1, self->id, /* propagatenil= */ true); - } - else - luaV_getimport_dep(L, L->gt, self->k, self->id, /* propagatenil= */ true); + luaV_getimport(L, L->gt, self->k, L->top - 1, self->id, /* propagatenil= */ true); } }; @@ -194,6 +159,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size uint8_t version = read(data, size, offset); + + // 0 means the rest of the bytecode is the error message if (version == 0) { @@ -221,6 +188,13 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size TString* source = luaS_new(L, chunkname); + + if (version >= 4) + { + uint8_t typesversion = read(data, size, offset); + LUAU_ASSERT(typesversion == 1); + } + // string table unsigned int stringCount = readVarInt(data, size, offset); TempBuffer strings(L, stringCount); @@ -248,6 +222,25 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->nups = read(data, size, offset); p->is_vararg = read(data, size, offset); + if (version >= 4) + { + uint8_t cgflags = read(data, size, offset); + LUAU_ASSERT(cgflags == 0); + + uint32_t typesize = readVarInt(data, size, offset); + + if (typesize) + { + uint8_t* types = (uint8_t*)data + offset; + + LUAU_ASSERT(typesize == unsigned(2 + p->numparams)); + LUAU_ASSERT(types[0] == LBC_TYPE_FUNCTION); + LUAU_ASSERT(types[1] == p->numparams); + + offset += typesize; + } + } + p->sizecode = readVarInt(data, size, offset); p->code = luaM_newarray(L, p->sizecode, Instruction, p->memcat); for (int j = 0; j < p->sizecode; ++j) diff --git a/bench/bench.py b/bench/bench.py index 547e0d38d..002dfadb5 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -73,7 +73,7 @@ def arrayRangeOffset(count, offset): return result -def getCallgrindOutput(lines): +def getCallgrindOutput(stdout, lines): result = [] name = None @@ -86,12 +86,36 @@ def getCallgrindOutput(lines): result += "|><|" + name + "|><|" + str(insn / CALLGRIND_INSN_PER_SEC * 1000.0) + "||_||" name = None + # If no results were found above, this may indicate the native executable running + # the benchmark doesn't have support for callgrind builtin. In that case just + # report the "totals" from the output file. + if len(result) == 0: + elements = stdout.decode('utf8').split("|><|") + if len(elements) >= 2: + name = elements[1] + + for l in lines: + if l.startswith("totals: "): + insn = int(l[8:]) + # Note: we only run each bench once under callgrind so we only report a single time per run; callgrind instruction count variance is ~0.01% so it might as well be zero + result += "|><|" + name + "|><|" + str(insn / CALLGRIND_INSN_PER_SEC * 1000.0) + "||_||" + return "".join(result) def conditionallyShowCommand(cmd): if arguments.show_commands: print(f'{colored(Color.BLUE, "EXECUTING")}: {cmd}') +def checkValgrindExecutable(): + """Return true if valgrind can be successfully spawned""" + try: + subprocess.check_call("valgrind --version", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + except: + print(f"{colored(Color.YELLOW, 'WARNING')}: Unable to spawn 'valgrind'. Please ensure valgrind is installed when using '--callgrind'.") + return False + + return True + def getVmOutput(cmd): if os.name == "nt": try: @@ -103,17 +127,24 @@ def getVmOutput(cmd): except: return "" elif arguments.callgrind: + if not checkValgrindExecutable(): + return "" + output_path = os.path.join(scriptdir, "callgrind.out") try: - fullCmd = "valgrind --tool=callgrind --callgrind-out-file=callgrind.out --combine-dumps=yes --dump-line=no " + cmd - conditionallyShowCommand(fullCmd) - subprocess.check_call(fullCmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, cwd=scriptdir) - path = os.path.join(scriptdir, "callgrind.out") - with open(path, "r") as file: - lines = file.readlines() - os.unlink(path) - return getCallgrindOutput(lines) + os.unlink(output_path) # Remove stale output except: - return "" + pass + fullCmd = "valgrind --tool=callgrind --callgrind-out-file=callgrind.out --combine-dumps=yes --dump-line=no " + cmd + conditionallyShowCommand(fullCmd) + try: + output = subprocess.check_output(fullCmd, shell=True, stderr=subprocess.DEVNULL, cwd=scriptdir) + except subprocess.CalledProcessError as e: + print(f"{colored(Color.YELLOW, 'WARNING')}: Valgrind returned error code {e.returncode}") + output = e.output + with open(output_path, "r") as file: + lines = file.readlines() + os.unlink(output_path) + return getCallgrindOutput(output, lines) else: conditionallyShowCommand(cmd) with subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd=scriptdir) as p: @@ -352,7 +383,7 @@ def analyzeResult(subdir, main, comparisons): if influxReporter != None: influxReporter.report_result(subdir, main.name, main.filename, "SUCCESS", main.min, main.avg, main.max, main.sampleConfidenceInterval, main.shortVm, main.vm) - print(colored(Color.YELLOW, 'SUCCESS') + ': {:<40}'.format(main.name) + ": " + '{:8.3f}'.format(main.avg) + "ms +/- " + + print(colored(Color.GREEN, 'SUCCESS') + ': {:<40}'.format(main.name) + ": " + '{:8.3f}'.format(main.avg) + "ms +/- " + '{:6.3f}'.format(main.sampleConfidenceInterval / main.avg * 100) + "% on " + main.shortVm) plotLabels.append(main.name) @@ -449,7 +480,7 @@ def analyzeResult(subdir, main, comparisons): 'P(T<=t)': '---' if pValue < 0 else '{:.0f}%'.format(pValue * 100) }) - print(colored(Color.YELLOW, 'SUCCESS') + ': {:<40}'.format(main.name) + ": " + '{:8.3f}'.format(compare.avg) + "ms +/- " + + print(colored(Color.GREEN, 'SUCCESS') + ': {:<40}'.format(main.name) + ": " + '{:8.3f}'.format(compare.avg) + "ms +/- " + '{:6.3f}'.format(compare.sampleConfidenceInterval / compare.avg * 100) + "% on " + compare.shortVm + ' ({:+7.3f}%, '.format(speedup * 100) + verdict + ")") @@ -727,6 +758,10 @@ def run(args, argsubcb): arguments = args argumentSubstituionCallback = argsubcb + if os.name == "nt" and arguments.callgrind: + print(f"{colored(Color.RED, 'ERROR')}: --callgrind is not supported on Windows. Please consider using this option on another OS, or Linux using WSL.") + sys.exit(1) + if arguments.report_metrics or arguments.print_influx_debugging: import influxbench influxReporter = influxbench.InfluxReporter(arguments) diff --git a/bench/bench_support.lua b/bench/bench_support.lua index a9608ecc2..9e415fc1c 100644 --- a/bench/bench_support.lua +++ b/bench/bench_support.lua @@ -57,4 +57,46 @@ function bench.runCode(f, description) print(report) end +-- This function acts a bit like a Unix "fork" operation +-- When it is first called it clones `scriptInstance` and starts executing +-- the cloned script parented to an Actor. When the cloned script calls "runScriptCodeUnderActor" +-- it will run 'f' and print out the provided 'description'. +-- +-- The function returns 'true' if it was invoked from a script running under an Actor +-- and 'false' otherwise. +-- +-- Example usage: +-- local bench = script and require(script.Parent.bench_support) or require("bench_support") +-- function testFunc() +-- ... +-- end +-- bench.runScriptCodeUnderActor(script, testFunc, "test function") +function bench.runScriptCodeUnderActor(scriptInstance, f, description) + if scriptInstance:GetActor() then + -- If this function was called from an Actor script, just run the function provided using runCode + bench.runCode(f, description) + return true + else + -- If this function was not called from an Actor script, clone the script and place it under + -- Actor instance. + + -- Create an Actor to run the script under + local actor = Instance.new("Actor") + -- Clone this script (i.e. the bench_support module) and place it under the Actor where + -- the script script would expect it to be when using 'require'. + local benchModule = script:Clone() + benchModule.Parent = actor + -- Clone the scriptInstance + local actorScript = scriptInstance:Clone() + -- Enable the script since `scriptInstance` may be started by roblox-cli without ever being enabled. + actorScript.Disabled = false + actorScript.Parent = actor + -- Add the actor to the workspace which will start executing the cloned script. + -- Note: the script needs to be placed under a instance that implements 'IScriptFilter' + -- (which workspace does) or it will never start executing. + actor.Parent = workspace + return false + end +end + return bench diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 63e92f8fa..177e24935 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -542,6 +542,20 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions") SINGLE_COMPARE(bsf(eax, edx), 0x0f, 0xbc, 0xc2); } +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelLea") +{ + CHECK(check( + [](AssemblyBuilderX64& build) { + Label fn; + build.lea(rax, fn); + build.ret(); + + build.setLabel(fn); + build.ret(); + }, + {0x48, 0x8d, 0x05, 0x01, 0x00, 0x00, 0x00, 0xc3, 0xc3})); +} + TEST_CASE("LogTest") { AssemblyBuilderX64 build(/* logText= */ true); @@ -561,6 +575,7 @@ TEST_CASE("LogTest") Label start = build.setLabel(); build.cmp(rsi, rdi); build.jcc(ConditionX64::Equal, start); + build.lea(rcx, start); build.jmp(qword[rdx]); build.vaddps(ymm9, ymm12, ymmword[rbp + 0xc]); @@ -605,6 +620,7 @@ TEST_CASE("LogTest") .L1: cmp rsi,rdi je .L1 + lea rcx,.L1 jmp qword ptr [rdx] vaddps ymm9,ymm12,ymmword ptr [rbp+0Ch] vaddpd ymm2,ymm7,qword ptr [.start-8] diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index 82577bed1..a264d0e7f 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -432,11 +432,11 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") REQUIRE(2 == root->body.size); std::string_view expected1 = - R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]}},{"name":"method","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}}}]})"; + R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]}},{"name":"method","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}}}],"indexer":null})"; CHECK(toJson(root->body.data[0]) == expected1); std::string_view expected2 = - R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","nameLocation":"7,19 - 7,25","parameters":[]}}]})"; + R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","nameLocation":"7,19 - 7,25","parameters":[]}}],"indexer":null})"; CHECK(toJson(root->body.data[1]) == expected2); } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 4885b3174..97cc32635 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -49,6 +49,15 @@ static std::string compileFunction0Coverage(const char* source, int level) return bcb.dumpFunction(0); } +static std::string compileFunction0TypeTable(const char* source) +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::compileOrThrow(bcb, source); + + return bcb.dumpTypeInfo(); +} + TEST_SUITE_BEGIN("Compiler"); TEST_CASE("CompileToBytecode") @@ -5796,8 +5805,6 @@ RETURN R3 1 TEST_CASE("InlineRecurseArguments") { - ScopedFastFlag sff("LuauCompileInlineDefer", true); - // the example looks silly but we preserve it verbatim as it was found by fuzzer for a previous version of the compiler CHECK_EQ("\n" + compileFunction(R"( local function foo(a, b) @@ -7071,4 +7078,56 @@ L1: RETURN R3 1 )"); } +TEST_CASE("EncodedTypeTable") +{ + ScopedFastFlag sffs[] = { + {"BytecodeVersion4", true}, + {"CompileFunctionType", true}, + }; + + CHECK_EQ("\n" + compileFunction0TypeTable(R"( +function myfunc(test: string, num: number) + print(test) +end + +function myfunc2(test: number?) +end + +function myfunc3(test: string, n: number) +end + +function myfunc4(test: string | number, n: number) +end + +-- Promoted to function(any, any) since general unions are not supported. +-- Functions with all `any` parameters will have omitted type info. +function myfunc5(test: string | number, n: number | boolean) +end + +myfunc('test') +)"), + R"( +0: function(string, number) +1: function(number?) +2: function(string, number) +3: function(any, number) +)"); + + CHECK_EQ("\n" + compileFunction0TypeTable(R"( +local Str = { + a = 1 +} + +-- Implicit `self` parameter is automatically assumed to be table type. +function Str:test(n: number) + print(self.a, n) +end + +Str:test(234) +)"), + R"( +0: function({ }, number) +)"); +} + TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 4bfb63f3a..5b0c44d0c 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -93,7 +93,7 @@ class IrBuilderFixture { for (uint32_t succIdx : successors(build.function.cfg, k)) { - if (succIdx == i) + if (succIdx == uint32_t(i)) build.function.cfg.predecessors.push_back(k); } } diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 1335b6f4e..a8738ac8e 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -54,7 +54,8 @@ TEST_SUITE_BEGIN("AllocatorTests"); TEST_CASE("allocator_can_be_moved") { Counter* c = nullptr; - auto inner = [&]() { + auto inner = [&]() + { Luau::Allocator allocator; c = allocator.alloc(); Luau::Allocator moved{std::move(allocator)}; @@ -921,7 +922,8 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_mid") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace") { - auto columnOfEndBraceError = [this](const char* code) { + auto columnOfEndBraceError = [this](const char* code) + { try { parse(code); @@ -1882,6 +1884,44 @@ TEST_CASE_FIXTURE(Fixture, "class_method_properties") CHECK_EQ(2, klass2->props.size); } +TEST_CASE_FIXTURE(Fixture, "class_indexer") +{ + ScopedFastFlag LuauParseDeclareClassIndexer("LuauParseDeclareClassIndexer", true); + + AstStatBlock* stat = parseEx(R"( + declare class Foo + prop: boolean + [string]: number + end + )") + .root; + + REQUIRE_EQ(stat->body.size, 1); + + AstStatDeclareClass* declaredClass = stat->body.data[0]->as(); + REQUIRE(declaredClass); + REQUIRE(declaredClass->indexer); + REQUIRE(declaredClass->indexer->indexType->is()); + CHECK(declaredClass->indexer->indexType->as()->name == "string"); + REQUIRE(declaredClass->indexer->resultType->is()); + CHECK(declaredClass->indexer->resultType->as()->name == "number"); + + const ParseResult p1 = matchParseError(R"( + declare class Foo + [string]: number + -- can only have one indexer + [number]: number + end + )", + "Cannot have more than one class indexer"); + + REQUIRE_EQ(1, p1.root->body.size); + + AstStatDeclareClass* klass = p1.root->body.data[0]->as(); + REQUIRE(klass != nullptr); + CHECK(klass->indexer); +} + TEST_CASE_FIXTURE(Fixture, "parse_variadics") { //clang-format off @@ -2347,7 +2387,8 @@ class CountAstNodes : public AstVisitor TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") { - auto checkAstEquivalence = [this](const char* codeWithErrors, const char* code) { + auto checkAstEquivalence = [this](const char* codeWithErrors, const char* code) + { try { parse(codeWithErrors); @@ -2367,7 +2408,8 @@ TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") CHECK_EQ(counterWithErrors.count, counter.count); }; - auto checkRecovery = [this, checkAstEquivalence](const char* codeWithErrors, const char* code, unsigned expectedErrorCount) { + auto checkRecovery = [this, checkAstEquivalence](const char* codeWithErrors, const char* code, unsigned expectedErrorCount) + { try { parse(codeWithErrors); diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp index 2052019ec..1223152ba 100644 --- a/tests/Simplify.test.cpp +++ b/tests/Simplify.test.cpp @@ -341,7 +341,7 @@ TEST_CASE_FIXTURE(SimplifyFixture, "tables") CHECK(t2 == intersect(t2, t1)); TypeId t3 = mkTable({}); - + // {tag : string} intersect {{}} CHECK(t1 == intersect(t1, t3)); CHECK(t1 == intersect(t3, t1)); } diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index d67997574..0ca9bd735 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -394,6 +394,36 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") CHECK_EQ(toString(requireType("y")), "string"); } + +TEST_CASE_FIXTURE(Fixture, "class_definition_indexer") +{ + ScopedFastFlag LuauParseDeclareClassIndexer("LuauParseDeclareClassIndexer", true); + ScopedFastFlag LuauTypecheckClassTypeIndexers("LuauTypecheckClassTypeIndexers", true); + + loadDefinition(R"( + declare class Foo + [number]: string + end + )"); + + CheckResult result = check(R"( + local x: Foo + local y = x[1] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const ClassType* ctv = get(requireType("x")); + REQUIRE(ctv != nullptr); + + REQUIRE(bool(ctv->indexer)); + + CHECK_EQ(*ctv->indexer->indexType, *builtinTypes->numberType); + CHECK_EQ(*ctv->indexer->indexResultType, *builtinTypes->stringType); + + CHECK_EQ(toString(requireType("y")), "string"); +} + TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { unfreeze(frontend.globals.globalTypes); From 8bc2f51d89c8c8a0d38dba32be995d9731322060 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Sat, 24 Jun 2023 08:33:44 +0300 Subject: [PATCH 61/66] Sync to upstream/release/582 --- Analysis/include/Luau/Constraint.h | 2 +- Analysis/include/Luau/ConstraintSolver.h | 11 + Analysis/include/Luau/Module.h | 23 +- Analysis/include/Luau/Normalize.h | 2 - Analysis/include/Luau/Substitution.h | 73 ++-- Analysis/include/Luau/TypeInfer.h | 2 - Analysis/include/Luau/TypeUtils.h | 9 + Analysis/src/Anyification.cpp | 4 +- Analysis/src/ApplyTypeFunction.cpp | 4 +- Analysis/src/ConstraintGraphBuilder.cpp | 9 +- Analysis/src/ConstraintSolver.cpp | 35 +- Analysis/src/Error.cpp | 3 +- Analysis/src/Frontend.cpp | 1 + Analysis/src/Instantiation.cpp | 6 +- Analysis/src/Module.cpp | 9 - Analysis/src/Normalize.cpp | 7 - Analysis/src/Quantify.cpp | 3 +- Analysis/src/Substitution.cpp | 414 ++++++++++++------ Analysis/src/TypeChecker2.cpp | 476 +++++++++++++-------- Analysis/src/TypeInfer.cpp | 8 +- Analysis/src/Unifier.cpp | 13 +- CLI/Compile.cpp | 25 +- CodeGen/include/Luau/CodeGen.h | 11 + CodeGen/include/Luau/IrData.h | 10 + CodeGen/include/Luau/IrUtils.h | 2 + CodeGen/src/CodeGen.cpp | 306 +------------ CodeGen/src/CodeGenAssembly.cpp | 146 +++++++ CodeGen/src/CodeGenLower.h | 240 +++++++++++ CodeGen/src/EmitBuiltinsX64.cpp | 29 -- CodeGen/src/EmitCommonX64.cpp | 17 +- CodeGen/src/EmitCommonX64.h | 4 +- CodeGen/src/IrAnalysis.cpp | 3 + CodeGen/src/IrDump.cpp | 4 + CodeGen/src/IrLoweringA64.cpp | 68 +-- CodeGen/src/IrLoweringA64.h | 4 +- CodeGen/src/IrLoweringX64.cpp | 35 +- CodeGen/src/IrLoweringX64.h | 4 +- CodeGen/src/IrTranslateBuiltins.cpp | 7 +- CodeGen/src/IrTranslation.cpp | 8 +- CodeGen/src/IrUtils.cpp | 3 + CodeGen/src/IrValueLocationTracking.cpp | 1 + CodeGen/src/OptimizeConstProp.cpp | 4 + Common/include/Luau/DenseHash.h | 22 +- Makefile | 8 +- Sources.cmake | 2 + tests/CostModel.test.cpp | 2 +- tests/IrBuilder.test.cpp | 4 +- tests/Module.test.cpp | 6 - tests/TypeFamily.test.cpp | 4 +- tests/TypeInfer.aliases.test.cpp | 31 +- tests/TypeInfer.anyerror.test.cpp | 5 +- tests/TypeInfer.classes.test.cpp | 9 +- tests/TypeInfer.functions.test.cpp | 23 +- tests/TypeInfer.generics.test.cpp | 17 +- tests/TypeInfer.intersectionTypes.test.cpp | 4 - tests/TypeInfer.modules.test.cpp | 15 +- tests/TypeInfer.oop.test.cpp | 26 +- tests/TypeInfer.operators.test.cpp | 17 + tests/TypeInfer.provisional.test.cpp | 24 +- tests/TypeInfer.tables.test.cpp | 58 +-- tests/TypeInfer.test.cpp | 41 ++ tests/TypeInfer.tryUnify.test.cpp | 8 - tools/faillist.txt | 14 +- 63 files changed, 1422 insertions(+), 963 deletions(-) create mode 100644 CodeGen/src/CodeGenAssembly.cpp create mode 100644 CodeGen/src/CodeGenLower.h diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index d73ba46df..aa1d1c0ec 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -83,7 +83,7 @@ struct IterableConstraint TypePackId variables; const AstNode* nextAstFragment; - DenseHashMap* astOverloadResolvedTypes; + DenseHashMap* astForInNextTypes; }; // name(namedType) = name diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index b13bb21bd..b26d88c33 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -245,6 +245,17 @@ struct ConstraintSolver template bool tryUnify(NotNull constraint, TID subTy, TID superTy); + /** + * Bind a BlockedType to another type while taking care not to bind it to + * itself in the case that resultTy == blockedTy. This can happen if we + * have a tautological constraint. When it does, we must instead bind + * blockedTy to a fresh type belonging to an appropriate scope. + * + * To determine which scope is appropriate, we also accept rootTy, which is + * to be the type that contains blockedTy. + */ + void bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, Location location); + /** * Marks a constraint as being blocked on a type or type pack. The constraint * solver will not attempt to dispatch blocked constraints until their diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 1fa2e03c7..a3b9c4172 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -81,13 +81,29 @@ struct Module DenseHashMap astTypePacks{nullptr}; DenseHashMap astExpectedTypes{nullptr}; + // For AST nodes that are function calls, this map provides the + // unspecialized type of the function that was called. If a function call + // resolves to a __call metamethod application, this map will point at that + // metamethod. + // + // This is useful for type checking and Signature Help. DenseHashMap astOriginalCallTypes{nullptr}; + + // The specialization of a function that was selected. If the function is + // generic, those generic type parameters will be replaced with the actual + // types that were passed. If the function is an overload, this map will + // point at the specific overloads that were selected. DenseHashMap astOverloadResolvedTypes{nullptr}; + // Only used with for...in loops. The computed type of the next() function + // is kept here for type checking. + DenseHashMap astForInNextTypes{nullptr}; + DenseHashMap astResolvedTypes{nullptr}; DenseHashMap astResolvedTypePacks{nullptr}; - // Map AST nodes to the scope they create. Cannot be NotNull because we need a sentinel value for the map. + // Map AST nodes to the scope they create. Cannot be NotNull because + // we need a sentinel value for the map. DenseHashMap astScopes{nullptr}; std::unordered_map declaredGlobals; @@ -103,8 +119,9 @@ struct Module bool hasModuleScope() const; ScopePtr getModuleScope() const; - // Once a module has been typechecked, we clone its public interface into a separate arena. - // This helps us to force Type ownership into a DAG rather than a DCG. + // Once a module has been typechecked, we clone its public interface into a + // separate arena. This helps us to force Type ownership into a DAG rather + // than a DCG. void clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice); }; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 72be0832b..1a252a88e 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -207,8 +207,6 @@ struct NormalizedFunctionType struct NormalizedType; using NormalizedTyvars = std::unordered_map>; -bool isInhabited_DEPRECATED(const NormalizedType& norm); - // A normalized type is either any, unknown, or one of the form P | T | F | G where // * P is a union of primitive types (including singletons, classes and the error type) // * T is a union of table types diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 626c93ad1..398d0ab68 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -69,6 +69,19 @@ struct TarjanWorklistVertex int lastEdge; }; +struct TarjanNode +{ + TypeId ty; + TypePackId tp; + + bool onStack; + bool dirty; + + // Tarjan calculates the lowlink for each vertex, + // which is the lowest ancestor index reachable from the vertex. + int lowlink; +}; + // Tarjan's algorithm for finding the SCCs in a cyclic structure. // https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm struct Tarjan @@ -76,17 +89,12 @@ struct Tarjan // Vertices (types and type packs) are indexed, using pre-order traversal. DenseHashMap typeToIndex{nullptr}; DenseHashMap packToIndex{nullptr}; - std::vector indexToType; - std::vector indexToPack; + + std::vector nodes; // Tarjan keeps a stack of vertices where we're still in the process // of finding their SCC. std::vector stack; - std::vector onStack; - - // Tarjan calculates the lowlink for each vertex, - // which is the lowest ancestor index reachable from the vertex. - std::vector lowlink; int childCount = 0; int childLimit = 0; @@ -98,6 +106,7 @@ struct Tarjan std::vector edgesTy; std::vector edgesTp; std::vector worklist; + // This is hot code, so we optimize recursion to a stack. TarjanResult loop(); @@ -124,10 +133,22 @@ struct Tarjan TarjanResult visitRoot(TypeId ty); TarjanResult visitRoot(TypePackId ty); - // Each subclass gets called back once for each edge, - // and once for each SCC. - virtual void visitEdge(int index, int parentIndex) {} - virtual void visitSCC(int index) {} + void clearTarjan(); + + // Get/set the dirty bit for an index (grows the vector if needed) + bool getDirty(int index); + void setDirty(int index, bool d); + + // Find all the dirty vertices reachable from `t`. + TarjanResult findDirty(TypeId t); + TarjanResult findDirty(TypePackId t); + + // We find dirty vertices using Tarjan + void visitEdge(int index, int parentIndex); + void visitSCC(int index); + + TarjanResult loop_DEPRECATED(); + void visitSCC_DEPRECATED(int index); // Each subclass can decide to ignore some nodes. virtual bool ignoreChildren(TypeId ty) @@ -150,27 +171,6 @@ struct Tarjan { return ignoreChildren(ty); } -}; - -// We use Tarjan to calculate dirty bits. We set `dirty[i]` true -// if the vertex with index `i` can reach a dirty vertex. -struct FindDirty : Tarjan -{ - std::vector dirty; - - void clearTarjan(); - - // Get/set the dirty bit for an index (grows the vector if needed) - bool getDirty(int index); - void setDirty(int index, bool d); - - // Find all the dirty vertices reachable from `t`. - TarjanResult findDirty(TypeId t); - TarjanResult findDirty(TypePackId t); - - // We find dirty vertices using Tarjan - void visitEdge(int index, int parentIndex) override; - void visitSCC(int index) override; // Subclasses should say which vertices are dirty, // and what to do with dirty vertices. @@ -178,11 +178,18 @@ struct FindDirty : Tarjan virtual bool isDirty(TypePackId tp) = 0; virtual void foundDirty(TypeId ty) = 0; virtual void foundDirty(TypePackId tp) = 0; + + // TODO: remove with FFlagLuauTarjanSingleArr + std::vector indexToType; + std::vector indexToPack; + std::vector onStack; + std::vector lowlink; + std::vector dirty; }; // And finally substitution, which finds all the reachable dirty vertices // and replaces them with clean ones. -struct Substitution : FindDirty +struct Substitution : Tarjan { protected: Substitution(const TxnLog* log_, TypeArena* arena) diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 1a721c743..9902e5a1e 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -19,8 +19,6 @@ #include #include -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 5ead2fa4a..84916cd24 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -64,4 +64,13 @@ const T* get(std::optional ty) return nullptr; } +template +std::optional follow(std::optional ty) +{ + if (ty) + return follow(*ty); + else + return std::nullopt; +} + } // namespace Luau diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp index 15dd25cc5..741d2141d 100644 --- a/Analysis/src/Anyification.cpp +++ b/Analysis/src/Anyification.cpp @@ -6,8 +6,6 @@ #include "Luau/Normalize.h" #include "Luau/TxnLog.h" -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { @@ -78,7 +76,7 @@ TypePackId Anyification::clean(TypePackId tp) bool Anyification::ignoreChildren(TypeId ty) { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return ty->persistent; diff --git a/Analysis/src/ApplyTypeFunction.cpp b/Analysis/src/ApplyTypeFunction.cpp index fe8cc8ac3..025e8f6db 100644 --- a/Analysis/src/ApplyTypeFunction.cpp +++ b/Analysis/src/ApplyTypeFunction.cpp @@ -2,8 +2,6 @@ #include "Luau/ApplyTypeFunction.h" -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { @@ -33,7 +31,7 @@ bool ApplyTypeFunction::ignoreChildren(TypeId ty) { if (get(ty)) return true; - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (get(ty)) return true; else return false; diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 07dba9219..429f1a4db 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -751,7 +751,12 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* f variableTypes.reserve(forIn->vars.size); for (AstLocal* var : forIn->vars) { - TypeId ty = freshType(loopScope); + TypeId ty = nullptr; + if (var->annotation) + ty = resolveType(loopScope, var->annotation, /*inTypeArguments*/ false); + else + ty = freshType(loopScope); + loopScope->bindings[var] = Binding{ty, var->location}; variableTypes.push_back(ty); @@ -763,7 +768,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* f TypePackId variablePack = arena->addTypePack(std::move(variableTypes), arena->addTypePack(FreeTypePack{loopScope.get()})); addConstraint( - loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astOverloadResolvedTypes}); + loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astForInNextTypes}); visit(loopScope, forIn->body); return ControlFlow::None; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index b85d2c59c..81e6574ad 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1453,7 +1453,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullty.emplace(result.value_or(builtinTypes->anyType)); + bindBlockedType(c.resultType, result.value_or(builtinTypes->anyType), c.subjectType, constraint->location); unblock(c.resultType, constraint->location); return true; } @@ -1559,8 +1559,8 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullty.emplace(b); + auto bind = [&](TypeId a, TypeId b) { + bindBlockedType(a, b, c.subjectType, constraint->location); }; if (existingPropType) @@ -2143,7 +2143,9 @@ bool ConstraintSolver::tryDispatchIterableFunction( // if there are no errors from unifying the two, we can pass forward the expected type as our selected resolution. if (errors.empty()) - (*c.astOverloadResolvedTypes)[c.nextAstFragment] = expectedNextTy; + { + (*c.astForInNextTypes)[c.nextAstFragment] = expectedNextTy; + } auto it = begin(nextRetPack); std::vector modifiedNextRetHead; @@ -2380,6 +2382,31 @@ bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, return true; } +void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, Location location) +{ + resultTy = follow(resultTy); + + LUAU_ASSERT(get(blockedTy)); + + if (blockedTy == resultTy) + { + rootTy = follow(rootTy); + Scope* freeScope = nullptr; + if (auto ft = get(rootTy)) + freeScope = ft->scope; + else if (auto tt = get(rootTy); tt && tt->state == TableState::Free) + freeScope = tt->scope; + else + iceReporter.ice("bindBlockedType couldn't find an appropriate scope for a fresh type!", location); + + LUAU_ASSERT(freeScope); + + asMutable(blockedTy)->ty.emplace(arena->freshType(freeScope)); + } + else + asMutable(blockedTy)->ty.emplace(resultTy); +} + void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index fdc19d0c4..fba3c88a3 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauTypeMismatchInvarianceInError, false) static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) @@ -106,7 +105,7 @@ struct ErrorConverter { result += "; " + tm.reason; } - else if (FFlag::LuauTypeMismatchInvarianceInError && tm.context == TypeMismatch::InvariantContext) + else if (tm.context == TypeMismatch::InvariantContext) { result += " in an invariant context"; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index d7077ba66..f88425b55 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -950,6 +950,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); module->astOverloadResolvedTypes.clear(); + module->astForInNextTypes.clear(); module->astResolvedTypes.clear(); module->astResolvedTypePacks.clear(); module->astScopes.clear(); diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 1d6092f87..52b4aa8cc 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -4,8 +4,6 @@ #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { @@ -33,7 +31,7 @@ bool Instantiation::ignoreChildren(TypeId ty) { if (log->getMutable(ty)) return true; - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (get(ty)) return true; else return false; @@ -84,7 +82,7 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) // whenever we quantify, so the vectors overlap if and only if they are equal. return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); } - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (get(ty)) return true; else { diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 37af00401..473b8acc4 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -16,9 +16,6 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess2, false); -LUAU_FASTFLAG(LuauSubstitutionReentrant); -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); -LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); LUAU_FASTFLAGVARIABLE(LuauCloneSkipNonInternalVisit, false); namespace Luau @@ -134,8 +131,6 @@ struct ClonePublicInterface : Substitution TypeId cloneType(TypeId ty) { - LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); - std::optional result = substitute(ty); if (result) { @@ -150,8 +145,6 @@ struct ClonePublicInterface : Substitution TypePackId cloneTypePack(TypePackId tp) { - LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); - std::optional result = substitute(tp); if (result) { @@ -166,8 +159,6 @@ struct ClonePublicInterface : Substitution TypeFun cloneTypeFun(const TypeFun& tf) { - LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); - std::vector typeParams; std::vector typePackParams; diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index e4f22f331..a7e3bb6e7 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -19,7 +19,6 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauTransitiveSubtyping) LUAU_FASTFLAG(DebugLuauReadWriteProperties) @@ -312,12 +311,6 @@ static bool isShallowInhabited(const NormalizedType& norm) !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } -bool isInhabited_DEPRECATED(const NormalizedType& norm) -{ - LUAU_ASSERT(!FFlag::LuauUninhabitedSubAnything2); - return isShallowInhabited(norm); -} - bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_set seen) { // If normalization failed, the type is complex, and so is more likely than not to be inhabited. diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 3528d5345..f7ed7619a 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -10,7 +10,6 @@ LUAU_FASTFLAG(DebugLuauSharedSelf) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) namespace Luau { @@ -244,7 +243,7 @@ struct PureQuantifier : Substitution bool ignoreChildren(TypeId ty) override { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return ty->persistent; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 655881a3f..6c1908bf6 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,13 +8,11 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false) LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) -LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false) -LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTFLAG(LuauCloneSkipNonInternalVisit) +LUAU_FASTFLAGVARIABLE(LuauTarjanSingleArr, false) namespace Luau { @@ -113,20 +111,35 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a else if constexpr (std::is_same_v) return dest.addType(a); else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); return ty; + } else if constexpr (std::is_same_v) { PendingExpansionType clone = PendingExpansionType{a.prefix, a.name, a.typeArguments, a.packArguments}; return dest.addType(std::move(clone)); } else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); return ty; + } else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); return ty; + } else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); return ty; + } else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); return ty; + } else if constexpr (std::is_same_v) return ty; else if constexpr (std::is_same_v) @@ -227,13 +240,10 @@ void Tarjan::visitChildren(TypeId ty, int index) if (const FunctionType* ftv = get(ty)) { - if (FFlag::LuauSubstitutionFixMissingFields) - { - for (TypeId generic : ftv->generics) - visitChild(generic); - for (TypePackId genericPack : ftv->genericPacks) - visitChild(genericPack); - } + for (TypeId generic : ftv->generics) + visitChild(generic); + for (TypePackId genericPack : ftv->genericPacks) + visitChild(genericPack); visitChild(ftv->argTypes); visitChild(ftv->retTypes); @@ -295,7 +305,7 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId a : tfit->packArguments) visitChild(a); } - else if (const ClassType* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + else if (const ClassType* ctv = get(ty)) { for (const auto& [name, prop] : ctv->props) visitChild(prop.type()); @@ -348,36 +358,67 @@ std::pair Tarjan::indexify(TypeId ty) { ty = log->follow(ty); - bool fresh = !typeToIndex.contains(ty); - int& index = typeToIndex[ty]; + if (FFlag::LuauTarjanSingleArr) + { + auto [index, fresh] = typeToIndex.try_insert(ty, false); - if (fresh) + if (fresh) + { + index = int(nodes.size()); + nodes.push_back({ty, nullptr, false, false, index}); + } + + return {index, fresh}; + } + else { - index = int(indexToType.size()); - indexToType.push_back(ty); - indexToPack.push_back(nullptr); - onStack.push_back(false); - lowlink.push_back(index); + bool fresh = !typeToIndex.contains(ty); + int& index = typeToIndex[ty]; + + if (fresh) + { + index = int(indexToType.size()); + indexToType.push_back(ty); + indexToPack.push_back(nullptr); + onStack.push_back(false); + lowlink.push_back(index); + } + return {index, fresh}; } - return {index, fresh}; } std::pair Tarjan::indexify(TypePackId tp) { tp = log->follow(tp); - bool fresh = !packToIndex.contains(tp); - int& index = packToIndex[tp]; + if (FFlag::LuauTarjanSingleArr) + { + auto [index, fresh] = packToIndex.try_insert(tp, false); + + if (fresh) + { + index = int(nodes.size()); + nodes.push_back({nullptr, tp, false, false, index}); + } - if (fresh) + return {index, fresh}; + } + else { - index = int(indexToPack.size()); - indexToType.push_back(nullptr); - indexToPack.push_back(tp); - onStack.push_back(false); - lowlink.push_back(index); + + bool fresh = !packToIndex.contains(tp); + int& index = packToIndex[tp]; + + if (fresh) + { + index = int(indexToPack.size()); + indexToType.push_back(nullptr); + indexToPack.push_back(tp); + onStack.push_back(false); + lowlink.push_back(index); + } + return {index, fresh}; } - return {index, fresh}; } void Tarjan::visitChild(TypeId ty) @@ -398,6 +439,9 @@ void Tarjan::visitChild(TypePackId tp) TarjanResult Tarjan::loop() { + if (!FFlag::LuauTarjanSingleArr) + return loop_DEPRECATED(); + // Normally Tarjan is presented recursively, but this is a hot loop, so worth optimizing while (!worklist.empty()) { @@ -411,14 +455,15 @@ TarjanResult Tarjan::loop() return TarjanResult::TooManyChildren; stack.push_back(index); - onStack[index] = true; + + nodes[index].onStack = true; currEdge = int(edgesTy.size()); // Fill in edge list of this vertex - if (TypeId ty = indexToType[index]) + if (TypeId ty = nodes[index].ty) visitChildren(ty, index); - else if (TypePackId tp = indexToPack[index]) + else if (TypePackId tp = nodes[index].tp) visitChildren(tp, index); lastEdge = int(edgesTy.size()); @@ -449,9 +494,9 @@ TarjanResult Tarjan::loop() foundFresh = true; break; } - else if (onStack[childIndex]) + else if (nodes[childIndex].onStack) { - lowlink[index] = std::min(lowlink[index], childIndex); + nodes[index].lowlink = std::min(nodes[index].lowlink, childIndex); } visitEdge(childIndex, index); @@ -460,14 +505,14 @@ TarjanResult Tarjan::loop() if (foundFresh) continue; - if (lowlink[index] == index) + if (nodes[index].lowlink == index) { visitSCC(index); while (!stack.empty()) { int popped = stack.back(); stack.pop_back(); - onStack[popped] = false; + nodes[popped].onStack = false; if (popped == index) break; } @@ -484,7 +529,7 @@ TarjanResult Tarjan::loop() edgesTy.resize(parentEndEdge); edgesTp.resize(parentEndEdge); - lowlink[parentIndex] = std::min(lowlink[parentIndex], lowlink[index]); + nodes[parentIndex].lowlink = std::min(nodes[parentIndex].lowlink, nodes[index].lowlink); visitEdge(index, parentIndex); } } @@ -518,54 +563,87 @@ TarjanResult Tarjan::visitRoot(TypePackId tp) return loop(); } -void FindDirty::clearTarjan() +void Tarjan::clearTarjan() { - dirty.clear(); + if (FFlag::LuauTarjanSingleArr) + { + typeToIndex.clear(); + packToIndex.clear(); + nodes.clear(); - typeToIndex.clear(); - packToIndex.clear(); - indexToType.clear(); - indexToPack.clear(); + stack.clear(); + } + else + { + dirty.clear(); - stack.clear(); - onStack.clear(); - lowlink.clear(); + typeToIndex.clear(); + packToIndex.clear(); + indexToType.clear(); + indexToPack.clear(); + + stack.clear(); + onStack.clear(); + lowlink.clear(); + } edgesTy.clear(); edgesTp.clear(); worklist.clear(); } -bool FindDirty::getDirty(int index) +bool Tarjan::getDirty(int index) { - if (dirty.size() <= size_t(index)) - dirty.resize(index + 1, false); - return dirty[index]; + if (FFlag::LuauTarjanSingleArr) + { + LUAU_ASSERT(size_t(index) < nodes.size()); + return nodes[index].dirty; + } + else + { + if (dirty.size() <= size_t(index)) + dirty.resize(index + 1, false); + return dirty[index]; + } } -void FindDirty::setDirty(int index, bool d) +void Tarjan::setDirty(int index, bool d) { - if (dirty.size() <= size_t(index)) - dirty.resize(index + 1, false); - dirty[index] = d; + if (FFlag::LuauTarjanSingleArr) + { + LUAU_ASSERT(size_t(index) < nodes.size()); + nodes[index].dirty = d; + } + else + { + if (dirty.size() <= size_t(index)) + dirty.resize(index + 1, false); + dirty[index] = d; + } } -void FindDirty::visitEdge(int index, int parentIndex) +void Tarjan::visitEdge(int index, int parentIndex) { if (getDirty(index)) setDirty(parentIndex, true); } -void FindDirty::visitSCC(int index) +void Tarjan::visitSCC(int index) { + if (!FFlag::LuauTarjanSingleArr) + return visitSCC_DEPRECATED(index); + bool d = getDirty(index); for (auto it = stack.rbegin(); !d && it != stack.rend(); it++) { - if (TypeId ty = indexToType[*it]) + TarjanNode& node = nodes[*it]; + + if (TypeId ty = node.ty) d = isDirty(ty); - else if (TypePackId tp = indexToPack[*it]) + else if (TypePackId tp = node.tp) d = isDirty(tp); + if (*it == index) break; } @@ -576,32 +654,161 @@ void FindDirty::visitSCC(int index) for (auto it = stack.rbegin(); it != stack.rend(); it++) { setDirty(*it, true); - if (TypeId ty = indexToType[*it]) + + TarjanNode& node = nodes[*it]; + + if (TypeId ty = node.ty) foundDirty(ty); - else if (TypePackId tp = indexToPack[*it]) + else if (TypePackId tp = node.tp) foundDirty(tp); + if (*it == index) return; } } -TarjanResult FindDirty::findDirty(TypeId ty) +TarjanResult Tarjan::findDirty(TypeId ty) { return visitRoot(ty); } -TarjanResult FindDirty::findDirty(TypePackId tp) +TarjanResult Tarjan::findDirty(TypePackId tp) { return visitRoot(tp); } +TarjanResult Tarjan::loop_DEPRECATED() +{ + // Normally Tarjan is presented recursively, but this is a hot loop, so worth optimizing + while (!worklist.empty()) + { + auto [index, currEdge, lastEdge] = worklist.back(); + + // First visit + if (currEdge == -1) + { + ++childCount; + if (childLimit > 0 && childLimit <= childCount) + return TarjanResult::TooManyChildren; + + stack.push_back(index); + onStack[index] = true; + + currEdge = int(edgesTy.size()); + + // Fill in edge list of this vertex + if (TypeId ty = indexToType[index]) + visitChildren(ty, index); + else if (TypePackId tp = indexToPack[index]) + visitChildren(tp, index); + + lastEdge = int(edgesTy.size()); + } + + // Visit children + bool foundFresh = false; + + for (; currEdge < lastEdge; currEdge++) + { + int childIndex = -1; + bool fresh = false; + + if (auto ty = edgesTy[currEdge]) + std::tie(childIndex, fresh) = indexify(ty); + else if (auto tp = edgesTp[currEdge]) + std::tie(childIndex, fresh) = indexify(tp); + else + LUAU_ASSERT(false); + + if (fresh) + { + // Original recursion point, update the parent continuation point and start the new element + worklist.back() = {index, currEdge + 1, lastEdge}; + worklist.push_back({childIndex, -1, -1}); + + // We need to continue the top-level loop from the start with the new worklist element + foundFresh = true; + break; + } + else if (onStack[childIndex]) + { + lowlink[index] = std::min(lowlink[index], childIndex); + } + + visitEdge(childIndex, index); + } + + if (foundFresh) + continue; + + if (lowlink[index] == index) + { + visitSCC(index); + while (!stack.empty()) + { + int popped = stack.back(); + stack.pop_back(); + onStack[popped] = false; + if (popped == index) + break; + } + } + + worklist.pop_back(); + + // Original return from recursion into a child + if (!worklist.empty()) + { + auto [parentIndex, _, parentEndEdge] = worklist.back(); + + // No need to keep child edges around + edgesTy.resize(parentEndEdge); + edgesTp.resize(parentEndEdge); + + lowlink[parentIndex] = std::min(lowlink[parentIndex], lowlink[index]); + visitEdge(index, parentIndex); + } + } + + return TarjanResult::Ok; +} + + +void Tarjan::visitSCC_DEPRECATED(int index) +{ + bool d = getDirty(index); + + for (auto it = stack.rbegin(); !d && it != stack.rend(); it++) + { + if (TypeId ty = indexToType[*it]) + d = isDirty(ty); + else if (TypePackId tp = indexToPack[*it]) + d = isDirty(tp); + if (*it == index) + break; + } + + if (!d) + return; + + for (auto it = stack.rbegin(); it != stack.rend(); it++) + { + setDirty(*it, true); + if (TypeId ty = indexToType[*it]) + foundDirty(ty); + else if (TypePackId tp = indexToPack[*it]) + foundDirty(tp); + if (*it == index) + return; + } +} + std::optional Substitution::substitute(TypeId ty) { ty = log->follow(ty); // clear algorithm state for reentrancy - if (FFlag::LuauSubstitutionReentrant) - clearTarjan(); + clearTarjan(); auto result = findDirty(ty); if (result != TarjanResult::Ok) @@ -609,34 +816,18 @@ std::optional Substitution::substitute(TypeId ty) for (auto [oldTy, newTy] : newTypes) { - if (FFlag::LuauSubstitutionReentrant) + if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) { - if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) - { - replaceChildren(newTy); - replacedTypes.insert(newTy); - } - } - else - { - if (!ignoreChildren(oldTy)) - replaceChildren(newTy); + replaceChildren(newTy); + replacedTypes.insert(newTy); } } for (auto [oldTp, newTp] : newPacks) { - if (FFlag::LuauSubstitutionReentrant) - { - if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) - { - replaceChildren(newTp); - replacedTypePacks.insert(newTp); - } - } - else + if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) { - if (!ignoreChildren(oldTp)) - replaceChildren(newTp); + replaceChildren(newTp); + replacedTypePacks.insert(newTp); } } TypeId newTy = replace(ty); @@ -648,8 +839,7 @@ std::optional Substitution::substitute(TypePackId tp) tp = log->follow(tp); // clear algorithm state for reentrancy - if (FFlag::LuauSubstitutionReentrant) - clearTarjan(); + clearTarjan(); auto result = findDirty(tp); if (result != TarjanResult::Ok) @@ -657,34 +847,18 @@ std::optional Substitution::substitute(TypePackId tp) for (auto [oldTy, newTy] : newTypes) { - if (FFlag::LuauSubstitutionReentrant) + if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) { - if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) - { - replaceChildren(newTy); - replacedTypes.insert(newTy); - } - } - else - { - if (!ignoreChildren(oldTy)) - replaceChildren(newTy); + replaceChildren(newTy); + replacedTypes.insert(newTy); } } for (auto [oldTp, newTp] : newPacks) { - if (FFlag::LuauSubstitutionReentrant) + if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) { - if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) - { - replaceChildren(newTp); - replacedTypePacks.insert(newTp); - } - } - else - { - if (!ignoreChildren(oldTp)) - replaceChildren(newTp); + replaceChildren(newTp); + replacedTypePacks.insert(newTp); } } TypePackId newTp = replace(tp); @@ -714,8 +888,7 @@ TypePackId Substitution::clone(TypePackId tp) { VariadicTypePack clone; clone.ty = vtp->ty; - if (FFlag::LuauSubstitutionFixMissingFields) - clone.hidden = vtp->hidden; + clone.hidden = vtp->hidden; return addTypePack(std::move(clone)); } else if (const TypeFamilyInstanceTypePack* tfitp = get(tp)) @@ -738,7 +911,7 @@ void Substitution::foundDirty(TypeId ty) { ty = log->follow(ty); - if (FFlag::LuauSubstitutionReentrant && newTypes.contains(ty)) + if (newTypes.contains(ty)) return; if (isDirty(ty)) @@ -751,7 +924,7 @@ void Substitution::foundDirty(TypePackId tp) { tp = log->follow(tp); - if (FFlag::LuauSubstitutionReentrant && newPacks.contains(tp)) + if (newPacks.contains(tp)) return; if (isDirty(tp)) @@ -792,13 +965,10 @@ void Substitution::replaceChildren(TypeId ty) if (FunctionType* ftv = getMutable(ty)) { - if (FFlag::LuauSubstitutionFixMissingFields) - { - for (TypeId& generic : ftv->generics) - generic = replace(generic); - for (TypePackId& genericPack : ftv->genericPacks) - genericPack = replace(genericPack); - } + for (TypeId& generic : ftv->generics) + generic = replace(generic); + for (TypePackId& genericPack : ftv->genericPacks) + genericPack = replace(genericPack); ftv->argTypes = replace(ftv->argTypes); ftv->retTypes = replace(ftv->retTypes); @@ -857,7 +1027,7 @@ void Substitution::replaceChildren(TypeId ty) for (TypePackId& a : tfit->packArguments) a = replace(a); } - else if (ClassType* ctv = getMutable(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + else if (ClassType* ctv = getMutable(ty)) { for (auto& [name, prop] : ctv->props) prop.setType(replace(prop.type())); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 0a9e9b648..7a46bf969 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -7,6 +7,7 @@ #include "Luau/Common.h" #include "Luau/DcrLogger.h" #include "Luau/Error.h" +#include "Luau/InsertionOrderedMap.h" #include "Luau/Instantiation.h" #include "Luau/Metamethods.h" #include "Luau/Normalize.h" @@ -656,7 +657,7 @@ struct TypeChecker2 // if the initial and expected types from the iterator unified during constraint solving, // we'll have a resolved type to use here, but we'll only use it if either the iterator is // directly present in the for-in statement or if we have an iterator state constraining us - TypeId* resolvedTy = module->astOverloadResolvedTypes.find(firstValue); + TypeId* resolvedTy = module->astForInNextTypes.find(firstValue); if (resolvedTy && (!retPack || valueTypes.size() > 1)) valueTypes[0] = *resolvedTy; @@ -1062,83 +1063,21 @@ struct TypeChecker2 // Note: this is intentionally separated from `visit(AstExprCall*)` for stack allocation purposes. void visitCall(AstExprCall* call) { - TypePackId expectedRetType = lookupExpectedPack(call, testArena); TypePack args; std::vector argLocs; argLocs.reserve(call->args.size + 1); - auto maybeOriginalCallTy = module->astOriginalCallTypes.find(call); - if (!maybeOriginalCallTy) + TypeId* originalCallTy = module->astOriginalCallTypes.find(call); + TypeId* selectedOverloadTy = module->astOverloadResolvedTypes.find(call); + if (!originalCallTy && !selectedOverloadTy) return; - TypeId originalCallTy = follow(*maybeOriginalCallTy); - std::vector overloads = flattenIntersection(originalCallTy); - - if (get(originalCallTy) || get(originalCallTy) || get(originalCallTy)) + TypeId fnTy = follow(selectedOverloadTy ? *selectedOverloadTy : *originalCallTy); + if (get(fnTy) || get(fnTy) || get(fnTy)) return; - else if (std::optional callMm = findMetatableEntry(builtinTypes, module->errors, originalCallTy, "__call", call->func->location)) - { - if (get(follow(*callMm))) - { - args.head.push_back(originalCallTy); - argLocs.push_back(call->func->location); - } - else - { - // TODO: This doesn't flag the __call metamethod as the problem - // very clearly. - reportError(CannotCallNonFunction{*callMm}, call->func->location); - return; - } - } - else if (get(originalCallTy)) - { - // ok. - } - else if (get(originalCallTy)) - { - auto norm = normalizer.normalize(originalCallTy); - if (!norm) - return reportError(CodeTooComplex{}, call->location); - - // NormalizedType::hasFunction returns true if its' tops component is `unknown`, but for soundness we want the reverse. - if (get(norm->tops) || !norm->hasFunctions()) - return reportError(CannotCallNonFunction{originalCallTy}, call->func->location); - } - else if (auto utv = get(originalCallTy)) - { - // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. - // Another scenario we might run into it is if the union has a nil member. In this case, we want to throw an error - if (isOptional(originalCallTy)) - { - reportError(OptionalValueAccess{originalCallTy}, call->location); - return; - } - std::optional fst; - for (TypeId ty : utv) - { - if (!fst) - fst = follow(ty); - else if (fst != follow(ty)) - { - reportError(CannotCallNonFunction{originalCallTy}, call->func->location); - return; - } - } - - if (!fst) - ice->ice("UnionType had no elements, so fst is nullopt?"); - - originalCallTy = follow(*fst); - if (!get(originalCallTy)) - { - reportError(CannotCallNonFunction{originalCallTy}, call->func->location); - return; - } - } - else + else if (isOptional(fnTy)) { - reportError(CannotCallNonFunction{originalCallTy}, call->func->location); + reportError(OptionalValueAccess{fnTy}, call->func->location); return; } @@ -1161,9 +1100,12 @@ struct TypeChecker2 args.head.push_back(*argTy); else if (i == call->args.size - 1) { - TypePackId* argTail = module->astTypePacks.find(arg); - if (argTail) - args.tail = *argTail; + if (auto argTail = module->astTypePacks.find(arg)) + { + auto [head, tail] = flatten(*argTail); + args.head.insert(args.head.end(), head.begin(), head.end()); + args.tail = tail; + } else args.tail = builtinTypes->anyTypePack; } @@ -1171,141 +1113,317 @@ struct TypeChecker2 args.head.push_back(builtinTypes->anyType); } - TypePackId expectedArgTypes = testArena.addTypePack(args); + FunctionCallResolver resolver{ + builtinTypes, + NotNull{&testArena}, + NotNull{&normalizer}, + NotNull{stack.back()}, + ice, + call->location, + }; + + resolver.resolve(fnTy, &args, call->func->location, &argLocs); - if (auto maybeSelectedOverload = module->astOverloadResolvedTypes.find(call)) + if (!resolver.ok.empty()) + return; // We found a call that works, so this is ok. + else if (auto norm = normalizer.normalize(fnTy); !norm || !normalizer.isInhabited(norm)) { - // This overload might not work still: the constraint solver will - // pass the type checker an instantiated function type that matches - // in arity, but not in subtyping, in order to allow the type - // checker to report better error messages. + if (!norm) + reportError(NormalizationTooComplex{}, call->func->location); + else + return; // Ok. Calling an uninhabited type is no-op. + } + else if (!resolver.nonviableOverloads.empty()) + { + if (resolver.nonviableOverloads.size() == 1) + reportErrors(resolver.nonviableOverloads.front().second); + else + { + std::string s = "None of the overloads for function that accept "; + s += std::to_string(args.head.size()); + s += " arguments are compatible."; + reportError(GenericError{std::move(s)}, call->location); + } + } + else if (!resolver.arityMismatches.empty()) + { + if (resolver.arityMismatches.size() == 1) + reportErrors(resolver.arityMismatches.front().second); + else + { + std::string s = "No overload for function accepts "; + s += std::to_string(args.head.size()); + s += " arguments."; + reportError(GenericError{std::move(s)}, call->location); + } + } + else if (!resolver.nonFunctions.empty()) + reportError(CannotCallNonFunction{fnTy}, call->func->location); + else + LUAU_ASSERT(!"Generating the best possible error from this function call resolution was inexhaustive?"); - TypeId selectedOverload = follow(*maybeSelectedOverload); - const FunctionType* ftv; + if (resolver.arityMismatches.size() > 1 || resolver.nonviableOverloads.size() > 1) + { + std::string s = "Available overloads: "; - if (get(selectedOverload) || get(selectedOverload) || get(selectedOverload)) + std::vector overloads; + if (resolver.nonviableOverloads.empty()) { - return; + for (const auto& [ty, p] : resolver.resolution) + { + if (p.first == FunctionCallResolver::TypeIsNotAFunction) + continue; + + overloads.push_back(ty); + } } - else if (const FunctionType* overloadFtv = get(selectedOverload)) + else { - ftv = overloadFtv; + for (const auto& [ty, _] : resolver.nonviableOverloads) + overloads.push_back(ty); } - else + + for (size_t i = 0; i < overloads.size(); ++i) { - reportError(CannotCallNonFunction{selectedOverload}, call->func->location); - return; + if (i > 0) + s += (i == overloads.size() - 1) ? "; and " : "; "; + + s += toString(overloads[i]); } - TxnLog fake{}; + reportError(ExtraInformation{std::move(s)}, call->func->location); + } + } - LUAU_ASSERT(ftv); - reportErrors(tryUnify(stack.back(), call->location, ftv->retTypes, expectedRetType, CountMismatch::Context::Return, /* genericsOkay */ true)); - reportErrors( - reduceFamilies(ftv->retTypes, call->location, NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true) - .errors); + struct FunctionCallResolver + { + enum Analysis + { + Ok, + TypeIsNotAFunction, + ArityMismatch, + OverloadIsNonviable, // Arguments were incompatible with the overload's parameters, but were otherwise compatible by arity. + }; - auto it = begin(expectedArgTypes); - size_t i = 0; - std::vector slice; - for (TypeId arg : ftv->argTypes) - { - if (it == end(expectedArgTypes)) - { - slice.push_back(arg); - continue; - } + NotNull builtinTypes; + NotNull arena; + NotNull normalizer; + NotNull scope; + NotNull ice; + Location callLoc; + + std::vector ok; + std::vector nonFunctions; + std::vector> arityMismatches; + std::vector> nonviableOverloads; + InsertionOrderedMap> resolution; + + private: + template + std::optional tryUnify(const Location& location, Ty subTy, Ty superTy) + { + Unifier u{normalizer, scope, location, Covariant}; + u.ctx = CountMismatch::Arg; + u.hideousFixMeGenericsAreActuallyFree = true; + u.enableScopeTests(); + u.tryUnify(subTy, superTy); + + if (u.errors.empty()) + return std::nullopt; + + return std::move(u.errors); + } - TypeId expectedArg = *it; + std::pair checkOverload(TypeId fnTy, const TypePack* args, Location fnLoc, const std::vector* argLocs, bool callMetamethodOk = true) + { + fnTy = follow(fnTy); - Location argLoc = argLocs.at(i >= argLocs.size() ? argLocs.size() - 1 : i); + ErrorVec discard; + if (get(fnTy) || get(fnTy) || get(fnTy)) + return {Ok, {}}; + else if (auto fn = get(fnTy)) + return checkOverload_(fnTy, fn, args, fnLoc, argLocs); // Intentionally split to reduce the stack pressure of this function. + else if (auto callMm = findMetatableEntry(builtinTypes, discard, fnTy, "__call", callLoc); callMm && callMetamethodOk) + { + // Calling a metamethod forwards the `fnTy` as self. + TypePack withSelf = *args; + withSelf.head.insert(withSelf.head.begin(), fnTy); - reportErrors(tryUnify(stack.back(), argLoc, expectedArg, arg, CountMismatch::Context::Arg, /* genericsOkay */ true)); - reportErrors(reduceFamilies(arg, argLoc, NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true).errors); + std::vector withSelfLocs = *argLocs; + withSelfLocs.insert(withSelfLocs.begin(), fnLoc); - ++it; - ++i; + return checkOverload(*callMm, &withSelf, fnLoc, &withSelfLocs, /*callMetamethodOk=*/ false); } + else + return {TypeIsNotAFunction, {}}; // Intentionally empty. We can just fabricate the type error later on. + } + + LUAU_NOINLINE + std::pair checkOverload_(TypeId fnTy, const FunctionType* fn, const TypePack* args, Location fnLoc, const std::vector* argLocs) + { + TxnLog fake; + FamilyGraphReductionResult result = reduceFamilies(fnTy, callLoc, arena, builtinTypes, scope, normalizer, &fake, /*force=*/ true); + if (!result.errors.empty()) + return {OverloadIsNonviable, result.errors}; + + ErrorVec argumentErrors; + + // Reminder: Functions have parameters. You provide arguments. + auto paramIter = begin(fn->argTypes); + size_t argOffset = 0; - if (slice.size() > 0 && it == end(expectedArgTypes)) + while (paramIter != end(fn->argTypes)) { - if (auto tail = it.tail()) + if (argOffset >= args->head.size()) + break; + + TypeId paramTy = *paramIter; + TypeId argTy = args->head[argOffset]; + Location argLoc = argLocs->at(argOffset >= argLocs->size() ? argLocs->size() - 1 : argOffset); + + if (auto errors = tryUnify(argLoc, argTy, paramTy)) { - TypePackId remainingArgs = testArena.addTypePack(TypePack{std::move(slice), std::nullopt}); - reportErrors(tryUnify(stack.back(), argLocs.back(), *tail, remainingArgs, CountMismatch::Context::Arg, /* genericsOkay */ true)); - reportErrors(reduceFamilies( - remainingArgs, argLocs.back(), NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true) - .errors); + // Since we're stopping right here, we need to decide if this is a nonviable overload or if there is an arity mismatch. + // If it's a nonviable overload, then we need to keep going to get all type errors. + auto [minParams, optMaxParams] = getParameterExtents(TxnLog::empty(), fn->argTypes); + if (args->head.size() < minParams) + return {ArityMismatch, *errors}; + else + argumentErrors.insert(argumentErrors.end(), errors->begin(), errors->end()); } + + ++paramIter; + ++argOffset; } - } - else - { - // No overload worked, even when instantiated. We need to filter the - // set of overloads to those that match the arity of the incoming - // argument set, and then report only those as not matching. - std::vector arityMatchingOverloads; - ErrorVec empty; - for (TypeId overload : overloads) + while (argOffset < args->head.size()) { - overload = follow(overload); - if (const FunctionType* ftv = get(overload)) + // If we can iterate over the head of arguments, then we have exhausted the head of the parameters. + LUAU_ASSERT(paramIter == end(fn->argTypes)); + + Location argLoc = argLocs->at(argOffset >= argLocs->size() ? argLocs->size() - 1 : argOffset); + + if (!paramIter.tail()) { - if (size(ftv->argTypes) == size(expectedArgTypes)) - { - arityMatchingOverloads.push_back(overload); - } + auto [minParams, optMaxParams] = getParameterExtents(TxnLog::empty(), fn->argTypes); + TypeError error{argLoc, CountMismatch{minParams, optMaxParams, args->head.size(), CountMismatch::Arg, false}}; + return {ArityMismatch, {error}}; } - else if (const std::optional callMm = findMetatableEntry(builtinTypes, empty, overload, "__call", call->location)) + else if (auto vtp = get(follow(paramIter.tail()))) { - if (const FunctionType* ftv = get(follow(*callMm))) - { - if (size(ftv->argTypes) == size(expectedArgTypes)) - { - arityMatchingOverloads.push_back(overload); - } - } - else - { - reportError(CannotCallNonFunction{}, call->location); - } + if (auto errors = tryUnify(argLoc, args->head[argOffset], vtp->ty)) + argumentErrors.insert(argumentErrors.end(), errors->begin(), errors->end()); } + + ++argOffset; } - if (arityMatchingOverloads.size() == 0) + while (paramIter != end(fn->argTypes)) { - reportError( - GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location); + // If we can iterate over parameters, then we have exhausted the head of the arguments. + LUAU_ASSERT(argOffset == args->head.size()); + + // It may have a tail, however, so check that. + if (auto vtp = get(follow(args->tail))) + { + Location argLoc = argLocs->at(argLocs->size() - 1); + + if (auto errors = tryUnify(argLoc, vtp->ty, *paramIter)) + argumentErrors.insert(argumentErrors.end(), errors->begin(), errors->end()); + } + else if (!isOptional(*paramIter)) + { + Location argLoc = argLocs->empty() ? fnLoc : argLocs->at(argLocs->size() - 1); + + // It is ok to have excess parameters as long as they are all optional. + auto [minParams, optMaxParams] = getParameterExtents(TxnLog::empty(), fn->argTypes); + TypeError error{argLoc, CountMismatch{minParams, optMaxParams, args->head.size(), CountMismatch::Arg, false}}; + return {ArityMismatch, {error}}; + } + + ++paramIter; } - else + + // We hit the end of the heads for both parameters and arguments, so check their tails. + LUAU_ASSERT(paramIter == end(fn->argTypes)); + LUAU_ASSERT(argOffset == args->head.size()); + + if (paramIter.tail() && args->tail) { - // We have handled the case of a singular arity-matching - // overload above, in the case where an overload was selected. - // LUAU_ASSERT(arityMatchingOverloads.size() > 1); - reportError(GenericError{"None of the overloads for function that accept " + std::to_string(size(expectedArgTypes)) + - " arguments are compatible."}, - call->location); + Location argLoc = argLocs->at(argLocs->size() - 1); + + if (auto errors = tryUnify(argLoc, *args->tail, *paramIter.tail())) + argumentErrors.insert(argumentErrors.end(), errors->begin(), errors->end()); } - std::string s; - std::vector& stringifyOverloads = arityMatchingOverloads.size() == 0 ? overloads : arityMatchingOverloads; - for (size_t i = 0; i < stringifyOverloads.size(); ++i) + return {argumentErrors.empty() ? Ok : OverloadIsNonviable, argumentErrors}; + } + + size_t indexof(Analysis analysis) + { + switch (analysis) { - TypeId overload = follow(stringifyOverloads[i]); + case Ok: + return ok.size(); + case TypeIsNotAFunction: + return nonFunctions.size(); + case ArityMismatch: + return arityMismatches.size(); + case OverloadIsNonviable: + return nonviableOverloads.size(); + } - if (i > 0) - s += "; "; + ice->ice("Inexhaustive switch in FunctionCallResolver::indexof"); + } - if (i > 0 && i == stringifyOverloads.size() - 1) - s += "and "; + void add(Analysis analysis, TypeId ty, ErrorVec&& errors) + { + resolution.insert(ty, {analysis, indexof(analysis)}); - s += toString(overload); + switch (analysis) + { + case Ok: + LUAU_ASSERT(errors.empty()); + ok.push_back(ty); + break; + case TypeIsNotAFunction: + LUAU_ASSERT(errors.empty()); + nonFunctions.push_back(ty); + break; + case ArityMismatch: + LUAU_ASSERT(!errors.empty()); + arityMismatches.emplace_back(ty, std::move(errors)); + break; + case OverloadIsNonviable: + LUAU_ASSERT(!errors.empty()); + nonviableOverloads.emplace_back(ty, std::move(errors)); + break; } + } + + public: + void resolve(TypeId fnTy, const TypePack* args, Location selfLoc, const std::vector* argLocs) + { + fnTy = follow(fnTy); - reportError(ExtraInformation{"Available overloads: " + s}, call->func->location); + auto it = get(fnTy); + if (!it) + { + auto [analysis, errors] = checkOverload(fnTy, args, selfLoc, argLocs); + add(analysis, fnTy, std::move(errors)); + return; + } + + for (TypeId ty : it) + { + if (resolution.find(ty) != resolution.end()) + continue; + + auto [analysis, errors] = checkOverload(ty, args, selfLoc, argLocs); + add(analysis, ty, std::move(errors)); + } } - } + }; void visit(AstExprCall* call) { @@ -1584,7 +1702,11 @@ struct TypeChecker2 leftType = stripNil(builtinTypes, testArena, leftType); } - bool isStringOperation = isString(leftType) && isString(rightType); + const NormalizedType* normLeft = normalizer.normalize(leftType); + const NormalizedType* normRight = normalizer.normalize(rightType); + + bool isStringOperation = + (normLeft ? normLeft->isSubtypeOfString() : isString(leftType)) && (normRight ? normRight->isSubtypeOfString() : isString(rightType)); if (get(leftType) || get(leftType) || get(leftType)) return leftType; @@ -1630,14 +1752,15 @@ struct TypeChecker2 { testUnion(utv, leftMt); } + } - // If either left or right has no metatable (or both), we need to consider if - // there are values in common that could possibly inhabit the type (and thus equality could be considered) + // If we're working with things that are not tables, the metatable comparisons above are a little excessive + // It's ok for one type to have a meta table and the other to not. In that case, we should fall back on + // checking if the intersection of the types is inhabited. + // TODO: Maybe add more checks here (e.g. for functions, classes, etc) + if (!(get(leftType) || get(rightType))) if (!leftMt.has_value() || !rightMt.has_value()) - { matches = matches || typesHaveIntersection; - } - } if (!matches && isComparison) { @@ -1663,15 +1786,15 @@ struct TypeChecker2 if (overrideKey != nullptr) key = overrideKey; - TypeId instantiatedMm = module->astOverloadResolvedTypes[key]; - if (!instantiatedMm) + TypeId* selectedOverloadTy = module->astOverloadResolvedTypes.find(key); + if (!selectedOverloadTy) { // reportError(CodeTooComplex{}, expr->location); // was handled by a type family return expectedResult; } - else if (const FunctionType* ftv = get(follow(instantiatedMm))) + else if (const FunctionType* ftv = get(follow(*selectedOverloadTy))) { TypePackId expectedArgs; // For >= and > we invoke __lt and __le respectively with @@ -1803,13 +1926,12 @@ struct TypeChecker2 case AstExprBinary::Op::CompareLe: case AstExprBinary::Op::CompareLt: { - const NormalizedType* leftTyNorm = normalizer.normalize(leftType); - if (leftTyNorm && leftTyNorm->isExactlyNumber()) + if (normLeft && normLeft->isExactlyNumber()) { reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); return builtinTypes->numberType; } - else if (leftTyNorm && leftTyNorm->isSubtypeOfString()) + else if (normLeft && normLeft->isSubtypeOfString()) { reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->stringType)); return builtinTypes->stringType; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index a3d917045..c9da34f4c 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,7 +35,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) -LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) @@ -841,7 +840,7 @@ struct Demoter : Substitution bool ignoreChildren(TypeId ty) override { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return false; @@ -2648,10 +2647,7 @@ static std::optional areEqComparable(NotNull arena, NotNullisInhabited(n); - else - return isInhabited_DEPRECATED(*n); + return normalizer->isInhabited(n); } TypeId TypeChecker::checkRelationalOperation( diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 91b89136a..eae007885 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -19,12 +19,10 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) LUAU_FASTFLAGVARIABLE(LuauVariadicAnyCanBeGeneric, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) @@ -315,7 +313,7 @@ TypePackId Widen::clean(TypePackId) bool Widen::ignoreChildren(TypeId ty) { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return !log->is(ty); @@ -748,10 +746,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.get(superTy) || log.get(subTy)) tryUnifyNegations(subTy, superTy); - else if (FFlag::LuauUninhabitedSubAnything2 && checkInhabited && !normalizer->isInhabited(subTy)) + else if (checkInhabited && !normalizer->isInhabited(subTy)) { } - else reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); @@ -2365,7 +2362,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) TypeId osubTy = subTy; TypeId osuperTy = superTy; - if (FFlag::LuauUninhabitedSubAnything2 && checkInhabited && !normalizer->isInhabited(subTy)) + if (checkInhabited && !normalizer->isInhabited(subTy)) return; if (reversed) @@ -2739,7 +2736,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } } } - else if (FFlag::LuauVariadicAnyCanBeGeneric && get(variadicTy) && log.get(subTp)) + else if (get(variadicTy) && log.get(subTp)) { // Nothing to do. This is ok. } @@ -2893,7 +2890,7 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack, bool reversed) if (innerState.failure) { reportError(location, OccursCheckFailed{}); - log.replace(needle, *builtinTypes->errorRecoveryType()); + log.replace(needle, BoundType{builtinTypes->errorRecoveryType()}); } } diff --git a/CLI/Compile.cpp b/CLI/Compile.cpp index 293809d0e..6197f03ef 100644 --- a/CLI/Compile.cpp +++ b/CLI/Compile.cpp @@ -129,7 +129,7 @@ static double recordDeltaTime(double& timer) return delta; } -static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) +static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::AssemblyOptions::Target assemblyTarget, CompileStats& stats) { double currts = Luau::TimeTrace::getClock(); @@ -150,6 +150,7 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st Luau::BytecodeBuilder bcb; Luau::CodeGen::AssemblyOptions options; + options.target = assemblyTarget; options.outputBinary = format == CompileFormat::CodegenNull; if (!options.outputBinary) @@ -248,6 +249,7 @@ static void displayHelp(const char* argv0) printf(" -h, --help: Display this usage message.\n"); printf(" -O: compile with optimization level n (default 1, n should be between 0 and 2).\n"); printf(" -g: compile with debug level n (default 1, n should be between 0 and 2).\n"); + printf(" --target=: compile code for specific architecture (a64, x64, a64_nf, x64_ms).\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); } @@ -264,6 +266,7 @@ int main(int argc, char** argv) setLuauFlagsDefault(); CompileFormat compileFormat = CompileFormat::Text; + Luau::CodeGen::AssemblyOptions::Target assemblyTarget = Luau::CodeGen::AssemblyOptions::Host; for (int i = 1; i < argc; i++) { @@ -292,6 +295,24 @@ int main(int argc, char** argv) } globalOptions.debugLevel = level; } + else if (strncmp(argv[i], "--target=", 9) == 0) + { + const char* value = argv[i] + 9; + + if (strcmp(value, "a64") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::A64; + else if (strcmp(value, "a64_nf") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::A64_NoFeatures; + else if (strcmp(value, "x64") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::X64_SystemV; + else if (strcmp(value, "x64_ms") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::X64_Windows; + else + { + fprintf(stderr, "Error: unknown target\n"); + return 1; + } + } else if (strcmp(argv[i], "--timetrace") == 0) { FFlag::DebugLuauTimeTracing.value = true; @@ -331,7 +352,7 @@ int main(int argc, char** argv) int failed = 0; for (const std::string& path : files) - failed += !compileFile(path.c_str(), compileFormat, stats); + failed += !compileFile(path.c_str(), compileFormat, assemblyTarget, stats); if (compileFormat == CompileFormat::Null) printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), int(stats.bytecode / 1024), diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 30c26b24c..febd021cd 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -23,6 +23,17 @@ using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int in struct AssemblyOptions { + enum Target + { + Host, + A64, + A64_NoFeatures, + X64_Windows, + X64_SystemV, + }; + + Target target = Host; + bool outputBinary = false; bool includeAssembly = false; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 1c79ccb47..8cbe7e8b1 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -414,6 +414,7 @@ enum class IrCmd : uint8_t // Handle GC write barrier (forward) // A: pointer (GCObject) // B: Rn (TValue that was written to the object) + // C: tag/undef (tag of the value that was written) BARRIER_OBJ, // Handle GC write barrier (backwards) for a write into a table @@ -423,6 +424,7 @@ enum class IrCmd : uint8_t // Handle GC write barrier (forward) for a write into a table // A: pointer (Table) // B: Rn (TValue that was written to the object) + // C: tag/undef (tag of the value that was written) BARRIER_TABLE_FORWARD, // Update savedpc value @@ -584,6 +586,14 @@ enum class IrCmd : uint8_t // B: double // C: double/int (optional, 2nd argument) INVOKE_LIBM, + + // Returns the string name of a type based on tag, alternative for type(x) + // A: tag + GET_TYPE, + + // Returns the string name of a type either from a __type metatable field or just based on the tag, alternative for typeof(x) + // A: Rn + GET_TYPEOF, }; enum class IrConstKind : uint8_t diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index a1211d46a..a3e97894c 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -189,6 +189,8 @@ inline bool hasResult(IrCmd cmd) case IrCmd::BITCOUNTLZ_UINT: case IrCmd::BITCOUNTRZ_UINT: case IrCmd::INVOKE_LIBM: + case IrCmd::GET_TYPE: + case IrCmd::GET_TYPEOF: return true; default: break; diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index d7283b4a6..63dd9a4d6 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -1,15 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/CodeGen.h" +#include "CodeGenLower.h" + #include "Luau/Common.h" #include "Luau/CodeAllocator.h" #include "Luau/CodeBlockUnwind.h" -#include "Luau/IrAnalysis.h" #include "Luau/IrBuilder.h" -#include "Luau/IrDump.h" -#include "Luau/IrUtils.h" -#include "Luau/OptimizeConstProp.h" -#include "Luau/OptimizeFinalX64.h" #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" @@ -21,17 +18,10 @@ #include "NativeState.h" #include "CodeGenA64.h" -#include "EmitCommonA64.h" -#include "IrLoweringA64.h" - #include "CodeGenX64.h" -#include "EmitCommonX64.h" -#include "EmitInstructionX64.h" -#include "IrLoweringX64.h" #include "lapi.h" -#include #include #include @@ -107,238 +97,14 @@ static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) gPerfLogFn(gPerfLogContext, addr, size, name); } -template -static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) -{ - // While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined - std::vector sortedBlocks; - sortedBlocks.reserve(function.blocks.size()); - for (uint32_t i = 0; i < function.blocks.size(); i++) - sortedBlocks.push_back(i); - - std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) { - const IrBlock& a = function.blocks[idxA]; - const IrBlock& b = function.blocks[idxB]; - - // Place fallback blocks at the end - if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback)) - return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); - - // Try to order by instruction order - return a.sortkey < b.sortkey; - }); - - // For each IR instruction that begins a bytecode instruction, which bytecode instruction is it? - std::vector bcLocations(function.instructions.size() + 1, ~0u); - - for (size_t i = 0; i < function.bcMapping.size(); ++i) - { - uint32_t irLocation = function.bcMapping[i].irLocation; - - if (irLocation != ~0u) - bcLocations[irLocation] = uint32_t(i); - } - - bool outputEnabled = options.includeAssembly || options.includeIr; - - IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; - - // We use this to skip outlined fallback blocks from IR/asm text output - size_t textSize = build.text.length(); - uint32_t codeSize = build.getCodeSize(); - bool seenFallback = false; - - IrBlock dummy; - dummy.start = ~0u; - - for (size_t i = 0; i < sortedBlocks.size(); ++i) - { - uint32_t blockIndex = sortedBlocks[i]; - IrBlock& block = function.blocks[blockIndex]; - - if (block.kind == IrBlockKind::Dead) - continue; - - LUAU_ASSERT(block.start != ~0u); - LUAU_ASSERT(block.finish != ~0u); - - // If we want to skip fallback code IR/asm, we'll record when those blocks start once we see them - if (block.kind == IrBlockKind::Fallback && !seenFallback) - { - textSize = build.text.length(); - codeSize = build.getCodeSize(); - seenFallback = true; - } - - if (options.includeIr) - { - build.logAppend("# "); - toStringDetailed(ctx, block, blockIndex, /* includeUseInfo */ true); - } - - // Values can only reference restore operands in the current block - function.validRestoreOpBlockIdx = blockIndex; - - build.setLabel(block.label); - - for (uint32_t index = block.start; index <= block.finish; index++) - { - LUAU_ASSERT(index < function.instructions.size()); - - uint32_t bcLocation = bcLocations[index]; - - // If IR instruction is the first one for the original bytecode, we can annotate it with source code text - if (outputEnabled && options.annotator && bcLocation != ~0u) - { - options.annotator(options.annotatorContext, build.text, bytecodeid, bcLocation); - } - - // If bytecode needs the location of this instruction for jumps, record it - if (bcLocation != ~0u) - { - Label label = (index == block.start) ? block.label : build.setLabel(); - function.bcMapping[bcLocation].asmLocation = build.getLabelOffset(label); - } - - IrInst& inst = function.instructions[index]; - - // Skip pseudo instructions, but make sure they are not used at this stage - // This also prevents them from getting into text output when that's enabled - if (isPseudo(inst.cmd)) - { - LUAU_ASSERT(inst.useCount == 0); - continue; - } - - // Either instruction result value is not referenced or the use count is not zero - LUAU_ASSERT(inst.lastUse == 0 || inst.useCount != 0); - - if (options.includeIr) - { - build.logAppend("# "); - toStringDetailed(ctx, block, blockIndex, inst, index, /* includeUseInfo */ true); - } - - IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; - - lowering.lowerInst(inst, index, next); - - if (lowering.hasError()) - { - // Place labels for all blocks that we're skipping - // This is needed to avoid AssemblyBuilder assertions about jumps in earlier blocks with unplaced labels - for (size_t j = i + 1; j < sortedBlocks.size(); ++j) - { - IrBlock& abandoned = function.blocks[sortedBlocks[j]]; - - build.setLabel(abandoned.label); - } - - lowering.finishFunction(); - - return false; - } - } - - lowering.finishBlock(); - - if (options.includeIr) - build.logAppend("#\n"); - } - - if (!seenFallback) - { - textSize = build.text.length(); - codeSize = build.getCodeSize(); - } - - lowering.finishFunction(); - - if (outputEnabled && !options.includeOutlinedCode && textSize < build.text.size()) - { - build.text.resize(textSize); - - if (options.includeAssembly) - build.logAppend("; skipping %u bytes of outlined code\n", unsigned((build.getCodeSize() - codeSize) * sizeof(build.code[0]))); - } - - return true; -} - -[[maybe_unused]] static bool lowerIr( - X64::AssemblyBuilderX64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) -{ - optimizeMemoryOperandsX64(ir.function); - - X64::IrLoweringX64 lowering(build, helpers, data, ir.function); - - return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); -} - -[[maybe_unused]] static bool lowerIr( - A64::AssemblyBuilderA64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) -{ - A64::IrLoweringA64 lowering(build, helpers, data, ir.function); - - return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); -} - template -static std::optional assembleFunction(AssemblyBuilder& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +static std::optional createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto) { - if (options.includeAssembly || options.includeIr) - { - if (proto->debugname) - build.logAppend("; function %s(", getstr(proto->debugname)); - else - build.logAppend("; function("); - - for (int i = 0; i < proto->numparams; i++) - { - LocVar* var = proto->locvars ? &proto->locvars[proto->sizelocvars - proto->numparams + i] : nullptr; - - if (var && var->varname) - build.logAppend("%s%s", i == 0 ? "" : ", ", getstr(var->varname)); - else - build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); - } - - if (proto->numparams != 0 && proto->is_vararg) - build.logAppend(", ...)"); - else - build.logAppend(")"); - - if (proto->linedefined >= 0) - build.logAppend(" line %d\n", proto->linedefined); - else - build.logAppend("\n"); - } - IrBuilder ir; ir.buildFunctionIr(proto); - computeCfgInfo(ir.function); - - if (!FFlag::DebugCodegenNoOpt) - { - bool useValueNumbering = !FFlag::DebugCodegenSkipNumbering; - - constPropInBlockChains(ir, useValueNumbering); - - if (!FFlag::DebugCodegenOptSize) - createLinearBlocks(ir, useValueNumbering); - } - - if (!lowerIr(build, ir, data, helpers, proto, options)) - { - if (build.logText) - build.logAppend("; skipping (can't lower)\n\n"); - + if (!lowerFunction(ir, build, helpers, proto, {})) return std::nullopt; - } - - if (build.logText) - build.logAppend("\n"); return createNativeProto(proto, ir); } @@ -384,7 +150,7 @@ static void onSetBreakpoint(lua_State* L, Proto* proto, int instruction) } #if defined(__aarch64__) -static unsigned int getCpuFeaturesA64() +unsigned int getCpuFeaturesA64() { unsigned int result = 0; @@ -482,21 +248,6 @@ void create(lua_State* L) ecb->setbreakpoint = onSetBreakpoint; } -static void gatherFunctions(std::vector& results, Proto* proto) -{ - if (results.size() <= size_t(proto->bytecodeid)) - results.resize(proto->bytecodeid + 1); - - // Skip protos that we've already compiled in this run: this happens because at -O2, inlined functions get their protos reused - if (results[proto->bytecodeid]) - return; - - results[proto->bytecodeid] = proto; - - for (int i = 0; i < proto->sizep; i++) - gatherFunctions(results, proto->p[i]); -} - void compile(lua_State* L, int idx) { LUAU_ASSERT(lua_isLfunction(L, idx)); @@ -529,7 +280,7 @@ void compile(lua_State* L, int idx) // Skip protos that have been compiled during previous invocations of CodeGen::compile for (Proto* p : protos) if (p && p->execdata == nullptr) - if (std::optional np = assembleFunction(build, *data, helpers, p, {})) + if (std::optional np = createNativeFunction(build, helpers, p)) results.push_back(*np); // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module @@ -580,51 +331,6 @@ void compile(lua_State* L, int idx) } } -std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) -{ - LUAU_ASSERT(lua_isLfunction(L, idx)); - const TValue* func = luaA_toobject(L, idx); - -#if defined(__aarch64__) - A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, getCpuFeaturesA64()); -#else - X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); -#endif - - NativeState data; - initFunctions(data); - - std::vector protos; - gatherFunctions(protos, clvalue(func)->l.p); - - ModuleHelpers helpers; -#if defined(__aarch64__) - A64::assembleHelpers(build, helpers); -#else - X64::assembleHelpers(build, helpers); -#endif - - if (!options.includeOutlinedCode && options.includeAssembly) - { - build.text.clear(); - build.logAppend("; skipping %u bytes of outlined helpers\n", unsigned(build.getCodeSize() * sizeof(build.code[0]))); - } - - for (Proto* p : protos) - if (p) - if (std::optional np = assembleFunction(build, data, helpers, p, options)) - destroyExecData(np->execdata); - - if (!build.finalize()) - return std::string(); - - if (options.outputBinary) - return std::string(reinterpret_cast(build.code.data()), reinterpret_cast(build.code.data() + build.code.size())) + - std::string(build.data.begin(), build.data.end()); - else - return build.text; -} - void setPerfLog(void* context, PerfLogFn logFn) { gPerfLogContext = context; diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp new file mode 100644 index 000000000..36d8b274f --- /dev/null +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -0,0 +1,146 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/CodeGen.h" + +#include "CodeGenLower.h" + +#include "CodeGenA64.h" +#include "CodeGenX64.h" + +#include "lapi.h" + +namespace Luau +{ +namespace CodeGen +{ + +template +static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) +{ + if (proto->debugname) + build.logAppend("; function %s(", getstr(proto->debugname)); + else + build.logAppend("; function("); + + for (int i = 0; i < proto->numparams; i++) + { + LocVar* var = proto->locvars ? &proto->locvars[proto->sizelocvars - proto->numparams + i] : nullptr; + + if (var && var->varname) + build.logAppend("%s%s", i == 0 ? "" : ", ", getstr(var->varname)); + else + build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); + } + + if (proto->numparams != 0 && proto->is_vararg) + build.logAppend(", ...)"); + else + build.logAppend(")"); + + if (proto->linedefined >= 0) + build.logAppend(" line %d\n", proto->linedefined); + else + build.logAppend("\n"); +} + +template +static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, AssemblyOptions options) +{ + std::vector protos; + gatherFunctions(protos, clvalue(func)->l.p); + + ModuleHelpers helpers; + assembleHelpers(build, helpers); + + if (!options.includeOutlinedCode && options.includeAssembly) + { + build.text.clear(); + build.logAppend("; skipping %u bytes of outlined helpers\n", unsigned(build.getCodeSize() * sizeof(build.code[0]))); + } + + for (Proto* p : protos) + if (p) + { + IrBuilder ir; + ir.buildFunctionIr(p); + + if (options.includeAssembly || options.includeIr) + logFunctionHeader(build, p); + + if (!lowerFunction(ir, build, helpers, p, options)) + { + if (build.logText) + build.logAppend("; skipping (can't lower)\n"); + } + + if (build.logText) + build.logAppend("\n"); + } + + if (!build.finalize()) + return std::string(); + + if (options.outputBinary) + return std::string(reinterpret_cast(build.code.data()), reinterpret_cast(build.code.data() + build.code.size())) + + std::string(build.data.begin(), build.data.end()); + else + return build.text; +} + +#if defined(__aarch64__) +unsigned int getCpuFeaturesA64(); +#endif + +std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) +{ + LUAU_ASSERT(lua_isLfunction(L, idx)); + const TValue* func = luaA_toobject(L, idx); + + switch (options.target) + { + case AssemblyOptions::Host: + { +#if defined(__aarch64__) + A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, getCpuFeaturesA64()); +#else + X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); +#endif + + return getAssemblyImpl(build, func, options); + } + + case AssemblyOptions::A64: + { + A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, /* features= */ A64::Feature_JSCVT); + + return getAssemblyImpl(build, func, options); + } + + case AssemblyOptions::A64_NoFeatures: + { + A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, /* features= */ 0); + + return getAssemblyImpl(build, func, options); + } + + case AssemblyOptions::X64_Windows: + { + X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly, X64::ABIX64::Windows); + + return getAssemblyImpl(build, func, options); + } + + case AssemblyOptions::X64_SystemV: + { + X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly, X64::ABIX64::SystemV); + + return getAssemblyImpl(build, func, options); + } + + default: + LUAU_ASSERT(!"Unknown target"); + return std::string(); + } +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h new file mode 100644 index 000000000..5b6c4ffc4 --- /dev/null +++ b/CodeGen/src/CodeGenLower.h @@ -0,0 +1,240 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AssemblyBuilderA64.h" +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/CodeGen.h" +#include "Luau/IrBuilder.h" +#include "Luau/IrDump.h" +#include "Luau/IrUtils.h" +#include "Luau/OptimizeConstProp.h" +#include "Luau/OptimizeFinalX64.h" + +#include "EmitCommon.h" +#include "IrLoweringA64.h" +#include "IrLoweringX64.h" + +#include "lobject.h" +#include "lstate.h" + +#include +#include + +LUAU_FASTFLAG(DebugCodegenNoOpt) +LUAU_FASTFLAG(DebugCodegenOptSize) +LUAU_FASTFLAG(DebugCodegenSkipNumbering) + +namespace Luau +{ +namespace CodeGen +{ + +inline void gatherFunctions(std::vector& results, Proto* proto) +{ + if (results.size() <= size_t(proto->bytecodeid)) + results.resize(proto->bytecodeid + 1); + + // Skip protos that we've already compiled in this run: this happens because at -O2, inlined functions get their protos reused + if (results[proto->bytecodeid]) + return; + + results[proto->bytecodeid] = proto; + + for (int i = 0; i < proto->sizep; i++) + gatherFunctions(results, proto->p[i]); +} + +template +inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) +{ + // While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined + std::vector sortedBlocks; + sortedBlocks.reserve(function.blocks.size()); + for (uint32_t i = 0; i < function.blocks.size(); i++) + sortedBlocks.push_back(i); + + std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) { + const IrBlock& a = function.blocks[idxA]; + const IrBlock& b = function.blocks[idxB]; + + // Place fallback blocks at the end + if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback)) + return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); + + // Try to order by instruction order + return a.sortkey < b.sortkey; + }); + + // For each IR instruction that begins a bytecode instruction, which bytecode instruction is it? + std::vector bcLocations(function.instructions.size() + 1, ~0u); + + for (size_t i = 0; i < function.bcMapping.size(); ++i) + { + uint32_t irLocation = function.bcMapping[i].irLocation; + + if (irLocation != ~0u) + bcLocations[irLocation] = uint32_t(i); + } + + bool outputEnabled = options.includeAssembly || options.includeIr; + + IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; + + // We use this to skip outlined fallback blocks from IR/asm text output + size_t textSize = build.text.length(); + uint32_t codeSize = build.getCodeSize(); + bool seenFallback = false; + + IrBlock dummy; + dummy.start = ~0u; + + for (size_t i = 0; i < sortedBlocks.size(); ++i) + { + uint32_t blockIndex = sortedBlocks[i]; + IrBlock& block = function.blocks[blockIndex]; + + if (block.kind == IrBlockKind::Dead) + continue; + + LUAU_ASSERT(block.start != ~0u); + LUAU_ASSERT(block.finish != ~0u); + + // If we want to skip fallback code IR/asm, we'll record when those blocks start once we see them + if (block.kind == IrBlockKind::Fallback && !seenFallback) + { + textSize = build.text.length(); + codeSize = build.getCodeSize(); + seenFallback = true; + } + + if (options.includeIr) + { + build.logAppend("# "); + toStringDetailed(ctx, block, blockIndex, /* includeUseInfo */ true); + } + + // Values can only reference restore operands in the current block + function.validRestoreOpBlockIdx = blockIndex; + + build.setLabel(block.label); + + for (uint32_t index = block.start; index <= block.finish; index++) + { + LUAU_ASSERT(index < function.instructions.size()); + + uint32_t bcLocation = bcLocations[index]; + + // If IR instruction is the first one for the original bytecode, we can annotate it with source code text + if (outputEnabled && options.annotator && bcLocation != ~0u) + { + options.annotator(options.annotatorContext, build.text, bytecodeid, bcLocation); + } + + // If bytecode needs the location of this instruction for jumps, record it + if (bcLocation != ~0u) + { + Label label = (index == block.start) ? block.label : build.setLabel(); + function.bcMapping[bcLocation].asmLocation = build.getLabelOffset(label); + } + + IrInst& inst = function.instructions[index]; + + // Skip pseudo instructions, but make sure they are not used at this stage + // This also prevents them from getting into text output when that's enabled + if (isPseudo(inst.cmd)) + { + LUAU_ASSERT(inst.useCount == 0); + continue; + } + + // Either instruction result value is not referenced or the use count is not zero + LUAU_ASSERT(inst.lastUse == 0 || inst.useCount != 0); + + if (options.includeIr) + { + build.logAppend("# "); + toStringDetailed(ctx, block, blockIndex, inst, index, /* includeUseInfo */ true); + } + + IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; + + lowering.lowerInst(inst, index, next); + + if (lowering.hasError()) + { + // Place labels for all blocks that we're skipping + // This is needed to avoid AssemblyBuilder assertions about jumps in earlier blocks with unplaced labels + for (size_t j = i + 1; j < sortedBlocks.size(); ++j) + { + IrBlock& abandoned = function.blocks[sortedBlocks[j]]; + + build.setLabel(abandoned.label); + } + + lowering.finishFunction(); + + return false; + } + } + + lowering.finishBlock(); + + if (options.includeIr) + build.logAppend("#\n"); + } + + if (!seenFallback) + { + textSize = build.text.length(); + codeSize = build.getCodeSize(); + } + + lowering.finishFunction(); + + if (outputEnabled && !options.includeOutlinedCode && textSize < build.text.size()) + { + build.text.resize(textSize); + + if (options.includeAssembly) + build.logAppend("; skipping %u bytes of outlined code\n", unsigned((build.getCodeSize() - codeSize) * sizeof(build.code[0]))); + } + + return true; +} + +inline bool lowerIr(X64::AssemblyBuilderX64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +{ + optimizeMemoryOperandsX64(ir.function); + + X64::IrLoweringX64 lowering(build, helpers, ir.function); + + return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); +} + +inline bool lowerIr(A64::AssemblyBuilderA64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +{ + A64::IrLoweringA64 lowering(build, helpers, ir.function); + + return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); +} + +template +inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +{ + computeCfgInfo(ir.function); + + if (!FFlag::DebugCodegenNoOpt) + { + bool useValueNumbering = !FFlag::DebugCodegenSkipNumbering; + + constPropInBlockChains(ir, useValueNumbering); + + if (!FFlag::DebugCodegenOptSize) + createLinearBlocks(ir, useValueNumbering); + } + + return lowerIr(build, ir, helpers, proto, options); +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 96599c2e5..efc480e02 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -75,29 +75,6 @@ static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vmovsd(luauRegValue(ra), tmp0.reg); } -static void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) -{ - ScopedRegX64 tmp0{regs, SizeX64::qword}; - ScopedRegX64 tag{regs, SizeX64::dword}; - - build.mov(tag.reg, luauRegTag(arg)); - - build.mov(tmp0.reg, qword[rState + offsetof(lua_State, global)]); - build.mov(tmp0.reg, qword[tmp0.reg + qwordReg(tag.reg) * sizeof(TString*) + offsetof(global_State, ttname)]); - - build.mov(luauRegValue(ra), tmp0.reg); -} - -static void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) -{ - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, luauRegAddress(arg)); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_objtypenamestr)]); - - build.mov(luauRegValue(ra), rax); -} - void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults) { switch (bfid) @@ -111,12 +88,6 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r case LBF_MATH_SIGN: LUAU_ASSERT(nparams == 1 && nresults == 1); return emitBuiltinMathSign(regs, build, ra, arg); - case LBF_TYPE: - LUAU_ASSERT(nparams == 1 && nresults == 1); - return emitBuiltinType(regs, build, ra, arg); - case LBF_TYPEOF: - LUAU_ASSERT(nparams == 1 && nresults == 1); - return emitBuiltinTypeof(regs, build, ra, arg); default: LUAU_ASSERT(!"Missing x64 lowering"); } diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index f240d26f7..4d70bb7a7 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -5,6 +5,7 @@ #include "Luau/IrCallWrapperX64.h" #include "Luau/IrData.h" #include "Luau/IrRegAllocX64.h" +#include "Luau/IrUtils.h" #include "NativeState.h" @@ -179,11 +180,15 @@ void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, Operan emitUpdateBase(build); } -void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip) +void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, int ratag, Label& skip) { - // iscollectable(ra) - build.cmp(luauRegTag(ra), LUA_TSTRING); - build.jcc(ConditionX64::Less, skip); + // Barrier should've been optimized away if we know that it's not collectable, checking for correctness + if (ratag == -1 || !isGCO(ratag)) + { + // iscollectable(ra) + build.cmp(luauRegTag(ra), LUA_TSTRING); + build.jcc(ConditionX64::Less, skip); + } // isblack(obj2gco(o)) build.test(byte[object + offsetof(GCheader, marked)], bitmask(BLACKBIT)); @@ -195,12 +200,12 @@ void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, Re build.jcc(ConditionX64::Zero, skip); } -void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra) +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, int ratag) { Label skip; ScopedRegX64 tmp{regs, SizeX64::qword}; - checkObjectBarrierConditions(build, tmp.reg, object, ra, skip); + checkObjectBarrierConditions(build, tmp.reg, object, ra, ratag, skip); { ScopedSpills spillGuard(regs); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 37be73fda..5a3548f6c 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -170,8 +170,8 @@ void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, in void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init); void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); -void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip); -void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra); +void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, int ratag, Label& skip); +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, int ratag); void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp); void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 14fc9b467..cf7161ef0 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -444,6 +444,9 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::ADJUST_STACK_TO_TOP: // While this can be considered to be a vararg consumer, it is already handled in fastcall instructions break; + case IrCmd::GET_TYPEOF: + use(inst.a); + break; default: // All instructions which reference registers have to be handled explicitly diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 09cafbaa8..e699229c8 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -297,6 +297,10 @@ const char* getCmdName(IrCmd cmd) return "BITCOUNTRZ_UINT"; case IrCmd::INVOKE_LIBM: return "INVOKE_LIBM"; + case IrCmd::GET_TYPE: + return "GET_TYPE"; + case IrCmd::GET_TYPEOF: + return "GET_TYPEOF"; } LUAU_UNREACHABLE(); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 94c46dbfd..3cf921730 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -60,14 +60,18 @@ inline ConditionA64 getConditionFP(IrCondition cond) } } -static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp, int ra, Label& skip) +static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp, int ra, int ratag, Label& skip) { RegisterA64 tempw = castReg(KindA64::w, temp); - // iscollectable(ra) - build.ldr(tempw, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, tt))); - build.cmp(tempw, LUA_TSTRING); - build.b(ConditionA64::Less, skip); + // Barrier should've been optimized away if we know that it's not collectable, checking for correctness + if (ratag == -1 || !isGCO(ratag)) + { + // iscollectable(ra) + build.ldr(tempw, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, tt))); + build.cmp(tempw, LUA_TSTRING); + build.b(ConditionA64::Less, skip); + } // isblack(obj2gco(o)) build.ldrb(tempw, mem(object, offsetof(GCheader, marked))); @@ -162,33 +166,15 @@ static bool emitBuiltin( build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); return true; - case LBF_TYPE: - build.ldr(w0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, tt))); - build.ldr(x1, mem(rState, offsetof(lua_State, global))); - LUAU_ASSERT(sizeof(TString*) == 8); - build.add(x1, x1, zextReg(w0), 3); - build.ldr(x0, mem(x1, offsetof(global_State, ttname))); - build.str(x0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.gc))); - return true; - - case LBF_TYPEOF: - build.mov(x0, rState); - build.add(x1, rBase, uint16_t(arg * sizeof(TValue))); - build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaT_objtypenamestr))); - build.blr(x2); - build.str(x0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.gc))); - return true; - default: LUAU_ASSERT(!"Missing A64 lowering"); return false; } } -IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, IrFunction& function) +IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function) : build(build) , helpers(helpers) - , data(data) , function(function) , regs(function, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}}) , valueTracker(function) @@ -1004,7 +990,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.str(temp3, temp2); Label skip; - checkObjectBarrierConditions(build, temp1, temp2, vmRegOp(inst.b), skip); + checkObjectBarrierConditions(build, temp1, temp2, vmRegOp(inst.b), /* ratag */ -1, skip); size_t spills = regs.spill(build, index, {temp1}); @@ -1210,7 +1196,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) RegisterA64 temp = regs.allocTemp(KindA64::x); Label skip; - checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), skip); + checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads size_t spills = regs.spill(build, index, {reg}); @@ -1254,7 +1240,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) RegisterA64 temp = regs.allocTemp(KindA64::x); Label skip; - checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), skip); + checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads size_t spills = regs.spill(build, index, {reg}); @@ -1710,6 +1696,34 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regA64 = regs.takeReg(d0, index); break; } + case IrCmd::GET_TYPE: + { + inst.regA64 = regs.allocReg(KindA64::x, index); + + build.ldr(inst.regA64, mem(rState, offsetof(lua_State, global))); + LUAU_ASSERT(sizeof(TString*) == 8); + + if (inst.a.kind == IrOpKind::Inst) + build.add(inst.regA64, inst.regA64, zextReg(regOp(inst.a)), 3); + else if (inst.a.kind == IrOpKind::Constant) + build.add(inst.regA64, inst.regA64, uint16_t(tagOp(inst.a)) * 8); + else + LUAU_ASSERT(!"Unsupported instruction form"); + + build.ldr(inst.regA64, mem(inst.regA64, offsetof(global_State, ttname))); + break; + } + case IrCmd::GET_TYPEOF: + { + regs.spill(build, index); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaT_objtypenamestr))); + build.blr(x2); + + inst.regA64 = regs.takeReg(x0, index); + break; + } // To handle unsupported instructions, add "case IrCmd::OP" and make sure to set error = true! } diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index fc228cf16..57c18b2ed 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -15,7 +15,6 @@ namespace CodeGen { struct ModuleHelpers; -struct NativeState; struct AssemblyOptions; namespace A64 @@ -23,7 +22,7 @@ namespace A64 struct IrLoweringA64 { - IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, IrFunction& function); + IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function); void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); void finishBlock(); @@ -63,7 +62,6 @@ struct IrLoweringA64 AssemblyBuilderA64& build; ModuleHelpers& helpers; - NativeState& data; IrFunction& function; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 320cb0791..abe02eedb 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -22,10 +22,9 @@ namespace CodeGen namespace X64 { -IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, IrFunction& function) +IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function) : build(build) , helpers(helpers) - , data(data) , function(function) , regs(build, function) , valueTracker(function) @@ -872,7 +871,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) tmp1.free(); - callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b)); + callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b), /* ratag */ -1); break; } case IrCmd::PREPARE_FORN: @@ -983,7 +982,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callStepGc(regs, build); break; case IrCmd::BARRIER_OBJ: - callBarrierObject(regs, build, regOp(inst.a), inst.a, vmRegOp(inst.b)); + callBarrierObject(regs, build, regOp(inst.a), inst.a, vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c)); break; case IrCmd::BARRIER_TABLE_BACK: callBarrierTableFast(regs, build, regOp(inst.a), inst.a); @@ -993,7 +992,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label skip; ScopedRegX64 tmp{regs, SizeX64::qword}; - checkObjectBarrierConditions(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), skip); + checkObjectBarrierConditions(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); { ScopedSpills spillGuard(regs); @@ -1350,6 +1349,30 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regX64 = regs.takeReg(xmm0, index); break; } + case IrCmd::GET_TYPE: + { + inst.regX64 = regs.allocReg(SizeX64::qword, index); + + build.mov(inst.regX64, qword[rState + offsetof(lua_State, global)]); + + if (inst.a.kind == IrOpKind::Inst) + build.mov(inst.regX64, qword[inst.regX64 + qwordReg(regOp(inst.a)) * sizeof(TString*) + offsetof(global_State, ttname)]); + else if (inst.a.kind == IrOpKind::Constant) + build.mov(inst.regX64, qword[inst.regX64 + tagOp(inst.a) * sizeof(TString*) + offsetof(global_State, ttname)]); + else + LUAU_ASSERT(!"Unsupported instruction form"); + break; + } + case IrCmd::GET_TYPEOF: + { + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(vmRegOp(inst.a))); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_objtypenamestr)]); + + inst.regX64 = regs.takeReg(rax, index); + break; + } // Pseudo instructions case IrCmd::NOP: @@ -1376,7 +1399,7 @@ void IrLoweringX64::finishFunction() for (InterruptHandler& handler : interruptHandlers) { build.setLabel(handler.self); - build.mov(rax, handler.pcpos + 1); + build.mov(eax, handler.pcpos + 1); build.lea(rbx, handler.next); build.jmp(helpers.interrupt); } diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index a375a334c..f50812e42 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -17,7 +17,6 @@ namespace CodeGen { struct ModuleHelpers; -struct NativeState; struct AssemblyOptions; namespace X64 @@ -25,7 +24,7 @@ namespace X64 struct IrLoweringX64 { - IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, IrFunction& function); + IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function); void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); void finishBlock(); @@ -63,7 +62,6 @@ struct IrLoweringX64 AssemblyBuilderX64& build; ModuleHelpers& helpers; - NativeState& data; IrFunction& function; diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index cfa4bc6c1..e99a991a7 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -344,8 +344,10 @@ static BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.inst(IrCmd::FASTCALL, build.constUint(LBF_TYPE), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); + IrOp tag = build.inst(IrCmd::LOAD_TAG, build.vmReg(arg)); + IrOp name = build.inst(IrCmd::GET_TYPE, tag); + build.inst(IrCmd::STORE_POINTER, build.vmReg(ra), name); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); return {BuiltinImplType::UsesFallback, 1}; @@ -356,8 +358,9 @@ static BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, i if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.inst(IrCmd::FASTCALL, build.constUint(LBF_TYPEOF), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); + IrOp name = build.inst(IrCmd::GET_TYPEOF, build.vmReg(arg)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(ra), name); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); return {BuiltinImplType::UsesFallback, 1}; diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 8e135dfe0..8f18827bc 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -825,7 +825,7 @@ void translateInstSetTableN(IrBuilder& build, const Instruction* pc, int pcpos) IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); build.inst(IrCmd::STORE_TVALUE, arrEl, tva); - build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra)); + build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra), build.undef()); IrOp next = build.blockAtInst(pcpos + 1); FallbackStreamScope scope(build, fallback, next); @@ -902,7 +902,7 @@ void translateInstSetTable(IrBuilder& build, const Instruction* pc, int pcpos) IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); build.inst(IrCmd::STORE_TVALUE, arrEl, tva); - build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra)); + build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra), build.undef()); IrOp next = build.blockAtInst(pcpos + 1); FallbackStreamScope scope(build, fallback, next); @@ -989,7 +989,7 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); build.inst(IrCmd::STORE_NODE_VALUE_TV, addrSlotEl, tva); - build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra)); + build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra), build.undef()); IrOp next = build.blockAtInst(pcpos + 2); FallbackStreamScope scope(build, fallback, next); @@ -1036,7 +1036,7 @@ void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos) IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); build.inst(IrCmd::STORE_NODE_VALUE_TV, addrSlotEl, tva); - build.inst(IrCmd::BARRIER_TABLE_FORWARD, env, build.vmReg(ra)); + build.inst(IrCmd::BARRIER_TABLE_FORWARD, env, build.vmReg(ra), build.undef()); IrOp next = build.blockAtInst(pcpos + 2); FallbackStreamScope scope(build, fallback, next); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 70ad1438c..833d1cdd7 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -159,6 +159,9 @@ IrValueKind getCmdValueKind(IrCmd cmd) return IrValueKind::Int; case IrCmd::INVOKE_LIBM: return IrValueKind::Double; + case IrCmd::GET_TYPE: + case IrCmd::GET_TYPEOF: + return IrValueKind::Pointer; } LUAU_UNREACHABLE(); diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index be661a7df..e94a43476 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -108,6 +108,7 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) case IrCmd::FALLBACK_SETTABLEKS: case IrCmd::FALLBACK_PREPVARARGS: case IrCmd::ADJUST_STACK_TO_TOP: + case IrCmd::GET_TYPEOF: break; // These instrucitons read VmReg only after optimizeMemoryOperandsX64 diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index b779fb4b7..e3cbef415 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -732,6 +732,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& // If the written object is not collectable, barrier is not required if (!isGCO(tag)) kill(function, inst); + else + replace(function, inst.c, build.constTag(tag)); } } break; @@ -820,6 +822,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::BITCOUNTLZ_UINT: case IrCmd::BITCOUNTRZ_UINT: case IrCmd::INVOKE_LIBM: + case IrCmd::GET_TYPE: + case IrCmd::GET_TYPEOF: break; case IrCmd::JUMP_CMP_ANY: diff --git a/Common/include/Luau/DenseHash.h b/Common/include/Luau/DenseHash.h index ce0dee6f5..997e090f6 100644 --- a/Common/include/Luau/DenseHash.h +++ b/Common/include/Luau/DenseHash.h @@ -33,7 +33,7 @@ class DenseHashTable class const_iterator; class iterator; - DenseHashTable(const Key& empty_key, size_t buckets = 0) + explicit DenseHashTable(const Key& empty_key, size_t buckets = 0) : data(nullptr) , capacity(0) , count(0) @@ -477,7 +477,7 @@ class DenseHashSet typedef typename Impl::const_iterator const_iterator; typedef typename Impl::iterator iterator; - DenseHashSet(const Key& empty_key, size_t buckets = 0) + explicit DenseHashSet(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) { } @@ -546,7 +546,7 @@ class DenseHashMap typedef typename Impl::const_iterator const_iterator; typedef typename Impl::iterator iterator; - DenseHashMap(const Key& empty_key, size_t buckets = 0) + explicit DenseHashMap(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) { } @@ -584,6 +584,22 @@ class DenseHashMap return impl.find(key) != 0; } + std::pair try_insert(const Key& key, const Value& value) + { + impl.rehash_if_full(key); + + size_t before = impl.size(); + std::pair* slot = impl.insert_unsafe(key); + + // Value is fresh if container count has increased + bool fresh = impl.size() > before; + + if (fresh) + slot->second = value; + + return std::make_pair(std::ref(slot->second), fresh); + } + size_t size() const { return impl.size(); diff --git a/Makefile b/Makefile index d3bf31d2e..852b14f83 100644 --- a/Makefile +++ b/Makefile @@ -161,7 +161,7 @@ clean: rm -rf $(BUILD) rm -rf $(EXECUTABLE_ALIASES) -coverage: $(TESTS_TARGET) +coverage: $(TESTS_TARGET) $(COMPILE_CLI_TARGET) $(TESTS_TARGET) mv default.profraw tests.profraw $(TESTS_TARGET) --fflags=true @@ -170,7 +170,11 @@ coverage: $(TESTS_TARGET) mv default.profraw codegen.profraw $(TESTS_TARGET) -ts=Conformance --codegen --fflags=true mv default.profraw codegen-flags.profraw - llvm-profdata merge tests.profraw tests-flags.profraw codegen.profraw codegen-flags.profraw -o default.profdata + $(COMPILE_CLI_TARGET) --codegennull --target=a64 tests/conformance + mv default.profraw codegen-a64.profraw + $(COMPILE_CLI_TARGET) --codegennull --target=x64 tests/conformance + mv default.profraw codegen-x64.profraw + llvm-profdata merge *.profraw -o default.profdata rm *.profraw llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests llvm-cov report -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests diff --git a/Sources.cmake b/Sources.cmake index 5b9bd61eb..74709b4bf 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -88,6 +88,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/CodeAllocator.cpp CodeGen/src/CodeBlockUnwind.cpp CodeGen/src/CodeGen.cpp + CodeGen/src/CodeGenAssembly.cpp CodeGen/src/CodeGenUtils.cpp CodeGen/src/CodeGenA64.cpp CodeGen/src/CodeGenX64.cpp @@ -115,6 +116,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/BitUtils.h CodeGen/src/ByteUtils.h + CodeGen/src/CodeGenLower.h CodeGen/src/CodeGenUtils.h CodeGen/src/CodeGenA64.h CodeGen/src/CodeGenX64.h diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp index 018fa87cf..686a99d17 100644 --- a/tests/CostModel.test.cpp +++ b/tests/CostModel.test.cpp @@ -31,7 +31,7 @@ static uint64_t modelFunction(const char* source) AstStatFunction* func = result.root->body.data[0]->as(); REQUIRE(func); - return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size, {nullptr}); + return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size, DenseHashMap{nullptr}); } TEST_CASE("Expression") diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 5b0c44d0c..f1399a590 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -1005,9 +1005,9 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); - build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, build.vmReg(0)); + build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, build.vmReg(0), build.undef()); IrOp something = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); - build.inst(IrCmd::BARRIER_OBJ, something, build.vmReg(0)); + build.inst(IrCmd::BARRIER_OBJ, something, build.vmReg(0), build.undef()); build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index abdfea77c..74e8a959e 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -409,9 +409,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_reexports") { ScopedFastFlag flags[] = { {"LuauClonePublicInterfaceLess2", true}, - {"LuauSubstitutionReentrant", true}, - {"LuauClassTypeVarsInSubstitution", true}, - {"LuauSubstitutionFixMissingFields", true}, }; fileResolver.source["Module/A"] = R"( @@ -447,9 +444,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") { ScopedFastFlag flags[] = { {"LuauClonePublicInterfaceLess2", true}, - {"LuauSubstitutionReentrant", true}, - {"LuauClassTypeVarsInSubstitution", true}, - {"LuauSubstitutionFixMissingFields", true}, }; fileResolver.source["Module/A"] = R"( diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index c10131baa..613aec801 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -140,8 +140,8 @@ TEST_CASE_FIXTURE(FamilyFixture, "unsolvable_family") local b = impossible(true) )"); - LUAU_REQUIRE_ERROR_COUNT(4, result); - for (size_t i = 0; i < 4; ++i) + LUAU_REQUIRE_ERROR_COUNT(2, result); + for (size_t i = 0; i < 2; ++i) { CHECK(toString(result.errors[i]) == "Type family instance Swap is uninhabited"); } diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 52de15c75..e4577df67 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -8,7 +8,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) TEST_SUITE_BEGIN("TypeAliases"); @@ -199,15 +198,9 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases") LUAU_REQUIRE_ERROR_COUNT(1, result); - const char* expectedError; - if (FFlag::LuauTypeMismatchInvarianceInError) - expectedError = "Type 'bad' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; - else - expectedError = "Type 'bad' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; + const char* expectedError = "Type 'bad' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; CHECK(result.errors[0].location == Location{{4, 31}, {4, 44}}); CHECK(toString(result.errors[0]) == expectedError); @@ -226,19 +219,11 @@ TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") LUAU_REQUIRE_ERROR_COUNT(1, result); - std::string expectedError; - if (FFlag::LuauTypeMismatchInvarianceInError) - expectedError = "Type 'bad' could not be converted into 'U'\n" - "caused by:\n" - " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; - else - expectedError = "Type 'bad' could not be converted into 'U'\n" - "caused by:\n" - " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; + std::string expectedError = "Type 'bad' could not be converted into 'U'\n" + "caused by:\n" + " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; CHECK(result.errors[0].location == Location{{4, 31}, {4, 52}}); CHECK(toString(result.errors[0]) == expectedError); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 0f255f08c..687bc766d 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -108,10 +108,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") end )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_ERROR_COUNT(2, result); - else - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("*error-type*", toString(requireType("a"))); } diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 37ecab2ee..9e3a63f79 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -12,8 +12,6 @@ using namespace Luau; using std::nullopt; -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError); - TEST_SUITE_BEGIN("TypeInferClasses"); TEST_CASE_FIXTURE(ClassFixture, "call_method_of_a_class") @@ -462,14 +460,9 @@ local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' caused by: Property 'x' is not compatible. Type 'ChildClass' could not be converted into 'BaseClass' in an invariant context)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' -caused by: - Property 'x' is not compatible. Type 'ChildClass' could not be converted into 'BaseClass')"); } TEST_CASE_FIXTURE(ClassFixture, "callable_classes") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 5aabb240b..f0630ca9a 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1312,10 +1312,6 @@ f(function(x) return x * 2 end) TEST_CASE_FIXTURE(Fixture, "variadic_any_is_compatible_with_a_generic_TypePack") { - ScopedFastFlag sff[] = { - {"LuauVariadicAnyCanBeGeneric", true} - }; - CheckResult result = check(R"( --!strict local function f(...) return ... end @@ -1328,8 +1324,6 @@ TEST_CASE_FIXTURE(Fixture, "variadic_any_is_compatible_with_a_generic_TypePack") // https://github.com/Roblox/luau/issues/767 TEST_CASE_FIXTURE(BuiltinsFixture, "variadic_any_is_compatible_with_a_generic_TypePack_2") { - ScopedFastFlag sff{"LuauVariadicAnyCanBeGeneric", true}; - CheckResult result = check(R"( local function somethingThatsAny(...: any) print(...) @@ -1920,8 +1914,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceede ScopedFastFlag sff[] = { {"DebugLuauDeferredConstraintResolution", true}, {"LuauClonePublicInterfaceLess2", true}, - {"LuauSubstitutionReentrant", true}, - {"LuauSubstitutionFixMissingFields", true}, {"LuauCloneSkipNonInternalVisit", true}, }; @@ -2089,4 +2081,19 @@ TEST_CASE_FIXTURE(Fixture, "attempt_to_call_an_intersection_of_tables") CHECK_EQ(toString(result.errors[0]), "Cannot call non-function {| x: number |}"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "attempt_to_call_an_intersection_of_tables_with_call_metamethod") +{ + CheckResult result = check(R"( + type Callable = typeof(setmetatable({}, { + __call = function(self, ...) return ... end + })) + + local function f(t: Callable & { x: number }) + t() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 5ab27f645..72323cf90 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -10,7 +10,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) using namespace Luau; @@ -725,24 +724,12 @@ y.a.c = y )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - { - CHECK_EQ(toString(result.errors[0]), - R"(Type 'y' could not be converted into 'T' + CHECK_EQ(toString(result.errors[0]), + R"(Type 'y' could not be converted into 'T' caused by: Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' caused by: Property 'd' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); - } - else - { - CHECK_EQ(toString(result.errors[0]), - R"(Type 'y' could not be converted into 'T' -caused by: - Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' -caused by: - Property 'd' is not compatible. Type 'number' could not be converted into 'string')"); - } } TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 3e813b7fa..012dc7b45 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -539,10 +539,6 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") { - ScopedFastFlag sffs[]{ - {"LuauUninhabitedSubAnything2", true}, - }; - CheckResult result = check(R"( local x : { p : number?, q : never } & { p : never, q : string? } -- OK local y : { p : never, q : never } = x -- OK diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 2a8db46f2..b75f909a8 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -11,7 +11,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) using namespace Luau; @@ -410,14 +409,9 @@ local b: B.T = a CheckResult result = frontend.check("game/C"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' -caused by: - Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); } TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict_instantiated") @@ -449,14 +443,9 @@ local b: B.T = a CheckResult result = frontend.check("game/D"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' -caused by: - Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); } TEST_CASE_FIXTURE(BuiltinsFixture, "constrained_anyification_clone_immutable_types") diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index ee7472520..08c0f7ca0 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -26,17 +26,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defi someTable.Function1() -- Argument count mismatch )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK(toString(result.errors[0]) == "No overload for function accepts 0 arguments."); - CHECK(toString(result.errors[1]) == "Available overloads: (a) -> ()"); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") @@ -50,17 +41,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_ someTable.Function2() -- Argument count mismatch )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK(toString(result.errors[0]) == "No overload for function accepts 0 arguments."); - CHECK(toString(result.errors[1]) == "Available overloads: (a, b) -> ()"); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index c905e1cca..d605d5bc4 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -1238,4 +1238,21 @@ TEST_CASE_FIXTURE(Fixture, "add_type_family_works") CHECK(toString(result.errors[0]) == "Type family instance Add is uninhabited"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "normalize_strings_comparison") +{ + CheckResult result = check(R"( +local function sortKeysForPrinting(a: any, b) + local typeofA = type(a) + local typeofB = type(b) + -- strings and numbers are sorted numerically/alphabetically + if typeofA == typeofB and (typeofA == "number" or typeofA == "string") then + return a < b + end + -- sort the rest by type name + return typeofA < typeofB +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index b5a06a746..a1f456a34 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -9,8 +9,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) - TEST_SUITE_BEGIN("ProvisionalTests"); // These tests check for behavior that differs from the final behavior we'd @@ -793,20 +791,10 @@ TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_ty LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTypeMismatchInvarianceInError) - { - CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' + CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' caused by: Property 'x' is not compatible. Type 'number?' could not be converted into 'number' in an invariant context)", - toString(result.errors[0])); - } - else - { - CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' -caused by: - Property 'x' is not compatible. Type 'number?' could not be converted into 'number')", - toString(result.errors[0])); - } + toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") @@ -856,10 +844,6 @@ TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions_of_tab TEST_CASE_FIXTURE(Fixture, "expected_type_should_be_a_helpful_deduction_guide_for_function_calls") { - ScopedFastFlag sffs[]{ - {"LuauTypeMismatchInvarianceInError", true}, - }; - CheckResult result = check(R"( type Ref = { val: T } @@ -947,10 +931,6 @@ TEST_CASE_FIXTURE(Fixture, "unify_more_complex_unions_that_include_nil") TEST_CASE_FIXTURE(Fixture, "optional_class_instances_are_invariant") { - ScopedFastFlag sff[] = { - {"LuauTypeMismatchInvarianceInError", true} - }; - createSomeClasses(&frontend); CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 694b62708..e3d712beb 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -17,7 +17,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) TEST_SUITE_BEGIN("TableTests"); @@ -2077,14 +2076,9 @@ local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' caused by: Property 'y' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' -caused by: - Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); } TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") @@ -2101,18 +2095,11 @@ local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' caused by: Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' caused by: Property 'y' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' -caused by: - Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' -caused by: - Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); } TEST_CASE_FIXTURE(BuiltinsFixture, "error_detailed_metatable_prop") @@ -2128,18 +2115,11 @@ local c2: typeof(a2) = b2 )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' + CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' caused by: Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number' in an invariant context)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' -caused by: - Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' -caused by: - Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); if (FFlag::LuauInstantiateInSubtyping) { @@ -2170,14 +2150,9 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' caused by: Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' -caused by: - Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string')"); } TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") @@ -2191,14 +2166,9 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' caused by: Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' -caused by: - Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string')"); } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") @@ -2871,10 +2841,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_must_be_callable") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(result.errors[0] == TypeError{ - Location{{5, 20}, {5, 21}}, - CannotCallNonFunction{builtinTypes->numberType}, - }); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK("Cannot call non-function { @metatable { __call: number }, { } }" == toString(result.errors[0])); + } + else + { + TypeError e{ + Location{{5, 20}, {5, 21}}, + CannotCallNonFunction{builtinTypes->numberType}, + }; + + CHECK(result.errors[0] == e); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_generic") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index efe7fed38..7ecde7feb 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1291,4 +1291,45 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "convoluted_case_where_two_TypeVars_were_boun // If this code does not crash, we are in good shape. } +/* + * Under DCR we had an issue where constraint resolution resulted in the + * following: + * + * *blocked-55* ~ hasProp {- name: *blocked-55* -}, "name" + * + * This is a perfectly reasonable constraint, but one that doesn't actually + * constrain anything. When we encounter a constraint like this, we need to + * replace the result type by a free type that is scoped to the enclosing table. + * + * Conceptually, it's simplest to think of this constraint as one that is + * tautological. It does not actually contribute any new information. + */ +TEST_CASE_FIXTURE(Fixture, "handle_self_referential_HasProp_constraints") +{ + CheckResult result = check(R"( + local function calculateTopBarHeight(props) + end + local function isTopPage(props) + local topMostOpaquePage + if props.avatarRoute then + topMostOpaquePage = props.avatarRoute.opaque.name + else + topMostOpaquePage = props.opaquePage + end + end + + function TopBarContainer:updateTopBarHeight(prevProps, prevState) + calculateTopBarHeight(self.props) + isTopPage(self.props) + local topMostOpaquePage + if self.props.avatarRoute then + topMostOpaquePage = self.props.avatarRoute.opaque.name + -- ^--------------------------------^ + else + topMostOpaquePage = self.props.opaquePage + end + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 7475d04bd..e00d5ae42 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -161,10 +161,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_anything") TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_never") { - ScopedFastFlag sffs[]{ - {"LuauUninhabitedSubAnything2", true}, - }; - CheckResult result = check(R"( function f(arg : { prop : string & number }) : never return arg @@ -175,10 +171,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_never") TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_anything") { - ScopedFastFlag sffs[]{ - {"LuauUninhabitedSubAnything2", true}, - }; - CheckResult result = check(R"( function f(arg : { prop : string & number }) : boolean return arg diff --git a/tools/faillist.txt b/tools/faillist.txt index f049a0ee9..1233837fe 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -5,8 +5,6 @@ BuiltinTests.assert_removes_falsy_types2 BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy BuiltinTests.bad_select_should_not_crash -BuiltinTests.gmatch_definition -BuiltinTests.math_max_checks_for_numbers BuiltinTests.select_slightly_out_of_range BuiltinTests.select_way_out_of_range BuiltinTests.set_metatable_needs_arguments @@ -16,6 +14,10 @@ BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_tostring_specifier_type_constraint BuiltinTests.string_format_use_correct_argument2 +BuiltinTests.table_pack +BuiltinTests.table_pack_reduce +BuiltinTests.table_pack_variadic +DefinitionTests.class_definition_indexer DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props GenericsTests.better_mismatch_error_messages @@ -71,6 +73,7 @@ TableTests.expected_indexer_value_type_extra_2 TableTests.explicitly_typed_table TableTests.explicitly_typed_table_with_indexer TableTests.fuzz_table_unify_instantiated_table +TableTests.fuzz_table_unify_instantiated_table_with_prop_realloc TableTests.generic_table_instantiation_potential_regression TableTests.give_up_after_one_metatable_index_look_up TableTests.indexer_on_sealed_table_must_unify_with_free_table @@ -93,6 +96,7 @@ TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables TableTests.table_call_metamethod_basic +TableTests.table_call_metamethod_generic TableTests.table_simple_call TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.used_colon_instead_of_dot @@ -127,6 +131,7 @@ TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 +TypeInferClasses.callable_classes TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.index_instance_property TypeInferFunctions.cannot_hoist_interior_defns_into_signature @@ -161,8 +166,6 @@ TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted -TypeInferOperators.CallAndOrOfFunctions -TypeInferOperators.CallOrOfFunctions TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops @@ -179,8 +182,6 @@ TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check TypePackTests.type_alias_backwards_compatible TypePackTests.type_alias_default_type_errors -TypePackTests.unify_variadic_tails_in_arguments -TypePackTests.variadic_packs TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch TypeSingletons.no_widening_from_callsites @@ -192,5 +193,4 @@ TypeSingletons.widening_happens_almost_everywhere UnionTypes.dont_allow_cyclic_unions_to_be_inferred UnionTypes.generic_function_with_optional_arg UnionTypes.index_on_a_union_type_with_missing_property -UnionTypes.optional_union_follow UnionTypes.table_union_write_indirect From e00dbbeaf2a18e027fa3d59d5e5d063fa02db33d Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 7 Jul 2023 10:14:35 -0700 Subject: [PATCH 62/66] Sync to upstream/release/583 --- Analysis/include/Luau/Constraint.h | 10 +- Analysis/include/Luau/Differ.h | 138 +++++++++++ Analysis/include/Luau/Frontend.h | 4 + Analysis/include/Luau/Normalize.h | 9 +- Analysis/include/Luau/Type.h | 2 - Analysis/include/Luau/VisitType.h | 9 +- Analysis/src/Clone.cpp | 20 +- Analysis/src/ConstraintGraphBuilder.cpp | 44 +++- Analysis/src/ConstraintSolver.cpp | 26 +- Analysis/src/Differ.cpp | 273 ++++++++++++++++++++ Analysis/src/Frontend.cpp | 12 + Analysis/src/Normalize.cpp | 56 +++-- Analysis/src/Substitution.cpp | 30 +-- Analysis/src/ToDot.cpp | 9 +- Analysis/src/TypeChecker2.cpp | 5 + Analysis/src/TypeInfer.cpp | 100 +++----- Ast/include/Luau/Location.h | 6 + Ast/include/Luau/TimeTrace.h | 13 +- Ast/src/Location.cpp | 9 + Ast/src/TimeTrace.cpp | 11 +- CodeGen/include/Luau/IrData.h | 15 +- CodeGen/include/Luau/IrRegAllocX64.h | 2 + CodeGen/include/Luau/IrUtils.h | 4 + CodeGen/src/CodeGenA64.cpp | 29 +++ CodeGen/src/CodeGenLower.h | 2 + CodeGen/src/CodeGenX64.cpp | 10 + CodeGen/src/EmitCommon.h | 2 + CodeGen/src/EmitCommonX64.cpp | 17 ++ CodeGen/src/EmitCommonX64.h | 2 + CodeGen/src/EmitInstructionX64.cpp | 5 + CodeGen/src/IrBuilder.cpp | 79 ++++++ CodeGen/src/IrCallWrapperX64.cpp | 8 +- CodeGen/src/IrDump.cpp | 2 + CodeGen/src/IrLoweringA64.cpp | 66 +++-- CodeGen/src/IrLoweringA64.h | 7 + CodeGen/src/IrLoweringX64.cpp | 51 +++- CodeGen/src/IrLoweringX64.h | 9 +- CodeGen/src/IrRegAllocA64.cpp | 21 +- CodeGen/src/IrRegAllocA64.h | 2 + CodeGen/src/IrRegAllocX64.cpp | 8 + CodeGen/src/IrTranslateBuiltins.cpp | 19 ++ CodeGen/src/IrTranslation.cpp | 12 +- CodeGen/src/IrUtils.cpp | 29 ++- CodeGen/src/OptimizeConstProp.cpp | 13 +- Compiler/src/BytecodeBuilder.cpp | 4 +- Compiler/src/Compiler.cpp | 19 +- Compiler/src/Types.cpp | 156 ++++++++++-- Compiler/src/Types.h | 6 +- Makefile | 6 +- Sources.cmake | 3 + VM/src/lfunc.cpp | 4 + VM/src/lobject.h | 2 + VM/src/lvmload.cpp | 17 +- fuzz/proto.cpp | 21 ++ tests/ClassFixture.cpp | 1 - tests/Compiler.test.cpp | 105 +++++++- tests/Differ.test.cpp | 316 ++++++++++++++++++++++++ tests/Frontend.test.cpp | 31 +++ tests/IrBuilder.test.cpp | 56 +++-- tests/IrCallWrapperX64.test.cpp | 44 +++- tests/Normalize.test.cpp | 31 +++ tests/TypeInfer.classes.test.cpp | 1 - tests/TypeInfer.definitions.test.cpp | 1 - tests/TypeInfer.functions.test.cpp | 42 ++++ tests/TypeInfer.refinements.test.cpp | 71 +++++- tests/TypeInfer.singletons.test.cpp | 14 ++ tests/conformance/native.lua | 31 +++ tools/faillist.txt | 4 +- 68 files changed, 1893 insertions(+), 293 deletions(-) create mode 100644 Analysis/include/Luau/Differ.h create mode 100644 Analysis/src/Differ.cpp create mode 100644 tests/Differ.test.cpp diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index aa1d1c0ec..67f9470ee 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -34,6 +34,11 @@ struct PackSubtypeConstraint { TypePackId subPack; TypePackId superPack; + + // HACK!! TODO clip. + // We need to know which of `PackSubtypeConstraint` are emitted from `AstStatReturn` vs any others. + // Then we force these specific `PackSubtypeConstraint` to only dispatch in the order of the `return`s. + bool returns = false; }; // generalizedType ~ gen sourceType @@ -108,13 +113,12 @@ struct FunctionCallConstraint TypeId fn; TypePackId argsPack; TypePackId result; - class AstExprCall* callSite; + class AstExprCall* callSite = nullptr; std::vector> discriminantTypes; // When we dispatch this constraint, we update the key at this map to record // the overload that we selected. - DenseHashMap* astOriginalCallTypes; - DenseHashMap* astOverloadResolvedTypes; + DenseHashMap* astOverloadResolvedTypes = nullptr; }; // result ~ prim ExpectedType SomeSingletonType MultitonType diff --git a/Analysis/include/Luau/Differ.h b/Analysis/include/Luau/Differ.h new file mode 100644 index 000000000..ad276b4f9 --- /dev/null +++ b/Analysis/include/Luau/Differ.h @@ -0,0 +1,138 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Type.h" +#include +#include + +namespace Luau +{ +struct DiffPathNode +{ + // TODO: consider using Variants to simplify toString implementation + enum Kind + { + TableProperty, + FunctionArgument, + FunctionReturn, + Union, + Intersection, + }; + Kind kind; + // non-null when TableProperty + std::optional tableProperty; + // non-null when FunctionArgument, FunctionReturn, Union, or Intersection (i.e. anonymous fields) + std::optional index; + + /** + * Do not use for leaf nodes + */ + DiffPathNode(Kind kind) + : kind(kind) + { + } + + DiffPathNode(Kind kind, std::optional tableProperty, std::optional index) + : kind(kind) + , tableProperty(tableProperty) + , index(index) + { + } + + std::string toString() const; + + static DiffPathNode constructWithTableProperty(Name tableProperty); +}; +struct DiffPathNodeLeaf +{ + std::optional ty; + std::optional tableProperty; + DiffPathNodeLeaf(std::optional ty, std::optional tableProperty) + : ty(ty) + , tableProperty(tableProperty) + { + } + + static DiffPathNodeLeaf nullopts(); +}; +struct DiffPath +{ + std::vector path; + + std::string toString(bool prependDot) const; +}; +struct DiffError +{ + enum Kind + { + Normal, + MissingProperty, + LengthMismatchInFnArgs, + LengthMismatchInFnRets, + LengthMismatchInUnion, + LengthMismatchInIntersection, + }; + Kind kind; + + DiffPath diffPath; + DiffPathNodeLeaf left; + DiffPathNodeLeaf right; + + std::string leftRootName; + std::string rightRootName; + + DiffError(Kind kind, DiffPathNodeLeaf left, DiffPathNodeLeaf right, std::string leftRootName, std::string rightRootName) + : kind(kind) + , left(left) + , right(right) + , leftRootName(leftRootName) + , rightRootName(rightRootName) + { + checkValidInitialization(left, right); + } + DiffError(Kind kind, DiffPath diffPath, DiffPathNodeLeaf left, DiffPathNodeLeaf right, std::string leftRootName, std::string rightRootName) + : kind(kind) + , diffPath(diffPath) + , left(left) + , right(right) + , leftRootName(leftRootName) + , rightRootName(rightRootName) + { + checkValidInitialization(left, right); + } + + std::string toString() const; + +private: + std::string toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf) const; + void checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right); + void checkNonMissingPropertyLeavesHaveNulloptTableProperty() const; +}; + +struct DifferResult +{ + std::optional diffError; + + DifferResult() {} + DifferResult(DiffError diffError) + : diffError(diffError) + { + } + + void wrapDiffPath(DiffPathNode node); +}; +struct DifferEnvironment +{ + TypeId rootLeft; + TypeId rootRight; +}; +DifferResult diff(TypeId ty1, TypeId ty2); + +/** + * True if ty is a "simple" type, i.e. cannot contain types. + * string, number, boolean are simple types. + * function and table are not simple types. + */ +bool isSimple(TypeId ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 1306ad2c7..7b1eb2076 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -144,6 +144,10 @@ struct Frontend Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {}); + // Parse module graph and prepare SourceNode/SourceModule data, including required dependencies without running typechecking + void parse(const ModuleName& name); + + // Parse and typecheck module graph CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 1a252a88e..75c07a7be 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -274,6 +274,9 @@ struct NormalizedType /// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString() bool isSubtypeOfString() const; + /// Returns true if this type should result in error suppressing behavior. + bool shouldSuppressErrors() const; + // Helpers that improve readability of the above (they just say if the component is present) bool hasTops() const; bool hasBooleans() const; @@ -343,7 +346,7 @@ class Normalizer void unionTablesWithTable(TypeIds& heres, TypeId there); void unionTables(TypeIds& heres, const TypeIds& theres); bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); - bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); + bool unionNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set& seenSetTypes, int ignoreSmallerTyvars = -1); // ------- Negations std::optional negateNormal(const NormalizedType& here); @@ -365,9 +368,9 @@ class Normalizer std::optional intersectionOfFunctions(TypeId here, TypeId there); void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); - bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there); + bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, std::unordered_set& seenSetTypes); bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); - bool intersectNormalWithTy(NormalizedType& here, TypeId there); + bool intersectNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set& seenSetTypes); bool normalizeIntersections(const std::vector& intersections, NormalizedType& outType); // Check for inhabitance diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 9a35a1d6f..e9420922f 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -23,7 +23,6 @@ LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) -LUAU_FASTFLAG(LuauTypecheckClassTypeIndexers) namespace Luau { @@ -527,7 +526,6 @@ struct ClassType , definitionModuleName(definitionModuleName) , indexer(indexer) { - LUAU_ASSERT(FFlag::LuauTypecheckClassTypeIndexers); } }; diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index 1464aa1b5..a84fb48cc 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -304,13 +304,10 @@ struct GenericTypeVisitor if (ctv->metatable) traverse(*ctv->metatable); - if (FFlag::LuauTypecheckClassTypeIndexers) + if (ctv->indexer) { - if (ctv->indexer) - { - traverse(ctv->indexer->indexType); - traverse(ctv->indexer->indexResultType); - } + traverse(ctv->indexer->indexType); + traverse(ctv->indexer->indexResultType); } } } diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 1eb78540a..bdb510a37 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -55,7 +55,6 @@ Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState) static TableIndexer clone(const TableIndexer& indexer, TypeArena& dest, CloneState& cloneState) { - LUAU_ASSERT(FFlag::LuauTypecheckClassTypeIndexers); return TableIndexer{clone(indexer.indexType, dest, cloneState), clone(indexer.indexResultType, dest, cloneState)}; } @@ -312,16 +311,8 @@ void TypeCloner::operator()(const TableType& t) for (const auto& [name, prop] : t.props) ttv->props[name] = clone(prop, dest, cloneState); - if (FFlag::LuauTypecheckClassTypeIndexers) - { - if (t.indexer) - ttv->indexer = clone(*t.indexer, dest, cloneState); - } - else - { - if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; - } + if (t.indexer) + ttv->indexer = clone(*t.indexer, dest, cloneState); for (TypeId& arg : ttv->instantiatedTypeParams) arg = clone(arg, dest, cloneState); @@ -360,11 +351,8 @@ void TypeCloner::operator()(const ClassType& t) if (t.metatable) ctv->metatable = clone(*t.metatable, dest, cloneState); - if (FFlag::LuauTypecheckClassTypeIndexers) - { - if (t.indexer) - ctv->indexer = clone(*t.indexer, dest, cloneState); - } + if (t.indexer) + ctv->indexer = clone(*t.indexer, dest, cloneState); } void TypeCloner::operator()(const AnyType& t) diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 429f1a4db..c62c214c7 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -776,9 +776,10 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* f ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_) { - check(scope, while_->condition); + RefinementId refinement = check(scope, while_->condition).refinement; ScopePtr whileScope = childScope(while_, scope); + applyRefinements(whileScope, while_->condition->location, refinement); visit(whileScope, while_->body); @@ -825,8 +826,17 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFun std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); - forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { + Constraint* previous = nullptr; + forEachConstraint(start, end, this, [&c, &previous](const ConstraintPtr& constraint) { c->dependencies.push_back(NotNull{constraint.get()}); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } }); addConstraint(scope, std::move(c)); @@ -915,9 +925,18 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); - forEachConstraint(start, end, this, [&c, &excludeList](const ConstraintPtr& constraint) { + Constraint* previous = nullptr; + forEachConstraint(start, end, this, [&c, &excludeList, &previous](const ConstraintPtr& constraint) { if (!excludeList.count(constraint.get())) c->dependencies.push_back(NotNull{constraint.get()}); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } }); addConstraint(scope, std::move(c)); @@ -936,7 +955,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* expectedTypes.push_back(ty); TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; - addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); + addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType, /*returns*/ true}); return ControlFlow::Returns; } @@ -1408,6 +1427,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa std::vector> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType); module->astOriginalCallTypes[call->func] = fnType; + module->astOriginalCallTypes[call] = fnType; TypePackId expectedArgPack = arena->freshTypePack(scope.get()); TypePackId expectedRetPack = arena->freshTypePack(scope.get()); @@ -1547,7 +1567,6 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa rets, call, std::move(discriminantTypes), - &module->astOriginalCallTypes, &module->astOverloadResolvedTypes, }); @@ -1642,7 +1661,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantSt if (expectedType) { const TypeId expectedTy = follow(*expectedType); - if (get(expectedTy) || get(expectedTy)) + if (get(expectedTy) || get(expectedTy) || get(expectedTy)) { TypeId ty = arena->addType(BlockedType{}); TypeId singletonType = arena->addType(SingletonType(StringSingleton{std::string(string->value.data, string->value.size)})); @@ -1774,8 +1793,17 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprFunction* TypeId generalizedTy = arena->addType(BlockedType{}); NotNull gc = addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature}); - forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { + Constraint* previous = nullptr; + forEachConstraint(startCheckpoint, endCheckpoint, this, [gc, &previous](const ConstraintPtr& constraint) { gc->dependencies.emplace_back(constraint.get()); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } }); return Inference{generalizedTy}; @@ -2412,7 +2440,7 @@ void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFun if (nullptr != getFallthrough(fn->body)) { - TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever + TypePackId empty = arena->addTypePack({}); // TODO we could have CGB retain one of these forever addConstraint(scope, fn->location, PackSubtypeConstraint{scope->returnType, empty}); } } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 81e6574ad..c9b584fd6 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1259,18 +1259,13 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn)) fn = collapse(it).value_or(fn); - if (c.callSite) - (*c.astOriginalCallTypes)[c.callSite] = fn; - // We don't support magic __call metamethods. if (std::optional callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location)) { - std::vector args{fn}; - - for (TypeId arg : c.argsPack) - args.push_back(arg); + auto [head, tail] = flatten(c.argsPack); + head.insert(head.begin(), fn); - argsPack = arena->addTypePack(TypePack{args, {}}); + argsPack = arena->addTypePack(TypePack{std::move(head), tail}); fn = *callMm; asMutable(c.result)->ty.emplace(constraint->scope); } @@ -1890,7 +1885,20 @@ bool ConstraintSolver::tryDispatch(const RefineConstraint& c, NotNullty.emplace(result); + const NormalizedType* normType = normalizer->normalize(c.type); + + if (!normType) + reportError(NormalizationTooComplex{}, constraint->location); + + if (normType && normType->shouldSuppressErrors()) + { + auto resultOrError = simplifyUnion(builtinTypes, arena, result, builtinTypes->errorType).result; + asMutable(c.resultType)->ty.emplace(resultOrError); + } + else + { + asMutable(c.resultType)->ty.emplace(result); + } unblock(c.resultType, constraint->location); diff --git a/Analysis/src/Differ.cpp b/Analysis/src/Differ.cpp new file mode 100644 index 000000000..91b6f4b61 --- /dev/null +++ b/Analysis/src/Differ.cpp @@ -0,0 +1,273 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Differ.h" +#include "Luau/Error.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" +#include + +namespace Luau +{ +std::string DiffPathNode::toString() const +{ + switch (kind) + { + case DiffPathNode::Kind::TableProperty: + { + if (!tableProperty.has_value()) + throw InternalCompilerError{"DiffPathNode has kind TableProperty but tableProperty is nullopt"}; + return *tableProperty; + break; + } + default: + { + throw InternalCompilerError{"DiffPathNode::toString is not exhaustive"}; + } + } +} + +DiffPathNode DiffPathNode::constructWithTableProperty(Name tableProperty) +{ + return DiffPathNode{DiffPathNode::Kind::TableProperty, tableProperty, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::nullopts() +{ + return DiffPathNodeLeaf{std::nullopt, std::nullopt}; +} + +std::string DiffPath::toString(bool prependDot) const +{ + std::string pathStr; + bool isFirstInForLoop = !prependDot; + for (auto node = path.rbegin(); node != path.rend(); node++) + { + if (isFirstInForLoop) + { + isFirstInForLoop = false; + } + else + { + pathStr += "."; + } + pathStr += node->toString(); + } + return pathStr; +} +std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf) const +{ + std::string pathStr{rootName + diffPath.toString(true)}; + switch (kind) + { + case DiffError::Kind::Normal: + { + checkNonMissingPropertyLeavesHaveNulloptTableProperty(); + return pathStr + " has type " + Luau::toString(*leaf.ty); + } + case DiffError::Kind::MissingProperty: + { + if (leaf.ty.has_value()) + { + if (!leaf.tableProperty.has_value()) + throw InternalCompilerError{"leaf.tableProperty is nullopt"}; + return pathStr + "." + *leaf.tableProperty + " has type " + Luau::toString(*leaf.ty); + } + else if (otherLeaf.ty.has_value()) + { + if (!otherLeaf.tableProperty.has_value()) + throw InternalCompilerError{"otherLeaf.tableProperty is nullopt"}; + return pathStr + " is missing the property " + *otherLeaf.tableProperty; + } + throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; + } + default: + { + throw InternalCompilerError{"DiffPath::toStringWithLeaf is not exhaustive"}; + } + } +} + +void DiffError::checkNonMissingPropertyLeavesHaveNulloptTableProperty() const +{ + if (left.tableProperty.has_value() || right.tableProperty.has_value()) + throw InternalCompilerError{"Non-MissingProperty DiffError should have nullopt tableProperty in both leaves"}; +} + +std::string getDevFixFriendlyName(TypeId ty) +{ + if (auto table = get(ty)) + { + if (table->name.has_value()) + return *table->name; + else if (table->syntheticName.has_value()) + return *table->syntheticName; + } + // else if (auto primitive = get(ty)) + //{ + // return ""; + //} + return ""; +} + +std::string DiffError::toString() const +{ + std::string msg = "DiffError: these two types are not equal because the left type at " + toStringALeaf(leftRootName, left, right) + + ", while the right type at " + toStringALeaf(rightRootName, right, left); + return msg; +} + +void DiffError::checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right) +{ + if (!left.ty.has_value() || !right.ty.has_value()) + { + // TODO: think about whether this should be always thrown! + // For example, Kind::Primitive doesn't make too much sense to have a TypeId + // throw InternalCompilerError{"Left and Right fields are leaf nodes and must have a TypeId"}; + } +} + +void DifferResult::wrapDiffPath(DiffPathNode node) +{ + if (!diffError.has_value()) + { + throw InternalCompilerError{"Cannot wrap diffPath because there is no diffError"}; + } + + diffError->diffPath.path.push_back(node); +} + +static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right); + +static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) +{ + const TableType* leftTable = get(left); + const TableType* rightTable = get(right); + + for (auto const& [field, value] : leftTable->props) + { + if (rightTable->props.find(field) == rightTable->props.end()) + { + // left has a field the right doesn't + return DifferResult{DiffError{ + DiffError::Kind::MissingProperty, + DiffPathNodeLeaf{value.type(), field}, + DiffPathNodeLeaf::nullopts(), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; + } + } + for (auto const& [field, value] : rightTable->props) + { + if (leftTable->props.find(field) == leftTable->props.end()) + { + // right has a field the left doesn't + return DifferResult{DiffError{DiffError::Kind::MissingProperty, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf{value.type(), field}, + getDevFixFriendlyName(env.rootLeft), getDevFixFriendlyName(env.rootRight)}}; + } + } + // left and right have the same set of keys + for (auto const& [field, leftValue] : leftTable->props) + { + auto const& rightValue = rightTable->props.at(field); + DifferResult differResult = diffUsingEnv(env, leftValue.type(), rightValue.type()); + if (differResult.diffError.has_value()) + { + differResult.wrapDiffPath(DiffPathNode::constructWithTableProperty(field)); + return differResult; + } + } + return DifferResult{}; +} + +static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right) +{ + const PrimitiveType* leftPrimitive = get(left); + const PrimitiveType* rightPrimitive = get(right); + + if (leftPrimitive->type != rightPrimitive->type) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf{left, std::nullopt}, + DiffPathNodeLeaf{right, std::nullopt}, + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; + } + return DifferResult{}; +} + +static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right) +{ + const SingletonType* leftSingleton = get(left); + const SingletonType* rightSingleton = get(right); + + if (*leftSingleton != *rightSingleton) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf{left, std::nullopt}, + DiffPathNodeLeaf{right, std::nullopt}, + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; + } + return DifferResult{}; +} + +static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + + if (left->ty.index() != right->ty.index()) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf{left, std::nullopt}, + DiffPathNodeLeaf{right, std::nullopt}, + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; + } + + // Both left and right are the same variant + + if (isSimple(left)) + { + if (auto lp = get(left)) + return diffPrimitive(env, left, right); + else if (auto ls = get(left)) + { + return diffSingleton(env, left, right); + } + + throw InternalCompilerError{"Unimplemented Simple TypeId variant for diffing"}; + } + + // Both left and right are the same non-Simple + + if (auto lt = get(left)) + { + return diffTable(env, left, right); + } + throw InternalCompilerError{"Unimplemented non-simple TypeId variant for diffing"}; +} + +DifferResult diff(TypeId ty1, TypeId ty2) +{ + DifferEnvironment differEnv{ty1, ty2}; + return diffUsingEnv(differEnv, ty1, ty2); +} + +bool isSimple(TypeId ty) +{ + ty = follow(ty); + // TODO: think about GenericType, etc. + return get(ty) || get(ty); +} + +} // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index f88425b55..9f1fd7267 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -415,6 +415,18 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c { } +void Frontend::parse(const ModuleName& name) +{ + LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + + if (getCheckResult(name, false, false)) + return; + + std::vector buildQueue; + parseGraph(buildQueue, name, false); +} + CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) { LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index a7e3bb6e7..33a8b6eb1 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -18,6 +18,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCyclicUnions, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauTransitiveSubtyping) LUAU_FASTFLAG(DebugLuauReadWriteProperties) @@ -247,6 +248,11 @@ bool NormalizedType::isSubtypeOfString() const !hasTables() && !hasFunctions() && !hasTyvars(); } +bool NormalizedType::shouldSuppressErrors() const +{ + return hasErrors() || get(tops); +} + bool NormalizedType::hasTops() const { return !get(tops); @@ -690,7 +696,8 @@ const NormalizedType* Normalizer::normalize(TypeId ty) return found->second.get(); NormalizedType norm{builtinTypes}; - if (!unionNormalWithTy(norm, ty)) + std::unordered_set seenSetTypes; + if (!unionNormalWithTy(norm, ty, seenSetTypes)) return nullptr; std::unique_ptr uniq = std::make_unique(std::move(norm)); const NormalizedType* result = uniq.get(); @@ -705,9 +712,12 @@ bool Normalizer::normalizeIntersections(const std::vector& intersections NormalizedType norm{builtinTypes}; norm.tops = builtinTypes->anyType; // Now we need to intersect the two types + std::unordered_set seenSetTypes; for (auto ty : intersections) - if (!intersectNormalWithTy(norm, ty)) + { + if (!intersectNormalWithTy(norm, ty, seenSetTypes)) return false; + } if (!unionNormals(outType, norm)) return false; @@ -1438,13 +1448,14 @@ bool Normalizer::withinResourceLimits() } // See above for an explaination of `ignoreSmallerTyvars`. -bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars) +bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set& seenSetTypes, int ignoreSmallerTyvars) { RecursionCounter _rc(&sharedState->counters.recursionCount); if (!withinResourceLimits()) return false; there = follow(there); + if (get(there) || get(there)) { TypeId tops = unionOfTops(here.tops, there); @@ -1465,9 +1476,23 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor } else if (const UnionType* utv = get(there)) { + if (FFlag::LuauNormalizeCyclicUnions) + { + if (seenSetTypes.count(there)) + return true; + seenSetTypes.insert(there); + } + for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) - if (!unionNormalWithTy(here, *it)) + { + if (!unionNormalWithTy(here, *it, seenSetTypes)) + { + seenSetTypes.erase(there); return false; + } + } + + seenSetTypes.erase(there); return true; } else if (const IntersectionType* itv = get(there)) @@ -1475,8 +1500,10 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor NormalizedType norm{builtinTypes}; norm.tops = builtinTypes->anyType; for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) - if (!intersectNormalWithTy(norm, *it)) + { + if (!intersectNormalWithTy(norm, *it, seenSetTypes)) return false; + } return unionNormals(here, norm); } else if (FFlag::LuauTransitiveSubtyping && get(here.tops)) @@ -1560,7 +1587,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor LUAU_ASSERT(!"Unreachable"); for (auto& [tyvar, intersect] : here.tyvars) - if (!unionNormalWithTy(*intersect, there, tyvarIndex(tyvar))) + if (!unionNormalWithTy(*intersect, there, seenSetTypes, tyvarIndex(tyvar))) return false; assertInvariant(here); @@ -2463,12 +2490,12 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali } } -bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there) +bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, std::unordered_set& seenSetTypes) { for (auto it = here.begin(); it != here.end();) { NormalizedType& inter = *it->second; - if (!intersectNormalWithTy(inter, there)) + if (!intersectNormalWithTy(inter, there, seenSetTypes)) return false; if (isShallowInhabited(inter)) ++it; @@ -2541,13 +2568,14 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th return true; } -bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) +bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set& seenSetTypes) { RecursionCounter _rc(&sharedState->counters.recursionCount); if (!withinResourceLimits()) return false; there = follow(there); + if (get(there) || get(there)) { here.tops = intersectionOfTops(here.tops, there); @@ -2556,20 +2584,20 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) else if (!get(here.tops)) { clearNormal(here); - return unionNormalWithTy(here, there); + return unionNormalWithTy(here, there, seenSetTypes); } else if (const UnionType* utv = get(there)) { NormalizedType norm{builtinTypes}; for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) - if (!unionNormalWithTy(norm, *it)) + if (!unionNormalWithTy(norm, *it, seenSetTypes)) return false; return intersectNormals(here, norm); } else if (const IntersectionType* itv = get(there)) { for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) - if (!intersectNormalWithTy(here, *it)) + if (!intersectNormalWithTy(here, *it, seenSetTypes)) return false; return true; } @@ -2691,7 +2719,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) return true; } else if (auto nt = get(t)) - return intersectNormalWithTy(here, nt->ty); + return intersectNormalWithTy(here, nt->ty, seenSetTypes); else { // TODO negated unions, intersections, table, and function. @@ -2706,7 +2734,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) else LUAU_ASSERT(!"Unreachable"); - if (!intersectTyvarsWithTy(tyvars, there)) + if (!intersectTyvarsWithTy(tyvars, there, seenSetTypes)) return false; here.tyvars = std::move(tyvars); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 6c1908bf6..9c34cd7cc 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -191,16 +191,8 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a { if (alwaysClone) { - if (FFlag::LuauTypecheckClassTypeIndexers) - { - ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName, a.indexer}; - return dest.addType(std::move(clone)); - } - else - { - ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName}; - return dest.addType(std::move(clone)); - } + ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName, a.indexer}; + return dest.addType(std::move(clone)); } else return ty; @@ -316,13 +308,10 @@ void Tarjan::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable); - if (FFlag::LuauTypecheckClassTypeIndexers) + if (ctv->indexer) { - if (ctv->indexer) - { - visitChild(ctv->indexer->indexType); - visitChild(ctv->indexer->indexResultType); - } + visitChild(ctv->indexer->indexType); + visitChild(ctv->indexer->indexResultType); } } else if (const NegationType* ntv = get(ty)) @@ -1038,13 +1027,10 @@ void Substitution::replaceChildren(TypeId ty) if (ctv->metatable) ctv->metatable = replace(*ctv->metatable); - if (FFlag::LuauTypecheckClassTypeIndexers) + if (ctv->indexer) { - if (ctv->indexer) - { - ctv->indexer->indexType = replace(ctv->indexer->indexType); - ctv->indexer->indexResultType = replace(ctv->indexer->indexResultType); - } + ctv->indexer->indexType = replace(ctv->indexer->indexType); + ctv->indexer->indexResultType = replace(ctv->indexer->indexResultType); } } else if (NegationType* ntv = getMutable(ty)) diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index f2f15e85e..c3a1db4cd 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -258,13 +258,10 @@ void StateDot::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable, index, "[metatable]"); - if (FFlag::LuauTypecheckClassTypeIndexers) + if (ctv->indexer) { - if (ctv->indexer) - { - visitChild(ctv->indexer->indexType, index, "[index]"); - visitChild(ctv->indexer->indexResultType, index, "[value]"); - } + visitChild(ctv->indexer->indexType, index, "[index]"); + visitChild(ctv->indexer->indexResultType, index, "[value]"); } } else if (const SingletonType* stv = get(ty)) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 7a46bf969..103f0dcab 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1090,6 +1090,11 @@ struct TypeChecker2 args.head.push_back(lookupType(indexExpr->expr)); argLocs.push_back(indexExpr->expr->location); } + else if (findMetatableEntry(builtinTypes, module->errors, *originalCallTy, "__call", call->func->location)) + { + args.head.insert(args.head.begin(), lookupType(call->func)); + argLocs.push_back(call->func->location); + } for (size_t i = 0; i < call->args.size; ++i) { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index c9da34f4c..c4a6d103b 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -38,7 +38,6 @@ LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) -LUAU_FASTFLAGVARIABLE(LuauTypecheckClassTypeIndexers, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAG(LuauParseDeclareClassIndexer) @@ -2107,21 +2106,18 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( if (prop) return prop->type(); - if (FFlag::LuauTypecheckClassTypeIndexers) + if (auto indexer = cls->indexer) { - if (auto indexer = cls->indexer) - { - // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. - ErrorVec errors = tryUnify(stringType, indexer->indexType, scope, location); + // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. + ErrorVec errors = tryUnify(stringType, indexer->indexType, scope, location); - if (errors.empty()) - return indexer->indexResultType; + if (errors.empty()) + return indexer->indexResultType; - if (addErrors) - reportError(location, UnknownProperty{type, name}); + if (addErrors) + reportError(location, UnknownProperty{type, name}); - return std::nullopt; - } + return std::nullopt; } } else if (const UnionType* utv = get(type)) @@ -3312,38 +3308,24 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (const ClassType* lhsClass = get(lhs)) { - if (FFlag::LuauTypecheckClassTypeIndexers) + if (const Property* prop = lookupClassProp(lhsClass, name)) { - if (const Property* prop = lookupClassProp(lhsClass, name)) - { - return prop->type(); - } - - if (auto indexer = lhsClass->indexer) - { - Unifier state = mkUnifier(scope, expr.location); - state.tryUnify(stringType, indexer->indexType); - if (state.errors.empty()) - { - state.log.commit(); - return indexer->indexResultType; - } - } - - reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return errorRecoveryType(scope); + return prop->type(); } - else + + if (auto indexer = lhsClass->indexer) { - const Property* prop = lookupClassProp(lhsClass, name); - if (!prop) + Unifier state = mkUnifier(scope, expr.location); + state.tryUnify(stringType, indexer->indexType); + if (state.errors.empty()) { - reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return errorRecoveryType(scope); + state.log.commit(); + return indexer->indexResultType; } - - return prop->type(); } + + reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); + return errorRecoveryType(scope); } else if (get(lhs)) { @@ -3385,45 +3367,29 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { if (const ClassType* exprClass = get(exprType)) { - if (FFlag::LuauTypecheckClassTypeIndexers) + if (const Property* prop = lookupClassProp(exprClass, value->value.data)) { - if (const Property* prop = lookupClassProp(exprClass, value->value.data)) - { - return prop->type(); - } - - if (auto indexer = exprClass->indexer) - { - unify(stringType, indexer->indexType, scope, expr.index->location); - return indexer->indexResultType; - } - - reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return errorRecoveryType(scope); + return prop->type(); } - else + + if (auto indexer = exprClass->indexer) { - const Property* prop = lookupClassProp(exprClass, value->value.data); - if (!prop) - { - reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return errorRecoveryType(scope); - } - return prop->type(); + unify(stringType, indexer->indexType, scope, expr.index->location); + return indexer->indexResultType; } + + reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); + return errorRecoveryType(scope); } } else { - if (FFlag::LuauTypecheckClassTypeIndexers) + if (const ClassType* exprClass = get(exprType)) { - if (const ClassType* exprClass = get(exprType)) + if (auto indexer = exprClass->indexer) { - if (auto indexer = exprClass->indexer) - { - unify(indexType, indexer->indexType, scope, expr.index->location); - return indexer->indexResultType; - } + unify(indexType, indexer->indexType, scope, expr.index->location); + return indexer->indexResultType; } } diff --git a/Ast/include/Luau/Location.h b/Ast/include/Luau/Location.h index dbe36becb..41ca379d1 100644 --- a/Ast/include/Luau/Location.h +++ b/Ast/include/Luau/Location.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include namespace Luau { @@ -38,6 +39,11 @@ struct Location bool containsClosed(const Position& p) const; void extend(const Location& other); void shift(const Position& start, const Position& oldEnd, const Position& newEnd); + + /** + * Use offset=1 when displaying for the user. + */ + std::string toString(int offset = 0, bool useBegin = true) const; }; } // namespace Luau diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index be2828272..2f7daf2c6 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -4,6 +4,7 @@ #include "Luau/Common.h" #include +#include #include @@ -54,7 +55,7 @@ struct Event struct GlobalContext; struct ThreadContext; -GlobalContext& getGlobalContext(); +std::shared_ptr getGlobalContext(); uint16_t createToken(GlobalContext& context, const char* name, const char* category); uint32_t createThread(GlobalContext& context, ThreadContext* threadContext); @@ -66,7 +67,7 @@ struct ThreadContext ThreadContext() : globalContext(getGlobalContext()) { - threadId = createThread(globalContext, this); + threadId = createThread(*globalContext, this); } ~ThreadContext() @@ -74,16 +75,16 @@ struct ThreadContext if (!events.empty()) flushEvents(); - releaseThread(globalContext, this); + releaseThread(*globalContext, this); } void flushEvents() { - static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace"); + static uint16_t flushToken = createToken(*globalContext, "flushEvents", "TimeTrace"); events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}}); - TimeTrace::flushEvents(globalContext, threadId, events, data); + TimeTrace::flushEvents(*globalContext, threadId, events, data); events.clear(); data.clear(); @@ -125,7 +126,7 @@ struct ThreadContext events.push_back({EventType::ArgValue, 0, {pos}}); } - GlobalContext& globalContext; + std::shared_ptr globalContext; uint32_t threadId; std::vector events; std::vector data; diff --git a/Ast/src/Location.cpp b/Ast/src/Location.cpp index d01d8a186..e0ae867fc 100644 --- a/Ast/src/Location.cpp +++ b/Ast/src/Location.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Location.h" +#include namespace Luau { @@ -128,4 +129,12 @@ void Location::shift(const Position& start, const Position& oldEnd, const Positi end.shift(start, oldEnd, newEnd); } +std::string Location::toString(int offset, bool useBegin) const +{ + const Position& pos = useBegin ? this->begin : this->end; + std::string line{std::to_string(pos.line + offset)}; + std::string column{std::to_string(pos.column + offset)}; + return "(" + line + ", " + column + ")"; +} + } // namespace Luau diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index e38076830..8b95cf0b7 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -90,7 +90,6 @@ namespace TimeTrace { struct GlobalContext { - GlobalContext() = default; ~GlobalContext() { // Ideally we would want all ThreadContext destructors to run @@ -110,11 +109,15 @@ struct GlobalContext uint32_t nextThreadId = 0; std::vector tokens; FILE* traceFile = nullptr; + +private: + friend std::shared_ptr getGlobalContext(); + GlobalContext() = default; }; -GlobalContext& getGlobalContext() +std::shared_ptr getGlobalContext() { - static GlobalContext context; + static std::shared_ptr context = std::shared_ptr{new GlobalContext}; return context; } @@ -261,7 +264,7 @@ ThreadContext& getThreadContext() uint16_t createScopeData(const char* name, const char* category) { - return createToken(Luau::TimeTrace::getGlobalContext(), name, category); + return createToken(*Luau::TimeTrace::getGlobalContext(), name, category); } } // namespace TimeTrace } // namespace Luau diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 8cbe7e8b1..16c8df628 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -236,6 +236,10 @@ enum class IrCmd : uint8_t // A: pointer (Table) TABLE_LEN, + // Get string length + // A: pointer (string) + STRING_LEN, + // Allocate new table // A: int (array element count) // B: int (node element count) @@ -361,8 +365,10 @@ enum class IrCmd : uint8_t // Guard against tag mismatch // A, B: tag // C: block/undef + // D: bool (finish execution in VM on failure) // In final x64 lowering, A can also be Rn - // When undef is specified instead of a block, execution is aborted on check failure + // When undef is specified instead of a block, execution is aborted on check failure; if D is true, execution is continued in VM interpreter + // instead. CHECK_TAG, // Guard against readonly table @@ -377,9 +383,9 @@ enum class IrCmd : uint8_t // When undef is specified instead of a block, execution is aborted on check failure CHECK_NO_METATABLE, - // Guard against executing in unsafe environment - // A: block/undef - // When undef is specified instead of a block, execution is aborted on check failure + // Guard against executing in unsafe environment, exits to VM on check failure + // A: unsigned int (pcpos)/undef + // When undef is specified, execution is aborted on check failure CHECK_SAFE_ENV, // Guard against index overflowing the table array size @@ -610,7 +616,6 @@ struct IrConst union { - bool valueBool; int valueInt; unsigned valueUint; double valueDouble; diff --git a/CodeGen/include/Luau/IrRegAllocX64.h b/CodeGen/include/Luau/IrRegAllocX64.h index f83cc2208..959308115 100644 --- a/CodeGen/include/Luau/IrRegAllocX64.h +++ b/CodeGen/include/Luau/IrRegAllocX64.h @@ -39,6 +39,8 @@ struct IrRegAllocX64 RegisterX64 allocRegOrReuse(SizeX64 size, uint32_t instIdx, std::initializer_list oprefs); RegisterX64 takeReg(RegisterX64 reg, uint32_t instIdx); + bool canTakeReg(RegisterX64 reg) const; + void freeReg(RegisterX64 reg); void freeLastUseReg(IrInst& target, uint32_t instIdx); void freeLastUseRegs(const IrInst& inst, uint32_t instIdx); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index a3e97894c..6481342f5 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -167,6 +167,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::ABS_NUM: case IrCmd::NOT_ANY: case IrCmd::TABLE_LEN: + case IrCmd::STRING_LEN: case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: case IrCmd::TRY_NUM_TO_INDEX: @@ -256,5 +257,8 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 uint32_t getNativeContextOffset(int bfid); +// Cleans up blocks that were created with no users +void killUnusedBlocks(IrFunction& function); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index cc0131820..37af59aa5 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -24,6 +24,15 @@ struct EntryLocations Label epilogueStart; }; +static void emitClearNativeFlag(AssemblyBuilderA64& build) +{ + build.ldr(x0, mem(rState, offsetof(lua_State, ci))); + build.ldr(w1, mem(x0, offsetof(CallInfo, flags))); + build.mov(w2, ~LUA_CALLINFO_NATIVE); + build.and_(w1, w1, w2); + build.str(w1, mem(x0, offsetof(CallInfo, flags))); +} + static void emitExit(AssemblyBuilderA64& build, bool continueInVm) { build.mov(x0, continueInVm); @@ -31,6 +40,16 @@ static void emitExit(AssemblyBuilderA64& build, bool continueInVm) build.br(x1); } +static void emitUpdatePcAndContinueInVm(AssemblyBuilderA64& build) +{ + // x0 = pcpos * sizeof(Instruction) + build.add(x0, rCode, x0); + build.ldr(x1, mem(rState, offsetof(lua_State, ci))); + build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); + + emitExit(build, /* continueInVm */ true); +} + static void emitInterrupt(AssemblyBuilderA64& build) { // x0 = pc offset @@ -286,6 +305,11 @@ bool initHeaderFunctions(NativeState& data) void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers) { + if (build.logText) + build.logAppend("; exitContinueVmClearNativeFlag\n"); + build.setLabel(helpers.exitContinueVmClearNativeFlag); + emitClearNativeFlag(build); + if (build.logText) build.logAppend("; exitContinueVm\n"); build.setLabel(helpers.exitContinueVm); @@ -296,6 +320,11 @@ void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers) build.setLabel(helpers.exitNoContinueVm); emitExit(build, /* continueInVm */ false); + if (build.logText) + build.logAppend("; updatePcAndContinueInVm\n"); + build.setLabel(helpers.updatePcAndContinueInVm); + emitUpdatePcAndContinueInVm(build); + if (build.logText) build.logAppend("; reentry\n"); build.setLabel(helpers.reentry); diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index 5b6c4ffc4..4b74e9f20 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -221,6 +221,8 @@ inline bool lowerIr(A64::AssemblyBuilderA64& build, IrBuilder& ir, ModuleHelpers template inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { + killUnusedBlocks(ir.function); + computeCfgInfo(ir.function); if (!FFlag::DebugCodegenNoOpt) diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 41c3dbd05..1e62a4d46 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -180,6 +180,11 @@ bool initHeaderFunctions(NativeState& data) void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) { + if (build.logText) + build.logAppend("; exitContinueVmClearNativeFlag\n"); + build.setLabel(helpers.exitContinueVmClearNativeFlag); + emitClearNativeFlag(build); + if (build.logText) build.logAppend("; exitContinueVm\n"); build.setLabel(helpers.exitContinueVm); @@ -190,6 +195,11 @@ void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) build.setLabel(helpers.exitNoContinueVm); emitExit(build, /* continueInVm */ false); + if (build.logText) + build.logAppend("; updatePcAndContinueInVm\n"); + build.setLabel(helpers.updatePcAndContinueInVm); + emitUpdatePcAndContinueInVm(build); + if (build.logText) build.logAppend("; continueCallInVm\n"); build.setLabel(helpers.continueCallInVm); diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index f912ffba7..086660647 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -24,6 +24,8 @@ struct ModuleHelpers // A64/X64 Label exitContinueVm; Label exitNoContinueVm; + Label exitContinueVmClearNativeFlag; + Label updatePcAndContinueInVm; Label return_; Label interrupt; diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 4d70bb7a7..1d707fad9 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -268,6 +268,13 @@ void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build) build.setLabel(skip); } + +void emitClearNativeFlag(AssemblyBuilderX64& build) +{ + build.mov(rax, qword[rState + offsetof(lua_State, ci)]); + build.and_(dword[rax + offsetof(CallInfo, flags)], ~LUA_CALLINFO_NATIVE); +} + void emitExit(AssemblyBuilderX64& build, bool continueInVm) { if (continueInVm) @@ -345,6 +352,16 @@ void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, in emitUpdateBase(build); } +void emitUpdatePcAndContinueInVm(AssemblyBuilderX64& build) +{ + // edx = pcpos * sizeof(Instruction) + build.add(rdx, sCode); + build.mov(rax, qword[rState + offsetof(lua_State, ci)]); + build.mov(qword[rax + offsetof(CallInfo, savedpc)], rdx); + + emitExit(build, /* continueInVm */ true); +} + void emitContinueCallInVm(AssemblyBuilderX64& build) { RegisterX64 proto = rcx; // Sync with emitInstCall diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 5a3548f6c..02d9f40bc 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -175,11 +175,13 @@ void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp); void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); +void emitClearNativeFlag(AssemblyBuilderX64& build); void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitUpdateBase(AssemblyBuilderX64& build); void emitInterrupt(AssemblyBuilderX64& build); void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos); +void emitUpdatePcAndContinueInVm(AssemblyBuilderX64& build); void emitContinueCallInVm(AssemblyBuilderX64& build); void emitReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers); diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 61d5ac63e..ea511958f 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -90,6 +90,11 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int build.mov(qword[rState + offsetof(lua_State, top)], argi); build.setLabel(skipVararg); + // Keep executing new function + // ci->savedpc = p->code; + build.mov(rax, qword[proto + offsetof(Proto, code)]); + build.mov(qword[ci + offsetof(CallInfo, savedpc)], rax); + // Get native function entry build.mov(rax, qword[proto + offsetof(Proto, exectarget)]); build.test(rax, rax); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 98db2977e..69ac295a3 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -22,6 +22,82 @@ IrBuilder::IrBuilder() { } +static void buildArgumentTypeChecks(IrBuilder& build, Proto* proto) +{ + if (!proto->typeinfo || proto->numparams == 0) + return; + + for (int i = 0; i < proto->numparams; ++i) + { + uint8_t et = proto->typeinfo[2 + i]; + + uint8_t tag = et & ~LBC_TYPE_OPTIONAL_BIT; + uint8_t optional = et & LBC_TYPE_OPTIONAL_BIT; + + if (tag == LBC_TYPE_ANY) + continue; + + IrOp load = build.inst(IrCmd::LOAD_TAG, build.vmReg(i)); + + IrOp nextCheck; + if (optional) + { + nextCheck = build.block(IrBlockKind::Internal); + IrOp fallbackCheck = build.block(IrBlockKind::Internal); + + build.inst(IrCmd::JUMP_EQ_TAG, load, build.constTag(LUA_TNIL), nextCheck, fallbackCheck); + + build.beginBlock(fallbackCheck); + } + + switch (tag) + { + case LBC_TYPE_NIL: + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNIL), build.undef(), build.constInt(1)); + break; + case LBC_TYPE_BOOLEAN: + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBOOLEAN), build.undef(), build.constInt(1)); + break; + case LBC_TYPE_NUMBER: + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNUMBER), build.undef(), build.constInt(1)); + break; + case LBC_TYPE_STRING: + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TSTRING), build.undef(), build.constInt(1)); + break; + case LBC_TYPE_TABLE: + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTABLE), build.undef(), build.constInt(1)); + break; + case LBC_TYPE_FUNCTION: + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TFUNCTION), build.undef(), build.constInt(1)); + break; + case LBC_TYPE_THREAD: + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTHREAD), build.undef(), build.constInt(1)); + break; + case LBC_TYPE_USERDATA: + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.undef(), build.constInt(1)); + break; + case LBC_TYPE_VECTOR: + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.undef(), build.constInt(1)); + break; + } + + if (optional) + { + build.inst(IrCmd::JUMP, nextCheck); + build.beginBlock(nextCheck); + } + } + + // If the last argument is optional, we can skip creating a new internal block since one will already have been created. + if (!(proto->typeinfo[2 + proto->numparams - 1] & LBC_TYPE_OPTIONAL_BIT)) + { + IrOp next = build.block(IrBlockKind::Internal); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); + } +} + void IrBuilder::buildFunctionIr(Proto* proto) { function.proto = proto; @@ -47,6 +123,9 @@ void IrBuilder::buildFunctionIr(Proto* proto) if (instIndexToBlock[i] != kNoAssociatedBlockIndex) beginBlock(blockAtInst(i)); + if (i == 0) + buildArgumentTypeChecks(*this, proto); + // We skip dead bytecode instructions when they appear after block was already terminated if (!inTerminatedBlock) translateInst(op, pc, i); diff --git a/CodeGen/src/IrCallWrapperX64.cpp b/CodeGen/src/IrCallWrapperX64.cpp index f466df4a8..816e01841 100644 --- a/CodeGen/src/IrCallWrapperX64.cpp +++ b/CodeGen/src/IrCallWrapperX64.cpp @@ -212,7 +212,13 @@ RegisterX64 IrCallWrapperX64::suggestNextArgumentRegister(SizeX64 size) const { OperandX64 target = getNextArgumentTarget(size); - return target.cat == CategoryX64::reg ? regs.takeReg(target.base, kInvalidInstIdx) : regs.allocReg(size, kInvalidInstIdx); + if (target.cat != CategoryX64::reg) + return regs.allocReg(size, kInvalidInstIdx); + + if (!regs.canTakeReg(target.base)) + return regs.allocReg(size, kInvalidInstIdx); + + return regs.takeReg(target.base, kInvalidInstIdx); } OperandX64 IrCallWrapperX64::getNextArgumentTarget(SizeX64 size) const diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index e699229c8..dfd7236a3 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -163,6 +163,8 @@ const char* getCmdName(IrCmd cmd) return "JUMP_SLOT_MATCH"; case IrCmd::TABLE_LEN: return "TABLE_LEN"; + case IrCmd::STRING_LEN: + return "STRING_LEN"; case IrCmd::NEW_TABLE: return "NEW_TABLE"; case IrCmd::DUP_TABLE: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 3cf921730..38e840ab7 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -376,8 +376,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.add(inst.regA64, regOp(inst.b), uint16_t(intOp(inst.a))); else { - RegisterA64 temp = tempInt(inst.b); - build.add(inst.regA64, regOp(inst.a), temp); + RegisterA64 temp1 = tempInt(inst.a); + RegisterA64 temp2 = tempInt(inst.b); + build.add(inst.regA64, temp1, temp2); } break; case IrCmd::SUB_INT: @@ -386,8 +387,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.sub(inst.regA64, regOp(inst.a), uint16_t(intOp(inst.b))); else { - RegisterA64 temp = tempInt(inst.b); - build.sub(inst.regA64, regOp(inst.a), temp); + RegisterA64 temp1 = tempInt(inst.a); + RegisterA64 temp2 = tempInt(inst.b); + build.sub(inst.regA64, temp1, temp2); } break; case IrCmd::ADD_NUM: @@ -689,6 +691,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.scvtf(inst.regA64, x0); break; } + case IrCmd::STRING_LEN: + { + RegisterA64 reg = regOp(inst.a); + inst.regA64 = regs.allocReg(KindA64::w, index); + + build.ldr(inst.regA64, mem(reg, offsetof(TString, len))); + break; + } case IrCmd::NEW_TABLE: { regs.spill(build, index); @@ -816,7 +826,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::FASTCALL: regs.spill(build, index); - error |= emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f)); + error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f)); break; case IrCmd::INVOKE_FASTCALL: { @@ -1018,8 +1028,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::CHECK_TAG: { + bool continueInVm = (inst.d.kind == IrOpKind::Constant && intOp(inst.d)); Label abort; // used when guard aborts execution - Label& fail = inst.c.kind == IrOpKind::Undef ? abort : labelOp(inst.c); + Label& fail = inst.c.kind == IrOpKind::Undef ? (continueInVm ? helpers.exitContinueVmClearNativeFlag : abort) : labelOp(inst.c); if (tagOp(inst.b) == 0) { build.cbnz(regOp(inst.a), fail); @@ -1029,7 +1040,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(regOp(inst.a), tagOp(inst.b)); build.b(ConditionA64::NotEqual, fail); } - if (abort.id) + if (abort.id && !continueInVm) emitAbort(build, abort); break; } @@ -1060,9 +1071,18 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) RegisterA64 tempw = castReg(KindA64::w, temp); build.ldr(temp, mem(rClosure, offsetof(Closure, env))); build.ldrb(tempw, mem(temp, offsetof(Table, safeenv))); - build.cbz(tempw, inst.a.kind == IrOpKind::Undef ? abort : labelOp(inst.a)); - if (abort.id) + + if (inst.a.kind == IrOpKind::Undef) + { + build.cbz(tempw, abort); emitAbort(build, abort); + } + else + { + Label self; + build.cbz(tempw, self); + exitHandlers.push_back({self, uintOp(inst.a)}); + } break; } case IrCmd::CHECK_ARRAY_SIZE: @@ -1528,7 +1548,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITAND_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) + if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) build.and_(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b))); else { @@ -1541,7 +1561,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITXOR_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) + if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) build.eor(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b))); else { @@ -1554,7 +1574,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITOR_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) + if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) build.orr(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b))); else { @@ -1574,7 +1594,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITLSHIFT_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant) + if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant) build.lsl(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { @@ -1587,7 +1607,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITRSHIFT_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant) + if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant) build.lsr(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { @@ -1600,7 +1620,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITARSHIFT_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant) + if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant) build.asr(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { @@ -1612,7 +1632,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::BITLROTATE_UINT: { - if (inst.b.kind == IrOpKind::Constant) + if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a}); build.ror(inst.regA64, regOp(inst.a), uint8_t((32 - unsigned(intOp(inst.b))) & 31)); @@ -1630,7 +1650,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITRROTATE_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant) + if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant) build.ror(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { @@ -1751,11 +1771,21 @@ void IrLoweringA64::finishFunction() build.adr(x1, handler.next); build.b(helpers.interrupt); } + + if (build.logText) + build.logAppend("; exit handlers\n"); + + for (ExitHandler& handler : exitHandlers) + { + build.setLabel(handler.self); + build.mov(x0, handler.pcpos * sizeof(Instruction)); + build.b(helpers.updatePcAndContinueInVm); + } } bool IrLoweringA64::hasError() const { - return error; + return error || regs.error; } bool IrLoweringA64::isFallthroughBlock(IrBlock target, IrBlock next) diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 57c18b2ed..5b1968892 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -60,6 +60,12 @@ struct IrLoweringA64 Label next; }; + struct ExitHandler + { + Label self; + unsigned int pcpos; + }; + AssemblyBuilderA64& build; ModuleHelpers& helpers; @@ -70,6 +76,7 @@ struct IrLoweringA64 IrValueLocationTracking valueTracker; std::vector interruptHandlers; + std::vector exitHandlers; bool error = false; }; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index abe02eedb..813f5123e 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -584,6 +584,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.vcvtsi2sd(inst.regX64, inst.regX64, eax); break; } + case IrCmd::STRING_LEN: + { + RegisterX64 ptr = regOp(inst.a); + inst.regX64 = regs.allocReg(SizeX64::dword, index); + build.mov(inst.regX64, dword[ptr + offsetof(TString, len)]); + break; + } case IrCmd::NEW_TABLE: { IrCallWrapperX64 callWrap(regs, build, index); @@ -720,9 +727,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) int nparams = intOp(inst.e); int nresults = intOp(inst.f); - ScopedRegX64 func{regs, SizeX64::qword}; - build.mov(func.reg, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); - IrCallWrapperX64 callWrap(regs, build, index); callWrap.addArgument(SizeX64::qword, rState); callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); @@ -748,6 +752,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::dword, nparams); } + ScopedRegX64 func{regs, SizeX64::qword}; + build.mov(func.reg, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); + callWrap.call(func.release()); inst.regX64 = regs.takeReg(eax, index); // Result of a builtin call is returned in eax break; @@ -878,9 +885,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callPrepareForN(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); break; case IrCmd::CHECK_TAG: + { + bool continueInVm = (inst.d.kind == IrOpKind::Constant && intOp(inst.d)); build.cmp(memRegTagOp(inst.a), tagOp(inst.b)); - jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.c); + jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.c, continueInVm); break; + } case IrCmd::CHECK_READONLY: build.cmp(byte[regOp(inst.a) + offsetof(Table, readonly)], 0); jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.b); @@ -896,7 +906,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(tmp.reg, sClosure); build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); build.cmp(byte[tmp.reg + offsetof(Table, safeenv)], 0); - jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.a); + + if (inst.a.kind == IrOpKind::Undef) + { + Label skip; + build.jcc(ConditionX64::NotEqual, skip); + build.ud2(); + build.setLabel(skip); + } + else + { + Label self; + build.jcc(ConditionX64::Equal, self); + exitHandlers.push_back({self, uintOp(inst.a)}); + } break; } case IrCmd::CHECK_ARRAY_SIZE: @@ -1403,6 +1426,16 @@ void IrLoweringX64::finishFunction() build.lea(rbx, handler.next); build.jmp(helpers.interrupt); } + + if (build.logText) + build.logAppend("; exit handlers\n"); + + for (ExitHandler& handler : exitHandlers) + { + build.setLabel(handler.self); + build.mov(edx, handler.pcpos * sizeof(Instruction)); + build.jmp(helpers.updatePcAndContinueInVm); + } } bool IrLoweringX64::hasError() const @@ -1425,10 +1458,16 @@ void IrLoweringX64::jumpOrFallthrough(IrBlock& target, IrBlock& next) build.jmp(target.label); } -void IrLoweringX64::jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef) +void IrLoweringX64::jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef, bool continueInVm) { if (targetOrUndef.kind == IrOpKind::Undef) { + if (continueInVm) + { + build.jcc(cond, helpers.exitContinueVmClearNativeFlag); + return; + } + Label skip; build.jcc(condInverse, skip); build.ud2(); diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index f50812e42..8ea4b41eb 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -34,7 +34,7 @@ struct IrLoweringX64 bool isFallthroughBlock(IrBlock target, IrBlock next); void jumpOrFallthrough(IrBlock& target, IrBlock& next); - void jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef); + void jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef, bool continueInVm = false); void storeDoubleAsFloat(OperandX64 dst, IrOp src); @@ -60,6 +60,12 @@ struct IrLoweringX64 Label next; }; + struct ExitHandler + { + Label self; + unsigned int pcpos; + }; + AssemblyBuilderX64& build; ModuleHelpers& helpers; @@ -70,6 +76,7 @@ struct IrLoweringX64 IrValueLocationTracking valueTracker; std::vector interruptHandlers; + std::vector exitHandlers; }; } // namespace X64 diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index a4cfeaed4..02d7df986 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -18,6 +18,8 @@ namespace CodeGen namespace A64 { +static const int8_t kInvalidSpill = 64; + static int allocSpill(uint32_t& free, KindA64 kind) { LUAU_ASSERT(kStackSize <= 256); // to support larger stack frames, we need to ensure qN is allocated at 16b boundary to fit in ldr/str encoding @@ -91,7 +93,8 @@ static void restoreInst(AssemblyBuilderA64& build, uint32_t& freeSpillSlots, IrF { build.ldr(reg, mem(sp, sSpillArea.data + s.slot * 8)); - freeSpill(freeSpillSlots, reg.kind, s.slot); + if (s.slot != kInvalidSpill) + freeSpill(freeSpillSlots, reg.kind, s.slot); } else { @@ -135,9 +138,8 @@ RegisterA64 IrRegAllocA64::allocReg(KindA64 kind, uint32_t index) if (set.free == 0) { - // TODO: remember the error and fail lowering - LUAU_ASSERT(!"Out of registers to allocate"); - return noreg; + error = true; + return RegisterA64{kind, 0}; } int reg = 31 - countlz(set.free); @@ -157,9 +159,8 @@ RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind) if (set.free == 0) { - // TODO: remember the error and fail lowering - LUAU_ASSERT(!"Out of registers to allocate"); - return noreg; + error = true; + return RegisterA64{kind, 0}; } int reg = 31 - countlz(set.free); @@ -332,7 +333,11 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init else { int slot = allocSpill(freeSpillSlots, def.regA64.kind); - LUAU_ASSERT(slot >= 0); // TODO: remember the error and fail lowering + if (slot < 0) + { + slot = kInvalidSpill; + error = true; + } build.str(def.regA64, mem(sp, sSpillArea.data + slot * 8)); diff --git a/CodeGen/src/IrRegAllocA64.h b/CodeGen/src/IrRegAllocA64.h index 689743789..854a9f10f 100644 --- a/CodeGen/src/IrRegAllocA64.h +++ b/CodeGen/src/IrRegAllocA64.h @@ -77,6 +77,8 @@ struct IrRegAllocA64 // which 8-byte slots are free uint32_t freeSpillSlots = 0; + + bool error = false; }; } // namespace A64 diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index 273740778..b81aec8ce 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -121,6 +121,14 @@ RegisterX64 IrRegAllocX64::takeReg(RegisterX64 reg, uint32_t instIdx) return reg; } +bool IrRegAllocX64::canTakeReg(RegisterX64 reg) const +{ + const std::array& freeMap = reg.size == SizeX64::xmmword ? freeXmmMap : freeGprMap; + const std::array& instUsers = reg.size == SizeX64::xmmword ? xmmInstUsers : gprInstUsers; + + return freeMap[reg.index] || instUsers[reg.index] != kInvalidInstIdx; +} + void IrRegAllocX64::freeReg(RegisterX64 reg) { if (reg.size == SizeX64::xmmword) diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index e99a991a7..960be4ed8 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -737,6 +737,23 @@ static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, i return {BuiltinImplType::UsesFallback, 1}; } +static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TSTRING, fallback); + + IrOp ts = build.inst(IrCmd::LOAD_POINTER, build.vmReg(arg)); + + IrOp len = build.inst(IrCmd::STRING_LEN, ts); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), build.inst(IrCmd::INT_TO_NUM, len)); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback) { // Builtins are not allowed to handle variadic arguments @@ -821,6 +838,8 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults, fallback); case LBF_VECTOR: return translateBuiltinVector(build, nparams, ra, arg, args, nresults, fallback); + case LBF_STRING_LEN: + return translateBuiltinStringLen(build, nparams, ra, arg, args, nresults, fallback); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 8f18827bc..3cbcd3cbd 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -516,6 +516,7 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc) void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next) { + LuauOpcode opcode = LuauOpcode(LUAU_INSN_OP(*pc)); int bfid = LUAU_INSN_A(*pc); int skip = LUAU_INSN_C(*pc); @@ -540,7 +541,8 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool IrOp fallback = build.block(IrBlockKind::Fallback); - build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + // In unsafe environment, instead of retrying fastcall at 'pcpos' we side-exit directly to fallback sequence + build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos + getOpLength(opcode))); BuiltinImplResult br = translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, nparams, nresults, fallback); @@ -554,7 +556,7 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool else { // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline - build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + getOpLength(opcode))); IrOp res = build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); @@ -668,7 +670,7 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo IrOp fallback = build.block(IrBlockKind::Fallback); // fast-path: pairs/next - build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos)); IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback); IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); @@ -695,7 +697,7 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp IrOp finish = build.block(IrBlockKind::Internal); // fast-path: ipairs/inext - build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos)); IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback); IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); @@ -921,7 +923,7 @@ void translateInstGetImport(IrBuilder& build, const Instruction* pc, int pcpos) IrOp fastPath = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); - build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos)); // note: if import failed, k[] is nil; we could check this during codegen, but we instead use runtime fallback // this allows us to handle ahead-of-time codegen smoothly when an import fails to resolve at runtime diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 833d1cdd7..2395fb1ec 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -80,6 +80,8 @@ IrValueKind getCmdValueKind(IrCmd cmd) return IrValueKind::None; case IrCmd::TABLE_LEN: return IrValueKind::Double; + case IrCmd::STRING_LEN: + return IrValueKind::Int; case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: return IrValueKind::Pointer; @@ -684,8 +686,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 unsigned op1 = unsigned(function.intOp(inst.a)); int op2 = function.intOp(inst.b); - if (unsigned(op2) < 32) - substitute(function, inst, build.constInt(op1 << op2)); + substitute(function, inst, build.constInt(op1 << (op2 & 31))); } else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) { @@ -698,8 +699,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 unsigned op1 = unsigned(function.intOp(inst.a)); int op2 = function.intOp(inst.b); - if (unsigned(op2) < 32) - substitute(function, inst, build.constInt(op1 >> op2)); + substitute(function, inst, build.constInt(op1 >> (op2 & 31))); } else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) { @@ -712,12 +712,9 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 int op1 = function.intOp(inst.a); int op2 = function.intOp(inst.b); - if (unsigned(op2) < 32) - { - // note: technically right shift of negative values is UB, but this behavior is getting defined in C++20 and all compilers do the - // right (shift) thing. - substitute(function, inst, build.constInt(op1 >> op2)); - } + // note: technically right shift of negative values is UB, but this behavior is getting defined in C++20 and all compilers do the + // right (shift) thing. + substitute(function, inst, build.constInt(op1 >> (op2 & 31))); } else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) { @@ -794,5 +791,17 @@ uint32_t getNativeContextOffset(int bfid) return 0; } +void killUnusedBlocks(IrFunction& function) +{ + // Start from 1 as the first block is the entry block + for (unsigned i = 1; i < function.blocks.size(); i++) + { + IrBlock& block = function.blocks[i]; + + if (block.kind != IrBlockKind::Dead && block.useCount == 0) + kill(function, block); + } +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index e3cbef415..eeedd6cb0 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -441,17 +441,25 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.substituteOrRecordVmRegLoad(inst); break; case IrCmd::LOAD_DOUBLE: - if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant) + { + IrOp value = state.tryGetValue(inst.a); + + if (function.asDoubleOp(value)) substitute(function, inst, value); else if (inst.a.kind == IrOpKind::VmReg) state.substituteOrRecordVmRegLoad(inst); break; + } case IrCmd::LOAD_INT: - if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant) + { + IrOp value = state.tryGetValue(inst.a); + + if (function.asIntOp(value)) substitute(function, inst, value); else if (inst.a.kind == IrOpKind::VmReg) state.substituteOrRecordVmRegLoad(inst); break; + } case IrCmd::LOAD_TVALUE: if (inst.a.kind == IrOpKind::VmReg) state.substituteOrRecordVmRegLoad(inst); @@ -775,6 +783,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_SLOT_MATCH: case IrCmd::TABLE_LEN: + case IrCmd::STRING_LEN: case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: case IrCmd::TRY_NUM_TO_INDEX: diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 9296519ba..8d360f875 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -2314,9 +2314,9 @@ static const char* getBaseTypeString(uint8_t type) case LBC_TYPE_STRING: return "string"; case LBC_TYPE_TABLE: - return "{ }"; + return "table"; case LBC_TYPE_FUNCTION: - return "function( )"; + return "function"; case LBC_TYPE_THREAD: return "thread"; case LBC_TYPE_USERDATA: diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 8dd9876ca..83aad3d82 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,8 +26,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(CompileFunctionType, false) -LUAU_FASTFLAG(BytecodeVersion4) +LUAU_FASTFLAGVARIABLE(LuauCompileFunctionType, false) namespace Luau { @@ -103,6 +102,7 @@ struct Compiler , locstants(nullptr) , tableShapes(nullptr) , builtins(nullptr) + , typeMap(nullptr) { // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays localStack.reserve(16); @@ -204,11 +204,11 @@ struct Compiler setDebugLine(func); - if (FFlag::BytecodeVersion4 && FFlag::CompileFunctionType) + if (FFlag::LuauCompileFunctionType) { - std::string funcType = getFunctionType(func); - if (!funcType.empty()) - bytecode.setFunctionTypeInfo(std::move(funcType)); + // note: we move types out of typeMap which is safe because compileFunction is only called once per function + if (std::string* funcType = typeMap.find(func)) + bytecode.setFunctionTypeInfo(std::move(*funcType)); } if (func->vararg) @@ -3807,6 +3807,8 @@ struct Compiler DenseHashMap locstants; DenseHashMap tableShapes; DenseHashMap builtins; + DenseHashMap typeMap; + const DenseHashMap* builtinsFold = nullptr; unsigned int regTop = 0; @@ -3870,6 +3872,11 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c root->visit(&fenvVisitor); } + if (FFlag::LuauCompileFunctionType) + { + buildTypeMap(compiler.typeMap, root); + } + // gathers all functions with the invariant that all function references are to functions earlier in the list // for example, function foo() return function() end end will result in two vector entries, [0] = anonymous and [1] = foo std::vector functions; diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index 02041986f..b8db7fb27 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -6,22 +6,58 @@ namespace Luau { -static LuauBytecodeEncodedType getType(AstType* ty) +static bool isGeneric(AstName name, const AstArray& generics) +{ + for (const AstGenericType& gt : generics) + if (gt.name == name) + return true; + + return false; +} + +static LuauBytecodeEncodedType getPrimitiveType(AstName name) +{ + if (name == "nil") + return LBC_TYPE_NIL; + else if (name == "boolean") + return LBC_TYPE_BOOLEAN; + else if (name == "number") + return LBC_TYPE_NUMBER; + else if (name == "string") + return LBC_TYPE_STRING; + else if (name == "thread") + return LBC_TYPE_THREAD; + else if (name == "any" || name == "unknown") + return LBC_TYPE_ANY; + else + return LBC_TYPE_INVALID; +} + +static LuauBytecodeEncodedType getType( + AstType* ty, const AstArray& generics, const DenseHashMap& typeAliases, bool resolveAliases) { if (AstTypeReference* ref = ty->as()) { - if (ref->name == "nil") - return LBC_TYPE_NIL; - else if (ref->name == "boolean") - return LBC_TYPE_BOOLEAN; - else if (ref->name == "number") - return LBC_TYPE_NUMBER; - else if (ref->name == "string") - return LBC_TYPE_STRING; - else if (ref->name == "thread") - return LBC_TYPE_THREAD; - else if (ref->name == "any" || ref->name == "unknown") + if (ref->prefix) return LBC_TYPE_ANY; + + if (AstStatTypeAlias* const* alias = typeAliases.find(ref->name); alias && *alias) + { + // note: we only resolve aliases to the depth of 1 to avoid dealing with recursive aliases + if (resolveAliases) + return getType((*alias)->type, (*alias)->generics, typeAliases, /* resolveAliases= */ false); + else + return LBC_TYPE_ANY; + } + + if (isGeneric(ref->name, generics)) + return LBC_TYPE_ANY; + + if (LuauBytecodeEncodedType prim = getPrimitiveType(ref->name); prim != LBC_TYPE_INVALID) + return prim; + + // not primitive or alias or generic => host-provided, we assume userdata for now + return LBC_TYPE_USERDATA; } else if (AstTypeTable* table = ty->as()) { @@ -38,7 +74,7 @@ static LuauBytecodeEncodedType getType(AstType* ty) for (AstType* ty : un->types) { - LuauBytecodeEncodedType et = getType(ty); + LuauBytecodeEncodedType et = getType(ty, generics, typeAliases, resolveAliases); if (et == LBC_TYPE_NIL) { @@ -69,11 +105,8 @@ static LuauBytecodeEncodedType getType(AstType* ty) return LBC_TYPE_ANY; } -std::string getFunctionType(const AstExprFunction* func) +static std::string getFunctionType(const AstExprFunction* func, const DenseHashMap& typeAliases) { - if (func->vararg || func->generics.size || func->genericPacks.size) - return {}; - bool self = func->self != 0; std::string typeInfo; @@ -88,7 +121,8 @@ std::string getFunctionType(const AstExprFunction* func) bool haveNonAnyParam = false; for (AstLocal* arg : func->args) { - LuauBytecodeEncodedType ty = arg->annotation ? getType(arg->annotation) : LBC_TYPE_ANY; + LuauBytecodeEncodedType ty = + arg->annotation ? getType(arg->annotation, func->generics, typeAliases, /* resolveAliases= */ true) : LBC_TYPE_ANY; if (ty != LBC_TYPE_ANY) haveNonAnyParam = true; @@ -103,4 +137,88 @@ std::string getFunctionType(const AstExprFunction* func) return typeInfo; } -} // namespace Luau \ No newline at end of file +struct TypeMapVisitor : AstVisitor +{ + DenseHashMap& typeMap; + + DenseHashMap typeAliases; + std::vector> typeAliasStack; + + TypeMapVisitor(DenseHashMap& typeMap) + : typeMap(typeMap) + , typeAliases(AstName()) + { + } + + size_t pushTypeAliases(AstStatBlock* block) + { + size_t aliasStackTop = typeAliasStack.size(); + + for (AstStat* stat : block->body) + if (AstStatTypeAlias* alias = stat->as()) + { + AstStatTypeAlias*& prevAlias = typeAliases[alias->name]; + + typeAliasStack.push_back(std::make_pair(alias->name, prevAlias)); + prevAlias = alias; + } + + return aliasStackTop; + } + + void popTypeAliases(size_t aliasStackTop) + { + while (typeAliasStack.size() > aliasStackTop) + { + std::pair& top = typeAliasStack.back(); + + typeAliases[top.first] = top.second; + typeAliasStack.pop_back(); + } + } + + bool visit(AstStatBlock* node) override + { + size_t aliasStackTop = pushTypeAliases(node); + + for (AstStat* stat : node->body) + stat->visit(this); + + popTypeAliases(aliasStackTop); + + return false; + } + + // repeat..until scoping rules are such that condition (along with any possible functions declared in it) has aliases from repeat body in scope + bool visit(AstStatRepeat* node) override + { + size_t aliasStackTop = pushTypeAliases(node->body); + + for (AstStat* stat : node->body->body) + stat->visit(this); + + node->condition->visit(this); + + popTypeAliases(aliasStackTop); + + return false; + } + + bool visit(AstExprFunction* node) override + { + std::string type = getFunctionType(node, typeAliases); + + if (!type.empty()) + typeMap[node] = std::move(type); + + return true; + } +}; + +void buildTypeMap(DenseHashMap& typeMap, AstNode* root) +{ + TypeMapVisitor visitor(typeMap); + root->visit(&visitor); +} + +} // namespace Luau diff --git a/Compiler/src/Types.h b/Compiler/src/Types.h index 1be9155f8..c3dd16209 100644 --- a/Compiler/src/Types.h +++ b/Compiler/src/Types.h @@ -3,7 +3,11 @@ #include "Luau/Ast.h" +#include + namespace Luau { -std::string getFunctionType(const AstExprFunction* func); + +void buildTypeMap(DenseHashMap& typeMap, AstNode* root); + } // namespace Luau diff --git a/Makefile b/Makefile index 852b14f83..17bae9192 100644 --- a/Makefile +++ b/Makefile @@ -176,9 +176,9 @@ coverage: $(TESTS_TARGET) $(COMPILE_CLI_TARGET) mv default.profraw codegen-x64.profraw llvm-profdata merge *.profraw -o default.profdata rm *.profraw - llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests - llvm-cov report -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests - llvm-cov export -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -format lcov --instr-profile default.profdata build/coverage/luau-tests >coverage.info + llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -output-dir=coverage --instr-profile default.profdata -object build/coverage/luau-tests -object build/coverage/luau-compile + llvm-cov report -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -show-region-summary=false --instr-profile default.profdata -object build/coverage/luau-tests -object build/coverage/luau-compile + llvm-cov export -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -format lcov --instr-profile default.profdata -object build/coverage/luau-tests -object build/coverage/luau-compile >coverage.info format: git ls-files '*.h' '*.cpp' | xargs clang-format-11 -i diff --git a/Sources.cmake b/Sources.cmake index 74709b4bf..ccf2e1df9 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -152,6 +152,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/DataFlowGraph.h Analysis/include/Luau/DcrLogger.h Analysis/include/Luau/Def.h + Analysis/include/Luau/Differ.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h @@ -209,6 +210,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/DataFlowGraph.cpp Analysis/src/DcrLogger.cpp Analysis/src/Def.cpp + Analysis/src/Differ.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp @@ -366,6 +368,7 @@ if(TARGET Luau.UnitTest) tests/CostModel.test.cpp tests/DataFlowGraph.test.cpp tests/DenseHash.test.cpp + tests/Differ.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp tests/IrBuilder.test.cpp diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 6330b4c36..a47ad34f3 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -34,6 +34,7 @@ Proto* luaF_newproto(lua_State* L) f->codeentry = NULL; f->execdata = NULL; f->exectarget = 0; + f->typeinfo = NULL; return f; } @@ -162,6 +163,9 @@ void luaF_freeproto(lua_State* L, Proto* f, lua_Page* page) } #endif + if (f->typeinfo) + luaM_freearray(L, f->typeinfo, f->numparams + 2, uint8_t, f->memcat); + luaM_freegco(L, f, sizeof(Proto), f->memcat, page); } diff --git a/VM/src/lobject.h b/VM/src/lobject.h index a42616332..560969d6d 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -279,6 +279,8 @@ typedef struct Proto void* execdata; uintptr_t exectarget; + uint8_t* typeinfo; + GCObject* gclist; diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index edbe5035d..a26dd0b8f 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauLoadCheckGC, false) + // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template struct TempBuffer @@ -178,6 +180,10 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size return 1; } + // we will allocate a fair amount of memory so check GC before we do + if (FFlag::LuauLoadCheckGC) + luaC_checkGC(L); + // pause GC for the duration of deserialization - some objects we're creating aren't rooted // TODO: if an allocation error happens mid-load, we do not unpause GC! size_t GCthreshold = L->global->GCthreshold; @@ -188,11 +194,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size TString* source = luaS_new(L, chunkname); + uint8_t typesversion = 0; if (version >= 4) { - uint8_t typesversion = read(data, size, offset); - LUAU_ASSERT(typesversion == 1); + typesversion = read(data, size, offset); } // string table @@ -229,7 +235,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size uint32_t typesize = readVarInt(data, size, offset); - if (typesize) + if (typesize && typesversion == LBC_TYPE_VERSION) { uint8_t* types = (uint8_t*)data + offset; @@ -237,8 +243,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size LUAU_ASSERT(types[0] == LBC_TYPE_FUNCTION); LUAU_ASSERT(types[1] == p->numparams); - offset += typesize; + p->typeinfo = luaM_newarray(L, typesize, uint8_t, p->memcat); + memcpy(p->typeinfo, types, typesize); } + + offset += typesize; } p->sizecode = readVarInt(data, size, offset); diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 9366da5e2..c30e93775 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -27,12 +27,16 @@ const bool kFuzzTypeck = true; const bool kFuzzVM = true; const bool kFuzzTranspile = true; const bool kFuzzCodegen = true; +const bool kFuzzCodegenAssembly = true; // Should we generate type annotations? const bool kFuzzTypes = true; +const Luau::CodeGen::AssemblyOptions::Target kFuzzCodegenTarget = Luau::CodeGen::AssemblyOptions::A64; + static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); static_assert(!(kFuzzCodegen && !kFuzzVM), "Codegen requires the VM!"); +static_assert(!(kFuzzCodegenAssembly && !kFuzzCompiler), "Codegen requires the compiler!"); std::vector protoprint(const luau::ModuleSet& stat, bool types); @@ -348,6 +352,23 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) } } + // run codegen on resulting bytecode (in separate state) + if (kFuzzCodegenAssembly && bytecode.size()) + { + static lua_State* globalState = luaL_newstate(); + + if (luau_load(globalState, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) + { + Luau::CodeGen::AssemblyOptions options; + options.outputBinary = true; + options.target = kFuzzCodegenTarget; + Luau::CodeGen::getAssembly(globalState, -1, options); + } + + lua_pop(globalState, 1); + lua_gc(globalState, LUA_GCCOLLECT, 0); + } + // run resulting bytecode (from last successfully compiler module) if (kFuzzVM && bytecode.size()) { diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index 5e28e8d90..784069335 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -107,7 +107,6 @@ ClassFixture::ClassFixture() globals.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; auto addIndexableClass = [&arena, &globals](const char* className, TypeId keyType, TypeId returnType) { - ScopedFastFlag LuauTypecheckClassTypeIndexers("LuauTypecheckClassTypeIndexers", true); TypeId indexableClassMetaType = arena.addType(TableType{}); TypeId indexableClassType = arena.addType(ClassType{className, {}, nullopt, indexableClassMetaType, {}, {}, "Test", TableIndexer{keyType, returnType}}); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 97cc32635..fde6e90e9 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -49,7 +49,7 @@ static std::string compileFunction0Coverage(const char* source, int level) return bcb.dumpFunction(0); } -static std::string compileFunction0TypeTable(const char* source) +static std::string compileTypeTable(const char* source) { Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); @@ -7080,12 +7080,9 @@ L1: RETURN R3 1 TEST_CASE("EncodedTypeTable") { - ScopedFastFlag sffs[] = { - {"BytecodeVersion4", true}, - {"CompileFunctionType", true}, - }; + ScopedFastFlag sff("LuauCompileFunctionType", true); - CHECK_EQ("\n" + compileFunction0TypeTable(R"( + CHECK_EQ("\n" + compileTypeTable(R"( function myfunc(test: string, num: number) print(test) end @@ -7104,6 +7101,9 @@ end function myfunc5(test: string | number, n: number | boolean) end +function myfunc6(test: (number) -> string) +end + myfunc('test') )"), R"( @@ -7111,9 +7111,10 @@ myfunc('test') 1: function(number?) 2: function(string, number) 3: function(any, number) +5: function(function) )"); - CHECK_EQ("\n" + compileFunction0TypeTable(R"( + CHECK_EQ("\n" + compileTypeTable(R"( local Str = { a = 1 } @@ -7126,7 +7127,95 @@ end Str:test(234) )"), R"( -0: function({ }, number) +0: function(table, number) +)"); +} + +TEST_CASE("HostTypesAreUserdata") +{ + ScopedFastFlag sff("LuauCompileFunctionType", true); + + CHECK_EQ("\n" + compileTypeTable(R"( +function myfunc(test: string, num: number) + print(test) +end + +function myfunc2(test: Instance, num: number) +end + +type Foo = string + +function myfunc3(test: string, n: Foo) +end + +function myfunc4(test: Bar, n: Part) +end +)"), + R"( +0: function(string, number) +1: function(userdata, number) +2: function(string, string) +3: function(any, userdata) +)"); +} + +TEST_CASE("TypeAliasScoping") +{ + ScopedFastFlag sff("LuauCompileFunctionType", true); + + CHECK_EQ("\n" + compileTypeTable(R"( +do + type Part = number +end + +function myfunc1(test: Part, num: number) +end + +do + type Part = number + + function myfunc2(test: Part, num: number) + end +end + +repeat + type Part = number +until (function(test: Part, num: number) end)() + +function myfunc4(test: Instance, num: number) +end + +type Instance = string +)"), + R"( +0: function(userdata, number) +1: function(number, number) +2: function(number, number) +3: function(string, number) +)"); +} + +TEST_CASE("TypeAliasResolve") +{ + ScopedFastFlag sff("LuauCompileFunctionType", true); + + CHECK_EQ("\n" + compileTypeTable(R"( +type Foo1 = number +type Foo2 = { number } +type Foo3 = Part +type Foo4 = Foo1 -- we do not resolve aliases within aliases +type Foo5 = X + +function myfunc(f1: Foo1, f2: Foo2, f3: Foo3, f4: Foo4, f5: Foo5) +end + +function myfuncerr(f1: Foo1, f2: Foo5) +end + +)"), + R"( +0: function(number, table, userdata, any, any) +1: function(number, any) )"); } diff --git a/tests/Differ.test.cpp b/tests/Differ.test.cpp new file mode 100644 index 000000000..6b570f398 --- /dev/null +++ b/tests/Differ.test.cpp @@ -0,0 +1,316 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Differ.h" +#include "Luau/Error.h" +#include "Luau/Frontend.h" + +#include "Fixture.h" + +#include "doctest.h" +#include + +using namespace Luau; + +TEST_SUITE_BEGIN("Differ"); + +TEST_CASE_FIXTURE(Fixture, "equal_numbers") +{ + CheckResult result = check(R"( + local foo = 5 + local almostFoo = 78 + almostFoo = foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + try + { + DifferResult diffRes = diff(foo, almostFoo); + CHECK(!diffRes.diffError.has_value()); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } +} + +TEST_CASE_FIXTURE(Fixture, "equal_strings") +{ + CheckResult result = check(R"( + local foo = "hello" + local almostFoo = "world" + almostFoo = foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + try + { + DifferResult diffRes = diff(foo, almostFoo); + CHECK(!diffRes.diffError.has_value()); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } +} + +TEST_CASE_FIXTURE(Fixture, "equal_tables") +{ + CheckResult result = check(R"( + local foo = { x = 1, y = "where" } + local almostFoo = { x = 5, y = "when" } + almostFoo = foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + try + { + DifferResult diffRes = diff(foo, almostFoo); + CHECK(!diffRes.diffError.has_value()); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } +} + +TEST_CASE_FIXTURE(Fixture, "a_table_missing_property") +{ + CheckResult result = check(R"( + local foo = { x = 1, y = 2 } + local almostFoo = { x = 1, z = 3 } + almostFoo = foo + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + diffMessage = diff(foo, almostFoo).diffError->toString(); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ("DiffError: these two types are not equal because the left type at foo.y has type number, while the right type at almostFoo is missing " + "the property y", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "left_table_missing_property") +{ + CheckResult result = check(R"( + local foo = { x = 1 } + local almostFoo = { x = 1, z = 3 } + almostFoo = foo + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + diffMessage = diff(foo, almostFoo).diffError->toString(); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ("DiffError: these two types are not equal because the left type at foo is missing the property z, while the right type at almostFoo.z " + "has type number", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "a_table_wrong_type") +{ + CheckResult result = check(R"( + local foo = { x = 1, y = 2 } + local almostFoo = { x = 1, y = "two" } + almostFoo = foo + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + diffMessage = diff(foo, almostFoo).diffError->toString(); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ("DiffError: these two types are not equal because the left type at foo.y has type number, while the right type at almostFoo.y has type " + "string", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "a_table_wrong_type") +{ + CheckResult result = check(R"( + local foo: string + local almostFoo: number + almostFoo = foo + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + diffMessage = diff(foo, almostFoo).diffError->toString(); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ("DiffError: these two types are not equal because the left type at has type string, while the right type at " + " has type number", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "a_nested_table_wrong_type") +{ + CheckResult result = check(R"( + local foo = { x = 1, inner = { table = { has = { wrong = { value = 5 } } } } } + local almostFoo = { x = 1, inner = { table = { has = { wrong = { value = "five" } } } } } + almostFoo = foo + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + diffMessage = diff(foo, almostFoo).diffError->toString(); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ("DiffError: these two types are not equal because the left type at foo.inner.table.has.wrong.value has type number, while the right " + "type at almostFoo.inner.table.has.wrong.value has type string", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "a_nested_table_wrong_match") +{ + CheckResult result = check(R"( + local foo = { x = 1, inner = { table = { has = { wrong = { variant = { because = { it = { goes = { on = "five" } } } } } } } } } + local almostFoo = { x = 1, inner = { table = { has = { wrong = { variant = "five" } } } } } + almostFoo = foo + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + diffMessage = diff(foo, almostFoo).diffError->toString(); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ("DiffError: these two types are not equal because the left type at foo.inner.table.has.wrong.variant has type { because: { it: { goes: " + "{ on: string } } } }, while the right type at almostFoo.inner.table.has.wrong.variant has type string", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "singleton") +{ + CheckResult result = check(R"( + local foo: "hello" = "hello" + local almostFoo: true = true + almostFoo = foo + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + diffMessage = diff(foo, almostFoo).diffError->toString(); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at has type "hello", while the right type at has type true)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "equal_singleton") +{ + CheckResult result = check(R"( + local foo: "hello" = "hello" + local almostFoo: "hello" + almostFoo = foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + try + { + DifferResult diffRes = diff(foo, almostFoo); + INFO(diffRes.diffError->toString()); + CHECK(!diffRes.diffError.has_value()); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } +} + +TEST_CASE_FIXTURE(Fixture, "singleton_string") +{ + CheckResult result = check(R"( + local foo: "hello" = "hello" + local almostFoo: "world" = "world" + almostFoo = foo + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + diffMessage = diff(foo, almostFoo).diffError->toString(); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at has type "hello", while the right type at has type "world")", + diffMessage); +} + +TEST_SUITE_END(); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 0b9c872c2..8f6834a17 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1146,4 +1146,35 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "module_scope_check") CHECK_EQ(toString(ty), "number"); } +TEST_CASE_FIXTURE(FrontendFixture, "parse_only") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + local a: number = 'oh no a type error' + return {a=a} + )"; + + fileResolver.source["game/Gui/Modules/B"] = R"( + local Modules = script.Parent + local A = require(Modules.A) + local b: number = 2 + )"; + + frontend.parse("game/Gui/Modules/B"); + + REQUIRE(frontend.sourceNodes.count("game/Gui/Modules/A")); + REQUIRE(frontend.sourceNodes.count("game/Gui/Modules/B")); + + auto node = frontend.sourceNodes["game/Gui/Modules/B"]; + CHECK_EQ(node->requireSet.count("game/Gui/Modules/A"), 1); + REQUIRE_EQ(node->requireLocations.size(), 1); + CHECK_EQ(node->requireLocations[0].second, Luau::Location(Position(2, 18), Position(2, 36))); + + // Early parse doesn't cause typechecking to be skipped + CheckResult result = frontend.check("game/Gui/Modules/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("game/Gui/Modules/A", result.errors[0].moduleName); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index f1399a590..571033525 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -526,7 +526,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32") )"); } -TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32Blocked") +TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32RangeReduction") { IrOp block = build.block(IrBlockKind::Internal); @@ -534,10 +534,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32Blocked") build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf), build.constInt(-10))); build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf), build.constInt(140))); - build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xf), build.constInt(-10))); - build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xf), build.constInt(140))); - build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xf), build.constInt(-10))); - build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xf), build.constInt(140))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xffffff), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xffffff), build.constInt(140))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xffffff), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xffffff), build.constInt(140))); build.inst(IrCmd::RETURN, build.constUint(0)); @@ -546,18 +546,12 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32Blocked") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: - %0 = BITLSHIFT_UINT 15i, -10i - STORE_INT R10, %0 - %2 = BITLSHIFT_UINT 15i, 140i - STORE_INT R10, %2 - %4 = BITRSHIFT_UINT 15i, -10i - STORE_INT R10, %4 - %6 = BITRSHIFT_UINT 15i, 140i - STORE_INT R10, %6 - %8 = BITARSHIFT_UINT 15i, -10i - STORE_INT R10, %8 - %10 = BITARSHIFT_UINT 15i, 140i - STORE_INT R10, %10 + STORE_INT R10, 62914560i + STORE_INT R10, 61440i + STORE_INT R10, 3i + STORE_INT R10, 4095i + STORE_INT R10, 3i + STORE_INT R10, 4095i RETURN 0u )"); @@ -1864,6 +1858,34 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VaridicRegisterRangeInvalidation") )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "LoadPropagatesOnlyRightType") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(2)); + IrOp value1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), value1); + IrOp value2 = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); + build.inst(IrCmd::STORE_INT, build.vmReg(2), value2); + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + STORE_INT R0, 2i + %1 = LOAD_DOUBLE R0 + STORE_DOUBLE R1, %1 + %3 = LOAD_INT R1 + STORE_INT R2, %3 + RETURN 0u + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Analysis"); diff --git a/tests/IrCallWrapperX64.test.cpp b/tests/IrCallWrapperX64.test.cpp index c8918dbde..ec04e5311 100644 --- a/tests/IrCallWrapperX64.test.cpp +++ b/tests/IrCallWrapperX64.test.cpp @@ -10,8 +10,8 @@ using namespace Luau::CodeGen::X64; class IrCallWrapperX64Fixture { public: - IrCallWrapperX64Fixture() - : build(/* logText */ true, ABIX64::Windows) + IrCallWrapperX64Fixture(ABIX64 abi = ABIX64::Windows) + : build(/* logText */ true, abi) , regs(build, function) , callWrap(regs, build, ~0u) { @@ -42,6 +42,15 @@ class IrCallWrapperX64Fixture static constexpr RegisterX64 rArg4d = r9d; }; +class IrCallWrapperX64FixtureSystemV : public IrCallWrapperX64Fixture +{ +public: + IrCallWrapperX64FixtureSystemV() + : IrCallWrapperX64Fixture(ABIX64::SystemV) + { + } +}; + TEST_SUITE_BEGIN("IrCallWrapperX64"); TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleRegs") @@ -519,4 +528,35 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ImmediateConflictWithFunction") )"); } +TEST_CASE_FIXTURE(IrCallWrapperX64FixtureSystemV, "SuggestedConflictWithReserved") +{ + ScopedRegX64 tmp{regs, regs.takeReg(r9, kInvalidInstIdx)}; + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, r12); + callWrap.addArgument(SizeX64::qword, r13); + callWrap.addArgument(SizeX64::qword, r14); + callWrap.addArgument(SizeX64::dword, 2); + callWrap.addArgument(SizeX64::qword, 1); + + RegisterX64 reg = callWrap.suggestNextArgumentRegister(SizeX64::dword); + build.mov(reg, 10); + callWrap.addArgument(SizeX64::dword, reg); + + callWrap.call(tmp.release()); + + checkMatch(R"( + mov eax,Ah + mov rdi,r12 + mov rsi,r13 + mov rdx,r14 + mov rcx,r9 + mov r9d,eax + mov rax,rcx + mov ecx,2 + mov r8,1 + call rax +)"); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 8fe86655c..afcc08f29 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -735,6 +735,37 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_metatables_where_the_metata CHECK("{ @metatable *error-type*, {| |} }" == toString(normal("Mt<{}, any> & Mt<{}, err>"))); } +TEST_CASE_FIXTURE(NormalizeFixture, "recurring_intersection") +{ + CheckResult result = check(R"( + type A = any? + type B = A & A + )"); + + std::optional t = lookupType("B"); + REQUIRE(t); + + const NormalizedType* nt = normalizer.normalize(*t); + REQUIRE(nt); + + CHECK("any" == toString(normalizer.typeFromNormal(*nt))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union") +{ + ScopedFastFlag sff{"LuauNormalizeCyclicUnions", true}; + + // T where T = any & (number | T) + TypeId t = arena.addType(BlockedType{}); + TypeId u = arena.addType(UnionType{{builtinTypes->numberType, t}}); + asMutable(t)->ty.emplace(IntersectionType{{builtinTypes->anyType, u}}); + + const NormalizedType* nt = normalizer.normalize(t); + REQUIRE(nt); + + CHECK("number" == toString(normalizer.typeFromNormal(*nt))); +} + TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") { CHECK("never" == toString(normal("Mt<{}, number> & Mt<{}, string>"))); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 9e3a63f79..07471d444 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -479,7 +479,6 @@ TEST_CASE_FIXTURE(ClassFixture, "callable_classes") TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") { // Test reading from an index - ScopedFastFlag LuauTypecheckClassTypeIndexers("LuauTypecheckClassTypeIndexers", true); { CheckResult result = check(R"( local x : IndexableClass diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 0ca9bd735..615b81d6f 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -398,7 +398,6 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") TEST_CASE_FIXTURE(Fixture, "class_definition_indexer") { ScopedFastFlag LuauParseDeclareClassIndexer("LuauParseDeclareClassIndexer", true); - ScopedFastFlag LuauTypecheckClassTypeIndexers("LuauTypecheckClassTypeIndexers", true); loadDefinition(R"( declare class Foo diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index f0630ca9a..d456f3783 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -2096,4 +2096,46 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "attempt_to_call_an_intersection_of_tables_wi LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "num_is_solved_before_num_or_str") +{ + CheckResult result = check(R"( + function num() + return 5 + end + + local function num_or_str() + if math.random() > 0.5 then + return num() + else + return "some string" + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); + CHECK_EQ("() -> number", toString(requireType("num_or_str"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "num_is_solved_after_num_or_str") +{ + CheckResult result = check(R"( + local function num_or_str() + if math.random() > 0.5 then + return num() + else + return "some string" + end + end + + function num() + return 5 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); + CHECK_EQ("() -> number", toString(requireType("num_or_str"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 3b0654a04..0c8887404 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -917,6 +917,52 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } + +TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_while_expression") +{ + CheckResult result = check(R"( + function f(v:string?) + while v do + local foo = v + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_while_expression") +{ + CheckResult result = check(R"( + function f(v:string?) + while not v do + local foo = v + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_the_correct_types_opposite_of_while_a_is_not_number_or_string") +{ + CheckResult result = check(R"( + local function f(a: string | number | boolean) + while type(a) ~= "number" and type(a) ~= "string" do + local foo = a + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") { CheckResult result = check(R"( @@ -1580,8 +1626,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_duri TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") { - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; - CheckResult result = check(R"( foo = { bar = 5 :: number? } @@ -1590,9 +1634,12 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ("~(false?)", toString(requireTypeAtPosition({4, 30}))); + CHECK_EQ("~(false?)", toString(requireTypeAtPosition({4, 30}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never") @@ -1757,4 +1804,20 @@ TEST_CASE_FIXTURE(Fixture, "refinements_should_not_affect_assignment") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "refinements_should_preserve_error_suppression") +{ + CheckResult result = check(R"( + local a: any = {} + local b + if typeof(a) == "table" then + b = a.field + end + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index f028e8e0d..c61ff16e8 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -9,6 +9,20 @@ using namespace Luau; TEST_SUITE_BEGIN("TypeSingletons"); +TEST_CASE_FIXTURE(Fixture, "function_args_infer_singletons") +{ + CheckResult result = check(R"( +--!strict +type Phase = "A" | "B" | "C" +local function f(e : Phase) : number + return 0 +end +local e = f("B") +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "bool_singletons") { CheckResult result = check(R"( diff --git a/tests/conformance/native.lua b/tests/conformance/native.lua index 85cc06bf0..bc70df3e7 100644 --- a/tests/conformance/native.lua +++ b/tests/conformance/native.lua @@ -61,4 +61,35 @@ end assert(pcall(fuzzfail5) == false) +local function fuzzfail6(_) + return bit32.extract(_,671088640,_) +end + +assert(pcall(fuzzfail6, 1) == false) + +local function fuzzfail7(_) + return bit32.extract(_,_,671088640) +end + +assert(pcall(fuzzfail7, 1) == false) + +local function fuzzfail8(...) + local _ = _,_ + _.n0,_,_,_,_,_,_,_,_._,_,_,_[...],_,_,_ = nil + _,n0,_,_,_,_,_,_,_,_,l0,_,_,_,_ = nil + function _() + end + _._,_,_,_,_,_,_,_,_,_,_[...],_,n0[l0],_ = nil + _[...],_,_,_,_,_,_,_,_()[_],_,_,_,_,_ = _(),... +end + +assert(pcall(fuzzfail8) == false) + +local function fuzzfail9() + local _ = bit32.bor + local x = _(_(_,_),_(_,_),_(-16834560,_),_(_(- _,-2130706432)),- _),_(_(_,_),_(-16834560,-2130706432)) +end + +assert(pcall(fuzzfail9) == false) + return('OK') diff --git a/tools/faillist.txt b/tools/faillist.txt index 1233837fe..31c42eb87 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -54,6 +54,8 @@ ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete RefinementTest.discriminate_from_truthiness_of_x RefinementTest.not_t_or_some_prop_of_t +RefinementTest.refine_a_property_of_some_global +RefinementTest.refinements_should_preserve_error_suppression RefinementTest.truthy_constraint_on_properties RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector @@ -96,7 +98,6 @@ TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables TableTests.table_call_metamethod_basic -TableTests.table_call_metamethod_generic TableTests.table_simple_call TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.used_colon_instead_of_dot @@ -131,7 +132,6 @@ TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 -TypeInferClasses.callable_classes TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.index_instance_property TypeInferFunctions.cannot_hoist_interior_defns_into_signature From dc2a1cc12caca2143d958ec9a36a299982a83043 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 7 Jul 2023 12:41:34 -0700 Subject: [PATCH 63/66] GCC fix. --- tests/Differ.test.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/Differ.test.cpp b/tests/Differ.test.cpp index 6b570f398..5e8e6b1df 100644 --- a/tests/Differ.test.cpp +++ b/tests/Differ.test.cpp @@ -28,7 +28,7 @@ TEST_CASE_FIXTURE(Fixture, "equal_numbers") DifferResult diffRes = diff(foo, almostFoo); CHECK(!diffRes.diffError.has_value()); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -51,7 +51,7 @@ TEST_CASE_FIXTURE(Fixture, "equal_strings") DifferResult diffRes = diff(foo, almostFoo); CHECK(!diffRes.diffError.has_value()); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -74,7 +74,7 @@ TEST_CASE_FIXTURE(Fixture, "equal_tables") DifferResult diffRes = diff(foo, almostFoo); CHECK(!diffRes.diffError.has_value()); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -97,7 +97,7 @@ TEST_CASE_FIXTURE(Fixture, "a_table_missing_property") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -123,7 +123,7 @@ TEST_CASE_FIXTURE(Fixture, "left_table_missing_property") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -149,7 +149,7 @@ TEST_CASE_FIXTURE(Fixture, "a_table_wrong_type") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -175,7 +175,7 @@ TEST_CASE_FIXTURE(Fixture, "a_table_wrong_type") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -201,7 +201,7 @@ TEST_CASE_FIXTURE(Fixture, "a_nested_table_wrong_type") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -227,7 +227,7 @@ TEST_CASE_FIXTURE(Fixture, "a_nested_table_wrong_match") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -253,7 +253,7 @@ TEST_CASE_FIXTURE(Fixture, "singleton") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -280,7 +280,7 @@ TEST_CASE_FIXTURE(Fixture, "equal_singleton") INFO(diffRes.diffError->toString()); CHECK(!diffRes.diffError.has_value()); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -303,7 +303,7 @@ TEST_CASE_FIXTURE(Fixture, "singleton_string") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); From b4030755731510b3770b58f6d00ecb385b20b3a3 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 14 Jul 2023 08:57:16 -0700 Subject: [PATCH 64/66] Sync to upstream/release/584 --- Analysis/include/Luau/Cancellation.h | 24 + Analysis/include/Luau/Differ.h | 24 +- Analysis/include/Luau/Error.h | 14 +- Analysis/include/Luau/Frontend.h | 4 + Analysis/include/Luau/InsertionOrderedMap.h | 6 +- Analysis/include/Luau/Module.h | 1 + Analysis/include/Luau/ToString.h | 2 +- Analysis/include/Luau/TypeChecker2.h | 3 +- Analysis/include/Luau/TypeInfer.h | 13 + Analysis/include/Luau/Unifier.h | 3 +- Analysis/src/Clone.cpp | 1 - Analysis/src/Differ.cpp | 209 +++++- Analysis/src/Error.cpp | 8 +- Analysis/src/Frontend.cpp | 50 +- Analysis/src/Module.cpp | 26 +- Analysis/src/Substitution.cpp | 90 +-- Analysis/src/ToString.cpp | 14 +- Analysis/src/TypeChecker2.cpp | 16 +- Analysis/src/TypeInfer.cpp | 38 +- Analysis/src/Unifier.cpp | 1 - Ast/include/Luau/Location.h | 5 - Ast/src/Location.cpp | 8 - Ast/src/TimeTrace.cpp | 10 +- CLI/Reduce.cpp | 3 +- CLI/Repl.cpp | 4 + CodeGen/include/Luau/IrBuilder.h | 4 + CodeGen/include/Luau/IrData.h | 9 +- CodeGen/src/CodeBlockUnwind.cpp | 43 +- CodeGen/src/CodeGen.cpp | 12 +- CodeGen/src/CodeGenAssembly.cpp | 3 +- CodeGen/src/CodeGenLower.h | 18 +- CodeGen/src/EmitCommon.h | 4 +- CodeGen/src/EmitCommonX64.cpp | 1 - CodeGen/src/IrBuilder.cpp | 72 +- CodeGen/src/IrDump.cpp | 3 + CodeGen/src/IrLoweringA64.cpp | 101 ++- CodeGen/src/IrLoweringA64.h | 5 + CodeGen/src/IrLoweringX64.cpp | 61 +- CodeGen/src/IrLoweringX64.h | 2 + CodeGen/src/IrTranslateBuiltins.cpp | 232 +++---- CodeGen/src/IrTranslateBuiltins.h | 3 +- CodeGen/src/IrTranslation.cpp | 28 +- CodeGen/src/IrTranslation.h | 2 +- CodeGen/src/OptimizeConstProp.cpp | 11 +- Common/include/Luau/Bytecode.h | 2 +- Compiler/include/Luau/BytecodeBuilder.h | 4 +- Compiler/include/Luau/Compiler.h | 3 + Compiler/include/luacode.h | 3 + Compiler/src/BytecodeBuilder.cpp | 17 +- Compiler/src/Compiler.cpp | 2 +- Compiler/src/Types.cpp | 35 +- Compiler/src/Types.h | 2 +- Sources.cmake | 1 + VM/src/ldebug.cpp | 35 +- VM/src/lfunc.cpp | 6 +- VM/src/lgc.h | 2 + VM/src/lgcdebug.cpp | 226 +++++++ VM/src/lobject.h | 23 +- VM/src/lstate.cpp | 2 - VM/src/lstate.h | 1 - VM/src/lvmload.cpp | 3 +- fuzz/proto.cpp | 15 +- tests/Compiler.test.cpp | 30 +- tests/Conformance.test.cpp | 38 ++ tests/Differ.test.cpp | 710 +++++++++++++++++++- tests/Module.test.cpp | 28 +- tests/Parser.test.cpp | 12 +- tests/TypeFamily.test.cpp | 3 +- tests/TypeInfer.functions.test.cpp | 2 - tests/TypeInfer.intersectionTypes.test.cpp | 30 + tests/TypeInfer.operators.test.cpp | 4 +- tests/TypeInfer.unionTypes.test.cpp | 15 + tests/conformance/native.lua | 9 + 73 files changed, 1915 insertions(+), 539 deletions(-) create mode 100644 Analysis/include/Luau/Cancellation.h diff --git a/Analysis/include/Luau/Cancellation.h b/Analysis/include/Luau/Cancellation.h new file mode 100644 index 000000000..441318631 --- /dev/null +++ b/Analysis/include/Luau/Cancellation.h @@ -0,0 +1,24 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +struct FrontendCancellationToken +{ + void cancel() + { + cancelled.store(true); + } + + bool requested() + { + return cancelled.load(); + } + + std::atomic cancelled; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Differ.h b/Analysis/include/Luau/Differ.h index ad276b4f9..da8b64685 100644 --- a/Analysis/include/Luau/Differ.h +++ b/Analysis/include/Luau/Differ.h @@ -21,8 +21,8 @@ struct DiffPathNode Kind kind; // non-null when TableProperty std::optional tableProperty; - // non-null when FunctionArgument, FunctionReturn, Union, or Intersection (i.e. anonymous fields) - std::optional index; + // non-null when FunctionArgument (unless variadic arg), FunctionReturn (unless variadic arg), Union, or Intersection (i.e. anonymous fields) + std::optional index; /** * Do not use for leaf nodes @@ -32,7 +32,7 @@ struct DiffPathNode { } - DiffPathNode(Kind kind, std::optional tableProperty, std::optional index) + DiffPathNode(Kind kind, std::optional tableProperty, std::optional index) : kind(kind) , tableProperty(tableProperty) , index(index) @@ -42,19 +42,35 @@ struct DiffPathNode std::string toString() const; static DiffPathNode constructWithTableProperty(Name tableProperty); + + static DiffPathNode constructWithKindAndIndex(Kind kind, size_t index); + + static DiffPathNode constructWithKind(Kind kind); }; + struct DiffPathNodeLeaf { std::optional ty; std::optional tableProperty; - DiffPathNodeLeaf(std::optional ty, std::optional tableProperty) + std::optional minLength; + bool isVariadic; + DiffPathNodeLeaf(std::optional ty, std::optional tableProperty, std::optional minLength, bool isVariadic) : ty(ty) , tableProperty(tableProperty) + , minLength(minLength) + , isVariadic(isVariadic) { } + static DiffPathNodeLeaf detailsNormal(TypeId ty); + + static DiffPathNodeLeaf detailsTableProperty(TypeId ty, Name tableProperty); + + static DiffPathNodeLeaf detailsLength(int minLength, bool isVariadic); + static DiffPathNodeLeaf nullopts(); }; + struct DiffPath { std::vector path; diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 858d1b499..13758b378 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -357,13 +357,13 @@ struct PackWhereClauseNeeded bool operator==(const PackWhereClauseNeeded& rhs) const; }; -using TypeErrorData = - Variant; +using TypeErrorData = Variant; struct TypeErrorSummary { diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 7b1eb2076..5804b7a8c 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -29,6 +29,7 @@ struct ModuleResolver; struct ParseResult; struct HotComment; struct BuildQueueItem; +struct FrontendCancellationToken; struct LoadDefinitionFileResult { @@ -96,6 +97,8 @@ struct FrontendOptions std::optional randomizeConstraintResolutionSeed; std::optional enabledLintWarnings; + + std::shared_ptr cancellationToken; }; struct CheckResult @@ -191,6 +194,7 @@ struct Frontend std::optional finishTime; std::optional instantiationChildLimit; std::optional unifierIterationLimit; + std::shared_ptr cancellationToken; }; ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, std::optional environmentScope, diff --git a/Analysis/include/Luau/InsertionOrderedMap.h b/Analysis/include/Luau/InsertionOrderedMap.h index 66d6b2ab8..2937dcda2 100644 --- a/Analysis/include/Luau/InsertionOrderedMap.h +++ b/Analysis/include/Luau/InsertionOrderedMap.h @@ -16,10 +16,10 @@ struct InsertionOrderedMap { static_assert(std::is_trivially_copyable_v, "key must be trivially copyable"); - private: +private: using vec = std::vector>; - public: +public: using iterator = typename vec::iterator; using const_iterator = typename vec::const_iterator; @@ -131,4 +131,4 @@ struct InsertionOrderedMap std::unordered_map indices; }; -} +} // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index a3b9c4172..cb7617140 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -112,6 +112,7 @@ struct Module Mode mode; SourceCode::Type type; bool timeout = false; + bool cancelled = false; TypePackId returnType = nullptr; std::unordered_map exportedTypeBindings; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index dec2c1fc5..efe82b124 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -139,6 +139,6 @@ std::string dump(const std::shared_ptr& scope, const char* name); std::string generateName(size_t n); std::string toString(const Position& position); -std::string toString(const Location& location); +std::string toString(const Location& location, int offset = 0, bool useBegin = true); } // namespace Luau diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index def00a440..11d2aff93 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -12,6 +12,7 @@ namespace Luau struct DcrLogger; struct BuiltinTypes; -void check(NotNull builtinTypes, NotNull sharedState, DcrLogger* logger, const SourceModule& sourceModule, Module* module); +void check(NotNull builtinTypes, NotNull sharedState, DcrLogger* logger, const SourceModule& sourceModule, + Module* module); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 9902e5a1e..79ee60c46 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -25,6 +25,7 @@ namespace Luau struct Scope; struct TypeChecker; struct ModuleResolver; +struct FrontendCancellationToken; using Name = std::string; using ScopePtr = std::shared_ptr; @@ -64,6 +65,15 @@ class TimeLimitError : public InternalCompilerError } }; +class UserCancelError : public InternalCompilerError +{ +public: + explicit UserCancelError(const std::string& moduleName) + : InternalCompilerError("Analysis has been cancelled by user", moduleName) + { + } +}; + struct GlobalTypes { GlobalTypes(NotNull builtinTypes); @@ -262,6 +272,7 @@ struct TypeChecker [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); [[noreturn]] void throwTimeLimitError(); + [[noreturn]] void throwUserCancelError(); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); ScopePtr childScope(const ScopePtr& parent, const Location& location); @@ -387,6 +398,8 @@ struct TypeChecker std::optional instantiationChildLimit; std::optional unifierIterationLimit; + std::shared_ptr cancellationToken; + public: const TypeId nilType; const TypeId numberType; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 99da33f62..7a6a2f760 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -76,8 +76,7 @@ struct Unifier std::vector blockedTypes; std::vector blockedTypePacks; - Unifier( - NotNull normalizer, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); + Unifier(NotNull normalizer, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); // Configure the Unifier to test for scope subsumption via embedded Scope // pointers rather than TypeLevels. diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index bdb510a37..197aad7af 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -7,7 +7,6 @@ #include "Luau/Unifiable.h" LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) -LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) diff --git a/Analysis/src/Differ.cpp b/Analysis/src/Differ.cpp index 91b6f4b61..50672cd9e 100644 --- a/Analysis/src/Differ.cpp +++ b/Analysis/src/Differ.cpp @@ -3,7 +3,9 @@ #include "Luau/Error.h" #include "Luau/ToString.h" #include "Luau/Type.h" +#include "Luau/TypePack.h" #include +#include namespace Luau { @@ -18,6 +20,20 @@ std::string DiffPathNode::toString() const return *tableProperty; break; } + case DiffPathNode::Kind::FunctionArgument: + { + if (!index.has_value()) + return "Arg[Variadic]"; + // Add 1 because Lua is 1-indexed + return "Arg[" + std::to_string(*index + 1) + "]"; + } + case DiffPathNode::Kind::FunctionReturn: + { + if (!index.has_value()) + return "Ret[Variadic]"; + // Add 1 because Lua is 1-indexed + return "Ret[" + std::to_string(*index + 1) + "]"; + } default: { throw InternalCompilerError{"DiffPathNode::toString is not exhaustive"}; @@ -30,9 +46,34 @@ DiffPathNode DiffPathNode::constructWithTableProperty(Name tableProperty) return DiffPathNode{DiffPathNode::Kind::TableProperty, tableProperty, std::nullopt}; } +DiffPathNode DiffPathNode::constructWithKindAndIndex(Kind kind, size_t index) +{ + return DiffPathNode{kind, std::nullopt, index}; +} + +DiffPathNode DiffPathNode::constructWithKind(Kind kind) +{ + return DiffPathNode{kind, std::nullopt, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsNormal(TypeId ty) +{ + return DiffPathNodeLeaf{ty, std::nullopt, std::nullopt, false}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsTableProperty(TypeId ty, Name tableProperty) +{ + return DiffPathNodeLeaf{ty, tableProperty, std::nullopt, false}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsLength(int minLength, bool isVariadic) +{ + return DiffPathNodeLeaf{std::nullopt, std::nullopt, minLength, isVariadic}; +} + DiffPathNodeLeaf DiffPathNodeLeaf::nullopts() { - return DiffPathNodeLeaf{std::nullopt, std::nullopt}; + return DiffPathNodeLeaf{std::nullopt, std::nullopt, std::nullopt, false}; } std::string DiffPath::toString(bool prependDot) const @@ -79,9 +120,21 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea } throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; } + case DiffError::Kind::LengthMismatchInFnArgs: + { + if (!leaf.minLength.has_value()) + throw InternalCompilerError{"leaf.minLength is nullopt"}; + return pathStr + " takes " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " arguments"; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + if (!leaf.minLength.has_value()) + throw InternalCompilerError{"leaf.minLength is nullopt"}; + return pathStr + " returns " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " values"; + } default: { - throw InternalCompilerError{"DiffPath::toStringWithLeaf is not exhaustive"}; + throw InternalCompilerError{"DiffPath::toStringALeaf is not exhaustive"}; } } } @@ -139,6 +192,14 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId right); +/** + * The last argument gives context info on which complex type contained the TypePack. + */ +static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right); +static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, + const std::pair, std::optional>& left, const std::pair, std::optional>& right); +static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right); static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) { @@ -152,7 +213,7 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) // left has a field the right doesn't return DifferResult{DiffError{ DiffError::Kind::MissingProperty, - DiffPathNodeLeaf{value.type(), field}, + DiffPathNodeLeaf::detailsTableProperty(value.type(), field), DiffPathNodeLeaf::nullopts(), getDevFixFriendlyName(env.rootLeft), getDevFixFriendlyName(env.rootRight), @@ -164,8 +225,9 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) if (leftTable->props.find(field) == leftTable->props.end()) { // right has a field the left doesn't - return DifferResult{DiffError{DiffError::Kind::MissingProperty, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf{value.type(), field}, - getDevFixFriendlyName(env.rootLeft), getDevFixFriendlyName(env.rootRight)}}; + return DifferResult{ + DiffError{DiffError::Kind::MissingProperty, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::detailsTableProperty(value.type(), field), + getDevFixFriendlyName(env.rootLeft), getDevFixFriendlyName(env.rootRight)}}; } } // left and right have the same set of keys @@ -191,8 +253,8 @@ static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId ri { return DifferResult{DiffError{ DiffError::Kind::Normal, - DiffPathNodeLeaf{left, std::nullopt}, - DiffPathNodeLeaf{right, std::nullopt}, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), getDevFixFriendlyName(env.rootLeft), getDevFixFriendlyName(env.rootRight), }}; @@ -209,8 +271,8 @@ static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId ri { return DifferResult{DiffError{ DiffError::Kind::Normal, - DiffPathNodeLeaf{left, std::nullopt}, - DiffPathNodeLeaf{right, std::nullopt}, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), getDevFixFriendlyName(env.rootLeft), getDevFixFriendlyName(env.rootRight), }}; @@ -218,6 +280,17 @@ static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId ri return DifferResult{}; } +static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId right) +{ + const FunctionType* leftFunction = get(left); + const FunctionType* rightFunction = get(right); + + DifferResult differResult = diffTpi(env, DiffError::Kind::LengthMismatchInFnArgs, leftFunction->argTypes, rightFunction->argTypes); + if (differResult.diffError.has_value()) + return differResult; + return diffTpi(env, DiffError::Kind::LengthMismatchInFnRets, leftFunction->retTypes, rightFunction->retTypes); +} + static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right) { left = follow(left); @@ -227,8 +300,8 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig { return DifferResult{DiffError{ DiffError::Kind::Normal, - DiffPathNodeLeaf{left, std::nullopt}, - DiffPathNodeLeaf{right, std::nullopt}, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), getDevFixFriendlyName(env.rootLeft), getDevFixFriendlyName(env.rootRight), }}; @@ -244,6 +317,11 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig { return diffSingleton(env, left, right); } + else if (auto la = get(left)) + { + // Both left and right must be Any if either is Any for them to be equal! + return DifferResult{}; + } throw InternalCompilerError{"Unimplemented Simple TypeId variant for diffing"}; } @@ -254,9 +332,116 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig { return diffTable(env, left, right); } + if (auto lf = get(left)) + { + return diffFunction(env, left, right); + } throw InternalCompilerError{"Unimplemented non-simple TypeId variant for diffing"}; } +static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right) +{ + left = follow(left); + right = follow(right); + + // Canonicalize + std::pair, std::optional> leftFlatTpi = flatten(left); + std::pair, std::optional> rightFlatTpi = flatten(right); + + // Check for shape equality + DifferResult diffResult = diffCanonicalTpShape(env, possibleNonNormalErrorKind, leftFlatTpi, rightFlatTpi); + if (diffResult.diffError.has_value()) + { + return diffResult; + } + + // Left and Right have the same shape + for (size_t i = 0; i < leftFlatTpi.first.size(); i++) + { + DifferResult differResult = diffUsingEnv(env, leftFlatTpi.first[i], rightFlatTpi.first[i]); + if (!differResult.diffError.has_value()) + continue; + + switch (possibleNonNormalErrorKind) + { + case DiffError::Kind::LengthMismatchInFnArgs: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKindAndIndex(DiffPathNode::Kind::FunctionArgument, i)); + return differResult; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKindAndIndex(DiffPathNode::Kind::FunctionReturn, i)); + return differResult; + } + default: + { + throw InternalCompilerError{"Unhandled Tpi diffing case with same shape"}; + } + } + } + if (!leftFlatTpi.second.has_value()) + return DifferResult{}; + + return diffHandleFlattenedTail(env, possibleNonNormalErrorKind, *leftFlatTpi.second, *rightFlatTpi.second); +} + +static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, + const std::pair, std::optional>& left, const std::pair, std::optional>& right) +{ + if (left.first.size() == right.first.size() && left.second.has_value() == right.second.has_value()) + return DifferResult{}; + + return DifferResult{DiffError{ + possibleNonNormalErrorKind, + DiffPathNodeLeaf::detailsLength(int(left.first.size()), left.second.has_value()), + DiffPathNodeLeaf::detailsLength(int(right.first.size()), right.second.has_value()), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; +} + +static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right) +{ + left = follow(left); + right = follow(right); + + if (left->ty.index() != right->ty.index()) + { + throw InternalCompilerError{"Unhandled case where the tail of 2 normalized typepacks have different variants"}; + } + + // Both left and right are the same variant + + if (auto lv = get(left)) + { + auto rv = get(right); + DifferResult differResult = diffUsingEnv(env, lv->ty, rv->ty); + if (!differResult.diffError.has_value()) + return DifferResult{}; + + switch (possibleNonNormalErrorKind) + { + case DiffError::Kind::LengthMismatchInFnArgs: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionArgument)); + return differResult; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionReturn)); + return differResult; + } + default: + { + throw InternalCompilerError{"Unhandled flattened tail case for VariadicTypePack"}; + } + } + } + + throw InternalCompilerError{"Unhandled tail type pack variant for flattened tails"}; +} + DifferResult diff(TypeId ty1, TypeId ty2) { DifferEnvironment differEnv{ty1, ty2}; @@ -267,7 +452,7 @@ bool isSimple(TypeId ty) { ty = follow(ty); // TODO: think about GenericType, etc. - return get(ty) || get(ty); + return get(ty) || get(ty) || get(ty); } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index fba3c88a3..1a690a50a 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -495,12 +495,16 @@ struct ErrorConverter std::string operator()(const WhereClauseNeeded& e) const { - return "Type family instance " + Luau::toString(e.ty) + " depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time"; + return "Type family instance " + Luau::toString(e.ty) + + " depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this " + "time"; } std::string operator()(const PackWhereClauseNeeded& e) const { - return "Type pack family instance " + Luau::toString(e.tp) + " depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time"; + return "Type pack family instance " + Luau::toString(e.tp) + + " depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this " + "time"; } }; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 9f1fd7267..2dea162bd 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -35,7 +35,7 @@ LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) -LUAU_FASTFLAGVARIABLE(LuauFixBuildQueueExceptionUnwrap, false) +LUAU_FASTFLAGVARIABLE(LuauTypecheckCancellation, false) namespace Luau { @@ -461,6 +461,10 @@ CheckResult Frontend::check(const ModuleName& name, std::optionaltimeout) checkResult.timeoutHits.push_back(item.name); + // If check was manually cancelled, do not return partial results + if (FFlag::LuauTypecheckCancellation && item.module->cancelled) + return {}; + checkResult.errors.insert(checkResult.errors.end(), item.module->errors.begin(), item.module->errors.end()); if (item.name == name) @@ -610,6 +614,7 @@ std::vector Frontend::checkQueuedModules(std::optional nextItems; std::optional itemWithException; + bool cancelled = false; while (remaining != 0) { @@ -626,15 +631,15 @@ std::vector Frontend::checkQueuedModules(std::optionalcancelled) + cancelled = true; + + if (itemWithException || cancelled) + break; recordItemResult(item); @@ -671,8 +676,12 @@ std::vector Frontend::checkQueuedModules(std::optional& items) for (BuildQueueItem& item : items) { checkBuildQueueItem(item); + + if (FFlag::LuauTypecheckCancellation && item.module && item.module->cancelled) + break; + recordItemResult(item); } } @@ -1232,8 +1253,8 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect catch (const InternalCompilerError& err) { InternalCompilerError augmented = err.location.has_value() - ? InternalCompilerError{err.message, sourceModule.humanReadableName, *err.location} - : InternalCompilerError{err.message, sourceModule.humanReadableName}; + ? InternalCompilerError{err.message, sourceModule.humanReadableName, *err.location} + : InternalCompilerError{err.message, sourceModule.humanReadableName}; throw augmented; } } @@ -1254,6 +1275,9 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect typeChecker.instantiationChildLimit = typeCheckLimits.instantiationChildLimit; typeChecker.unifierIterationLimit = typeCheckLimits.unifierIterationLimit; + if (FFlag::LuauTypecheckCancellation) + typeChecker.cancellationToken = typeCheckLimits.cancellationToken; + return typeChecker.check(sourceModule, mode, environmentScope); } } diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 473b8acc4..cb2114abd 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,8 +15,6 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess2, false); -LUAU_FASTFLAGVARIABLE(LuauCloneSkipNonInternalVisit, false); namespace Luau { @@ -98,7 +96,7 @@ struct ClonePublicInterface : Substitution bool ignoreChildrenVisit(TypeId ty) override { - if (FFlag::LuauCloneSkipNonInternalVisit && ty->owningArena != &module->internalTypes) + if (ty->owningArena != &module->internalTypes) return true; return false; @@ -106,7 +104,7 @@ struct ClonePublicInterface : Substitution bool ignoreChildrenVisit(TypePackId tp) override { - if (FFlag::LuauCloneSkipNonInternalVisit && tp->owningArena != &module->internalTypes) + if (tp->owningArena != &module->internalTypes) return true; return false; @@ -211,35 +209,23 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr TxnLog log; ClonePublicInterface clonePublicInterface{&log, builtinTypes, this}; - if (FFlag::LuauClonePublicInterfaceLess2) - returnType = clonePublicInterface.cloneTypePack(returnType); - else - returnType = clone(returnType, interfaceTypes, cloneState); + returnType = clonePublicInterface.cloneTypePack(returnType); moduleScope->returnType = returnType; if (varargPack) { - if (FFlag::LuauClonePublicInterfaceLess2) - varargPack = clonePublicInterface.cloneTypePack(*varargPack); - else - varargPack = clone(*varargPack, interfaceTypes, cloneState); + varargPack = clonePublicInterface.cloneTypePack(*varargPack); moduleScope->varargPack = varargPack; } for (auto& [name, tf] : moduleScope->exportedTypeBindings) { - if (FFlag::LuauClonePublicInterfaceLess2) - tf = clonePublicInterface.cloneTypeFun(tf); - else - tf = clone(tf, interfaceTypes, cloneState); + tf = clonePublicInterface.cloneTypeFun(tf); } for (auto& [name, ty] : declaredGlobals) { - if (FFlag::LuauClonePublicInterfaceLess2) - ty = clonePublicInterface.cloneType(ty); - else - ty = clone(ty, interfaceTypes, cloneState); + ty = clonePublicInterface.cloneType(ty); } // Copy external stuff over to Module itself diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 9c34cd7cc..4c6c35b0f 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,93 +8,15 @@ #include #include -LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAG(DebugLuauReadWriteProperties) -LUAU_FASTFLAG(LuauCloneSkipNonInternalVisit) LUAU_FASTFLAGVARIABLE(LuauTarjanSingleArr, false) namespace Luau { -static TypeId DEPRECATED_shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) -{ - ty = log->follow(ty); - - TypeId result = ty; - - if (auto pty = log->pending(ty)) - ty = &pty->pending; - - if (const FunctionType* ftv = get(ty)) - { - FunctionType clone = FunctionType{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; - clone.generics = ftv->generics; - clone.genericPacks = ftv->genericPacks; - clone.magicFunction = ftv->magicFunction; - clone.dcrMagicFunction = ftv->dcrMagicFunction; - clone.dcrMagicRefinement = ftv->dcrMagicRefinement; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - result = dest.addType(std::move(clone)); - } - else if (const TableType* ttv = get(ty)) - { - LUAU_ASSERT(!ttv->boundTo); - TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; - clone.definitionModuleName = ttv->definitionModuleName; - clone.definitionLocation = ttv->definitionLocation; - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - clone.tags = ttv->tags; - result = dest.addType(std::move(clone)); - } - else if (const MetatableType* mtv = get(ty)) - { - MetatableType clone = MetatableType{mtv->table, mtv->metatable}; - clone.syntheticName = mtv->syntheticName; - result = dest.addType(std::move(clone)); - } - else if (const UnionType* utv = get(ty)) - { - UnionType clone; - clone.options = utv->options; - result = dest.addType(std::move(clone)); - } - else if (const IntersectionType* itv = get(ty)) - { - IntersectionType clone; - clone.parts = itv->parts; - result = dest.addType(std::move(clone)); - } - else if (const PendingExpansionType* petv = get(ty)) - { - PendingExpansionType clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; - result = dest.addType(std::move(clone)); - } - else if (const NegationType* ntv = get(ty)) - { - result = dest.addType(NegationType{ntv->ty}); - } - else if (const TypeFamilyInstanceType* tfit = get(ty)) - { - TypeFamilyInstanceType clone{tfit->family, tfit->typeArguments, tfit->packArguments}; - result = dest.addType(std::move(clone)); - } - else - return result; - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; -} - static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) { - if (!FFlag::LuauClonePublicInterfaceLess2) - return DEPRECATED_shallowClone(ty, dest, log, alwaysClone); - auto go = [ty, &dest, alwaysClone](auto&& a) { using T = std::decay_t; @@ -224,7 +146,7 @@ void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(ty == log->follow(ty)); - if (FFlag::LuauCloneSkipNonInternalVisit ? ignoreChildrenVisit(ty) : ignoreChildren(ty)) + if (ignoreChildrenVisit(ty)) return; if (auto pty = log->pending(ty)) @@ -324,7 +246,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) { LUAU_ASSERT(tp == log->follow(tp)); - if (FFlag::LuauCloneSkipNonInternalVisit ? ignoreChildrenVisit(tp) : ignoreChildren(tp)) + if (ignoreChildrenVisit(tp)) return; if (auto ptp = log->pending(tp)) @@ -856,7 +778,7 @@ std::optional Substitution::substitute(TypePackId tp) TypeId Substitution::clone(TypeId ty) { - return shallowClone(ty, *arena, log, /* alwaysClone */ FFlag::LuauClonePublicInterfaceLess2); + return shallowClone(ty, *arena, log, /* alwaysClone */ true); } TypePackId Substitution::clone(TypePackId tp) @@ -888,12 +810,8 @@ TypePackId Substitution::clone(TypePackId tp) clone.packArguments.assign(tfitp->packArguments.begin(), tfitp->packArguments.end()); return addTypePack(std::move(clone)); } - else if (FFlag::LuauClonePublicInterfaceLess2) - { - return addTypePack(*tp); - } else - return tp; + return addTypePack(*tp); } void Substitution::foundDirty(TypeId ty) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index f4375b5ae..19776d0a9 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -13,8 +13,10 @@ #include #include +#include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAGVARIABLE(LuauToStringPrettifyLocation, false) /* * Enables increasing levels of verbosity for Luau type names when stringifying. @@ -1739,9 +1741,17 @@ std::string toString(const Position& position) return "{ line = " + std::to_string(position.line) + ", col = " + std::to_string(position.column) + " }"; } -std::string toString(const Location& location) +std::string toString(const Location& location, int offset, bool useBegin) { - return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }"; + if (FFlag::LuauToStringPrettifyLocation) + { + return "(" + std::to_string(location.begin.line + offset) + ", " + std::to_string(location.begin.column + offset) + ") - (" + + std::to_string(location.end.line + offset) + ", " + std::to_string(location.end.column + offset) + ")"; + } + else + { + return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }"; + } } } // namespace Luau diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 103f0dcab..b77f7f159 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1239,7 +1239,8 @@ struct TypeChecker2 return std::move(u.errors); } - std::pair checkOverload(TypeId fnTy, const TypePack* args, Location fnLoc, const std::vector* argLocs, bool callMetamethodOk = true) + std::pair checkOverload( + TypeId fnTy, const TypePack* args, Location fnLoc, const std::vector* argLocs, bool callMetamethodOk = true) { fnTy = follow(fnTy); @@ -1257,17 +1258,18 @@ struct TypeChecker2 std::vector withSelfLocs = *argLocs; withSelfLocs.insert(withSelfLocs.begin(), fnLoc); - return checkOverload(*callMm, &withSelf, fnLoc, &withSelfLocs, /*callMetamethodOk=*/ false); + return checkOverload(*callMm, &withSelf, fnLoc, &withSelfLocs, /*callMetamethodOk=*/false); } else return {TypeIsNotAFunction, {}}; // Intentionally empty. We can just fabricate the type error later on. } LUAU_NOINLINE - std::pair checkOverload_(TypeId fnTy, const FunctionType* fn, const TypePack* args, Location fnLoc, const std::vector* argLocs) + std::pair checkOverload_( + TypeId fnTy, const FunctionType* fn, const TypePack* args, Location fnLoc, const std::vector* argLocs) { TxnLog fake; - FamilyGraphReductionResult result = reduceFamilies(fnTy, callLoc, arena, builtinTypes, scope, normalizer, &fake, /*force=*/ true); + FamilyGraphReductionResult result = reduceFamilies(fnTy, callLoc, arena, builtinTypes, scope, normalizer, &fake, /*force=*/true); if (!result.errors.empty()) return {OverloadIsNonviable, result.errors}; @@ -2374,6 +2376,9 @@ struct TypeChecker2 return; } + if (norm->shouldSuppressErrors()) + return; + bool foundOneProp = false; std::vector typesMissingTheProp; @@ -2539,7 +2544,8 @@ struct TypeChecker2 } }; -void check(NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule& sourceModule, Module* module) +void check( + NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule& sourceModule, Module* module) { TypeChecker2 typeChecker{builtinTypes, unifierState, logger, &sourceModule, module}; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index c4a6d103b..cfb0f21cc 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -2,6 +2,7 @@ #include "Luau/TypeInfer.h" #include "Luau/ApplyTypeFunction.h" +#include "Luau/Cancellation.h" #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Instantiation.h" @@ -40,6 +41,7 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAG(LuauParseDeclareClassIndexer) +LUAU_FASTFLAGVARIABLE(LuauIndexTableIntersectionStringExpr, false) namespace Luau { @@ -301,6 +303,10 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo { currentModule->timeout = true; } + catch (const UserCancelError&) + { + currentModule->cancelled = true; + } if (FFlag::DebugLuauSharedSelf) { @@ -344,7 +350,9 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program) { if (finishTime && TimeTrace::getClock() > *finishTime) - throw TimeLimitError(iceHandler->moduleName); + throwTimeLimitError(); + if (cancellationToken && cancellationToken->requested()) + throwUserCancelError(); if (auto block = program.as()) return check(scope, *block); @@ -3381,6 +3389,20 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); return errorRecoveryType(scope); } + else if (FFlag::LuauIndexTableIntersectionStringExpr && get(exprType)) + { + Name name = std::string(value->value.data, value->value.size); + + if (std::optional ty = getIndexTypeFromType(scope, exprType, name, expr.location, /* addErrors= */ false)) + return *ty; + + // If intersection has a table part, report that it cannot be extended just as a sealed table + if (isTableIntersection(exprType)) + { + reportError(TypeError{expr.location, CannotExtendTable{exprType, CannotExtendTable::Property, name}}); + return errorRecoveryType(scope); + } + } } else { @@ -4914,16 +4936,26 @@ void TypeChecker::reportErrors(const ErrorVec& errors) reportError(err); } -void TypeChecker::ice(const std::string& message, const Location& location) +LUAU_NOINLINE void TypeChecker::ice(const std::string& message, const Location& location) { iceHandler->ice(message, location); } -void TypeChecker::ice(const std::string& message) +LUAU_NOINLINE void TypeChecker::ice(const std::string& message) { iceHandler->ice(message); } +LUAU_NOINLINE void TypeChecker::throwTimeLimitError() +{ + throw TimeLimitError(iceHandler->moduleName); +} + +LUAU_NOINLINE void TypeChecker::throwUserCancelError() +{ + throw UserCancelError(iceHandler->moduleName); +} + void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) { // Remove errors with names that were generated by recovery from a parse error diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index eae007885..e54156feb 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -19,7 +19,6 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauVariadicAnyCanBeGeneric, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) diff --git a/Ast/include/Luau/Location.h b/Ast/include/Luau/Location.h index 41ca379d1..041a2c631 100644 --- a/Ast/include/Luau/Location.h +++ b/Ast/include/Luau/Location.h @@ -39,11 +39,6 @@ struct Location bool containsClosed(const Position& p) const; void extend(const Location& other); void shift(const Position& start, const Position& oldEnd, const Position& newEnd); - - /** - * Use offset=1 when displaying for the user. - */ - std::string toString(int offset = 0, bool useBegin = true) const; }; } // namespace Luau diff --git a/Ast/src/Location.cpp b/Ast/src/Location.cpp index e0ae867fc..40f8e23ee 100644 --- a/Ast/src/Location.cpp +++ b/Ast/src/Location.cpp @@ -129,12 +129,4 @@ void Location::shift(const Position& start, const Position& oldEnd, const Positi end.shift(start, oldEnd, newEnd); } -std::string Location::toString(int offset, bool useBegin) const -{ - const Position& pos = useBegin ? this->begin : this->end; - std::string line{std::to_string(pos.line + offset)}; - std::string column{std::to_string(pos.column + offset)}; - return "(" + line + ", " + column + ")"; -} - } // namespace Luau diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index 8b95cf0b7..9373fa95d 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -92,14 +92,6 @@ struct GlobalContext { ~GlobalContext() { - // Ideally we would want all ThreadContext destructors to run - // But in VS, not all thread_local object instances are destroyed - for (ThreadContext* context : threads) - { - if (!context->events.empty()) - context->flushEvents(); - } - if (traceFile) fclose(traceFile); } @@ -109,7 +101,7 @@ struct GlobalContext uint32_t nextThreadId = 0; std::vector tokens; FILE* traceFile = nullptr; - + private: friend std::shared_ptr getGlobalContext(); GlobalContext() = default; diff --git a/CLI/Reduce.cpp b/CLI/Reduce.cpp index ffe670b8a..38133e04a 100644 --- a/CLI/Reduce.cpp +++ b/CLI/Reduce.cpp @@ -429,8 +429,7 @@ struct Reducer } } - void run(const std::string scriptName, const std::string command, std::string_view source, - std::string_view searchText) + void run(const std::string scriptName, const std::string command, std::string_view source, std::string_view searchText) { this->scriptName = scriptName; diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 87ce27177..df1b4edaf 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -675,6 +675,10 @@ int replMain(int argc, char** argv) setLuauFlagsDefault(); +#ifdef _WIN32 + SetConsoleOutputCP(CP_UTF8); +#endif + int profile = 0; bool coverage = false; bool interactive = false; diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 60106d1b5..d854b400a 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -26,6 +26,7 @@ struct IrBuilder void rebuildBytecodeBasicBlocks(Proto* proto); void translateInst(LuauOpcode op, const Instruction* pc, int i); + void handleFastcallFallback(IrOp fallbackOrUndef, const Instruction* pc, int i); bool isInternalBlock(IrOp block); void beginBlock(IrOp block); @@ -61,10 +62,13 @@ struct IrBuilder IrOp vmConst(uint32_t index); IrOp vmUpvalue(uint8_t index); + IrOp vmExit(uint32_t pcpos); + bool inTerminatedBlock = false; bool activeFastcallFallback = false; IrOp fastcallFallbackReturn; + int fastcallSkipTarget = -1; IrFunction function; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 16c8df628..0b38743ac 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -165,7 +165,7 @@ enum class IrCmd : uint8_t NOT_ANY, // TODO: boolean specialization will be useful // Unconditional jump - // A: block + // A: block/vmexit JUMP, // Jump if TValue is truthy @@ -364,7 +364,7 @@ enum class IrCmd : uint8_t // Guard against tag mismatch // A, B: tag - // C: block/undef + // C: block/vmexit/undef // D: bool (finish execution in VM on failure) // In final x64 lowering, A can also be Rn // When undef is specified instead of a block, execution is aborted on check failure; if D is true, execution is continued in VM interpreter @@ -384,7 +384,7 @@ enum class IrCmd : uint8_t CHECK_NO_METATABLE, // Guard against executing in unsafe environment, exits to VM on check failure - // A: unsigned int (pcpos)/undef + // A: vmexit/undef // When undef is specified, execution is aborted on check failure CHECK_SAFE_ENV, @@ -670,6 +670,9 @@ enum class IrOpKind : uint32_t // To reference a VM upvalue VmUpvalue, + + // To reference an exit to VM at specific PC pos + VmExit, }; struct IrOp diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index e9ce86747..a762cd371 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -23,11 +23,23 @@ extern "C" void __register_frame(const void*); extern "C" void __deregister_frame(const void*); extern "C" void __unw_add_dynamic_fde() __attribute__((weak)); - #endif #if defined(__APPLE__) && defined(__aarch64__) #include +#include +#include + +struct unw_dynamic_unwind_sections_t +{ + uintptr_t dso_base; + uintptr_t dwarf_section; + size_t dwarf_section_length; + uintptr_t compact_unwind_section; + size_t compact_unwind_section_length; +}; + +typedef int (*unw_add_find_dynamic_unwind_sections_t)(int (*)(uintptr_t addr, unw_dynamic_unwind_sections_t* info)); #endif namespace Luau @@ -35,6 +47,26 @@ namespace Luau namespace CodeGen { +#if defined(__APPLE__) && defined(__aarch64__) +static int findDynamicUnwindSections(uintptr_t addr, unw_dynamic_unwind_sections_t* info) +{ + // Define a minimal mach header for JIT'd code. + static const mach_header_64 kFakeMachHeader = { + MH_MAGIC_64, + CPU_TYPE_ARM64, + CPU_SUBTYPE_ARM64_ALL, + MH_DYLIB, + }; + + info->dso_base = (uintptr_t)&kFakeMachHeader; + info->dwarf_section = 0; + info->dwarf_section_length = 0; + info->compact_unwind_section = 0; + info->compact_unwind_section_length = 0; + return 1; +} +#endif + #if defined(__linux__) || defined(__APPLE__) static void visitFdeEntries(char* pos, void (*cb)(const void*)) { @@ -86,6 +118,15 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz visitFdeEntries(unwindData, __register_frame); #endif +#if defined(__APPLE__) && defined(__aarch64__) + // Starting from macOS 14, we need to register unwind section callback to state that our ABI doesn't require pointer authentication + // This might conflict with other JITs that do the same; unfortunately this is the best we can do for now. + static unw_add_find_dynamic_unwind_sections_t unw_add_find_dynamic_unwind_sections = + unw_add_find_dynamic_unwind_sections_t(dlsym(RTLD_DEFAULT, "__unw_add_find_dynamic_unwind_sections")); + static int regonce = unw_add_find_dynamic_unwind_sections ? unw_add_find_dynamic_unwind_sections(findDynamicUnwindSections) : 0; + LUAU_ASSERT(regonce == 0); +#endif + beginOffset = unwindSize + unwind->getBeginOffset(); return block; } diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 63dd9a4d6..cdb761c6a 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -141,14 +141,6 @@ static int onEnter(lua_State* L, Proto* proto) return GateFn(data->context.gateEntry)(L, proto, target, &data->context); } -static void onSetBreakpoint(lua_State* L, Proto* proto, int instruction) -{ - if (!proto->execdata) - return; - - LUAU_ASSERT(!"Native breakpoints are not implemented"); -} - #if defined(__aarch64__) unsigned int getCpuFeaturesA64() { @@ -245,7 +237,6 @@ void create(lua_State* L) ecb->close = onCloseState; ecb->destroy = onDestroyFunction; ecb->enter = onEnter; - ecb->setbreakpoint = onSetBreakpoint; } void compile(lua_State* L, int idx) @@ -259,7 +250,8 @@ void compile(lua_State* L, int idx) return; #if defined(__aarch64__) - A64::AssemblyBuilderA64 build(/* logText= */ false, getCpuFeaturesA64()); + static unsigned int cpuFeatures = getCpuFeaturesA64(); + A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); #else X64::AssemblyBuilderX64 build(/* logText= */ false); #endif diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index 36d8b274f..fed5ddd3e 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -100,7 +100,8 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) case AssemblyOptions::Host: { #if defined(__aarch64__) - A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, getCpuFeaturesA64()); + static unsigned int cpuFeatures = getCpuFeaturesA64(); + A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, cpuFeatures); #else X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); #endif diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index 4b74e9f20..a7352bce0 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -44,6 +44,18 @@ inline void gatherFunctions(std::vector& results, Proto* proto) gatherFunctions(results, proto->p[i]); } +inline IrBlock& getNextBlock(IrFunction& function, std::vector& sortedBlocks, IrBlock& dummy, size_t i) +{ + for (size_t j = i + 1; j < sortedBlocks.size(); ++j) + { + IrBlock& block = function.blocks[sortedBlocks[j]]; + if (block.kind != IrBlockKind::Dead) + return block; + } + + return dummy; +} + template inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) { @@ -118,6 +130,8 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& build.setLabel(block.label); + IrBlock& nextBlock = getNextBlock(function, sortedBlocks, dummy, i); + for (uint32_t index = block.start; index <= block.finish; index++) { LUAU_ASSERT(index < function.instructions.size()); @@ -156,9 +170,7 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& toStringDetailed(ctx, block, blockIndex, inst, index, /* includeUseInfo */ true); } - IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; - - lowering.lowerInst(inst, index, next); + lowering.lowerInst(inst, index, nextBlock); if (lowering.hasError()) { diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index 086660647..214cfd6de 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -12,7 +12,7 @@ constexpr unsigned kTValueSizeLog2 = 4; constexpr unsigned kLuaNodeSizeLog2 = 5; // TKey.tt and TKey.next are packed together in a bitfield -constexpr unsigned kOffsetOfTKeyTagNext = 12; // offsetof cannot be used on a bit field +constexpr unsigned kOffsetOfTKeyTagNext = 12; // offsetof cannot be used on a bit field constexpr unsigned kTKeyTagBits = 4; constexpr unsigned kTKeyTagMask = (1 << kTKeyTagBits) - 1; @@ -33,7 +33,7 @@ struct ModuleHelpers Label continueCallInVm; // A64 - Label reentry; // x0: closure + Label reentry; // x0: closure }; } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 1d707fad9..2ad5b040a 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -268,7 +268,6 @@ void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build) build.setLabel(skip); } - void emitClearNativeFlag(AssemblyBuilderX64& build) { build.mov(rax, qword[rState + offsetof(lua_State, ci)]); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 69ac295a3..04318effb 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -128,8 +128,16 @@ void IrBuilder::buildFunctionIr(Proto* proto) // We skip dead bytecode instructions when they appear after block was already terminated if (!inTerminatedBlock) + { translateInst(op, pc, i); + if (fastcallSkipTarget != -1) + { + nexti = fastcallSkipTarget; + fastcallSkipTarget = -1; + } + } + i = nexti; LUAU_ASSERT(i <= proto->sizecode); @@ -357,49 +365,17 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstCloseUpvals(*this, pc); break; case LOP_FASTCALL: - { - int skip = LUAU_INSN_C(*pc); - IrOp next = blockAtInst(i + skip + 2); - - translateFastCallN(*this, pc, i, false, 0, {}, next); - - activeFastcallFallback = true; - fastcallFallbackReturn = next; + handleFastcallFallback(translateFastCallN(*this, pc, i, false, 0, {}), pc, i); break; - } case LOP_FASTCALL1: - { - int skip = LUAU_INSN_C(*pc); - IrOp next = blockAtInst(i + skip + 2); - - translateFastCallN(*this, pc, i, true, 1, undef(), next); - - activeFastcallFallback = true; - fastcallFallbackReturn = next; + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 1, undef()), pc, i); break; - } case LOP_FASTCALL2: - { - int skip = LUAU_INSN_C(*pc); - IrOp next = blockAtInst(i + skip + 2); - - translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), next); - - activeFastcallFallback = true; - fastcallFallbackReturn = next; + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1])), pc, i); break; - } case LOP_FASTCALL2K: - { - int skip = LUAU_INSN_C(*pc); - IrOp next = blockAtInst(i + skip + 2); - - translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), next); - - activeFastcallFallback = true; - fastcallFallbackReturn = next; + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1])), pc, i); break; - } case LOP_FORNPREP: translateInstForNPrep(*this, pc, i); break; @@ -493,6 +469,25 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } } +void IrBuilder::handleFastcallFallback(IrOp fallbackOrUndef, const Instruction* pc, int i) +{ + int skip = LUAU_INSN_C(*pc); + + if (fallbackOrUndef.kind != IrOpKind::Undef) + { + IrOp next = blockAtInst(i + skip + 2); + inst(IrCmd::JUMP, next); + beginBlock(fallbackOrUndef); + + activeFastcallFallback = true; + fastcallFallbackReturn = next; + } + else + { + fastcallSkipTarget = i + skip + 2; + } +} + bool IrBuilder::isInternalBlock(IrOp block) { IrBlock& target = function.blocks[block.index]; @@ -718,5 +713,10 @@ IrOp IrBuilder::vmUpvalue(uint8_t index) return {IrOpKind::VmUpvalue, index}; } +IrOp IrBuilder::vmExit(uint32_t pcpos) +{ + return {IrOpKind::VmExit, pcpos}; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index dfd7236a3..c44cd8eb8 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -389,6 +389,9 @@ void toString(IrToStringContext& ctx, IrOp op) case IrOpKind::VmUpvalue: append(ctx.result, "U%d", vmUpvalueOp(op)); break; + case IrOpKind::VmExit: + append(ctx.result, "exit(%d)", op.index); + break; } } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 38e840ab7..92cb49adb 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -178,6 +178,7 @@ IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, , function(function) , regs(function, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}}) , valueTracker(function) + , exitHandlerMap(~0u) { // In order to allocate registers during lowering, we need to know where instruction results are last used updateLastUseLocations(function); @@ -514,8 +515,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(regOp(inst.a), LUA_TBOOLEAN); build.b(ConditionA64::NotEqual, notbool); - // boolean => invert value - build.eor(inst.regA64, regOp(inst.b), 1); + if (inst.b.kind == IrOpKind::Constant) + build.mov(inst.regA64, intOp(inst.b) == 0 ? 1 : 0); + else + build.eor(inst.regA64, regOp(inst.b), 1); // boolean => invert value + build.b(exit); // not boolean => result is true iff tag was nil @@ -527,7 +531,16 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::JUMP: - jumpOrFallthrough(blockOp(inst.a), next); + if (inst.a.kind == IrOpKind::VmExit) + { + Label fresh; + build.b(getTargetLabel(inst.a, fresh)); + finalizeTargetLabel(inst.a, fresh); + } + else + { + jumpOrFallthrough(blockOp(inst.a), next); + } break; case IrCmd::JUMP_IF_TRUTHY: { @@ -1029,8 +1042,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::CHECK_TAG: { bool continueInVm = (inst.d.kind == IrOpKind::Constant && intOp(inst.d)); - Label abort; // used when guard aborts execution - Label& fail = inst.c.kind == IrOpKind::Undef ? (continueInVm ? helpers.exitContinueVmClearNativeFlag : abort) : labelOp(inst.c); + Label fresh; // used when guard aborts execution or jumps to a VM exit + Label& fail = continueInVm ? helpers.exitContinueVmClearNativeFlag : getTargetLabel(inst.c, fresh); if (tagOp(inst.b) == 0) { build.cbnz(regOp(inst.a), fail); @@ -1040,55 +1053,43 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(regOp(inst.a), tagOp(inst.b)); build.b(ConditionA64::NotEqual, fail); } - if (abort.id && !continueInVm) - emitAbort(build, abort); + if (!continueInVm) + finalizeTargetLabel(inst.c, fresh); break; } case IrCmd::CHECK_READONLY: { - Label abort; // used when guard aborts execution + Label fresh; // used when guard aborts execution or jumps to a VM exit RegisterA64 temp = regs.allocTemp(KindA64::w); build.ldrb(temp, mem(regOp(inst.a), offsetof(Table, readonly))); - build.cbnz(temp, inst.b.kind == IrOpKind::Undef ? abort : labelOp(inst.b)); - if (abort.id) - emitAbort(build, abort); + build.cbnz(temp, getTargetLabel(inst.b, fresh)); + finalizeTargetLabel(inst.b, fresh); break; } case IrCmd::CHECK_NO_METATABLE: { - Label abort; // used when guard aborts execution + Label fresh; // used when guard aborts execution or jumps to a VM exit RegisterA64 temp = regs.allocTemp(KindA64::x); build.ldr(temp, mem(regOp(inst.a), offsetof(Table, metatable))); - build.cbnz(temp, inst.b.kind == IrOpKind::Undef ? abort : labelOp(inst.b)); - if (abort.id) - emitAbort(build, abort); + build.cbnz(temp, getTargetLabel(inst.b, fresh)); + finalizeTargetLabel(inst.b, fresh); break; } case IrCmd::CHECK_SAFE_ENV: { - Label abort; // used when guard aborts execution + Label fresh; // used when guard aborts execution or jumps to a VM exit RegisterA64 temp = regs.allocTemp(KindA64::x); RegisterA64 tempw = castReg(KindA64::w, temp); build.ldr(temp, mem(rClosure, offsetof(Closure, env))); build.ldrb(tempw, mem(temp, offsetof(Table, safeenv))); - - if (inst.a.kind == IrOpKind::Undef) - { - build.cbz(tempw, abort); - emitAbort(build, abort); - } - else - { - Label self; - build.cbz(tempw, self); - exitHandlers.push_back({self, uintOp(inst.a)}); - } + build.cbz(tempw, getTargetLabel(inst.a, fresh)); + finalizeTargetLabel(inst.a, fresh); break; } case IrCmd::CHECK_ARRAY_SIZE: { - Label abort; // used when guard aborts execution - Label& fail = inst.c.kind == IrOpKind::Undef ? abort : labelOp(inst.c); + Label fresh; // used when guard aborts execution or jumps to a VM exit + Label& fail = getTargetLabel(inst.c, fresh); RegisterA64 temp = regs.allocTemp(KindA64::w); build.ldr(temp, mem(regOp(inst.a), offsetof(Table, sizearray))); @@ -1120,8 +1121,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) else LUAU_ASSERT(!"Unsupported instruction form"); - if (abort.id) - emitAbort(build, abort); + finalizeTargetLabel(inst.c, fresh); break; } case IrCmd::JUMP_SLOT_MATCH: @@ -1158,15 +1158,13 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::CHECK_NODE_NO_NEXT: { - Label abort; // used when guard aborts execution + Label fresh; // used when guard aborts execution or jumps to a VM exit RegisterA64 temp = regs.allocTemp(KindA64::w); build.ldr(temp, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyTagNext)); build.lsr(temp, temp, kTKeyTagBits); - build.cbnz(temp, inst.b.kind == IrOpKind::Undef ? abort : labelOp(inst.b)); - - if (abort.id) - emitAbort(build, abort); + build.cbnz(temp, getTargetLabel(inst.b, fresh)); + finalizeTargetLabel(inst.b, fresh); break; } case IrCmd::INTERRUPT: @@ -1799,6 +1797,35 @@ void IrLoweringA64::jumpOrFallthrough(IrBlock& target, IrBlock& next) build.b(target.label); } +Label& IrLoweringA64::getTargetLabel(IrOp op, Label& fresh) +{ + if (op.kind == IrOpKind::Undef) + return fresh; + + if (op.kind == IrOpKind::VmExit) + { + if (uint32_t* index = exitHandlerMap.find(op.index)) + return exitHandlers[*index].self; + + return fresh; + } + + return labelOp(op); +} + +void IrLoweringA64::finalizeTargetLabel(IrOp op, Label& fresh) +{ + if (op.kind == IrOpKind::Undef) + { + emitAbort(build, fresh); + } + else if (op.kind == IrOpKind::VmExit && fresh.id != 0) + { + exitHandlerMap[op.index] = uint32_t(exitHandlers.size()); + exitHandlers.push_back({fresh, op.index}); + } +} + RegisterA64 IrLoweringA64::tempDouble(IrOp op) { if (op.kind == IrOpKind::Inst) diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 5b1968892..72b0da2f6 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/AssemblyBuilderA64.h" +#include "Luau/DenseHash.h" #include "Luau/IrData.h" #include "IrRegAllocA64.h" @@ -33,6 +34,9 @@ struct IrLoweringA64 bool isFallthroughBlock(IrBlock target, IrBlock next); void jumpOrFallthrough(IrBlock& target, IrBlock& next); + Label& getTargetLabel(IrOp op, Label& fresh); + void finalizeTargetLabel(IrOp op, Label& fresh); + // Operand data build helpers // May emit data/address synthesis instructions RegisterA64 tempDouble(IrOp op); @@ -77,6 +81,7 @@ struct IrLoweringA64 std::vector interruptHandlers; std::vector exitHandlers; + DenseHashMap exitHandlerMap; bool error = false; }; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 813f5123e..670d60666 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -28,6 +28,7 @@ IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, , function(function) , regs(build, function) , valueTracker(function) + , exitHandlerMap(~0u) { // In order to allocate registers during lowering, we need to know where instruction results are last used updateLastUseLocations(function); @@ -492,8 +493,17 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.jcc(ConditionX64::NotEqual, savezero); } - build.cmp(regOp(inst.b), 0); - build.jcc(ConditionX64::Equal, saveone); + if (inst.b.kind == IrOpKind::Constant) + { + // If value is 1, we fallthrough to storing 0 + if (intOp(inst.b) == 0) + build.jmp(saveone); + } + else + { + build.cmp(regOp(inst.b), 0); + build.jcc(ConditionX64::Equal, saveone); + } build.setLabel(savezero); build.mov(inst.regX64, 0); @@ -506,7 +516,24 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::JUMP: - jumpOrFallthrough(blockOp(inst.a), next); + if (inst.a.kind == IrOpKind::VmExit) + { + if (uint32_t* index = exitHandlerMap.find(inst.a.index)) + { + build.jmp(exitHandlers[*index].self); + } + else + { + Label self; + build.jmp(self); + exitHandlerMap[inst.a.index] = uint32_t(exitHandlers.size()); + exitHandlers.push_back({self, inst.a.index}); + } + } + else + { + jumpOrFallthrough(blockOp(inst.a), next); + } break; case IrCmd::JUMP_IF_TRUTHY: jumpIfTruthy(build, vmRegOp(inst.a), labelOp(inst.b), labelOp(inst.c)); @@ -907,19 +934,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); build.cmp(byte[tmp.reg + offsetof(Table, safeenv)], 0); - if (inst.a.kind == IrOpKind::Undef) - { - Label skip; - build.jcc(ConditionX64::NotEqual, skip); - build.ud2(); - build.setLabel(skip); - } - else - { - Label self; - build.jcc(ConditionX64::Equal, self); - exitHandlers.push_back({self, uintOp(inst.a)}); - } + jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.a); break; } case IrCmd::CHECK_ARRAY_SIZE: @@ -1473,6 +1488,20 @@ void IrLoweringX64::jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInver build.ud2(); build.setLabel(skip); } + else if (targetOrUndef.kind == IrOpKind::VmExit) + { + if (uint32_t* index = exitHandlerMap.find(targetOrUndef.index)) + { + build.jcc(cond, exitHandlers[*index].self); + } + else + { + Label self; + build.jcc(cond, self); + exitHandlerMap[targetOrUndef.index] = uint32_t(exitHandlers.size()); + exitHandlers.push_back({self, targetOrUndef.index}); + } + } else { build.jcc(cond, labelOp(targetOrUndef)); diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index 8ea4b41eb..a8dab3c99 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/AssemblyBuilderX64.h" +#include "Luau/DenseHash.h" #include "Luau/IrData.h" #include "Luau/IrRegAllocX64.h" @@ -77,6 +78,7 @@ struct IrLoweringX64 std::vector interruptHandlers; std::vector exitHandlers; + DenseHashMap exitHandlerMap; }; } // namespace X64 diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index 960be4ed8..73055c393 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -18,12 +18,12 @@ namespace Luau namespace CodeGen { -static void builtinCheckDouble(IrBuilder& build, IrOp arg, IrOp fallback) +static void builtinCheckDouble(IrBuilder& build, IrOp arg, int pcpos) { if (arg.kind == IrOpKind::Constant) LUAU_ASSERT(build.function.constOp(arg).kind == IrConstKind::Double); else - build.loadAndCheckTag(arg, LUA_TNUMBER, fallback); + build.loadAndCheckTag(arg, LUA_TNUMBER, build.vmExit(pcpos)); } static IrOp builtinLoadDouble(IrBuilder& build, IrOp arg) @@ -38,27 +38,27 @@ static IrOp builtinLoadDouble(IrBuilder& build, IrOp arg) // (number, ...) -> number static BuiltinImplResult translateBuiltinNumberToNumber( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinNumberToNumberLibm( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(bfid), va); @@ -68,17 +68,17 @@ static BuiltinImplResult translateBuiltinNumberToNumberLibm( if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltin2NumberToNumberLibm( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vb = builtinLoadDouble(build, args); @@ -90,17 +90,17 @@ static BuiltinImplResult translateBuiltin2NumberToNumberLibm( if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinMathLdexp( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vb = builtinLoadDouble(build, args); @@ -114,17 +114,17 @@ static BuiltinImplResult translateBuiltinMathLdexp( if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } // (number, ...) -> (number, number) static BuiltinImplResult translateBuiltinNumberTo2Number( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults > 2) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); build.inst( IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(nresults == 1 ? 1 : 2)); @@ -134,7 +134,7 @@ static BuiltinImplResult translateBuiltinNumberTo2Number( if (nresults != 1) build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 2}; + return {BuiltinImplType::Full, 2}; } static BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) @@ -151,12 +151,12 @@ static BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, i return {BuiltinImplType::UsesFallback, 0}; } -static BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); const double rpd = (3.14159265358979323846 / 180.0); @@ -167,15 +167,15 @@ static BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); const double rpd = (3.14159265358979323846 / 180.0); @@ -186,11 +186,11 @@ static BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinMathLog( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -213,7 +213,7 @@ static BuiltinImplResult translateBuiltinMathLog( denom = log(*y); } - builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); @@ -227,19 +227,19 @@ static BuiltinImplResult translateBuiltinMathLog( if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 2 || nparams > kMinMaxUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); for (int i = 3; i <= nparams; ++i) - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), pcpos); IrOp varg1 = builtinLoadDouble(build, build.vmReg(arg)); IrOp varg2 = builtinLoadDouble(build, args); @@ -257,19 +257,19 @@ static BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 2 || nparams > kMinMaxUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); for (int i = 3; i <= nparams; ++i) - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), pcpos); IrOp varg1 = builtinLoadDouble(build, build.vmReg(arg)); IrOp varg2 = builtinLoadDouble(build, args); @@ -287,10 +287,10 @@ static BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -299,9 +299,9 @@ static BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams LUAU_ASSERT(args.kind == IrOpKind::VmReg); - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), pcpos); IrOp min = builtinLoadDouble(build, args); IrOp max = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); @@ -321,12 +321,12 @@ static BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams return {BuiltinImplType::UsesFallback, 1}; } -static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); IrOp varg = builtinLoadDouble(build, build.vmReg(arg)); IrOp result = build.inst(cmd, varg); @@ -336,10 +336,10 @@ static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -350,10 +350,10 @@ static BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int build.inst(IrCmd::STORE_POINTER, build.vmReg(ra), name); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -363,20 +363,20 @@ static BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, i build.inst(IrCmd::STORE_POINTER, build.vmReg(ra), name); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinBit32BinaryOp( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 2 || nparams > kBit32BinaryOpUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); for (int i = 3; i <= nparams; ++i) - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vb = builtinLoadDouble(build, args); @@ -433,16 +433,16 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp( build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); } - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinBit32Bnot( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); @@ -454,19 +454,19 @@ static BuiltinImplResult translateBuiltinBit32Bnot( if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinBit32Shift( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; IrOp block = build.block(IrBlockKind::Internal); - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vb = builtinLoadDouble(build, args); @@ -499,13 +499,13 @@ static BuiltinImplResult translateBuiltinBit32Shift( } static BuiltinImplResult translateBuiltinBit32Rotate( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vb = builtinLoadDouble(build, args); @@ -522,17 +522,17 @@ static BuiltinImplResult translateBuiltinBit32Rotate( if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinBit32Extract( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vb = builtinLoadDouble(build, args); @@ -553,7 +553,7 @@ static BuiltinImplResult translateBuiltinBit32Extract( } else { - builtinCheckDouble(build, build.vmReg(args.index + 1), fallback); + builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos); IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vc); @@ -586,12 +586,12 @@ static BuiltinImplResult translateBuiltinBit32Extract( } static BuiltinImplResult translateBuiltinBit32ExtractK( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); @@ -613,16 +613,16 @@ static BuiltinImplResult translateBuiltinBit32ExtractK( if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinBit32Countz( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); @@ -637,18 +637,18 @@ static BuiltinImplResult translateBuiltinBit32Countz( if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinBit32Replace( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); - builtinCheckDouble(build, build.vmReg(args.index + 1), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); + builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vb = builtinLoadDouble(build, args); @@ -678,7 +678,7 @@ static BuiltinImplResult translateBuiltinBit32Replace( } else { - builtinCheckDouble(build, build.vmReg(args.index + 2), fallback); + builtinCheckDouble(build, build.vmReg(args.index + 2), pcpos); IrOp vd = builtinLoadDouble(build, build.vmReg(args.index + 2)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vd); @@ -716,16 +716,16 @@ static BuiltinImplResult translateBuiltinBit32Replace( return {BuiltinImplType::UsesFallback, 1}; } -static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; LUAU_ASSERT(LUA_VECTOR_SIZE == 3); - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), fallback); + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), pcpos); IrOp x = builtinLoadDouble(build, build.vmReg(arg)); IrOp y = builtinLoadDouble(build, args); @@ -734,15 +734,15 @@ static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, i build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), x, y, z); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TSTRING, fallback); + build.loadAndCheckTag(build.vmReg(arg), LUA_TSTRING, build.vmExit(pcpos)); IrOp ts = build.inst(IrCmd::LOAD_POINTER, build.vmReg(arg)); @@ -751,10 +751,10 @@ static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), build.inst(IrCmd::INT_TO_NUM, len)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - return {BuiltinImplType::UsesFallback, 1}; + return {BuiltinImplType::Full, 1}; } -BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback) +BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback, int pcpos) { // Builtins are not allowed to handle variadic arguments if (nparams == LUA_MULTRET) @@ -765,27 +765,27 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_ASSERT: return translateBuiltinAssert(build, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_DEG: - return translateBuiltinMathDeg(build, nparams, ra, arg, args, nresults, fallback); + return translateBuiltinMathDeg(build, nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_RAD: - return translateBuiltinMathRad(build, nparams, ra, arg, args, nresults, fallback); + return translateBuiltinMathRad(build, nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_LOG: - return translateBuiltinMathLog(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinMathLog(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_MIN: - return translateBuiltinMathMin(build, nparams, ra, arg, args, nresults, fallback); + return translateBuiltinMathMin(build, nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_MAX: - return translateBuiltinMathMax(build, nparams, ra, arg, args, nresults, fallback); + return translateBuiltinMathMax(build, nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_CLAMP: - return translateBuiltinMathClamp(build, nparams, ra, arg, args, nresults, fallback); + return translateBuiltinMathClamp(build, nparams, ra, arg, args, nresults, fallback, pcpos); case LBF_MATH_FLOOR: - return translateBuiltinMathUnary(build, IrCmd::FLOOR_NUM, nparams, ra, arg, nresults, fallback); + return translateBuiltinMathUnary(build, IrCmd::FLOOR_NUM, nparams, ra, arg, nresults, pcpos); case LBF_MATH_CEIL: - return translateBuiltinMathUnary(build, IrCmd::CEIL_NUM, nparams, ra, arg, nresults, fallback); + return translateBuiltinMathUnary(build, IrCmd::CEIL_NUM, nparams, ra, arg, nresults, pcpos); case LBF_MATH_SQRT: - return translateBuiltinMathUnary(build, IrCmd::SQRT_NUM, nparams, ra, arg, nresults, fallback); + return translateBuiltinMathUnary(build, IrCmd::SQRT_NUM, nparams, ra, arg, nresults, pcpos); case LBF_MATH_ABS: - return translateBuiltinMathUnary(build, IrCmd::ABS_NUM, nparams, ra, arg, nresults, fallback); + return translateBuiltinMathUnary(build, IrCmd::ABS_NUM, nparams, ra, arg, nresults, pcpos); case LBF_MATH_ROUND: - return translateBuiltinMathUnary(build, IrCmd::ROUND_NUM, nparams, ra, arg, nresults, fallback); + return translateBuiltinMathUnary(build, IrCmd::ROUND_NUM, nparams, ra, arg, nresults, pcpos); case LBF_MATH_EXP: case LBF_MATH_ASIN: case LBF_MATH_SIN: @@ -797,49 +797,49 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_TAN: case LBF_MATH_TANH: case LBF_MATH_LOG10: - return translateBuiltinNumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinNumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_SIGN: - return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_POW: case LBF_MATH_FMOD: case LBF_MATH_ATAN2: - return translateBuiltin2NumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltin2NumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_LDEXP: - return translateBuiltinMathLdexp(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinMathLdexp(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_FREXP: case LBF_MATH_MODF: - return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_BAND: case LBF_BIT32_BOR: case LBF_BIT32_BXOR: case LBF_BIT32_BTEST: - return translateBuiltinBit32BinaryOp(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinBit32BinaryOp(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_BNOT: - return translateBuiltinBit32Bnot(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinBit32Bnot(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_LSHIFT: case LBF_BIT32_RSHIFT: case LBF_BIT32_ARSHIFT: - return translateBuiltinBit32Shift(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinBit32Shift(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback, pcpos); case LBF_BIT32_LROTATE: case LBF_BIT32_RROTATE: - return translateBuiltinBit32Rotate(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinBit32Rotate(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_EXTRACT: - return translateBuiltinBit32Extract(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinBit32Extract(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback, pcpos); case LBF_BIT32_EXTRACTK: - return translateBuiltinBit32ExtractK(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinBit32ExtractK(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_COUNTLZ: case LBF_BIT32_COUNTRZ: - return translateBuiltinBit32Countz(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinBit32Countz(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_REPLACE: - return translateBuiltinBit32Replace(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinBit32Replace(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback, pcpos); case LBF_TYPE: - return translateBuiltinType(build, nparams, ra, arg, args, nresults, fallback); + return translateBuiltinType(build, nparams, ra, arg, args, nresults); case LBF_TYPEOF: - return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults, fallback); + return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults); case LBF_VECTOR: - return translateBuiltinVector(build, nparams, ra, arg, args, nresults, fallback); + return translateBuiltinVector(build, nparams, ra, arg, args, nresults, pcpos); case LBF_STRING_LEN: - return translateBuiltinStringLen(build, nparams, ra, arg, args, nresults, fallback); + return translateBuiltinStringLen(build, nparams, ra, arg, args, nresults, pcpos); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslateBuiltins.h b/CodeGen/src/IrTranslateBuiltins.h index 945b32f3f..8ae64b945 100644 --- a/CodeGen/src/IrTranslateBuiltins.h +++ b/CodeGen/src/IrTranslateBuiltins.h @@ -13,6 +13,7 @@ enum class BuiltinImplType { None, UsesFallback, // Uses fallback for unsupported cases + Full, // Is either implemented in full, or exits to VM }; struct BuiltinImplResult @@ -21,7 +22,7 @@ struct BuiltinImplResult int actualResultCount; }; -BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback); +BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback, int pcpos); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 3cbcd3cbd..5cde510ff 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -514,7 +514,7 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc) build.inst(IrCmd::CLOSE_UPVALS, build.vmReg(ra)); } -void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next) +IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs) { LuauOpcode opcode = LuauOpcode(LUAU_INSN_OP(*pc)); int bfid = LUAU_INSN_A(*pc); @@ -542,16 +542,25 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool IrOp fallback = build.block(IrBlockKind::Fallback); // In unsafe environment, instead of retrying fastcall at 'pcpos' we side-exit directly to fallback sequence - build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos + getOpLength(opcode))); + build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos + getOpLength(opcode))); - BuiltinImplResult br = translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, nparams, nresults, fallback); + BuiltinImplResult br = + translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, nparams, nresults, fallback, pcpos + getOpLength(opcode)); - if (br.type == BuiltinImplType::UsesFallback) + if (br.type != BuiltinImplType::None) { LUAU_ASSERT(nparams != LUA_MULTRET && "builtins are not allowed to handle variadic arguments"); if (nresults == LUA_MULTRET) build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(ra), build.constInt(br.actualResultCount)); + + if (br.type != BuiltinImplType::UsesFallback) + { + // We ended up not using the fallback block, kill it + build.function.blockOp(fallback).kind = IrBlockKind::Dead; + + return build.undef(); + } } else { @@ -568,10 +577,7 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool build.inst(IrCmd::ADJUST_STACK_TO_TOP); } - build.inst(IrCmd::JUMP, next); - - // this will be filled with IR corresponding to instructions after FASTCALL until skip+1 - build.beginBlock(fallback); + return fallback; } void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) @@ -670,7 +676,7 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo IrOp fallback = build.block(IrBlockKind::Fallback); // fast-path: pairs/next - build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos)); + build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos)); IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback); IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); @@ -697,7 +703,7 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp IrOp finish = build.block(IrBlockKind::Internal); // fast-path: ipairs/inext - build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos)); + build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos)); IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback); IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); @@ -923,7 +929,7 @@ void translateInstGetImport(IrBuilder& build, const Instruction* pc, int pcpos) IrOp fastPath = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); - build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos)); + build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos)); // note: if import failed, k[] is nil; we could check this during codegen, but we instead use runtime fallback // this allows us to handle ahead-of-time codegen smoothly when an import fails to resolve at runtime diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 38dcdd40f..0c24b27da 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -45,7 +45,7 @@ void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc); -void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next); +IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs); void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos); diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index eeedd6cb0..72869ad13 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -977,7 +977,7 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite // Unconditional jump into a block with a single user (current block) allows us to continue optimization // with the information we have gathered so far (unless we have already visited that block earlier) - if (termInst.cmd == IrCmd::JUMP) + if (termInst.cmd == IrCmd::JUMP && termInst.a.kind != IrOpKind::VmExit) { IrBlock& target = function.blockOp(termInst.a); uint32_t targetIdx = function.getBlockIndex(target); @@ -1011,7 +1011,7 @@ static std::vector collectDirectBlockJumpPath(IrFunction& function, st IrBlock* nextBlock = nullptr; // A chain is made from internal blocks that were not a part of bytecode CFG - if (termInst.cmd == IrCmd::JUMP) + if (termInst.cmd == IrCmd::JUMP && termInst.a.kind != IrOpKind::VmExit) { IrBlock& target = function.blockOp(termInst.a); uint32_t targetIdx = function.getBlockIndex(target); @@ -1052,6 +1052,10 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited if (termInst.cmd != IrCmd::JUMP) return; + // And it can't be jump to a VM exit + if (termInst.a.kind == IrOpKind::VmExit) + return; + // And it has to jump to a block with more than one user // If there's only one use, it should already be optimized by constPropInBlockChain if (function.blockOp(termInst.a).useCount == 1) @@ -1084,7 +1088,8 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited build.beginBlock(newBlock); - // By default, blocks are ordered according to start instruction; we alter sort order to make sure linearized block is placed right after the starting block + // By default, blocks are ordered according to start instruction; we alter sort order to make sure linearized block is placed right after the + // starting block function.blocks[newBlock.index].sortkey = startingInsn + 1; replace(function, termInst.a, newBlock); diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index eab57b17a..7b3a057b7 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -428,7 +428,7 @@ enum LuauBytecodeTag }; // Type table tags -enum LuauBytecodeEncodedType +enum LuauBytecodeType { LBC_TYPE_NIL = 0, LBC_TYPE_BOOLEAN, diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index f3c2f47d7..3044e4483 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -47,7 +47,7 @@ class BytecodeBuilder BytecodeBuilder(BytecodeEncoder* encoder = 0); uint32_t beginFunction(uint8_t numparams, bool isvararg = false); - void endFunction(uint8_t maxstacksize, uint8_t numupvalues); + void endFunction(uint8_t maxstacksize, uint8_t numupvalues, uint8_t flags = 0); void setMainFunction(uint32_t fid); @@ -274,7 +274,7 @@ class BytecodeBuilder void dumpConstant(std::string& result, int k) const; void dumpInstruction(const uint32_t* opcode, std::string& output, int targetLabel) const; - void writeFunction(std::string& ss, uint32_t id) const; + void writeFunction(std::string& ss, uint32_t id, uint8_t flags) const; void writeLineInfo(std::string& ss) const; void writeStringTable(std::string& ss) const; diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index eec70d7a1..36a21a72c 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -35,6 +35,9 @@ struct CompileOptions const char* vectorLib = nullptr; const char* vectorCtor = nullptr; + // vector type name for type tables; disabled by default + const char* vectorType = nullptr; + // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these const char** mutableGlobals = nullptr; }; diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h index 5f69f69e5..7c59ce0b2 100644 --- a/Compiler/include/luacode.h +++ b/Compiler/include/luacode.h @@ -31,6 +31,9 @@ struct lua_CompileOptions const char* vectorLib; const char* vectorCtor; + // vector type name for type tables; disabled by default + const char* vectorType; + // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these const char** mutableGlobals; }; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 8d360f875..eeb9c10e0 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -249,7 +249,7 @@ uint32_t BytecodeBuilder::beginFunction(uint8_t numparams, bool isvararg) return id; } -void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) +void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues, uint8_t flags) { LUAU_ASSERT(currentFunction != ~0u); @@ -265,7 +265,7 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) // very approximate: 4 bytes per instruction for code, 1 byte for debug line, and 1-2 bytes for aux data like constants plus overhead func.data.reserve(32 + insns.size() * 7); - writeFunction(func.data, currentFunction); + writeFunction(func.data, currentFunction, flags); currentFunction = ~0u; @@ -631,7 +631,7 @@ void BytecodeBuilder::finalize() writeVarInt(bytecode, mainFunction); } -void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const +void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id, uint8_t flags) const { LUAU_ASSERT(id < functions.size()); const Function& func = functions[id]; @@ -644,7 +644,7 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const if (FFlag::BytecodeVersion4) { - writeByte(ss, 0); // Reserved for cgflags + writeByte(ss, flags); writeVarInt(ss, uint32_t(func.typeinfo.size())); ss.append(func.typeinfo); @@ -1213,10 +1213,15 @@ void BytecodeBuilder::validateInstructions() const break; case LOP_GETIMPORT: + { VREG(LUAU_INSN_A(insn)); VCONST(LUAU_INSN_D(insn), Import); - // TODO: check insn[i + 1] for conformance with 10-bit import encoding - break; + uint32_t id = insns[i + 1]; + LUAU_ASSERT((id >> 30) != 0); // import chain with length 1-3 + for (unsigned int j = 0; j < (id >> 30); ++j) + VCONST((id >> (20 - 10 * j)) & 1023, String); + } + break; case LOP_GETTABLE: case LOP_SETTABLE: diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 83aad3d82..fe65f67a1 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -3874,7 +3874,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c if (FFlag::LuauCompileFunctionType) { - buildTypeMap(compiler.typeMap, root); + buildTypeMap(compiler.typeMap, root, options.vectorType); } // gathers all functions with the invariant that all function references are to functions earlier in the list diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index b8db7fb27..e85cc92c4 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -15,7 +15,7 @@ static bool isGeneric(AstName name, const AstArray& generics) return false; } -static LuauBytecodeEncodedType getPrimitiveType(AstName name) +static LuauBytecodeType getPrimitiveType(AstName name) { if (name == "nil") return LBC_TYPE_NIL; @@ -33,8 +33,8 @@ static LuauBytecodeEncodedType getPrimitiveType(AstName name) return LBC_TYPE_INVALID; } -static LuauBytecodeEncodedType getType( - AstType* ty, const AstArray& generics, const DenseHashMap& typeAliases, bool resolveAliases) +static LuauBytecodeType getType(AstType* ty, const AstArray& generics, const DenseHashMap& typeAliases, + bool resolveAliases, const char* vectorType) { if (AstTypeReference* ref = ty->as()) { @@ -45,7 +45,7 @@ static LuauBytecodeEncodedType getType( { // note: we only resolve aliases to the depth of 1 to avoid dealing with recursive aliases if (resolveAliases) - return getType((*alias)->type, (*alias)->generics, typeAliases, /* resolveAliases= */ false); + return getType((*alias)->type, (*alias)->generics, typeAliases, /* resolveAliases= */ false, vectorType); else return LBC_TYPE_ANY; } @@ -53,7 +53,10 @@ static LuauBytecodeEncodedType getType( if (isGeneric(ref->name, generics)) return LBC_TYPE_ANY; - if (LuauBytecodeEncodedType prim = getPrimitiveType(ref->name); prim != LBC_TYPE_INVALID) + if (vectorType && ref->name == vectorType) + return LBC_TYPE_VECTOR; + + if (LuauBytecodeType prim = getPrimitiveType(ref->name); prim != LBC_TYPE_INVALID) return prim; // not primitive or alias or generic => host-provided, we assume userdata for now @@ -70,11 +73,11 @@ static LuauBytecodeEncodedType getType( else if (AstTypeUnion* un = ty->as()) { bool optional = false; - LuauBytecodeEncodedType type = LBC_TYPE_INVALID; + LuauBytecodeType type = LBC_TYPE_INVALID; for (AstType* ty : un->types) { - LuauBytecodeEncodedType et = getType(ty, generics, typeAliases, resolveAliases); + LuauBytecodeType et = getType(ty, generics, typeAliases, resolveAliases, vectorType); if (et == LBC_TYPE_NIL) { @@ -95,7 +98,7 @@ static LuauBytecodeEncodedType getType( if (type == LBC_TYPE_INVALID) return LBC_TYPE_ANY; - return LuauBytecodeEncodedType(type | (optional && (type != LBC_TYPE_ANY) ? LBC_TYPE_OPTIONAL_BIT : 0)); + return LuauBytecodeType(type | (optional && (type != LBC_TYPE_ANY) ? LBC_TYPE_OPTIONAL_BIT : 0)); } else if (AstTypeIntersection* inter = ty->as()) { @@ -105,7 +108,7 @@ static LuauBytecodeEncodedType getType( return LBC_TYPE_ANY; } -static std::string getFunctionType(const AstExprFunction* func, const DenseHashMap& typeAliases) +static std::string getFunctionType(const AstExprFunction* func, const DenseHashMap& typeAliases, const char* vectorType) { bool self = func->self != 0; @@ -121,8 +124,8 @@ static std::string getFunctionType(const AstExprFunction* func, const DenseHashM bool haveNonAnyParam = false; for (AstLocal* arg : func->args) { - LuauBytecodeEncodedType ty = - arg->annotation ? getType(arg->annotation, func->generics, typeAliases, /* resolveAliases= */ true) : LBC_TYPE_ANY; + LuauBytecodeType ty = + arg->annotation ? getType(arg->annotation, func->generics, typeAliases, /* resolveAliases= */ true, vectorType) : LBC_TYPE_ANY; if (ty != LBC_TYPE_ANY) haveNonAnyParam = true; @@ -140,12 +143,14 @@ static std::string getFunctionType(const AstExprFunction* func, const DenseHashM struct TypeMapVisitor : AstVisitor { DenseHashMap& typeMap; + const char* vectorType; DenseHashMap typeAliases; std::vector> typeAliasStack; - TypeMapVisitor(DenseHashMap& typeMap) + TypeMapVisitor(DenseHashMap& typeMap, const char* vectorType) : typeMap(typeMap) + , vectorType(vectorType) , typeAliases(AstName()) { } @@ -206,7 +211,7 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprFunction* node) override { - std::string type = getFunctionType(node, typeAliases); + std::string type = getFunctionType(node, typeAliases, vectorType); if (!type.empty()) typeMap[node] = std::move(type); @@ -215,9 +220,9 @@ struct TypeMapVisitor : AstVisitor } }; -void buildTypeMap(DenseHashMap& typeMap, AstNode* root) +void buildTypeMap(DenseHashMap& typeMap, AstNode* root, const char* vectorType) { - TypeMapVisitor visitor(typeMap); + TypeMapVisitor visitor(typeMap, vectorType); root->visit(&visitor); } diff --git a/Compiler/src/Types.h b/Compiler/src/Types.h index c3dd16209..cad55ab54 100644 --- a/Compiler/src/Types.h +++ b/Compiler/src/Types.h @@ -8,6 +8,6 @@ namespace Luau { -void buildTypeMap(DenseHashMap& typeMap, AstNode* root); +void buildTypeMap(DenseHashMap& typeMap, AstNode* root, const char* vectorType); } // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index ccf2e1df9..2a58f061d 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -143,6 +143,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Autocomplete.h Analysis/include/Luau/Breadcrumb.h Analysis/include/Luau/BuiltinDefinitions.h + Analysis/include/Luau/Cancellation.h Analysis/include/Luau/Clone.h Analysis/include/Luau/Config.h Analysis/include/Luau/Constraint.h diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index d3e21f5de..cb3c9d3bb 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -35,6 +35,10 @@ int lua_getargument(lua_State* L, int level, int n) return 0; CallInfo* ci = L->ci - level; + // changing tables in native functions externally may invalidate safety contracts wrt table state (metatable/size/readonly) + if (ci->flags & LUA_CALLINFO_NATIVE) + return 0; + Proto* fp = getluaproto(ci); int res = 0; @@ -60,9 +64,13 @@ int lua_getargument(lua_State* L, int level, int n) const char* lua_getlocal(lua_State* L, int level, int n) { if (unsigned(level) >= unsigned(L->ci - L->base_ci)) - return 0; + return NULL; CallInfo* ci = L->ci - level; + // changing tables in native functions externally may invalidate safety contracts wrt table state (metatable/size/readonly) + if (ci->flags & LUA_CALLINFO_NATIVE) + return NULL; + Proto* fp = getluaproto(ci); const LocVar* var = fp ? luaF_getlocal(fp, n, currentpc(L, ci)) : NULL; if (var) @@ -77,9 +85,13 @@ const char* lua_getlocal(lua_State* L, int level, int n) const char* lua_setlocal(lua_State* L, int level, int n) { if (unsigned(level) >= unsigned(L->ci - L->base_ci)) - return 0; + return NULL; CallInfo* ci = L->ci - level; + // changing registers in native functions externally may invalidate safety contracts wrt register type tags + if (ci->flags & LUA_CALLINFO_NATIVE) + return NULL; + Proto* fp = getluaproto(ci); const LocVar* var = fp ? luaF_getlocal(fp, n, currentpc(L, ci)) : NULL; if (var) @@ -321,7 +333,8 @@ void luaG_pusherror(lua_State* L, const char* error) void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable) { - if (p->lineinfo) + // since native code doesn't support breakpoints, we would need to update all call frames with LUAU_CALLINFO_NATIVE that refer to p + if (p->lineinfo && !p->execdata) { for (int i = 0; i < p->sizecode; ++i) { @@ -347,11 +360,6 @@ void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable) p->code[i] |= op; LUAU_ASSERT(LUAU_INSN_OP(p->code[i]) == op); -#if LUA_CUSTOM_EXECUTION - if (L->global->ecb.setbreakpoint) - L->global->ecb.setbreakpoint(L, p, i); -#endif - // note: this is important! // we only patch the *first* instruction in each proto that's attributed to a given line // this can be changed, but if requires making patching a bit more nuanced so that we don't patch AUX words @@ -410,11 +418,11 @@ static int getmaxline(Proto* p) return result; } -// Find the line number with instructions. If the provided line doesn't have any instruction, it should return the next line number with -// instructions. +// Find the line number with instructions. If the provided line doesn't have any instruction, it should return the next valid line number. static int getnextline(Proto* p, int line) { int closest = -1; + if (p->lineinfo) { for (int i = 0; i < p->sizecode; ++i) @@ -435,7 +443,6 @@ static int getnextline(Proto* p, int line) for (int i = 0; i < p->sizep; ++i) { - // Find the closest line number to the intended one. int candidate = getnextline(p->p[i], line); if (candidate == line) @@ -454,14 +461,12 @@ int lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) api_check(L, ttisfunction(func) && !clvalue(func)->isC); Proto* p = clvalue(func)->l.p; - // Find line number to add the breakpoint to. + + // set the breakpoint to the next closest line with valid instructions int target = getnextline(p, line); if (target != -1) - { - // Add breakpoint on the exact line luaG_breakpoint(L, p, target, bool(enabled)); - } return target; } diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index a47ad34f3..88a3e40ab 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -22,6 +22,7 @@ Proto* luaF_newproto(lua_State* L) f->numparams = 0; f->is_vararg = 0; f->maxstacksize = 0; + f->flags = 0; f->sizelineinfo = 0; f->linegaplog2 = 0; f->lineinfo = NULL; @@ -155,13 +156,8 @@ void luaF_freeproto(lua_State* L, Proto* f, lua_Page* page) if (f->debuginsn) luaM_freearray(L, f->debuginsn, f->sizecode, uint8_t, f->memcat); -#if LUA_CUSTOM_EXECUTION if (f->execdata) - { - LUAU_ASSERT(L->global->ecb.destroy); L->global->ecb.destroy(L, f); - } -#endif if (f->typeinfo) luaM_freearray(L, f->typeinfo, f->numparams + 2, uint8_t, f->memcat); diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 51216bd8e..ec7a6828f 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -134,5 +134,7 @@ LUAI_FUNC void luaC_barriertable(lua_State* L, Table* t, GCObject* v); LUAI_FUNC void luaC_barrierback(lua_State* L, GCObject* o, GCObject** gclist); LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); +LUAI_FUNC void luaC_enumheap(lua_State* L, void* context, void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, const char* name), + void (*edge)(void* context, void* from, void* to, const char* name)); LUAI_FUNC int64_t luaC_allocationrate(lua_State* L); LUAI_FUNC const char* luaC_statename(int state); diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index 2f9c1756a..e68b84a9d 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -602,3 +602,229 @@ void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* fprintf(f, "}\n"); fprintf(f, "}}\n"); } + +struct EnumContext +{ + lua_State* L; + void* context; + void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, const char* name); + void (*edge)(void* context, void* from, void* to, const char* name); +}; + +static void* enumtopointer(GCObject* gco) +{ + // To match lua_topointer, userdata pointer is represented as a pointer to internal data + return gco->gch.tt == LUA_TUSERDATA ? (void*)gco2u(gco)->data : (void*)gco; +} + +static void enumnode(EnumContext* ctx, GCObject* gco, const char* objname) +{ + ctx->node(ctx->context, enumtopointer(gco), gco->gch.tt, gco->gch.memcat, objname); +} + +static void enumedge(EnumContext* ctx, GCObject* from, GCObject* to, const char* edgename) +{ + ctx->edge(ctx->context, enumtopointer(from), enumtopointer(to), edgename); +} + +static void enumedges(EnumContext* ctx, GCObject* from, TValue* data, size_t size, const char* edgename) +{ + for (size_t i = 0; i < size; ++i) + { + if (iscollectable(&data[i])) + enumedge(ctx, from, gcvalue(&data[i]), edgename); + } +} + +static void enumstring(EnumContext* ctx, TString* ts) +{ + enumnode(ctx, obj2gco(ts), NULL); +} + +static void enumtable(EnumContext* ctx, Table* h) +{ + // Provide a name for a special registry table + enumnode(ctx, obj2gco(h), h == hvalue(registry(ctx->L)) ? "registry" : NULL); + + if (h->node != &luaH_dummynode) + { + for (int i = 0; i < sizenode(h); ++i) + { + const LuaNode& n = h->node[i]; + + if (!ttisnil(&n.val) && (iscollectable(&n.key) || iscollectable(&n.val))) + { + if (iscollectable(&n.key)) + enumedge(ctx, obj2gco(h), gcvalue(&n.key), "[key]"); + + if (iscollectable(&n.val)) + { + if (ttisstring(&n.key)) + { + enumedge(ctx, obj2gco(h), gcvalue(&n.val), svalue(&n.key)); + } + else if (ttisnumber(&n.key)) + { + char buf[32]; + snprintf(buf, sizeof(buf), "%.14g", nvalue(&n.key)); + enumedge(ctx, obj2gco(h), gcvalue(&n.val), buf); + } + else + { + enumedge(ctx, obj2gco(h), gcvalue(&n.val), NULL); + } + } + } + } + } + + if (h->sizearray) + enumedges(ctx, obj2gco(h), h->array, h->sizearray, "array"); + + if (h->metatable) + enumedge(ctx, obj2gco(h), obj2gco(h->metatable), "metatable"); +} + +static void enumclosure(EnumContext* ctx, Closure* cl) +{ + if (cl->isC) + { + enumnode(ctx, obj2gco(cl), cl->c.debugname); + } + else + { + Proto* p = cl->l.p; + + char buf[LUA_IDSIZE]; + + if (p->source) + snprintf(buf, sizeof(buf), "%s:%d %s", p->debugname ? getstr(p->debugname) : "", p->linedefined, getstr(p->source)); + else + snprintf(buf, sizeof(buf), "%s:%d", p->debugname ? getstr(p->debugname) : "", p->linedefined); + + enumnode(ctx, obj2gco(cl), buf); + } + + enumedge(ctx, obj2gco(cl), obj2gco(cl->env), "env"); + + if (cl->isC) + { + if (cl->nupvalues) + enumedges(ctx, obj2gco(cl), cl->c.upvals, cl->nupvalues, "upvalue"); + } + else + { + enumedge(ctx, obj2gco(cl), obj2gco(cl->l.p), "proto"); + + if (cl->nupvalues) + enumedges(ctx, obj2gco(cl), cl->l.uprefs, cl->nupvalues, "upvalue"); + } +} + +static void enumudata(EnumContext* ctx, Udata* u) +{ + enumnode(ctx, obj2gco(u), NULL); + + if (u->metatable) + enumedge(ctx, obj2gco(u), obj2gco(u->metatable), "metatable"); +} + +static void enumthread(EnumContext* ctx, lua_State* th) +{ + Closure* tcl = NULL; + for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) + { + if (ttisfunction(ci->func)) + { + tcl = clvalue(ci->func); + break; + } + } + + if (tcl && !tcl->isC && tcl->l.p->source) + { + Proto* p = tcl->l.p; + + enumnode(ctx, obj2gco(th), getstr(p->source)); + } + else + { + enumnode(ctx, obj2gco(th), NULL); + } + + enumedge(ctx, obj2gco(th), obj2gco(th->gt), "globals"); + + if (th->top > th->stack) + enumedges(ctx, obj2gco(th), th->stack, th->top - th->stack, "stack"); +} + +static void enumproto(EnumContext* ctx, Proto* p) +{ + enumnode(ctx, obj2gco(p), p->source ? getstr(p->source) : NULL); + + if (p->sizek) + enumedges(ctx, obj2gco(p), p->k, p->sizek, "constants"); + + for (int i = 0; i < p->sizep; ++i) + enumedge(ctx, obj2gco(p), obj2gco(p->p[i]), "protos"); +} + +static void enumupval(EnumContext* ctx, UpVal* uv) +{ + enumnode(ctx, obj2gco(uv), NULL); + + if (iscollectable(uv->v)) + enumedge(ctx, obj2gco(uv), gcvalue(uv->v), "value"); +} + +static void enumobj(EnumContext* ctx, GCObject* o) +{ + switch (o->gch.tt) + { + case LUA_TSTRING: + return enumstring(ctx, gco2ts(o)); + + case LUA_TTABLE: + return enumtable(ctx, gco2h(o)); + + case LUA_TFUNCTION: + return enumclosure(ctx, gco2cl(o)); + + case LUA_TUSERDATA: + return enumudata(ctx, gco2u(o)); + + case LUA_TTHREAD: + return enumthread(ctx, gco2th(o)); + + case LUA_TPROTO: + return enumproto(ctx, gco2p(o)); + + case LUA_TUPVAL: + return enumupval(ctx, gco2uv(o)); + + default: + LUAU_ASSERT(!"Unknown object tag"); + } +} + +static bool enumgco(void* context, lua_Page* page, GCObject* gco) +{ + enumobj((EnumContext*)context, gco); + return false; +} + +void luaC_enumheap(lua_State* L, void* context, void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, const char* name), + void (*edge)(void* context, void* from, void* to, const char* name)) +{ + global_State* g = L->global; + + EnumContext ctx; + ctx.L = L; + ctx.context = context; + ctx.node = node; + ctx.edge = edge; + + enumgco(&ctx, NULL, obj2gco(g->mainthread)); + + luaM_visitgco(L, &ctx, enumgco); +} diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 560969d6d..74ea16235 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -263,9 +263,22 @@ typedef struct Proto CommonHeader; + uint8_t nups; // number of upvalues + uint8_t numparams; + uint8_t is_vararg; + uint8_t maxstacksize; + uint8_t flags; + + TValue* k; // constants used by the function Instruction* code; // function bytecode struct Proto** p; // functions defined inside the function + const Instruction* codeentry; + + void* execdata; + uintptr_t exectarget; + + uint8_t* lineinfo; // for each instruction, line number as a delta from baseline int* abslineinfo; // baseline line info, one entry for each 1<memcatbytes[i] == 0); -#if LUA_CUSTOM_EXECUTION if (L->global->ecb.close) L->global->ecb.close(L); -#endif (*g->frealloc)(g->ud, L, sizeof(LG), 0); } diff --git a/VM/src/lstate.h b/VM/src/lstate.h index ca8bc1b31..a7346ccb2 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -154,7 +154,6 @@ struct lua_ExecutionCallbacks void (*close)(lua_State* L); // called when global VM state is closed void (*destroy)(lua_State* L, Proto* proto); // called when function is destroyed int (*enter)(lua_State* L, Proto* proto); // called when function is about to start/resume (when execdata is present), return 0 to exit VM - void (*setbreakpoint)(lua_State* L, Proto* proto, int line); // called when a breakpoint is set in a function }; /* diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index a26dd0b8f..7a065383a 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -230,8 +230,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size if (version >= 4) { - uint8_t cgflags = read(data, size, offset); - LUAU_ASSERT(cgflags == 0); + p->flags = read(data, size, offset); uint32_t typesize = readVarInt(data, size, offset); diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index c30e93775..e1ecaf65a 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -26,7 +26,7 @@ const bool kFuzzLinter = true; const bool kFuzzTypeck = true; const bool kFuzzVM = true; const bool kFuzzTranspile = true; -const bool kFuzzCodegen = true; +const bool kFuzzCodegenVM = true; const bool kFuzzCodegenAssembly = true; // Should we generate type annotations? @@ -35,7 +35,7 @@ const bool kFuzzTypes = true; const Luau::CodeGen::AssemblyOptions::Target kFuzzCodegenTarget = Luau::CodeGen::AssemblyOptions::A64; static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); -static_assert(!(kFuzzCodegen && !kFuzzVM), "Codegen requires the VM!"); +static_assert(!(kFuzzCodegenVM && !kFuzzCompiler), "Codegen requires the compiler!"); static_assert(!(kFuzzCodegenAssembly && !kFuzzCompiler), "Codegen requires the compiler!"); std::vector protoprint(const luau::ModuleSet& stat, bool types); @@ -47,6 +47,7 @@ LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAG(DebugLuauAbortingChecks) std::chrono::milliseconds kInterruptTimeout(10); std::chrono::time_point interruptDeadline; @@ -90,7 +91,7 @@ lua_State* createGlobalState() { lua_State* L = lua_newstate(allocate, NULL); - if (kFuzzCodegen && Luau::CodeGen::isSupported()) + if (kFuzzCodegenVM && Luau::CodeGen::isSupported()) Luau::CodeGen::create(L); lua_callbacks(L)->interrupt = interrupt; @@ -228,6 +229,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) flag->value = true; FFlag::DebugLuauFreezeArena.value = true; + FFlag::DebugLuauAbortingChecks.value = true; std::vector sources = protoprint(message, kFuzzTypes); @@ -370,7 +372,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) } // run resulting bytecode (from last successfully compiler module) - if (kFuzzVM && bytecode.size()) + if ((kFuzzVM || kFuzzCodegenVM) && bytecode.size()) { static lua_State* globalState = createGlobalState(); @@ -395,9 +397,10 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) LUAU_ASSERT(heapSize < 256 * 1024); }; - runCode(bytecode, false); + if (kFuzzVM) + runCode(bytecode, false); - if (kFuzzCodegen && Luau::CodeGen::isSupported()) + if (kFuzzCodegenVM && Luau::CodeGen::isSupported()) runCode(bytecode, true); } } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index fde6e90e9..db779da2b 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -53,7 +53,10 @@ static std::string compileTypeTable(const char* source) { Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); - Luau::compileOrThrow(bcb, source); + + Luau::CompileOptions opts; + opts.vectorType = "Vector3"; + Luau::compileOrThrow(bcb, source, opts); return bcb.dumpTypeInfo(); } @@ -7159,6 +7162,31 @@ end )"); } +TEST_CASE("HostTypesVector") +{ + ScopedFastFlag sff("LuauCompileFunctionType", true); + + CHECK_EQ("\n" + compileTypeTable(R"( +function myfunc(test: Instance, pos: Vector3) +end + +function myfunc2(test: Instance, pos: Vector3) +end + +do + type Vector3 = number + + function myfunc3(test: Instance, pos: Vector3) + end +end +)"), + R"( +0: function(userdata, vector) +1: function(userdata, any) +2: function(userdata, number) +)"); +} + TEST_CASE("TypeAliasScoping") { ScopedFastFlag sff("LuauCompileFunctionType", true); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 9b47b6f5d..c98dabb95 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -5,6 +5,7 @@ #include "luacodegen.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/DenseHash.h" #include "Luau/ModuleResolver.h" #include "Luau/TypeInfer.h" #include "Luau/StringUtils.h" @@ -15,6 +16,7 @@ #include "ScopedFlags.h" #include +#include #include #include @@ -1244,6 +1246,8 @@ TEST_CASE("GCDump") { // internal function, declared in lgc.h - not exposed via lua.h extern void luaC_dump(lua_State * L, void* file, const char* (*categoryName)(lua_State * L, uint8_t memcat)); + extern void luaC_enumheap(lua_State * L, void* context, void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, const char* name), + void (*edge)(void* context, void* from, void* to, const char* name)); StateRef globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); @@ -1287,6 +1291,40 @@ TEST_CASE("GCDump") luaC_dump(L, f, nullptr); fclose(f); + + struct Node + { + void* ptr; + uint8_t tag; + uint8_t memcat; + std::string name; + }; + + struct EnumContext + { + EnumContext() + : nodes{nullptr} + , edges{nullptr} + { + } + + Luau::DenseHashMap nodes; + Luau::DenseHashMap edges; + } ctx; + + luaC_enumheap( + L, &ctx, + [](void* ctx, void* gco, uint8_t tt, uint8_t memcat, const char* name) { + EnumContext& context = *(EnumContext*)ctx; + context.nodes[gco] = {gco, tt, memcat, name ? name : ""}; + }, + [](void* ctx, void* s, void* t, const char*) { + EnumContext& context = *(EnumContext*)ctx; + context.edges[s] = t; + }); + + CHECK(!ctx.nodes.empty()); + CHECK(!ctx.edges.empty()); } TEST_CASE("Interrupt") diff --git a/tests/Differ.test.cpp b/tests/Differ.test.cpp index 6b570f398..520c53021 100644 --- a/tests/Differ.test.cpp +++ b/tests/Differ.test.cpp @@ -1,15 +1,20 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Differ.h" +#include "Luau/Common.h" #include "Luau/Error.h" #include "Luau/Frontend.h" #include "Fixture.h" +#include "Luau/Symbol.h" +#include "ScopedFlags.h" #include "doctest.h" #include using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + TEST_SUITE_BEGIN("Differ"); TEST_CASE_FIXTURE(Fixture, "equal_numbers") @@ -28,7 +33,7 @@ TEST_CASE_FIXTURE(Fixture, "equal_numbers") DifferResult diffRes = diff(foo, almostFoo); CHECK(!diffRes.diffError.has_value()); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -51,7 +56,7 @@ TEST_CASE_FIXTURE(Fixture, "equal_strings") DifferResult diffRes = diff(foo, almostFoo); CHECK(!diffRes.diffError.has_value()); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -74,7 +79,7 @@ TEST_CASE_FIXTURE(Fixture, "equal_tables") DifferResult diffRes = diff(foo, almostFoo); CHECK(!diffRes.diffError.has_value()); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -97,7 +102,7 @@ TEST_CASE_FIXTURE(Fixture, "a_table_missing_property") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -123,7 +128,7 @@ TEST_CASE_FIXTURE(Fixture, "left_table_missing_property") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -149,7 +154,7 @@ TEST_CASE_FIXTURE(Fixture, "a_table_wrong_type") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -175,7 +180,7 @@ TEST_CASE_FIXTURE(Fixture, "a_table_wrong_type") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -201,7 +206,7 @@ TEST_CASE_FIXTURE(Fixture, "a_nested_table_wrong_type") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -227,7 +232,7 @@ TEST_CASE_FIXTURE(Fixture, "a_nested_table_wrong_match") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -253,7 +258,7 @@ TEST_CASE_FIXTURE(Fixture, "singleton") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -280,7 +285,7 @@ TEST_CASE_FIXTURE(Fixture, "equal_singleton") INFO(diffRes.diffError->toString()); CHECK(!diffRes.diffError.has_value()); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -303,7 +308,7 @@ TEST_CASE_FIXTURE(Fixture, "singleton_string") { diffMessage = diff(foo, almostFoo).diffError->toString(); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -313,4 +318,685 @@ TEST_CASE_FIXTURE(Fixture, "singleton_string") diffMessage); } +TEST_CASE_FIXTURE(Fixture, "equal_function") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number) + return x + end + function almostFoo(y: number) + return y + 10 + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + try + { + DifferResult diffRes = diff(foo, almostFoo); + INFO(diffRes.diffError->toString()); + CHECK(!diffRes.diffError.has_value()); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } +} + +TEST_CASE_FIXTURE(Fixture, "equal_function_inferred_ret_length") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function bar(x: number, y: string) + return x, y + end + function almostBar(a: number, b: string) + return a, b + end + function foo(x: number, y: string, z: boolean) + return z, bar(x, y) + end + function almostFoo(a: number, b: string, c: boolean) + return c, almostBar(a, b) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + try + { + DifferResult diffRes = diff(foo, almostFoo); + INFO(diffRes.diffError->toString()); + CHECK(!diffRes.diffError.has_value()); + } + catch (InternalCompilerError e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } +} + +TEST_CASE_FIXTURE(Fixture, "equal_function_inferred_ret_length_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function bar(x: number, y: string) + return x, y + end + function foo(x: number, y: string, z: boolean) + return bar(x, y), z + end + function almostFoo(a: number, b: string, c: boolean) + return a, c + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + try + { + DifferResult diffRes = diff(foo, almostFoo); + INFO(diffRes.diffError->toString()); + CHECK(!diffRes.diffError.has_value()); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } +} + +TEST_CASE_FIXTURE(Fixture, "function_arg_normal") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: number, z: number) + return x * y * z + end + function almostFoo(a: number, b: number, msg: string) + return a + almostFoo = foo + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + diffMessage = diff(foo, almostFoo).diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at .Arg[3] has type number, while the right type at .Arg[3] has type string)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_arg_normal_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: number, z: string) + return x * y + end + function almostFoo(a: number, y: string, msg: string) + return a + almostFoo = foo + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + diffMessage = diff(foo, almostFoo).diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at .Arg[2] has type number, while the right type at .Arg[2] has type string)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_ret_normal") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: number, z: string) + return x + end + function almostFoo(a: number, b: number, msg: string) + return msg + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at .Ret[1] has type number, while the right type at .Ret[1] has type string)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_arg_length") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: number) + return x + end + function almostFoo(x: number, y: number, c: number) + return x + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at takes 2 or more arguments, while the right type at takes 3 or more arguments)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_arg_length_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: string, z: number) + return z + end + function almostFoo(x: number, y: string) + return x + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at takes 3 or more arguments, while the right type at takes 2 or more arguments)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_arg_length_none") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo() + return 5 + end + function almostFoo(x: number, y: string) + return x + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at takes 0 or more arguments, while the right type at takes 2 or more arguments)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_arg_length_none_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number) + return x + end + function almostFoo() + return 5 + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at takes 1 or more arguments, while the right type at takes 0 or more arguments)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_ret_length") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: number) + return x + end + function almostFoo(x: number, y: number) + return x, y + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at returns 1 values, while the right type at returns 2 values)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_ret_length_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: string, z: number) + return y, x, z + end + function almostFoo(x: number, y: string, z: number) + return y, x + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at returns 3 values, while the right type at returns 2 values)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_ret_length_none") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: string) + return + end + function almostFoo(x: number, y: string) + return x + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at returns 0 values, while the right type at returns 1 values)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_ret_length_none_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo() + return 5 + end + function almostFoo() + return + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at returns 1 values, while the right type at returns 0 values)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_variadic_arg_normal") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: string, ...: number) + return x, y + end + function almostFoo(a: number, b: string, ...: string) + return a, b + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type number, while the right type at .Arg[Variadic] has type string)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_variadic_arg_missing") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: string, ...: number) + return x, y + end + function almostFoo(a: number, b: string) + return a, b + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type number, while the right type at .Arg[Variadic] has type any)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_variadic_arg_missing_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: number, y: string) + return x, y + end + function almostFoo(a: number, b: string, ...: string) + return a, b + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type any, while the right type at .Arg[Variadic] has type string)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_variadic_oversaturation") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + -- allowed to be oversaturated + function foo(x: number, y: string) + return x, y + end + -- must not be oversaturated + local almostFoo: (number, string) -> (number, string) = foo + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at takes 2 or more arguments, while the right type at takes 2 arguments)", + diffMessage); +} + +TEST_CASE_FIXTURE(Fixture, "function_variadic_oversaturation_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + -- must not be oversaturated + local foo: (number, string) -> (number, string) + -- allowed to be oversaturated + function almostFoo(x: number, y: string) + return x, y + end + )"); + + TypeId foo = requireType("foo"); + TypeId almostFoo = requireType("almostFoo"); + std::string diffMessage; + try + { + DifferResult diffRes = diff(foo, almostFoo); + if (!diffRes.diffError.has_value()) + { + INFO("Differ did not report type error, even though types are unequal"); + CHECK(false); + } + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + INFO(("InternalCompilerError: " + e.message)); + CHECK(false); + } + CHECK_EQ( + R"(DiffError: these two types are not equal because the left type at takes 2 arguments, while the right type at takes 2 or more arguments)", + diffMessage); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 74e8a959e..0be3fa980 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -407,19 +407,15 @@ type B = A TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_reexports") { - ScopedFastFlag flags[] = { - {"LuauClonePublicInterfaceLess2", true}, - }; - fileResolver.source["Module/A"] = R"( -export type A = {p : number} -return {} + export type A = {p : number} + return {} )"; fileResolver.source["Module/B"] = R"( -local a = require(script.Parent.A) -export type B = {q : a.A} -return {} + local a = require(script.Parent.A) + export type B = {q : a.A} + return {} )"; CheckResult result = frontend.check("Module/B"); @@ -442,19 +438,15 @@ return {} TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") { - ScopedFastFlag flags[] = { - {"LuauClonePublicInterfaceLess2", true}, - }; - fileResolver.source["Module/A"] = R"( -local exports = {a={p=5}} -return exports + local exports = {a={p=5}} + return exports )"; fileResolver.source["Module/B"] = R"( -local a = require(script.Parent.A) -local exports = {b=a.a} -return exports + local a = require(script.Parent.A) + local exports = {b=a.a} + return exports )"; CheckResult result = frontend.check("Module/B"); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index a8738ac8e..234034d7e 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -54,8 +54,7 @@ TEST_SUITE_BEGIN("AllocatorTests"); TEST_CASE("allocator_can_be_moved") { Counter* c = nullptr; - auto inner = [&]() - { + auto inner = [&]() { Luau::Allocator allocator; c = allocator.alloc(); Luau::Allocator moved{std::move(allocator)}; @@ -922,8 +921,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_mid") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace") { - auto columnOfEndBraceError = [this](const char* code) - { + auto columnOfEndBraceError = [this](const char* code) { try { parse(code); @@ -2387,8 +2385,7 @@ class CountAstNodes : public AstVisitor TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") { - auto checkAstEquivalence = [this](const char* codeWithErrors, const char* code) - { + auto checkAstEquivalence = [this](const char* codeWithErrors, const char* code) { try { parse(codeWithErrors); @@ -2408,8 +2405,7 @@ TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") CHECK_EQ(counterWithErrors.count, counter.count); }; - auto checkRecovery = [this, checkAstEquivalence](const char* codeWithErrors, const char* code, unsigned expectedErrorCount) - { + auto checkRecovery = [this, checkAstEquivalence](const char* codeWithErrors, const char* code, unsigned expectedErrorCount) { try { parse(codeWithErrors); diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index 613aec801..6e6dba09d 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -225,7 +225,8 @@ TEST_CASE_FIXTURE(Fixture, "internal_families_raise_errors") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(toString(result.errors[0]) == "Type family instance Add depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time"); + CHECK(toString(result.errors[0]) == "Type family instance Add depends on generic function parameters but does not appear in the function " + "signature; this construct cannot be type-checked at this time"); } TEST_CASE_FIXTURE(BuiltinsFixture, "type_families_inhabited_with_normalization") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index d456f3783..268980feb 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1913,8 +1913,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceede ScopedFastInt sfi{"LuauTarjanChildLimit", 2}; ScopedFastFlag sff[] = { {"DebugLuauDeferredConstraintResolution", true}, - {"LuauClonePublicInterfaceLess2", true}, - {"LuauCloneSkipNonInternalVisit", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 012dc7b45..45d127ab8 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -880,4 +880,34 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") CHECK_EQ("({| x: number |} & {| x: string |}) -> never", toString(requireType("f"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "index_property_table_intersection_1") +{ + CheckResult result = check(R"( +type Foo = { + Bar: string, +} & { Baz: number } + +local x: Foo = { Bar = "1", Baz = 2 } +local y = x.Bar + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_property_table_intersection_2") +{ + ScopedFastFlag sff{"LuauIndexTableIntersectionStringExpr", true}; + + CheckResult result = check(R"( +type Foo = { + Bar: string, +} & { Baz: number } + +local x: Foo = { Bar = "1", Baz = 2 } +local y = x["Bar"] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index d605d5bc4..0b7a83113 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -670,7 +670,9 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") if (FFlag::DebugLuauDeferredConstraintResolution) { LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); - CHECK_EQ("Type family instance Add depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time", toString(result.errors[0])); + CHECK_EQ("Type family instance Add depends on generic function parameters but does not appear in the function signature; this " + "construct cannot be type-checked at this time", + toString(result.errors[0])); CHECK_EQ("Unknown type used in - operation; consider adding a type annotation to 'a'", toString(result.errors[1])); } else diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index d6ae5acc9..12868d8b3 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -789,4 +789,19 @@ TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions") CHECK("variables" == unknownProp->key); } +TEST_CASE_FIXTURE(Fixture, "suppress_errors_for_prop_lookup_of_a_union_that_includes_error") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + registerHiddenTypes(&frontend); + + CheckResult result = check(R"( + local a : err | Not + + local b = a.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/conformance/native.lua b/tests/conformance/native.lua index bc70df3e7..6c0e0e0e2 100644 --- a/tests/conformance/native.lua +++ b/tests/conformance/native.lua @@ -92,4 +92,13 @@ end assert(pcall(fuzzfail9) == false) +local function fuzzfail10() + local _ + _ = false,if _ then _ else _ + _ = not _ + l0,_[l0] = not _ +end + +assert(pcall(fuzzfail10) == false) + return('OK') From f16d002db9edc2650b1d68482ae458e8bda6d8c3 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 14 Jul 2023 10:38:54 -0700 Subject: [PATCH 65/66] GCC fix --- tests/Differ.test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Differ.test.cpp b/tests/Differ.test.cpp index 520c53021..1e9dcac32 100644 --- a/tests/Differ.test.cpp +++ b/tests/Differ.test.cpp @@ -341,7 +341,7 @@ TEST_CASE_FIXTURE(Fixture, "equal_function") INFO(diffRes.diffError->toString()); CHECK(!diffRes.diffError.has_value()); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); @@ -377,7 +377,7 @@ TEST_CASE_FIXTURE(Fixture, "equal_function_inferred_ret_length") INFO(diffRes.diffError->toString()); CHECK(!diffRes.diffError.has_value()); } - catch (InternalCompilerError e) + catch (const InternalCompilerError& e) { INFO(("InternalCompilerError: " + e.message)); CHECK(false); From 5e1aca164c83dd3b91ae99fc3bf0b003d22ba561 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 28 Jul 2023 14:37:00 +0300 Subject: [PATCH 66/66] Sync to upstream/release/588 --- Analysis/include/Luau/Autocomplete.h | 7 + Analysis/include/Luau/Config.h | 2 +- .../include/Luau/ConstraintGraphBuilder.h | 3 +- Analysis/include/Luau/ConstraintSolver.h | 8 +- Analysis/include/Luau/Differ.h | 20 +- Analysis/include/Luau/Frontend.h | 14 +- Analysis/include/Luau/Normalize.h | 3 + Analysis/include/Luau/Symbol.h | 14 +- Analysis/include/Luau/TypeArena.h | 4 + Analysis/include/Luau/TypeCheckLimits.h | 41 + Analysis/include/Luau/TypeInfer.h | 19 +- Analysis/include/Luau/TypeUtils.h | 46 + Analysis/include/Luau/Unifier.h | 29 +- Analysis/src/Autocomplete.cpp | 249 ++++- Analysis/src/Config.cpp | 3 - Analysis/src/ConstraintGraphBuilder.cpp | 21 +- Analysis/src/ConstraintSolver.cpp | 75 +- Analysis/src/Differ.cpp | 278 +++++- Analysis/src/Frontend.cpp | 61 +- Analysis/src/Linter.cpp | 9 + Analysis/src/Normalize.cpp | 9 +- Analysis/src/Quantify.cpp | 1 - Analysis/src/Scope.cpp | 14 +- Analysis/src/Simplify.cpp | 6 + Analysis/src/Symbol.cpp | 14 + Analysis/src/TxnLog.cpp | 2 - Analysis/src/TypeChecker2.cpp | 250 +++-- Analysis/src/TypeInfer.cpp | 16 + Analysis/src/TypeUtils.cpp | 54 + Analysis/src/Unifier.cpp | 204 ++-- CLI/Repl.cpp | 3 - CodeGen/include/Luau/CodeGen.h | 8 +- CodeGen/include/Luau/IrData.h | 16 +- CodeGen/include/Luau/IrUtils.h | 3 + CodeGen/include/luacodegen.h | 2 +- CodeGen/src/CodeGen.cpp | 12 +- CodeGen/src/IrAnalysis.cpp | 7 +- CodeGen/src/IrBuilder.cpp | 2 +- CodeGen/src/IrDump.cpp | 14 +- CodeGen/src/IrLoweringA64.cpp | 54 +- CodeGen/src/IrLoweringX64.cpp | 62 +- CodeGen/src/IrTranslation.cpp | 56 ++ CodeGen/src/IrTranslation.h | 1 + CodeGen/src/IrUtils.cpp | 7 +- CodeGen/src/IrValueLocationTracking.cpp | 3 +- CodeGen/src/NativeState.cpp | 2 + CodeGen/src/NativeState.h | 2 + CodeGen/src/OptimizeConstProp.cpp | 8 +- Common/include/Luau/Bytecode.h | 12 + Compiler/src/Builtins.cpp | 54 +- Compiler/src/Builtins.h | 9 + Compiler/src/Compiler.cpp | 37 +- Compiler/src/CostModel.cpp | 42 +- Sources.cmake | 1 + VM/include/luaconf.h | 2 +- VM/src/laux.cpp | 2 +- VM/src/lbuiltins.cpp | 77 ++ VM/src/lfunc.cpp | 1 + VM/src/lgc.h | 4 +- VM/src/lmathlib.cpp | 113 ++- VM/src/lobject.h | 2 + VM/src/ltablib.cpp | 6 +- VM/src/lvmexecute.cpp | 11 +- bench/micro_tests/test_ToNumberString.lua | 22 + tests/Autocomplete.test.cpp | 567 ++++++++++- tests/Compiler.test.cpp | 17 + tests/Conformance.test.cpp | 32 + tests/ConstraintGraphBuilderFixture.cpp | 4 +- tests/CostModel.test.cpp | 27 +- tests/Differ.test.cpp | 927 ++++++------------ tests/Fixture.cpp | 13 +- tests/Fixture.h | 49 + tests/Frontend.test.cpp | 47 + tests/Linter.test.cpp | 8 +- tests/Simplify.test.cpp | 11 + tests/StringUtils.test.cpp | 2 +- tests/ToString.test.cpp | 1 + tests/TypeInfer.anyerror.test.cpp | 2 + tests/TypeInfer.classes.test.cpp | 2 + tests/TypeInfer.functions.test.cpp | 22 +- tests/TypeInfer.generics.test.cpp | 3 +- tests/TypeInfer.intersectionTypes.test.cpp | 1 + tests/TypeInfer.provisional.test.cpp | 10 +- tests/TypeInfer.refinements.test.cpp | 42 + tests/TypeInfer.singletons.test.cpp | 2 + tests/TypeInfer.tables.test.cpp | 91 +- tests/TypeInfer.test.cpp | 41 +- tests/TypeInfer.tryUnify.test.cpp | 5 +- tests/TypeInfer.typePacks.cpp | 2 + tests/TypeInfer.unionTypes.test.cpp | 2 + tests/TypeInfer.unknownnever.test.cpp | 2 + tests/conformance/basic.lua | 5 + tests/conformance/math.lua | 7 + tests/conformance/strings.lua | 6 + tests/main.cpp | 8 +- tools/faillist.txt | 7 +- 96 files changed, 3082 insertions(+), 1016 deletions(-) create mode 100644 Analysis/include/Luau/TypeCheckLimits.h create mode 100644 bench/micro_tests/test_ToNumberString.lua diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index 618325777..bc709c7f5 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -38,6 +38,7 @@ enum class AutocompleteEntryKind String, Type, Module, + GeneratedFunction, }; enum class ParenthesesRecommendation @@ -70,6 +71,10 @@ struct AutocompleteEntry std::optional documentationSymbol = std::nullopt; Tags tags; ParenthesesRecommendation parens = ParenthesesRecommendation::None; + std::optional insertText; + + // Only meaningful if kind is Property. + bool indexedWithSelf = false; }; using AutocompleteEntryMap = std::unordered_map; @@ -94,4 +99,6 @@ using StringCompletionCallback = AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); +constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)"; + } // namespace Luau diff --git a/Analysis/include/Luau/Config.h b/Analysis/include/Luau/Config.h index 8ba4ffa56..88c10554b 100644 --- a/Analysis/include/Luau/Config.h +++ b/Analysis/include/Luau/Config.h @@ -19,7 +19,7 @@ struct Config { Config(); - Mode mode; + Mode mode = Mode::Nonstrict; ParseOptions parseOptions; diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index ababe0a36..eb1b1fede 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -94,12 +94,13 @@ struct ConstraintGraphBuilder ScopePtr globalScope; std::function prepareModuleScope; + std::vector requireCycles; DcrLogger* logger; ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, std::function prepareModuleScope, - DcrLogger* logger, NotNull dfg); + DcrLogger* logger, NotNull dfg, std::vector requireCycles); /** * Fabricates a new free type belonging to a given scope. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index b26d88c33..cba2cbb46 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -8,6 +8,7 @@ #include "Luau/Normalize.h" #include "Luau/ToString.h" #include "Luau/Type.h" +#include "Luau/TypeCheckLimits.h" #include "Luau/Variant.h" #include @@ -81,9 +82,11 @@ struct ConstraintSolver std::vector requireCycles; DcrLogger* logger; + TypeCheckLimits limits; explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger, + TypeCheckLimits limits); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); @@ -280,6 +283,9 @@ struct ConstraintSolver TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); + void throwTimeLimitError(); + void throwUserCancelError(); + ToStringOptions opts; }; diff --git a/Analysis/include/Luau/Differ.h b/Analysis/include/Luau/Differ.h index da8b64685..60f555dc9 100644 --- a/Analysis/include/Luau/Differ.h +++ b/Analysis/include/Luau/Differ.h @@ -1,9 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include "Luau/Type.h" #include #include +#include namespace Luau { @@ -17,6 +19,7 @@ struct DiffPathNode FunctionReturn, Union, Intersection, + Negation, }; Kind kind; // non-null when TableProperty @@ -54,11 +57,15 @@ struct DiffPathNodeLeaf std::optional tableProperty; std::optional minLength; bool isVariadic; - DiffPathNodeLeaf(std::optional ty, std::optional tableProperty, std::optional minLength, bool isVariadic) + // TODO: Rename to anonymousIndex, for both union and Intersection + std::optional unionIndex; + DiffPathNodeLeaf( + std::optional ty, std::optional tableProperty, std::optional minLength, bool isVariadic, std::optional unionIndex) : ty(ty) , tableProperty(tableProperty) , minLength(minLength) , isVariadic(isVariadic) + , unionIndex(unionIndex) { } @@ -66,6 +73,8 @@ struct DiffPathNodeLeaf static DiffPathNodeLeaf detailsTableProperty(TypeId ty, Name tableProperty); + static DiffPathNodeLeaf detailsUnionIndex(TypeId ty, size_t index); + static DiffPathNodeLeaf detailsLength(int minLength, bool isVariadic); static DiffPathNodeLeaf nullopts(); @@ -82,11 +91,12 @@ struct DiffError enum Kind { Normal, - MissingProperty, + MissingTableProperty, + MissingUnionMember, + MissingIntersectionMember, + IncompatibleGeneric, LengthMismatchInFnArgs, LengthMismatchInFnRets, - LengthMismatchInUnion, - LengthMismatchInIntersection, }; Kind kind; @@ -141,6 +151,8 @@ struct DifferEnvironment { TypeId rootLeft; TypeId rootRight; + + DenseHashMap genericMatchedPairs; }; DifferResult diff(TypeId ty1, TypeId ty2); diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 5804b7a8c..5853eb32c 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -6,6 +6,7 @@ #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" #include "Luau/Scope.h" +#include "Luau/TypeCheckLimits.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" @@ -189,14 +190,6 @@ struct Frontend std::optional getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); private: - struct TypeCheckLimits - { - std::optional finishTime; - std::optional instantiationChildLimit; - std::optional unifierIterationLimit; - std::shared_ptr cancellationToken; - }; - ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, std::optional environmentScope, bool forAutocomplete, bool recordJsonLog, TypeCheckLimits typeCheckLimits); @@ -248,11 +241,12 @@ struct Frontend ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, std::function prepareModuleScope, FrontendOptions options); + const ScopePtr& globalScope, std::function prepareModuleScope, FrontendOptions options, + TypeCheckLimits limits); ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, const ScopePtr& globalScope, std::function prepareModuleScope, FrontendOptions options, - bool recordJsonLog); + TypeCheckLimits limits, bool recordJsonLog); } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 75c07a7be..0f9352d15 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -277,6 +277,9 @@ struct NormalizedType /// Returns true if this type should result in error suppressing behavior. bool shouldSuppressErrors() const; + /// Returns true if this type contains the primitve top table type, `table`. + bool hasTopTable() const; + // Helpers that improve readability of the above (they just say if the component is present) bool hasTops() const; bool hasBooleans() const; diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index b47554e0d..337e2a9f2 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) - namespace Luau { @@ -42,17 +40,7 @@ struct Symbol return local != nullptr || global.value != nullptr; } - bool operator==(const Symbol& rhs) const - { - if (local) - return local == rhs.local; - else if (global.value) - return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. - else if (FFlag::DebugLuauDeferredConstraintResolution) - return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. - else - return false; - } + bool operator==(const Symbol& rhs) const; bool operator!=(const Symbol& rhs) const { diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 0e69bb4aa..5f831f18a 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -9,12 +9,16 @@ namespace Luau { +struct Module; struct TypeArena { TypedAllocator types; TypedAllocator typePacks; + // Owning module, if any + Module* owningModule = nullptr; + void clear(); template diff --git a/Analysis/include/Luau/TypeCheckLimits.h b/Analysis/include/Luau/TypeCheckLimits.h new file mode 100644 index 000000000..9eabe0ff6 --- /dev/null +++ b/Analysis/include/Luau/TypeCheckLimits.h @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Cancellation.h" +#include "Luau/Error.h" + +#include +#include +#include + +namespace Luau +{ + +class TimeLimitError : public InternalCompilerError +{ +public: + explicit TimeLimitError(const std::string& moduleName) + : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) + { + } +}; + +class UserCancelError : public InternalCompilerError +{ +public: + explicit UserCancelError(const std::string& moduleName) + : InternalCompilerError("Analysis has been cancelled by user", moduleName) + { + } +}; + +struct TypeCheckLimits +{ + std::optional finishTime; + std::optional instantiationChildLimit; + std::optional unifierIterationLimit; + + std::shared_ptr cancellationToken; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 79ee60c46..9a44af496 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -10,6 +10,7 @@ #include "Luau/Symbol.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" +#include "Luau/TypeCheckLimits.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier.h" @@ -56,24 +57,6 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; -class TimeLimitError : public InternalCompilerError -{ -public: - explicit TimeLimitError(const std::string& moduleName) - : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) - { - } -}; - -class UserCancelError : public InternalCompilerError -{ -public: - explicit UserCancelError(const std::string& moduleName) - : InternalCompilerError("Analysis has been cancelled by user", moduleName) - { - } -}; - struct GlobalTypes { GlobalTypes(NotNull builtinTypes); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 84916cd24..793415ee0 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -14,6 +14,7 @@ namespace Luau struct TxnLog; struct TypeArena; +class Normalizer; enum class ValueContext { @@ -55,6 +56,51 @@ std::vector reduceUnion(const std::vector& types); */ TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty); +enum class ErrorSuppression +{ + Suppress, + DoNotSuppress, + NormalizationFailed +}; + +/** + * Normalizes the given type using the normalizer to determine if the type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param ty the type to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty); + +/** + * Flattens and normalizes the given typepack using the normalizer to determine if the type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param tp the typepack to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp); + +/** + * Normalizes the two given type using the normalizer to determine if either type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param ty1 the first type to check for error suppression + * @param ty2 the second type to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty1, TypeId ty2); + +/** + * Flattens and normalizes the two given typepacks using the normalizer to determine if either type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param tp1 the first typepack to check for error suppression + * @param tp2 the second typepack to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp1, TypePackId tp2); + template const T* get(std::optional ty) { diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 7a6a2f760..f7c5c94c0 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -43,6 +43,21 @@ struct Widen : Substitution TypePackId operator()(TypePackId ty); }; +/** + * Normally, when we unify table properties, we must do so invariantly, but we + * can introduce a special exception: If the table property in the subtype + * position arises from a literal expression, it is safe to instead perform a + * covariant check. + * + * This is very useful for typechecking cases where table literals (and trees of + * table literals) are passed directly to functions. + * + * In this case, we know that the property has no other name referring to it and + * so it is perfectly safe for the function to mutate the table any way it + * wishes. + */ +using LiteralProperties = DenseHashSet; + // TODO: Use this more widely. struct UnifierOptions { @@ -80,7 +95,7 @@ struct Unifier // Configure the Unifier to test for scope subsumption via embedded Scope // pointers rather than TypeLevels. - void enableScopeTests(); + void enableNewSolver(); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); @@ -90,10 +105,10 @@ struct Unifier * Populate the vector errors with any type errors that may arise. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. */ - void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); private: - void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy); // Traverse the two types provided and block on any BlockedTypes we find. @@ -108,7 +123,7 @@ struct Unifier void tryUnifyPrimitives(TypeId subTy, TypeId superTy); void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); - void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); @@ -163,8 +178,10 @@ struct Unifier // Available after regular type pack unification errors std::optional firstPackErrorPos; - // If true, we use the scope hierarchy rather than TypeLevels - bool useScopes = false; + // If true, we do a bunch of small things differently to work better with + // the new type inference engine. Most notably, we use the Scope hierarchy + // directly rather than using TypeLevels. + bool useNewSolver = false; }; void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index d67eda8d5..cd3f4c6e3 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,6 +13,10 @@ #include LUAU_FASTFLAG(DebugLuauReadWriteProperties) +LUAU_FASTFLAGVARIABLE(LuauDisableCompletionOutsideQuotes, false) +LUAU_FASTFLAGVARIABLE(LuauAnonymousAutofilled, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompleteLastTypecheck, false) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteHideSelfArg, false) static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -280,18 +284,38 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Property, - type, - prop.deprecated, - isWrongIndexer(type), - typeCorrect, - containingClass, - &prop, - prop.documentationSymbol, - {}, - parens, - }; + if (FFlag::LuauAutocompleteHideSelfArg) + { + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Property, + type, + prop.deprecated, + isWrongIndexer(type), + typeCorrect, + containingClass, + &prop, + prop.documentationSymbol, + {}, + parens, + {}, + indexType == PropIndexType::Colon + }; + } + else + { + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Property, + type, + prop.deprecated, + isWrongIndexer(type), + typeCorrect, + containingClass, + &prop, + prop.documentationSymbol, + {}, + parens + }; + } } } }; @@ -591,14 +615,14 @@ std::optional getLocalTypeInScopeAt(const Module& module, Position posit return {}; } -static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty) +template +static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) { - if (!canSuggestInferredType(scope, ty)) - return std::nullopt; - + LUAU_ASSERT(FFlag::LuauAnonymousAutofilled); ToStringOptions opts; opts.useLineBreaks = false; opts.hideTableKind = true; + opts.functionTypeArguments = functionTypeArguments; opts.scope = scope; ToStringResult name = toStringDetailed(ty, opts); @@ -608,6 +632,30 @@ static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty) return name.name; } +static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool functionTypeArguments = false) +{ + if (!canSuggestInferredType(scope, ty)) + return std::nullopt; + + if (FFlag::LuauAnonymousAutofilled) + { + return tryToStringDetailed(scope, ty, functionTypeArguments); + } + else + { + ToStringOptions opts; + opts.useLineBreaks = false; + opts.hideTableKind = true; + opts.scope = scope; + ToStringResult name = toStringDetailed(ty, opts); + + if (name.error || name.invalid || name.cycle || name.truncated) + return std::nullopt; + + return name.name; + } +} + static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) { std::optional ty; @@ -1297,6 +1345,14 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } + if (FFlag::LuauDisableCompletionOutsideQuotes && !nodes.back()->is()) + { + if (nodes.back()->location.end == position || nodes.back()->location.begin == position) + { + return std::nullopt; + } + } + AstExprCall* candidate = nodes.at(nodes.size() - 2)->as(); if (!candidate) { @@ -1361,6 +1417,140 @@ static AutocompleteResult autocompleteWhileLoopKeywords(std::vector an return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword}; } +static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) +{ + LUAU_ASSERT(FFlag::LuauAnonymousAutofilled); + std::string result = "function("; + + auto [args, tail] = Luau::flatten(funcTy.argTypes); + + bool first = true; + // Skip the implicit 'self' argument if call is indexed with ':' + for (size_t argIdx = 0; argIdx < args.size(); ++argIdx) + { + if (!first) + result += ", "; + else + first = false; + + std::string name; + if (argIdx < funcTy.argNames.size() && funcTy.argNames[argIdx]) + name = funcTy.argNames[argIdx]->name; + else + name = "a" + std::to_string(argIdx); + + if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) + result += name + ": " + *type; + else + result += name; + } + + if (tail && (Luau::isVariadic(*tail) || Luau::get(Luau::follow(*tail)))) + { + if (!first) + result += ", "; + + std::optional varArgType; + if (const VariadicTypePack* pack = get(follow(*tail))) + { + if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) + varArgType = std::move(res); + } + + if (varArgType) + result += "...: " + *varArgType; + else + result += "..."; + } + + result += ")"; + + auto [rets, retTail] = Luau::flatten(funcTy.retTypes); + if (const size_t totalRetSize = rets.size() + (retTail ? 1 : 0); totalRetSize > 0) + { + if (std::optional returnTypes = tryToStringDetailed(scope, funcTy.retTypes, true)) + { + result += ": "; + bool wrap = totalRetSize != 1; + if (wrap) + result += "("; + result += *returnTypes; + if (wrap) + result += ")"; + } + } + result += " end"; + return result; +} + +static std::optional makeAnonymousAutofilled(const ModulePtr& module, Position position, const AstNode* node, const std::vector& ancestry) +{ + LUAU_ASSERT(FFlag::LuauAnonymousAutofilled); + const AstExprCall* call = node->as(); + if (!call && ancestry.size() > 1) + call = ancestry[ancestry.size() - 2]->as(); + + if (!call) + return std::nullopt; + + if (!call->location.containsClosed(position) || call->func->location.containsClosed(position)) + return std::nullopt; + + TypeId* typeIter = module->astTypes.find(call->func); + if (!typeIter) + return std::nullopt; + + const FunctionType* outerFunction = get(follow(*typeIter)); + if (!outerFunction) + return std::nullopt; + + size_t argument = 0; + for (size_t i = 0; i < call->args.size; ++i) + { + if (call->args.data[i]->location.containsClosed(position)) + { + argument = i; + break; + } + } + + if (call->self) + argument++; + + std::optional argType; + auto [args, tail] = flatten(outerFunction->argTypes); + if (argument < args.size()) + argType = args[argument]; + + if (!argType) + return std::nullopt; + + TypeId followed = follow(*argType); + const FunctionType* type = get(followed); + if (!type) + { + if (const UnionType* unionType = get(followed)) + { + if (std::optional nonnullFunction = returnFirstNonnullOptionOfType(unionType)) + type = *nonnullFunction; + } + } + + if (!type) + return std::nullopt; + + const ScopePtr scope = findScopeAtPosition(*module, position); + if (!scope) + return std::nullopt; + + AutocompleteEntry entry; + entry.kind = AutocompleteEntryKind::GeneratedFunction; + entry.typeCorrect = TypeCorrectKind::Correct; + entry.type = argType; + entry.insertText = makeAnonymous(scope, *type); + return std::make_optional(std::move(entry)); +} + static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull builtinTypes, TypeArena* typeArena, Scope* globalScope, Position position, StringCompletionCallback callback) { @@ -1612,7 +1802,19 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {}; if (node->asExpr()) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + { + if (FFlag::LuauAnonymousAutofilled) + { + AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) + ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); + return ret; + } + else + { + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + } + } else if (node->asStat()) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1621,11 +1823,14 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - FrontendOptions opts; - opts.forAutocomplete = true; - frontend.check(moduleName, opts); + if (!FFlag::LuauAutocompleteLastTypecheck) + { + // FIXME: We can improve performance here by parsing without checking. + // The old type graph is probably fine. (famous last words!) + FrontendOptions opts; + opts.forAutocomplete = true; + frontend.check(moduleName, opts); + } const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) diff --git a/Analysis/src/Config.cpp b/Analysis/src/Config.cpp index 00ca7b16f..9369743ed 100644 --- a/Analysis/src/Config.cpp +++ b/Analysis/src/Config.cpp @@ -4,15 +4,12 @@ #include "Luau/Lexer.h" #include "Luau/StringUtils.h" -LUAU_FASTFLAGVARIABLE(LuauEnableNonstrictByDefaultForLuauConfig, false) - namespace Luau { using Error = std::optional; Config::Config() - : mode(FFlag::LuauEnableNonstrictByDefaultForLuauConfig ? Mode::Nonstrict : Mode::NoCheck) { enabledLint.setDefaults(); } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index c62c214c7..9c2766ec7 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -139,7 +139,8 @@ void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const Con ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, - std::function prepareModuleScope, DcrLogger* logger, NotNull dfg) + std::function prepareModuleScope, DcrLogger* logger, NotNull dfg, + std::vector requireCycles) : module(module) , builtinTypes(builtinTypes) , arena(arena) @@ -149,6 +150,7 @@ ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, TypeArena* aren , ice(ice) , globalScope(globalScope) , prepareModuleScope(std::move(prepareModuleScope)) + , requireCycles(std::move(requireCycles)) , logger(logger) { LUAU_ASSERT(module); @@ -703,6 +705,16 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* l { scope->importedTypeBindings[name] = module->exportedTypeBindings; scope->importedModules[name] = moduleInfo->name; + + // Imported types of requires that transitively refer to current module have to be replaced with 'any' + for (const auto& [location, path] : requireCycles) + { + if (!path.empty() && path.front() == moduleInfo->name) + { + for (auto& [name, tf] : scope->importedTypeBindings[name]) + tf = TypeFun{{}, {}, builtinTypes->anyType}; + } + } } } } @@ -1913,6 +1925,9 @@ std::tuple ConstraintGraphBuilder::checkBinary( NullableBreadcrumbId bc = dfg->getBreadcrumb(typeguard->target); if (!bc) return {leftType, rightType, nullptr}; + auto augmentForErrorSupression = [&](TypeId ty) -> TypeId { + return arena->addType(UnionType{{ty, builtinTypes->errorType}}); + }; TypeId discriminantTy = builtinTypes->neverType; if (typeguard->type == "nil") @@ -1926,9 +1941,9 @@ std::tuple ConstraintGraphBuilder::checkBinary( else if (typeguard->type == "thread") discriminantTy = builtinTypes->threadType; else if (typeguard->type == "table") - discriminantTy = builtinTypes->tableType; + discriminantTy = augmentForErrorSupression(builtinTypes->tableType); else if (typeguard->type == "function") - discriminantTy = builtinTypes->functionType; + discriminantTy = augmentForErrorSupression(builtinTypes->functionType); else if (typeguard->type == "userdata") { // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index c9b584fd6..fbe081627 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -12,6 +12,7 @@ #include "Luau/ModuleResolver.h" #include "Luau/Quantify.h" #include "Luau/Simplify.h" +#include "Luau/TimeTrace.h" #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypeFamily.h" @@ -259,7 +260,7 @@ struct InstantiationQueuer : TypeOnceVisitor }; ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger, TypeCheckLimits limits) : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) @@ -269,6 +270,7 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull *limits.finishTime) + throwTimeLimitError(); + if (limits.cancellationToken && limits.cancellationToken->requested()) + throwUserCancelError(); + std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; StepSnapshot snapshot; @@ -555,6 +562,9 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNullscope); + if (limits.instantiationChildLimit) + inst.childLimit = *limits.instantiationChildLimit; + std::optional instantiated = inst.substitute(c.superType); LUAU_ASSERT(get(c.subType)); @@ -586,7 +596,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull(operandType)) + if (!force && get(operandType)) return block(operandType, constraint); LUAU_ASSERT(get(c.resultType)); @@ -713,6 +723,10 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope}; + + if (limits.instantiationChildLimit) + instantiation.childLimit = *limits.instantiationChildLimit; + std::optional instantiatedMm = instantiation.substitute(*mm); if (!instantiatedMm) { @@ -1318,6 +1332,9 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope); + if (limits.instantiationChildLimit) + inst.childLimit = *limits.instantiationChildLimit; + std::vector arityMatchingOverloads; std::optional bestOverloadLog; @@ -1334,7 +1351,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, Location{}, Covariant}; - u.enableScopeTests(); + u.enableNewSolver(); u.tryUnify(*instantiated, inferredTy, /* isFunctionCall */ true); @@ -1384,7 +1401,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, Location{}, Covariant}; - u.enableScopeTests(); + u.enableNewSolver(); u.tryUnify(inferredTy, builtinTypes->anyType); u.tryUnify(fn, builtinTypes->anyType); @@ -1505,6 +1522,7 @@ static void updateTheTableType( for (size_t i = 0; i < path.size() - 1; ++i) { + t = follow(t); auto propTy = findTablePropertyRespectingMeta(builtinTypes, dummy, t, path[i], Location{}); dummy.clear(); @@ -1885,19 +1903,20 @@ bool ConstraintSolver::tryDispatch(const RefineConstraint& c, NotNullnormalize(c.type); - - if (!normType) - reportError(NormalizationTooComplex{}, constraint->location); - - if (normType && normType->shouldSuppressErrors()) + switch (shouldSuppressErrors(normalizer, c.type)) + { + case ErrorSuppression::Suppress: { auto resultOrError = simplifyUnion(builtinTypes, arena, result, builtinTypes->errorType).result; asMutable(c.resultType)->ty.emplace(resultOrError); + break; } - else - { + case ErrorSuppression::DoNotSuppress: asMutable(c.resultType)->ty.emplace(result); + break; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, constraint->location); + break; } unblock(c.resultType, constraint->location); @@ -1983,6 +2002,15 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl unify(*anyified, ty, constraint->scope); }; + auto unknownify = [&](auto ty) { + Anyification anyify{arena, constraint->scope, builtinTypes, &iceReporter, builtinTypes->unknownType, builtinTypes->anyTypePack}; + std::optional anyified = anyify.substitute(ty); + if (!anyified) + reportError(CodeTooComplex{}, constraint->location); + else + unify(*anyified, ty, constraint->scope); + }; + auto errorify = [&](auto ty) { Anyification anyify{arena, constraint->scope, builtinTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; std::optional errorified = anyify.substitute(ty); @@ -2051,6 +2079,9 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl Instantiation instantiation(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); + if (limits.instantiationChildLimit) + instantiation.childLimit = *limits.instantiationChildLimit; + if (std::optional instantiatedIterFn = instantiation.substitute(*iterFn)) { if (auto iterFtv = get(*instantiatedIterFn)) @@ -2107,6 +2138,8 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl LUAU_ASSERT(false); } + else if (auto primitiveTy = get(iteratorTy); primitiveTy && primitiveTy->type == PrimitiveType::Type::Table) + unknownify(c.variables); else errorify(c.variables); @@ -2357,7 +2390,7 @@ template bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, TID superTy) { Unifier u{normalizer, constraint->scope, constraint->location, Covariant}; - u.enableScopeTests(); + u.enableNewSolver(); u.tryUnify(subTy, superTy); @@ -2606,7 +2639,7 @@ bool ConstraintSolver::isBlocked(NotNull constraint) ErrorVec ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) { Unifier u{normalizer, scope, Location{}, Covariant}; - u.enableScopeTests(); + u.enableNewSolver(); u.tryUnify(subType, superType); @@ -2631,7 +2664,7 @@ ErrorVec ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNu { UnifierSharedState sharedState{&iceReporter}; Unifier u{normalizer, scope, Location{}, Covariant}; - u.enableScopeTests(); + u.enableNewSolver(); u.tryUnify(subPack, superPack); @@ -2728,7 +2761,7 @@ TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, if (unifyFreeTypes && (get(a) || get(b))) { Unifier u{normalizer, scope, Location{}, Covariant}; - u.enableScopeTests(); + u.enableNewSolver(); u.tryUnify(b, a); if (u.errors.empty()) @@ -2785,4 +2818,14 @@ TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp) return arena->addTypePack(resultTypes, resultTail); } +LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() +{ + throw TimeLimitError(currentModuleName); +} + +LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() +{ + throw UserCancelError(currentModuleName); +} + } // namespace Luau diff --git a/Analysis/src/Differ.cpp b/Analysis/src/Differ.cpp index 50672cd9e..307446ef3 100644 --- a/Analysis/src/Differ.cpp +++ b/Analysis/src/Differ.cpp @@ -1,11 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Differ.h" +#include "Luau/Common.h" #include "Luau/Error.h" #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypePack.h" +#include "Luau/Unifiable.h" #include #include +#include +#include namespace Luau { @@ -34,6 +38,10 @@ std::string DiffPathNode::toString() const // Add 1 because Lua is 1-indexed return "Ret[" + std::to_string(*index + 1) + "]"; } + case DiffPathNode::Kind::Negation: + { + return "Negation"; + } default: { throw InternalCompilerError{"DiffPathNode::toString is not exhaustive"}; @@ -58,22 +66,27 @@ DiffPathNode DiffPathNode::constructWithKind(Kind kind) DiffPathNodeLeaf DiffPathNodeLeaf::detailsNormal(TypeId ty) { - return DiffPathNodeLeaf{ty, std::nullopt, std::nullopt, false}; + return DiffPathNodeLeaf{ty, std::nullopt, std::nullopt, false, std::nullopt}; } DiffPathNodeLeaf DiffPathNodeLeaf::detailsTableProperty(TypeId ty, Name tableProperty) { - return DiffPathNodeLeaf{ty, tableProperty, std::nullopt, false}; + return DiffPathNodeLeaf{ty, tableProperty, std::nullopt, false, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsUnionIndex(TypeId ty, size_t index) +{ + return DiffPathNodeLeaf{ty, std::nullopt, std::nullopt, false, index}; } DiffPathNodeLeaf DiffPathNodeLeaf::detailsLength(int minLength, bool isVariadic) { - return DiffPathNodeLeaf{std::nullopt, std::nullopt, minLength, isVariadic}; + return DiffPathNodeLeaf{std::nullopt, std::nullopt, minLength, isVariadic, std::nullopt}; } DiffPathNodeLeaf DiffPathNodeLeaf::nullopts() { - return DiffPathNodeLeaf{std::nullopt, std::nullopt, std::nullopt, false}; + return DiffPathNodeLeaf{std::nullopt, std::nullopt, std::nullopt, false, std::nullopt}; } std::string DiffPath::toString(bool prependDot) const @@ -104,7 +117,7 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea checkNonMissingPropertyLeavesHaveNulloptTableProperty(); return pathStr + " has type " + Luau::toString(*leaf.ty); } - case DiffError::Kind::MissingProperty: + case DiffError::Kind::MissingTableProperty: { if (leaf.ty.has_value()) { @@ -120,6 +133,38 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea } throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; } + case DiffError::Kind::MissingUnionMember: + { + // TODO: do normal case + if (leaf.ty.has_value()) + { + if (!leaf.unionIndex.has_value()) + throw InternalCompilerError{"leaf.unionIndex is nullopt"}; + return pathStr + " is a union containing type " + Luau::toString(*leaf.ty); + } + else if (otherLeaf.ty.has_value()) + { + return pathStr + " is a union missing type " + Luau::toString(*otherLeaf.ty); + } + throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; + } + case DiffError::Kind::MissingIntersectionMember: + { + // TODO: better message for intersections + // An intersection of just functions is always an "overloaded function" + // An intersection of just tables is always a "joined table" + if (leaf.ty.has_value()) + { + if (!leaf.unionIndex.has_value()) + throw InternalCompilerError{"leaf.unionIndex is nullopt"}; + return pathStr + " is an intersection containing type " + Luau::toString(*leaf.ty); + } + else if (otherLeaf.ty.has_value()) + { + return pathStr + " is an intersection missing type " + Luau::toString(*otherLeaf.ty); + } + throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; + } case DiffError::Kind::LengthMismatchInFnArgs: { if (!leaf.minLength.has_value()) @@ -163,9 +208,20 @@ std::string getDevFixFriendlyName(TypeId ty) std::string DiffError::toString() const { - std::string msg = "DiffError: these two types are not equal because the left type at " + toStringALeaf(leftRootName, left, right) + - ", while the right type at " + toStringALeaf(rightRootName, right, left); - return msg; + switch (kind) + { + case DiffError::Kind::IncompatibleGeneric: + { + std::string diffPathStr{diffPath.toString(true)}; + return "DiffError: these two types are not equal because the left generic at " + leftRootName + diffPathStr + + " cannot be the same type parameter as the right generic at " + rightRootName + diffPathStr; + } + default: + { + return "DiffError: these two types are not equal because the left type at " + toStringALeaf(leftRootName, left, right) + + ", while the right type at " + toStringALeaf(rightRootName, right, left); + } + } } void DiffError::checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right) @@ -193,6 +249,19 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffNegation(DifferEnvironment& env, TypeId left, TypeId right); +struct FindSeteqCounterexampleResult +{ + // nullopt if no counterexample found + std::optional mismatchIdx; + // true if counterexample is in the left, false if cex is in the right + bool inLeft; +}; +static FindSeteqCounterexampleResult findSeteqCounterexample( + DifferEnvironment& env, const std::vector& left, const std::vector& right); +static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right); /** * The last argument gives context info on which complex type contained the TypePack. */ @@ -205,6 +274,8 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) { const TableType* leftTable = get(left); const TableType* rightTable = get(right); + LUAU_ASSERT(leftTable); + LUAU_ASSERT(rightTable); for (auto const& [field, value] : leftTable->props) { @@ -212,7 +283,7 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) { // left has a field the right doesn't return DifferResult{DiffError{ - DiffError::Kind::MissingProperty, + DiffError::Kind::MissingTableProperty, DiffPathNodeLeaf::detailsTableProperty(value.type(), field), DiffPathNodeLeaf::nullopts(), getDevFixFriendlyName(env.rootLeft), @@ -225,9 +296,9 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) if (leftTable->props.find(field) == leftTable->props.end()) { // right has a field the left doesn't - return DifferResult{ - DiffError{DiffError::Kind::MissingProperty, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::detailsTableProperty(value.type(), field), - getDevFixFriendlyName(env.rootLeft), getDevFixFriendlyName(env.rootRight)}}; + return DifferResult{DiffError{DiffError::Kind::MissingTableProperty, DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::detailsTableProperty(value.type(), field), getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight)}}; } } // left and right have the same set of keys @@ -248,6 +319,8 @@ static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId ri { const PrimitiveType* leftPrimitive = get(left); const PrimitiveType* rightPrimitive = get(right); + LUAU_ASSERT(leftPrimitive); + LUAU_ASSERT(rightPrimitive); if (leftPrimitive->type != rightPrimitive->type) { @@ -266,6 +339,8 @@ static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId ri { const SingletonType* leftSingleton = get(left); const SingletonType* rightSingleton = get(right); + LUAU_ASSERT(leftSingleton); + LUAU_ASSERT(rightSingleton); if (*leftSingleton != *rightSingleton) { @@ -284,6 +359,8 @@ static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId rig { const FunctionType* leftFunction = get(left); const FunctionType* rightFunction = get(right); + LUAU_ASSERT(leftFunction); + LUAU_ASSERT(rightFunction); DifferResult differResult = diffTpi(env, DiffError::Kind::LengthMismatchInFnArgs, leftFunction->argTypes, rightFunction->argTypes); if (differResult.diffError.has_value()) @@ -291,6 +368,157 @@ static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId rig return diffTpi(env, DiffError::Kind::LengthMismatchInFnRets, leftFunction->retTypes, rightFunction->retTypes); } +static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId right) +{ + LUAU_ASSERT(get(left)); + LUAU_ASSERT(get(right)); + // Try to pair up the generics + bool isLeftFree = !env.genericMatchedPairs.contains(left); + bool isRightFree = !env.genericMatchedPairs.contains(right); + if (isLeftFree && isRightFree) + { + env.genericMatchedPairs[left] = right; + env.genericMatchedPairs[right] = left; + return DifferResult{}; + } + else if (isLeftFree || isRightFree) + { + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; + } + + // Both generics are already paired up + if (*env.genericMatchedPairs.find(left) == right) + return DifferResult{}; + + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; +} + +static DifferResult diffNegation(DifferEnvironment& env, TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + const NegationType* rightNegation = get(right); + LUAU_ASSERT(leftNegation); + LUAU_ASSERT(rightNegation); + + DifferResult differResult = diffUsingEnv(env, leftNegation->ty, rightNegation->ty); + if (!differResult.diffError.has_value()) + return DifferResult{}; + + differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::Negation)); + return differResult; +} + +static FindSeteqCounterexampleResult findSeteqCounterexample( + DifferEnvironment& env, const std::vector& left, const std::vector& right) +{ + std::unordered_set unmatchedRightIdxes; + for (size_t i = 0; i < right.size(); i++) + unmatchedRightIdxes.insert(i); + for (size_t leftIdx = 0; leftIdx < left.size(); leftIdx++) + { + bool leftIdxIsMatched = false; + auto unmatchedRightIdxIt = unmatchedRightIdxes.begin(); + while (unmatchedRightIdxIt != unmatchedRightIdxes.end()) + { + DifferResult differResult = diffUsingEnv(env, left[leftIdx], right[*unmatchedRightIdxIt]); + if (differResult.diffError.has_value()) + { + unmatchedRightIdxIt++; + continue; + } + + // unmatchedRightIdxIt is matched with current leftIdx + leftIdxIsMatched = true; + unmatchedRightIdxIt = unmatchedRightIdxes.erase(unmatchedRightIdxIt); + } + if (!leftIdxIsMatched) + { + return FindSeteqCounterexampleResult{leftIdx, true}; + } + } + if (unmatchedRightIdxes.empty()) + return FindSeteqCounterexampleResult{std::nullopt, false}; + return FindSeteqCounterexampleResult{*unmatchedRightIdxes.begin(), false}; +} + +static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + const UnionType* rightUnion = get(right); + LUAU_ASSERT(leftUnion); + LUAU_ASSERT(rightUnion); + + FindSeteqCounterexampleResult findSeteqCexResult = findSeteqCounterexample(env, leftUnion->options, rightUnion->options); + if (findSeteqCexResult.mismatchIdx.has_value()) + { + if (findSeteqCexResult.inLeft) + return DifferResult{DiffError{ + DiffError::Kind::MissingUnionMember, + DiffPathNodeLeaf::detailsUnionIndex(leftUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + DiffPathNodeLeaf::nullopts(), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; + else + return DifferResult{DiffError{ + DiffError::Kind::MissingUnionMember, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::detailsUnionIndex(rightUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; + } + + // TODO: somehow detect mismatch index, likely using heuristics + + return DifferResult{}; +} + +static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right) +{ + const IntersectionType* leftIntersection = get(left); + const IntersectionType* rightIntersection = get(right); + LUAU_ASSERT(leftIntersection); + LUAU_ASSERT(rightIntersection); + + FindSeteqCounterexampleResult findSeteqCexResult = findSeteqCounterexample(env, leftIntersection->parts, rightIntersection->parts); + if (findSeteqCexResult.mismatchIdx.has_value()) + { + if (findSeteqCexResult.inLeft) + return DifferResult{DiffError{ + DiffError::Kind::MissingIntersectionMember, + DiffPathNodeLeaf::detailsUnionIndex(leftIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + DiffPathNodeLeaf::nullopts(), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; + else + return DifferResult{DiffError{ + DiffError::Kind::MissingIntersectionMember, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::detailsUnionIndex(rightIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; + } + + // TODO: somehow detect mismatch index, likely using heuristics + + return DifferResult{}; +} + static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right) { left = follow(left); @@ -322,6 +550,10 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig // Both left and right must be Any if either is Any for them to be equal! return DifferResult{}; } + else if (auto ln = get(left)) + { + return diffNegation(env, left, right); + } throw InternalCompilerError{"Unimplemented Simple TypeId variant for diffing"}; } @@ -336,6 +568,24 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig { return diffFunction(env, left, right); } + if (auto lg = get(left)) + { + return diffGeneric(env, left, right); + } + if (auto lu = get(left)) + { + return diffUnion(env, left, right); + } + if (auto li = get(left)) + { + return diffIntersection(env, left, right); + } + if (auto le = get(left)) + { + // TODO: return debug-friendly result state + return DifferResult{}; + } + throw InternalCompilerError{"Unimplemented non-simple TypeId variant for diffing"}; } @@ -444,7 +694,7 @@ static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::K DifferResult diff(TypeId ty1, TypeId ty2) { - DifferEnvironment differEnv{ty1, ty2}; + DifferEnvironment differEnv{ty1, ty2, DenseHashMap{nullptr}}; return diffUsingEnv(differEnv, ty1, ty2); } @@ -452,7 +702,7 @@ bool isSimple(TypeId ty) { ty = follow(ty); // TODO: think about GenericType, etc. - return get(ty) || get(ty) || get(ty); + return get(ty) || get(ty) || get(ty) || get(ty); } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 2dea162bd..362fcdcce 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1141,22 +1141,26 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& parentScope, std::function prepareModuleScope, FrontendOptions options) + const ScopePtr& parentScope, std::function prepareModuleScope, FrontendOptions options, + TypeCheckLimits limits) { const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson; return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, parentScope, std::move(prepareModuleScope), - options, recordJsonLog); + options, limits, recordJsonLog); } ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, const ScopePtr& parentScope, std::function prepareModuleScope, FrontendOptions options, - bool recordJsonLog) + TypeCheckLimits limits, bool recordJsonLog) { ModulePtr result = std::make_shared(); result->name = sourceModule.name; result->humanReadableName = sourceModule.humanReadableName; + result->internalTypes.owningModule = result.get(); + result->interfaceTypes.owningModule = result.get(); + iceHandler->moduleName = sourceModule.name; std::unique_ptr logger; @@ -1174,32 +1178,34 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorinternalTypes, builtinTypes, NotNull{&unifierState}}; - ConstraintGraphBuilder cgb{ - result, - &result->internalTypes, - moduleResolver, - builtinTypes, - iceHandler, - parentScope, - std::move(prepareModuleScope), - logger.get(), - NotNull{&dfg}, - }; + ConstraintGraphBuilder cgb{result, &result->internalTypes, moduleResolver, builtinTypes, iceHandler, parentScope, std::move(prepareModuleScope), + logger.get(), NotNull{&dfg}, requireCycles}; cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{ - NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), result->name, moduleResolver, requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), result->humanReadableName, moduleResolver, + requireCycles, logger.get(), limits}; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); - cs.run(); + try + { + cs.run(); + } + catch (const TimeLimitError&) + { + result->timeout = true; + } + catch (const UserCancelError&) + { + result->cancelled = true; + } for (TypeError& e : cs.errors) result->errors.emplace_back(std::move(e)); @@ -1209,7 +1215,22 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorclonePublicInterface(builtinTypes, *iceHandler); - Luau::check(builtinTypes, NotNull{&unifierState}, logger.get(), sourceModule, result.get()); + if (result->timeout || result->cancelled) + { + // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending types + ScopePtr moduleScope = result->getModuleScope(); + moduleScope->returnType = builtinTypes->errorRecoveryTypePack(); + + for (auto& [name, ty] : result->declaredGlobals) + ty = builtinTypes->errorRecoveryType(); + + for (auto& [name, tf] : result->exportedTypeBindings) + tf.type = builtinTypes->errorRecoveryType(); + } + else + { + Luau::check(builtinTypes, NotNull{&unifierState}, logger.get(), sourceModule, result.get()); + } // It would be nice if we could freeze the arenas before doing type // checking, but we'll have to do some work to get there. @@ -1248,7 +1269,7 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect { return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, - environmentScope ? *environmentScope : globals.globalScope, prepareModuleScopeWrap, options, recordJsonLog); + environmentScope ? *environmentScope : globals.globalScope, prepareModuleScopeWrap, options, typeCheckLimits, recordJsonLog); } catch (const InternalCompilerError& err) { diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index d6aafda62..4abc1aa1a 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,6 +14,8 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) +LUAU_FASTFLAGVARIABLE(LuauLintNativeComment, false) + namespace Luau { @@ -2825,6 +2827,12 @@ static void lintComments(LintContext& context, const std::vector& ho "optimize directive uses unknown optimization level '%s', 0..2 expected", level); } } + else if (FFlag::LuauLintNativeComment && first == "native") + { + if (space != std::string::npos) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "native directive has extra symbols at the end of the line"); + } else { static const char* kHotComments[] = { @@ -2833,6 +2841,7 @@ static void lintComments(LintContext& context, const std::vector& ho "nonstrict", "strict", "optimize", + "native", }; if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments))) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 33a8b6eb1..bcad75b04 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -19,7 +19,6 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCyclicUnions, false); -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauTransitiveSubtyping) LUAU_FASTFLAG(DebugLuauReadWriteProperties) @@ -253,6 +252,14 @@ bool NormalizedType::shouldSuppressErrors() const return hasErrors() || get(tops); } +bool NormalizedType::hasTopTable() const +{ + return hasTables() && std::any_of(tables.begin(), tables.end(), [&](TypeId ty) { + auto primTy = get(ty); + return primTy && primTy->type == PrimitiveType::Type::Table; + }); +} + bool NormalizedType::hasTops() const { return !get(tops); diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index f7ed7619a..0cc53d656 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -9,7 +9,6 @@ #include "Luau/VisitType.h" LUAU_FASTFLAG(DebugLuauSharedSelf) -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); namespace Luau { diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 2de381be2..bcd21d262 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -2,6 +2,8 @@ #include "Luau/Scope.h" +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -160,14 +162,12 @@ void Scope::inheritRefinements(const ScopePtr& childScope) dcrRefinements[k] = a; } } - else + + for (const auto& [k, a] : childScope->refinements) { - for (const auto& [k, a] : childScope->refinements) - { - Symbol symbol = getBaseSymbol(k); - if (lookup(symbol)) - refinements[k] = a; - } + Symbol symbol = getBaseSymbol(k); + if (lookup(symbol)) + refinements[k] = a; } } diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index e17df3870..20a9fa57f 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -364,7 +364,13 @@ Relation relate(TypeId left, TypeId right) if (auto ut = get(left)) return Relation::Intersects; else if (auto ut = get(right)) + { + std::vector opts; + for (TypeId part : ut) + if (relate(left, part) == Relation::Subset) + return Relation::Subset; return Relation::Intersects; + } if (auto rnt = get(right)) { diff --git a/Analysis/src/Symbol.cpp b/Analysis/src/Symbol.cpp index 5922bb50e..4b808f194 100644 --- a/Analysis/src/Symbol.cpp +++ b/Analysis/src/Symbol.cpp @@ -3,9 +3,23 @@ #include "Luau/Common.h" +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + namespace Luau { +bool Symbol::operator==(const Symbol& rhs) const +{ + if (local) + return local == rhs.local; + else if (global.value) + return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. + else if (FFlag::DebugLuauDeferredConstraintResolution) + return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. + else + return false; +} + std::string toString(const Symbol& name) { if (name.local) diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 8a9b35684..6446570cc 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -9,8 +9,6 @@ #include #include -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) - namespace Luau { diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index b77f7f159..40a4bd0fc 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -283,7 +283,8 @@ struct TypeChecker2 if (errors.empty()) noTypeFamilyErrors.insert(instance); - reportErrors(std::move(errors)); + if (!isErrorSuppressing(location, instance)) + reportErrors(std::move(errors)); return instance; } @@ -488,7 +489,7 @@ struct TypeChecker2 u.hideousFixMeGenericsAreActuallyFree = true; u.tryUnify(actualRetType, expectedRetType); - const bool ok = u.errors.empty() && u.log.empty(); + const bool ok = (u.errors.empty() && u.log.empty()) || isErrorSuppressing(ret->location, actualRetType, ret->location, expectedRetType); if (!ok) { @@ -526,9 +527,7 @@ struct TypeChecker2 TypeId valueType = value ? lookupType(value) : nullptr; if (valueType) { - ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType); - if (!errors.empty()) - reportErrors(std::move(errors)); + reportErrors(tryUnify(stack.back(), value->location, valueType, annotationType)); } visit(var->annotation); @@ -554,9 +553,7 @@ struct TypeChecker2 if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - ErrorVec errors = tryUnify(stack.back(), value->location, valueTypes.head[j - i], varType); - if (!errors.empty()) - reportErrors(std::move(errors)); + reportErrors(tryUnify(stack.back(), value->location, valueTypes.head[j - i], varType)); visit(var->annotation); } @@ -764,6 +761,11 @@ struct TypeChecker2 } }; + const NormalizedType* iteratorNorm = normalizer.normalize(iteratorTy); + + if (!iteratorNorm) + reportError(NormalizationTooComplex{}, firstValue->location); + /* * If the first iterator argument is a function * * There must be 1 to 3 iterator arguments. Name them (nextTy, @@ -798,7 +800,7 @@ struct TypeChecker2 { // nothing } - else if (isOptional(iteratorTy)) + else if (isOptional(iteratorTy) && !(iteratorNorm && iteratorNorm->shouldSuppressErrors())) { reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location); } @@ -833,7 +835,7 @@ struct TypeChecker2 { checkFunction(nextFtv, instantiatedIteratorTypes, true); } - else + else if (!isErrorSuppressing(forInStatement->values.data[0]->location, *instantiatedNextFn)) { reportError(CannotCallNonFunction{*instantiatedNextFn}, forInStatement->values.data[0]->location); } @@ -843,7 +845,7 @@ struct TypeChecker2 reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); } } - else + else if (!isErrorSuppressing(forInStatement->values.data[0]->location, *iterMmTy)) { // TODO: This will not tell the user that this is because the // metamethod isn't callable. This is not ideal, and we should @@ -859,7 +861,11 @@ struct TypeChecker2 reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); } } - else + else if (iteratorNorm && iteratorNorm->hasTopTable()) + { + // nothing + } + else if (!iteratorNorm || !iteratorNorm->shouldSuppressErrors()) { reportError(CannotCallNonFunction{iteratorTy}, forInStatement->values.data[0]->location); } @@ -882,7 +888,9 @@ struct TypeChecker2 if (get(lhsType)) continue; - if (!isSubtype(rhsType, lhsType, stack.back())) + + if (!isSubtype(rhsType, lhsType, stack.back()) && + !isErrorSuppressing(assign->vars.data[i]->location, lhsType, assign->values.data[i]->location, rhsType)) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } @@ -1064,8 +1072,8 @@ struct TypeChecker2 void visitCall(AstExprCall* call) { TypePack args; - std::vector argLocs; - argLocs.reserve(call->args.size + 1); + std::vector argExprs; + argExprs.reserve(call->args.size + 1); TypeId* originalCallTy = module->astOriginalCallTypes.find(call); TypeId* selectedOverloadTy = module->astOverloadResolvedTypes.find(call); @@ -1088,18 +1096,18 @@ struct TypeChecker2 ice->ice("method call expression has no 'self'"); args.head.push_back(lookupType(indexExpr->expr)); - argLocs.push_back(indexExpr->expr->location); + argExprs.push_back(indexExpr->expr); } else if (findMetatableEntry(builtinTypes, module->errors, *originalCallTy, "__call", call->func->location)) { args.head.insert(args.head.begin(), lookupType(call->func)); - argLocs.push_back(call->func->location); + argExprs.push_back(call->func); } for (size_t i = 0; i < call->args.size; ++i) { AstExpr* arg = call->args.data[i]; - argLocs.push_back(arg->location); + argExprs.push_back(arg); TypeId* argTy = module->astTypes.find(arg); if (argTy) args.head.push_back(*argTy); @@ -1127,20 +1135,20 @@ struct TypeChecker2 call->location, }; - resolver.resolve(fnTy, &args, call->func->location, &argLocs); + resolver.resolve(fnTy, &args, call->func, &argExprs); + auto norm = normalizer.normalize(fnTy); + if (!norm) + reportError(NormalizationTooComplex{}, call->func->location); - if (!resolver.ok.empty()) + if (norm && norm->shouldSuppressErrors()) + return; // error suppressing function type! + else if (!resolver.ok.empty()) return; // We found a call that works, so this is ok. - else if (auto norm = normalizer.normalize(fnTy); !norm || !normalizer.isInhabited(norm)) - { - if (!norm) - reportError(NormalizationTooComplex{}, call->func->location); - else - return; // Ok. Calling an uninhabited type is no-op. - } + else if (!norm || !normalizer.isInhabited(norm)) + return; // Ok. Calling an uninhabited type is no-op. else if (!resolver.nonviableOverloads.empty()) { - if (resolver.nonviableOverloads.size() == 1) + if (resolver.nonviableOverloads.size() == 1 && !isErrorSuppressing(call->func->location, resolver.nonviableOverloads.front().first)) reportErrors(resolver.nonviableOverloads.front().second); else { @@ -1224,13 +1232,26 @@ struct TypeChecker2 InsertionOrderedMap> resolution; private: - template - std::optional tryUnify(const Location& location, Ty subTy, Ty superTy) + std::optional tryUnify(const Location& location, TypeId subTy, TypeId superTy, const LiteralProperties* literalProperties = nullptr) + { + Unifier u{normalizer, scope, location, Covariant}; + u.ctx = CountMismatch::Arg; + u.hideousFixMeGenericsAreActuallyFree = true; + u.enableNewSolver(); + u.tryUnify(subTy, superTy, /*isFunctionCall*/ false, /*isIntersection*/ false, literalProperties); + + if (u.errors.empty()) + return std::nullopt; + + return std::move(u.errors); + } + + std::optional tryUnify(const Location& location, TypePackId subTy, TypePackId superTy) { Unifier u{normalizer, scope, location, Covariant}; u.ctx = CountMismatch::Arg; u.hideousFixMeGenericsAreActuallyFree = true; - u.enableScopeTests(); + u.enableNewSolver(); u.tryUnify(subTy, superTy); if (u.errors.empty()) @@ -1240,7 +1261,7 @@ struct TypeChecker2 } std::pair checkOverload( - TypeId fnTy, const TypePack* args, Location fnLoc, const std::vector* argLocs, bool callMetamethodOk = true) + TypeId fnTy, const TypePack* args, AstExpr* fnLoc, const std::vector* argExprs, bool callMetamethodOk = true) { fnTy = follow(fnTy); @@ -1248,25 +1269,64 @@ struct TypeChecker2 if (get(fnTy) || get(fnTy) || get(fnTy)) return {Ok, {}}; else if (auto fn = get(fnTy)) - return checkOverload_(fnTy, fn, args, fnLoc, argLocs); // Intentionally split to reduce the stack pressure of this function. + return checkOverload_(fnTy, fn, args, fnLoc, argExprs); // Intentionally split to reduce the stack pressure of this function. else if (auto callMm = findMetatableEntry(builtinTypes, discard, fnTy, "__call", callLoc); callMm && callMetamethodOk) { // Calling a metamethod forwards the `fnTy` as self. TypePack withSelf = *args; withSelf.head.insert(withSelf.head.begin(), fnTy); - std::vector withSelfLocs = *argLocs; - withSelfLocs.insert(withSelfLocs.begin(), fnLoc); + std::vector withSelfExprs = *argExprs; + withSelfExprs.insert(withSelfExprs.begin(), fnLoc); - return checkOverload(*callMm, &withSelf, fnLoc, &withSelfLocs, /*callMetamethodOk=*/false); + return checkOverload(*callMm, &withSelf, fnLoc, &withSelfExprs, /*callMetamethodOk=*/false); } else return {TypeIsNotAFunction, {}}; // Intentionally empty. We can just fabricate the type error later on. } + static bool isLiteral(AstExpr* expr) + { + if (auto group = expr->as()) + return isLiteral(group->expr); + else if (auto assertion = expr->as()) + return isLiteral(assertion->expr); + + return + expr->is() || + expr->is() || + expr->is() || + expr->is() || + expr->is() || + expr->is(); + } + + static std::unique_ptr buildLiteralPropertiesSet(AstExpr* expr) + { + const AstExprTable* table = expr->as(); + if (!table) + return nullptr; + + std::unique_ptr result = std::make_unique(Name{}); + + for (const AstExprTable::Item& item : table->items) + { + if (item.kind != AstExprTable::Item::Record) + continue; + + AstExprConstantString* keyExpr = item.key->as(); + LUAU_ASSERT(keyExpr); + + if (isLiteral(item.value)) + result->insert(Name{keyExpr->value.begin(), keyExpr->value.end()}); + } + + return result; + } + LUAU_NOINLINE std::pair checkOverload_( - TypeId fnTy, const FunctionType* fn, const TypePack* args, Location fnLoc, const std::vector* argLocs) + TypeId fnTy, const FunctionType* fn, const TypePack* args, AstExpr* fnExpr, const std::vector* argExprs) { TxnLog fake; FamilyGraphReductionResult result = reduceFamilies(fnTy, callLoc, arena, builtinTypes, scope, normalizer, &fake, /*force=*/true); @@ -1286,9 +1346,11 @@ struct TypeChecker2 TypeId paramTy = *paramIter; TypeId argTy = args->head[argOffset]; - Location argLoc = argLocs->at(argOffset >= argLocs->size() ? argLocs->size() - 1 : argOffset); + AstExpr* argLoc = argExprs->at(argOffset >= argExprs->size() ? argExprs->size() - 1 : argOffset); + + std::unique_ptr literalProperties{buildLiteralPropertiesSet(argLoc)}; - if (auto errors = tryUnify(argLoc, argTy, paramTy)) + if (auto errors = tryUnify(argLoc->location, argTy, paramTy, literalProperties.get())) { // Since we're stopping right here, we need to decide if this is a nonviable overload or if there is an arity mismatch. // If it's a nonviable overload, then we need to keep going to get all type errors. @@ -1308,19 +1370,21 @@ struct TypeChecker2 // If we can iterate over the head of arguments, then we have exhausted the head of the parameters. LUAU_ASSERT(paramIter == end(fn->argTypes)); - Location argLoc = argLocs->at(argOffset >= argLocs->size() ? argLocs->size() - 1 : argOffset); + AstExpr* argExpr = argExprs->at(argOffset >= argExprs->size() ? argExprs->size() - 1 : argOffset); if (!paramIter.tail()) { auto [minParams, optMaxParams] = getParameterExtents(TxnLog::empty(), fn->argTypes); - TypeError error{argLoc, CountMismatch{minParams, optMaxParams, args->head.size(), CountMismatch::Arg, false}}; + TypeError error{argExpr->location, CountMismatch{minParams, optMaxParams, args->head.size(), CountMismatch::Arg, false}}; return {ArityMismatch, {error}}; } else if (auto vtp = get(follow(paramIter.tail()))) { - if (auto errors = tryUnify(argLoc, args->head[argOffset], vtp->ty)) + if (auto errors = tryUnify(argExpr->location, args->head[argOffset], vtp->ty)) argumentErrors.insert(argumentErrors.end(), errors->begin(), errors->end()); } + else if (get(follow(paramIter.tail()))) + argumentErrors.push_back(TypeError{argExpr->location, TypePackMismatch{fn->argTypes, arena->addTypePack(*args)}}); ++argOffset; } @@ -1333,18 +1397,18 @@ struct TypeChecker2 // It may have a tail, however, so check that. if (auto vtp = get(follow(args->tail))) { - Location argLoc = argLocs->at(argLocs->size() - 1); + AstExpr* argExpr = argExprs->at(argExprs->size() - 1); - if (auto errors = tryUnify(argLoc, vtp->ty, *paramIter)) + if (auto errors = tryUnify(argExpr->location, vtp->ty, *paramIter)) argumentErrors.insert(argumentErrors.end(), errors->begin(), errors->end()); } else if (!isOptional(*paramIter)) { - Location argLoc = argLocs->empty() ? fnLoc : argLocs->at(argLocs->size() - 1); + AstExpr* argExpr = argExprs->empty() ? fnExpr : argExprs->at(argExprs->size() - 1); // It is ok to have excess parameters as long as they are all optional. auto [minParams, optMaxParams] = getParameterExtents(TxnLog::empty(), fn->argTypes); - TypeError error{argLoc, CountMismatch{minParams, optMaxParams, args->head.size(), CountMismatch::Arg, false}}; + TypeError error{argExpr->location, CountMismatch{minParams, optMaxParams, args->head.size(), CountMismatch::Arg, false}}; return {ArityMismatch, {error}}; } @@ -1355,13 +1419,27 @@ struct TypeChecker2 LUAU_ASSERT(paramIter == end(fn->argTypes)); LUAU_ASSERT(argOffset == args->head.size()); + const Location argLoc = argExprs->empty() ? Location{} // TODO + : argExprs->at(argExprs->size() - 1)->location; + if (paramIter.tail() && args->tail) { - Location argLoc = argLocs->at(argLocs->size() - 1); - if (auto errors = tryUnify(argLoc, *args->tail, *paramIter.tail())) argumentErrors.insert(argumentErrors.end(), errors->begin(), errors->end()); } + else if (paramIter.tail()) + { + const TypePackId paramTail = follow(*paramIter.tail()); + + if (get(paramTail)) + { + argumentErrors.push_back(TypeError{argLoc, TypePackMismatch{fn->argTypes, arena->addTypePack(*args)}}); + } + else if (get(paramTail)) + { + // Nothing. This is ok. + } + } return {argumentErrors.empty() ? Ok : OverloadIsNonviable, argumentErrors}; } @@ -1409,14 +1487,14 @@ struct TypeChecker2 } public: - void resolve(TypeId fnTy, const TypePack* args, Location selfLoc, const std::vector* argLocs) + void resolve(TypeId fnTy, const TypePack* args, AstExpr* selfExpr, const std::vector* argExprs) { fnTy = follow(fnTy); auto it = get(fnTy); if (!it) { - auto [analysis, errors] = checkOverload(fnTy, args, selfLoc, argLocs); + auto [analysis, errors] = checkOverload(fnTy, args, selfExpr, argExprs); add(analysis, fnTy, std::move(errors)); return; } @@ -1426,7 +1504,7 @@ struct TypeChecker2 if (resolution.find(ty) != resolution.end()) continue; - auto [analysis, errors] = checkOverload(ty, args, selfLoc, argLocs); + auto [analysis, errors] = checkOverload(ty, args, selfExpr, argExprs); add(analysis, ty, std::move(errors)); } } @@ -1508,8 +1586,7 @@ struct TypeChecker2 return; } - // TODO! - visit(indexExpr->expr, ValueContext::LValue); + visit(indexExpr->expr, ValueContext::RValue); visit(indexExpr->index, ValueContext::RValue); NotNull scope = stack.back(); @@ -1576,7 +1653,8 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(inferredArgTy, annotatedArgTy, stack.back())) + if (!isSubtype(inferredArgTy, annotatedArgTy, stack.back()) && + !isErrorSuppressing(arg->location, inferredArgTy, arg->annotation->location, annotatedArgTy)) { reportError(TypeMismatch{inferredArgTy, annotatedArgTy}, arg->location); } @@ -1610,7 +1688,7 @@ struct TypeChecker2 TypeId operandType = lookupType(expr->expr); TypeId resultType = lookupType(expr); - if (get(operandType) || get(operandType) || get(operandType)) + if (isErrorSuppressing(expr->expr->location, operandType)) return; if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end()) @@ -1645,7 +1723,7 @@ struct TypeChecker2 TypeId expectedFunction = testArena.addType(FunctionType{expectedArgs, expectedRet}); ErrorVec errors = tryUnify(scope, expr->location, *mm, expectedFunction); - if (!errors.empty()) + if (!errors.empty() && !isErrorSuppressing(expr->expr->location, *firstArg, expr->expr->location, operandType)) { reportError(TypeMismatch{*firstArg, operandType}, expr->location); return; @@ -1660,7 +1738,10 @@ struct TypeChecker2 { DenseHashSet seen{nullptr}; int recursionCount = 0; + const NormalizedType* nty = normalizer.normalize(operandType); + if (nty && nty->shouldSuppressErrors()) + return; if (!hasLength(operandType, seen, &recursionCount)) { @@ -1719,6 +1800,8 @@ struct TypeChecker2 return leftType; else if (get(rightType) || get(rightType) || get(rightType)) return rightType; + else if ((normLeft && normLeft->shouldSuppressErrors()) || (normRight && normRight->shouldSuppressErrors())) + return builtinTypes->anyType; // we can't say anything better if it's error suppressing but not any or error alone. if ((get(leftType) || get(leftType) || get(leftType)) && !isEquality && !isLogical) { @@ -1933,6 +2016,9 @@ struct TypeChecker2 case AstExprBinary::Op::CompareLe: case AstExprBinary::Op::CompareLt: { + if (normLeft && normLeft->shouldSuppressErrors()) + return builtinTypes->numberType; + if (normLeft && normLeft->isExactlyNumber()) { reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); @@ -2324,7 +2410,7 @@ struct TypeChecker2 TypeArena arena; Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; u.hideousFixMeGenericsAreActuallyFree = genericsOkay; - u.enableScopeTests(); + u.enableNewSolver(); u.tryUnify(subTy, superTy); const bool ok = u.errors.empty() && u.log.empty(); @@ -2338,9 +2424,12 @@ struct TypeChecker2 Unifier u{NotNull{&normalizer}, scope, location, Covariant}; u.ctx = context; u.hideousFixMeGenericsAreActuallyFree = genericsOkay; - u.enableScopeTests(); + u.enableNewSolver(); u.tryUnify(subTy, superTy); + if (isErrorSuppressing(location, subTy, location, superTy)) + return {}; + return std::move(u.errors); } @@ -2376,6 +2465,7 @@ struct TypeChecker2 return; } + // if the type is error suppressing, we don't actually have any work left to do. if (norm->shouldSuppressErrors()) return; @@ -2542,6 +2632,50 @@ struct TypeChecker2 if (!candidates.empty()) data = TypeErrorData(UnknownPropButFoundLikeProp{utk->table, utk->key, candidates}); } + + bool isErrorSuppressing(Location loc, TypeId ty) + { + switch (shouldSuppressErrors(NotNull{&normalizer}, ty)) + { + case ErrorSuppression::DoNotSuppress: + return false; + case ErrorSuppression::Suppress: + return true; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, loc); + return false; + }; + + LUAU_ASSERT(false); + return false; // UNREACHABLE + } + + bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2) + { + return isErrorSuppressing(loc1, ty1) || isErrorSuppressing(loc2, ty2); + } + + bool isErrorSuppressing(Location loc, TypePackId tp) + { + switch (shouldSuppressErrors(NotNull{&normalizer}, tp)) + { + case ErrorSuppression::DoNotSuppress: + return false; + case ErrorSuppression::Suppress: + return true; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, loc); + return false; + }; + + LUAU_ASSERT(false); + return false; // UNREACHABLE + } + + bool isErrorSuppressing(Location loc1, TypePackId tp1, Location loc2, TypePackId tp2) + { + return isErrorSuppressing(loc1, tp1) || isErrorSuppressing(loc2, tp2); + } }; void check( diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index cfb0f21cc..a80250969 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) +LUAU_FASTFLAGVARIABLE(LuauFixCyclicModuleExports, false) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) @@ -269,6 +270,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule.reset(new Module); currentModule->name = module.name; currentModule->humanReadableName = module.humanReadableName; + currentModule->internalTypes.owningModule = currentModule.get(); + currentModule->interfaceTypes.owningModule = currentModule.get(); currentModule->type = module.type; currentModule->allocator = module.allocator; currentModule->names = module.names; @@ -1193,6 +1196,19 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) { scope->importedTypeBindings[name] = module->exportedTypeBindings; scope->importedModules[name] = moduleInfo->name; + + if (FFlag::LuauFixCyclicModuleExports) + { + // Imported types of requires that transitively refer to current module have to be replaced with 'any' + for (const auto& [location, path] : requireCycles) + { + if (!path.empty() && path.front() == moduleInfo->name) + { + for (auto& [name, tf] : scope->importedTypeBindings[name]) + tf = TypeFun{{}, {}, anyType}; + } + } + } } // In non-strict mode we force the module type on the variable, in strict mode it is already unified diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 9124e2fc5..4f87de8f0 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -295,4 +295,58 @@ TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty) return follow(ty); } +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty) +{ + const NormalizedType* normType = normalizer->normalize(ty); + + if (!normType) + return ErrorSuppression::NormalizationFailed; + + return (normType->shouldSuppressErrors()) ? ErrorSuppression::Suppress : ErrorSuppression::DoNotSuppress; +} + +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp) +{ + auto [tys, tail] = flatten(tp); + + // check the head, one type at a time + for (TypeId ty : tys) + { + auto result = shouldSuppressErrors(normalizer, ty); + if (result != ErrorSuppression::DoNotSuppress) + return result; + } + + // check the tail if we have one and it's finite + if (tail && finite(*tail)) + return shouldSuppressErrors(normalizer, *tail); + + return ErrorSuppression::DoNotSuppress; +} + +// This is a useful helper because it is often the case that we are looking at specifically a pair of types that might suppress. +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty1, TypeId ty2) +{ + auto result = shouldSuppressErrors(normalizer, ty1); + + // if ty1 is do not suppress, ty2 determines our overall behavior + if (result == ErrorSuppression::DoNotSuppress) + return shouldSuppressErrors(normalizer, ty2); + + // otherwise, ty1 is either suppress or normalization failure which are both the appropriate overarching result + return result; +} + +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp1, TypePackId tp2) +{ + auto result = shouldSuppressErrors(normalizer, tp1); + + // if tp1 is do not suppress, tp2 determines our overall behavior + if (result == ErrorSuppression::DoNotSuppress) + return shouldSuppressErrors(normalizer, tp2); + + // otherwise, tp1 is either suppress or normalization failure which are both the appropriate overarching result + return result; +} + } // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index e54156feb..c1b5e45e1 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,7 +22,6 @@ LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) @@ -50,7 +49,7 @@ struct PromoteTypeLevels final : TypeOnceVisitor template void promote(TID ty, T* t) { - if (FFlag::DebugLuauDeferredConstraintResolution && !t) + if (useScopes && !t) return; LUAU_ASSERT(t); @@ -369,7 +368,6 @@ static std::optional> getTableMatchT return std::nullopt; } -// TODO: Inline and clip with FFlag::DebugLuauDeferredConstraintResolution template static bool subsumes(bool useScopes, TY_A* left, TY_B* right) { @@ -406,11 +404,11 @@ Unifier::Unifier(NotNull normalizer, NotNull scope, const Loc LUAU_ASSERT(sharedState.iceHandler); } -void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection, const LiteralProperties* literalProperties) { sharedState.counters.iterationCount = 0; - tryUnify_(subTy, superTy, isFunctionCall, isIntersection); + tryUnify_(subTy, superTy, isFunctionCall, isIntersection, literalProperties); } static bool isBlocked(const TxnLog& log, TypeId ty) @@ -425,7 +423,7 @@ static bool isBlocked(const TxnLog& log, TypePackId tp) return get(tp); } -void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection, const LiteralProperties* literalProperties) { RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); @@ -443,6 +441,16 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; + if (isBlocked(log, subTy) && isBlocked(log, superTy)) + { + blockedTypes.push_back(subTy); + blockedTypes.push_back(superTy); + } + else if (isBlocked(log, subTy)) + blockedTypes.push_back(subTy); + else if (isBlocked(log, superTy)) + blockedTypes.push_back(superTy); + if (log.get(superTy)) { // We do not report errors from reducing here. This is because we will @@ -470,7 +478,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); - if (superFree && subFree && subsumes(useScopes, superFree, subFree)) + if (superFree && subFree && subsumes(useNewSolver, superFree, subFree)) { if (!occursCheck(subTy, superTy, /* reversed = */ false)) log.replace(subTy, BoundType(superTy)); @@ -481,7 +489,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { if (!occursCheck(superTy, subTy, /* reversed = */ true)) { - if (subsumes(useScopes, superFree, subFree)) + if (subsumes(useNewSolver, superFree, subFree)) { log.changeLevel(subTy, superFree->level); } @@ -495,7 +503,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { // Unification can't change the level of a generic. auto subGeneric = log.getMutable(subTy); - if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) + if (subGeneric && !subsumes(useNewSolver, subGeneric, superFree)) { // TODO: a more informative error message? CLI-39912 reportError(location, GenericError{"Generic subtype escaping scope"}); @@ -504,7 +512,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursCheck(superTy, subTy, /* reversed = */ true)) { - promoteTypeLevels(log, types, superFree->level, superFree->scope, useScopes, subTy); + promoteTypeLevels(log, types, superFree->level, superFree->scope, useNewSolver, subTy); Widen widen{types, builtinTypes}; log.replace(superTy, BoundType(widen(subTy))); @@ -521,7 +529,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto superGeneric = log.getMutable(superTy); - if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) + if (superGeneric && !subsumes(useNewSolver, superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 reportError(location, GenericError{"Generic supertype escaping scope"}); @@ -530,7 +538,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursCheck(subTy, superTy, /* reversed = */ false)) { - promoteTypeLevels(log, types, subFree->level, subFree->scope, useScopes, superTy); + promoteTypeLevels(log, types, subFree->level, subFree->scope, useNewSolver, superTy); log.replace(subTy, BoundType(superTy)); } @@ -542,7 +550,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool auto superGeneric = log.getMutable(superTy); auto subGeneric = log.getMutable(subTy); - if (superGeneric && subGeneric && subsumes(useScopes, superGeneric, subGeneric)) + if (superGeneric && subGeneric && subsumes(useNewSolver, superGeneric, subGeneric)) { if (!occursCheck(subTy, superTy, /* reversed = */ false)) log.replace(subTy, BoundType(superTy)); @@ -637,16 +645,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (isBlocked(log, subTy) && isBlocked(log, superTy)) - { - blockedTypes.push_back(subTy); - blockedTypes.push_back(superTy); - } - else if (isBlocked(log, subTy)) - blockedTypes.push_back(subTy); - else if (isBlocked(log, superTy)) - blockedTypes.push_back(superTy); - else if (const UnionType* subUnion = log.getMutable(subTy)) + if (const UnionType* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); } @@ -717,7 +716,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(superTy) && log.getMutable(subTy)) { - tryUnifyTables(subTy, superTy, isIntersection); + tryUnifyTables(subTy, superTy, isIntersection, literalProperties); } else if (log.get(superTy) && (log.get(subTy) || log.get(subTy))) { @@ -772,7 +771,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ Unifier innerState = makeChildUnifier(); innerState.tryUnify_(type, superTy); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) logs.push_back(std::move(innerState.log)); if (auto e = hasUnificationTooComplex(innerState.errors)) @@ -955,7 +954,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp if (FFlag::LuauTransitiveSubtyping ? !innerState.failure : innerState.errors.empty()) { found = true; - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) logs.push_back(std::move(innerState.log)); else { @@ -980,7 +979,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); if (unificationTooComplex) @@ -1061,14 +1060,14 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I firstFailedOption = {innerState.errors.front()}; } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) logs.push_back(std::move(innerState.log)); else log.concat(std::move(innerState.log)); failure |= innerState.failure; } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) log.concat(combineLogsIntoIntersection(std::move(logs))); if (unificationTooComplex) @@ -1118,7 +1117,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* } } - if (FFlag::DebugLuauDeferredConstraintResolution && normalize) + if (useNewSolver && normalize) { // We cannot normalize a type that contains blocked types. We have to // stop for now if we find any. @@ -1161,7 +1160,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* { found = true; errorsSuppressed = innerState.failure; - if (FFlag::DebugLuauDeferredConstraintResolution || (FFlag::LuauTransitiveSubtyping && innerState.failure)) + if (useNewSolver || (FFlag::LuauTransitiveSubtyping && innerState.failure)) logs.push_back(std::move(innerState.log)); else { @@ -1176,7 +1175,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* } } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) log.concat(combineLogsIntoIntersection(std::move(logs))); else if (FFlag::LuauTransitiveSubtyping && errorsSuppressed) log.concat(std::move(logs.front())); @@ -1296,7 +1295,7 @@ void Unifier::tryUnifyNormalizedTypes( } } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) { for (TypeId superTable : superNorm.tables) { @@ -1527,6 +1526,15 @@ struct WeirdIter return pack != nullptr && index < pack->head.size(); } + std::optional tail() const + { + if (!pack) + return packId; + + LUAU_ASSERT(index == pack->head.size()); + return pack->tail; + } + bool advance() { if (!pack) @@ -1588,9 +1596,9 @@ struct WeirdIter } }; -void Unifier::enableScopeTests() +void Unifier::enableNewSolver() { - useScopes = true; + useNewSolver = true; log.useScopes = true; } @@ -1664,25 +1672,26 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal blockedTypePacks.push_back(superTp); } else if (isBlocked(log, subTp)) - { blockedTypePacks.push_back(subTp); - } else if (isBlocked(log, superTp)) - { blockedTypePacks.push_back(superTp); - } - if (log.getMutable(superTp)) + + if (auto superFree = log.getMutable(superTp)) { if (!occursCheck(superTp, subTp, /* reversed = */ true)) { Widen widen{types, builtinTypes}; + if (useNewSolver) + promoteTypeLevels(log, types, superFree->level, superFree->scope, /*useScopes*/ true, subTp); log.replace(superTp, Unifiable::Bound(widen(subTp))); } } - else if (log.getMutable(subTp)) + else if (auto subFree = log.getMutable(subTp)) { if (!occursCheck(subTp, superTp, /* reversed = */ false)) { + if (useNewSolver) + promoteTypeLevels(log, types, subFree->level, subFree->scope, /*useScopes*/ true, superTp); log.replace(subTp, Unifiable::Bound(superTp)); } } @@ -1771,28 +1780,74 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; - const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) - { - tryUnify_(*subTpv->tail, *superTpv->tail); - } - else if (lFreeTail) + if (useNewSolver) { - tryUnify_(emptyTp, *superTpv->tail); - } - else if (rFreeTail) - { - tryUnify_(emptyTp, *subTpv->tail); + if (subIter.tail() && superIter.tail()) + tryUnify_(*subIter.tail(), *superIter.tail()); + else if (subIter.tail()) + { + const TypePackId subTail = log.follow(*subIter.tail()); + + if (log.get(subTail)) + tryUnify_(subTail, emptyTp); + else if (log.get(subTail)) + reportError(location, TypePackMismatch{subTail, emptyTp}); + else if (log.get(subTail) || log.get(subTail)) + { + // Nothing. This is ok. + } + else + { + ice("Unexpected subtype tail pack " + toString(subTail), location); + } + } + else if (superIter.tail()) + { + const TypePackId superTail = log.follow(*superIter.tail()); + + if (log.get(superTail)) + tryUnify_(emptyTp, superTail); + else if (log.get(superTail)) + reportError(location, TypePackMismatch{emptyTp, superTail}); + else if (log.get(superTail) || log.get(superTail)) + { + // Nothing. This is ok. + } + else + { + ice("Unexpected supertype tail pack " + toString(superTail), location); + } + } + else + { + // Nothing. This is ok. + } } - else if (subTpv->tail && superTpv->tail) + else { - if (log.getMutable(superIter.packId)) - tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); - else if (log.getMutable(subIter.packId)) - tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); - else + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; + if (lFreeTail && rFreeTail) + { tryUnify_(*subTpv->tail, *superTpv->tail); + } + else if (lFreeTail) + { + tryUnify_(emptyTp, *superTpv->tail); + } + else if (rFreeTail) + { + tryUnify_(emptyTp, *subTpv->tail); + } + else if (subTpv->tail && superTpv->tail) + { + if (log.getMutable(superIter.packId)) + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + else if (log.getMutable(subIter.packId)) + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + else + tryUnify_(*subTpv->tail, *superTpv->tail); + } } break; @@ -2049,7 +2104,7 @@ struct Resetter } // namespace -void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) +void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, const LiteralProperties* literalProperties) { if (isPrim(log.follow(subTy), PrimitiveType::Table)) subTy = builtinTypes->emptyTableType; @@ -2134,7 +2189,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { // TODO: read-only properties don't need invariance Resetter resetter{&variance}; - variance = Invariant; + if (!literalProperties || !literalProperties->contains(name)) + variance = Invariant; Unifier innerState = makeChildUnifier(); innerState.tryUnify_(r->second.type(), prop.type()); @@ -2150,7 +2206,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. Resetter resetter{&variance}; - variance = Invariant; + if (!literalProperties || !literalProperties->contains(name)) + variance = Invariant; Unifier innerState = makeChildUnifier(); innerState.tryUnify_(subTable->indexer->indexResultType, prop.type()); @@ -2213,10 +2270,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. Resetter resetter{&variance}; - variance = Invariant; + if (!literalProperties || !literalProperties->contains(name)) + variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTable->indexer->indexResultType, prop.type()); + if (useNewSolver) + innerState.tryUnify_(prop.type(), superTable->indexer->indexResultType); + else + { + // Incredibly, the old solver depends on this bug somehow. + innerState.tryUnify_(superTable->indexer->indexResultType, prop.type()); + } checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); @@ -2478,7 +2542,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { case TableState::Free: { - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) { Unifier innerState = makeChildUnifier(); bool missingProperty = false; @@ -2843,8 +2907,8 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) { - LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); - TxnLog result(useScopes); + LUAU_ASSERT(useNewSolver); + TxnLog result(useNewSolver); for (TxnLog& log : logs) result.concatAsIntersections(std::move(log), NotNull{types}); return result; @@ -2852,7 +2916,7 @@ TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) TxnLog Unifier::combineLogsIntoUnion(std::vector logs) { - TxnLog result(useScopes); + TxnLog result(useNewSolver); for (TxnLog& log : logs) result.concatAsUnion(std::move(log), NotNull{types}); return result; @@ -3012,8 +3076,8 @@ Unifier Unifier::makeChildUnifier() u.normalize = normalize; u.checkInhabited = checkInhabited; - if (useScopes) - u.enableScopeTests(); + if (useNewSolver) + u.enableNewSolver(); return u; } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index df1b4edaf..b2a523d99 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -274,9 +274,6 @@ std::string runCode(lua_State* L, const std::string& source) return error; } - if (codegen) - Luau::CodeGen::compile(L, -1); - lua_State* T = lua_newthread(L); lua_pushvalue(L, -2); diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index febd021cd..002b4a994 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -12,12 +12,18 @@ namespace Luau namespace CodeGen { +enum CodeGenFlags +{ + // Only run native codegen for modules that have been marked with --!native + CodeGen_OnlyNativeModules = 1 << 0, +}; + bool isSupported(); void create(lua_State* L); // Builds target function and all inner functions -void compile(lua_State* L, int idx); +void compile(lua_State* L, int idx, unsigned int flags = 0); using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 0b38743ac..b950a8eca 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -77,6 +77,12 @@ enum class IrCmd : uint8_t // B: unsigned int (hash) GET_HASH_NODE_ADDR, + // Get pointer (TValue) to Closure upvalue. + // A: pointer or undef (Closure) + // B: UPn + // When undef is specified, uses current function Closure. + GET_CLOSURE_UPVAL_ADDR, + // Store a tag into TValue // A: Rn // B: tag @@ -542,10 +548,10 @@ enum class IrCmd : uint8_t FALLBACK_GETVARARGS, // Create closure from a child proto - // A: unsigned int (bytecode instruction index) - // B: Rn (dest) + // A: unsigned int (nups) + // B: pointer (table) // C: unsigned int (protoid) - FALLBACK_NEWCLOSURE, + NEWCLOSURE, // Create closure from a pre-created function object (reusing it unless environments diverge) // A: unsigned int (bytecode instruction index) @@ -600,6 +606,10 @@ enum class IrCmd : uint8_t // Returns the string name of a type either from a __type metatable field or just based on the tag, alternative for typeof(x) // A: Rn GET_TYPEOF, + + // Find or create an upval at the given level + // A: Rn (level) + FINDUPVAL, }; enum class IrConstKind : uint8_t diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 6481342f5..9a9b84b40 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -150,6 +150,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::GET_ARR_ADDR: case IrCmd::GET_SLOT_NODE_ADDR: case IrCmd::GET_HASH_NODE_ADDR: + case IrCmd::GET_CLOSURE_UPVAL_ADDR: case IrCmd::ADD_INT: case IrCmd::SUB_INT: case IrCmd::ADD_NUM: @@ -192,6 +193,8 @@ inline bool hasResult(IrCmd cmd) case IrCmd::INVOKE_LIBM: case IrCmd::GET_TYPE: case IrCmd::GET_TYPEOF: + case IrCmd::NEWCLOSURE: + case IrCmd::FINDUPVAL: return true; default: break; diff --git a/CodeGen/include/luacodegen.h b/CodeGen/include/luacodegen.h index 654fc2c90..9de15e63c 100644 --- a/CodeGen/include/luacodegen.h +++ b/CodeGen/include/luacodegen.h @@ -6,7 +6,7 @@ #define LUACODEGEN_API extern #endif -struct lua_State; +typedef struct lua_State lua_State; // returns 1 if Luau code generator is supported, 0 otherwise LUACODEGEN_API int luau_codegen_supported(void); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index cdb761c6a..602130f1b 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -239,7 +239,7 @@ void create(lua_State* L) ecb->enter = onEnter; } -void compile(lua_State* L, int idx) +void compile(lua_State* L, int idx, unsigned int flags) { LUAU_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); @@ -249,6 +249,13 @@ void compile(lua_State* L, int idx) if (!data) return; + Proto* root = clvalue(func)->l.p; + if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) + return; + + std::vector protos; + gatherFunctions(protos, root); + #if defined(__aarch64__) static unsigned int cpuFeatures = getCpuFeaturesA64(); A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); @@ -256,9 +263,6 @@ void compile(lua_State* L, int idx) X64::AssemblyBuilderX64 build(/* logText= */ false); #endif - std::vector protos; - gatherFunctions(protos, clvalue(func)->l.p); - ModuleHelpers helpers; #if defined(__aarch64__) A64::assembleHelpers(build, helpers); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index cf7161ef0..62c0b8ab2 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -427,9 +427,6 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::FALLBACK_GETVARARGS: defRange(vmRegOp(inst.b), function.intOp(inst.c)); break; - case IrCmd::FALLBACK_NEWCLOSURE: - def(inst.b); - break; case IrCmd::FALLBACK_DUPCLOSURE: def(inst.b); break; @@ -448,6 +445,10 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& use(inst.a); break; + case IrCmd::FINDUPVAL: + use(inst.a); + break; + default: // All instructions which reference registers have to be handled explicitly LUAU_ASSERT(inst.a.kind != IrOpKind::VmReg); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 04318effb..52d0a0b56 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -452,7 +452,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) inst(IrCmd::FALLBACK_GETVARARGS, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); break; case LOP_NEWCLOSURE: - inst(IrCmd::FALLBACK_NEWCLOSURE, constUint(i), vmReg(LUAU_INSN_A(*pc)), constUint(LUAU_INSN_D(*pc))); + translateInstNewClosure(*this, pc, i); break; case LOP_DUPCLOSURE: inst(IrCmd::FALLBACK_DUPCLOSURE, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmConst(LUAU_INSN_D(*pc))); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index c44cd8eb8..ce0cbfb3f 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -61,6 +61,12 @@ static const char* getTagName(uint8_t tag) return "tuserdata"; case LUA_TTHREAD: return "tthread"; + case LUA_TPROTO: + return "tproto"; + case LUA_TUPVAL: + return "tupval"; + case LUA_TDEADKEY: + return "tdeadkey"; default: LUAU_ASSERT(!"Unknown type tag"); LUAU_UNREACHABLE(); @@ -93,6 +99,8 @@ const char* getCmdName(IrCmd cmd) return "GET_SLOT_NODE_ADDR"; case IrCmd::GET_HASH_NODE_ADDR: return "GET_HASH_NODE_ADDR"; + case IrCmd::GET_CLOSURE_UPVAL_ADDR: + return "GET_CLOSURE_UPVAL_ADDR"; case IrCmd::STORE_TAG: return "STORE_TAG"; case IrCmd::STORE_POINTER: @@ -267,8 +275,8 @@ const char* getCmdName(IrCmd cmd) return "FALLBACK_PREPVARARGS"; case IrCmd::FALLBACK_GETVARARGS: return "FALLBACK_GETVARARGS"; - case IrCmd::FALLBACK_NEWCLOSURE: - return "FALLBACK_NEWCLOSURE"; + case IrCmd::NEWCLOSURE: + return "NEWCLOSURE"; case IrCmd::FALLBACK_DUPCLOSURE: return "FALLBACK_DUPCLOSURE"; case IrCmd::FALLBACK_FORGPREP: @@ -303,6 +311,8 @@ const char* getCmdName(IrCmd cmd) return "GET_TYPE"; case IrCmd::GET_TYPEOF: return "GET_TYPEOF"; + case IrCmd::FINDUPVAL: + return "FINDUPVAL"; } LUAU_UNREACHABLE(); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 92cb49adb..3c247abdc 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -314,6 +314,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.add(inst.regA64, inst.regA64, zextReg(temp2), kLuaNodeSizeLog2); break; } + case IrCmd::GET_CLOSURE_UPVAL_ADDR: + { + inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); + RegisterA64 cl = inst.a.kind == IrOpKind::Undef ? rClosure : regOp(inst.a); + + build.add(inst.regA64, cl, uint16_t(offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.b))); + break; + } case IrCmd::STORE_TAG: { RegisterA64 temp = regs.allocTemp(KindA64::w); @@ -1044,13 +1052,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) bool continueInVm = (inst.d.kind == IrOpKind::Constant && intOp(inst.d)); Label fresh; // used when guard aborts execution or jumps to a VM exit Label& fail = continueInVm ? helpers.exitContinueVmClearNativeFlag : getTargetLabel(inst.c, fresh); + + // To support DebugLuauAbortingChecks, CHECK_TAG with VmReg has to be handled + RegisterA64 tag = inst.a.kind == IrOpKind::VmReg ? regs.allocTemp(KindA64::w) : regOp(inst.a); + + if (inst.a.kind == IrOpKind::VmReg) + build.ldr(tag, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); + if (tagOp(inst.b) == 0) { - build.cbnz(regOp(inst.a), fail); + build.cbnz(tag, fail); } else { - build.cmp(regOp(inst.a), tagOp(inst.b)); + build.cmp(tag, tagOp(inst.b)); build.b(ConditionA64::NotEqual, fail); } if (!continueInVm) @@ -1517,13 +1532,26 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.spill(build, index); emitFallback(build, offsetof(NativeContext, executeGETVARARGS), uintOp(inst.a)); break; - case IrCmd::FALLBACK_NEWCLOSURE: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + case IrCmd::NEWCLOSURE: + { + RegisterA64 reg = regOp(inst.b); // note: we need to call regOp before spill so that we don't do redundant reloads - regs.spill(build, index); - emitFallback(build, offsetof(NativeContext, executeNEWCLOSURE), uintOp(inst.a)); + regs.spill(build, index, {reg}); + build.mov(x2, reg); + + build.mov(x0, rState); + build.mov(w1, uintOp(inst.a)); + + build.ldr(x3, mem(rClosure, offsetof(Closure, l.p))); + build.ldr(x3, mem(x3, offsetof(Proto, p))); + build.ldr(x3, mem(x3, sizeof(Proto*) * uintOp(inst.c))); + + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaF_newLclosure))); + build.blr(x4); + + inst.regA64 = regs.takeReg(x0, index); break; + } case IrCmd::FALLBACK_DUPCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); @@ -1743,6 +1771,18 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } + case IrCmd::FINDUPVAL: + { + regs.spill(build, index); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaF_findupval))); + build.blr(x2); + + inst.regA64 = regs.takeReg(x0, index); + break; + } + // To handle unsupported instructions, add "case IrCmd::OP" and make sure to set error = true! } diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 670d60666..e791e55d5 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -186,14 +186,40 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.add(inst.regX64, tmp.reg); break; }; + case IrCmd::GET_CLOSURE_UPVAL_ADDR: + { + inst.regX64 = regs.allocRegOrReuse(SizeX64::qword, index, {inst.a}); + + if (inst.a.kind == IrOpKind::Undef) + { + build.mov(inst.regX64, sClosure); + } + else + { + RegisterX64 cl = regOp(inst.a); + if (inst.regX64 != cl) + build.mov(inst.regX64, cl); + } + + build.add(inst.regX64, offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.b)); + break; + } case IrCmd::STORE_TAG: if (inst.b.kind == IrOpKind::Constant) - build.mov(luauRegTag(vmRegOp(inst.a)), tagOp(inst.b)); + { + if (inst.a.kind == IrOpKind::Inst) + build.mov(dword[regOp(inst.a) + offsetof(TValue, tt)], tagOp(inst.b)); + else + build.mov(luauRegTag(vmRegOp(inst.a)), tagOp(inst.b)); + } else LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::STORE_POINTER: - build.mov(luauRegValue(vmRegOp(inst.a)), regOp(inst.b)); + if (inst.a.kind == IrOpKind::Inst) + build.mov(qword[regOp(inst.a) + offsetof(TValue, value)], regOp(inst.b)); + else + build.mov(luauRegValue(vmRegOp(inst.a)), regOp(inst.b)); break; case IrCmd::STORE_DOUBLE: if (inst.b.kind == IrOpKind::Constant) @@ -1207,12 +1233,25 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) emitFallback(regs, build, offsetof(NativeContext, executeGETVARARGS), uintOp(inst.a)); break; - case IrCmd::FALLBACK_NEWCLOSURE: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + case IrCmd::NEWCLOSURE: + { + ScopedRegX64 tmp2{regs, SizeX64::qword}; + build.mov(tmp2.reg, sClosure); + build.mov(tmp2.reg, qword[tmp2.reg + offsetof(Closure, l.p)]); + build.mov(tmp2.reg, qword[tmp2.reg + offsetof(Proto, p)]); + build.mov(tmp2.reg, qword[tmp2.reg + sizeof(Proto*) * uintOp(inst.c)]); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, uintOp(inst.a), inst.a); + callWrap.addArgument(SizeX64::qword, regOp(inst.b), inst.b); + callWrap.addArgument(SizeX64::qword, tmp2); - emitFallback(regs, build, offsetof(NativeContext, executeNEWCLOSURE), uintOp(inst.a)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaF_newLclosure)]); + + inst.regX64 = regs.takeReg(rax, index); break; + } case IrCmd::FALLBACK_DUPCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); @@ -1412,6 +1451,17 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } + case IrCmd::FINDUPVAL: + { + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(vmRegOp(inst.a))); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaF_findupval)]); + + inst.regX64 = regs.takeReg(rax, index); + break; + } + // Pseudo instructions case IrCmd::NOP: case IrCmd::SUBSTITUTE: diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 5cde510ff..63e756e16 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -1213,5 +1213,61 @@ void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c } } +void translateInstNewClosure(IrBuilder& build, const Instruction* pc, int pcpos) +{ + LUAU_ASSERT(unsigned(LUAU_INSN_D(*pc)) < unsigned(build.function.proto->sizep)); + + int ra = LUAU_INSN_A(*pc); + Proto* pv = build.function.proto->p[LUAU_INSN_D(*pc)]; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + + IrOp env = build.inst(IrCmd::LOAD_ENV); + IrOp ncl = build.inst(IrCmd::NEWCLOSURE, build.constUint(pv->nups), env, build.constUint(LUAU_INSN_D(*pc))); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(ra), ncl); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TFUNCTION)); + + for (int ui = 0; ui < pv->nups; ++ui) + { + Instruction uinsn = pc[ui + 1]; + LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); + + IrOp dst = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, ncl, build.vmUpvalue(ui)); + + switch (LUAU_INSN_A(uinsn)) + { + case LCT_VAL: + { + IrOp src = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(LUAU_INSN_B(uinsn))); + build.inst(IrCmd::STORE_TVALUE, dst, src); + break; + } + + case LCT_REF: + { + IrOp src = build.inst(IrCmd::FINDUPVAL, build.vmReg(LUAU_INSN_B(uinsn))); + build.inst(IrCmd::STORE_POINTER, dst, src); + build.inst(IrCmd::STORE_TAG, dst, build.constTag(LUA_TUPVAL)); + break; + } + + case LCT_UPVAL: + { + IrOp src = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, build.undef(), build.vmUpvalue(LUAU_INSN_B(uinsn))); + IrOp load = build.inst(IrCmd::LOAD_TVALUE, src); + build.inst(IrCmd::STORE_TVALUE, dst, load); + break; + } + + default: + LUAU_ASSERT(!"Unknown upvalue capture type"); + LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks + } + } + + build.inst(IrCmd::CHECK_GC); +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 0c24b27da..aff18b308 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -65,6 +65,7 @@ void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); +void translateInstNewClosure(IrBuilder& build, const Instruction* pc, int pcpos); inline int getOpLength(LuauOpcode op) { diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 2395fb1ec..310c15b89 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -38,6 +38,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::GET_ARR_ADDR: case IrCmd::GET_SLOT_NODE_ADDR: case IrCmd::GET_HASH_NODE_ADDR: + case IrCmd::GET_CLOSURE_UPVAL_ADDR: return IrValueKind::Pointer; case IrCmd::STORE_TAG: case IrCmd::STORE_POINTER: @@ -141,7 +142,9 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::FALLBACK_NAMECALL: case IrCmd::FALLBACK_PREPVARARGS: case IrCmd::FALLBACK_GETVARARGS: - case IrCmd::FALLBACK_NEWCLOSURE: + return IrValueKind::None; + case IrCmd::NEWCLOSURE: + return IrValueKind::Pointer; case IrCmd::FALLBACK_DUPCLOSURE: case IrCmd::FALLBACK_FORGPREP: return IrValueKind::None; @@ -164,6 +167,8 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::GET_TYPE: case IrCmd::GET_TYPEOF: return IrValueKind::Pointer; + case IrCmd::FINDUPVAL: + return IrValueKind::Pointer; } LUAU_UNREACHABLE(); diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index e94a43476..0ed7c3883 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -77,7 +77,6 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) case IrCmd::FALLBACK_GETVARARGS: invalidateRestoreVmRegs(vmRegOp(inst.b), function.intOp(inst.c)); break; - case IrCmd::FALLBACK_NEWCLOSURE: case IrCmd::FALLBACK_DUPCLOSURE: invalidateRestoreOp(inst.b); break; @@ -109,6 +108,8 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) case IrCmd::FALLBACK_PREPVARARGS: case IrCmd::ADJUST_STACK_TO_TOP: case IrCmd::GET_TYPEOF: + case IrCmd::NEWCLOSURE: + case IrCmd::FINDUPVAL: break; // These instrucitons read VmReg only after optimizeMemoryOperandsX64 diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 14c1acd99..65984562e 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -56,6 +56,8 @@ void initFunctions(NativeState& data) data.context.luaC_step = luaC_step; data.context.luaF_close = luaF_close; + data.context.luaF_findupval = luaF_findupval; + data.context.luaF_newLclosure = luaF_newLclosure; data.context.luaT_gettm = luaT_gettm; data.context.luaT_objtypenamestr = luaT_objtypenamestr; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index a2393bbfe..1a039812f 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -52,6 +52,8 @@ struct NativeContext size_t (*luaC_step)(lua_State* L, bool assist) = nullptr; void (*luaF_close)(lua_State* L, StkId level) = nullptr; + UpVal* (*luaF_findupval)(lua_State* L, StkId level) = nullptr; + Closure* (*luaF_newLclosure)(lua_State* L, int nelems, Table* e, Proto* p) = nullptr; const TValue* (*luaT_gettm)(Table* events, TMS event, TString* ename) = nullptr; const TString* (*luaT_objtypenamestr)(lua_State* L, const TValue* o) = nullptr; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 72869ad13..758518a2c 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -415,6 +415,8 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid case LBF_RAWLEN: case LBF_BIT32_EXTRACTK: case LBF_GETMETATABLE: + case LBF_TONUMBER: + case LBF_TOSTRING: break; case LBF_SETMETATABLE: state.invalidateHeap(); // TODO: only knownNoMetatable is affected and we might know which one @@ -760,6 +762,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::GET_ARR_ADDR: case IrCmd::GET_SLOT_NODE_ADDR: case IrCmd::GET_HASH_NODE_ADDR: + case IrCmd::GET_CLOSURE_UPVAL_ADDR: break; case IrCmd::ADD_INT: case IrCmd::SUB_INT: @@ -823,6 +826,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::BITXOR_UINT: case IrCmd::BITOR_UINT: case IrCmd::BITNOT_UINT: + break; case IrCmd::BITLSHIFT_UINT: case IrCmd::BITRSHIFT_UINT: case IrCmd::BITARSHIFT_UINT: @@ -833,6 +837,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::INVOKE_LIBM: case IrCmd::GET_TYPE: case IrCmd::GET_TYPEOF: + case IrCmd::FINDUPVAL: break; case IrCmd::JUMP_CMP_ANY: @@ -923,8 +928,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::FALLBACK_GETVARARGS: state.invalidateRegisterRange(vmRegOp(inst.b), function.intOp(inst.c)); break; - case IrCmd::FALLBACK_NEWCLOSURE: - state.invalidate(inst.b); + case IrCmd::NEWCLOSURE: break; case IrCmd::FALLBACK_DUPCLOSURE: state.invalidate(inst.b); diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 7b3a057b7..976dd04f2 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -44,6 +44,7 @@ // Version 1: Baseline version for the open-source release. Supported until 0.521. // Version 2: Adds Proto::linedefined. Supported until 0.544. // Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported. +// Version 4: Adds Proto::flags and typeinfo. Currently supported. // Bytecode opcode, part of the instruction header enum LuauOpcode @@ -543,6 +544,10 @@ enum LuauBuiltinFunction // get/setmetatable LBF_GETMETATABLE, LBF_SETMETATABLE, + + // tonumber/tostring + LBF_TONUMBER, + LBF_TOSTRING, }; // Capture type, used in LOP_CAPTURE @@ -552,3 +557,10 @@ enum LuauCaptureType LCT_REF, LCT_UPVAL, }; + +// Proto flag bitmask, stored in Proto::flags +enum LuauProtoFlag +{ + // used to tag main proto for modules with --!native + LPF_NATIVE_MODULE = 1 << 0, +}; diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 80fe0b6dd..4ec083bbd 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,6 +4,9 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinTonumber, false) +LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinTostring, false) + namespace Luau { namespace Compile @@ -69,6 +72,11 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op if (builtin.isGlobal("setmetatable")) return LBF_SETMETATABLE; + if (FFlag::LuauCompileBuiltinTonumber && builtin.isGlobal("tonumber")) + return LBF_TONUMBER; + if (FFlag::LuauCompileBuiltinTostring && builtin.isGlobal("tostring")) + return LBF_TOSTRING; + if (builtin.object == "math") { if (builtin.method == "abs") @@ -257,10 +265,10 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_MATH_ABS: case LBF_MATH_ACOS: case LBF_MATH_ASIN: - return {1, 1}; + return {1, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_MATH_ATAN2: - return {2, 1}; + return {2, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_MATH_ATAN: case LBF_MATH_CEIL: @@ -269,19 +277,19 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_MATH_DEG: case LBF_MATH_EXP: case LBF_MATH_FLOOR: - return {1, 1}; + return {1, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_MATH_FMOD: - return {2, 1}; + return {2, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_MATH_FREXP: - return {1, 2}; + return {1, 2, BuiltinInfo::Flag_NoneSafe}; case LBF_MATH_LDEXP: - return {2, 1}; + return {2, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_MATH_LOG10: - return {1, 1}; + return {1, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_MATH_LOG: return {-1, 1}; // 1 or 2 parameters @@ -291,10 +299,10 @@ BuiltinInfo getBuiltinInfo(int bfid) return {-1, 1}; // variadic case LBF_MATH_MODF: - return {1, 2}; + return {1, 2, BuiltinInfo::Flag_NoneSafe}; case LBF_MATH_POW: - return {2, 1}; + return {2, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_MATH_RAD: case LBF_MATH_SINH: @@ -302,16 +310,16 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_MATH_SQRT: case LBF_MATH_TANH: case LBF_MATH_TAN: - return {1, 1}; + return {1, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_BIT32_ARSHIFT: - return {2, 1}; + return {2, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_BIT32_BAND: return {-1, 1}; // variadic case LBF_BIT32_BNOT: - return {1, 1}; + return {1, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_BIT32_BOR: case LBF_BIT32_BXOR: @@ -323,14 +331,14 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_BIT32_LROTATE: case LBF_BIT32_LSHIFT: - return {2, 1}; + return {2, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_BIT32_REPLACE: return {-1, 1}; // 3 or 4 parameters case LBF_BIT32_RROTATE: case LBF_BIT32_RSHIFT: - return {2, 1}; + return {2, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_TYPE: return {1, 1}; @@ -342,7 +350,7 @@ BuiltinInfo getBuiltinInfo(int bfid) return {-1, 1}; // variadic case LBF_STRING_LEN: - return {1, 1}; + return {1, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_TYPEOF: return {1, 1}; @@ -351,11 +359,11 @@ BuiltinInfo getBuiltinInfo(int bfid) return {-1, 1}; // 2 or 3 parameters case LBF_MATH_CLAMP: - return {3, 1}; + return {3, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_MATH_SIGN: case LBF_MATH_ROUND: - return {1, 1}; + return {1, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_RAWSET: return {3, 1}; @@ -375,22 +383,28 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_BIT32_COUNTLZ: case LBF_BIT32_COUNTRZ: - return {1, 1}; + return {1, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_SELECT_VARARG: return {-1, -1}; // variadic case LBF_RAWLEN: - return {1, 1}; + return {1, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_BIT32_EXTRACTK: - return {3, 1}; + return {3, 1, BuiltinInfo::Flag_NoneSafe}; case LBF_GETMETATABLE: return {1, 1}; case LBF_SETMETATABLE: return {2, 1}; + + case LBF_TONUMBER: + return {-1, 1}; // 1 or 2 parameters + + case LBF_TOSTRING: + return {1, 1}; }; LUAU_UNREACHABLE(); diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h index e179218aa..2d7832484 100644 --- a/Compiler/src/Builtins.h +++ b/Compiler/src/Builtins.h @@ -41,8 +41,17 @@ void analyzeBuiltins(DenseHashMap& result, const DenseHashMap struct BuiltinInfo { + enum Flags + { + // none-safe builtins are builtins that have the same behavior for arguments that are nil or none + // this allows the compiler to compile calls to builtins more efficiently in certain cases + // for example, math.abs(x()) may compile x() as if it returns one value; if it returns no values, abs() will get nil instead of none + Flag_NoneSafe = 1 << 0, + }; + int params; int results; + unsigned int flags; }; BuiltinInfo getBuiltinInfo(int bfid); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index fe65f67a1..23deec9b3 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -27,6 +27,9 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAGVARIABLE(LuauCompileFunctionType, false) +LUAU_FASTFLAGVARIABLE(LuauCompileNativeComment, false) + +LUAU_FASTFLAGVARIABLE(LuauCompileFixBuiltinArity, false) namespace Luau { @@ -187,7 +190,7 @@ struct Compiler return node->as(); } - uint32_t compileFunction(AstExprFunction* func) + uint32_t compileFunction(AstExprFunction* func, uint8_t protoflags) { LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); @@ -262,7 +265,7 @@ struct Compiler if (bytecode.getInstructionCount() > kMaxInstructionCount) CompileError::raise(func->location, "Exceeded function instruction limit; split the function into parts to compile"); - bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size())); + bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size()), protoflags); Function& f = functions[func]; f.id = fid; @@ -792,8 +795,19 @@ struct Compiler { if (!isExprMultRet(expr->args.data[expr->args.size - 1])) return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); - else if (options.optimizationLevel >= 2 && int(expr->args.size) == getBuiltinInfo(bfid).params) - return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + else if (options.optimizationLevel >= 2) + { + if (FFlag::LuauCompileFixBuiltinArity) + { + // when a builtin is none-safe with matching arity, even if the last expression returns 0 or >1 arguments, + // we can rely on the behavior of the function being the same (none-safe means nil and none are interchangeable) + BuiltinInfo info = getBuiltinInfo(bfid); + if (int(expr->args.size) == info.params && (info.flags & BuiltinInfo::Flag_NoneSafe) != 0) + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + } + else if (int(expr->args.size) == getBuiltinInfo(bfid).params) + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + } } if (expr->self) @@ -3147,7 +3161,7 @@ struct Compiler } } - // compute expressions with side effects for lulz + // compute expressions with side effects for (size_t i = stat->vars.size; i < stat->values.size; ++i) { RegScope rsi(this); @@ -3834,11 +3848,20 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c LUAU_ASSERT(parseResult.errors.empty()); CompileOptions options = inputOptions; + uint8_t mainFlags = 0; for (const HotComment& hc : parseResult.hotcomments) + { if (hc.header && hc.content.compare(0, 9, "optimize ") == 0) options.optimizationLevel = std::max(0, std::min(2, atoi(hc.content.c_str() + 9))); + if (FFlag::LuauCompileNativeComment && hc.header && hc.content == "native") + { + mainFlags |= LPF_NATIVE_MODULE; + options.optimizationLevel = 2; // note: this might be removed in the future in favor of --!optimize + } + } + AstStatBlock* root = parseResult.root; Compiler compiler(bytecode, options); @@ -3884,12 +3907,12 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c root->visit(&functionVisitor); for (AstExprFunction* expr : functions) - compiler.compileFunction(expr); + compiler.compileFunction(expr, 0); AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), /* self= */ nullptr, AstArray(), /* vararg= */ true, /* varargLocation= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); - uint32_t mainid = compiler.compileFunction(&main); + uint32_t mainid = compiler.compileFunction(&main, mainFlags); const Compiler::Function* mainf = compiler.functions.find(&main); LUAU_ASSERT(mainf && mainf->upvals.empty()); diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index ffc1cb1f6..2f7af6ea1 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -6,6 +6,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauAssignmentHasCost, false) + namespace Luau { namespace Compile @@ -302,12 +304,14 @@ struct CostVisitor : AstVisitor return false; } - bool visit(AstStat* node) override + bool visit(AstStatIf* node) override { - if (node->is()) + // unconditional 'else' may require a jump after the 'if' body + // note: this ignores cases when 'then' always terminates and also assumes comparison requires an extra instruction which may be false + if (!FFlag::LuauAssignmentHasCost) result += 2; - else if (node->is() || node->is()) - result += 1; + else + result += 1 + (node->elsebody && !node->elsebody->is()); return true; } @@ -333,7 +337,21 @@ struct CostVisitor : AstVisitor for (size_t i = 0; i < node->vars.size; ++i) assign(node->vars.data[i]); - return true; + if (!FFlag::LuauAssignmentHasCost) + return true; + + for (size_t i = 0; i < node->vars.size || i < node->values.size; ++i) + { + Cost ac; + if (i < node->vars.size) + ac += model(node->vars.data[i]); + if (i < node->values.size) + ac += model(node->values.data[i]); + // local->local or constant->local assignment is not free + result += ac.model == 0 ? Cost(1) : ac; + } + + return false; } bool visit(AstStatCompoundAssign* node) override @@ -345,6 +363,20 @@ struct CostVisitor : AstVisitor return true; } + + bool visit(AstStatBreak* node) override + { + result += 1; + + return false; + } + + bool visit(AstStatContinue* node) override + { + result += 1; + + return false; + } }; uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap& builtins) diff --git a/Sources.cmake b/Sources.cmake index 2a58f061d..c1230f30a 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -186,6 +186,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeArena.h Analysis/include/Luau/TypeAttach.h Analysis/include/Luau/TypeChecker2.h + Analysis/include/Luau/TypeCheckLimits.h Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeFamily.h Analysis/include/Luau/TypeInfer.h diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index 2045768a3..dcb785b60 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -123,7 +123,7 @@ // enables callbacks to redirect code execution from Luau VM to a custom implementation #ifndef LUA_CUSTOM_EXECUTION -#define LUA_CUSTOM_EXECUTION 0 +#define LUA_CUSTOM_EXECUTION 1 #endif // }================================================================== diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index b4490fff3..0b9787a04 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -336,7 +336,7 @@ const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) const char* luaL_typename(lua_State* L, int idx) { const TValue* obj = luaA_toobject(L, idx); - return luaT_objtypename(L, obj); + return obj ? luaT_objtypename(L, obj) : "no value"; } /* diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index e0dc8a38f..c893d6037 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -23,6 +23,8 @@ #endif #endif +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFastcallGC, false) + // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -830,6 +832,8 @@ static int luauF_char(lua_State* L, StkId res, TValue* arg0, int nresults, StkId if (nparams < int(sizeof(buffer)) && nresults <= 1) { + if (DFFlag::LuauFastcallGC && luaC_needsGC(L)) + return -1; // we can't call luaC_checkGC so fall back to C implementation if (nparams >= 1) { @@ -900,6 +904,9 @@ static int luauF_sub(lua_State* L, StkId res, TValue* arg0, int nresults, StkId int i = int(nvalue(args)); int j = int(nvalue(args + 1)); + if (DFFlag::LuauFastcallGC && luaC_needsGC(L)) + return -1; // we can't call luaC_checkGC so fall back to C implementation + if (i >= 1 && j >= i && unsigned(j - 1) < unsigned(ts->len)) { setsvalue(L, res, luaS_newlstr(L, getstr(ts) + (i - 1), j - i + 1)); @@ -1247,6 +1254,73 @@ static int luauF_setmetatable(lua_State* L, StkId res, TValue* arg0, int nresult return -1; } +static int luauF_tonumber(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams == 1 && nresults <= 1) + { + double num; + + if (ttisnumber(arg0)) + { + setnvalue(res, nvalue(arg0)); + return 1; + } + else if (ttisstring(arg0) && luaO_str2d(svalue(arg0), &num)) + { + setnvalue(res, num); + return 1; + } + else + { + setnilvalue(res); + return 1; + } + } + + return -1; +} + +static int luauF_tostring(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + switch (ttype(arg0)) + { + case LUA_TNIL: + { + TString* s = L->global->ttname[LUA_TNIL]; + setsvalue(L, res, s); + return 1; + } + case LUA_TBOOLEAN: + { + TString* s = bvalue(arg0) ? luaS_newliteral(L, "true") : luaS_newliteral(L, "false"); + setsvalue(L, res, s); + return 1; + } + case LUA_TNUMBER: + { + if (DFFlag::LuauFastcallGC && luaC_needsGC(L)) + return -1; // we can't call luaC_checkGC so fall back to C implementation + + char s[LUAI_MAXNUM2STR]; + char* e = luai_num2str(s, nvalue(arg0)); + setsvalue(L, res, luaS_newlstr(L, s, e - s)); + return 1; + } + case LUA_TSTRING: + { + setsvalue(L, res, tsvalue(arg0)); + return 1; + } + } + + // fall back to generic C implementation + } + + return -1; +} + static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { return -1; @@ -1411,6 +1485,9 @@ const luau_FastFunction luauF_table[256] = { luauF_getmetatable, luauF_setmetatable, + luauF_tonumber, + luauF_tostring, + // When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback. // This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing. // Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest. diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 88a3e40ab..276598fe5 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -36,6 +36,7 @@ Proto* luaF_newproto(lua_State* L) f->execdata = NULL; f->exectarget = 0; f->typeinfo = NULL; + f->userdata = NULL; return f; } diff --git a/VM/src/lgc.h b/VM/src/lgc.h index ec7a6828f..1ebb01d28 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -73,10 +73,12 @@ #define luaC_white(g) cast_to(uint8_t, ((g)->currentwhite) & WHITEBITS) +#define luaC_needsGC(L) (L->global->totalbytes >= L->global->GCthreshold) + #define luaC_checkGC(L) \ { \ condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK)); \ - if (L->global->totalbytes >= L->global->GCthreshold) \ + if (luaC_needsGC(L)) \ { \ condhardmemtests(luaC_validate(L), 1); \ luaC_step(L, true); \ diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 2d4e3277a..fe7b1a128 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -7,6 +7,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauFasterNoise, false) + #undef PI #define PI (3.14159265358979323846) #define RADIANS_PER_DEGREE (PI / 180.0) @@ -275,6 +277,7 @@ static int math_randomseed(lua_State* L) return 0; } +// TODO: Delete with LuauFasterNoise static const unsigned char kPerlin[512] = {151, 160, 137, 91, 90, 15, 131, 13, 201, 95, 96, 53, 194, 233, 7, 225, 140, 36, 103, 30, 69, 142, 8, 99, 37, 240, 21, 10, 23, 190, 6, 148, 247, 120, 234, 75, 0, 26, 197, 62, 94, 252, 219, 203, 117, 35, 11, 32, 57, 177, 33, 88, 237, 149, 56, 87, 174, 20, 125, 136, 171, 168, 68, 175, 74, 165, 71, 134, 139, 48, 27, 166, 77, 146, 158, 231, 83, 111, 229, 122, 60, 211, 133, 230, 220, 105, 92, 41, @@ -295,18 +298,32 @@ static const unsigned char kPerlin[512] = {151, 160, 137, 91, 90, 15, 131, 13, 2 106, 157, 184, 84, 204, 176, 115, 121, 50, 45, 127, 4, 150, 254, 138, 236, 205, 93, 222, 114, 67, 29, 24, 72, 243, 141, 128, 195, 78, 66, 215, 61, 156, 180}; -static float fade(float t) +static const unsigned char kPerlinHash[257] = {151, 160, 137, 91, 90, 15, 131, 13, 201, 95, 96, 53, 194, 233, 7, 225, 140, 36, 103, 30, 69, 142, 8, 99, + 37, 240, 21, 10, 23, 190, 6, 148, 247, 120, 234, 75, 0, 26, 197, 62, 94, 252, 219, 203, 117, 35, 11, 32, 57, 177, 33, 88, 237, 149, 56, 87, 174, + 20, 125, 136, 171, 168, 68, 175, 74, 165, 71, 134, 139, 48, 27, 166, 77, 146, 158, 231, 83, 111, 229, 122, 60, 211, 133, 230, 220, 105, 92, 41, + 55, 46, 245, 40, 244, 102, 143, 54, 65, 25, 63, 161, 1, 216, 80, 73, 209, 76, 132, 187, 208, 89, 18, 169, 200, 196, 135, 130, 116, 188, 159, 86, + 164, 100, 109, 198, 173, 186, 3, 64, 52, 217, 226, 250, 124, 123, 5, 202, 38, 147, 118, 126, 255, 82, 85, 212, 207, 206, 59, 227, 47, 16, 58, 17, + 182, 189, 28, 42, 223, 183, 170, 213, 119, 248, 152, 2, 44, 154, 163, 70, 221, 153, 101, 155, 167, 43, 172, 9, 129, 22, 39, 253, 19, 98, 108, 110, + 79, 113, 224, 232, 178, 185, 112, 104, 218, 246, 97, 228, 251, 34, 242, 193, 238, 210, 144, 12, 191, 179, 162, 241, 81, 51, 145, 235, 249, 14, + 239, 107, 49, 192, 214, 31, 181, 199, 106, 157, 184, 84, 204, 176, 115, 121, 50, 45, 127, 4, 150, 254, 138, 236, 205, 93, 222, 114, 67, 29, 24, + 72, 243, 141, 128, 195, 78, 66, 215, 61, 156, 180, 151}; + +const float kPerlinGrad[16][3] = {{1, 1, 0}, {-1, 1, 0}, {1, -1, 0}, {-1, -1, 0}, {1, 0, 1}, {-1, 0, 1}, {1, 0, -1}, {-1, 0, -1}, {0, 1, 1}, + {0, -1, 1}, {0, 1, -1}, {0, -1, -1}, {1, 1, 0}, {0, -1, 1}, {-1, 1, 0}, {0, -1, -1}}; + +static float perlin_fade(float t) { return t * t * t * (t * (t * 6 - 15) + 10); } -static float math_lerp(float t, float a, float b) +static float perlin_lerp(float t, float a, float b) { return a + t * (b - a); } static float grad(unsigned char hash, float x, float y, float z) { + LUAU_ASSERT(!FFlag::LuauFasterNoise); unsigned char h = hash & 15; float u = (h < 8) ? x : y; float v = (h < 4) ? y : (h == 12 || h == 14) ? x : z; @@ -314,8 +331,15 @@ static float grad(unsigned char hash, float x, float y, float z) return (h & 1 ? -u : u) + (h & 2 ? -v : v); } -static float perlin(float x, float y, float z) +static float perlin_grad(int hash, float x, float y, float z) { + const float* g = kPerlinGrad[hash & 15]; + return g[0] * x + g[1] * y + g[2] * z; +} + +static float perlin_dep(float x, float y, float z) +{ + LUAU_ASSERT(!FFlag::LuauFasterNoise); float xflr = floorf(x); float yflr = floorf(y); float zflr = floorf(z); @@ -328,9 +352,9 @@ static float perlin(float x, float y, float z) float yf = y - yflr; float zf = z - zflr; - float u = fade(xf); - float v = fade(yf); - float w = fade(zf); + float u = perlin_fade(xf); + float v = perlin_fade(yf); + float w = perlin_fade(zf); const unsigned char* p = kPerlin; @@ -342,24 +366,79 @@ static float perlin(float x, float y, float z) int ba = p[b] + zi; int bb = p[b + 1] + zi; - return math_lerp(w, - math_lerp(v, math_lerp(u, grad(p[aa], xf, yf, zf), grad(p[ba], xf - 1, yf, zf)), - math_lerp(u, grad(p[ab], xf, yf - 1, zf), grad(p[bb], xf - 1, yf - 1, zf))), - math_lerp(v, math_lerp(u, grad(p[aa + 1], xf, yf, zf - 1), grad(p[ba + 1], xf - 1, yf, zf - 1)), - math_lerp(u, grad(p[ab + 1], xf, yf - 1, zf - 1), grad(p[bb + 1], xf - 1, yf - 1, zf - 1)))); + return perlin_lerp(w, + perlin_lerp(v, perlin_lerp(u, grad(p[aa], xf, yf, zf), grad(p[ba], xf - 1, yf, zf)), + perlin_lerp(u, grad(p[ab], xf, yf - 1, zf), grad(p[bb], xf - 1, yf - 1, zf))), + perlin_lerp(v, perlin_lerp(u, grad(p[aa + 1], xf, yf, zf - 1), grad(p[ba + 1], xf - 1, yf, zf - 1)), + perlin_lerp(u, grad(p[ab + 1], xf, yf - 1, zf - 1), grad(p[bb + 1], xf - 1, yf - 1, zf - 1)))); +} + +static float perlin(float x, float y, float z) +{ + LUAU_ASSERT(FFlag::LuauFasterNoise); + float xflr = floorf(x); + float yflr = floorf(y); + float zflr = floorf(z); + + int xi = int(xflr) & 255; + int yi = int(yflr) & 255; + int zi = int(zflr) & 255; + + float xf = x - xflr; + float yf = y - yflr; + float zf = z - zflr; + + float u = perlin_fade(xf); + float v = perlin_fade(yf); + float w = perlin_fade(zf); + + const unsigned char* p = kPerlinHash; + + int a = (p[xi] + yi) & 255; + int aa = (p[a] + zi) & 255; + int ab = (p[a + 1] + zi) & 255; + + int b = (p[xi + 1] + yi) & 255; + int ba = (p[b] + zi) & 255; + int bb = (p[b + 1] + zi) & 255; + + float la = perlin_lerp(u, perlin_grad(p[aa], xf, yf, zf), perlin_grad(p[ba], xf - 1, yf, zf)); + float lb = perlin_lerp(u, perlin_grad(p[ab], xf, yf - 1, zf), perlin_grad(p[bb], xf - 1, yf - 1, zf)); + float la1 = perlin_lerp(u, perlin_grad(p[aa + 1], xf, yf, zf - 1), perlin_grad(p[ba + 1], xf - 1, yf, zf - 1)); + float lb1 = perlin_lerp(u, perlin_grad(p[ab + 1], xf, yf - 1, zf - 1), perlin_grad(p[bb + 1], xf - 1, yf - 1, zf - 1)); + + return perlin_lerp(w, perlin_lerp(v, la, lb), perlin_lerp(v, la1, lb1)); } static int math_noise(lua_State* L) { - double x = luaL_checknumber(L, 1); - double y = luaL_optnumber(L, 2, 0.0); - double z = luaL_optnumber(L, 3, 0.0); + if (FFlag::LuauFasterNoise) + { + int nx, ny, nz; + double x = lua_tonumberx(L, 1, &nx); + double y = lua_tonumberx(L, 2, &ny); + double z = lua_tonumberx(L, 3, &nz); - double r = perlin((float)x, (float)y, (float)z); + luaL_argexpected(L, nx, 1, "number"); + luaL_argexpected(L, ny || lua_isnoneornil(L, 2), 2, "number"); + luaL_argexpected(L, nz || lua_isnoneornil(L, 3), 3, "number"); - lua_pushnumber(L, r); + double r = perlin((float)x, (float)y, (float)z); - return 1; + lua_pushnumber(L, r); + return 1; + } + else + { + double x = luaL_checknumber(L, 1); + double y = luaL_optnumber(L, 2, 0.0); + double z = luaL_optnumber(L, 3, 0.0); + + double r = perlin_dep((float)x, (float)y, (float)z); + + lua_pushnumber(L, r); + return 1; + } } static int math_clamp(lua_State* L) diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 74ea16235..18ff75460 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -290,6 +290,8 @@ typedef struct Proto uint8_t* typeinfo; + void* userdata; + GCObject* gclist; diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 9bc624e93..fbf03deb8 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,6 +10,8 @@ #include "ldebug.h" #include "lvm.h" +LUAU_FASTFLAGVARIABLE(LuauFasterTableConcat, false) + static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -219,8 +221,8 @@ static int tmove(lua_State* L) static void addfield(lua_State* L, luaL_Buffer* b, int i) { - lua_rawgeti(L, 1, i); - if (!lua_isstring(L, -1)) + int tt = lua_rawgeti(L, 1, i); + if (FFlag::LuauFasterTableConcat ? (tt != LUA_TSTRING && tt != LUA_TNUMBER) : !lua_isstring(L, -1)) luaL_error(L, "invalid value (%s) at index %d in table for 'concat'", luaL_typename(L, -1), i); luaL_addvalue(b); } diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 90c5a7e86..2909d4775 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -131,6 +131,9 @@ goto dispatchContinue #endif +// Does VM support native execution via ExecutionCallbacks? We mostly assume it does but keep the define to make it easy to quantify the cost. +#define VM_HAS_NATIVE LUA_CUSTOM_EXECUTION + LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) { ptrdiff_t base = savestack(L, L->base); @@ -207,7 +210,7 @@ static void luau_execute(lua_State* L) LUAU_ASSERT(L->isactive); LUAU_ASSERT(!isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black -#if LUA_CUSTOM_EXECUTION +#if VM_HAS_NATIVE if ((L->ci->flags & LUA_CALLINFO_NATIVE) && !SingleStep) { Proto* p = clvalue(L->ci->func)->l.p; @@ -1036,7 +1039,7 @@ static void luau_execute(lua_State* L) Closure* nextcl = clvalue(cip->func); Proto* nextproto = nextcl->l.p; -#if LUA_CUSTOM_EXECUTION +#if VM_HAS_NATIVE if (LUAU_UNLIKELY((cip->flags & LUA_CALLINFO_NATIVE) && !SingleStep)) { if (L->global->ecb.enter(L, nextproto) == 1) @@ -2371,7 +2374,7 @@ static void luau_execute(lua_State* L) ci->flags = LUA_CALLINFO_NATIVE; ci->savedpc = p->code; -#if LUA_CUSTOM_EXECUTION +#if VM_HAS_NATIVE if (L->global->ecb.enter(L, p) == 1) goto reentry; else @@ -2890,7 +2893,7 @@ int luau_precall(lua_State* L, StkId func, int nresults) ci->savedpc = p->code; -#if LUA_CUSTOM_EXECUTION +#if VM_HAS_NATIVE if (p->execdata) ci->flags = LUA_CALLINFO_NATIVE; #endif diff --git a/bench/micro_tests/test_ToNumberString.lua b/bench/micro_tests/test_ToNumberString.lua new file mode 100644 index 000000000..611047831 --- /dev/null +++ b/bench/micro_tests/test_ToNumberString.lua @@ -0,0 +1,22 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +bench.runCode(function() + for j=1,1e6 do + tonumber("42") + tonumber(42) + end +end, "tonumber") + +bench.runCode(function() + for j=1,1e6 do + tostring(nil) + tostring("test") + tostring(42) + end +end, "tostring") + +bench.runCode(function() + for j=1,1e6 do + tostring(j) + end +end, "tostring-gc") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index b8dee9976..e13e203a1 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) +LUAU_FASTFLAG(LuauAutocompleteLastTypecheck) using namespace Luau; @@ -33,14 +34,40 @@ struct ACFixtureImpl : BaseType AutocompleteResult autocomplete(unsigned row, unsigned column) { + if (FFlag::LuauAutocompleteLastTypecheck) + { + FrontendOptions opts; + opts.forAutocomplete = true; + this->frontend.check("MainModule", opts); + } + return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); } AutocompleteResult autocomplete(char marker, StringCompletionCallback callback = nullCallback) { + if (FFlag::LuauAutocompleteLastTypecheck) + { + FrontendOptions opts; + opts.forAutocomplete = true; + this->frontend.check("MainModule", opts); + } + return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), callback); } + AutocompleteResult autocomplete(const ModuleName& name, Position pos, StringCompletionCallback callback = nullCallback) + { + if (FFlag::LuauAutocompleteLastTypecheck) + { + FrontendOptions opts; + opts.forAutocomplete = true; + this->frontend.check(name, opts); + } + + return Luau::autocomplete(this->frontend, name, pos, callback); + } + CheckResult check(const std::string& source) { markerPosition.clear(); @@ -99,7 +126,7 @@ struct ACFixtureImpl : BaseType LUAU_ASSERT(i != markerPosition.end()); return i->second; } - + ScopedFastFlag flag{"LuauAutocompleteHideSelfArg", true}; // Maps a marker character (0-9 inclusive) to a position in the source code. std::map markerPosition; }; @@ -1319,7 +1346,7 @@ local a: aa frontend.check("Module/B"); - auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 11}, nullCallback); + auto ac = autocomplete("Module/B", Position{2, 11}); CHECK(ac.entryMap.count("aaa")); CHECK_EQ(ac.context, AutocompleteContext::Type); @@ -1342,7 +1369,7 @@ local a: aaa. frontend.check("Module/B"); - auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 13}, nullCallback); + auto ac = autocomplete("Module/B", Position{2, 13}); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("A")); @@ -1999,7 +2026,7 @@ ex.a(function(x: frontend.check("Module/B"); - auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 16}, nullCallback); + auto ac = autocomplete("Module/B", Position{2, 16}); CHECK(!ac.entryMap.count("done")); @@ -2010,7 +2037,7 @@ ex.b(function(x: frontend.check("Module/C"); - ac = Luau::autocomplete(frontend, "Module/C", Position{2, 16}, nullCallback); + ac = autocomplete("Module/C", Position{2, 16}); CHECK(!ac.entryMap.count("(done) -> number")); } @@ -2033,7 +2060,7 @@ ex.a(function(x: frontend.check("Module/B"); - auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 16}, nullCallback); + auto ac = autocomplete("Module/B", Position{2, 16}); CHECK(!ac.entryMap.count("done")); CHECK(ac.entryMap.count("ex.done")); @@ -2046,7 +2073,7 @@ ex.b(function(x: frontend.check("Module/C"); - ac = Luau::autocomplete(frontend, "Module/C", Position{2, 16}, nullCallback); + ac = autocomplete("Module/C", Position{2, 16}); CHECK(!ac.entryMap.count("(done) -> number")); CHECK(ac.entryMap.count("(ex.done) -> number")); @@ -2360,7 +2387,7 @@ local a: aaa.do frontend.check("Module/B"); - auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 15}, nullCallback); + auto ac = autocomplete("Module/B", Position{2, 15}); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2372,7 +2399,7 @@ TEST_CASE_FIXTURE(ACFixture, "comments") { fileResolver.source["Comments"] = "--!str"; - auto ac = Luau::autocomplete(frontend, "Comments", Position{0, 6}, nullCallback); + auto ac = autocomplete("Comments", Position{0, 6}); CHECK_EQ(0, ac.entryMap.size()); } @@ -2391,7 +2418,7 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod -- | Column 20 )"; - auto ac = Luau::autocomplete(frontend, "Module/A", Position{9, 20}, nullCallback); + auto ac = autocomplete("Module/A", Position{9, 20}); REQUIRE_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); } @@ -2484,7 +2511,7 @@ TEST_CASE_FIXTURE(ACFixture, "not_the_var_we_are_defining") { fileResolver.source["Module/A"] = "abc,de"; - auto ac = Luau::autocomplete(frontend, "Module/A", Position{0, 6}, nullCallback); + auto ac = autocomplete("Module/A", Position{0, 6}); CHECK(!ac.entryMap.count("de")); } @@ -2495,7 +2522,7 @@ TEST_CASE_FIXTURE(ACFixture, "recursive_function_global") end )"; - auto ac = Luau::autocomplete(frontend, "global", Position{1, 0}, nullCallback); + auto ac = autocomplete("global", Position{1, 0}); CHECK(ac.entryMap.count("abc")); } @@ -2508,7 +2535,7 @@ TEST_CASE_FIXTURE(ACFixture, "recursive_function_local") end )"; - auto ac = Luau::autocomplete(frontend, "local", Position{1, 0}, nullCallback); + auto ac = autocomplete("local", Position{1, 0}); CHECK(ac.entryMap.count("abc")); } @@ -3147,6 +3174,8 @@ t:@1 REQUIRE(ac.entryMap.count("two")); CHECK(!ac.entryMap["one"].wrongIndexType); CHECK(ac.entryMap["two"].wrongIndexType); + CHECK(ac.entryMap["one"].indexedWithSelf); + CHECK(ac.entryMap["two"].indexedWithSelf); } { @@ -3161,6 +3190,8 @@ t.@1 REQUIRE(ac.entryMap.count("two")); CHECK(ac.entryMap["one"].wrongIndexType); CHECK(!ac.entryMap["two"].wrongIndexType); + CHECK(!ac.entryMap["one"].indexedWithSelf); + CHECK(!ac.entryMap["two"].indexedWithSelf); } } @@ -3190,6 +3221,7 @@ t:@1 REQUIRE(ac.entryMap.count("m")); CHECK(!ac.entryMap["m"].wrongIndexType); + CHECK(ac.entryMap["m"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") @@ -3204,6 +3236,7 @@ t:@1 REQUIRE(ac.entryMap.count("m")); CHECK(ac.entryMap["m"].wrongIndexType); + CHECK(ac.entryMap["m"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") @@ -3219,6 +3252,7 @@ t:@1 REQUIRE(ac.entryMap.count("f")); CHECK(ac.entryMap["f"].wrongIndexType); + CHECK(ac.entryMap["f"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "do_wrong_compatible_self_calls") @@ -3234,6 +3268,22 @@ t:@1 REQUIRE(ac.entryMap.count("m")); // We can make changes to mark this as a wrong way to call even though it's compatible CHECK(!ac.entryMap["m"].wrongIndexType); + CHECK(ac.entryMap["m"].indexedWithSelf); +} + +TEST_CASE_FIXTURE(ACFixture, "do_wrong_compatible_nonself_calls") +{ + check(R"( +local t = {} +function t:m(x: string) end +t.@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + CHECK(!ac.entryMap["m"].wrongIndexType); + CHECK(!ac.entryMap["m"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "no_wrong_compatible_self_calls_with_generics") @@ -3249,6 +3299,7 @@ t:@1 REQUIRE(ac.entryMap.count("m")); // While this call is compatible with the type, this requires instantiation of a generic type which we don't perform CHECK(ac.entryMap["m"].wrongIndexType); + CHECK(ac.entryMap["m"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") @@ -3262,10 +3313,13 @@ s:@1 REQUIRE(ac.entryMap.count("byte")); CHECK(ac.entryMap["byte"].wrongIndexType == false); + CHECK(ac.entryMap["byte"].indexedWithSelf); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == true); + CHECK(ac.entryMap["char"].indexedWithSelf); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == false); + CHECK(ac.entryMap["sub"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") @@ -3279,8 +3333,10 @@ s.@1 REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == false); + CHECK(!ac.entryMap["char"].indexedWithSelf); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == true); + CHECK(!ac.entryMap["sub"].indexedWithSelf); } TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_non_self_calls_are_fine") @@ -3293,10 +3349,13 @@ string.@1 REQUIRE(ac.entryMap.count("byte")); CHECK(ac.entryMap["byte"].wrongIndexType == false); + CHECK(!ac.entryMap["byte"].indexedWithSelf); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == false); + CHECK(!ac.entryMap["char"].indexedWithSelf); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == false); + CHECK(!ac.entryMap["sub"].indexedWithSelf); check(R"( table.@1 @@ -3306,10 +3365,13 @@ table.@1 REQUIRE(ac.entryMap.count("remove")); CHECK(ac.entryMap["remove"].wrongIndexType == false); + CHECK(!ac.entryMap["remove"].indexedWithSelf); REQUIRE(ac.entryMap.count("getn")); CHECK(ac.entryMap["getn"].wrongIndexType == false); + CHECK(!ac.entryMap["getn"].indexedWithSelf); REQUIRE(ac.entryMap.count("insert")); CHECK(ac.entryMap["insert"].wrongIndexType == false); + CHECK(!ac.entryMap["insert"].indexedWithSelf); } TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_self_calls_are_invalid") @@ -3322,13 +3384,16 @@ string:@1 REQUIRE(ac.entryMap.count("byte")); CHECK(ac.entryMap["byte"].wrongIndexType == true); + CHECK(ac.entryMap["byte"].indexedWithSelf); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == true); + CHECK(ac.entryMap["char"].indexedWithSelf); // We want the next test to evaluate to 'true', but we have to allow function defined with 'self' to be callable with ':' // We may change the definition of the string metatable to not use 'self' types in the future (like byte/char/pack/unpack) REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == false); + CHECK(ac.entryMap["sub"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "source_module_preservation_and_invalidation") @@ -3489,4 +3554,480 @@ TEST_CASE_FIXTURE(ACFixture, "frontend_use_correct_global_scope") CHECK(ac.entryMap.count("Name")); } +TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes") +{ + ScopedFastFlag flag{"LuauDisableCompletionOutsideQuotes", true}; + + loadDefinition(R"( + declare function require(path: string): any + )"); + + std::optional require = frontend.globalsForAutocomplete.globalScope->linearSearchForBinding("require"); + REQUIRE(require); + Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); + attachTag(require->typeId, "RequireCall"); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); + + check(R"( + local x = require(@1"@2"@3) + )"); + + StringCompletionCallback callback = [](std::string, std::optional, + std::optional contents) -> std::optional + { + Luau::AutocompleteEntryMap results = {{"test", Luau::AutocompleteEntry{Luau::AutocompleteEntryKind::String, std::nullopt, false, false}}}; + return results; + }; + + auto ac = autocomplete('2', callback); + + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("test")); + + ac = autocomplete('1', callback); + + CHECK_EQ(0, ac.entryMap.size()); + + ac = autocomplete('3', callback); + + CHECK_EQ(0, ac.entryMap.size()); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_empty") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: () -> ()) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function() end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (number, string) -> ()) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: number, a1: string) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_single_return") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (number, string) -> (string)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: number, a1: string): string end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_multi_return") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (number, string) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: number, a1: string): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__noargs_multi_return") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: () -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__varargs_multi_return") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (...number) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(...: number): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_return") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (string, ...number) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: string, ...: number): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_varargs_return") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (string, ...number) -> ...number) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: string, ...: number): ...number end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_varargs_return") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (string, ...number) -> (boolean, ...number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: string, ...: number): (boolean, ...number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_named_args") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (foo: number, bar: string) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(foo: number, bar: string): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (number, bar: string) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: number, bar: string): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args_last") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (foo: number, string) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(foo: number, a1: string): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_args") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local t = { a = 1, b = 2 } + +local function foo(a: (foo: typeof(t)) -> ()) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(foo) end"; // Cannot utter this type. + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (tbl: { x: number, y: number }) -> number) return a({x=2, y = 3}) end +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(tbl: { x: number, y: number }): number end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_returns") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local t = { a = 1, b = 2 } + +local function foo(a: () -> typeof(t)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function() end"; // Cannot utter this type. + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: () -> { x: number, y: number }) return {x=2, y = 3} end +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(): { x: number, y: number } end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_vararg") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local t = { a = 1, b = 2 } + +local function foo(a: (...typeof(t)) -> ()) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(...) end"; // Cannot utter this type. + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_type_pack_vararg") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (...A) -> number, ...: A) + return a(...) +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(...): number end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_on_argument_type_pack_vararg") +{ + ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + + check(R"( +local function foo(a: (...: T...) -> number) + return a(4, 5, 6) +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(...): number end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index db779da2b..7abf04238 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -6978,6 +6978,8 @@ L3: RETURN R0 0 TEST_CASE("BuiltinArity") { + ScopedFastFlag sff("LuauCompileFixBuiltinArity", true); + // by default we can't assume that we know parameter/result count for builtins as they can be overridden at runtime CHECK_EQ("\n" + compileFunction(R"( return math.abs(unknown()) @@ -7037,6 +7039,21 @@ FASTCALL 34 L0 GETIMPORT R0 4 [bit32.extract] CALL R0 -1 1 L0: RETURN R0 1 +)"); + + // some builtins are not variadic and have a fixed number of arguments but are not none-safe, meaning that we can't replace calls that may + // return none with calls that will return nil + CHECK_EQ("\n" + compileFunction(R"( +return type(unknown()) +)", + 0, 2), + R"( +GETIMPORT R1 1 [unknown] +CALL R1 0 -1 +FASTCALL 40 L0 +GETIMPORT R0 3 [type] +CALL R0 -1 1 +L0: RETURN R0 1 )"); // importantly, this optimization also helps us get around the multret inlining restriction for builtin wrappers diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index c98dabb95..c07aab0d7 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -273,6 +273,8 @@ TEST_CASE("Assert") TEST_CASE("Basic") { + ScopedFastFlag sff("LuauCompileFixBuiltinArity", true); + runConformance("basic.lua"); } @@ -326,6 +328,8 @@ TEST_CASE("Clear") TEST_CASE("Strings") { + ScopedFastFlag sff("LuauCompileFixBuiltinArity", true); + runConformance("strings.lua"); } @@ -1112,6 +1116,34 @@ static bool endsWith(const std::string& str, const std::string& suffix) return suffix == std::string_view(str.c_str() + str.length() - suffix.length(), suffix.length()); } +TEST_CASE("ApiType") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_pushnumber(L, 2); + CHECK(strcmp(luaL_typename(L, -1), "number") == 0); + CHECK(strcmp(luaL_typename(L, 1), "number") == 0); + CHECK(lua_type(L, -1) == LUA_TNUMBER); + CHECK(lua_type(L, 1) == LUA_TNUMBER); + + CHECK(strcmp(luaL_typename(L, 2), "no value") == 0); + CHECK(lua_type(L, 2) == LUA_TNONE); + CHECK(strcmp(lua_typename(L, lua_type(L, 2)), "no value") == 0); + + lua_newuserdata(L, 0); + CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0); + CHECK(lua_type(L, -1) == LUA_TUSERDATA); + + lua_newtable(L); + lua_pushstring(L, "hello"); + lua_setfield(L, -2, "__type"); + lua_setmetatable(L, -2); + + CHECK(strcmp(luaL_typename(L, -1), "hello") == 0); + CHECK(lua_type(L, -1) == LUA_TUSERDATA); +} + #if !LUA_USE_LONGJMP TEST_CASE("ExceptionObject") { diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index 6bfb15901..01b9a5dd6 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -21,7 +21,7 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) AstStatBlock* root = parse(code); dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); cgb = std::make_unique(mainModule, &arena, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), - frontend.globals.globalScope, /*prepareModuleScope*/ nullptr, &logger, NotNull{dfg.get()}); + frontend.globals.globalScope, /*prepareModuleScope*/ nullptr, &logger, NotNull{dfg.get()}, std::vector()); cgb->visit(root); rootScope = cgb->rootScope; constraints = Luau::borrowConstraints(cgb->constraints); @@ -30,7 +30,7 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) void ConstraintGraphBuilderFixture::solve(const std::string& code) { generateConstraints(code); - ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {}}; cs.run(); } diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp index 686a99d17..206b83a81 100644 --- a/tests/CostModel.test.cpp +++ b/tests/CostModel.test.cpp @@ -133,6 +133,8 @@ end TEST_CASE("ControlFlow") { + ScopedFastFlag sff("LuauAssignmentHasCost", true); + uint64_t model = modelFunction(R"( function test(a) while a < 0 do @@ -156,8 +158,8 @@ end const bool args1[] = {false}; const bool args2[] = {true}; - CHECK_EQ(82, Luau::Compile::computeCost(model, args1, 1)); - CHECK_EQ(79, Luau::Compile::computeCost(model, args2, 1)); + CHECK_EQ(76, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(73, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("Conditional") @@ -240,4 +242,25 @@ end CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); } +TEST_CASE("MultipleAssignments") +{ + ScopedFastFlag sff("LuauAssignmentHasCost", true); + + uint64_t model = modelFunction(R"( +function test(a) + local x = 0 + x = a + x = a + 1 + x, x, x = a + x = a, a, a +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(8, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(7, Luau::Compile::computeCost(model, args2, 1)); +} + TEST_SUITE_END(); diff --git a/tests/Differ.test.cpp b/tests/Differ.test.cpp index 520c53021..132b0267a 100644 --- a/tests/Differ.test.cpp +++ b/tests/Differ.test.cpp @@ -17,308 +17,266 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) TEST_SUITE_BEGIN("Differ"); -TEST_CASE_FIXTURE(Fixture, "equal_numbers") +TEST_CASE_FIXTURE(DifferFixture, "equal_numbers") { CheckResult result = check(R"( local foo = 5 local almostFoo = 78 - almostFoo = foo )"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - try - { - DifferResult diffRes = diff(foo, almostFoo); - CHECK(!diffRes.diffError.has_value()); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } + compareTypesEq("foo", "almostFoo"); } -TEST_CASE_FIXTURE(Fixture, "equal_strings") +TEST_CASE_FIXTURE(DifferFixture, "equal_strings") { CheckResult result = check(R"( local foo = "hello" local almostFoo = "world" - almostFoo = foo )"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - try - { - DifferResult diffRes = diff(foo, almostFoo); - CHECK(!diffRes.diffError.has_value()); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } + compareTypesEq("foo", "almostFoo"); } -TEST_CASE_FIXTURE(Fixture, "equal_tables") +TEST_CASE_FIXTURE(DifferFixture, "equal_tables") { CheckResult result = check(R"( local foo = { x = 1, y = "where" } local almostFoo = { x = 5, y = "when" } - almostFoo = foo )"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - try - { - DifferResult diffRes = diff(foo, almostFoo); - CHECK(!diffRes.diffError.has_value()); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } + compareTypesEq("foo", "almostFoo"); } -TEST_CASE_FIXTURE(Fixture, "a_table_missing_property") +TEST_CASE_FIXTURE(DifferFixture, "a_table_missing_property") { CheckResult result = check(R"( local foo = { x = 1, y = 2 } local almostFoo = { x = 1, z = 3 } - almostFoo = foo )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - diffMessage = diff(foo, almostFoo).diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ("DiffError: these two types are not equal because the left type at foo.y has type number, while the right type at almostFoo is missing " - "the property y", - diffMessage); + compareTypesNe("foo", "almostFoo", + "DiffError: these two types are not equal because the left type at foo.y has type number, while the right type at almostFoo is missing " + "the property y"); } -TEST_CASE_FIXTURE(Fixture, "left_table_missing_property") +TEST_CASE_FIXTURE(DifferFixture, "left_table_missing_property") { CheckResult result = check(R"( local foo = { x = 1 } local almostFoo = { x = 1, z = 3 } - almostFoo = foo )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - diffMessage = diff(foo, almostFoo).diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ("DiffError: these two types are not equal because the left type at foo is missing the property z, while the right type at almostFoo.z " - "has type number", - diffMessage); + compareTypesNe("foo", "almostFoo", + "DiffError: these two types are not equal because the left type at foo is missing the property z, while the right type at almostFoo.z " + "has type number"); } -TEST_CASE_FIXTURE(Fixture, "a_table_wrong_type") +TEST_CASE_FIXTURE(DifferFixture, "a_table_wrong_type") { CheckResult result = check(R"( local foo = { x = 1, y = 2 } local almostFoo = { x = 1, y = "two" } - almostFoo = foo )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - diffMessage = diff(foo, almostFoo).diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ("DiffError: these two types are not equal because the left type at foo.y has type number, while the right type at almostFoo.y has type " - "string", - diffMessage); + compareTypesNe("foo", "almostFoo", + "DiffError: these two types are not equal because the left type at foo.y has type number, while the right type at almostFoo.y has type " + "string"); } -TEST_CASE_FIXTURE(Fixture, "a_table_wrong_type") +TEST_CASE_FIXTURE(DifferFixture, "a_table_wrong_type") { CheckResult result = check(R"( local foo: string local almostFoo: number - almostFoo = foo )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - diffMessage = diff(foo, almostFoo).diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ("DiffError: these two types are not equal because the left type at has type string, while the right type at " - " has type number", - diffMessage); + compareTypesNe("foo", "almostFoo", + "DiffError: these two types are not equal because the left type at has type string, while the right type at " + " has type number"); } -TEST_CASE_FIXTURE(Fixture, "a_nested_table_wrong_type") +TEST_CASE_FIXTURE(DifferFixture, "a_nested_table_wrong_type") { CheckResult result = check(R"( local foo = { x = 1, inner = { table = { has = { wrong = { value = 5 } } } } } local almostFoo = { x = 1, inner = { table = { has = { wrong = { value = "five" } } } } } - almostFoo = foo )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - diffMessage = diff(foo, almostFoo).diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ("DiffError: these two types are not equal because the left type at foo.inner.table.has.wrong.value has type number, while the right " - "type at almostFoo.inner.table.has.wrong.value has type string", - diffMessage); + compareTypesNe("foo", "almostFoo", + "DiffError: these two types are not equal because the left type at foo.inner.table.has.wrong.value has type number, while the right " + "type at almostFoo.inner.table.has.wrong.value has type string"); } -TEST_CASE_FIXTURE(Fixture, "a_nested_table_wrong_match") +TEST_CASE_FIXTURE(DifferFixture, "a_nested_table_wrong_match") { CheckResult result = check(R"( local foo = { x = 1, inner = { table = { has = { wrong = { variant = { because = { it = { goes = { on = "five" } } } } } } } } } local almostFoo = { x = 1, inner = { table = { has = { wrong = { variant = "five" } } } } } - almostFoo = foo )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - diffMessage = diff(foo, almostFoo).diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ("DiffError: these two types are not equal because the left type at foo.inner.table.has.wrong.variant has type { because: { it: { goes: " - "{ on: string } } } }, while the right type at almostFoo.inner.table.has.wrong.variant has type string", - diffMessage); + compareTypesNe("foo", "almostFoo", + "DiffError: these two types are not equal because the left type at foo.inner.table.has.wrong.variant has type { because: { it: { goes: " + "{ on: string } } } }, while the right type at almostFoo.inner.table.has.wrong.variant has type string"); } -TEST_CASE_FIXTURE(Fixture, "singleton") +TEST_CASE_FIXTURE(DifferFixture, "singleton") { CheckResult result = check(R"( local foo: "hello" = "hello" local almostFoo: true = true - almostFoo = foo )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - diffMessage = diff(foo, almostFoo).diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at has type "hello", while the right type at has type true)", - diffMessage); + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at has type "hello", while the right type at has type true)"); } -TEST_CASE_FIXTURE(Fixture, "equal_singleton") +TEST_CASE_FIXTURE(DifferFixture, "equal_singleton") { CheckResult result = check(R"( local foo: "hello" = "hello" local almostFoo: "hello" - almostFoo = foo )"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - try - { - DifferResult diffRes = diff(foo, almostFoo); - INFO(diffRes.diffError->toString()); - CHECK(!diffRes.diffError.has_value()); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } + compareTypesEq("foo", "almostFoo"); } -TEST_CASE_FIXTURE(Fixture, "singleton_string") +TEST_CASE_FIXTURE(DifferFixture, "singleton_string") { CheckResult result = check(R"( local foo: "hello" = "hello" local almostFoo: "world" = "world" - almostFoo = foo )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at has type "hello", while the right type at has type "world")"); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "negation") +{ + // Old solver does not correctly refine test types + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + local bar: { x: { y: unknown }} + local almostBar: { x: { y: unknown }} + + local foo + local almostFoo + + if typeof(bar.x.y) ~= "string" then + foo = bar + end + + if typeof(almostBar.x.y) ~= "number" then + almostFoo = almostBar + end + + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .x.y.Negation has type string, while the right type at .x.y.Negation has type number)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "union_missing_right") +{ + CheckResult result = check(R"( + local foo: string | number + local almostFoo: boolean | string + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at is a union containing type number, while the right type at is a union missing type number)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "union_missing_left") +{ + CheckResult result = check(R"( + local foo: string | number + local almostFoo: boolean | string | number + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at is a union missing type boolean, while the right type at is a union containing type boolean)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "union_missing") +{ + // TODO: this test case produces an error message that is not the most UX-friendly + + CheckResult result = check(R"( + local foo: { bar: number, pan: string } | { baz: boolean, rot: "singleton" } + local almostFoo: { bar: number, pan: string } | { baz: string, rot: "singleton" } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at is a union containing type {| baz: boolean, rot: "singleton" |}, while the right type at is a union missing type {| baz: boolean, rot: "singleton" |})"); +} + +TEST_CASE_FIXTURE(DifferFixture, "intersection_missing_right") +{ + CheckResult result = check(R"( + local foo: (number) -> () & (string) -> () + local almostFoo: (string) -> () & (boolean) -> () + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at is an intersection containing type (number) -> (), while the right type at is an intersection missing type (number) -> ())"); +} + +TEST_CASE_FIXTURE(DifferFixture, "intersection_missing_left") +{ + CheckResult result = check(R"( + local foo: (number) -> () & (string) -> () + local almostFoo: (string) -> () & (boolean) -> () & (number) -> () + )"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - diffMessage = diff(foo, almostFoo).diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at has type "hello", while the right type at has type "world")", - diffMessage); + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at is an intersection missing type (boolean) -> (), while the right type at is an intersection containing type (boolean) -> ())"); } -TEST_CASE_FIXTURE(Fixture, "equal_function") +TEST_CASE_FIXTURE(DifferFixture, "intersection_tables_missing_right") +{ + CheckResult result = check(R"( + local foo: { x: number } & { y: string } + local almostFoo: { y: string } & { z: boolean } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at is an intersection containing type {| x: number |}, while the right type at is an intersection missing type {| x: number |})"); +} + +TEST_CASE_FIXTURE(DifferFixture, "intersection_tables_missing_left") +{ + CheckResult result = check(R"( + local foo: { x: number } & { y: string } + local almostFoo: { y: string } & { z: boolean } & { x: number } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at is an intersection missing type {| z: boolean |}, while the right type at is an intersection containing type {| z: boolean |})"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_function") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -333,22 +291,10 @@ TEST_CASE_FIXTURE(Fixture, "equal_function") )"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - try - { - DifferResult diffRes = diff(foo, almostFoo); - INFO(diffRes.diffError->toString()); - CHECK(!diffRes.diffError.has_value()); - } - catch (InternalCompilerError e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } + compareTypesEq("foo", "almostFoo"); } -TEST_CASE_FIXTURE(Fixture, "equal_function_inferred_ret_length") +TEST_CASE_FIXTURE(DifferFixture, "equal_function_inferred_ret_length") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -369,22 +315,10 @@ TEST_CASE_FIXTURE(Fixture, "equal_function_inferred_ret_length") )"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - try - { - DifferResult diffRes = diff(foo, almostFoo); - INFO(diffRes.diffError->toString()); - CHECK(!diffRes.diffError.has_value()); - } - catch (InternalCompilerError e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } + compareTypesEq("foo", "almostFoo"); } -TEST_CASE_FIXTURE(Fixture, "equal_function_inferred_ret_length_2") +TEST_CASE_FIXTURE(DifferFixture, "equal_function_inferred_ret_length_2") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -402,22 +336,10 @@ TEST_CASE_FIXTURE(Fixture, "equal_function_inferred_ret_length_2") )"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - try - { - DifferResult diffRes = diff(foo, almostFoo); - INFO(diffRes.diffError->toString()); - CHECK(!diffRes.diffError.has_value()); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } + compareTypesEq("foo", "almostFoo"); } -TEST_CASE_FIXTURE(Fixture, "function_arg_normal") +TEST_CASE_FIXTURE(DifferFixture, "function_arg_normal") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -428,28 +350,15 @@ TEST_CASE_FIXTURE(Fixture, "function_arg_normal") end function almostFoo(a: number, b: number, msg: string) return a - almostFoo = foo + end )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - diffMessage = diff(foo, almostFoo).diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at .Arg[3] has type number, while the right type at .Arg[3] has type string)", - diffMessage); + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Arg[3] has type number, while the right type at .Arg[3] has type string)"); } -TEST_CASE_FIXTURE(Fixture, "function_arg_normal_2") +TEST_CASE_FIXTURE(DifferFixture, "function_arg_normal_2") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -460,28 +369,15 @@ TEST_CASE_FIXTURE(Fixture, "function_arg_normal_2") end function almostFoo(a: number, y: string, msg: string) return a - almostFoo = foo + end )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - diffMessage = diff(foo, almostFoo).diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at .Arg[2] has type number, while the right type at .Arg[2] has type string)", - diffMessage); + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Arg[2] has type number, while the right type at .Arg[2] has type string)"); } -TEST_CASE_FIXTURE(Fixture, "function_ret_normal") +TEST_CASE_FIXTURE(DifferFixture, "function_ret_normal") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -494,31 +390,13 @@ TEST_CASE_FIXTURE(Fixture, "function_ret_normal") return msg end )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Ret[1] has type number, while the right type at .Ret[1] has type string)"); +} - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at .Ret[1] has type number, while the right type at .Ret[1] has type string)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_arg_length") +TEST_CASE_FIXTURE(DifferFixture, "function_arg_length") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -531,31 +409,13 @@ TEST_CASE_FIXTURE(Fixture, "function_arg_length") return x end )"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at takes 2 or more arguments, while the right type at takes 3 or more arguments)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_arg_length_2") + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 2 or more arguments, while the right type at takes 3 or more arguments)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_arg_length_2") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -568,31 +428,13 @@ TEST_CASE_FIXTURE(Fixture, "function_arg_length_2") return x end )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 3 or more arguments, while the right type at takes 2 or more arguments)"); +} - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at takes 3 or more arguments, while the right type at takes 2 or more arguments)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_arg_length_none") +TEST_CASE_FIXTURE(DifferFixture, "function_arg_length_none") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -605,31 +447,13 @@ TEST_CASE_FIXTURE(Fixture, "function_arg_length_none") return x end )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 0 or more arguments, while the right type at takes 2 or more arguments)"); +} - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at takes 0 or more arguments, while the right type at takes 2 or more arguments)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_arg_length_none_2") +TEST_CASE_FIXTURE(DifferFixture, "function_arg_length_none_2") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -642,31 +466,13 @@ TEST_CASE_FIXTURE(Fixture, "function_arg_length_none_2") return 5 end )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 1 or more arguments, while the right type at takes 0 or more arguments)"); +} - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at takes 1 or more arguments, while the right type at takes 0 or more arguments)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_ret_length") +TEST_CASE_FIXTURE(DifferFixture, "function_ret_length") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -679,31 +485,13 @@ TEST_CASE_FIXTURE(Fixture, "function_ret_length") return x, y end )"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at returns 1 values, while the right type at returns 2 values)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_ret_length_2") + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at returns 1 values, while the right type at returns 2 values)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_ret_length_2") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -716,31 +504,13 @@ TEST_CASE_FIXTURE(Fixture, "function_ret_length_2") return y, x end )"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at returns 3 values, while the right type at returns 2 values)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_ret_length_none") + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at returns 3 values, while the right type at returns 2 values)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_ret_length_none") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -753,31 +523,13 @@ TEST_CASE_FIXTURE(Fixture, "function_ret_length_none") return x end )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at returns 0 values, while the right type at returns 1 values)"); +} - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at returns 0 values, while the right type at returns 1 values)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_ret_length_none_2") +TEST_CASE_FIXTURE(DifferFixture, "function_ret_length_none_2") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -790,31 +542,13 @@ TEST_CASE_FIXTURE(Fixture, "function_ret_length_none_2") return end )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at returns 1 values, while the right type at returns 0 values)"); +} - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at returns 1 values, while the right type at returns 0 values)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_variadic_arg_normal") +TEST_CASE_FIXTURE(DifferFixture, "function_variadic_arg_normal") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -827,31 +561,13 @@ TEST_CASE_FIXTURE(Fixture, "function_variadic_arg_normal") return a, b end )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type number, while the right type at .Arg[Variadic] has type string)"); +} - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type number, while the right type at .Arg[Variadic] has type string)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_variadic_arg_missing") +TEST_CASE_FIXTURE(DifferFixture, "function_variadic_arg_missing") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -864,31 +580,13 @@ TEST_CASE_FIXTURE(Fixture, "function_variadic_arg_missing") return a, b end )"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type number, while the right type at .Arg[Variadic] has type any)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_variadic_arg_missing_2") + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type number, while the right type at .Arg[Variadic] has type any)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_variadic_arg_missing_2") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -901,31 +599,13 @@ TEST_CASE_FIXTURE(Fixture, "function_variadic_arg_missing_2") return a, b end )"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type any, while the right type at .Arg[Variadic] has type string)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_variadic_oversaturation") + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type any, while the right type at .Arg[Variadic] has type string)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_variadic_oversaturation") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -938,31 +618,13 @@ TEST_CASE_FIXTURE(Fixture, "function_variadic_oversaturation") -- must not be oversaturated local almostFoo: (number, string) -> (number, string) = foo )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 2 or more arguments, while the right type at takes 2 arguments)"); +} - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at takes 2 or more arguments, while the right type at takes 2 arguments)", - diffMessage); -} - -TEST_CASE_FIXTURE(Fixture, "function_variadic_oversaturation_2") +TEST_CASE_FIXTURE(DifferFixture, "function_variadic_oversaturation_2") { // Old solver does not correctly infer function typepacks ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; @@ -975,28 +637,67 @@ TEST_CASE_FIXTURE(Fixture, "function_variadic_oversaturation_2") return x, y end )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 2 arguments, while the right type at takes 2 or more arguments)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "generic") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x, y) + return x, y + end + function almostFoo(x, y) + return y, x + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left generic at .Ret[1] cannot be the same type parameter as the right generic at .Ret[1])"); +} + +TEST_CASE_FIXTURE(DifferFixture, "generic_one_vs_two") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: X, y: X) + return + end + function almostFoo(x: T, y: U) + return + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left generic at .Arg[2] cannot be the same type parameter as the right generic at .Arg[2])"); +} + +TEST_CASE_FIXTURE(DifferFixture, "generic_three_or_three") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo(x: X, y: X, z: Y) + return + end + function almostFoo(x: T, y: U, z: U) + return + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId foo = requireType("foo"); - TypeId almostFoo = requireType("almostFoo"); - std::string diffMessage; - try - { - DifferResult diffRes = diff(foo, almostFoo); - if (!diffRes.diffError.has_value()) - { - INFO("Differ did not report type error, even though types are unequal"); - CHECK(false); - } - diffMessage = diffRes.diffError->toString(); - } - catch (const InternalCompilerError& e) - { - INFO(("InternalCompilerError: " + e.message)); - CHECK(false); - } - CHECK_EQ( - R"(DiffError: these two types are not equal because the left type at takes 2 arguments, while the right type at takes 2 or more arguments)", - diffMessage); + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left generic at .Arg[2] cannot be the same type parameter as the right generic at .Arg[2])"); } TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index c6fc475b2..d4fa71786 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -176,7 +176,7 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars if (FFlag::DebugLuauDeferredConstraintResolution) { ModulePtr module = Luau::check(*sourceModule, {}, builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, - frontend.globals.globalScope, /*prepareModuleScope*/ nullptr, frontend.options); + frontend.globals.globalScope, /*prepareModuleScope*/ nullptr, frontend.options, {}); Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); } @@ -415,6 +415,17 @@ TypeId Fixture::requireTypeAlias(const std::string& name) return *ty; } +TypeId Fixture::requireExportedType(const ModuleName& moduleName, const std::string& name) +{ + ModulePtr module = frontend.moduleResolver.getModule(moduleName); + REQUIRE(module); + + auto it = module->exportedTypeBindings.find(name); + REQUIRE(it != module->exportedTypeBindings.end()); + + return it->second.type; +} + std::string Fixture::decorateWithTypes(const std::string& code) { fileResolver.source[mainModuleName] = code; diff --git a/tests/Fixture.h b/tests/Fixture.h index 8d48ab1dc..a9c5d9b00 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -2,6 +2,8 @@ #pragma once #include "Luau/Config.h" +#include "Luau/Differ.h" +#include "Luau/Error.h" #include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/IostreamHelpers.h" @@ -15,6 +17,7 @@ #include "IostreamOptional.h" #include "ScopedFlags.h" +#include "doctest.h" #include #include #include @@ -92,6 +95,7 @@ struct Fixture std::optional lookupType(const std::string& name); std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); TypeId requireTypeAlias(const std::string& name); + TypeId requireExportedType(const ModuleName& moduleName, const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; @@ -153,6 +157,51 @@ std::optional linearSearchForBinding(Scope* scope, const char* name); void registerHiddenTypes(Frontend* frontend); void createSomeClasses(Frontend* frontend); +template +struct DifferFixtureGeneric : BaseFixture +{ + void compareNe(TypeId left, TypeId right, const std::string& expectedMessage) + { + std::string diffMessage; + try + { + DifferResult diffRes = diff(left, right); + REQUIRE_MESSAGE(diffRes.diffError.has_value(), "Differ did not report type error, even though types are unequal"); + diffMessage = diffRes.diffError->toString(); + } + catch (const InternalCompilerError& e) + { + REQUIRE_MESSAGE(false, ("InternalCompilerError: " + e.message)); + } + CHECK_EQ(expectedMessage, diffMessage); + } + + void compareTypesNe(const std::string& leftSymbol, const std::string& rightSymbol, const std::string& expectedMessage) + { + compareNe(BaseFixture::requireType(leftSymbol), BaseFixture::requireType(rightSymbol), expectedMessage); + } + + void compareEq(TypeId left, TypeId right) + { + try + { + DifferResult diffRes = diff(left, right); + CHECK_MESSAGE(!diffRes.diffError.has_value(), diffRes.diffError->toString()); + } + catch (const InternalCompilerError& e) + { + REQUIRE_MESSAGE(false, ("InternalCompilerError: " + e.message)); + } + } + + void compareTypesEq(const std::string& leftSymbol, const std::string& rightSymbol) + { + compareEq(BaseFixture::requireType(leftSymbol), BaseFixture::requireType(rightSymbol)); + } +}; +using DifferFixture = DifferFixtureGeneric; +using DifferFixtureWithBuiltins = DifferFixtureGeneric; + } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 8f6834a17..fda0a6f0a 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -444,6 +444,53 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") CHECK_EQ(toString(tyB), "any"); } +TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_exports") +{ + ScopedFastFlag luauFixCyclicModuleExports{"LuauFixCyclicModuleExports", true}; + + fileResolver.source["game/A"] = R"( +local b = require(game.B) +export type atype = { x: b.btype } +return {mod_a = 1} + )"; + + fileResolver.source["game/B"] = R"( +export type btype = { x: number } + +local function bf() + local a = require(game.A) + local bfl : a.atype = nil + return {bfl.x} +end +return {mod_b = 2} + )"; + + ToStringOptions opts; + opts.exhaustive = true; + + CheckResult resultA = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(resultA); + + CheckResult resultB = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(resultB); + + TypeId tyB = requireExportedType("game/B", "btype"); + CHECK_EQ(toString(tyB, opts), "{| x: number |}"); + + TypeId tyA = requireExportedType("game/A", "atype"); + CHECK_EQ(toString(tyA, opts), "{| x: any |}"); + + frontend.markDirty("game/B"); + resultB = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(resultB); + + tyB = requireExportedType("game/B", "btype"); + CHECK_EQ(toString(tyB, opts), "{| x: number |}"); + + tyA = requireExportedType("game/A", "atype"); + CHECK_EQ(toString(tyA, opts), "{| x: any |}"); +} + TEST_CASE_FIXTURE(FrontendFixture, "dont_reparse_clean_file_when_linting") { fileResolver.source["Modules/A"] = R"( diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 54a1f44cb..e906c224a 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1657,6 +1657,8 @@ _ = (math.random() < 0.5 and false) or 42 -- currently ignored TEST_CASE_FIXTURE(Fixture, "WrongComment") { + ScopedFastFlag sff("LuauLintNativeComment", true); + LintResult result = lint(R"( --!strict --!struct @@ -1666,17 +1668,19 @@ TEST_CASE_FIXTURE(Fixture, "WrongComment") --!nolint UnknownGlobal --! no more lint --!strict here +--!native on do end --!nolint )"); - REQUIRE(6 == result.warnings.size()); + REQUIRE(7 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Unknown comment directive 'struct'; did you mean 'strict'?"); CHECK_EQ(result.warnings[1].text, "Unknown comment directive 'nolintGlobal'"); CHECK_EQ(result.warnings[2].text, "nolint directive refers to unknown lint rule 'Global'"); CHECK_EQ(result.warnings[3].text, "nolint directive refers to unknown lint rule 'KnownGlobal'; did you mean 'UnknownGlobal'?"); CHECK_EQ(result.warnings[4].text, "Comment directive with the type checking mode has extra symbols at the end of the line"); - CHECK_EQ(result.warnings[5].text, "Comment directive is ignored because it is placed after the first non-comment token"); + CHECK_EQ(result.warnings[5].text, "native directive has extra symbols at the end of the line"); + CHECK_EQ(result.warnings[6].text, "Comment directive is ignored because it is placed after the first non-comment token"); } TEST_CASE_FIXTURE(Fixture, "WrongCommentMuteSelf") diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp index 1223152ba..63c03ba8e 100644 --- a/tests/Simplify.test.cpp +++ b/tests/Simplify.test.cpp @@ -114,6 +114,17 @@ struct SimplifyFixture : Fixture TEST_SUITE_BEGIN("Simplify"); +TEST_CASE_FIXTURE(SimplifyFixture, "overload_negation_refinement_is_never") +{ + TypeId f1 = mkFunction(stringTy, numberTy); + TypeId f2 = mkFunction(numberTy, stringTy); + TypeId intersection = arena->addType(IntersectionType{{f1, f2}}); + TypeId unionT = arena->addType(UnionType{{errorTy, functionTy}}); + TypeId negationT = mkNegation(unionT); + // The intersection of string -> number & number -> string, ~(error | function) + CHECK(neverTy == intersect(intersection, negationT)); +} + TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_other_tops_and_bottom_types") { CHECK(unknownTy == intersect(unknownTy, unknownTy)); diff --git a/tests/StringUtils.test.cpp b/tests/StringUtils.test.cpp index 786f965ea..cf65856d1 100644 --- a/tests/StringUtils.test.cpp +++ b/tests/StringUtils.test.cpp @@ -59,7 +59,7 @@ TEST_CASE("BenchmarkLevenshteinDistance") auto end = std::chrono::steady_clock::now(); auto time = std::chrono::duration_cast(end - start); - std::cout << "Running levenshtein distance " << count << " times took " << time.count() << "ms" << std::endl; + MESSAGE("Running levenshtein distance ", count, " times took ", time.count(), "ms"); } #endif diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 39759c716..d4a25f80c 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -11,6 +11,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); TEST_SUITE_BEGIN("ToString"); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 687bc766d..5601b8af8 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + TEST_SUITE_BEGIN("TypeInferAnyError"); TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 07471d444..a4df1be7c 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -12,6 +12,8 @@ using namespace Luau; using std::nullopt; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + TEST_SUITE_BEGIN("TypeInferClasses"); TEST_CASE_FIXTURE(ClassFixture, "call_method_of_a_class") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 268980feb..b37bcf834 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -14,7 +14,8 @@ using namespace Luau; -LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauInstantiateInSubtyping); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -2094,6 +2095,25 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "attempt_to_call_an_intersection_of_tables_wi LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "generic_packs_are_not_variadic") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + local function apply(f: (a, b...) -> c..., x: a) + return f(x) + end + + local function add(x: number, y: number) + return x + y + end + + apply(add, 5) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "num_is_solved_before_num_or_str") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 72323cf90..35df644ba 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,7 +9,8 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauInstantiateInSubtyping); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); using namespace Luau; diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 45d127ab8..954d9858d 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -8,6 +8,7 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); TEST_SUITE_BEGIN("IntersectionTypes"); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index a1f456a34..6f4f93287 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + TEST_SUITE_BEGIN("ProvisionalTests"); // These tests check for behavior that differs from the final behavior we'd @@ -502,6 +504,9 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, NotNull{scope.get()}, Location{}, Variance::Covariant}; + if (FFlag::DebugLuauDeferredConstraintResolution) + u.enableNewSolver(); + u.tryUnify(option1, option2); CHECK(!u.failure); @@ -565,7 +570,7 @@ return wrapStrictTable(Constants, "Constants") std::optional result = first(m->returnType); REQUIRE(result); if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("(any & ~table)?", toString(*result)); + CHECK_EQ("(any & ~(*error-type* | table))?", toString(*result)); else CHECK_MESSAGE(get(*result), *result); } @@ -905,6 +910,9 @@ TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, NotNull{scope.get()}, Location{}, Variance::Covariant}; + if (FFlag::DebugLuauDeferredConstraintResolution) + u.enableNewSolver(); + u.tryUnify(option1, option2); CHECK(!u.failure); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 0c8887404..ca302a2ff 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1820,4 +1820,46 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refinements_should_preserve_error_suppressio LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "many_refinements_on_val") +{ + CheckResult result = check(R"( + local function is_nan(val: any): boolean + return type(val) == "number" and val ~= val + end + + local function is_js_boolean(val: any): boolean + return not not val and val ~= 0 and val ~= "" and not is_nan(val) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(any) -> boolean", toString(requireType("is_nan"))); + CHECK_EQ("(any) -> boolean", toString(requireType("is_js_boolean"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + // this test is DCR-only as an instance of DCR fixing a bug in the old solver + + CheckResult result = check(R"( + local a : unknown = nil + + local idx, val + + if typeof(a) == "table" then + for i, v in a do + idx = i + val = v + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("unknown", toString(requireType("idx"))); + CHECK_EQ("unknown", toString(requireType("val"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index c61ff16e8..84e9fc7cd 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -7,6 +7,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + TEST_SUITE_BEGIN("TypeSingletons"); TEST_CASE_FIXTURE(Fixture, "function_args_infer_singletons") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index e3d712beb..8d93561f1 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -132,6 +132,24 @@ TEST_CASE_FIXTURE(Fixture, "cannot_change_type_of_table_prop") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "report_sensible_error_when_adding_a_value_to_a_nonexistent_prop") +{ + CheckResult result = check(R"( + local t = {} + t.foo[1] = 'one' + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + INFO(result.errors[0]); + + UnknownProperty* err = get(result.errors[0]); + REQUIRE(err); + + CHECK("t" == toString(err->table)); + CHECK("foo" == err->key); +} + TEST_CASE_FIXTURE(Fixture, "function_calls_can_produce_tables") { CheckResult result = check("function get_table() return {prop=999} end get_table().prop = 0"); @@ -439,8 +457,6 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") )"); LUAU_REQUIRE_NO_ERRORS(result); - for (const auto& e : result.errors) - std::cout << "Error: " << e << std::endl; TypeId qType = requireType("q"); const TableType* qTable = get(qType); @@ -3642,4 +3658,75 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "certain_properties_of_table_literal_arguments_can_be_covariant") +{ + CheckResult result = check(R"( + function f(a: {[string]: string | {any} | nil }) + return a + end + + local x = f({ + title = "Feature.VirtualEvents.EnableNotificationsModalTitle", + body = "Feature.VirtualEvents.EnableNotificationsModalBody", + notNow = "Feature.VirtualEvents.NotNowButton", + getNotified = "Feature.VirtualEvents.GetNotifiedButton", + }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "subproperties_can_also_be_covariantly_tested") +{ + CheckResult result = check(R"( + type T = { + [string]: {[string]: (string | number)?} + } + + function f(t: T) + return t + end + + local x = f({ + subprop={x="hello"} + }) + + local y = f({ + subprop={x=41} + }) + + local z = f({ + subprop={} + }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_shifted_tables") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = id({}) + foo.foo.foo = id({}) + foo.foo.foo.foo = id({}) + foo.foo.foo.foo.foo = foo + + local almostFoo = id({}) + almostFoo.foo = id({}) + almostFoo.foo.foo = id({}) + almostFoo.foo.foo.foo = id({}) + almostFoo.foo.foo.foo.foo = almostFoo + -- Shift + almostFoo = almostFoo.foo.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 7ecde7feb..36422f8da 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -224,7 +224,7 @@ TEST_CASE_FIXTURE(Fixture, "crazy_complexity") A:A():A():A():A():A():A():A():A():A():A():A() )"); - std::cout << "OK! Allocated " << typeChecker.types.size() << " types" << std::endl; + MESSAGE("OK! Allocated ", typeChecker.types.size(), " types"); } #endif @@ -1332,4 +1332,43 @@ TEST_CASE_FIXTURE(Fixture, "handle_self_referential_HasProp_constraints") )"); } +/* We had an issue where we were unifying two type packs + * + * free-2-0... and (string, free-4-0...) + * + * The correct thing to do here is to promote everything on the right side to + * level 2-0 before binding the left pack to the right. If we fail to do this, + * then the code fragment here fails to typecheck because the argument and + * return types of C are generalized before we ever get to checking the body of + * C. + */ +TEST_CASE_FIXTURE(Fixture, "promote_tail_type_packs") +{ + CheckResult result = check(R"( + --!strict + + local A: any = nil + + local C + local D = A( + A({}, { + __call = function(a): string + local E: string = C(a) + return E + end + }), + { + F = function(s: typeof(C)) + end + } + ) + + function C(b: any): string + return '' + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index e00d5ae42..5656e8715 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -345,7 +345,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table { ScopedFastFlag sff[] = { {"LuauTransitiveSubtyping", true}, - {"DebugLuauDeferredConstraintResolution", true}, }; TableType::Props freeProps{ @@ -369,6 +368,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table TypeId target = arena.addType(TableType{TableState::Unsealed, TypeLevel{}}); TypeId metatable = arena.addType(MetatableType{target, mt}); + state.enableNewSolver(); state.tryUnify(metatable, free); state.log.commit(); @@ -439,11 +439,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_two_unions_under_dcr_does_not_creat const TypeId innerType = arena.freshType(nestedScope.get()); ScopedFastFlag sffs[]{ - {"DebugLuauDeferredConstraintResolution", true}, {"LuauAlwaysCommitInferencesOfFunctionCalls", true}, }; - state.enableScopeTests(); + state.enableNewSolver(); SUBCASE("equal_scopes") { diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index afe0552cc..d2ae166b8 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 12868d8b3..3ab7bebb0 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + TEST_SUITE_BEGIN("UnionTypes"); TEST_CASE_FIXTURE(Fixture, "return_types_can_be_disjoint") diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index e78c3d06d..038594469 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -6,6 +6,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + TEST_SUITE_BEGIN("TypeInferUnknownNever"); TEST_CASE_FIXTURE(Fixture, "string_subtype_and_unknown_supertype") diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index f4a91fc38..07150ba80 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -944,6 +944,11 @@ end)(true) == 5050) assert(pcall(typeof) == false) assert(pcall(type) == false) +function nothing() end + +assert(pcall(function() return typeof(nothing()) end) == false) +assert(pcall(function() return type(nothing()) end) == false) + -- typeof == type in absence of custom userdata assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number,nil,table,userdata") diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 473427309..9e9ae3843 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -235,6 +235,12 @@ assert(flag); assert(select(2, pcall(math.random, 1, 2, 3)):match("wrong number of arguments")) +-- argument count +function nothing() end + +assert(pcall(math.abs) == false) +assert(pcall(function() return math.abs(nothing()) end) == false) + -- min/max assert(math.min(1) == 1) assert(math.min(1, 2) == 1) @@ -249,6 +255,7 @@ assert(math.max(1, -1, 2) == 2) assert(math.noise(0.5) == 0) assert(math.noise(0.5, 0.5) == -0.25) assert(math.noise(0.5, 0.5, -0.5) == 0.125) +assert(math.noise(455.7204209769105, 340.80410508750134, 121.80087666537628) == 0.5010709762573242) local inf = math.huge * 2 local nan = 0 / 0 diff --git a/tests/conformance/strings.lua b/tests/conformance/strings.lua index 702b51f83..370641d92 100644 --- a/tests/conformance/strings.lua +++ b/tests/conformance/strings.lua @@ -107,6 +107,12 @@ assert(tostring(1234567890123) == '1234567890123') assert(#tostring('\0') == 1) assert(tostring(true) == "true") assert(tostring(false) == "false") + +function nothing() end + +assert(pcall(tostring) == false) +assert(pcall(function() return tostring(nothing()) end) == false) + print('+') x = '"ílo"\n\\' diff --git a/tests/main.cpp b/tests/main.cpp index 5395e7c60..cfa1f9db1 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -201,7 +201,7 @@ static FValueResult parseFFlag(std::string_view view) auto [name, value] = parseFValueHelper(view); bool state = value ? *value == "true" : true; if (value && value != "true" && value != "false") - std::cerr << "Ignored '" << name << "' because '" << *value << "' is not a valid FFlag state." << std::endl; + fprintf(stderr, "Ignored '%s' because '%s' is not a valid flag state\n", name.c_str(), value->c_str()); return {name, state}; } @@ -264,9 +264,7 @@ int main(int argc, char** argv) if (skipFastFlag(flag->name)) continue; - if (flag->dynamic) - std::cout << 'D'; - std::cout << "FFlag" << flag->name << std::endl; + printf("%sFFlag%s\n", flag->dynamic ? "D" : "", flag->name); } return 0; @@ -286,7 +284,7 @@ int main(int argc, char** argv) if (doctest::parseIntOption(argc, argv, "-O", doctest::option_int, level)) { if (level < 0 || level > 2) - std::cerr << "Optimization level must be between 0 and 2 inclusive." << std::endl; + fprintf(stderr, "Optimization level must be between 0 and 2 inclusive\n"); else optimizationLevel = level; } diff --git a/tools/faillist.txt b/tools/faillist.txt index 31c42eb87..3e2ee1851 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,4 +1,6 @@ AstQuery.last_argument_function_call_type +AutocompleteTest.anonymous_autofilled_generic_on_argument_type_pack_vararg +AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 @@ -55,7 +57,6 @@ ProvisionalTests.typeguard_inference_incomplete RefinementTest.discriminate_from_truthiness_of_x RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_property_of_some_global -RefinementTest.refinements_should_preserve_error_suppression RefinementTest.truthy_constraint_on_properties RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector @@ -69,9 +70,6 @@ TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index TableTests.dont_suggest_exact_match_keys TableTests.error_detailed_metatable_prop -TableTests.expected_indexer_from_table_union -TableTests.expected_indexer_value_type_extra -TableTests.expected_indexer_value_type_extra_2 TableTests.explicitly_typed_table TableTests.explicitly_typed_table_with_indexer TableTests.fuzz_table_unify_instantiated_table @@ -128,6 +126,7 @@ TypeInfer.follow_on_new_types_in_substitution TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.no_stack_overflow_from_isoptional +TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter_2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer