Skip to content

Commit

Permalink
Fix incorrect behavior in lenient tagged union decoders (#1620)
Browse files Browse the repository at this point in the history
* Fix incorrect behavior in lenient tagged union decoders

* Rearrange things + add entry in CHANGELOG
  • Loading branch information
msosnicki authored Nov 8, 2024
1 parent 31e6188 commit e369dd4
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 90 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Thank you!
* Adds utility types for working with endpoint handlers (see [#1612](https://github.com/disneystreaming/smithy4s/pull/1612))
* Add a more informative error message for repeated namespaces (see [#1608](https://github.com/disneystreaming/smithy4s/pull/1608)).
* Adds `com.disneystreaming.smithy4s:smithy4s-protocol` dependency to the generation of `smithy-build.json` in the `smithy4sUpdateLSPConfig` tasks of the codegen plugins (see [#1610](https://github.com/disneystreaming/smithy4s/pull/1610)).
* Fix for the lenient union decoding [bug](https://github.com/disneystreaming/smithy4s/issues/1617) (see[#1620](https://github.com/disneystreaming/smithy4s/pull/1620)).

# 0.18.25

Expand Down
143 changes: 53 additions & 90 deletions modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -976,30 +976,63 @@ private[smithy4s] class SchemaVisitorJCodec(

private type Writer[A] = A => JsonWriter => Unit

private def taggedUnion[U](
alternatives: Vector[Alt[U, _]]
)(dispatch: Alt.Dispatcher[U]): JCodec[U] =
new JCodec[U] {
val expecting: String = "tagged-union"
private abstract class TaggedUnionJCodec[U](alternatives: Vector[Alt[U, _]])(
dispatch: Alt.Dispatcher[U]
) extends JCodec[U] {

override def canBeKey: Boolean = false
val expecting = "tagged-union"

def jsonLabel[A](alt: Alt[U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
override def canBeKey: Boolean = false

def jsonLabel[A](alt: Alt[U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
}

protected val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](alt: Alt[U, A]) = {
val codec = apply(alt.schema)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
}

private[this] val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](alt: Alt[U, A]) = {
val codec = apply(alt.schema)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}

protected val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance.compile(self)
a =>
out => {
out.writeObjectStart()
out.writeKey(jsonLabel)
jcodecA.encodeValue(a, out)
out.writeObjectEnd()
}
}
}
protected val writer = dispatch.compile(precompiler)

alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}
def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}

def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")

def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")

}

private def taggedUnion[U](
alternatives: Vector[Alt[U, _]]
)(dispatch: Alt.Dispatcher[U]): JCodec[U] =
new TaggedUnionJCodec[U](alternatives)(dispatch) {

def decodeValue(cursor: Cursor, in: JsonReader): U =
if (in.isNextToken('{')) {
Expand All @@ -1020,66 +1053,20 @@ private[smithy4s] class SchemaVisitorJCodec(
}
}
} else in.decodeError("Expected JSON object")

val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance.compile(self)
a =>
out => {
out.writeObjectStart()
out.writeKey(jsonLabel)
jcodecA.encodeValue(a, out)
out.writeObjectEnd()
}
}
}
val writer = dispatch.compile(precompiler)

def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}

def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")

def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")
}

private def lenientTaggedUnion[U](
alternatives: Vector[Alt[U, _]]
)(dispatch: Alt.Dispatcher[U]): JCodec[U] =
new JCodec[U] {
val expecting: String = "tagged-union"

override def canBeKey: Boolean = false

def jsonLabel[A](alt: Alt[U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
}

private[this] val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](alt: Alt[U, A]) = {
val codec = apply(alt.schema)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
}

alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}

new TaggedUnionJCodec[U](alternatives)(dispatch) {
def decodeValue(cursor: Cursor, in: JsonReader): U = {
var result: U = null.asInstanceOf[U]
if (in.isNextToken('{')) {
if (!in.isNextToken('}')) {
in.rollbackToken()
while ({
val key = in.readKeyAsString()
cursor.push(key)
val handler = handlerMap.get(key)
if (handler eq null) in.skip()
else if (in.isNextToken('n')) {
Expand All @@ -1103,31 +1090,7 @@ private[smithy4s] class SchemaVisitorJCodec(
}
} else in.decodeError("Expected JSON object")
}
val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance.compile(self)
a =>
out => {
out.writeObjectStart()
out.writeKey(jsonLabel)
jcodecA.encodeValue(a, out)
out.writeObjectEnd()
}
}
}
val writer = dispatch.compile(precompiler)

def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}

def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")

def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")
}

private def untaggedUnion[U](
Expand Down
38 changes: 38 additions & 0 deletions modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,44 @@ class SchemaVisitorJCodecTests() extends FunSuite {
expect.same(readFromString[Either[Int, String]](json2), Left(1))
}

test("Lenient and regular unions have the same error messages") {
val json = """|{
| "left" : {"foo": "b"}
|}
|""".stripMargin

val schema = Schema.either(
Schema
.struct[String](
Schema.string
.required[String]("bar", identity)
)(identity),
Schema
.struct[String](
Schema.string
.required[String]("baz", identity)
)(identity)
)

val regularCodec =
JsoniterCodecCompilerImpl.defaultJsoniterCodecCompiler.fromSchema(schema)
val lenientCodec =
JsoniterCodecCompilerImpl.defaultJsoniterCodecCompiler.withLenientTaggedUnionDecoding
.fromSchema(schema)

def decodeCheck(codec: JsonCodec[Either[String, String]]) =
expect.same(
Try(
readFromString[Either[String, String]](json)(codec)
).toEither.left.map(_.getMessage),
Left("Missing required field (path: .left.bar)")
)

decodeCheck(regularCodec)
decodeCheck(lenientCodec)

}

test("Untagged union are encoded / decoded") {
val oneJ = """ {"three":"three_value"}"""
val twoJ = """ {"four":4}"""
Expand Down

0 comments on commit e369dd4

Please sign in to comment.