From 270583ad38edee3a7ab82986eda1230f03f7c907 Mon Sep 17 00:00:00 2001 From: Finley McIlwaine Date: Mon, 22 Jan 2024 19:54:27 -0800 Subject: [PATCH] Add StreamType wrappers to Client.Binary --- src/Network/GRPC/Client/Binary.hs | 58 +++++++++++++++++++++++++++++-- src/Network/GRPC/Server/Binary.hs | 6 ++-- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/src/Network/GRPC/Client/Binary.hs b/src/Network/GRPC/Client/Binary.hs index 25abeb1c..c19f9d2c 100644 --- a/src/Network/GRPC/Client/Binary.hs +++ b/src/Network/GRPC/Client/Binary.hs @@ -4,20 +4,29 @@ -- -- import Network.GRPC.Client.Binary qualified as Binary module Network.GRPC.Client.Binary ( - -- | Convenience wrappers using @binary@ for serialization/deserialization + -- * Convenience wrappers using @binary@ for serialization/deserialization sendInput , sendFinalInput , recvOutput , recvFinalOutput + + -- * Stream types + , nonStreaming + , clientStreaming + , serverStreaming + , biDiStreaming ) where +import Control.Monad +import Control.Monad.Catch import Control.Monad.IO.Class import Data.Binary import Network.GRPC.Client (Call) import Network.GRPC.Client qualified as Client -import Network.GRPC.Common.Binary import Network.GRPC.Common +import Network.GRPC.Common.Binary +import Network.GRPC.Common.StreamType qualified as StreamType {------------------------------------------------------------------------------- Convenience wrappers using @binary@ for serialization/deserialization @@ -61,3 +70,48 @@ recvFinalOutput :: forall out serv meth m. recvFinalOutput call = liftIO $ do (out, md) <- Client.recvFinalOutput call (, md) <$> decodeOrThrow out + +{------------------------------------------------------------------------------- + Wrappers for common streaming types that handle encoding/decoding +-------------------------------------------------------------------------------} + +-- | Wrapper for 'StreamType.nonStreaming' that handles encoding\/decoding of +-- input\/output. +nonStreaming :: forall inp out serv meth m. + (Binary inp, Binary out, MonadThrow m) + => StreamType.NonStreamingHandler m (BinaryRpc serv meth) + -> inp -> m out +nonStreaming h inp = + StreamType.nonStreaming h (encode inp) >>= decodeOrThrow + +-- | Wrapper for 'StreamType.clientStreaming' that handles encoding\/decoding of +-- input\/output. +clientStreaming :: forall inp out serv meth m. + (Binary inp, Binary out, MonadThrow m) + => StreamType.ClientStreamingHandler m (BinaryRpc serv meth) + -> m (StreamElem NoMetadata inp) + -> m out +clientStreaming h f = + StreamType.clientStreaming h (fmap encode <$> f) >>= decodeOrThrow + +-- | Wrapper for 'StreamType.serverStreaming' that handles encoding\/decoding of +-- input\/output. +serverStreaming :: forall inp out serv meth m. + (Binary inp, Binary out, MonadThrow m) + => StreamType.ServerStreamingHandler m (BinaryRpc serv meth) + -> inp + -> (out -> m ()) + -> m () +serverStreaming h inp f = + StreamType.serverStreaming h (encode inp) (f <=< decodeOrThrow) + +-- | Wrapper for 'StreamType.biDiStreaming' that handles encoding\/decoding of +-- input\/output. +biDiStreaming :: forall inp out serv meth m. + (Binary inp, Binary out, MonadThrow m) + => StreamType.BiDiStreamingHandler m (BinaryRpc serv meth) + -> m (StreamElem NoMetadata inp) + -> (out -> m ()) + -> m () +biDiStreaming h inp f = + StreamType.biDiStreaming h (fmap encode <$> inp) (f <=< decodeOrThrow) diff --git a/src/Network/GRPC/Server/Binary.hs b/src/Network/GRPC/Server/Binary.hs index 46382924..3a8f202b 100644 --- a/src/Network/GRPC/Server/Binary.hs +++ b/src/Network/GRPC/Server/Binary.hs @@ -98,9 +98,9 @@ mkServerStreaming f = StreamType.mkServerStreaming $ \inp send -> do mkBiDiStreaming :: forall m serv meth. MonadThrow m => ( (forall inp. Binary inp => m (StreamElem NoMetadata inp)) - -> (forall out. Binary out => out -> m ()) - -> m () - ) + -> (forall out. Binary out => out -> m ()) + -> m () + ) -> StreamType.BiDiStreamingHandler m (BinaryRpc serv meth) mkBiDiStreaming f = StreamType.mkBiDiStreaming $ \recv send -> f (recv >>= traverse decodeOrThrow) (send . encode)