Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make use of Ktor websocket extensions and serialization #799

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions core/src/commonMain/kotlin/builder/kord/KordBuilderUtil.kt
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
package dev.kord.core.builder.kord

import dev.kord.common.annotation.KordInternal
import dev.kord.common.annotation.KordUnsafe
import dev.kord.common.entity.Snowflake
import dev.kord.common.http.HttpEngine
import dev.kord.gateway.WebSocketCompression
import io.ktor.client.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.plugins.websocket.*
import io.ktor.serialization.kotlinx.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.util.*
import kotlinx.serialization.json.Json

@OptIn(KordUnsafe::class)
internal fun HttpClientConfig<*>.defaultConfig() {
expectSuccess = false

val json = Json {
encodeDefaults = false
allowStructuredMapKeys = true
ignoreUnknownKeys = true
isLenient = true
}
install(ContentNegotiation) {
json()
json(json)
}
install(WebSockets) {
contentConverter = KotlinxWebsocketSerializationConverter(json)
extensions {
install(WebSocketCompression)
}
}
install(WebSockets)
}

/** @suppress */
Expand All @@ -26,18 +41,8 @@ public fun HttpClient?.configure(): HttpClient {
defaultConfig()
}

val json = Json {
encodeDefaults = false
allowStructuredMapKeys = true
ignoreUnknownKeys = true
isLenient = true
}

return HttpClient(HttpEngine) {
defaultConfig()
install(ContentNegotiation) {
json(json)
}
}
}

Expand Down
38 changes: 36 additions & 2 deletions gateway/api/gateway.api
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ public final class dev/kord/gateway/Close$ZombieConnection : dev/kord/gateway/Cl
}

public abstract class dev/kord/gateway/Command {
public static final field Companion Ldev/kord/gateway/Command$Companion;
}

public final class dev/kord/gateway/Command$Companion {
public final fun serializer ()Lkotlinx/serialization/KSerializer;
}

public final class dev/kord/gateway/Command$Heartbeat : dev/kord/gateway/Command {
Expand All @@ -212,8 +217,10 @@ public final class dev/kord/gateway/Command$Heartbeat : dev/kord/gateway/Command
public fun toString ()Ljava/lang/String;
}

