diff --git a/src/Type/Infer.hs b/src/Type/Infer.hs index 4db298d80..ca63faaf7 100644 --- a/src/Type/Infer.hs +++ b/src/Type/Infer.hs @@ -75,6 +75,7 @@ import Core.BindingGroups( regroup ) import qualified Syntax.RangeMap as RM import Common.File (seqqList) +import Type.Operations (hasOptionalOrImplicits) {-------------------------------------------------------------------------- @@ -1985,9 +1986,10 @@ inferArgsN ctx range parArgs inferArgExpr tp argexpr hidden -- TODO: add context for better error messages for implicit parameters? = (if hidden then withNoRangeInfo else id) $ allowReturn False $ - inferExpr (Just (tp,getRange argexpr)) - (if isRho tp then Instantiated else Generalized False) - argexpr + do -- traceDoc (\penv -> text "inferArgExpr: expected" <+> ppType penv tp) + etaArg <- etaExpandVarArg tp argexpr + inferExpr (Just (tp,getRange argexpr)) (if isRho tp then Instantiated else Generalized False) etaArg + inferArg !acc [] = return (reverse acc) @@ -2033,6 +2035,31 @@ inferArgsN ctx range parArgs inferArg ((eff,core) : acc) args +-- Possibly eta-expand an (variable) argument to resolve any optional and implicit parameters +etaExpandVarArg :: Type -> Expr Type -> Inf (Expr Type) +etaExpandVarArg tp argexpr + = case (argexpr,splitFunType tp) of -- not for polymorphic types (e.g. `type/hr1.kk`) + (Var name _ vrng, Just (parTps,_,resTp)) | not (hasOptionalOrImplicits parTps) + -> do -- variable argument with an expected function type without optional parameters + matches <- lookupNameCtx isInfoValFunExt name (CtxFunArgs True {-not partial-} (length parTps) [] (Just resTp)) vrng + case matches of + [(qname,info)] + -> do let vtp = infoType info + -- traceDoc $ \penv -> text "inferArgExpr: try eta-expanded:" <+> ppParam penv (qname,vtp) + case splitFunScheme vtp of + Just (_,_,vparTps,_,_) | hasOptionalOrImplicits vparTps + && all isMonoType (map snd parTps) -- cannot abstract over polymorphic parameters + -> -- the variable has a type with optional parameters, eta-expand it to match the expected type without optional parameters + do let range = getRange argexpr + nameFixed = [makeHiddenName "arg" (newName ("x" ++ show i)) | (i,_) <- zip [1..] parTps] + argsFixed = [(Nothing,Var name False range) | name <- nameFixed] + body = App argexpr argsFixed range + eta = Lam [ValueBinder name Nothing Nothing range range | name <- nameFixed] body range + return eta + _ -> return argexpr + _ -> return argexpr + _ -> return argexpr + -- | Is an expression annotated? isAnnot (Parens expr _ _ rng) = isAnnot expr isAnnot (Ann expr tp rng) = True diff --git a/src/Type/InferMonad.hs b/src/Type/InferMonad.hs index eae5f3d7f..210ea8873 100644 --- a/src/Type/InferMonad.hs +++ b/src/Type/InferMonad.hs @@ -126,7 +126,7 @@ import Syntax.Syntax(Expr(..),ValueBinder(..)) import qualified Debug.Trace as DT trace s x = - -- DT.trace (" " ++ s) + DT.trace (" " ++ s) x {-------------------------------------------------------------------------- diff --git a/src/Type/Operations.hs b/src/Type/Operations.hs index d2daa3b84..a69a8e1bb 100644 --- a/src/Type/Operations.hs +++ b/src/Type/Operations.hs @@ -16,6 +16,7 @@ module Type.Operations( instantiate , Evidence(..) , freshSub , isOptionalOrImplicit, splitOptionalImplicit, requiresImplicits + , hasOptionalOrImplicits ) where @@ -44,6 +45,10 @@ splitOptionalImplicit pars (opts,named) = span (isOptional . snd) rest in (fixed,opts,named) +hasOptionalOrImplicits :: [(Name,Type)] -> Bool +hasOptionalOrImplicits pars + = any isOptionalOrImplicit pars + -------------------------------------------------------------------------- -- Instantiation -------------------------------------------------------------------------- diff --git a/src/Type/Type.hs b/src/Type/Type.hs index 3e5882211..6c2ccd43e 100644 --- a/src/Type/Type.hs +++ b/src/Type/Type.hs @@ -24,7 +24,7 @@ module Type.Type (-- * Types -- ** Accessors , maxSynonymRank , synonymRank, typeVarId, typeConName, typeSynName - , isBound, isSkolem, isMeta + , isBound, isSkolem, isMeta, isMonoType -- ** Operations , makeScheme , quantifyType, qualifyType, applyType, tForall @@ -274,6 +274,12 @@ predType :: Pred -> Type predType (PredSub t1 t2) = typeFun [(newName "sub",t1)] typeTotal t2 predType (PredIFace name tps) = typeUnit --todo "Type.Operations.predType.PredIFace" +isMonoType :: Type -> Bool +isMonoType tp + = case expandSyn tp of + TForall{} -> False + _ -> True + {-------------------------------------------------------------------------- Equality diff --git a/test/overload/config.json b/test/overload/config.json new file mode 100644 index 000000000..da24c1df1 --- /dev/null +++ b/test/overload/config.json @@ -0,0 +1 @@ +"-e --showtypesigs" diff --git a/test/overload/eta1.kk b/test/overload/eta1.kk new file mode 100644 index 000000000..a8640d401 --- /dev/null +++ b/test/overload/eta1.kk @@ -0,0 +1,6 @@ +import std/num/float64 + +pub fun main() + // with `float64/show(f : float64, precision : ?int) : string` + // the implicit parameter needs to be eta-expanded + println(1.0) \ No newline at end of file diff --git a/test/overload/eta1.kk.out b/test/overload/eta1.kk.out new file mode 100644 index 000000000..4edcb70db --- /dev/null +++ b/test/overload/eta1.kk.out @@ -0,0 +1,3 @@ +1 + +overload/eta1/main: () -> console () \ No newline at end of file diff --git a/test/overload/eta2.kk b/test/overload/eta2.kk new file mode 100644 index 000000000..dc8c870b7 --- /dev/null +++ b/test/overload/eta2.kk @@ -0,0 +1,8 @@ +noinline fun plus(x:int,y:int,z:int=1) + x+y+z + +noinline fun foo( xs : list, zero : a, ?plus : (a,a) -> a) : a + xs.foldl(zero,?plus) + +pub fun main() + [1,2,3].foo(0) \ No newline at end of file diff --git a/test/overload/eta2.kk.out b/test/overload/eta2.kk.out new file mode 100644 index 000000000..f6879a343 --- /dev/null +++ b/test/overload/eta2.kk.out @@ -0,0 +1,5 @@ +9 + +overload/eta2/foo: forall (xs : list, zero : a, ?plus : (a, a) -> a) -> a +overload/eta2/main: () -> int +overload/eta2/plus: (x : int, y : int, z : ? int) -> int \ No newline at end of file diff --git a/test/overload/monoid2.kk b/test/overload/monoid2.kk index caec5e084..579582efd 100644 --- a/test/overload/monoid2.kk +++ b/test/overload/monoid2.kk @@ -16,7 +16,7 @@ val intmul/monoid = Monoid(1,(*)) val float/monoid = Monoid(0.0,(+)) fun msum( xs, ?monoid ) - xs.foldl(zero(),fn(x,y) plus(x,y)) + xs.foldl(zero(),plus) // fn(x,y) plus(x,y)) // note1: zero is a function now.. // note2: we need to eta expand `plus` to match the type `(a,a,?monoid:monoid) -> a` to `(a,a) -> a` // maybe we can do this kind of conversion automatically in the compiler? diff --git a/test/type/opt1.kk b/test/type/opt1.kk new file mode 100644 index 000000000..d113c857d --- /dev/null +++ b/test/type/opt1.kk @@ -0,0 +1,5 @@ +noinline fun plus(x:int,y:int,z:int=1) + x+y+z + +pub fun main() + [1,2,3].foldl(0,plus) diff --git a/test/type/opt1.kk.out b/test/type/opt1.kk.out new file mode 100644 index 000000000..2174bc442 --- /dev/null +++ b/test/type/opt1.kk.out @@ -0,0 +1,2 @@ +type/opt1/main: () -> int +type/opt1/plus: (x : int, y : int, z : ? int) -> int \ No newline at end of file diff --git a/test/type/opt2.kk b/test/type/opt2.kk new file mode 100644 index 000000000..0b7e3b2f5 --- /dev/null +++ b/test/type/opt2.kk @@ -0,0 +1,8 @@ +noinline fun plus(x:int, y:int, z:int=1) + x+y+z + +fun psum( xs:list, ?plus) : int + xs.foldl(0,plus) + +pub fun main() + [1,2,3].psum(?plus=plus) diff --git a/test/type/opt2.kk.out b/test/type/opt2.kk.out new file mode 100644 index 000000000..c4c49ef83 --- /dev/null +++ b/test/type/opt2.kk.out @@ -0,0 +1,3 @@ +type/opt2/main: () -> int +type/opt2/plus: (x : int, y : int, z : ? int) -> int +type/opt2/psum: forall (xs : list, ?plus : a) -> int \ No newline at end of file