Skip to content

Commit

Permalink
Use dedicated macro for scala 3 enum type
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed May 31, 2024
1 parent 71df285 commit acf0302
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 85 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ val commonSettings = Seq(
"-Yretain-trees",
// tolerate some nested macro expansion
"-Xmax-inlines",
"64"
"128"
)
case Some((2, 13)) =>
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@

package magnolify.shared

import magnolia1.{CaseClass, SealedTrait}
import magnolia1._

import scala.annotation.implicitNotFound

trait EnumTypeDerivation {
implicit def gen[T]: EnumType[T] = macro EnumTypeMacros.derivationEnumTypeMacro[T]
}

object EnumTypeDerivation {
type Typeclass[T] = EnumType[T]

// EnumType can only be split into objects with fixed name
Expand Down Expand Up @@ -49,7 +53,7 @@ trait EnumTypeDerivation {
val ns = sealedTrait.typeName.owner
val subs = sealedTrait.subtypes.map(_.typeclass)
val values = subs.flatMap(_.values).sorted.toList
val annotations = (sealedTrait.annotations ++ subs.flatMap(_.annotations)).toList
val annotations = sealedTrait.inheritedAnnotations.toList ++ sealedTrait.annotations.toList
EnumType.create(
n,
ns,
Expand All @@ -60,4 +64,6 @@ trait EnumTypeDerivation {
v => subs.find(_.name == v).get.from(v)
)
}

def gen[T]: EnumType[T] = macro Magnolia.gen[T]
}
16 changes: 9 additions & 7 deletions shared/src/main/scala-2/magnolify/shared/EnumTypeMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

package magnolify.shared

import magnolia1.Magnolia
import scala.reflect.macros.whitebox

object EnumTypeMacros {

def scalaEnumTypeMacro[T: c.WeakTypeTag](
c: whitebox.Context
)(annotations: c.Expr[AnnotationType[T]]): c.Tree = {
Expand All @@ -44,18 +44,20 @@ object EnumTypeMacros {
val tpe = weakTypeOf[T]
val symbol = tpe.typeSymbol
if (symbol.isModuleClass) {
q"new _root_.magnolify.shared.EnumType.EnumValue[$tpe]{}"
q"new _root_.magnolify.shared.EnumTypeDerivation.EnumValue[$tpe]{}"
} else {
c.abort(c.enclosingPosition, "EnumType value must be an object")
}
}

def derivationEnumTypeMacro[T: c.WeakTypeTag](c: whitebox.Context): c.Tree = {
import c.universe._
val tpe = weakTypeOf[T]
q"_root_.magnolify.shared.EnumTypeDerivation.gen[$tpe]"
}
}

trait EnumTypeCompanionMacros extends EnumTypeCompanionLowPrioMacros {
trait EnumTypeCompanionMacros extends EnumTypeDerivation {
implicit def scalaEnumType[T <: Enumeration#Value: AnnotationType]: EnumType[T] =
macro EnumTypeMacros.scalaEnumTypeMacro[T]
}

trait EnumTypeCompanionLowPrioMacros extends EnumTypeDerivation {
implicit def gen[T]: EnumType[T] = macro Magnolia.gen[T]
}
79 changes: 29 additions & 50 deletions shared/src/main/scala-3/magnolify/shared/EnumTypeDerivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,41 @@

package magnolify.shared

import magnolia1.{CaseClass, CommonDerivation, SealedTrait, SealedTraitDerivation}
import magnolia1.*

import scala.compiletime.*
import scala.deriving.Mirror

// Do not extend Derivation so we can add an extra check when deriving the sum type
trait EnumTypeDerivation extends CommonDerivation[EnumType] with SealedTraitDerivation:

transparent inline def subtypes[T, SubtypeTuple <: Tuple](
m: Mirror.SumOf[T],
idx: Int = 0 // no longer used, kept for bincompat
): List[SealedTrait.Subtype[Typeclass, T, _]] =
subtypesFromMirror[T, SubtypeTuple](m, idx)

inline def derivedMirrorSum[A](sum: Mirror.SumOf[A]): EnumType[A] =
summonAll[Tuple.Map[sum.MirroredElemTypes, [S] =>> S <:< Singleton]] // assert all singleton
split(sealedTraitFromMirror(sum))

inline def derivedMirror[A](using mirror: Mirror.Of[A]): EnumType[A] =
inline mirror match
case sum: Mirror.SumOf[A] => derivedMirrorSum[A](sum)
case product: Mirror.ProductOf[A] => derivedMirrorProduct[A](product)

inline def derived[A](using Mirror.Of[A]): EnumType[A] = derivedMirror[A]
trait EnumTypeDerivation:
implicit inline def gen[T](using mirror: Mirror.Of[T]): EnumType[T] = EnumTypeDerivation.gen[T]

protected override inline def deriveSubtype[s](m: Mirror.Of[s]): EnumType[s] =
derivedMirror[s](using m)

def join[T](caseClass: CaseClass[EnumType, T]): EnumType[T] =
// fail at runtime since we can't prevent derivation
// see https://github.com/softwaremill/magnolia/issues/267
require(caseClass.isObject, s"Cannot derive EnumType[T] for case class ${caseClass.typeInfo}")
val n = caseClass.typeInfo.short
val ns = caseClass.typeInfo.owner
EnumType.create(
n,
ns,
List(n),
caseClass.annotations.toList,
_ => caseClass.rawConstruct(Nil)
)
end join
// Do not extend Derivation so we can add an extra check when deriving the sum type
object EnumTypeDerivation:

private transparent inline def values[A, S <: Tuple](m: Mirror.Of[A]): List[A] =
inline erasedValue[S] match
case _: EmptyTuple =>
Nil
case _: (s *: tail) =>
val infos = summonFrom {
case mm: Mirror.SumOf[`s`] =>
values[A, mm.MirroredElemTypes](mm.asInstanceOf[m.type])
case mm: Mirror.ProductOf[`s`] if Macro.isObject[`s`] =>
List(mm.fromProduct(EmptyTuple).asInstanceOf[A])
case _ =>
error("Cannot derive EnumType for non singleton sum type")
}
infos ::: values[A, tail](m)

inline implicit def gen[T](using mirror: Mirror.Of[T]): EnumType[T] =
val s = values[T, T *: EmptyTuple](mirror)
val it = Macro.typeInfo[T]
val annotations = Macro.inheritedAnns[T] ++ Macro.anns[T]

def split[T](sealedTrait: SealedTrait[EnumType, T]): EnumType[T] =
val n = sealedTrait.typeInfo.short
val ns = sealedTrait.typeInfo.owner
val subs = sealedTrait.subtypes.map(_.typeclass)
val values = subs.flatMap(_.values).sorted.toList
val annotations = (sealedTrait.annotations ++ subs.flatMap(_.annotations)).toList
EnumType.create(
n,
ns,
values,
it.short,
it.owner,
s.map(_.toString),
annotations,
// it is ok to use the inefficient find here because it will be called only once
// and cached inside an instance of EnumType
v => subs.find(_.name == v).get.from(v)
name => s.find(_.toString == name).get
)
end split
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package magnolify.shared

import scala.quoted.*
import scala.deriving.Mirror

object EnumTypeMacros:
def scalaEnumTypeMacro[T: Type](annotations: Expr[AnnotationType[T]])(using
Expand All @@ -35,13 +34,8 @@ object EnumTypeMacros:
val map = '{ $e.values.iterator.map(x => x.toString -> x.asInstanceOf[T]).toMap.apply(_) }
'{ EnumType.create[T]($n, $ns, $vs, $as, $map) }

trait EnumTypeCompanionMacros extends EnumTypeCompanionMacros0

trait EnumTypeCompanionMacros0 extends EnumTypeCompanionMacros1:
trait EnumTypeCompanionMacros extends EnumTypeDerivation:
inline implicit def scalaEnumType[T <: Enumeration#Value](using
annotations: AnnotationType[T]
): EnumType[T] =
${ EnumTypeMacros.scalaEnumTypeMacro[T]('annotations) }

trait EnumTypeCompanionMacros1 extends EnumTypeDerivation:
inline implicit def gen[T](using Mirror.Of[T]): EnumType[T] = derivedMirror[T]
71 changes: 57 additions & 14 deletions test/src/test/scala/magnolify/shared/EnumTypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,31 @@ class EnumTypeSuite extends MagnolifySuite {
}

test("ADT") {
val et = ensureSerializable(EnumType[ADT.Color])
assertEquals(et.name, "Color")
assertEquals(et.namespace, "magnolify.test.ADT")
assertEquals(et.values, List("Blue", "Green", "Red")) // ADTs are ordered alphabetically
assertEquals(et.from("Red"), ADT.Red)
assertEquals(et.to(ADT.Red), "Red")
val etPrimaryColor = ensureSerializable(EnumType[ADT.PrimaryColor])
assertEquals(etPrimaryColor.name, "PrimaryColor")
assertEquals(etPrimaryColor.namespace, "magnolify.test.ADT")
assertEquals(
etPrimaryColor.values,
List("Blue", "Green", "Red")
) // ADTs are ordered alphabetically
assertEquals(etPrimaryColor.from("Red"), ADT.Red)
assertEquals(etPrimaryColor.to(ADT.Red), "Red")
// Magnolia does not capture Java annotations
val as = et.annotations.collect { case a: ScalaAnnotation => a.value }
assertEquals(as, List("Color", "Red"))
val annPrimaryColor = etPrimaryColor.annotations.collect { case a: ScalaAnnotation => a.value }
assertEquals(annPrimaryColor, List("Color", "PrimaryColor"))

val etColor = ensureSerializable(EnumType[ADT.Color])
assertEquals(etColor.name, "Color")
assertEquals(etColor.namespace, "magnolify.test.ADT")
assertEquals(
etColor.values,
List("Blue", "Cyan", "Green", "Magenta", "Red", "Yellow")
) // ADTs are ordered alphabetically
assertEquals(etColor.from("Magenta"), ADT.Magenta)
assertEquals(etColor.to(ADT.Magenta), "Magenta")
// Magnolia does not capture Java annotations
val as = etColor.annotations.collect { case a: ScalaAnnotation => a.value }
assertEquals(as, List("Color"))
}

test("ADT No Default Constructor") {
Expand All @@ -77,9 +93,9 @@ class EnumTypeSuite extends MagnolifySuite {
| ^
|""".stripMargin
val scala3Error =
"""|error: Cannot prove that Some[magnolify.test.ADT.Color] <:< Singleton.
|
| ^
"""|error: Cannot derive EnumType for non sum type
| val error = compileErrors("EnumType.gen[Option[ADT.Color]]")
| ^
|""".stripMargin
if (Properties.versionNumberString.startsWith("2.12")) {
assertNoDiff(error, scala2Error)
Expand All @@ -100,10 +116,37 @@ class EnumTypeSuite extends MagnolifySuite {
| ^
|""".stripMargin

@nowarn
val scala3Error =
"""|error: Cannot prove that Some[magnolify.test.ADT.Color] <:< Singleton.
"""|error:
|No given instance of type magnolify.shared.EnumType[Option[magnolify.test.ADT.Color]] was found for parameter et of method apply in object EnumType.
|I found:
|
| ^
| magnolify.shared.EnumType.gen[Option[magnolify.test.ADT.Color]](
| {
| final class $anon() extends Object(), Serializable {
| type MirroredMonoType = Option[magnolify.test.ADT.Color]
| }
| (new $anon():Object & Serializable)
| }.$asInstanceOf[
|
| scala.deriving.Mirror.Sum{
| type MirroredMonoType² = Option[magnolify.test.ADT.Color];
| type MirroredType = Option[magnolify.test.ADT.Color];
| type MirroredLabel = ("Option" : String);
| type MirroredElemTypes = (None.type, Some[magnolify.test.ADT.Color]);
| type MirroredElemLabels = (("None$" : String), ("Some" : String))
| }
|
| ]
| )
|
|But method gen in trait EnumTypeCompanionMacros1 does not match type magnolify.shared.EnumType[Option[magnolify.test.ADT.Color]]
|
|where: MirroredMonoType is a type in an anonymous class locally defined in class EnumTypeSuite which is an alias of Option[magnolify.test.ADT.Color]
| MirroredMonoType² is a type in trait Mirror with bounds""".stripMargin + " \n" + """|.
|EnumType[Option[ADT.Color]]
| ^
|""".stripMargin

if (Properties.versionNumberString.startsWith("2.12")) {
Expand Down Expand Up @@ -132,7 +175,7 @@ class EnumTypeSuite extends MagnolifySuite {
}

test("ADT CaseMapper") {
val et = ensureSerializable(EnumType[ADT.Color](CaseMapper(_.toLowerCase)))
val et = ensureSerializable(EnumType[ADT.PrimaryColor](CaseMapper(_.toLowerCase)))
assertEquals(et.values, List("blue", "green", "red")) // ADTs are ordered alphabetically
assertEquals(et.from("red"), ADT.Red)
assertEquals(et.to(ADT.Red), "red")
Expand Down
16 changes: 12 additions & 4 deletions test/src/test/scala/magnolify/test/ADT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,18 @@ object ADT {

@ScalaAnnotation("Color")
sealed trait Color
@ScalaAnnotation("Red")
case object Red extends Color
case object Green extends Color
case object Blue extends Color
@ScalaAnnotation("PrimaryColor")
sealed trait PrimaryColor extends Color
case object Red extends PrimaryColor
case object Green extends PrimaryColor
case object Blue extends PrimaryColor
@ScalaAnnotation("SecondaryColor")
sealed abstract class SecondaryColor(p1: PrimaryColor, p2: PrimaryColor) extends Color {
def primaryColors: Set[PrimaryColor] = Set(p1, p2)
}
case object Yellow extends SecondaryColor(Red, Green)
case object Cyan extends SecondaryColor(Green, Blue)
case object Magenta extends SecondaryColor(Red, Blue)

// This is needed to simulate an error with "no valid constructor"
// exception on attempt to deserialize a case object implementing an abstract class without
Expand Down

0 comments on commit acf0302

Please sign in to comment.