Skip to content

Commit

Permalink
added encoding for DatePeriod, DateTimePeriod, Instant, LocalDateTime…
Browse files Browse the repository at this point in the history
…, and LocalDate, Duration not working
  • Loading branch information
Jolanrensen committed Apr 7, 2024
1 parent 48db819 commit ab4c455
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.unsafe.types.CalendarInterval
import org.jetbrains.kotlinx.spark.api.plugin.annotations.ColumnName
import org.jetbrains.kotlinx.spark.api.plugin.annotations.Sparkify
import org.jetbrains.kotlinx.spark.api.udts.DatePeriodUdt
import org.jetbrains.kotlinx.spark.api.udts.DateTimePeriodUdt
import org.jetbrains.kotlinx.spark.api.udts.InstantUdt
import org.jetbrains.kotlinx.spark.api.udts.LocalDateTimeUdt
import org.jetbrains.kotlinx.spark.api.udts.LocalDateUdt
import scala.reflect.ClassTag
import java.io.Serializable
import java.util.*
Expand Down Expand Up @@ -170,12 +175,14 @@ object KotlinTypeInference : Serializable {
* @return an [AgnosticEncoder] for the given [kType].
*/
@Suppress("UNCHECKED_CAST")
fun <T> encoderFor(kType: KType): AgnosticEncoder<T> =
encoderFor(
fun <T> encoderFor(kType: KType): AgnosticEncoder<T> {
registerUdts()
return encoderFor(
currentType = kType,
seenTypeSet = emptySet(),
typeVariables = emptyMap(),
) as AgnosticEncoder<T>
}


private inline fun <reified T> KType.isSubtypeOf(): Boolean = isSubtypeOf(typeOf<T>())
Expand Down Expand Up @@ -296,6 +303,16 @@ object KotlinTypeInference : Serializable {
private fun <K, V> transitiveMerge(a: Map<K, V>, b: Map<K, V>, valueToKey: (V) -> K?): Map<K, V> =
a + b.mapValues { a.getOrDefault(valueToKey(it.value), it.value) }

private fun registerUdts() {
UDTRegistration.register(kotlinx.datetime.LocalDate::class.java.name, LocalDateUdt::class.java.name)
UDTRegistration.register(kotlinx.datetime.Instant::class.java.name, InstantUdt::class.java.name)
UDTRegistration.register(kotlinx.datetime.LocalDateTime::class.java.name, LocalDateTimeUdt::class.java.name)
UDTRegistration.register(kotlinx.datetime.DatePeriod::class.java.name, DatePeriodUdt::class.java.name)
UDTRegistration.register(kotlinx.datetime.DateTimePeriod::class.java.name, DateTimePeriodUdt::class.java.name)
// TODO
// UDTRegistration.register(kotlin.time.Duration::class.java.name, DurationUdt::class.java.name)
}

/**
*
*/
Expand Down Expand Up @@ -375,19 +392,12 @@ object KotlinTypeInference : Serializable {
currentType.isSubtypeOf<java.math.BigInteger?>() -> AgnosticEncoders.`JavaBigIntEncoder$`.`MODULE$`
currentType.isSubtypeOf<CalendarInterval?>() -> AgnosticEncoders.`CalendarIntervalEncoder$`.`MODULE$`
currentType.isSubtypeOf<java.time.LocalDate?>() -> AgnosticEncoders.STRICT_LOCAL_DATE_ENCODER()
currentType.isSubtypeOf<kotlinx.datetime.LocalDate?>() -> TODO("User java.time.LocalDate for now. We'll create a UDT for this.")
currentType.isSubtypeOf<java.sql.Date?>() -> AgnosticEncoders.STRICT_DATE_ENCODER()
currentType.isSubtypeOf<java.time.Instant?>() -> AgnosticEncoders.STRICT_INSTANT_ENCODER()
currentType.isSubtypeOf<kotlinx.datetime.Instant?>() -> TODO("Use java.time.Instant for now. We'll create a UDT for this.")
currentType.isSubtypeOf<kotlin.time.TimeMark?>() -> TODO("Use java.time.Instant for now. We'll create a UDT for this.")
currentType.isSubtypeOf<java.sql.Timestamp?>() -> AgnosticEncoders.STRICT_TIMESTAMP_ENCODER()
currentType.isSubtypeOf<java.time.LocalDateTime?>() -> AgnosticEncoders.`LocalDateTimeEncoder$`.`MODULE$`
currentType.isSubtypeOf<kotlinx.datetime.LocalDateTime?>() -> TODO("Use java.time.LocalDateTime for now. We'll create a UDT for this.")
currentType.isSubtypeOf<java.time.Duration?>() -> AgnosticEncoders.`DayTimeIntervalEncoder$`.`MODULE$`
currentType.isSubtypeOf<kotlin.time.Duration?>() -> TODO("Use java.time.Duration for now. We'll create a UDT for this.")
currentType.isSubtypeOf<java.time.Period?>() -> AgnosticEncoders.`YearMonthIntervalEncoder$`.`MODULE$`
currentType.isSubtypeOf<kotlinx.datetime.DateTimePeriod?>() -> TODO("Use java.time.Period for now. We'll create a UDT for this.")
currentType.isSubtypeOf<kotlinx.datetime.DatePeriod?>() -> TODO("Use java.time.Period for now. We'll create a UDT for this.")
currentType.isSubtypeOf<Row?>() -> AgnosticEncoders.`UnboundRowEncoder$`.`MODULE$`

// enums
Expand All @@ -414,6 +424,8 @@ object KotlinTypeInference : Serializable {
AgnosticEncoders.UDTEncoder(udt, udt.javaClass)
}

currentType.isSubtypeOf<kotlin.time.Duration?>() -> TODO("kotlin.time.Duration is unsupported. Use java.time.Duration for now.")

currentType.isSubtypeOf<scala.Option<*>?>() -> {
val elementEncoder = encoderFor(
currentType = tArguments.first().type!!,
Expand Down Expand Up @@ -666,8 +678,6 @@ object KotlinTypeInference : Serializable {
fields.asScalaSeq(),
)
}

// else -> throw IllegalArgumentException("No encoder found for type $currentType")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.jetbrains.kotlinx.spark.api.udts

import kotlinx.datetime.DatePeriod
import kotlinx.datetime.toJavaPeriod
import kotlinx.datetime.toKotlinDatePeriod
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.sql.types.YearMonthIntervalType

/**
* NOTE: Just like java.time.DatePeriod, this is truncated to months.
*/
class DatePeriodUdt : UserDefinedType<DatePeriod>() {

override fun userClass(): Class<DatePeriod> = DatePeriod::class.java
override fun deserialize(datum: Any?): DatePeriod? =
when (datum) {
null -> null
is Int -> IntervalUtils.monthsToPeriod(datum).toKotlinDatePeriod()
else -> throw IllegalArgumentException("Unsupported datum: $datum")
}

override fun serialize(obj: DatePeriod?): Int? =
obj?.let { IntervalUtils.periodToMonths(it.toJavaPeriod()) }

override fun sqlType(): YearMonthIntervalType = YearMonthIntervalType.apply()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.jetbrains.kotlinx.spark.api.udts

import kotlinx.datetime.DateTimePeriod
import org.apache.spark.sql.types.CalendarIntervalType
import org.apache.spark.sql.types.`CalendarIntervalType$`
import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.unsafe.types.CalendarInterval
import kotlin.time.Duration.Companion.hours
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.nanoseconds
import kotlin.time.Duration.Companion.seconds

/**
* NOTE: Just like java.time.DatePeriod, this is truncated to months.
*/
class DateTimePeriodUdt : UserDefinedType<DateTimePeriod>() {

override fun userClass(): Class<DateTimePeriod> = DateTimePeriod::class.java
override fun deserialize(datum: Any?): DateTimePeriod? =
when (datum) {
null -> null
is CalendarInterval ->
DateTimePeriod(
months = datum.months,
days = datum.days,
nanoseconds = datum.microseconds * 1_000,
)

else -> throw IllegalArgumentException("Unsupported datum: $datum")
}

override fun serialize(obj: DateTimePeriod?): CalendarInterval? =
obj?.let {
CalendarInterval(
/* months = */ obj.months + obj.years * 12,
/* days = */ obj.days,
/* microseconds = */
(obj.hours.hours +
obj.minutes.minutes +
obj.seconds.seconds +
obj.nanoseconds.nanoseconds).inWholeMicroseconds,
)
}

override fun sqlType(): CalendarIntervalType = `CalendarIntervalType$`.`MODULE$`
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.jetbrains.kotlinx.spark.api.udts

import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.DayTimeIntervalType
import org.apache.spark.sql.types.UserDefinedType
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.nanoseconds
import kotlin.time.toJavaDuration
import kotlin.time.toKotlinDuration

// TODO Fails, likely because Duration is a value class.
class DurationUdt : UserDefinedType<Duration>() {

override fun userClass(): Class<Duration> = Duration::class.java
override fun deserialize(datum: Any?): Duration? =
when (datum) {
null -> null
is Long -> IntervalUtils.microsToDuration(datum).toKotlinDuration()
// is Long -> IntervalUtils.microsToDuration(datum).toKotlinDuration().let {
// // store in nanos
// it.inWholeNanoseconds shl 1
// }
else -> throw IllegalArgumentException("Unsupported datum: $datum")
}

// override fun serialize(obj: Duration): Long =
// IntervalUtils.durationToMicros(obj.toJavaDuration())

fun serialize(obj: Long): Long? =
obj?.let { rawValue ->
val unitDiscriminator = rawValue.toInt() and 1
fun isInNanos() = unitDiscriminator == 0
val value = rawValue shr 1
val duration = if (isInNanos()) value.nanoseconds else value.milliseconds

IntervalUtils.durationToMicros(duration.toJavaDuration())
}

override fun serialize(obj: Duration): Long? =
obj?.let { IntervalUtils.durationToMicros(it.toJavaDuration()) }


override fun sqlType(): DataType = DayTimeIntervalType.apply()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.jetbrains.kotlinx.spark.api.udts

import kotlinx.datetime.Instant
import kotlinx.datetime.toJavaInstant
import kotlinx.datetime.toKotlinInstant
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.`TimestampType$`
import org.apache.spark.sql.types.UserDefinedType


class InstantUdt : UserDefinedType<Instant>() {

override fun userClass(): Class<Instant> = Instant::class.java
override fun deserialize(datum: Any?): Instant? =
when (datum) {
null -> null
is Long -> DateTimeUtils.microsToInstant(datum).toKotlinInstant()
else -> throw IllegalArgumentException("Unsupported datum: $datum")
}

override fun serialize(obj: Instant?): Long? =
obj?.let { DateTimeUtils.instantToMicros(it.toJavaInstant()) }

override fun sqlType(): DataType = `TimestampType$`.`MODULE$`
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.jetbrains.kotlinx.spark.api.udts

import kotlinx.datetime.LocalDateTime
import kotlinx.datetime.toJavaLocalDateTime
import kotlinx.datetime.toKotlinLocalDateTime
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.`TimestampNTZType$`
import org.apache.spark.sql.types.UserDefinedType


class LocalDateTimeUdt : UserDefinedType<LocalDateTime>() {

override fun userClass(): Class<LocalDateTime> = LocalDateTime::class.java
override fun deserialize(datum: Any?): LocalDateTime? =
when (datum) {
null -> null
is Long -> DateTimeUtils.microsToLocalDateTime(datum).toKotlinLocalDateTime()
else -> throw IllegalArgumentException("Unsupported datum: $datum")
}

override fun serialize(obj: LocalDateTime?): Long? =
obj?.let { DateTimeUtils.localDateTimeToMicros(it.toJavaLocalDateTime()) }

override fun sqlType(): DataType = `TimestampNTZType$`.`MODULE$`
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.jetbrains.kotlinx.spark.api.udts

import kotlinx.datetime.LocalDate
import kotlinx.datetime.toJavaLocalDate
import kotlinx.datetime.toKotlinLocalDate
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.`DateType$`
import org.apache.spark.sql.types.UserDefinedType


class LocalDateUdt : UserDefinedType<LocalDate>() {

override fun userClass(): Class<LocalDate> = LocalDate::class.java
override fun deserialize(datum: Any?): LocalDate? =
when (datum) {
null -> null
is Int -> DateTimeUtils.daysToLocalDate(datum).toKotlinLocalDate()
else -> throw IllegalArgumentException("Unsupported datum: $datum")
}

override fun serialize(obj: LocalDate?): Int? =
obj?.let { DateTimeUtils.localDateToDays(it.toJavaLocalDate()) }

override fun sqlType(): DataType = `DateType$`.`MODULE$`
}
Loading

0 comments on commit ab4c455

Please sign in to comment.