Skip to content

Commit

Permalink
Fix JSON handling in generator
Browse files Browse the repository at this point in the history
  • Loading branch information
sake92 committed Sep 20, 2024
1 parent eb6c47d commit 4652f48
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 15 deletions.
7 changes: 3 additions & 4 deletions docs/src/files/tutorials/CodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,14 @@ object CodeGen extends TutorialPage {
```scala
import $$ivy.`ba.sake:squery-generator_2.13:${Consts.ArtifactVersion}`
import $$ivy.`ba.sake::squery:${Consts.ArtifactVersion}`
// if using Postgres JSONB
// import $$ivy.`ba.sake::squery-postgres-jawn:${Consts.ArtifactVersion}`
import $$ivy.`org.postgresql:postgresql:42.7.4`
import $$ivy.`com.zaxxer:HikariCP:5.1.0`

import ba.sake.squery.generator.*
import com.zaxxer.hikari.HikariDataSource

// if using Postgres JSONB
// import $$ivy.`ba.sake::squery-postgres-jawn:${Consts.ArtifactVersion}`
// import ba.sake.squery.postgres.jawn.{*, given}

val dataSource = HikariDataSource()
dataSource.setJdbcUrl("jdbc:postgresql://localhost:5432/mydb")
dataSource.setUsername("username")
Expand Down
10 changes: 9 additions & 1 deletion generator/src/ba/sake/squery/generator/DbDefExtractor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,15 @@ sealed abstract class ColumnType {
def name: String
}
object ColumnType {
case class Predefined(tpe: scala.meta.Type) extends ColumnType {
sealed abstract class ScalarType extends ColumnType {
def tpe: scala.meta.Type
}

case class Predefined(tpe: scala.meta.Type) extends ScalarType {
override def name: String = tpe.toString()
}
// e.g. jawn JSON, needs a custom import
case class ThirdParty(tpe: scala.meta.Type, requiredImports: Seq[String]) extends ScalarType {
override def name: String = tpe.toString()
}
case class Enumeration(name: String, values: Seq[String]) extends ColumnType
Expand Down
18 changes: 14 additions & 4 deletions generator/src/ba/sake/squery/generator/PostgresDefExtractor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class PostgresDefExtractor(ds: DataSource) extends DbDefExtractor(ds) {

}

private def resolveScalarType(tpe: String) = Try {
//

private def resolveScalarType(tpe: String): Try[ColumnType.ScalarType] = Try {
tpe match {
case "boolean" => ColumnType.Predefined(t"Boolean")
case "integer" => ColumnType.Predefined(t"Int")
Expand All @@ -98,9 +100,17 @@ class PostgresDefExtractor(ds: DataSource) extends DbDefExtractor(ds) {
case "timestamp with time zone" => ColumnType.Predefined(t"Instant")
case "uuid" => ColumnType.Predefined(t"UUID")
case "bytea" => ColumnType.Predefined(t"Array[Byte]")
case "json" => ColumnType.Predefined(t"org.typelevel.jawn.ast.JValue")
case "jsonb" => ColumnType.Predefined(t"org.typelevel.jawn.ast.JValue")
case other => throw new RuntimeException(s"Unknown scalar type ${other}")
case "json" =>
ColumnType.ThirdParty(
t"JValue",
Seq("org.typelevel.jawn.ast.JValue", "ba.sake.squery.postgres.jawn.{*, given}")
)
case "jsonb" =>
ColumnType.ThirdParty(
t"JValue",
Seq("org.typelevel.jawn.ast.JValue", "ba.sake.squery.postgres.jawn.{*, given}")
)
case other => throw new RuntimeException(s"Unknown scalar type ${other}")
}
}

Expand Down
31 changes: 25 additions & 6 deletions generator/src/ba/sake/squery/generator/SqueryGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import scala.meta.dialects.Scala34
import ba.sake.regenesca._
import ba.sake.squery.generator.ColumnType.Predefined
import ba.sake.squery.generator.ColumnType.Unknown
import ba.sake.squery.generator.ColumnType.ThirdParty

class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGeneratorConfig.Default) {
private val logger = Logger(getClass.getName)
Expand All @@ -33,11 +34,18 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
dbDef.schemas.find(_.name == schemaName) match {
case Some(schemaDef) =>
logger.info(s"Started generating schema '${schemaName}'")
val requiredImports = schemaDef.tables
.flatMap(_.columnDefs.map(_.scalaType).collect { case ThirdParty(_, requiredImports) =>
requiredImports
})
.flatten
.distinct
val (modelFiles, daoFiles) =
generateSchema(schemaDef, dbType = dbDef.tpe, basePackage = "", fileGen = false)
// models first because of Ammonite eval order!
val allFiles = modelFiles ++ daoFiles
val allSources = generateBaseImports(dbDef.tpe).map(_.syntax) ++ allFiles.map(_.source.syntax)
val allSources =
generateModelImports(dbDef.tpe, requiredImports).map(_.syntax) ++ allFiles.map(_.source.syntax)
logger.info(s"Finished generating schema '${schemaName}'")
allSources.mkString("\n")
case None =>
Expand Down Expand Up @@ -204,14 +212,20 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
// object goes first, coz class references PK type
// so that it works in ammonite where definitions are parsed 1 by 1
val source =
if (fileGen)
if (fileGen) {
val requiredImports = schemaDef.tables
.flatMap(_.columnDefs.map(_.scalaType).collect { case ThirdParty(_, requiredImport) =>
requiredImport
})
.flatten
.distinct
source"""
package ${generatePkgSelect(s"${basePackage}.models")}
..${generateBaseImports(dbType)}
..${generateModelImports(dbType, requiredImports)}
${objectDefn}
${caseClassDefn}
"""
else
} else
source"""
${objectDefn}
${caseClassDefn}
Expand Down Expand Up @@ -542,7 +556,7 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
}
}

private def generateBaseImports(dbType: DbType) = {
private def generateBaseImports(dbType: DbType): List[Import] = {
val dbSpecificImporter = s"ba.sake.squery.${dbType.squeryPackage}.{*, given}".parse[Importer].get
List(
q"import java.time.*",
Expand All @@ -553,7 +567,11 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
q"import ..${List(dbSpecificImporter)}"
)
}
private def generateDaoImports(dbType: DbType, basePackage: String) = {
private def generateModelImports(dbType: DbType, additionalImportsStr: Seq[String]): List[Import] = {
val additionalImports = additionalImportsStr.map(_.parse[Importer].get).toList
generateBaseImports(dbType).appended(q"import ..${additionalImports}")
}
private def generateDaoImports(dbType: DbType, basePackage: String): List[Import] = {
val modelsImporter = s"${basePackage}.models.*".parse[Importer].get
val modelsImport = q"import ..${List(modelsImporter)}"
generateBaseImports(dbType) ++ List(modelsImport)
Expand All @@ -568,6 +586,7 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
implicit class ColumnTypeOps(tpe: ColumnType) {
def safeTypeName: String = tpe match {
case p: ColumnType.Predefined => p.name
case p: ColumnType.ThirdParty => p.name
case ColumnType.Enumeration(enumName, _) => enumName.safeTypeName
case ColumnType.Unknown(originalName) => s"<UNKNOWN> // ${originalName}"
}
Expand Down

0 comments on commit 4652f48

Please sign in to comment.