Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect behavior in lenient tagged union decoders #1620

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Thank you!
* Adds utility types for working with endpoint handlers (see [#1612](https://github.com/disneystreaming/smithy4s/pull/1612))
* Add a more informative error message for repeated namespaces (see [#1608](https://github.com/disneystreaming/smithy4s/pull/1608)).
* Adds `com.disneystreaming.smithy4s:smithy4s-protocol` dependency to the generation of `smithy-build.json` in the `smithy4sUpdateLSPConfig` tasks of the codegen plugins (see [#1610](https://github.com/disneystreaming/smithy4s/pull/1610)).
* Fix for the lenient union decoding [bug](https://github.com/disneystreaming/smithy4s/issues/1617) (see[#1620](https://github.com/disneystreaming/smithy4s/pull/1620)).

# 0.18.25

Expand Down
143 changes: 53 additions & 90 deletions modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -976,30 +976,63 @@ private[smithy4s] class SchemaVisitorJCodec(

private type Writer[A] = A => JsonWriter => Unit

private def taggedUnion[U](
alternatives: Vector[Alt[U, _]]
)(dispatch: Alt.Dispatcher[U]): JCodec[U] =
new JCodec[U] {
val expecting: String = "tagged-union"
private abstract class TaggedUnionJCodec[U](alternatives: Vector[Alt[U, _]])(
dispatch: Alt.Dispatcher[U]
) extends JCodec[U] {

override def canBeKey: Boolean = false
val expecting = "tagged-union"

def jsonLabel[A](alt: Alt[U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
override def canBeKey: Boolean = false

def jsonLabel[A](alt: Alt[U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
}

protected val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](alt: Alt[U, A]) = {
val codec = apply(alt.schema)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
}

private[this] val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](alt: Alt[U, A]) = {
val codec = apply(alt.schema)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}

protected val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance.compile(self)
a =>
out => {
out.writeObjectStart()
out.writeKey(jsonLabel)
jcodecA.encodeValue(a, out)
out.writeObjectEnd()
}
}
}
protected val writer = dispatch.compile(precompiler)

alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}
def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}

def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")

def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")

}

private def taggedUnion[U](
alternatives: Vector[Alt[U, _]]
)(dispatch: Alt.Dispatcher[U]): JCodec[U] =
new TaggedUnionJCodec[U](alternatives)(dispatch) {

def decodeValue(cursor: Cursor, in: JsonReader): U =
if (in.isNextToken('{')) {
Expand All @@ -1020,66 +1053,20 @@ private[smithy4s] class SchemaVisitorJCodec(
}
}
} else in.decodeError("Expected JSON object")

val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance.compile(self)
a =>
out => {
out.writeObjectStart()
out.writeKey(jsonLabel)
jcodecA.encodeValue(a, out)
out.writeObjectEnd()
}
}
}
val writer = dispatch.compile(precompiler)

def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}

def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")

def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")
}

private def lenientTaggedUnion[U](
alternatives: Vector[Alt[U, _]]
)(dispatch: Alt.Dispatcher[U]): JCodec[U] =
new JCodec[U] {
val expecting: String = "tagged-union"

override def canBeKey: Boolean = false

def jsonLabel[A](alt: Alt[U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
}

private[this] val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](alt: Alt[U, A]) = {
val codec = apply(alt.schema)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
}

alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}

new TaggedUnionJCodec[U](alternatives)(dispatch) {
def decodeValue(cursor: Cursor, in: JsonReader): U = {
var result: U = null.asInstanceOf[U]
if (in.isNextToken('{')) {
if (!in.isNextToken('}')) {
in.rollbackToken()
while ({
val key = in.readKeyAsString()
cursor.push(key)
val handler = handlerMap.get(key)
if (handler eq null) in.skip()
else if (in.isNextToken('n')) {
Expand All @@ -1103,31 +1090,7 @@ private[smithy4s] class SchemaVisitorJCodec(
}
} else in.decodeError("Expected JSON object")
}
val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance.compile(self)
a =>
out => {
out.writeObjectStart()
out.writeKey(jsonLabel)
jcodecA.encodeValue(a, out)
out.writeObjectEnd()
}
}
}
val writer = dispatch.compile(precompiler)

def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}

def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")

def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")
}

private def untaggedUnion[U](
Expand Down
38 changes: 38 additions & 0 deletions modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,44 @@ class SchemaVisitorJCodecTests() extends FunSuite {
expect.same(readFromString[Either[Int, String]](json2), Left(1))
}

test("Lenient and regular unions have the same error messages") {
val json = """|{
| "left" : {"foo": "b"}
|}
|""".stripMargin

val schema = Schema.either(
Schema
.struct[String](
Schema.string
.required[String]("bar", identity)
)(identity),
Schema
.struct[String](
Schema.string
.required[String]("baz", identity)
)(identity)
)

val regularCodec =
JsoniterCodecCompilerImpl.defaultJsoniterCodecCompiler.fromSchema(schema)
val lenientCodec =
JsoniterCodecCompilerImpl.defaultJsoniterCodecCompiler.withLenientTaggedUnionDecoding
.fromSchema(schema)

def decodeCheck(codec: JsonCodec[Either[String, String]]) =
expect.same(
Try(
readFromString[Either[String, String]](json)(codec)
).toEither.left.map(_.getMessage),
Left("Missing required field (path: .left.bar)")
)

decodeCheck(regularCodec)
decodeCheck(lenientCodec)

}

test("Untagged union are encoded / decoded") {
val oneJ = """ {"three":"three_value"}"""
val twoJ = """ {"four":4}"""
Expand Down