Skip to content

Commit

Permalink
Improve muxer test coverage. Fix several muxer issues (#285)
Browse files Browse the repository at this point in the history
* Refactor MuxHandlerAbstractTest: add ability to handle outbound messages
* Control and check ByteBuf allocations and releases
* Add more testcases and assertions
* [Generic] (buffer leak) release message buffer if inbound frame have non-existing streamId: 71c68c4
* [Yamux] Process RST (Reset) flag: 42e40f6
* [Generic] Writing to a stream which was prior closed should be prohibited: d3c4580
* [Generic] Receiving a data on a steam which was remotely closed should result in exception (recoverable, i.e. connection should not be terminated): 8974317
* [Yamux] switch the logic of onLocalDisconnect() and onLocalClose() methods. onLocalDisconnect() should leave the stream open for inbound data: 820c252
* [Yamux] need to clean up stream entries on any kind of stream closure (local Reset, remote Reset, local + remote Close): 66ef36b
* Convert RemoteWriteClosed to singleton
  • Loading branch information
Nashatyrev authored May 30, 2023
1 parent 63869ca commit 0981ec6
Show file tree
Hide file tree
Showing 10 changed files with 440 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,39 @@ abstract class AbstractMuxHandler<TData>() :
}

fun getChannelHandlerContext(): ChannelHandlerContext {
return ctx ?: throw InternalErrorException("Internal error: handler context should be initialized at this stage")
return ctx
?: throw InternalErrorException("Internal error: handler context should be initialized at this stage")
}

protected fun childRead(id: MuxId, msg: TData) {
val child = streamMap[id] ?: throw ConnectionClosedException("Channel with id $id not opened")
pendingReadComplete += id
child.pipeline().fireChannelRead(msg)
val child = streamMap[id]
when {
child == null -> {
releaseMessage(msg)
throw ConnectionClosedException("Channel with id $id not opened")
}
child.remoteDisconnected -> {
releaseMessage(msg)
throw ConnectionClosedException("Channel with id $id was closed for sending by remote")
}
else -> {
pendingReadComplete += id
child.pipeline().fireChannelRead(msg)
}
}
}

override fun channelReadComplete(ctx: ChannelHandlerContext) {
pendingReadComplete.forEach { streamMap[it]?.pipeline()?.fireChannelReadComplete() }
pendingReadComplete.clear()
}

/**
* Needs to be called when message was not passed to the child channel pipeline due to any error.
* (if a message was passed to the child channel it's the child channel's responsibility to release the message)
*/
abstract fun releaseMessage(msg: TData)

abstract fun onChildWrite(child: MuxChannel<TData>, data: TData)

protected fun onRemoteOpen(id: MuxId) {
Expand Down Expand Up @@ -96,13 +115,15 @@ abstract class AbstractMuxHandler<TData>() :

fun onClosed(child: MuxChannel<TData>) {
streamMap.remove(child.id)
onChildClosed(child)
}

abstract override fun channelRead(ctx: ChannelHandlerContext, msg: Any)
protected open fun onRemoteCreated(child: MuxChannel<TData>) {}
protected abstract fun onLocalOpen(child: MuxChannel<TData>)
protected abstract fun onLocalClose(child: MuxChannel<TData>)
protected abstract fun onLocalDisconnect(child: MuxChannel<TData>)
protected abstract fun onChildClosed(child: MuxChannel<TData>)

private fun createChild(
id: MuxId,
Expand Down Expand Up @@ -142,5 +163,6 @@ abstract class AbstractMuxHandler<TData>() :
}
}

private fun checkClosed() = if (closed) throw ConnectionClosedException("Can't create a new stream: connection was closed: " + ctx!!.channel()) else Unit
private fun checkClosed() =
if (closed) throw ConnectionClosedException("Can't create a new stream: connection was closed: " + ctx!!.channel()) else Unit
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.libp2p.etc.util.netty.mux

