Skip to content

Commit

Permalink
Cleanup http3 client
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieuancelin committed Oct 31, 2023
1 parent 13ab723 commit 973cee2
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions otoroshi/app/netty/h3client.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package otoroshi.netty

import akka.http.scaladsl.model.HttpHeader.ParsingResult
import akka.http.scaladsl.model.headers.{`Content-Length`, `Content-Type`, `User-Agent`, RawHeader}
import akka.http.scaladsl.model.headers.{RawHeader, `Content-Length`, `Content-Type`, `User-Agent`}
import akka.http.scaladsl.model.{ContentType, HttpHeader, StatusCode, Uri}
import akka.stream.scaladsl.{Sink, Source}
import akka.util.ByteString
Expand Down Expand Up @@ -42,7 +42,7 @@ import scala.collection.concurrent.TrieMap
import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.concurrent.{Await, Future, Promise}
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success}
import scala.util.{Failure, Success, Try}
import scala.xml.{Elem, XML}

case class Http3Response(
Expand Down Expand Up @@ -71,7 +71,7 @@ class NettyHttp3Client(val env: Env) {

private[netty] val logger = NettyHttp3Client.logger
private val group = new NioEventLoopGroup(Runtime.getRuntime.availableProcessors() + 1)
private val bs = new Bootstrap()
private val bs = new Bootstrap().group(group).channel(classOf[NioDatagramChannel])

private val codecs = Caches.bounded[String, ChannelHandler](999)
private val channels = Caches.bounded[String, Future[Channel]](999)
Expand Down Expand Up @@ -104,8 +104,8 @@ class NettyHttp3Client(val env: Env) {
_ => {
val promise = Promise.apply[Channel]()
val future = bs
.group(group)
.channel(classOf[NioDatagramChannel])
//.group(group)
//.channel(classOf[NioDatagramChannel])
.handler(codec)
.bind(0)
future.addListener(new GenericFutureListener[io.netty.util.concurrent.Future[Void]] {
Expand Down Expand Up @@ -200,6 +200,10 @@ class NettyHttp3Client(val env: Env) {
val hotSource = Sinks.many().unicast().onBackpressureBuffer[ByteString]()
val hotFlux = hotSource.asFlux()

override def channelRead(ctx: ChannelHandlerContext, frame: Http3UnknownFrame): Unit = {
if (logger.isDebugEnabled) logger.debug("unknown frame")
}

override def channelActive(ctx: ChannelHandlerContext): Unit = {
if (logger.isDebugEnabled) logger.debug("channel active")
}
Expand All @@ -221,12 +225,13 @@ class NettyHttp3Client(val env: Env) {
}

override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
if (logger.isDebugEnabled) logger.debug("channelReadComplete")
ctx.close()
hotSource.tryEmitComplete()
}

override def channelRead(ctx: ChannelHandlerContext, frame: Http3HeadersFrame): Unit = {
val isLast = false
val isLast = headersReceived
if (logger.isDebugEnabled) logger.debug(s"got header frame !!!! ${isLast}")
if (headersReceived) {
val trailerHeaders = frame
Expand All @@ -236,6 +241,7 @@ class NettyHttp3Client(val env: Env) {
.map(name => (name.toString, frame.headers().getAll(name).asScala.map(_.toString)))
.toMap
trailerPromise.trySuccess(trailerHeaders)
ReferenceCountUtil.release(frame)
} else {
headersReceived = true
status = frame.headers().status().toString.toInt
Expand All @@ -246,34 +252,20 @@ class NettyHttp3Client(val env: Env) {
.map(name => (name.toString, frame.headers().getAll(name).asScala.map(_.toString)))
.toMap
promise.trySuccess(Http3Response(status, headers, hotFlux, trailerPromise.future))
releaseFrameAndCloseIfLast(ctx, frame, isLast)
ReferenceCountUtil.release(frame)
}
}

override def channelRead(ctx: ChannelHandlerContext, frame: Http3DataFrame): Unit = {
val isLast = false
val content = frame.content().toString(CharsetUtil.US_ASCII)
val chunk = ByteString(content)
if (logger.isDebugEnabled) logger.debug(s"got data frame !!! - ${isLast}")
if (logger.isDebugEnabled) logger.debug(s"got data frame in !!!")
hotSource.tryEmitNext(chunk)
releaseFrameAndCloseIfLast(ctx, frame, isLast)
}

private def releaseFrameAndCloseIfLast(
ctx: ChannelHandlerContext,
frame: Http3RequestStreamFrame,
isLast: Boolean
) {
ReferenceCountUtil.release(frame)
// println("releaseFrameAndCloseIfLast", isLast, frame.getClass.getName)
if (isLast) {
//promise.trySuccess(Http3Response(status, headers, hotFlux))
ctx.close()
hotSource.tryEmitComplete()
}
}

override def channelInputClosed(ctx: ChannelHandlerContext): Unit = {
if (logger.isDebugEnabled) logger.debug("channelInputClosed")
// TODO: check if right
ctx.close()
hotSource.tryEmitComplete()
Expand Down Expand Up @@ -328,7 +320,17 @@ case class NettyHttp3ClientStrictWsResponse(resp: NettyHttp3ClientWsResponse, bo

case class NettyHttp3ClientWsResponse(resp: Http3Response, _uri: Uri, env: Env) extends WSResponse with TrailerSupport {

private lazy val _body: Source[ByteString, _] = Source.fromPublisher(resp.bodyFlux).filter(_.nonEmpty)
private lazy val _body: Source[ByteString, _] = Try {
Source.fromPublisher(resp.bodyFlux).filter(_.nonEmpty).alsoTo(Sink.onComplete {
case Failure(e) => e.printStackTrace()
case Success(_) => ()
})
} match {
case Failure(e) =>
e.printStackTrace()
Source.empty
case Success(source) => source
}

private lazy val _bodyAsBytes: ByteString = {
Await.result(
Expand Down Expand Up @@ -661,7 +663,7 @@ case class NettyHttp3ClientWsRequest(
.addListener(QuicStreamChannel.SHUTDOWN_OUTPUT)
}
promise.future.andThen {
case _ => {
case e => {
streamChannel.closeFuture().sync()
quicChannel.close().sync()
}
Expand Down

0 comments on commit 973cee2

Please sign in to comment.