Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dropUnusedMap implementation #120

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions avocADO/src/main/scala/ado.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,32 @@ trait AvocADO[F[_]] {
def zip[A, B](fa: F[A], fb: F[B]): F[(A, B)]
def flatMap[A, B](fa: F[A], f: A => F[B]): F[B]
}

/**
* Drops unused trailing map calls in a for-comprehension. Helps with making for-comprehensions stack-safe.
* Example usage:
* ```scala
* dropUnusedMap {
* for {
* a <- doSth()
* _ <- doSideEffectAndReturnUnit(a)
* } yield ()
* }
* ```
*
* The above code will be transformed to code essentially equivalent to:
* ```scala
* doSth().flatMap(a => doSideEffectAndReturnUnit(a))
* ```
*
* instead of the normal for-comprehension desugaring:
* ```scala
* doSth().map(a => doSideEffectAndReturnUnit(a)).map(_ => ())
* ```
*
* Handled cases:
* - returning `()` from the for-comprehension, when the last generator expression also binds to `Unit`
* - returning the same variable reference as the last generator expression
*/
inline def dropUnusedMap[F[_], A](inline comp: F[A]): F[A] =
${ macros.dropUnusedMapImpl[F, A]('comp) }
128 changes: 114 additions & 14 deletions avocADO/src/main/scala/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ private[avocado] object macros {
def adoImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(using Quotes): Expr[F[A]] =
new ADOImpl(using quotes).adoImpl(compExpr, instanceExpr)

def dropUnusedMapImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]])(using Quotes): Expr[F[A]] =
new ADOImpl(using quotes).dropUnusedMapImpl1(compExpr)

class ADOImpl(using Quotes) {
import quotes.reflect.*

Expand All @@ -34,7 +37,56 @@ private[avocado] object macros {

private def ctx(using context: Context): Context = context

def dropUnusedMapImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(using Quotes): Expr[F[A]] = {
implCommon(compExpr, instanceExpr)(doPar = false, doDropMap = true)
}

def dropUnusedMapImpl1[F[_]: Type, A: Type](compExpr: Expr[F[A]]): Expr[F[A]] = {
given Context = Context(compExpr.asTerm, TypeRepr.of[F]) // This instance is wrong on purpose, with the assumption that it won't be used
def doDrop(expr: Term, arg: Term): Option[Term] = {
arg match {
case Lambda(List(param), body)
if isConstUnitBody(param, body) && expr.tpe.widen <:< ctx.fTpe.appliedTo(TypeRepr.of[Unit]) =>
Some(expr)
case Lambda(List(param), body)
if isIdentityBody(param, body) =>
Some(expr)
case _ =>
None
}
}
def isConstUnitBody(param: ValDef, tree: Tree): Boolean = tree match {
case Block(Nil, body) => isConstUnitBody(param, body)
case Match(scrutinee, List(CaseDef(_, _, body))) if scrutinee.symbol == param.symbol => isConstUnitBody(param, body)
case Literal(UnitConstant()) => true
case _ => false
}
def isIdentityBody(param: ValDef, tree: Tree): Boolean = tree match {
case Block(Nil, body) => isIdentityBody(param, body)
case Ident(name) if name == param.name => true
case _ => false
}
object dropUnusedMapMap extends TreeMap {
override def transformTerm(tree: Term)(owner: Symbol): Term = tree match {
case NormalAllowed(expr, methodName, typeArgs, arg) if methodName == "map" =>
doDrop(expr, arg).fold(super.transformTerm(tree)(owner))(identity)
case WithImplicitsAllowed(expr, args, methodName, typeArgs, arg) if methodName == "map" =>
doDrop(expr, arg).fold(super.transformTerm(tree)(owner))(identity)
case FromTypeclassAllowed(expr, evidences, methodName, typeArgs, arg) if methodName == "map" =>
doDrop(expr, arg).fold(super.transformTerm(tree)(owner))(identity)
case _ =>
super.transformTerm(tree)(owner)
}
}

dropUnusedMapMap.transformTerm(compExpr.asTerm)(Symbol.spliceOwner).asExprOf[F[A]]
}

def adoImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(using Quotes): Expr[F[A]] = {
implCommon(compExpr, instanceExpr)(doPar = true, doDropMap = false)
}

def implCommon[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(doPar: Boolean, doDropMap: Boolean)(using Quotes): Expr[F[A]] = {
val exprTree = compExpr.asTerm match
case Inlined(_, _, tree) => tree match
case Block(Nil, expr) => expr
Expand All @@ -51,30 +103,70 @@ private[avocado] object macros {
case binding => (binding, getBindingDependencies(binding.tree, bindingVals))
}

connectBindings(bindingsWithDependencies, res).asExprOf[F[A]]
val splitFn = if doPar then splitToZip else splitByOne
val dropMapFn = if doDropMap then maybeDropMap else (_: Term, _: Tree, _: Term) => None

connectBindings(bindingsWithDependencies, res)(splitFn, dropMapFn).asExprOf[F[A]]
}

private def connectBindings(bindings: List[(Binding, Set[Symbol])], res: Term)(using Context): Tree = {
def go(bindings: List[(Binding, Set[Symbol])], zipped: List[(Tree, TypeRepr)], acc: Term, lastBinding: Binding): Term = bindings match {
private def connectBindings(
bindings: List[(Binding, Set[Symbol])],
res: Term
)(
splitFn: List[(Binding, Set[Symbol])] => (List[(Binding, Set[Symbol])], List[(Binding, Set[Symbol])], Binding),
dropMapFn: (Term, Tree, Term) => Option[Term]
)(using Context): Tree = {
def go(bindings: List[(Binding, Set[Symbol])], zipped: List[(Tree, TypeRepr)], acc: Term, lastBinding: Binding, res: Term): Term = bindings match {
case Nil =>
val arg = funFromZipped(zipped, res, Symbol.spliceOwner)
ctx.instance
.select(ctx.instance.tpe.typeSymbol.methodMember(lastBinding.methodName).head)
.appliedToTypes(List(typeReprForBindings(zipped), adaptTpeForMethod(res, lastBinding.methodName)))
.appliedToArgs(List(acc, arg))
case head :: Nil =>
dropMapFn(head._1.tree, head._1.pattern, res) match {
case Some(prevExpr) =>
go(Nil, zipped, acc, lastBinding, head._1.tree)
case None =>
makeNonFinalCall(bindings, zipped, acc, lastBinding, res)
}
case _ =>
val (toZip, rest, newLastBinding) = splitToZip(bindings)
val body = go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), newLastBinding)
val arg = funFromZipped(zipped, body, Symbol.spliceOwner)
val tpes = lastBinding.typeArgs.map(_.widen)
ctx.instance
.select(ctx.instance.tpe.typeSymbol.methodMember(lastBinding.methodName).head)
.appliedToTypes(List(typeReprForBindings(zipped), adaptTpeForMethod(body, lastBinding.methodName)))
.appliedToArgs(List(acc, arg))
makeNonFinalCall(bindings, zipped, acc, lastBinding, res)
}

def makeNonFinalCall(bindings: List[(Binding, Set[Symbol])], zipped: List[(Tree, TypeRepr)], acc: Term, lastBinding: Binding, res: Term): Term = {
val (toZip, rest, newLastBinding) = splitFn(bindings)
val body = go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), newLastBinding, res)
val arg = funFromZipped(zipped, body, Symbol.spliceOwner)
val tpes = lastBinding.typeArgs.map(_.widen)
ctx.instance
.select(ctx.instance.tpe.typeSymbol.methodMember(lastBinding.methodName).head)
.appliedToTypes(List(typeReprForBindings(zipped), adaptTpeForMethod(body, lastBinding.methodName)))
.appliedToArgs(List(acc, arg))
}

val (toZip, rest, lastMethod) = splitFn(bindings)
go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), lastMethod, res)
}

