From d54ddb49fcf318846abeda986e4b4e2f2515d0f8 Mon Sep 17 00:00:00 2001 From: Brent Yorgey Date: Sat, 22 Jun 2024 17:30:24 -0500 Subject: [PATCH] apply bindings before returning unification mismatch error --- src/swarm-lang/Swarm/Language/Typecheck.hs | 17 ++++++++++++++++- test/unit/TestLanguagePipeline.hs | 10 ++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/swarm-lang/Swarm/Language/Typecheck.hs b/src/swarm-lang/Swarm/Language/Typecheck.hs index 260bccb58..b168ed9f9 100644 --- a/src/swarm-lang/Swarm/Language/Typecheck.hs +++ b/src/swarm-lang/Swarm/Language/Typecheck.hs @@ -127,6 +127,19 @@ type Sourced a = (Source, a) -- | A "join" where an expected thing meets an actual thing. newtype Join a = Join (Source -> a) + deriving (Functor) + +instance Foldable Join where + foldMap :: Monoid m => (a -> m) -> Join a -> m + foldMap f j = f a1 <> f a2 + where + (a1, a2) = getJoin j + +instance Traversable Join where + traverse :: Applicative f => (a -> f b) -> Join a -> f (Join b) + traverse f j = joined <$> f a1 <*> f a2 + where + (a1, a2) = getJoin j instance (Show a) => Show (Join a) where show (getJoin -> (e, a)) = "(expected: " <> show e <> ", actual: " <> show a <> ")" @@ -313,7 +326,9 @@ unify :: unify ms j = do res <- expected =:= actual case res of - Left _ -> throwTypeErr NoLoc $ Mismatch ms j + Left _ -> do + j' <- traverse U.applyBindings j + throwTypeErr NoLoc $ Mismatch ms j' Right ty -> return ty where (expected, actual) = getJoin j diff --git a/test/unit/TestLanguagePipeline.hs b/test/unit/TestLanguagePipeline.hs index fdecd3282..6c7b04174 100644 --- a/test/unit/TestLanguagePipeline.hs +++ b/test/unit/TestLanguagePipeline.hs @@ -482,7 +482,7 @@ testLanguagePipeline = "applying a pair" ( process "(1,2) \"hi\"" - "1:1: Type mismatch:\n From context, expected `(1, 2)` to be a function,\n but it is actually a pair" + "1:1: Type mismatch:\n From context, expected `(1, 2)` to be a function,\n but it actually has type `Int * Int`" ) , testCase "providing a pair as an argument" @@ -520,6 +520,12 @@ testLanguagePipeline = "1 + (\\x. \\y. 3)" "1:5: Type mismatch:\n From context, expected `\\x. \\y. 3` to have type `Int`,\n but it is actually a function\n" ) + , testCase + "apply HOF to int - #1888" + ( process + "(\\f. f 3) 2" + "1:11: Type mismatch:\n From context, expected `2` to have a type like `Int -> _`" + ) ] , testGroup "generalize top-level binds #351 #1501" @@ -644,7 +650,7 @@ testLanguagePipeline = "def at non-cmd type" ( process "def x = 3 end; x + 2" - "1:16: Type mismatch:\n From context, expected `x + 2` to have a type like" + "1:16: Type mismatch:\n From context, expected `x + 2` to be a command" ) , testCase "def at cmd type"