import io.libp2p.core.ConnectionClosedException
import io.libp2p.etc.util.netty.AbstractChildChannel
import io.netty.channel.ChannelMetadata
import io.netty.channel.ChannelOutboundBuffer
Expand All @@ -16,8 +17,8 @@ class MuxChannel<TData>(
val initiator: Boolean
) : AbstractChildChannel(parent.ctx!!.channel(), id) {

private var remoteDisconnected = false
private var localDisconnected = false
var remoteDisconnected = false
var localDisconnected = false

override fun metadata(): ChannelMetadata = ChannelMetadata(true)
override fun localAddress0() =
Expand All @@ -35,6 +36,9 @@ class MuxChannel<TData>(
while (true) {
val msg = buf.current() ?: break
try {
if (localDisconnected) {
throw ConnectionClosedException("The stream was closed for writing locally: $id")
}
// the msg is released by both onChildWrite and buf.remove() so we need to retain
// however it is still to be confirmed that no buf leaks happen here TODO
ReferenceCountUtil.retain(msg)
Expand All @@ -55,7 +59,7 @@ class MuxChannel<TData>(
}

fun onRemoteDisconnected() {
pipeline().fireUserEventTriggered(RemoteWriteClosed())
pipeline().fireUserEventTriggered(RemoteWriteClosed)
remoteDisconnected = true
closeIfBothDisconnected()
}
Expand All @@ -74,11 +78,6 @@ class MuxChannel<TData>(
}
}

/**
* This Netty user event is fired to the [Stream] channel when remote peer closes its write side of the Stream
*/
class RemoteWriteClosed

data class MultiplexSocketAddress(val parentAddress: SocketAddress, val streamId: MuxId) : SocketAddress() {
override fun toString(): String {
return "Mux[$parentAddress-$streamId]"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package io.libp2p.etc.util.netty.mux

/**
* This Netty user event is fired to the [Stream] channel when remote peer closes its write side of the Stream
*/
object RemoteWriteClosed
4 changes: 4 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,8 @@ abstract class MuxHandler(
}.thenApply { it.attr(STREAM).get() }
return StreamPromise(stream, controller)
}

override fun releaseMessage(msg: ByteBuf) {
msg.release()
}
}
3 changes: 1 addition & 2 deletions libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,5 @@ open class MplexHandler(
getChannelHandlerContext().writeAndFlush(MplexFrame.createResetFrame(child.id))
}

override fun onRemoteCreated(child: MuxChannel<ByteBuf>) {
}
override fun onChildClosed(child: MuxChannel<ByteBuf>) {}
}
46 changes: 28 additions & 18 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ open class YamuxHandler(

fun handleFlags(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
if (msg.flags == YamuxFlags.SYN) {
// ACK the new stream
onRemoteOpen(msg.id)
ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
when (msg.flags) {
YamuxFlags.SYN -> {
// ACK the new stream
onRemoteOpen(msg.id)
ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
}
YamuxFlags.FIN -> onRemoteDisconnect(msg.id)
YamuxFlags.RST -> onRemoteClose(msg.id)
}
if (msg.flags == YamuxFlags.FIN)
onRemoteDisconnect(msg.id)
}

fun handleDataRead(msg: YamuxFrame) {
Expand All @@ -88,8 +90,10 @@ open class YamuxHandler(
if (size.toInt() == 0)
return
val recWindow = receiveWindows.get(msg.id)
if (recWindow == null)
if (recWindow == null) {
releaseMessage(msg.data!!)
throw Libp2pException("No receive window for " + msg.id)
}
val newWindow = recWindow.addAndGet(-size.toInt())
if (newWindow < INITIAL_WINDOW_SIZE / 2) {
val delta = INITIAL_WINDOW_SIZE / 2
Expand Down Expand Up @@ -143,30 +147,36 @@ open class YamuxHandler(
}

override fun onLocalOpen(child: MuxChannel<ByteBuf>) {
onStreamCreate(child)
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.SYN, 0))
}

override fun onRemoteCreated(child: MuxChannel<ByteBuf>) {
onStreamCreate(child)
}

private fun onStreamCreate(child: MuxChannel<ByteBuf>) {
receiveWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE))
sendWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE))
}

override fun onLocalDisconnect(child: MuxChannel<ByteBuf>) {
sendWindows.remove(child.id)
receiveWindows.remove(child.id)
sendBuffers.remove(child.id)
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0))
}

override fun onLocalClose(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0))
val sendWindow = sendWindows.remove(child.id)
val buffered = sendBuffers.remove(child.id)
if (buffered != null && sendWindow != null) {
buffered.flush(sendWindow, child.id)
}
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0))
}

override fun onRemoteCreated(child: MuxChannel<ByteBuf>) {
receiveWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE))
sendWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE))
override fun onLocalClose(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0))
}

override fun onChildClosed(child: MuxChannel<ByteBuf>) {
sendWindows.remove(child.id)
receiveWindows.remove(child.id)
sendBuffers.remove(child.id)
}

override fun generateNextId() =
Expand Down
Loading

0 comments on commit 0981ec6

Please sign in to comment.