private def maybeDropMap(prevTerm: Term, prevPattern: Tree, res: Term)(using Context): Option[Term] = {
res match {
case Literal(UnitConstant()) if extractTypeFromApplicative(prevTerm.tpe).widen =:= TypeRepr.of[Unit] =>
Some(prevTerm)
case _ if eqPrevPatternRef(prevPattern, res) =>
Some(prevTerm)
case _ =>
None
}
}

val (toZip, rest, lastMethod) = splitToZip(bindings)
go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), lastMethod)
private def eqPrevPatternRef(prevPattern: Tree, res: Term): Boolean = {
(prevPattern, res) match {
case (valdef: ValDef, ident: Ident) =>
valdef.symbol == ident.symbol
case _ =>
false
}
}

private def adaptTpeForMethod(arg: Term, methodName: String): TypeRepr =
Expand Down Expand Up @@ -127,8 +219,16 @@ private[avocado] object macros {
(List(head), tail, head._1)
case _ =>
throwGenericError()
}
}

private def splitByOne(bindings: List[(Binding, Set[Symbol])]): (List[(Binding, Set[Symbol])], List[(Binding, Set[Symbol])], Binding) = {
bindings match {
case head :: tail =>
(List(head), tail, head._1)
case _ =>
throwGenericError()
}

}

private val tuple2: Term = Ref(Symbol.requiredModule("scala.Tuple2"))
Expand Down
190 changes: 190 additions & 0 deletions avocADO/src/test/scala/DropMapTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
package avocado.tests

import avocado.*

class DropMapTest extends munit.FunSuite {
class myOptionPackage(doOnMap: => Unit) {
sealed trait MyOption[+A] {
def map[B](f: A => B): MyOption[B] = this match {
case MySome(a) =>
doOnMap
MySome(f(a))
case MyNone => MyNone
}
def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match {
case MySome(a) => f(a)
case MyNone => MyNone
}
def zip[B](that: MyOption[B]): MyOption[(A, B)] = (this, that) match {
case (MySome(a), MySome(b)) => MySome((a, b))
case _ => MyNone
}
def value: Option[A] = this match {
case MySome(a) => Some(a)
case MyNone => None
}
}
case class MySome[A](a: A) extends MyOption[A]
case object MyNone extends MyOption[Nothing]
}

test("don't drop map in a simple case") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
b <- MySome(2)
} yield a + b
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
b <- MySome(2)
} yield a + b
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes == mapUnusedResOrg)
}

test("drop map with same var ref as result") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
b <- MySome(a)
} yield b
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
b <- MySome(a)
} yield b
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes < mapUnusedResOrg)
}

test("drop map with unit result and wildcard last pattern") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
_ <- MySome(())
} yield ()
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
_ <- MySome(())
} yield ()
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes < mapUnusedResOrg)
}

test("drop map with unit result and named last pattern") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
b <- MySome(())
} yield ()
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
b <- MySome(())
} yield ()
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes < mapUnusedResOrg)
}

test("drop map with unit result and wildcard last pattern with alias in the middle") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
b = a
_ <- MySome(())
} yield ()
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
b = a
_ <- MySome(())
} yield ()
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes < mapUnusedResOrg)
}

test("drop map with unit result and named last pattern with alias in the middle") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
b = a
c <- MySome(())
} yield ()
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
b = a
c <- MySome(())
} yield ()
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes < mapUnusedResOrg)
}
}
Loading