diff --git a/ouroboros-network-framework/changelog.d/20230224_094922_tdammers_raw_bearer.rst b/ouroboros-network-framework/changelog.d/20230224_094922_tdammers_raw_bearer.rst new file mode 100644 index 00000000000..e2b5f727b00 --- /dev/null +++ b/ouroboros-network-framework/changelog.d/20230224_094922_tdammers_raw_bearer.rst @@ -0,0 +1,6 @@ +Added +----- + +- RawBearer API +- ToRawBearer typeclass +- ToRawBearer instances for `Socket`, `LocalSocket`, and `Simulation.Network.Snocket.FD` diff --git a/ouroboros-network-framework/ouroboros-network-framework.cabal b/ouroboros-network-framework/ouroboros-network-framework.cabal index 43479c7dfa5..9830606e53f 100644 --- a/ouroboros-network-framework/ouroboros-network-framework.cabal +++ b/ouroboros-network-framework/ouroboros-network-framework.cabal @@ -31,6 +31,7 @@ library Ouroboros.Network.IOManager Ouroboros.Network.Mux Ouroboros.Network.MuxMode + Ouroboros.Network.RawBearer Ouroboros.Network.Protocol.Handshake Ouroboros.Network.Protocol.Handshake.Type @@ -155,6 +156,7 @@ test-suite test Test.Ouroboros.Network.Socket Test.Ouroboros.Network.Subscription Test.Ouroboros.Network.RateLimiting + Test.Ouroboros.Network.RawBearer Test.Simulation.Network.Snocket build-depends: base >=4.14 && <4.17 diff --git a/ouroboros-network-framework/src/Ouroboros/Network/RawBearer.hs b/ouroboros-network-framework/src/Ouroboros/Network/RawBearer.hs new file mode 100644 index 00000000000..2b4cbc65ffa --- /dev/null +++ b/ouroboros-network-framework/src/Ouroboros/Network/RawBearer.hs @@ -0,0 +1,50 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE MultiParamTypeClasses #-} + +module Ouroboros.Network.RawBearer +where + +import Network.Socket (Socket) +import qualified Network.Socket as Socket +import Foreign.Ptr (Ptr) +import Data.Word (Word8) + +#if defined(mingw32_HOST_OS) +import Data.Bits +import Foreign.Ptr (IntPtr (..), ptrToIntPtr) +import qualified System.Win32 as Win32 +import qualified System.Win32.Async as Win32.Async +import qualified System.Win32.NamedPipes as Win32 +#endif + +-- | Generalized API for sending and receiving raw bytes over a file +-- descriptor, socket, or similar object. +data RawBearer m = + RawBearer + { send :: Ptr Word8 -> Int -> m Int + , recv :: Ptr Word8 -> Int -> m Int + } + +class ToRawBearer m fd where + toRawBearer :: fd -> m (RawBearer m) + +instance ToRawBearer IO Socket where + toRawBearer s = + return RawBearer + { send = Socket.sendBuf s + , recv = Socket.recvBuf s + } + +#if defined(mingw32_HOST_OS) + +-- | We cannot declare an @instance ToRawBearer Win32.HANDLE@, because +-- 'Win32.Handle' is just a type alias for @Ptr ()@. So instead, we provide +-- this function, which can be used to implement 'ToRawBearer' elsewhere (e.g. +-- over a newtype). +win32HandleToRawBearer :: Win32.HANDLE -> RawBearer IO +win32HandleToRawBearer s = + RawBearer + { send = \buf size -> fromIntegral <$> Win32.win32_WriteFile s (castPtr buf) (fromIntegral size) + , recv = \buf size -> fromIntegral <$> Win32.win32_ReadFile s (castPtr buf) (fromIntegral size) + } +#endif diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs index d5b1e4085c3..91d1a10595e 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs @@ -60,6 +60,7 @@ import qualified Network.Socket as Socket import Network.Mux.Bearer +import Ouroboros.Network.RawBearer import Ouroboros.Network.IOManager @@ -349,10 +350,17 @@ data LocalSocket = LocalSocket { getLocalHandle :: !LocalHandle } deriving (Eq, Generic) deriving Show via Quiet LocalSocket + +instance ToRawBearer IO LocalSocket where + toRawBearer = return . win32HandleToRawBearer . getLocalHandle + #else newtype LocalSocket = LocalSocket { getLocalHandle :: LocalHandle } deriving (Eq, Generic) deriving Show via Quiet LocalSocket + +instance ToRawBearer IO LocalSocket where + toRawBearer = toRawBearer . getLocalHandle #endif makeLocalBearer :: MakeBearer IO LocalSocket diff --git a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs index 03db9005a02..ffc06dc4a77 100644 --- a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs @@ -51,6 +51,9 @@ import Control.Monad (when) import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadTime import Control.Monad.Class.MonadTimer +import Control.Monad.Class.MonadST +import Control.Monad.Class.MonadSay +import Control.Monad.ST.Unsafe (unsafeIOToST) import Control.Tracer (Tracer, contramap, contramapM, traceWith) import GHC.IO.Exception @@ -62,8 +65,12 @@ import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Data.Typeable (Typeable) import Foreign.C.Error +import Foreign.Ptr (castPtr) +import Foreign.Marshal (copyBytes) import Numeric.Natural (Natural) import Text.Printf (printf) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Lazy as LBS import Data.Monoid.Synchronisation (FirstToFinish (..)) import Data.Wedge @@ -74,6 +81,7 @@ import Network.Mux.Bearer.AttenuatedChannel import Ouroboros.Network.ConnectionId import Ouroboros.Network.ConnectionManager.Types (AddressType (..)) import Ouroboros.Network.Snocket +import Ouroboros.Network.RawBearer import Ouroboros.Network.Testing.Data.Script (Script (..), stepScriptSTM) @@ -526,8 +534,92 @@ instance Show addr => Show (FD_ m addr) where -- | File descriptor type. -- -newtype FD m peerAddr = FD { fdVar :: (StrictTVar m (FD_ m peerAddr)) } - +newtype FD m peerAddr = FD { fdVar :: StrictTVar m (FD_ m peerAddr) } + +instance ( MonadST m + , MonadThrow m + , MonadSay m + , MonadLabelledSTM m + , Show addr + ) => ToRawBearer m (FD m (TestAddress addr)) where + toRawBearer = makeRawFDBearer + +-- | Make a 'RawBearer' from an 'FD'. Since this is only used for testing, we +-- can bypass the requirement of moving raw bytes directly between file +-- descriptors and provided memory buffers, and we can instead covertly use +-- plain old 'ByteString' under the hood. This allows us to use the +-- 'AttenuatedChannel' inside the `FD_`, even though its send and receive +-- methods do not have the right format. +makeRawFDBearer :: forall addr m. + ( MonadST m + , MonadLabelledSTM m + , MonadThrow m + , MonadSay m + , Show addr + ) + => FD m (TestAddress addr) + -> m (RawBearer m) +makeRawFDBearer (FD {fdVar}) = do + (bufVar :: StrictTMVar m LBS.ByteString) <- newTMVarIO LBS.empty + return RawBearer + { send = \src srcSize -> do + labelTVarIO fdVar "sender" + say $ "Sending " ++ show srcSize ++ " bytes" + fd_ <- readTVarIO fdVar + case fd_ of + FDConnected _ conn -> do + bs <- withLiftST $ \liftST -> + liftST . unsafeIOToST $ BS.packCStringLen (castPtr src, srcSize) + let bsl = LBS.fromStrict bs + acWrite (connChannelLocal conn) bsl + say $ "Sent " ++ show srcSize ++ " bytes" + return srcSize + _ -> + throwIO (invalidError fd_) + , recv = \dst size -> do + labelTVarIO fdVar "receiver" + let size64 = fromIntegral size + say $ "Receiving " ++ show size ++ " bytes" + fd_ <- readTVarIO fdVar + case fd_ of + FDConnected _ conn -> do + say $ "Checking buffer" + bytesFromBuffer <- atomically $ takeTMVar bufVar + say $ "Buffer: " ++ show bytesFromBuffer + (lhs, rhs) <- if not (LBS.null bytesFromBuffer) then do + say $ "Reading up to " ++ show size ++ " bytes from buffer" + return (LBS.take size64 bytesFromBuffer, LBS.drop size64 bytesFromBuffer) + else do + bytesRead <- acRead (connChannelLocal conn) + say $ "Received " ++ show (LBS.length $ LBS.take size64 bytesRead) ++ " or more bytes" + return (LBS.take size64 bytesRead, LBS.drop size64 bytesRead) + say $ "Updating buffer; use: " ++ show lhs ++ " keep: " ++ show (LBS.take 10 rhs) ++ + if (LBS.length . LBS.take 11 $ rhs) == 11 then "..." else "" + atomically $ putTMVar bufVar rhs + say $ "Receive: buffer updated" + if LBS.null lhs then do + say $ "Receive: End of stream" + return 0 + else do + say $ "Receive: copying." + let bs = LBS.toStrict lhs + withLiftST $ \liftST -> + liftST . unsafeIOToST $ BS.useAsCStringLen bs $ \(src, srcSize) -> do + copyBytes dst (castPtr src) srcSize + return srcSize + _ -> + throwIO (invalidError fd_) + } + where + invalidError :: FD_ m (TestAddress addr) -> IOError + invalidError fd_ = IOError + { ioe_handle = Nothing + , ioe_type = InvalidArgument + , ioe_location = "Ouroboros.Network.Snocket.Sim.toRawBearer" + , ioe_description = printf "Invalid argument (%s)" (show fd_) + , ioe_errno = Nothing + , ioe_filename = Nothing + } makeFDBearer :: forall addr m. ( MonadMonotonicTime m @@ -551,17 +643,16 @@ makeFDBearer = MakeBearer $ \sduTimeout muxTracer FD { fdVar } -> do (connChannelLocal conn) FDClosed {} -> throwIO (invalidError fd_) - where - -- io errors - invalidError :: FD_ m (TestAddress addr) -> IOError - invalidError fd_ = IOError - { ioe_handle = Nothing - , ioe_type = InvalidArgument - , ioe_location = "Ouroboros.Network.Snocket.Sim.toBearer" - , ioe_description = printf "Invalid argument (%s)" (show fd_) - , ioe_errno = Nothing - , ioe_filename = Nothing - } + where + invalidError :: FD_ m (TestAddress addr) -> IOError + invalidError fd_ = IOError + { ioe_handle = Nothing + , ioe_type = InvalidArgument + , ioe_location = "Ouroboros.Network.Snocket.Sim.toBearer" + , ioe_description = printf "Invalid argument (%s)" (show fd_) + , ioe_errno = Nothing + , ioe_filename = Nothing + } -- -- Simulated snockets diff --git a/ouroboros-network-framework/test/Main.hs b/ouroboros-network-framework/test/Main.hs index 22db7863812..a889060bd8c 100644 --- a/ouroboros-network-framework/test/Main.hs +++ b/ouroboros-network-framework/test/Main.hs @@ -9,6 +9,7 @@ import qualified Test.Ouroboros.Network.Server2 as Server2 import qualified Test.Ouroboros.Network.Socket as Socket import qualified Test.Ouroboros.Network.Subscription as Subscription import qualified Test.Simulation.Network.Snocket as Snocket +import qualified Test.Ouroboros.Network.RawBearer as RawBearer main :: IO () main = defaultMain tests @@ -23,6 +24,7 @@ tests = , Subscription.tests , RateLimiting.tests , Snocket.tests + , RawBearer.tests ] diff --git a/ouroboros-network-framework/test/Test/Ouroboros/Network/RawBearer.hs b/ouroboros-network-framework/test/Test/Ouroboros/Network/RawBearer.hs new file mode 100644 index 00000000000..c6bcbe99299 --- /dev/null +++ b/ouroboros-network-framework/test/Test/Ouroboros/Network/RawBearer.hs @@ -0,0 +1,190 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE RankNTypes #-} + +module Test.Ouroboros.Network.RawBearer +where + +import Ouroboros.Network.RawBearer +import Ouroboros.Network.Snocket +import Ouroboros.Network.Testing.Data.AbsBearerInfo +import Ouroboros.Network.IOManager + +import Control.Monad (when) +import Control.Monad.Class.MonadAsync +import Control.Monad.Class.MonadST (MonadST, withLiftST) +import Control.Monad.Class.MonadThrow (MonadThrow, bracket, throwIO, catchJust, finally) +import Control.Monad.Class.MonadFork (labelThisThread) +import Control.Monad.Class.MonadTimer (MonadTimer (..), threadDelay) +import Control.Monad.Class.MonadMVar +import Control.Monad.Class.MonadSay +import Control.Monad.IOSim hiding (liftST) +import Control.Monad.ST.Unsafe (unsafeIOToST) +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS +import Foreign.Marshal (copyBytes, mallocBytes, free) +import Foreign.Ptr (castPtr, plusPtr) +import qualified Network.Socket as Socket +import Simulation.Network.Snocket as SimSnocket +import System.Directory (removeFile) +import System.IO.Error (isDoesNotExistErrorType, ioeGetErrorType) +import Control.Tracer (nullTracer) +import Control.Exception (Exception) + +import Test.Tasty +import Test.Tasty.QuickCheck +import Test.Simulation.Network.Snocket (toBearerInfo) + +tests :: TestTree +tests = testGroup "Ouroboros.Network.RawBearer" + [ testProperty "raw bearer send receive simulated socket" prop_raw_bearer_send_and_receive_iosim + , testProperty "raw bearer send receive unix socket" prop_raw_bearer_send_and_receive_unix + ] + +prop_raw_bearer_send_and_receive_unix :: Message -> Property +prop_raw_bearer_send_and_receive_unix msg = + ioProperty $ withIOManager $ \iomgr -> do + let clientName = "unix_socket_client.test" + let serverName = "unix_socket_server.test" + cleanUp clientName + cleanUp serverName + let clientAddr = Socket.SockAddrUnix clientName + let serverAddr = Socket.SockAddrUnix serverName + rawBearerSendAndReceive + (socketSnocket iomgr) + clientAddr serverAddr + msg `finally` do + cleanUp clientName + cleanUp serverName + where + cleanUp name = do + catchJust (\e -> if isDoesNotExistErrorType (ioeGetErrorType e) then Just () else Nothing) + (removeFile name) + (\_ -> return ()) + +prop_raw_bearer_send_and_receive_iosim :: Int -> Int -> Message -> Property +prop_raw_bearer_send_and_receive_iosim clientInt serverInt msg = + (clientInt /= serverInt) ==> + iosimProperty $ + SimSnocket.withSnocket + nullTracer + (toBearerInfo absNoAttenuation) + mempty $ \snocket _observe -> do + rawBearerSendAndReceive + snocket + (TestAddress clientInt) + (TestAddress serverInt) + msg + +newtype Message = Message { messageBytes :: ByteString } + deriving (Show, Eq, Ord) + +instance Arbitrary Message where + shrink = filter (not . BS.null . messageBytes) . fmap (Message . BS.pack) . shrink . BS.unpack . messageBytes + arbitrary = Message . BS.pack <$> listOf1 arbitrary + +newtype TestError = + TestError String + deriving (Show) + +instance Exception TestError where + +rawBearerSendAndReceive :: forall m fd addr + . ( MonadST m + , MonadThrow m + , MonadAsync m + -- , MonadTimer m + , MonadMVar m + , MonadSay m + , ToRawBearer m fd + ) + => Snocket m fd addr + -> addr + -> addr + -> Message + -> m Property +rawBearerSendAndReceive snocket clientAddr serverAddr msg = + withLiftST $ \liftST -> do + let io = liftST . unsafeIOToST + let size = BS.length (messageBytes msg) + retVar <- newEmptyMVar + senderDone <- newEmptyMVar + let sender = bracket (openToConnect snocket clientAddr) (close snocket) $ \s -> do + say "sender: connecting" + connect snocket s serverAddr + say "sender: connected" + bearer <- toRawBearer s + bracket (io $ mallocBytes size) (io . free) $ \srcBuf -> do + io $ BS.useAsCStringLen (messageBytes msg) + (uncurry (copyBytes srcBuf)) + let go _ 0 = do + say "sender: done" + return () + go buf n = do + say $ "sender: " ++ show n ++ " bytes left" + bytesSent <- send bearer buf n + when (bytesSent == 0) (throwIO $ TestError "sender: premature hangup") + let n' = n - bytesSent + say $ "sender: " ++ show bytesSent ++ " bytes sent, " ++ show n' ++ " remaining" + go (plusPtr buf bytesSent) n' + go (castPtr srcBuf) size + putMVar senderDone () + receiver s = do + let acceptLoop :: Accept m fd addr -> m () + acceptLoop accept0 = do + say "receiver: accepting connection" + (accepted, acceptNext) <- runAccept accept0 + case accepted :: Accepted fd addr of + AcceptFailure err -> + throwIO err + Accepted s' _ -> do + labelThisThread "accept" + say "receiver: connection accepted" + flip finally (say "receiver: closing connection" >> close snocket s' >> say "receiver: connection closed") $ do + bearer <- toRawBearer s' + retval <- bracket (io $ mallocBytes size) (io . free) $ \dstBuf -> do + let go _ 0 = do + say "receiver: done receiving" + return () + go buf n = do + say $ "receiver: " ++ show n ++ " bytes left" + bytesReceived <- recv bearer buf n + when (bytesReceived == 0) (throwIO $ TestError "receiver: premature hangup") + let n' = n - bytesReceived + say $ "receiver: " ++ show bytesReceived ++ " bytes received, " ++ show n' ++ " remaining" + go (plusPtr buf bytesReceived) n' + go (castPtr dstBuf) size + io (BS.packCStringLen (castPtr dstBuf, size)) + say $ "receiver: received " ++ show retval + written <- tryPutMVar retVar retval + say $ if written then "receiver: stored " ++ show retval else "receiver: already have result" + say "receiver: finishing connection" + acceptLoop acceptNext + accept snocket s >>= acceptLoop + + resBSEither <- bracket (open snocket (addrFamily snocket serverAddr)) (close snocket) $ \s -> do + say "receiver: starting" + bind snocket s serverAddr + listen snocket s + say "receiver: listening" + race + (sender `concurrently` receiver s) + (takeMVar retVar <* takeMVar senderDone) + return $ resBSEither === Right (messageBytes msg) + +iosimProperty :: (forall s . IOSim s Property) + -> Property +iosimProperty sim = + let tr = runSimTrace sim + in case traceResult True tr of + Left e -> counterexample + (unlines + [ "=== Say Events ===" + , unlines (selectTraceEventsSay' tr) + , "=== Trace Events ===" + , unlines (show `map` traceEvents tr) + , "=== Error ===" + , show e ++ "\n" + ]) + False + Right prop -> prop +