Skip to content

Commit

Permalink
auto eta-expand variable arguments that take optional parameters but …
Browse files Browse the repository at this point in the history
…match with a function type without optional parameters. This also addresses PR #493 (I think)
  • Loading branch information
daanx committed Sep 13, 2024
1 parent 1369d3b commit 123d3e7
Show file tree
Hide file tree
Showing 14 changed files with 85 additions and 6 deletions.
33 changes: 30 additions & 3 deletions src/Type/Infer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ import Core.BindingGroups( regroup )

import qualified Syntax.RangeMap as RM
import Common.File (seqqList)
import Type.Operations (hasOptionalOrImplicits)


{--------------------------------------------------------------------------
Expand Down Expand Up @@ -1985,9 +1986,10 @@ inferArgsN ctx range parArgs
inferArgExpr tp argexpr hidden -- TODO: add context for better error messages for implicit parameters?
= (if hidden then withNoRangeInfo else id) $
allowReturn False $
inferExpr (Just (tp,getRange argexpr))
(if isRho tp then Instantiated else Generalized False)
argexpr
do -- traceDoc (\penv -> text "inferArgExpr: expected" <+> ppType penv tp)
etaArg <- etaExpandVarArg tp argexpr
inferExpr (Just (tp,getRange argexpr)) (if isRho tp then Instantiated else Generalized False) etaArg


inferArg !acc []
= return (reverse acc)
Expand Down Expand Up @@ -2033,6 +2035,31 @@ inferArgsN ctx range parArgs
inferArg ((eff,core) : acc) args


-- Possibly eta-expand an (variable) argument to resolve any optional and implicit parameters
etaExpandVarArg :: Type -> Expr Type -> Inf (Expr Type)
etaExpandVarArg tp argexpr
= case (argexpr,splitFunType tp) of -- not for polymorphic types (e.g. `type/hr1.kk`)
(Var name _ vrng, Just (parTps,_,resTp)) | not (hasOptionalOrImplicits parTps)
-> do -- variable argument with an expected function type without optional parameters
matches <- lookupNameCtx isInfoValFunExt name (CtxFunArgs True {-not partial-} (length parTps) [] (Just resTp)) vrng
case matches of
[(qname,info)]
-> do let vtp = infoType info
-- traceDoc $ \penv -> text "inferArgExpr: try eta-expanded:" <+> ppParam penv (qname,vtp)
case splitFunScheme vtp of
Just (_,_,vparTps,_,_) | hasOptionalOrImplicits vparTps
&& all isMonoType (map snd parTps) -- cannot abstract over polymorphic parameters
-> -- the variable has a type with optional parameters, eta-expand it to match the expected type without optional parameters
do let range = getRange argexpr
nameFixed = [makeHiddenName "arg" (newName ("x" ++ show i)) | (i,_) <- zip [1..] parTps]
argsFixed = [(Nothing,Var name False range) | name <- nameFixed]
body = App argexpr argsFixed range
eta = Lam [ValueBinder name Nothing Nothing range range | name <- nameFixed] body range
return eta
_ -> return argexpr
_ -> return argexpr
_ -> return argexpr

-- | Is an expression annotated?
isAnnot (Parens expr _ _ rng) = isAnnot expr
isAnnot (Ann expr tp rng) = True
Expand Down
2 changes: 1 addition & 1 deletion src/Type/InferMonad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ import Syntax.Syntax(Expr(..),ValueBinder(..))
import qualified Debug.Trace as DT

trace s x =
-- DT.trace (" " ++ s)
DT.trace (" " ++ s)
x

{--------------------------------------------------------------------------
Expand Down
5 changes: 5 additions & 0 deletions src/Type/Operations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ module Type.Operations( instantiate
, Evidence(..)
, freshSub
, isOptionalOrImplicit, splitOptionalImplicit, requiresImplicits
, hasOptionalOrImplicits
) where


Expand Down Expand Up @@ -44,6 +45,10 @@ splitOptionalImplicit pars
(opts,named) = span (isOptional . snd) rest
in (fixed,opts,named)

hasOptionalOrImplicits :: [(Name,Type)] -> Bool
hasOptionalOrImplicits pars
= any isOptionalOrImplicit pars

--------------------------------------------------------------------------
-- Instantiation
--------------------------------------------------------------------------
Expand Down
8 changes: 7 additions & 1 deletion src/Type/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ module Type.Type (-- * Types
-- ** Accessors
, maxSynonymRank
, synonymRank, typeVarId, typeConName, typeSynName
, isBound, isSkolem, isMeta
, isBound, isSkolem, isMeta, isMonoType
-- ** Operations
, makeScheme
, quantifyType, qualifyType, applyType, tForall
Expand Down Expand Up @@ -274,6 +274,12 @@ predType :: Pred -> Type
predType (PredSub t1 t2) = typeFun [(newName "sub",t1)] typeTotal t2
predType (PredIFace name tps) = typeUnit --todo "Type.Operations.predType.PredIFace"

isMonoType :: Type -> Bool
isMonoType tp
= case expandSyn tp of
TForall{} -> False
_ -> True


{--------------------------------------------------------------------------
Equality
Expand Down
1 change: 1 addition & 0 deletions test/overload/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"-e --showtypesigs"
6 changes: 6 additions & 0 deletions test/overload/eta1.kk
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import std/num/float64

pub fun main()
// with `float64/show(f : float64, precision : ?int) : string`
// the implicit parameter needs to be eta-expanded
println(1.0)
3 changes: 3 additions & 0 deletions test/overload/eta1.kk.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
1

overload/eta1/main: () -> console ()
8 changes: 8 additions & 0 deletions test/overload/eta2.kk
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
noinline fun plus(x:int,y:int,z:int=1)
x+y+z

noinline fun foo( xs : list<a>, zero : a, ?plus : (a,a) -> a) : a
xs.foldl(zero,?plus)

pub fun main()
[1,2,3].foo(0)
5 changes: 5 additions & 0 deletions test/overload/eta2.kk.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
9

overload/eta2/foo: forall<a> (xs : list<a>, zero : a, ?plus : (a, a) -> a) -> a
overload/eta2/main: () -> int
overload/eta2/plus: (x : int, y : int, z : ? int) -> int
2 changes: 1 addition & 1 deletion test/overload/monoid2.kk
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ val intmul/monoid = Monoid(1,(*))
val float/monoid = Monoid(0.0,(+))

fun msum( xs, ?monoid )
xs.foldl(zero(),fn(x,y) plus(x,y))
xs.foldl(zero(),plus) // fn(x,y) plus(x,y))
// note1: zero is a function now..
// note2: we need to eta expand `plus` to match the type `(a,a,?monoid:monoid<a>) -> a` to `(a,a) -> a`
// maybe we can do this kind of conversion automatically in the compiler?
Expand Down
5 changes: 5 additions & 0 deletions test/type/opt1.kk
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
noinline fun plus(x:int,y:int,z:int=1)
x+y+z

pub fun main()
[1,2,3].foldl(0,plus)
2 changes: 2 additions & 0 deletions test/type/opt1.kk.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
type/opt1/main: () -> int
type/opt1/plus: (x : int, y : int, z : ? int) -> int
8 changes: 8 additions & 0 deletions test/type/opt2.kk
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
noinline fun plus(x:int, y:int, z:int=1)
x+y+z

fun psum( xs:list<int>, ?plus) : int
xs.foldl(0,plus)

pub fun main()
[1,2,3].psum(?plus=plus)
3 changes: 3 additions & 0 deletions test/type/opt2.kk.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
type/opt2/main: () -> int
type/opt2/plus: (x : int, y : int, z : ? int) -> int
type/opt2/psum: forall<a> (xs : list<int>, ?plus : a) -> int

0 comments on commit 123d3e7

Please sign in to comment.