diff --git a/src/main/scala/scalaz/nio/Buffer.scala b/src/main/scala/scalaz/nio/Buffer.scala index 349f93c..d171c7f 100644 --- a/src/main/scala/scalaz/nio/Buffer.scala +++ b/src/main/scala/scalaz/nio/Buffer.scala @@ -61,6 +61,9 @@ object ByteBuffer { def apply(capacity: Int): IO[Exception, ByteBuffer] = IO.syncException(JByteBuffer.allocate(capacity)).map(new ByteBuffer(_)) + + def apply(bytes: Array[Byte]): IO[Exception, ByteBuffer] = + IO.syncException(JByteBuffer.wrap(bytes)).map(new ByteBuffer(_)) } object Buffer { diff --git a/src/main/scala/scalaz/nio/channels/SocketClient.scala b/src/main/scala/scalaz/nio/channels/SocketClient.scala new file mode 100644 index 0000000..52fae8c --- /dev/null +++ b/src/main/scala/scalaz/nio/channels/SocketClient.scala @@ -0,0 +1,40 @@ +package scalaz.nio.channels + +import scalaz.nio.{ ByteBuffer, InetSocketAddress, SocketAddress } +import scalaz.zio.{ IO, Managed } + +class SocketClient private (channel: AsynchronousSocketChannel) { + + def write(bytes: Array[Byte]): IO[Exception, Unit] = + for { + buffer <- ByteBuffer(bytes) + _ <- channel.write(buffer) + } yield () + + def read(numBytes: Int): IO[Exception, Array[Byte]] = + for { + buffer <- ByteBuffer(numBytes) + _ <- channel.read(buffer) + array <- buffer.array + } yield array + + def close: IO[Exception, Unit] = channel.close + +} + +object SocketClient { + + def apply(host: String, port: Int): Managed[Exception, SocketClient] = + Managed(for { + address <- SocketAddress.inetSocketAddress(host, port) + channel <- AsynchronousSocketChannel() + _ <- channel.connect(address) + } yield new SocketClient(channel))(_.close.attempt.void) + + def apply(address: InetSocketAddress): Managed[Exception, SocketClient] = + Managed(for { + channel <- AsynchronousSocketChannel() + _ <- channel.connect(address) + } yield new SocketClient(channel))(_.close.attempt.void) + +} diff --git a/src/test/scala/scalaz/nio/ChannelSuite.scala b/src/test/scala/scalaz/nio/ChannelSuite.scala index 1db93b9..4b55d96 100644 --- a/src/test/scala/scalaz/nio/ChannelSuite.scala +++ b/src/test/scala/scalaz/nio/ChannelSuite.scala @@ -1,6 +1,10 @@ package scalaz.nio -import scalaz.nio.channels.{ AsynchronousServerSocketChannel, AsynchronousSocketChannel } +import scalaz.nio.channels.{ + AsynchronousServerSocketChannel, + AsynchronousSocketChannel, + SocketClient +} import scalaz.zio.{ IO, RTS } import testz.{ Harness, assert } @@ -8,47 +12,86 @@ object ChannelSuite extends RTS { def tests[T](harness: Harness[T]): T = { import harness._ - section(test("read/write") { () => - val inetAddress = InetAddress.localHost - .flatMap(iAddr => SocketAddress.inetSocketAddress(iAddr, 1337)) - - def echoServer: IO[Exception, Unit] = - for { - address <- inetAddress - sink <- Buffer.byte(3) - server <- AsynchronousServerSocketChannel() - _ <- server.bind(address) - worker <- server.accept - _ <- worker.read(sink) - _ <- sink.flip - _ <- worker.write(sink) - _ <- worker.close - _ <- server.close - } yield () - - def echoClient: IO[Exception, Boolean] = - for { - address <- inetAddress - src <- Buffer.byte(3) - client <- AsynchronousSocketChannel() - _ <- client.connect(address) - sent <- src.array - _ = sent.update(0, 1) - _ <- client.write(src) - _ <- src.flip - _ <- client.read(src) - received <- src.array - _ <- client.close - } yield sent.sameElements(received) - - val testProgram: IO[Exception, Boolean] = for { - serverFiber <- echoServer.fork - clientFiber <- echoClient.fork - _ <- serverFiber.join - same <- clientFiber.join - } yield same - - assert(unsafeRun(testProgram)) - }) + section( + test("read/write") { () => + val inetAddress = InetAddress.localHost + .flatMap(iAddr => SocketAddress.inetSocketAddress(iAddr, 1337)) + + def echoServer: IO[Exception, Unit] = + for { + address <- inetAddress + sink <- Buffer.byte(3) + server <- AsynchronousServerSocketChannel() + _ <- server.bind(address) + worker <- server.accept + _ <- worker.read(sink) + _ <- sink.flip + _ <- worker.write(sink) + _ <- worker.close + _ <- server.close + } yield () + + def echoClient: IO[Exception, Boolean] = + for { + address <- inetAddress + src <- Buffer.byte(3) + client <- AsynchronousSocketChannel() + _ <- client.connect(address) + sent <- src.array + _ = sent.update(0, 1) + _ <- client.write(src) + _ <- src.flip + _ <- client.read(src) + received <- src.array + _ <- client.close + } yield sent.sameElements(received) + + val testProgram: IO[Exception, Boolean] = for { + serverFiber <- echoServer.fork + clientFiber <- echoClient.fork + _ <- serverFiber.join + same <- clientFiber.join + } yield same + + assert(unsafeRun(testProgram)) + }, + test("read/write with SocketClient") { () => + val inetAddress = InetAddress.localHost + .flatMap(iAddr => SocketAddress.inetSocketAddress(iAddr, 1337)) + + def echoServer: IO[Exception, Unit] = + for { + address <- inetAddress + sink <- Buffer.byte(3) + server <- AsynchronousServerSocketChannel() + _ <- server.bind(address) + worker <- server.accept + _ <- worker.read(sink) + _ <- sink.flip + _ <- worker.write(sink) + _ <- worker.close + _ <- server.close + } yield () + + def echoClient(address: InetSocketAddress): IO[Exception, Boolean] = + SocketClient(address).use { client => + val sent: Array[Byte] = Array(0, 1, 2) + for { + _ <- client.write(sent) + received <- client.read(sent.length) + } yield sent.sameElements(received) + } + + val testProgram: IO[Exception, Boolean] = for { + address <- inetAddress + serverFiber <- echoServer.fork + clientFiber <- echoClient(address).fork + _ <- serverFiber.join + same <- clientFiber.join + } yield same + + assert(unsafeRun(testProgram)) + } + ) } }