From 001ac195ae09e858b176361a161507c3e24e5732 Mon Sep 17 00:00:00 2001 From: Brent Yorgey Date: Fri, 9 Jun 2023 20:22:50 -0500 Subject: [PATCH] More typechecking error message improvements (#1308) Towards #1297 . In this PR: - Moves some functions out of `Pipeline` and into more appropriate modules - Adds a new type `Join` to represent two types or other things, one "expected" and one "actual", at the point where they join + we check whether they are equal; a bunch of refactoring to use this new type. This way we don't have to remember which argument is which when calling a function that takes one "expected" thing and one "actual" thing. - Refactors all the `decomposeXXX` functions to take a `Syntax` node to use in error messages, and to take a `Sourced UType` (*i.e.* a `UType` along with whether it is "expected" or "actual") instead of just a `UType` so we can generate accurate error messages - Introduce a basic typechecking stack that keeps track of "what we are currently in the middle of checking", so that when we encounter an error we can say things like "while checking the definition of `foo`". Currently this is just a basic proof of concept. - A few other miscellaneous error message improvements. --- src/Swarm/Language/LSP.hs | 17 +- src/Swarm/Language/Pipeline.hs | 42 ++-- src/Swarm/Language/Pipeline/QQ.hs | 3 +- src/Swarm/Language/Pretty.hs | 86 ++++++-- src/Swarm/Language/Typecheck.hs | 277 ++++++++++++++++++-------- src/Swarm/Language/Typecheck/Unify.hs | 4 +- test/unit/TestLanguagePipeline.hs | 31 ++- 7 files changed, 319 insertions(+), 141 deletions(-) diff --git a/src/Swarm/Language/LSP.hs b/src/Swarm/Language/LSP.hs index 5abc5ae85..eeed96a4a 100644 --- a/src/Swarm/Language/LSP.hs +++ b/src/Swarm/Language/LSP.hs @@ -18,11 +18,14 @@ import Language.LSP.Server import Language.LSP.Types (Hover (Hover)) import Language.LSP.Types qualified as J import Language.LSP.Types.Lens qualified as J -import Language.LSP.VFS +import Language.LSP.VFS (VirtualFile (..), virtualFileText) import Swarm.Language.LSP.Hover qualified as H import Swarm.Language.LSP.VarUsage qualified as VU import Swarm.Language.Parse import Swarm.Language.Pipeline +import Swarm.Language.Pretty (prettyText) +import Swarm.Language.Syntax (SrcLoc (..)) +import Swarm.Language.Typecheck (ContextualTypeErr (..)) import System.IO (stderr) import Witch @@ -59,7 +62,7 @@ lspMain = diagnosticSourcePrefix :: Text diagnosticSourcePrefix = "swarm-lsp" -debug :: MonadIO m => Text -> m () +debug :: (MonadIO m) => Text -> m () debug msg = liftIO $ Text.hPutStrLn stderr $ "[swarm-lsp] " <> msg validateSwarmCode :: J.NormalizedUri -> J.TextDocumentVersion -> Text -> LspM () () @@ -125,6 +128,16 @@ validateSwarmCode doc version content = do Nothing -- tags (Just (J.List [])) +showTypeErrorPos :: Text -> ContextualTypeErr -> ((Int, Int), (Int, Int), Text) +showTypeErrorPos code (CTE l _ te) = (minusOne start, minusOne end, msg) + where + minusOne (x, y) = (x - 1, y - 1) + + (start, end) = case l of + SrcLoc s e -> getLocRange code (s, e) + NoLoc -> ((1, 1), (65535, 65535)) -- unknown loc spans the whole document + msg = prettyText te + handlers :: Handlers (LspM ()) handlers = mconcat diff --git a/src/Swarm/Language/Pipeline.hs b/src/Swarm/Language/Pipeline.hs index 8c66b73f1..ff225222b 100644 --- a/src/Swarm/Language/Pipeline.hs +++ b/src/Swarm/Language/Pipeline.hs @@ -14,8 +14,6 @@ module Swarm.Language.Pipeline ( processParsedTerm, processTerm', processParsedTerm', - prettyTypeErr, - showTypeErrorPos, processTermEither, ) where @@ -37,15 +35,17 @@ import Swarm.Language.Types import Witch -- | A record containing the results of the language processing --- pipeline. Put a 'Term' in, and get one of these out. -data ProcessedTerm - = ProcessedTerm - TModule - -- ^ The elaborated + type-annotated term, plus types of any embedded definitions - Requirements - -- ^ Requirements of the term - ReqCtx - -- ^ Capability context for any definitions embedded in the term +-- pipeline. Put a 'Term' in, and get one of these out. A +-- 'ProcessedTerm' contains: +-- +-- * The elaborated + type-annotated term, plus the types of any +-- embedded definitions ('TModule') +-- +-- * The 'Requirements' of the term +-- +-- * The requirements context for any definitions embedded in the +-- term ('ReqCtx') +data ProcessedTerm = ProcessedTerm TModule Requirements ReqCtx deriving (Data, Show, Eq, Generic) processTermEither :: Text -> Either String ProcessedTerm @@ -80,25 +80,7 @@ processParsedTerm = processParsedTerm' empty empty processTerm' :: TCtx -> ReqCtx -> Text -> Either Text (Maybe ProcessedTerm) processTerm' ctx capCtx txt = do mt <- readTerm txt - first (prettyTypeErr txt) $ traverse (processParsedTerm' ctx capCtx) mt - -prettyTypeErr :: Text -> ContextualTypeErr -> Text -prettyTypeErr code (CTE l te) = teLoc <> prettyText te - where - teLoc = case l of - SrcLoc s e -> (into @Text . showLoc . fst $ getLocRange code (s, e)) <> ": " - NoLoc -> "" - showLoc (r, c) = show r ++ ":" ++ show c - -showTypeErrorPos :: Text -> ContextualTypeErr -> ((Int, Int), (Int, Int), Text) -showTypeErrorPos code (CTE l te) = (minusOne start, minusOne end, msg) - where - minusOne (x, y) = (x - 1, y - 1) - - (start, end) = case l of - SrcLoc s e -> getLocRange code (s, e) - NoLoc -> ((1, 1), (65535, 65535)) -- unknown loc spans the whole document - msg = prettyText te + first (prettyTypeErrText txt) $ traverse (processParsedTerm' ctx capCtx) mt -- | Like 'processTerm'', but use a term that has already been parsed. processParsedTerm' :: TCtx -> ReqCtx -> Syntax -> Either ContextualTypeErr ProcessedTerm diff --git a/src/Swarm/Language/Pipeline/QQ.hs b/src/Swarm/Language/Pipeline/QQ.hs index 82dbd4e6c..5bf0aaee7 100644 --- a/src/Swarm/Language/Pipeline/QQ.hs +++ b/src/Swarm/Language/Pipeline/QQ.hs @@ -9,6 +9,7 @@ import Language.Haskell.TH qualified as TH import Language.Haskell.TH.Quote import Swarm.Language.Parse import Swarm.Language.Pipeline +import Swarm.Language.Pretty import Swarm.Language.Syntax import Swarm.Language.Types (Polytype) import Swarm.Util (failT, liftText) @@ -41,7 +42,7 @@ quoteTermExp s = do ) parsed <- runParserTH pos parseTerm s case processParsedTerm parsed of - Left err -> failT [prettyTypeErr (from s) err] + Left err -> failT [prettyTypeErrText (from s) err] Right ptm -> dataToExpQ ((fmap liftText . cast) `extQ` antiTermExp) ptm antiTermExp :: Term' Polytype -> Maybe TH.ExpQ diff --git a/src/Swarm/Language/Pretty.hs b/src/Swarm/Language/Pretty.hs index a1048b02f..286e6fcef 100644 --- a/src/Swarm/Language/Pretty.hs +++ b/src/Swarm/Language/Pretty.hs @@ -25,27 +25,39 @@ import Prettyprinter.Render.String qualified as RS import Prettyprinter.Render.Text qualified as RT import Swarm.Language.Capability import Swarm.Language.Context +import Swarm.Language.Parse (getLocRange) import Swarm.Language.Syntax import Swarm.Language.Typecheck import Swarm.Language.Types import Witch +------------------------------------------------------------ +-- PrettyPrec class + utilities + -- | Type class for things that can be pretty-printed, given a -- precedence level of their context. class PrettyPrec a where prettyPrec :: Int -> a -> Doc ann -- can replace with custom ann type later if desired -- | Pretty-print a thing, with a context precedence level of zero. -ppr :: PrettyPrec a => a -> Doc ann +ppr :: (PrettyPrec a) => a -> Doc ann ppr = prettyPrec 0 +-- | Render a pretty-printed document as @Text@. +docToText :: Doc a -> Text +docToText = RT.renderStrict . layoutPretty defaultLayoutOptions + -- | Pretty-print something and render it as @Text@. -prettyText :: PrettyPrec a => a -> Text -prettyText = RT.renderStrict . layoutPretty defaultLayoutOptions . ppr +prettyText :: (PrettyPrec a) => a -> Text +prettyText = docToText . ppr + +-- | Render a pretty-printed document as a @String@. +docToString :: Doc a -> String +docToString = RS.renderString . layoutPretty defaultLayoutOptions -- | Pretty-print something and render it as a @String@. -prettyString :: PrettyPrec a => a -> String -prettyString = RS.renderString . layoutPretty defaultLayoutOptions . ppr +prettyString :: (PrettyPrec a) => a -> String +prettyString = docToString . ppr -- | Optionally surround a document with parentheses depending on the -- @Bool@ argument. @@ -53,9 +65,25 @@ pparens :: Bool -> Doc ann -> Doc ann pparens True = parens pparens False = id +-- | Surround a document with backticks. bquote :: Doc ann -> Doc ann bquote d = "`" <> d <> "`" +-------------------------------------------------- +-- Bullet lists + +data BulletList i = BulletList + { bulletListHeader :: forall a. Doc a + , bulletListItems :: [i] + } + +instance (PrettyPrec i) => PrettyPrec (BulletList i) where + prettyPrec _ (BulletList hdr items) = + nest 2 . vcat $ hdr : map (("-" <+>) . ppr) items + +------------------------------------------------------------ +-- PrettyPrec instances for terms, types, etc. + instance PrettyPrec Text where prettyPrec _ = pretty @@ -72,14 +100,14 @@ instance PrettyPrec BaseTy where instance PrettyPrec IntVar where prettyPrec _ = pretty . mkVarName "u" -instance PrettyPrec (t (Fix t)) => PrettyPrec (Fix t) where +instance (PrettyPrec (t (Fix t))) => PrettyPrec (Fix t) where prettyPrec p = prettyPrec p . unFix instance (PrettyPrec (t (UTerm t v)), PrettyPrec v) => PrettyPrec (UTerm t v) where prettyPrec p (UTerm t) = prettyPrec p t prettyPrec p (UVar v) = prettyPrec p v -instance PrettyPrec t => PrettyPrec (TypeF t) where +instance (PrettyPrec t) => PrettyPrec (TypeF t) where prettyPrec _ (TyBaseF b) = ppr b prettyPrec _ (TyVarF v) = pretty v prettyPrec p (TySumF ty1 ty2) = @@ -103,7 +131,7 @@ instance PrettyPrec UPolytype where prettyPrec _ (Forall [] t) = ppr t prettyPrec _ (Forall xs t) = hsep ("∀" : map pretty xs) <> "." <+> ppr t -instance PrettyPrec t => PrettyPrec (Ctx t) where +instance (PrettyPrec t) => PrettyPrec (Ctx t) where prettyPrec _ Empty = emptyDoc prettyPrec _ (assocs -> bs) = brackets (hsep (punctuate "," (map prettyBinding bs))) @@ -206,20 +234,41 @@ appliedTermPrec (TApp f _) = case f of _ -> appliedTermPrec f appliedTermPrec _ = 10 +------------------------------------------------------------ +-- Error messages + +-- | Format a 'ContextualTypeError' for the user and render it as +-- @Text@. +prettyTypeErrText :: Text -> ContextualTypeErr -> Text +prettyTypeErrText code = docToText . prettyTypeErr code + +-- | Format a 'ContextualTypeError' for the user. +prettyTypeErr :: Text -> ContextualTypeErr -> Doc ann +prettyTypeErr code (CTE l tcStack te) = + vcat + [ teLoc <> ppr te + , ppr (BulletList "" tcStack) + ] + where + teLoc = case l of + SrcLoc s e -> (showLoc . fst $ getLocRange code (s, e)) <> ": " + NoLoc -> emptyDoc + showLoc (r, c) = pretty r <> ":" <> pretty c + instance PrettyPrec TypeErr where prettyPrec _ (UnifyErr ty1 ty2) = "Can't unify" <+> ppr ty1 <+> "and" <+> ppr ty2 - prettyPrec _ (Mismatch Nothing ty1 ty2) = + prettyPrec _ (Mismatch Nothing (getJoin -> (ty1, ty2))) = "Type mismatch: expected" <+> ppr ty1 <> ", but got" <+> ppr ty2 - prettyPrec _ (Mismatch (Just t) ty1 ty2) = + prettyPrec _ (Mismatch (Just t) (getJoin -> (ty1, ty2))) = nest 2 . vcat $ [ "Type mismatch:" , "From context, expected" <+> bquote (ppr t) <+> "to have type" <+> bquote (ppr ty1) <> "," , "but it actually has type" <+> bquote (ppr ty2) ] - prettyPrec _ (LambdaArgMismatch ty1 ty2) = + prettyPrec _ (LambdaArgMismatch (getJoin -> (ty1, ty2))) = "Lambda argument has type annotation" <+> ppr ty2 <> ", but expected argument type" <+> ppr ty1 - prettyPrec _ (FieldsMismatch expFs actFs) = fieldMismatchMsg expFs actFs + prettyPrec _ (FieldsMismatch (getJoin -> (expFs, actFs))) = fieldMismatchMsg expFs actFs prettyPrec _ (EscapedSkolem x) = "Skolem variable" <+> pretty x <+> "would escape its scope" prettyPrec _ (UnboundVar x) = @@ -255,11 +304,10 @@ instance PrettyPrec InvalidAtomicReason where prettyPrec _ NestedAtomic = "nested atomic block" prettyPrec _ LongConst = "commands that can take multiple ticks to execute are not allowed" -data BulletList i = BulletList - { bulletListHeader :: forall a. Doc a - , bulletListItems :: [i] - } +instance PrettyPrec LocatedTCFrame where + prettyPrec p (LocatedTCFrame _ f) = prettyPrec p f -instance PrettyPrec i => PrettyPrec (BulletList i) where - prettyPrec _ (BulletList hdr items) = - nest 2 . vcat $ hdr : map (("-" <+>) . ppr) items +instance PrettyPrec TCFrame where + prettyPrec _ (TCDef x) = "While checking the definition of" <+> pretty x + prettyPrec _ TCBindL = "While checking the left-hand side of a semicolon" + prettyPrec _ TCBindR = "While checking the right-hand side of a semicolon" diff --git a/src/Swarm/Language/Typecheck.hs b/src/Swarm/Language/Typecheck.hs index 590e46e3a..cda8f4783 100644 --- a/src/Swarm/Language/Typecheck.hs +++ b/src/Swarm/Language/Typecheck.hs @@ -1,6 +1,7 @@ {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- For 'Ord IntVar' instance @@ -17,16 +18,26 @@ module Swarm.Language.Typecheck ( TypeErr (..), InvalidAtomicReason (..), - -- * Inference monad + -- * Type provenance + Source (..), + Join, + getJoin, + + -- * Typechecking stack + TCFrame (..), + LocatedTCFrame (..), + TCStack, + withFrame, + getTCStack, + + -- * Typechecking monad TC, runTC, - lookup, fresh, -- * Unification substU, - expect, - (=:=), + unify, HasBindings (..), instantiate, skolemize, @@ -38,8 +49,6 @@ module Swarm.Language.Typecheck ( infer, inferConst, check, - decomposeCmdTy, - decomposeFunTy, isSimpleUType, ) where @@ -49,7 +58,7 @@ import Control.Lens ((^.)) import Control.Lens.Indexed (itraverse) import Control.Monad.Except import Control.Monad.Reader -import Control.Unification hiding (applyBindings, (=:=)) +import Control.Unification hiding (applyBindings, unify, (=:=)) import Control.Unification qualified as U import Control.Unification.IntVar import Data.Data (Data, gmapM) @@ -71,6 +80,63 @@ import Swarm.Language.Typecheck.Unify import Swarm.Language.Types import Prelude hiding (lookup) +------------------------------------------------------------ +-- Typechecking stack + +-- | A frame to keep track of something we were in the middle of doing +-- during typechecking. +data TCFrame where + -- | Checking a definition. + TCDef :: Var -> TCFrame + -- | Inferring the LHS of a bind. + TCBindL :: TCFrame + -- | Inferring the RHS of a bind. + TCBindR :: TCFrame + deriving (Show) + +-- | A typechecking stack frame together with the relevant @SrcLoc@. +data LocatedTCFrame = LocatedTCFrame SrcLoc TCFrame + deriving (Show) + +-- | A typechecking stack keeps track of what we are currently in the +-- middle of doing during typechecking. +type TCStack = [LocatedTCFrame] + +------------------------------------------------------------ +-- Type source + +-- | The source of a type during typechecking. +data Source + = -- | An expected type that was "pushed down" from the context. + Expected + | -- | An actual/inferred type that was "pulled up" from a term. + Actual + deriving (Show, Eq, Ord, Bounded, Enum) + +-- | A value along with its source (expected vs actual). +type Sourced a = (Source, a) + +-- | A "join" where an expected thing meets an actual thing. +newtype Join a = Join (Source -> a) + +instance (Show a) => Show (Join a) where + show (getJoin -> (e, a)) = "(expected: " <> show e <> ", actual: " <> show a <> ")" + +type TypeJoin = Join UType + +-- | Create a 'Join' from an expected thing and an actual thing (in that order). +joined :: a -> a -> Join a +joined expect actual = Join (\case Expected -> expect; Actual -> actual) + +-- | Create a 'Join' from a 'Sourced' thing together with another +-- thing (which is assumed to have the opposite 'Source'). +mkJoin :: Sourced a -> a -> Join a +mkJoin (src, a1) a2 = Join $ \s -> if s == src then a1 else a2 + +-- | Convert a 'Join' into a pair of (expected, actual). +getJoin :: Join a -> (a, a) +getJoin (Join j) = (j Expected, j Actual) + ------------------------------------------------------------ -- Type checking monad @@ -78,7 +144,18 @@ import Prelude hiding (lookup) -- monad transformer provided by the @unification-fd@ library which -- supports various operations such as generating fresh variables -- and unifying things. -type TC = ReaderT UCtx (ExceptT ContextualTypeErr (IntBindingT TypeF Identity)) +type TC = ReaderT UCtx (ReaderT TCStack (ExceptT ContextualTypeErr (IntBindingT TypeF Identity))) + +-- | Push a frame on the typechecking stack within a local 'TC' +-- computation. +withFrame :: SrcLoc -> TCFrame -> TC a -> TC a +withFrame l f = mapReaderT (local (LocatedTCFrame l f :)) + +-- | Get the current typechecking stack. +getTCStack :: TC TCStack +getTCStack = lift ask + +------------------------------------------------------------ -- | Run a top-level inference computation, returning either a -- 'TypeErr' or a fully resolved 'TModule'. @@ -92,6 +169,7 @@ runTC ctx = <*> pure (fromU uctx) ) >>> flip runReaderT (toU ctx) + >>> flip runReaderT [] >>> runExceptT >>> evalIntBindingT >>> runIdentity @@ -102,15 +180,19 @@ runTC ctx = -- 'instantiate'. lookup :: SrcLoc -> Var -> TC UType lookup loc x = do - ctx <- ask + ctx <- getCtx maybe (throwTypeErr loc $ UnboundVar x) instantiate (Ctx.lookup x ctx) +-- | Get the current type context. +getCtx :: TC UCtx +getCtx = ask + -- | Catch any thrown type errors and re-throw them with an added source -- location. addLocToTypeErr :: SrcLoc -> TC a -> TC a addLocToTypeErr l m = m `catchError` \case - CTE NoLoc te -> throwTypeErr l te + CTE NoLoc _ te -> throwTypeErr l te te -> throwError te ------------------------------------------------------------ @@ -128,10 +210,10 @@ class FreeVars a where -- | We can get the free unification variables of a 'UType'. instance FreeVars UType where - freeVars ut = fmap S.fromList . lift . lift $ getFreeVars ut + freeVars ut = fmap S.fromList . lift . lift . lift $ getFreeVars ut -- | We can also get the free variables of a polytype. -instance FreeVars t => FreeVars (Poly t) where +instance (FreeVars t) => FreeVars (Poly t) where freeVars (Forall _ t) = freeVars t -- | We can get the free variables in any polytype in a context. @@ -140,7 +222,7 @@ instance FreeVars UCtx where -- | Generate a fresh unification variable. fresh :: TC UType -fresh = UVar <$> lift (lift freeVar) +fresh = UVar <$> (lift . lift . lift $ freeVar) -- | Perform a substitution over a 'UType', substituting for both type -- and unification variables. Note that since 'UType's do not have @@ -182,25 +264,24 @@ noSkolems l (Forall xs upty) = do infix 4 =:= --- | @expect t expTy actTy@ expects that term @t@ has type @expTy@, --- where it actually has type @actTy@. Ensure those types are the --- same. -expect :: Maybe Syntax -> UType -> UType -> TC UType -expect ms expected actual = case unifyCheck expected actual of - Apart -> throwTypeErr NoLoc $ Mismatch ms expected actual +-- | @unify t expTy actTy@ ensures that the given two types are equal. +-- If we know the actual term @t@ which is supposed to have these +-- types, we can use it to generate better error messages. +-- +-- We first do a quick-and-dirty check to see whether we know for +-- sure the types either are or cannot be equal, generating an +-- equality constraint for the unifier as a last resort. +unify :: Maybe Syntax -> TypeJoin -> TC UType +unify ms j = case unifyCheck expected actual of + Apart -> throwTypeErr NoLoc $ Mismatch ms j Equal -> return expected - MightUnify -> lift $ expected U.=:= actual + MightUnify -> lift . lift $ expected U.=:= actual + where + (expected, actual) = getJoin j --- | Constrain two types to be equal, first with a quick-and-dirty --- check to see whether we know for sure they either are or cannot --- be equal, generating an equality constraint for the unifier as a --- last resort. --- --- Important: the first given type should be the "expected" type --- from context, and the second should be the "actual" type (in any --- situation when this distinction makes sense). +-- | Ensure two types are the same. (=:=) :: UType -> UType -> TC UType -(=:=) = expect Nothing +ty1 =:= ty2 = unify Nothing (joined ty1 ty2) -- | @unification-fd@ provides a function 'U.applyBindings' which -- fully substitutes for any bound unification variables (for @@ -212,7 +293,7 @@ class HasBindings u where applyBindings :: u -> TC u instance HasBindings UType where - applyBindings = lift . U.applyBindings + applyBindings = lift . lift . U.applyBindings instance HasBindings UPolytype where applyBindings (Forall xs u) = Forall xs <$> applyBindings u @@ -261,7 +342,7 @@ skolemize (Forall xs uty) = do generalize :: UType -> TC UPolytype generalize uty = do uty' <- applyBindings uty - ctx <- ask + ctx <- getCtx tmfvs <- freeVars uty' ctxfvs <- freeVars ctx let fvs = S.toList $ tmfvs \\ ctxfvs @@ -279,24 +360,23 @@ generalize uty = do -- Type errors -- | A type error along with various contextual information to help us --- generate better error messages. For now there is only a SrcLoc --- but there will be additional context in the future, such as a --- stack of stuff we were in the middle of doing, relevant names in --- scope, etc. (#1297). -data ContextualTypeErr = CTE {cteSrcLoc :: SrcLoc, cteTypeErr :: TypeErr} +-- generate better error messages. +data ContextualTypeErr = CTE {cteSrcLoc :: SrcLoc, cteStack :: TCStack, cteTypeErr :: TypeErr} deriving (Show) -- | Create a raw 'ContextualTypeErr' with no context information. mkRawTypeErr :: TypeErr -> ContextualTypeErr -mkRawTypeErr = CTE NoLoc +mkRawTypeErr = CTE NoLoc [] -- | Create a 'ContextualTypeErr' value from a 'TypeErr' and context. -mkTypeErr :: SrcLoc -> TypeErr -> ContextualTypeErr +mkTypeErr :: SrcLoc -> TCStack -> TypeErr -> ContextualTypeErr mkTypeErr = CTE -- | Throw a 'ContextualTypeErr'. throwTypeErr :: SrcLoc -> TypeErr -> TC a -throwTypeErr l te = throwError (mkTypeErr l te) +throwTypeErr l te = do + stk <- getTCStack + throwError $ mkTypeErr l stk te -- | Errors that can occur during type checking. The idea is that -- each error carries information that can be used to help explain @@ -315,13 +395,13 @@ data TypeErr | -- | Type mismatch caught by 'unifyCheck'. The given term was -- expected to have a certain type, but has a different type -- instead. - Mismatch (Maybe Syntax) UType UType -- expected, actual + Mismatch (Maybe Syntax) TypeJoin | -- | Lambda argument type mismatch. - LambdaArgMismatch UType UType -- expected, actual + LambdaArgMismatch TypeJoin | -- | Record field mismatch, i.e. based on the expected type we -- were expecting a record with certain fields, but found one with -- a different field set. - FieldsMismatch (Set Var) (Set Var) -- expected fields, actual + FieldsMismatch (Join (Set Var)) | -- | A definition was encountered not at the top level. DefNotTopLevel Term | -- | A term was encountered which we cannot infer the type of. @@ -357,38 +437,46 @@ instance Fallible TypeF IntVar ContextualTypeErr where ------------------------------------------------------------ -- Type decomposition --- | Decompose a type that is supposed to be a delay type. -decomposeDelayTy :: UType -> TC UType -decomposeDelayTy (UTyDelay a) = return a -decomposeDelayTy ty = do +-- | Decompose a type that is supposed to be a delay type. Also take +-- the term which is supposed to have that type, for use in error +-- messages. +decomposeDelayTy :: Syntax -> Sourced UType -> TC UType +decomposeDelayTy _ (_, UTyDelay a) = return a +decomposeDelayTy t ty = do a <- fresh - _ <- UTyDelay a =:= ty + _ <- unify (Just t) (mkJoin ty (UTyDelay a)) return a --- | Decompose a type that is supposed to be a command type. -decomposeCmdTy :: UType -> TC UType -decomposeCmdTy (UTyCmd a) = return a -decomposeCmdTy ty = do +-- | Decompose a type that is supposed to be a command type. Also take +-- the term which is supposed to have that type, for use in error +-- messages. +decomposeCmdTy :: Syntax -> Sourced UType -> TC UType +decomposeCmdTy _ (_, UTyCmd a) = return a +decomposeCmdTy t ty = do a <- fresh - _ <- UTyCmd a =:= ty + _ <- unify (Just t) (mkJoin ty (UTyCmd a)) return a --- | Decompose a type that is supposed to be a function type. -decomposeFunTy :: UType -> TC (UType, UType) -decomposeFunTy (UTyFun ty1 ty2) = return (ty1, ty2) -decomposeFunTy ty = do +-- | Decompose a type that is supposed to be a function type. Also take +-- the term which is supposed to have that type, for use in error +-- messages. +decomposeFunTy :: Syntax -> Sourced UType -> TC (UType, UType) +decomposeFunTy _ (_, UTyFun ty1 ty2) = return (ty1, ty2) +decomposeFunTy t ty = do ty1 <- fresh ty2 <- fresh - _ <- UTyFun ty1 ty2 =:= ty + _ <- unify (Just t) (mkJoin ty (UTyFun ty1 ty2)) return (ty1, ty2) --- | Decompose a type that is supposed to be a product type. -decomposeProdTy :: UType -> TC (UType, UType) -decomposeProdTy (UTyProd ty1 ty2) = return (ty1, ty2) -decomposeProdTy ty = do +-- | Decompose a type that is supposed to be a product type. Also take +-- the term which is supposed to have that type, for use in error +-- messages. +decomposeProdTy :: Syntax -> Sourced UType -> TC (UType, UType) +decomposeProdTy _ (_, UTyProd ty1 ty2) = return (ty1, ty2) +decomposeProdTy t ty = do ty1 <- fresh ty2 <- fresh - _ <- UTyProd ty1 ty2 =:= ty + _ <- unify (Just t) (mkJoin ty (UTyProd ty1 ty2)) return (ty1, ty2) ------------------------------------------------------------ @@ -408,16 +496,16 @@ inferModule s@(Syntax l t) = addLocToTypeErr l $ case t of -- variable for the body, infer the body under an extended context, -- and unify the two. Then generalize the type and return an -- appropriate context. - SDef r x Nothing t1 -> do + SDef r x Nothing t1 -> withFrame l (TCDef (lvVar x)) $ do xTy <- fresh t1' <- withBinding (lvVar x) (Forall [] xTy) $ infer t1 - _ <- xTy =:= t1' ^. sType + _ <- unify (Just t1) (joined xTy (t1' ^. sType)) pty <- generalize (t1' ^. sType) return $ Module (Syntax' l (SDef r x Nothing t1') (UTyCmd UTyUnit)) (singleton (lvVar x) pty) -- If a (poly)type signature has been provided, skolemize it and -- check the definition. - SDef r x (Just pty) t1 -> do + SDef r x (Just pty) t1 -> withFrame l (TCDef (lvVar x)) $ do let upty = toU pty uty <- skolemize upty t1' <- withBinding (lvVar x) upty $ check t1 uty @@ -428,8 +516,8 @@ inferModule s@(Syntax l t) = addLocToTypeErr l $ case t of -- correct context when checking the right-hand side in particular. SBind mx c1 c2 -> do -- First, infer the left side. - Module c1' ctx1 <- inferModule c1 - a <- decomposeCmdTy (c1' ^. sType) + Module c1' ctx1 <- withFrame l TCBindL $ inferModule c1 + a <- decomposeCmdTy c1 (Actual, c1' ^. sType) -- Now infer the right side under an extended context: things in -- scope on the right-hand side include both any definitions @@ -440,13 +528,13 @@ inferModule s@(Syntax l t) = addLocToTypeErr l $ case t of -- that binding /after/ (i.e. /within/) the application of @ctx1@. withBindings ctx1 $ maybe id ((`withBinding` Forall [] a) . lvVar) mx $ do - Module c2' ctx2 <- inferModule c2 + Module c2' ctx2 <- withFrame l TCBindR $ inferModule c2 -- We don't actually need the result type since we're just -- going to return the entire type, but it's important to -- ensure it's a command type anyway. Otherwise something -- like 'move; 3' would be accepted with type int. - _ <- decomposeCmdTy (c2' ^. sType) + _ <- decomposeCmdTy c2 (Actual, c2' ^. sType) -- Ctx.union is right-biased, so ctx1 `union` ctx2 means later -- definitions will shadow previous ones. Include the binder @@ -527,19 +615,38 @@ infer s@(Syntax l t) = addLocToTypeErr l $ case t of SApp f x -> do -- Infer the type of the left-hand side and make sure it has a function type. f' <- infer f - (argTy, resTy) <- decomposeFunTy (f' ^. sType) + (argTy, resTy) <- decomposeFunTy f (Actual, f' ^. sType) -- Then check that the argument has the right type. x' <- check x argTy - return $ Syntax' l (SApp f' x') resTy + + -- Call applyBindings explicitly, so that anything we learned + -- about unification variables while checking the type of the + -- argument can flow to later steps. This is especially helpful + -- while checking applications of polymorphic multi-argument + -- functions such as 'if'. Without this call to 'applyBindings', + -- type mismatches between the branches of an 'if' tend to get + -- caught in the unifier, resulting in vague "can't unify" + -- messages (for example, "if true {3} {move}" yields "can't + -- unify int and cmd unit"). With this 'applyBindings' call, we + -- get more specific errors about how the second branch was + -- expected to have the same type as the first (e.g. "expected + -- `move` to have type `int`, but it actually has type `cmd + -- unit`). + resTy' <- applyBindings resTy + + return $ Syntax' l (SApp f' x') resTy' -- We handle binds in inference mode for a similar reason to -- application. SBind mx c1 c2 -> do - c1' <- infer c1 - a <- decomposeCmdTy (c1' ^. sType) - c2' <- maybe id ((`withBinding` Forall [] a) . lvVar) mx $ infer c2 - _ <- decomposeCmdTy (c2' ^. sType) + c1' <- withFrame l TCBindL $ infer c1 + a <- decomposeCmdTy c1 (Actual, c1' ^. sType) + c2' <- + maybe id ((`withBinding` Forall [] a) . lvVar) mx + . withFrame l TCBindR + $ infer c2 + _ <- decomposeCmdTy c2 (Actual, c2' ^. sType) return $ Syntax' l (SBind mx c1' c2') (c2' ^. sType) -- Handle record projection in inference mode. Knowing the expected @@ -567,7 +674,7 @@ infer s@(Syntax l t) = addLocToTypeErr l $ case t of uty <- skolemize upty _ <- check c uty -- Make sure no skolem variables have escaped. - ask >>= mapM_ (noSkolems l) + getCtx >>= mapM_ (noSkolems l) -- If check against skolemized polytype is successful, -- instantiate polytype with unification variables. -- Free variables should be able to unify with anything in @@ -703,28 +810,28 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of -- dynamically at runtime when evaluating recursive let or def expressions, -- so we don't have to worry about typechecking them here. SDelay d s1 -> do - ty1 <- decomposeDelayTy expected + ty1 <- decomposeDelayTy s (Expected, expected) s1' <- check s1 ty1 return $ Syntax' l (SDelay d s1') (UTyDelay ty1) -- To check the type of a pair, make sure the expected type is a -- product type, and push the two types down into the left and right. SPair s1 s2 -> do - (ty1, ty2) <- decomposeProdTy expected + (ty1, ty2) <- decomposeProdTy s (Expected, expected) s1' <- check s1 ty1 s2' <- check s2 ty2 return $ Syntax' l (SPair s1' s2') (UTyProd ty1 ty2) -- To check a lambda, make sure the expected type is a function type. SLam x mxTy body -> do - (argTy, resTy) <- decomposeFunTy expected + (argTy, resTy) <- decomposeFunTy s (Expected, expected) case toU mxTy of Just xTy -> case unifyCheck argTy xTy of -- Generate a special error when the explicit type annotation -- on a lambda doesn't match the expected type, -- e.g. (\x:int. x + 2) : text -> int, since the usual -- "expected/but got" language would probably be confusing. - Apart -> throwTypeErr l $ LambdaArgMismatch argTy xTy + Apart -> throwTypeErr l $ LambdaArgMismatch (joined argTy xTy) -- Otherwise, make sure to unify the annotation with the -- expected argument type. _ -> void $ argTy =:= xTy @@ -738,7 +845,7 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of TConst c :$: at | c `elem` [Atomic, Instant] -> do - argTy <- decomposeCmdTy expected + argTy <- decomposeCmdTy s (Expected, expected) at' <- check at (UTyCmd argTy) atomic' <- infer (Syntax l (TConst c)) -- It's important that we typecheck the subterm @at@ *before* we @@ -777,7 +884,7 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of t2' <- withBinding (lvVar x) upty $ check t2 expected -- Make sure no skolem variables have escaped. - ask >>= mapM_ (noSkolems l) + getCtx >>= mapM_ (noSkolems l) -- Return the annotated let. return $ Syntax' l (SLet r x mxTy t1' t2') expected @@ -800,7 +907,7 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of actualFields = M.keysSet fields when (actualFields /= expectedFields) $ throwTypeErr l $ - FieldsMismatch expectedFields actualFields + FieldsMismatch (joined expectedFields actualFields) m' <- itraverse (\x ms -> check (fromMaybe (STerm (TVar x)) ms) (tyMap ! x)) fields return $ Syntax' l (SRcd (Just <$> m')) expected @@ -808,7 +915,7 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of -- get is what we expected. _ -> do Syntax' l' t' actual <- infer s - Syntax' l' t' <$> expect (Just s) expected actual + Syntax' l' t' <$> unify (Just s) (joined expected actual) -- ~~~~ Note [Checking and inference for record literals] -- @@ -908,7 +1015,7 @@ analyzeAtomic locals (Syntax l t) = case t of TVar x | x `S.member` locals -> return 0 | otherwise -> do - mxTy <- asks $ Ctx.lookup x + mxTy <- Ctx.lookup x <$> getCtx case mxTy of -- If the variable is undefined, return 0 to indicate the -- atomic block is valid, because we'd rather have the error diff --git a/src/Swarm/Language/Typecheck/Unify.hs b/src/Swarm/Language/Typecheck/Unify.hs index 4b1c86c3c..c68308ebf 100644 --- a/src/Swarm/Language/Typecheck/Unify.hs +++ b/src/Swarm/Language/Typecheck/Unify.hs @@ -52,8 +52,8 @@ instance Monoid UnifyStatus where -- 'Equal'. In case (1), we can generate a much better error -- message at the instant the two types come together than we could -- if we threw a constraint into the unifier. In case (2), we don't --- have to bother with generating a constraint. If we don't know for --- sure whether they will unify, return 'MightUnify'. +-- have to bother with generating a trivial constraint. If we don't +-- know for sure whether they will unify, return 'MightUnify'. unifyCheck :: UType -> UType -> UnifyStatus unifyCheck ty1 ty2 = case (ty1, ty2) of (UVar x, UVar y) diff --git a/test/unit/TestLanguagePipeline.hs b/test/unit/TestLanguagePipeline.hs index b73cf9b18..cec8624fe 100644 --- a/test/unit/TestLanguagePipeline.hs +++ b/test/unit/TestLanguagePipeline.hs @@ -100,13 +100,13 @@ testLanguagePipeline = "failure inside bind chain" ( process "move;\n1;\nmove" - "2:1: Type mismatch: expected cmd u0, but got int" + "2:1: Type mismatch:\n From context, expected `1` to have type `cmd u0`,\n but it actually has type `int`" ) , testCase "failure inside function call" ( process "if true \n{} \n(move)" - "3:1: Type mismatch:\n From context, expected `move` to have type `{u0}`,\n but it actually has type `cmd unit`" + "3:1: Type mismatch:\n From context, expected `move` to have type `{cmd unit}`,\n but it actually has type `cmd unit`" ) , testCase "parsing operators #236 - report failure on invalid operator start" @@ -336,6 +336,33 @@ testLanguagePipeline = "1:1: Lambda argument has type annotation int, but expected argument type text" ) ] + , testGroup + "typechecking errors" + [ testCase + "applying a pair" + ( process + "(1,2) \"hi\"" + "1:1: Type mismatch:\n From context, expected `(1, 2)` to have type `u3 -> u4`,\n but it actually has type `u1 * u2`" + ) + , testCase + "providing a pair as an argument" + ( process + "(\\x:int. x + 1) (1,2)" + "1:17: Type mismatch:\n From context, expected `(1, 2)` to have type `int`,\n but it actually has type `u0 * u1`" + ) + , testCase + "mismatched if branches" + ( process + "if true {grab} {}" + "1:16: Type mismatch:\n From context, expected `noop` to have type `cmd text`,\n but it actually has type `cmd unit`" + ) + , testCase + "definition with wrong result" + ( process + "def m : int -> int -> int = \\x. \\y. {3} end" + "1:37: Type mismatch:\n From context, expected `{3}` to have type `int`,\n but it actually has type `{u0}`" + ) + ] ] where valid = flip process ""