Skip to content

Commit

Permalink
fix: add unique ids for queries (#25)
Browse files Browse the repository at this point in the history
Close #24
  • Loading branch information
tamimattafi authored Jun 3, 2024
1 parent 7404e5e commit 9b2cc56
Show file tree
Hide file tree
Showing 11 changed files with 314 additions and 120 deletions.
51 changes: 51 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: Test

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
inputs:
logLevel:
description: 'Log Level'
required: false
default: 'warning'
type: choice
options:
- info
- warning
- debug

concurrency:
cancel-in-progress: true
group: ${{ github.workflow }}-${{ github.ref }}

permissions:
contents: read
checks: write
id-token: write

jobs:
test-library:
name: Run Library Tests
runs-on: macos-latest
steps:
- name: Check out code
uses: actions/checkout@v2
- name: Set up JDK 17
uses: actions/setup-java@v2
with:
distribution: adopt
java-version: 17
- name: Run Compiler Tests
run: ./gradlew library:compiler:allTests
- name: Publish Test Results
uses: mikepenz/action-junit-report@v4
if: success() || failure()
with:
report_paths: '**/build/test-results/**/TEST-*.xml'
include_passed: true
fail_on_failure: true
annotate_notice: true
follow_symlink: true
4 changes: 4 additions & 0 deletions convention/multiplatform/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ plugins {
id(libs.plugins.java.gradle.plugin.get().pluginId)
}

kotlin {
jvmToolchain(17)
}

gradlePlugin {
plugins.create("multiplatform") {
id = "com.attafitamim.kabin.multiplatform"
Expand Down
4 changes: 4 additions & 0 deletions convention/publishing/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ plugins {
id(libs.plugins.java.gradle.plugin.get().pluginId)
}

kotlin {
jvmToolchain(17)
}

gradlePlugin {
plugins.create("publish") {
id = "com.attafitamim.kabin.publish"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.gradle.api.publish.maven.MavenPomScm

class PublishConventions : Plugin<Project> {

private val version = "0.1.0-alpha10"
private val version = "0.1.0-alpha11"
private val group = "com.attafitamim.kabin"

override fun apply(project: Project) {
Expand Down
4 changes: 4 additions & 0 deletions library/compiler/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ kotlin {
implementation(libs.kotlin.ksp)
implementation(libs.sqldelight.runtime)
}

jvmTest.dependencies {
implementation(libs.junit)
}
}

java {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,15 @@ class DaoGenerator(
val entityColumnAccess = entitySpec
.getColumnAccessChain(compoundRelationSpec.relation.entityColumn)

val functionName = entitySpec.getQueryByColumnsName(
entityColumnAccess.last(),
val directEntityColumn = entityColumnAccess.last()

val entityQuery = getSelectSQLQuery(
entitySpec,
directEntityColumn
)

val functionName = entitySpec.getQueryFunctionName(
entityQuery,
isNullable = false,
parent = null
)
Expand All @@ -524,20 +531,26 @@ class DaoGenerator(
if (junctionSpec != null) {
val junctionParentAccess = junctionSpec.entitySpec
.getColumnAccessChain(junctionSpec.parentColumn)
.last()

val junctionQuery = getSelectSQLQuery(
junctionSpec.entitySpec,
junctionParentAccess
)

val junctionEntityAccess = junctionSpec.entitySpec
.getColumnAccessChain(junctionSpec.entityColumn)

val awaitFunction = "awaitAsOneNotNullIO"
val directJunctionColumn = junctionEntityAccess.last()
val directEntityColumn = entityColumnAccess.last()
val adapter = directJunctionColumn
.getAdapterReference(directEntityColumn)

adapter?.let(adapters::add)

val junctionElement = "junction"
val junctionFunctionName = junctionSpec.entitySpec.getQueryByColumnsName(
junctionParentAccess.last(),
val junctionFunctionName = junctionSpec.entitySpec.getQueryFunctionName(
junctionQuery,
isNullable = false,
parent = null
)
Expand All @@ -561,7 +574,6 @@ class DaoGenerator(
val awaitFunction = property.dataTypeSpec.getAwaitFunction()

val directParentColumn = parentColumnAccess.last()
val directEntityColumn = entityColumnAccess.last()
val adapter = directParentColumn
.getAdapterReference(directEntityColumn)

Expand Down Expand Up @@ -689,15 +701,19 @@ class DaoGenerator(
val directEntityColumn = entityColumnAccess.last()
val directEntityColumnName = directEntityColumn.declaration.simpleNameString
if (junctionSpec != null) {
val junctionParentColumn = junctionSpec.entitySpec
val junctionColumn = junctionSpec.entitySpec
.getColumnAccessChain(junctionSpec.parentColumn)
.last()

actualFunctionName = junctionSpec.entitySpec.getQueryByColumnsName(
junctionParentColumn.toSortedSet(),
junctionColumn,
isNullable = false,
parent = null
)

val junctionParentColumn = junctionSpec.entitySpec
.getColumnAccessChain(junctionSpec.parentColumn)

val directJunctionParentColumn = junctionParentColumn.last()
val directJunctionParentColumnName = directJunctionParentColumn.declaration.simpleNameString
parameterReference = ParameterReference(
Expand All @@ -710,10 +726,9 @@ class DaoGenerator(
directEntityColumn.declaration.type.toTypeName()
)


actualFunctionName = if (relationSpec.property.dataTypeSpec.dataType is DataTypeSpec.DataType.Collection) {
val functionName = compoundSpec.declaration.getQueryByColumnsName(
entityColumnAccess.toSortedSet(),
entityColumnAccess,
isNullable = false,
parent = null
)
Expand All @@ -723,7 +738,7 @@ class DaoGenerator(
}
} else {
compoundReturnType.getQueryByColumnsName(
entityColumnAccess.toSortedSet(),
entityColumnAccess,
parent = null
)
}
Expand All @@ -747,6 +762,12 @@ class DaoGenerator(
if (junctionSpec != null) {
val junctionParentAccess = junctionSpec.entitySpec
.getColumnAccessChain(junctionSpec.parentColumn)
.last()

val junctionQuery = getSelectSQLQuery(
junctionSpec.entitySpec,
junctionParentAccess
)

val junctionEntityAccess = junctionSpec.entitySpec
.getColumnAccessChain(junctionSpec.entityColumn)
Expand All @@ -758,8 +779,8 @@ class DaoGenerator(
adapter?.let(adapters::add)

val junctionElement = "junction"
val junctionFunctionName = junctionSpec.entitySpec.getQueryByColumnsName(
junctionParentAccess.last(),
val junctionFunctionName = junctionSpec.entitySpec.getQueryFunctionName(
junctionQuery,
isNullable = false,
parent = null
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ sealed interface SQLQuery {
data class Parameters(
override val value: String,
override val parametersSize: Int,
val queryParameters: Collection<QueryParameter>,
val queryParameters: List<QueryParameter>,
val mutatedKeys: Set<String>,
override val queriedKeys: Set<String>
): SQLQuery {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,53 @@ import com.attafitamim.kabin.specs.dao.DataTypeSpec
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.FunSpec


const val EXECUTE_FUNCTION = "execute"
const val EXECUTE_QUERY_FUNCTION = "executeQuery"

private val generatedIds = HashMap<String, Int>()
private val generatedHashCodes = HashMap<Int, String>()

private const val MAX_UNIQUE_CYCLES = 100
private const val CYCLES_START_INDEX = 0
private const val HASH_CODE_INTERVAL = 31

fun SQLQuery.getQueryIdentifier(): Int? = when (this) {
is SQLQuery.Columns -> value.getUniqueQueryIdentifier()
is SQLQuery.Parameters -> value.getUniqueQueryIdentifier()
is SQLQuery.Raw -> null
}

// TODO: optimize this
fun String.getUniqueQueryIdentifier(): Int? {
val generatedId = generatedIds[this]
if (generatedId != null) {
return generatedId
}

var cycle = CYCLES_START_INDEX
while (cycle <= MAX_UNIQUE_CYCLES) {
val originalHashCode = hashCode()
val hashCode = if (cycle == CYCLES_START_INDEX) {
originalHashCode
} else {
originalHashCode * (cycle * HASH_CODE_INTERVAL)
}

val valueLinkedToHashCode = generatedHashCodes[hashCode]
if (valueLinkedToHashCode == null || valueLinkedToHashCode == this) {
generatedIds[this] = hashCode
generatedHashCodes[hashCode] = this
return hashCode
}

cycle++
}

return null
}


fun FunSpec.Builder.addDriverQueryCode(
query: SQLQuery,
function: String = EXECUTE_QUERY_FUNCTION,
Expand Down Expand Up @@ -56,21 +100,22 @@ fun FunSpec.Builder.addDriverRawQueryCode(
function: String,
binderCode: (CodeBlock.Builder.() -> Unit)? = null
): FunSpec.Builder = apply {
val identifier = query.getQueryIdentifier()
val logic = if (function == EXECUTE_QUERY_FUNCTION) {
"""
|val result = driver.executeQuery(
| null,
| %L,
| mapper,
| 0
| identifier = $identifier,
| sql = %L,
| mapper = mapper,
| parameters = 0
|)
""".trimMargin()
} else {
"""
|driver.execute(
| null,
| %L,
| 0
| identifier = $identifier,
| sql = %L,
| parameters = 0
|)
""".trimMargin()
}
Expand Down Expand Up @@ -142,30 +187,31 @@ fun FunSpec.Builder.addDriverQueryCode(

sizeExpression.append(simpleParametersSize)
val codeBlockBuilder = CodeBlock.builder()
.addStatement("val kabinQuery = %P", query.value)
.addStatement("val kabinParametersCount = $sizeExpression")
.addStatement("val internalQuerySql = %P", query.value)
.addStatement("val internalQueryParametersCount = $sizeExpression")

val originalIdentifier = query.getQueryIdentifier()
val identifier = if (addedConstants.isEmpty()) {
query.value.hashCode()
originalIdentifier
} else {
"kabinQuery.hashCode()"
null
}

val logic = if (function == "executeQuery") {
"""
|val result = driver.executeQuery(
| $identifier,
| kabinQuery,
| mapper,
| kabinParametersCount
| identifier = $identifier,
| sql = internalQuerySql,
| mapper = mapper,
| parameters = internalQueryParametersCount
|)
""".trimMargin()
} else {
"""
|driver.execute(
| $identifier,
| kabinQuery,
| kabinParametersCount
| identifier = $identifier,
| sql = internalQuerySql,
| parameters = internalQueryParametersCount
|)
""".trimMargin()
}
Expand All @@ -177,7 +223,8 @@ fun FunSpec.Builder.addDriverQueryCode(
}

if (query.mutatedKeys.isNotEmpty()) {
codeBlockBuilder.beginControlFlow("notifyQueries($identifier) { emit ->")
val notifyIdentifier = identifier ?: -1
codeBlockBuilder.beginControlFlow("notifyQueries($notifyIdentifier) { emit ->")

query.mutatedKeys.forEach { key ->
codeBlockBuilder.addStatement("emit(%S)", key)
Expand All @@ -200,22 +247,22 @@ fun FunSpec.Builder.addDriverQueryCode(
) = apply {
val codeBlockBuilder = CodeBlock.builder()

val identifier = query.hashCode()
val identifier = query.getQueryIdentifier()
val logic = if (function == EXECUTE_QUERY_FUNCTION) {
"""
|val result = driver.executeQuery(
| $identifier,
| %P,
| mapper,
| ${query.parametersSize}
| identifier = $identifier,
| sql = %P,
| mapper = mapper,
| parameters = ${query.parametersSize}
|)
""".trimMargin()
} else {
"""
|driver.execute(
| $identifier,
| %P,
| ${query.parametersSize}
| identifier = $identifier,
| sql = %P,
| parameters = ${query.parametersSize}
|)
""".trimMargin()
}
Expand Down
Loading

0 comments on commit 9b2cc56

Please sign in to comment.