Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hard abort on resource closure if any part of the stream remains open #853

Open
wants to merge 4 commits into
base: series/0.9
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 122 additions & 102 deletions core/src/main/scala/org/http4s/jdkhttpclient/JdkWSClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import java.net.http.{WebSocket => JWebSocket}
import java.nio.ByteBuffer
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage
import scala.concurrent.duration.DurationInt

/** A `WSClient` wrapper for the JDK 11+ websocket client. It will reply to Pongs with Pings even in
* "low-level" mode. Custom (non-GET) HTTP methods are ignored.
Expand All @@ -47,115 +48,134 @@ object JdkWSClient {
jdkHttpClient: HttpClient
)(implicit F: Async[F]): WSClient[F] =
WSClient(respondToPings = false) { req =>
Dispatcher.sequential.flatMap { dispatcher =>
Resource
.make {
for {
wsBuilder <- F.delay {
val builder = jdkHttpClient.newWebSocketBuilder()
val (subprotocols, hs) = req.headers.headers.partitionEither {
case Header.Raw(ci"Sec-WebSocket-Protocol", p) => Left(p)
case h => Right(h)
Dispatcher
.sequential(
await = false
)
.flatMap { dispatcher =>
Resource
.make {
for {
wsBuilder <- F.delay {
val builder = jdkHttpClient.newWebSocketBuilder()
val (subprotocols, hs) = req.headers.headers.partitionEither {
case Header.Raw(ci"Sec-WebSocket-Protocol", p) => Left(p)
case h => Right(h)
}
hs.foreach { h => builder.header(h.name.toString, h.value); () }
subprotocols match {
case head :: tail => builder.subprotocols(head, tail: _*)
case Nil =>
}
builder
}
hs.foreach { h => builder.header(h.name.toString, h.value); () }
subprotocols match {
case head :: tail => builder.subprotocols(head, tail: _*)
case Nil =>
queue <- Queue.unbounded[F, Either[Throwable, WSFrame]]
closedDef <- Deferred[F, Unit]
handleReceive =
(wsf: Either[Throwable, WSFrame]) =>
dispatcher.unsafeToCompletableFuture(
queue.offer(wsf) *> (wsf match {
case Left(_) | Right(_: WSFrame.Close) => closedDef.complete(()).void
case _ => F.unit
})
)
wsListener = new JWebSocket.Listener {
override def onOpen(webSocket: JWebSocket): Unit = ()
override def onClose(webSocket: JWebSocket, statusCode: Int, reason: String)
: CompletionStage[_] =
// The output side of this connection will be closed when the returned CompletionStage completes.
// Therefore, we return a never completing CompletionStage, so we can control when the output will
// be closed (as it is allowed to continue sending frames (as few as possible) after a close frame
// has been received).
handleReceive(WSFrame.Close(statusCode, reason).asRight)
.thenCompose[Nothing](_ => new CompletableFuture[Nothing])
override def onText(webSocket: JWebSocket, data: CharSequence, last: Boolean)
: CompletionStage[_] =
handleReceive(WSFrame.Text(data.toString, last).asRight)
override def onBinary(webSocket: JWebSocket, data: ByteBuffer, last: Boolean)
: CompletionStage[_] =
handleReceive(WSFrame.Binary(ByteVector(data), last).asRight)
override def onPing(webSocket: JWebSocket, message: ByteBuffer)
: CompletionStage[_] =
handleReceive(WSFrame.Ping(ByteVector(message)).asRight)
override def onPong(webSocket: JWebSocket, message: ByteBuffer)
: CompletionStage[_] =
handleReceive(WSFrame.Pong(ByteVector(message)).asRight)
override def onError(webSocket: JWebSocket, error: Throwable): Unit = {
handleReceive(error.asLeft); ()
}
}
builder
}
queue <- Queue.unbounded[F, Either[Throwable, WSFrame]]
closedDef <- Deferred[F, Unit]
handleReceive =
(wsf: Either[Throwable, WSFrame]) =>
dispatcher.unsafeToCompletableFuture(
queue.offer(wsf) *> (wsf match {
case Left(_) | Right(_: WSFrame.Close) => closedDef.complete(()).void
case _ => F.unit
})
webSocket <- F.fromCompletableFuture(
F.delay(wsBuilder.buildAsync(URI.create(req.uri.renderString), wsListener))
)
sendSem <- Semaphore[F](1L)
} yield (webSocket, queue, closedDef, sendSem)
} { case (webSocket, queue, closedDef, sendSem) =>
val cleanupF = for {
isOutputOpen <- F.delay(!webSocket.isOutputClosed)
closeOutput = sendSem.permit.use { _ =>
F.fromCompletableFuture(
F.delay(webSocket.sendClose(JWebSocket.NORMAL_CLOSURE, ""))
)
wsListener = new JWebSocket.Listener {
override def onOpen(webSocket: JWebSocket): Unit = ()
override def onClose(webSocket: JWebSocket, statusCode: Int, reason: String)
: CompletionStage[_] =
// The output side of this connection will be closed when the returned CompletionStage completes.
// Therefore, we return a never completing CompletionStage, so we can control when the output will
// be closed (as it is allowed to continue sending frames (as few as possible) after a close frame
// has been received).
handleReceive(WSFrame.Close(statusCode, reason).asRight)
.thenCompose[Nothing](_ => new CompletableFuture[Nothing])
override def onText(webSocket: JWebSocket, data: CharSequence, last: Boolean)
: CompletionStage[_] =
handleReceive(WSFrame.Text(data.toString, last).asRight)
override def onBinary(webSocket: JWebSocket, data: ByteBuffer, last: Boolean)
: CompletionStage[_] =
handleReceive(WSFrame.Binary(ByteVector(data), last).asRight)
override def onPing(webSocket: JWebSocket, message: ByteBuffer)
: CompletionStage[_] =
handleReceive(WSFrame.Ping(ByteVector(message)).asRight)
override def onPong(webSocket: JWebSocket, message: ByteBuffer)
: CompletionStage[_] =
handleReceive(WSFrame.Pong(ByteVector(message)).asRight)
override def onError(webSocket: JWebSocket, error: Throwable): Unit = {
handleReceive(error.asLeft); ()
}
_ <-
closeOutput
.whenA(isOutputOpen)
.recover { case e: IOException if e.getMessage == "closed output" => () }
.onError { case e: IOException =>
for {
errs <- Stream
.repeatEval(queue.tryTake)
.unNoneTerminate
.collect { case Left(e) => e }
.compile
.toList
_ <- F.raiseError[Unit](CompositeFailure.fromList(errs) match {
case Some(cf) => cf
case None => e
})
} yield ()
}

isInputOpen <- F.delay {
!webSocket.isInputClosed
}

_ <- if (isInputOpen) F.timeoutTo(closedDef.get, 1.second, F.unit) else F.unit

} yield ()

val ensureResourceReleaseF = F.delay {
webSocket.abort
}
webSocket <- F.fromCompletableFuture(
F.delay(wsBuilder.buildAsync(URI.create(req.uri.renderString), wsListener))
)
sendSem <- Semaphore[F](1L)
} yield (webSocket, queue, closedDef, sendSem)
} { case (webSocket, queue, _, _) =>
for {
isOutputOpen <- F.delay(!webSocket.isOutputClosed)
closeOutput = F.fromCompletableFuture(
F.delay(webSocket.sendClose(JWebSocket.NORMAL_CLOSURE, ""))
)
_ <-
closeOutput
.whenA(isOutputOpen)
.recover { case e: IOException if e.getMessage == "closed output" => () }
.onError { case e: IOException =>
for {
errs <- Stream
.repeatEval(queue.tryTake)
.unNoneTerminate
.collect { case Left(e) => e }
.compile
.toList
_ <- F.raiseError[Unit](CompositeFailure.fromList(errs) match {
case Some(cf) => cf
case None => e
})
} yield ()
}
} yield ()
}
.map { case (webSocket, queue, closedDef, sendSem) =>
// sending will throw if done in parallel
val rawSend = (wsf: WSFrame) =>
F.fromCompletableFuture(F.delay(wsf match {
case WSFrame.Text(text, last) => webSocket.sendText(text, last)
case WSFrame.Binary(data, last) => webSocket.sendBinary(data.toByteBuffer, last)
case WSFrame.Ping(data) => webSocket.sendPing(data.toByteBuffer)
case WSFrame.Pong(data) => webSocket.sendPong(data.toByteBuffer)
case WSFrame.Close(statusCode, reason) => webSocket.sendClose(statusCode, reason)
}))
.void
new WSConnection[F] {
override def send(wsf: WSFrame) =
sendSem.permit.use(_ => rawSend(wsf))
override def sendMany[G[_]: Foldable, A <: WSFrame](wsfs: G[A]) =
sendSem.permit.use(_ => wsfs.traverse_(rawSend))
override def receive = closedDef.tryGet.flatMap {
case None => F.delay(webSocket.request(1)) *> queue.take.rethrow.map(_.some)
case Some(()) => none[WSFrame].pure[F]

F.guarantee(cleanupF, ensureResourceReleaseF)
}
.map { case (webSocket, queue, closedDef, sendSem) =>
// sending will throw if done in parallel
val rawSend = (wsf: WSFrame) =>
F.fromCompletableFuture(F.delay(wsf match {
case WSFrame.Text(text, last) => webSocket.sendText(text, last)
case WSFrame.Binary(data, last) => webSocket.sendBinary(data.toByteBuffer, last)
case WSFrame.Ping(data) => webSocket.sendPing(data.toByteBuffer)
case WSFrame.Pong(data) => webSocket.sendPong(data.toByteBuffer)
case WSFrame.Close(statusCode, reason) => webSocket.sendClose(statusCode, reason)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I have seen is this line failing because the output was already closed. We don't send close frames manually so the only way this should be invoked is by connectHighLevel which mirrors the close frame. Maybe we should guard this with an if, checking if the output is still open? I don't like that much though because low level usage would then have implicit restrictions.

Also as a side note, the mirroring of the closure frame in connectHighLevel actually may fail if the server emits a close code that is not applicable for a client to send.

}))
.void
new WSConnection[F] {
override def send(wsf: WSFrame) =
sendSem.permit.use(_ => rawSend(wsf))
override def sendMany[G[_]: Foldable, A <: WSFrame](wsfs: G[A]) =
sendSem.permit.use(_ => wsfs.traverse_(rawSend))
override def receive = closedDef.tryGet.flatMap {
case None => F.delay(webSocket.request(1)) *> queue.take.rethrow.map(_.some)
case Some(()) => none[WSFrame].pure[F]
}
override def subprotocol =
webSocket.getSubprotocol.some.filter(_.nonEmpty)
}
override def subprotocol =
webSocket.getSubprotocol.some.filter(_.nonEmpty)
}
}
}
}
}

/** A `WSClient` wrapping the default `HttpClient`. */
Expand Down