public final class dev/kord/gateway/Command$SerializationStrategy : kotlinx/serialization/SerializationStrategy {
public final class dev/kord/gateway/Command$SerializationStrategy : kotlinx/serialization/KSerializer {
public static final field INSTANCE Ldev/kord/gateway/Command$SerializationStrategy;
public fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ldev/kord/gateway/Command;
public synthetic fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ljava/lang/Object;
public fun getDescriptor ()Lkotlinx/serialization/descriptors/SerialDescriptor;
public fun serialize (Lkotlinx/serialization/encoding/Encoder;Ldev/kord/gateway/Command;)V
public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V
Expand Down Expand Up @@ -598,13 +605,20 @@ public abstract class dev/kord/gateway/DispatchEvent : dev/kord/gateway/Event {
}

public abstract class dev/kord/gateway/Event {
public static final field Companion Ldev/kord/gateway/Event$Companion;
}

public final class dev/kord/gateway/Event$DeserializationStrategy : kotlinx/serialization/DeserializationStrategy {
public final class dev/kord/gateway/Event$Companion {
public final fun serializer ()Lkotlinx/serialization/KSerializer;
}

public final class dev/kord/gateway/Event$DeserializationStrategy : kotlinx/serialization/KSerializer {
public static final field INSTANCE Ldev/kord/gateway/Event$DeserializationStrategy;
public fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ldev/kord/gateway/Event;
public synthetic fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ljava/lang/Object;
public fun getDescriptor ()Lkotlinx/serialization/descriptors/SerialDescriptor;
public fun serialize (Lkotlinx/serialization/encoding/Encoder;Ldev/kord/gateway/Event;)Ljava/lang/Void;
public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V
}

public abstract interface class dev/kord/gateway/Gateway : kotlinx/coroutines/CoroutineScope {
Expand Down Expand Up @@ -1948,6 +1962,26 @@ public final class dev/kord/gateway/VoiceStateUpdate : dev/kord/gateway/Dispatch
public fun toString ()Ljava/lang/String;
}

public final class dev/kord/gateway/WebSocketCompression : io/ktor/websocket/WebSocketExtension {
public static final field Companion Ldev/kord/gateway/WebSocketCompression$Companion;
public fun <init> ()V
public fun clientNegotiation (Ljava/util/List;)Z
public fun getFactory ()Lio/ktor/websocket/WebSocketExtensionFactory;
public fun getProtocols ()Ljava/util/List;
public fun processIncomingFrame (Lio/ktor/websocket/Frame;)Lio/ktor/websocket/Frame;
public fun processOutgoingFrame (Lio/ktor/websocket/Frame;)Lio/ktor/websocket/Frame;
public fun serverNegotiation (Ljava/util/List;)Ljava/util/List;
}

public final class dev/kord/gateway/WebSocketCompression$Companion : io/ktor/websocket/WebSocketExtensionFactory {
public fun getKey ()Lio/ktor/util/AttributeKey;
public fun getRsv1 ()Z
public fun getRsv2 ()Z
public fun getRsv3 ()Z
public fun install (Lkotlin/jvm/functions/Function1;)Ldev/kord/gateway/WebSocketCompression;
public synthetic fun install (Lkotlin/jvm/functions/Function1;)Lio/ktor/websocket/WebSocketExtension;
}

public final class dev/kord/gateway/WebhooksUpdate : dev/kord/gateway/DispatchEvent {
public fun <init> (Ldev/kord/common/entity/DiscordWebhooksUpdateData;Ljava/lang/Integer;)V
public final fun component1 ()Ldev/kord/common/entity/DiscordWebhooksUpdateData;
Expand Down
9 changes: 7 additions & 2 deletions gateway/src/commonMain/kotlin/Command.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,31 @@ import dev.kord.common.serialization.InstantInEpochMillisecondsSerializer
import kotlinx.atomicfu.atomic
import kotlinx.datetime.Instant
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.SerializationStrategy as KSerializationStrategy

@Serializable(with = Command.SerializationStrategy::class)
public sealed class Command {

public data class Heartbeat(val sequenceNumber: Int?) : Command()

public object SerializationStrategy : KSerializationStrategy<Command> {
public object SerializationStrategy : KSerializer<Command> {

override val descriptor: SerialDescriptor = buildClassSerialDescriptor("Command") {
element("op", OpCode.serializer().descriptor)
element("d", JsonElement.serializer().descriptor)
}

override fun deserialize(decoder: Decoder): Command =
TODO("Deserializing gateway commands is not supported yet")

@OptIn(PrivilegedIntent::class)
override fun serialize(encoder: Encoder, value: Command) {
val composite = encoder.beginStructure(descriptor)
Expand Down
55 changes: 55 additions & 0 deletions gateway/src/commonMain/kotlin/Compression.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package dev.kord.gateway

import dev.kord.common.annotation.KordUnsafe
import io.ktor.util.*
import io.ktor.websocket.*

/**
* [WebSocketExtension] inflating incoming websocket requests using `zlib`.
*
* *Note:** Normally you don't need this and this is configured by Kord automatically, however, if you want to use
* a custom HTTP client, you might need to add this, don't use it if you don't use what you're doing
*/
@KordUnsafe
public class WebSocketCompression : WebSocketExtension<Unit> {
/**
* https://discord.com/developers/docs/topics/gateway#transport-compression
*
* > Every connection to the gateway should use its own unique zlib context.
*
* https://api.ktor.io/ktor-shared/ktor-websockets/io.ktor.websocket/-web-socket-extension/index.html
* > A WebSocket extension instance. This instance is created for each WebSocket request,
* for every installed extension by WebSocketExtensionFactory.
*/
private val inflater = Inflater()

override val factory: WebSocketExtensionFactory<Unit, out WebSocketExtension<Unit>>
get() = Companion
override val protocols: List<WebSocketExtensionHeader>
get() = emptyList()

override fun clientNegotiation(negotiatedProtocols: List<WebSocketExtensionHeader>): Boolean = true

override fun processIncomingFrame(frame: Frame): Frame {
return if (frame is Frame.Binary) {
with(inflater) { Frame.Text(frame.inflateData()) }
} else {
frame
}
}

// Discord doesn't support deflating of gateway commands
override fun processOutgoingFrame(frame: Frame): Frame = frame

override fun serverNegotiation(requestedProtocols: List<WebSocketExtensionHeader>): List<WebSocketExtensionHeader> =
requestedProtocols

public companion object : WebSocketExtensionFactory<Unit, WebSocketCompression> {
override val key: AttributeKey<WebSocketCompression> = AttributeKey("WebSocketCompression")
override val rsv1: Boolean = false
override val rsv2: Boolean = false
override val rsv3: Boolean = false

override fun install(config: Unit.() -> Unit): WebSocketCompression = WebSocketCompression()
}
}
69 changes: 12 additions & 57 deletions gateway/src/commonMain/kotlin/DefaultGateway.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.update
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.json.Json
import mu.KotlinLogging
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
Expand Down Expand Up @@ -76,13 +76,6 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {

private val handshakeHandler: HandshakeHandler

private lateinit var inflater: Inflater

private val jsonParser = Json {
ignoreUnknownKeys = true
isLenient = true
}

private val stateMutex = Mutex()

init {
Expand Down Expand Up @@ -110,14 +103,9 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
}

defaultGatewayLogger.trace { "opening gateway connection to $gatewayUrl" }
socket = data.client.webSocketSession { url(gatewayUrl) }

/**
* https://discord.com/developers/docs/topics/gateway#transport-compression
*
* > Every connection to the gateway should use its own unique zlib context.
*/
inflater = Inflater()
socket = data.client.webSocketSession {
url(gatewayUrl)
}
} catch (exception: Exception) {
defaultGatewayLogger.error(exception)
if (exception.isTimeout()) {
Expand Down Expand Up @@ -167,31 +155,12 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
}


@OptIn(ExperimentalCoroutinesApi::class)
private suspend fun readSocket() {
socket.incoming.asFlow().buffer(Channel.UNLIMITED).collect {
when (it) {
is Frame.Binary, is Frame.Text -> read(it)
else -> { /*ignore*/
}
}
}
}

private suspend fun read(frame: Frame) {
defaultGatewayLogger.trace { "Received raw frame: $frame" }
val json = when {
compression -> with(inflater) { frame.inflateData() }
else -> frame.data.decodeToString()
}

try {
defaultGatewayLogger.trace { "Gateway <<< $json" }
val event = jsonParser.decodeFromString(Event.DeserializationStrategy, json) ?: return
while (!socket.incoming.isClosedForReceive) {
val event = socket.receiveDeserialized<Event>()
data.eventFlow.emit(event)
} catch (exception: Exception) {
defaultGatewayLogger.error(exception)
}

}

private suspend fun handleClose() {
Expand All @@ -209,6 +178,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
state.update { State.Stopped }
throw IllegalStateException("Gateway closed: ${reason.code} ${reason.message}")
}

discordReason.resetSession -> {
setStopped()
}
Expand All @@ -220,14 +190,6 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
state.update { State.Running(true) }
}

private fun <T> ReceiveChannel<T>.asFlow() = flow {
try {
for (value in this@asFlow) emit(value)
} catch (ignore: CancellationException) {
//reading was stopped from somewhere else, ignore
}
}

override suspend fun stop() {
check(state.value !is State.Detached) { "The resources of this gateway are detached, create another one" }
data.eventFlow.emit(Close.UserClose)
Expand Down Expand Up @@ -268,14 +230,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {

private suspend fun sendUnsafe(command: Command) {
data.sendRateLimiter.consume()
val json = Json.encodeToString(Command.SerializationStrategy, command)
if (command is Identify) {
defaultGatewayLogger.trace {
val copy = command.copy(token = "token")
"Gateway >>> ${Json.encodeToString(Command.SerializationStrategy, copy)}"
}
} else defaultGatewayLogger.trace { "Gateway >>> $json" }
socket.send(Frame.Text(json))
socket.sendSerialized(command)
}

@OptIn(ExperimentalCoroutinesApi::class)
Expand Down
15 changes: 10 additions & 5 deletions gateway/src/commonMain/kotlin/DefaultGatewayBuilder.kt
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
package dev.kord.gateway

import dev.kord.common.KordConfiguration
import dev.kord.common.annotation.KordUnsafe
import dev.kord.common.http.HttpEngine
import dev.kord.common.ratelimit.IntervalRateLimiter
import dev.kord.common.ratelimit.RateLimiter
import dev.kord.gateway.ratelimit.IdentifyRateLimiter
import dev.kord.gateway.retry.LinearRetry
import dev.kord.gateway.retry.Retry
import io.ktor.client.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.serialization.kotlinx.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.serialization.json.Json
import kotlin.time.Duration.Companion.seconds

public class DefaultGatewayBuilder {
Expand All @@ -28,11 +29,15 @@ public class DefaultGatewayBuilder {
public var dispatcher: CoroutineDispatcher = Dispatchers.Default
public var eventFlow: MutableSharedFlow<Event> = MutableSharedFlow(extraBufferCapacity = Int.MAX_VALUE)

@OptIn(KordUnsafe::class)
public fun build(): DefaultGateway {
val client = client ?: HttpClient(HttpEngine) {
install(WebSockets)
install(ContentNegotiation) {
json()
install(WebSockets) {
contentConverter = KotlinxWebsocketSerializationConverter(Json)

extensions {
install(WebSocketCompression)
}
}
}
val retry = reconnectRetry ?: LinearRetry(2.seconds, 20.seconds, 10)
Expand Down
Loading