diff --git a/src/Type/Infer.hs b/src/Type/Infer.hs
index 4db298d80..ca63faaf7 100644
--- a/src/Type/Infer.hs
+++ b/src/Type/Infer.hs
@@ -75,6 +75,7 @@ import Core.BindingGroups( regroup )
import qualified Syntax.RangeMap as RM
import Common.File (seqqList)
+import Type.Operations (hasOptionalOrImplicits)
{--------------------------------------------------------------------------
@@ -1985,9 +1986,10 @@ inferArgsN ctx range parArgs
inferArgExpr tp argexpr hidden -- TODO: add context for better error messages for implicit parameters?
= (if hidden then withNoRangeInfo else id) $
allowReturn False $
- inferExpr (Just (tp,getRange argexpr))
- (if isRho tp then Instantiated else Generalized False)
- argexpr
+ do -- traceDoc (\penv -> text "inferArgExpr: expected" <+> ppType penv tp)
+ etaArg <- etaExpandVarArg tp argexpr
+ inferExpr (Just (tp,getRange argexpr)) (if isRho tp then Instantiated else Generalized False) etaArg
+
inferArg !acc []
= return (reverse acc)
@@ -2033,6 +2035,31 @@ inferArgsN ctx range parArgs
inferArg ((eff,core) : acc) args
+-- Possibly eta-expand an (variable) argument to resolve any optional and implicit parameters
+etaExpandVarArg :: Type -> Expr Type -> Inf (Expr Type)
+etaExpandVarArg tp argexpr
+ = case (argexpr,splitFunType tp) of -- not for polymorphic types (e.g. `type/hr1.kk`)
+ (Var name _ vrng, Just (parTps,_,resTp)) | not (hasOptionalOrImplicits parTps)
+ -> do -- variable argument with an expected function type without optional parameters
+ matches <- lookupNameCtx isInfoValFunExt name (CtxFunArgs True {-not partial-} (length parTps) [] (Just resTp)) vrng
+ case matches of
+ [(qname,info)]
+ -> do let vtp = infoType info
+ -- traceDoc $ \penv -> text "inferArgExpr: try eta-expanded:" <+> ppParam penv (qname,vtp)
+ case splitFunScheme vtp of
+ Just (_,_,vparTps,_,_) | hasOptionalOrImplicits vparTps
+ && all isMonoType (map snd parTps) -- cannot abstract over polymorphic parameters
+ -> -- the variable has a type with optional parameters, eta-expand it to match the expected type without optional parameters
+ do let range = getRange argexpr
+ nameFixed = [makeHiddenName "arg" (newName ("x" ++ show i)) | (i,_) <- zip [1..] parTps]
+ argsFixed = [(Nothing,Var name False range) | name <- nameFixed]
+ body = App argexpr argsFixed range
+ eta = Lam [ValueBinder name Nothing Nothing range range | name <- nameFixed] body range
+ return eta
+ _ -> return argexpr
+ _ -> return argexpr
+ _ -> return argexpr
+
-- | Is an expression annotated?
isAnnot (Parens expr _ _ rng) = isAnnot expr
isAnnot (Ann expr tp rng) = True
diff --git a/src/Type/InferMonad.hs b/src/Type/InferMonad.hs
index eae5f3d7f..210ea8873 100644
--- a/src/Type/InferMonad.hs
+++ b/src/Type/InferMonad.hs
@@ -126,7 +126,7 @@ import Syntax.Syntax(Expr(..),ValueBinder(..))
import qualified Debug.Trace as DT
trace s x =
- -- DT.trace (" " ++ s)
+ DT.trace (" " ++ s)
x
{--------------------------------------------------------------------------
diff --git a/src/Type/Operations.hs b/src/Type/Operations.hs
index d2daa3b84..a69a8e1bb 100644
--- a/src/Type/Operations.hs
+++ b/src/Type/Operations.hs
@@ -16,6 +16,7 @@ module Type.Operations( instantiate
, Evidence(..)
, freshSub
, isOptionalOrImplicit, splitOptionalImplicit, requiresImplicits
+ , hasOptionalOrImplicits
) where
@@ -44,6 +45,10 @@ splitOptionalImplicit pars
(opts,named) = span (isOptional . snd) rest
in (fixed,opts,named)
+hasOptionalOrImplicits :: [(Name,Type)] -> Bool
+hasOptionalOrImplicits pars
+ = any isOptionalOrImplicit pars
+
--------------------------------------------------------------------------
-- Instantiation
--------------------------------------------------------------------------
diff --git a/src/Type/Type.hs b/src/Type/Type.hs
index 3e5882211..6c2ccd43e 100644
--- a/src/Type/Type.hs
+++ b/src/Type/Type.hs
@@ -24,7 +24,7 @@ module Type.Type (-- * Types
-- ** Accessors
, maxSynonymRank
, synonymRank, typeVarId, typeConName, typeSynName
- , isBound, isSkolem, isMeta
+ , isBound, isSkolem, isMeta, isMonoType
-- ** Operations
, makeScheme
, quantifyType, qualifyType, applyType, tForall
@@ -274,6 +274,12 @@ predType :: Pred -> Type
predType (PredSub t1 t2) = typeFun [(newName "sub",t1)] typeTotal t2
predType (PredIFace name tps) = typeUnit --todo "Type.Operations.predType.PredIFace"
+isMonoType :: Type -> Bool
+isMonoType tp
+ = case expandSyn tp of
+ TForall{} -> False
+ _ -> True
+
{--------------------------------------------------------------------------
Equality
diff --git a/test/overload/config.json b/test/overload/config.json
new file mode 100644
index 000000000..da24c1df1
--- /dev/null
+++ b/test/overload/config.json
@@ -0,0 +1 @@
+"-e --showtypesigs"
diff --git a/test/overload/eta1.kk b/test/overload/eta1.kk
new file mode 100644
index 000000000..a8640d401
--- /dev/null
+++ b/test/overload/eta1.kk
@@ -0,0 +1,6 @@
+import std/num/float64
+
+pub fun main()
+ // with `float64/show(f : float64, precision : ?int) : string`
+ // the implicit parameter needs to be eta-expanded
+ println(1.0)
\ No newline at end of file
diff --git a/test/overload/eta1.kk.out b/test/overload/eta1.kk.out
new file mode 100644
index 000000000..4edcb70db
--- /dev/null
+++ b/test/overload/eta1.kk.out
@@ -0,0 +1,3 @@
+1
+
+overload/eta1/main: () -> console ()
\ No newline at end of file
diff --git a/test/overload/eta2.kk b/test/overload/eta2.kk
new file mode 100644
index 000000000..dc8c870b7
--- /dev/null
+++ b/test/overload/eta2.kk
@@ -0,0 +1,8 @@
+noinline fun plus(x:int,y:int,z:int=1)
+ x+y+z
+
+noinline fun foo( xs : list, zero : a, ?plus : (a,a) -> a) : a
+ xs.foldl(zero,?plus)
+
+pub fun main()
+ [1,2,3].foo(0)
\ No newline at end of file
diff --git a/test/overload/eta2.kk.out b/test/overload/eta2.kk.out
new file mode 100644
index 000000000..f6879a343
--- /dev/null
+++ b/test/overload/eta2.kk.out
@@ -0,0 +1,5 @@
+9
+
+overload/eta2/foo: forall (xs : list, zero : a, ?plus : (a, a) -> a) -> a
+overload/eta2/main: () -> int
+overload/eta2/plus: (x : int, y : int, z : ? int) -> int
\ No newline at end of file
diff --git a/test/overload/monoid2.kk b/test/overload/monoid2.kk
index caec5e084..579582efd 100644
--- a/test/overload/monoid2.kk
+++ b/test/overload/monoid2.kk
@@ -16,7 +16,7 @@ val intmul/monoid = Monoid(1,(*))
val float/monoid = Monoid(0.0,(+))
fun msum( xs, ?monoid )
- xs.foldl(zero(),fn(x,y) plus(x,y))
+ xs.foldl(zero(),plus) // fn(x,y) plus(x,y))
// note1: zero is a function now..
// note2: we need to eta expand `plus` to match the type `(a,a,?monoid:monoid) -> a` to `(a,a) -> a`
// maybe we can do this kind of conversion automatically in the compiler?
diff --git a/test/type/opt1.kk b/test/type/opt1.kk
new file mode 100644
index 000000000..d113c857d
--- /dev/null
+++ b/test/type/opt1.kk
@@ -0,0 +1,5 @@
+noinline fun plus(x:int,y:int,z:int=1)
+ x+y+z
+
+pub fun main()
+ [1,2,3].foldl(0,plus)
diff --git a/test/type/opt1.kk.out b/test/type/opt1.kk.out
new file mode 100644
index 000000000..2174bc442
--- /dev/null
+++ b/test/type/opt1.kk.out
@@ -0,0 +1,2 @@
+type/opt1/main: () -> int
+type/opt1/plus: (x : int, y : int, z : ? int) -> int
\ No newline at end of file
diff --git a/test/type/opt2.kk b/test/type/opt2.kk
new file mode 100644
index 000000000..0b7e3b2f5
--- /dev/null
+++ b/test/type/opt2.kk
@@ -0,0 +1,8 @@
+noinline fun plus(x:int, y:int, z:int=1)
+ x+y+z
+
+fun psum( xs:list, ?plus) : int
+ xs.foldl(0,plus)
+
+pub fun main()
+ [1,2,3].psum(?plus=plus)
diff --git a/test/type/opt2.kk.out b/test/type/opt2.kk.out
new file mode 100644
index 000000000..c4c49ef83
--- /dev/null
+++ b/test/type/opt2.kk.out
@@ -0,0 +1,3 @@
+type/opt2/main: () -> int
+type/opt2/plus: (x : int, y : int, z : ? int) -> int
+type/opt2/psum: forall (xs : list, ?plus : a) -> int
\ No newline at end of file