Skip to content

Commit

Permalink
Set avro class-loader on the current thread (#229)
Browse files Browse the repository at this point in the history
Velcity engine loads classes through the thread class loader
  • Loading branch information
RustedBones authored Jan 14, 2025
1 parent ed72061 commit 4fe9ca4
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions plugin/src/main/scala/com/github/sbt/avro/SbtAvro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,28 +251,30 @@ object SbtAvro extends AutoPlugin {
.toArray,
this.getClass.getClassLoader
)
val initLoader = Thread.currentThread().getContextClassLoader

val compiler = avroClassLoader
.loadClass(avroCompiler.value)
.getDeclaredConstructor()
.newInstance()
.asInstanceOf[AvroCompiler]

compiler.setStringType(avroStringType.value)
compiler.setFieldVisibility(avroFieldVisibility.value.toUpperCase)
compiler.setEnableDecimalLogicalType(avroEnableDecimalLogicalType.value)
compiler.setCreateSetters(avroCreateSetters.value)
compiler.setOptionalGetters(avroOptionalGetters.value)

val recs = records.map(avroClassLoader.loadClass)
val avdls = srcDirs.flatMap(d => (d ** AvroAvdlFilter).get())
val avscs = srcDirs.flatMap(d => (d ** AvroAvscFilter).get())
val avprs = srcDirs.flatMap(d => (d ** AvroAvrpFilter).get())

out.log.info(
s"Avro compiler ${avroVersion.value} using stringType=${avroStringType.value}"
)
try {
val compiler = avroClassLoader
.loadClass(avroCompiler.value)
.getDeclaredConstructor()
.newInstance()
.asInstanceOf[AvroCompiler]

compiler.setStringType(avroStringType.value)
compiler.setFieldVisibility(avroFieldVisibility.value.toUpperCase)
compiler.setEnableDecimalLogicalType(avroEnableDecimalLogicalType.value)
compiler.setCreateSetters(avroCreateSetters.value)
compiler.setOptionalGetters(avroOptionalGetters.value)

val recs = records.map(avroClassLoader.loadClass)
val avdls = srcDirs.flatMap(d => (d ** AvroAvdlFilter).get())
val avscs = srcDirs.flatMap(d => (d ** AvroAvscFilter).get())
val avprs = srcDirs.flatMap(d => (d ** AvroAvrpFilter).get())

out.log.info(
s"Avro compiler ${avroVersion.value} using stringType=${avroStringType.value}"
)
Thread.currentThread().setContextClassLoader(avroClassLoader)
compiler.recompile(recs.toArray, outDir)
compiler.compileAvscs(avscs.toArray, outDir)
compiler.compileIdls(avdls.toArray, outDir)
Expand All @@ -284,6 +286,7 @@ object SbtAvro extends AutoPlugin {
out.log.err(e.getMessage)
throw new AvroGenerateFailedException
} finally {
Thread.currentThread().setContextClassLoader(initLoader)
avroClassLoader.close()
}
} else {
Expand Down

0 comments on commit 4fe9ca4

Please sign in to comment.