From 4fe9ca4dae4607af63e8a927ba38056d262ed593 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Tue, 14 Jan 2025 15:18:11 +0100 Subject: [PATCH] Set avro class-loader on the current thread (#229) Velcity engine loads classes through the thread class loader --- .../scala/com/github/sbt/avro/SbtAvro.scala | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/plugin/src/main/scala/com/github/sbt/avro/SbtAvro.scala b/plugin/src/main/scala/com/github/sbt/avro/SbtAvro.scala index 1d7d00e..a73395f 100644 --- a/plugin/src/main/scala/com/github/sbt/avro/SbtAvro.scala +++ b/plugin/src/main/scala/com/github/sbt/avro/SbtAvro.scala @@ -251,28 +251,30 @@ object SbtAvro extends AutoPlugin { .toArray, this.getClass.getClassLoader ) + val initLoader = Thread.currentThread().getContextClassLoader - val compiler = avroClassLoader - .loadClass(avroCompiler.value) - .getDeclaredConstructor() - .newInstance() - .asInstanceOf[AvroCompiler] - - compiler.setStringType(avroStringType.value) - compiler.setFieldVisibility(avroFieldVisibility.value.toUpperCase) - compiler.setEnableDecimalLogicalType(avroEnableDecimalLogicalType.value) - compiler.setCreateSetters(avroCreateSetters.value) - compiler.setOptionalGetters(avroOptionalGetters.value) - - val recs = records.map(avroClassLoader.loadClass) - val avdls = srcDirs.flatMap(d => (d ** AvroAvdlFilter).get()) - val avscs = srcDirs.flatMap(d => (d ** AvroAvscFilter).get()) - val avprs = srcDirs.flatMap(d => (d ** AvroAvrpFilter).get()) - - out.log.info( - s"Avro compiler ${avroVersion.value} using stringType=${avroStringType.value}" - ) try { + val compiler = avroClassLoader + .loadClass(avroCompiler.value) + .getDeclaredConstructor() + .newInstance() + .asInstanceOf[AvroCompiler] + + compiler.setStringType(avroStringType.value) + compiler.setFieldVisibility(avroFieldVisibility.value.toUpperCase) + compiler.setEnableDecimalLogicalType(avroEnableDecimalLogicalType.value) + compiler.setCreateSetters(avroCreateSetters.value) + compiler.setOptionalGetters(avroOptionalGetters.value) + + val recs = records.map(avroClassLoader.loadClass) + val avdls = srcDirs.flatMap(d => (d ** AvroAvdlFilter).get()) + val avscs = srcDirs.flatMap(d => (d ** AvroAvscFilter).get()) + val avprs = srcDirs.flatMap(d => (d ** AvroAvrpFilter).get()) + + out.log.info( + s"Avro compiler ${avroVersion.value} using stringType=${avroStringType.value}" + ) + Thread.currentThread().setContextClassLoader(avroClassLoader) compiler.recompile(recs.toArray, outDir) compiler.compileAvscs(avscs.toArray, outDir) compiler.compileIdls(avdls.toArray, outDir) @@ -284,6 +286,7 @@ object SbtAvro extends AutoPlugin { out.log.err(e.getMessage) throw new AvroGenerateFailedException } finally { + Thread.currentThread().setContextClassLoader(initLoader) avroClassLoader.close() } } else {