Skip to content

Commit

Permalink
Support unions in elm (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbryzek authored Oct 1, 2024
1 parent 6bb611f commit 3fa0ce9
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 121 deletions.
15 changes: 13 additions & 2 deletions elm-generator/src/main/scala/generator/elm/ElmGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ case class ElmGenerator() {
ElmCommon(args).generate().validNec,
generateEnums(args).validNec,
generateModels(args),
generateUnions(args),
generateResources(args)
).mapN { case (a,b,c,d) => Seq(a,b,c,d) }.map { contents =>
).mapN { case (a,b,c,d,e) => Seq(a,b,c,d,e) }.map { contents =>
Seq(File(
name = s"Generated/" + pascalServiceName(service) + ".elm",
contents = generate(
Expand All @@ -53,15 +54,25 @@ case class ElmGenerator() {
s"module Generated.${pascalServiceName(service)} exposing (..)",
args.imports.generateCode(), // must be generated after the contents
args.functions.generateCode(),
contents.mkString("\n\n")
trimTrailing(contents.filterNot(_.isEmpty).mkString("\n\n"))
).mkString("\n\n")
}

private def trimTrailing(str: String): String = {
str.split("\n").toSeq.map(_.stripTrailing).mkString("\n").stripTrailing
}


private[elm] def generateModels(args: GenArgs): ValidatedNec[String, String] = {
val models = ElmModel(args)
args.service.models.map(models.generate).sequence.map(_.mkString("\n\n"))
}

private[elm] def generateUnions(args: GenArgs): ValidatedNec[String, String] = {
val unions = ElmUnion(args)
args.service.unions.map(unions.generate).sequence.map(_.mkString("\n\n"))
}

private def generateEnums(args: GenArgs): String = {
val enums = ElmEnum(args)
args.service.enums.map(enums.generate).mkString("\n\n")
Expand Down
6 changes: 5 additions & 1 deletion elm-generator/src/main/scala/generator/elm/ElmJson.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ case class ElmJson(imports: Imports) {
}


def decoderName(name: String): String = {
s"${Names.camelCase(name)}Decoder"
}

def decoder(name: String)(contents: String): ElmFunction = {
val n = s"${Names.camelCase(name)}Decoder"
val n = decoderName(name)
imports.addAs("Json.Decode", "Decode")
val code = Seq(
s"$n : Decode.Decoder ${Names.pascalCase(name)}",
Expand Down
65 changes: 65 additions & 0 deletions elm-generator/src/main/scala/generator/elm/ElmUnion.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package generator.elm

import cats.data.ValidatedNec
import cats.implicits._
import io.apibuilder.spec.v0.models.Union

case class ElmUnion(args: GenArgs) {
private val elmJson = ElmJson(args.imports)

def generate(union: Union): ValidatedNec[String, String] = {
genDecoder(union).map { decoder =>
(
Seq(genTypeAlias(union)) ++ Seq(decoder.code)
).mkString("\n\n")
}
}

private def genTypeAlias(union: Union): String = {
val unionName = Names.pascalCase(union.name)
Seq(
s"type $unionName =",
union.types.map { t =>
val typeName = Names.pascalCase(t.`type`)
s"$unionName${typeName} $typeName"
}.mkString("\n| ").indent(4).stripTrailing(),
).mkString("\n")
}

private def genDecoder(m: Union): ValidatedNec[String, ElmFunction] = {
m.discriminator match {
case None => "Only union types with discriminators are currently supported".invalidNec
case Some(disc) => genDecoderType(m, disc).validNec
}
}

private def decoderByDiscriminatorName(u: Union, disc: String): String = {
elmJson.decoderName(u.name) + "By" + Names.pascalCase(disc)
}

private def genDecoderType(u: Union, disc: String): ElmFunction = {
args.imports.addAs("Json.Decode", "Decode")
elmJson.decoder(u.name) {
Seq(
s"Decode.field ${Names.wrapInQuotes(disc)} Decode.string",
s" |> Decode.andThen (\\disc ->",
s" case disc of",
genDecoderDiscriminator(u, disc).indent(12).stripTrailing,
s" )"
).mkString("\n")
}
}

private def genDecoderDiscriminator(u: Union, disc: String): String = {
val unionName = Names.pascalCase(u.name)
val all = u.types.map { t =>
s"""
|${Names.wrapInQuotes(t.`type`)} ->
| ${elmJson.decoderName(t.`type`)} |> Decode.map $unionName${Names.pascalCase(t.`type`)}
|""".stripMargin.strip
} ++ Seq(s"_ ->\n Decode.fail (\"Unknown ${Names.maybeQuote(disc)}: \" ++ disc)")
all.mkString("\n\n").strip()
}

}

5 changes: 3 additions & 2 deletions generator/app/controllers/Invocations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package controllers

import io.apibuilder.generator.v0.models.json._
import io.apibuilder.generator.v0.models.{Invocation, InvocationForm}
import lib.Validation
import lib.{ServiceApidocBug, Validation}
import play.api.libs.json._
import play.api.mvc._

Expand All @@ -17,7 +17,8 @@ class Invocations extends InjectedController {
def postByKey(key: String): Action[AnyContent] = Action { request =>
request.body.asJson match {
case None => Conflict(Json.toJson(Validation.error("Must provide form data (JSON)")))
case Some(js) => {
case Some(incomingJs) => {
val js = ServiceApidocBug.rewrite(incomingJs)
Generators.findGenerator(key).map(_.generator) match {
case Some(generator) =>
js.validate[InvocationForm] match {
Expand Down
63 changes: 63 additions & 0 deletions generator/app/lib/ServiceApidocBug.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package lib

import play.api.libs.json.{JsArray, JsObject, JsDefined, JsValue, Json}

/**
* Version 0.4.28 of apibuilder-validation has an invalid Service spec definition which makes
* the apidoc field required. This dependency is pulled in by apibuilder-graphql. Once we can
* migrate the generator to scala3 we can pull in the latest dependencies. Until then to keep
* things working, we inject a default "apidoc" node where not specified.
*/
object ServiceApidocBug {
private val DefaultApiDoc = Json.obj(
"apidoc" -> Json.obj(
"version" -> "0.16.0"
)
)

def rewrite(js: JsValue): JsValue = {
js match {
case o: JsObject => rewriteObject(o)
case _ => js
}
}

private def rewriteObject(js: JsObject): JsObject = {
rewriteImportedServices(
rewriteService(js)
)
}

private def rewriteService(js: JsObject): JsObject = {
js \ "service" match {
case JsDefined(svc: JsObject) => {
js ++ Json.obj("service" -> maybeAddApidoc(svc))
}
case _ => js
}
}

private def rewriteImportedServices(js: JsObject): JsObject = {
js \ "imported_services" match {
case JsDefined(svc: JsArray) => {
js ++ Json.obj(
"imported_services" -> JsArray(
svc.value.toSeq.map {
case o: JsObject => maybeAddApidoc(o)
case o => o
}
)
)
}
case _ => js
}
}

private def maybeAddApidoc(js: JsObject): JsObject = {
if (js.fields.contains("apidoc")) {
js
} else {
js ++ DefaultApiDoc
}
}
}
25 changes: 0 additions & 25 deletions scala-generator/src/main/scala/models/FeatureMigration.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package scala.generator

import scala.annotation.nowarn

import scala.models.{FeatureMigration, JsonImports}
import scala.models.JsonImports
import io.apibuilder.spec.v0.models.{ResponseCodeInt, ResponseCodeOption, ResponseCodeUndefinedType}
import lib.Text._

Expand All @@ -17,9 +17,6 @@ class ScalaClientMethodGenerator(

protected val sortedResources: Seq[ScalaResource] = ssd.resources.sortWith { _.plural.toLowerCase < _.plural.toLowerCase }

@nowarn("msg=value apidoc in class Service is deprecated")
protected val featureMigration: FeatureMigration = FeatureMigration(ssd.service.apidoc)

def traitsAndErrors(): String = {
Seq(
interfaces(),
Expand Down Expand Up @@ -112,7 +109,7 @@ class ScalaClientMethodGenerator(

protected def includeJsonImportsInErrorsPackage: Boolean = true

protected def errorTypeClass(response: ScalaResponse): String = {
private def errorTypeClass(response: ScalaResponse): String = {
require(!response.isSuccess)

if (response.isUnit) {
Expand All @@ -128,7 +125,7 @@ class ScalaClientMethodGenerator(
}
}

@nowarn protected def exceptionClass(
@nowarn private def exceptionClass(
className: String,
body: Option[String] = None
): String = {
Expand Down Expand Up @@ -185,25 +182,14 @@ class ScalaClientMethodGenerator(
case v => s"""${v.mkString("\n\n")}\n\n_executeRequest("${op.method}", $path, ${args.mkString(", ")})"""
}

val hasOptionResult = if (featureMigration.hasImplicit404s()) {
op.responses.filter(_.isSuccess).find(_.isOption).map { _ =>
s"\ncase r if r.${config.responseStatusMethod} == 404 => None"
}
} else {
None
}
val hasOptionResult = None

val allResponseCodes = (
op.responses.flatMap { r =>
r.code match {
case ResponseCodeInt(value) => Some(value)
case ResponseCodeOption.Default | ResponseCodeOption.UNDEFINED(_) | ResponseCodeUndefinedType(_) => None
}
} ++ (hasOptionResult match {
case None => Seq.empty
case Some(_) => Seq(404)
})
).distinct.sorted
val allResponseCodes = op.responses.flatMap { r =>
r.code match {
case ResponseCodeInt(value) => Some(value)
case ResponseCodeOption.Default | ResponseCodeOption.UNDEFINED(_) | ResponseCodeUndefinedType(_) => None
}
}.distinct.sorted

val defaultResponse = op.responses.find { r =>
r.code match {
Expand All @@ -228,26 +214,14 @@ class ScalaClientMethodGenerator(
response.code match {
case ResponseCodeInt(statusCode) => {
if (response.isSuccess) {
if (featureMigration.hasImplicit404s() && response.isOption) {
if (response.isUnit) {
Some(s"case r if r.${config.responseStatusMethod} == $statusCode => Some($unitResponse")
} else {
val result = config.buildResponse("r", response.datatype.name)
Some(s"case r if r.${config.responseStatusMethod} == $statusCode => Some($result)")
}

} else if (response.isUnit) {
if (response.isUnit) {
Some(s"case r if r.${config.responseStatusMethod} == $statusCode => $unitResponse")

} else {
val result = config.buildResponse("r", response.datatype.name)
Some(s"case r if r.${config.responseStatusMethod} == $statusCode => $result")
}

} else if (featureMigration.hasImplicit404s() && response.isNotFound && response.isOption) {
// will be added later
None

} else {
if (response.isUnit) {
Some(s"case r if r.${config.responseStatusMethod} == $statusCode => throw ${namespaces.errors}.${response.errorClassName}(r.${config.responseStatusMethod})")
Expand Down
Loading

0 comments on commit 3fa0ce9

Please sign in to comment.