diff --git a/auto-update/Control/AutoUpdate.hs b/auto-update/Control/AutoUpdate.hs index 723dbd2d1..a5b3daf3b 100644 --- a/auto-update/Control/AutoUpdate.hs +++ b/auto-update/Control/AutoUpdate.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} + -- | In a multithreaded environment, running actions on a regularly scheduled -- background thread can dramatically improve performance. -- For example, web servers need to return the current time with each HTTP response. @@ -29,39 +30,52 @@ -- -- For more examples, . module Control.AutoUpdate ( - -- * Type - UpdateSettings - , defaultUpdateSettings - -- * Accessors - , updateAction - , updateFreq - , updateSpawnThreshold - -- * Creation - , mkAutoUpdate - , mkAutoUpdateWithModify - ) where + -- * Type + UpdateSettings, + defaultUpdateSettings, + + -- * Accessors + updateAction, + updateFreq, + updateSpawnThreshold, + + -- * Creation + mkAutoUpdate, + mkAutoUpdateWithModify, +) where #if __GLASGOW_HASKELL__ < 709 import Control.Applicative ((<*>)) #endif -import Control.Concurrent (forkIO, threadDelay) -import Control.Concurrent.MVar (newEmptyMVar, putMVar, readMVar, - takeMVar, tryPutMVar) -import Control.Exception (SomeException, catch, mask_, throw, - try) -import Control.Monad (void) -import Data.IORef (newIORef, readIORef, writeIORef) -import Data.Maybe (fromMaybe) +import Control.Concurrent (forkIO, threadDelay) +import Control.Concurrent.MVar ( + newEmptyMVar, + putMVar, + readMVar, + takeMVar, + tryPutMVar, + ) +import Control.Exception ( + SomeException, + catch, + mask_, + throw, + try, + ) +import Control.Monad (void) +import Data.IORef (newIORef, readIORef, writeIORef) +import Data.Maybe (fromMaybe) -- | Default value for creating an 'UpdateSettings'. -- -- @since 0.1.0 defaultUpdateSettings :: UpdateSettings () -defaultUpdateSettings = UpdateSettings - { updateFreq = 1000000 - , updateSpawnThreshold = 3 - , updateAction = return () - } +defaultUpdateSettings = + UpdateSettings + { updateFreq = 1000000 + , updateSpawnThreshold = 3 + , updateAction = return () + } -- | Settings to control how values are updated. -- @@ -74,7 +88,7 @@ defaultUpdateSettings = UpdateSettings -- -- @since 0.1.0 data UpdateSettings a = UpdateSettings - { updateFreq :: Int + { updateFreq :: Int -- ^ Microseconds between update calls. Same considerations as -- 'threadDelay' apply. -- @@ -91,7 +105,7 @@ data UpdateSettings a = UpdateSettings -- Default: 3 -- -- @since 0.1.0 - , updateAction :: IO a + , updateAction :: IO a -- ^ Action to be performed to get the current value. -- -- Default: does nothing. @@ -137,12 +151,16 @@ mkAutoUpdateHelper us updateActionModify = do let fillRefOnExit f = do eres <- try f case eres of - Left e -> writeIORef currRef $ error $ - "Control.AutoUpdate.mkAutoUpdate: worker thread exited with exception: " - ++ show (e :: SomeException) - Right () -> writeIORef currRef $ error $ - "Control.AutoUpdate.mkAutoUpdate: worker thread exited normally, " - ++ "which should be impossible due to usage of infinite loop" + Left e -> + writeIORef currRef $ + error $ + "Control.AutoUpdate.mkAutoUpdate: worker thread exited with exception: " + ++ show (e :: SomeException) + Right () -> + writeIORef currRef $ + error $ + "Control.AutoUpdate.mkAutoUpdate: worker thread exited normally, " + ++ "which should be impossible due to usage of infinite loop" -- fork the worker thread immediately. Note that we mask async exceptions, -- but *not* in an uninterruptible manner. This will allow a diff --git a/auto-update/Control/AutoUpdate/Util.hs b/auto-update/Control/AutoUpdate/Util.hs index 8978b439f..7e2348911 100644 --- a/auto-update/Control/AutoUpdate/Util.hs +++ b/auto-update/Control/AutoUpdate/Util.hs @@ -1,7 +1,8 @@ {-# LANGUAGE CPP #-} -module Control.AutoUpdate.Util - ( atomicModifyIORef' - ) where + +module Control.AutoUpdate.Util ( + atomicModifyIORef', +) where #ifndef MIN_VERSION_base #define MIN_VERSION_base(x,y,z) 1 diff --git a/auto-update/Control/Debounce.hs b/auto-update/Control/Debounce.hs index bae733ae6..2f77d734f 100644 --- a/auto-update/Control/Debounce.hs +++ b/auto-update/Control/Debounce.hs @@ -23,37 +23,40 @@ -- See the fast-logger package ("System.Log.FastLogger") for real-world usage. -- -- @since 0.1.2 -module Control.Debounce - ( -- * Type - DI.DebounceSettings - , defaultDebounceSettings - -- * Accessors - , DI.debounceFreq - , DI.debounceAction - , DI.debounceEdge - , DI.leadingEdge - , DI.trailingEdge - -- * Creation - , mkDebounce - ) where +module Control.Debounce ( + -- * Type + DI.DebounceSettings, + defaultDebounceSettings, -import Control.Concurrent (newEmptyMVar, threadDelay) + -- * Accessors + DI.debounceFreq, + DI.debounceAction, + DI.debounceEdge, + DI.leadingEdge, + DI.trailingEdge, + + -- * Creation + mkDebounce, +) where + +import Control.Concurrent (newEmptyMVar, threadDelay) import qualified Control.Debounce.Internal as DI -- | Default value for creating a 'DebounceSettings'. -- -- @since 0.1.2 defaultDebounceSettings :: DI.DebounceSettings -defaultDebounceSettings = DI.DebounceSettings - { DI.debounceFreq = 1000000 - , DI.debounceAction = return () - , DI.debounceEdge = DI.leadingEdge - } +defaultDebounceSettings = + DI.DebounceSettings + { DI.debounceFreq = 1000000 + , DI.debounceAction = return () + , DI.debounceEdge = DI.leadingEdge + } -- | Generate an action which will trigger the debounced action to be performed. -- -- @since 0.1.2 mkDebounce :: DI.DebounceSettings -> IO (IO ()) mkDebounce settings = do - baton <- newEmptyMVar - DI.mkDebounceInternal baton threadDelay settings + baton <- newEmptyMVar + DI.mkDebounceInternal baton threadDelay settings diff --git a/auto-update/Control/Debounce/Internal.hs b/auto-update/Control/Debounce/Internal.hs index 1cdb970b4..1426a97b8 100644 --- a/auto-update/Control/Debounce/Internal.hs +++ b/auto-update/Control/Debounce/Internal.hs @@ -2,17 +2,22 @@ -- | Unstable API which exposes internals for testing. module Control.Debounce.Internal ( - DebounceSettings(..) - , DebounceEdge(..) - , leadingEdge - , trailingEdge - , mkDebounceInternal - ) where + DebounceSettings (..), + DebounceEdge (..), + leadingEdge, + trailingEdge, + mkDebounceInternal, +) where -import Control.Concurrent (forkIO) -import Control.Concurrent.MVar (takeMVar, tryPutMVar, tryTakeMVar, MVar) -import Control.Exception (SomeException, handle, mask_) -import Control.Monad (forever, void) +import Control.Concurrent (forkIO) +import Control.Concurrent.MVar ( + MVar, + takeMVar, + tryPutMVar, + tryTakeMVar, + ) +import Control.Exception (SomeException, handle, mask_) +import Control.Monad (forever, void) -- | Settings to control how debouncing should work. -- @@ -25,7 +30,7 @@ import Control.Monad (forever, void) -- -- @since 0.1.2 data DebounceSettings = DebounceSettings - { debounceFreq :: Int + { debounceFreq :: Int -- ^ Length of the debounce timeout period in microseconds. -- -- Default: 1 second (1000000) @@ -52,16 +57,15 @@ data DebounceSettings = DebounceSettings -- edge of the timeout. -- -- @since 0.1.6 -data DebounceEdge = - Leading - -- ^ Perform the action immediately, and then begin a cooldown period. - -- If the trigger happens again during the cooldown, wait until the end of the cooldown - -- and then perform the action again, then enter a new cooldown period. - | Trailing - -- ^ Start a cooldown period and perform the action when the period ends. If another trigger - -- happens during the cooldown, it has no effect. - deriving (Show, Eq) - +data DebounceEdge + = -- | Perform the action immediately, and then begin a cooldown period. + -- If the trigger happens again during the cooldown, wait until the end of the cooldown + -- and then perform the action again, then enter a new cooldown period. + Leading + | -- | Start a cooldown period and perform the action when the period ends. If another trigger + -- happens during the cooldown, it has no effect. + Trailing + deriving (Show, Eq) -- | Perform the action immediately, and then begin a cooldown period. -- If the trigger happens again during the cooldown, wait until the end of the cooldown @@ -78,19 +82,20 @@ leadingEdge = Leading trailingEdge :: DebounceEdge trailingEdge = Trailing -mkDebounceInternal :: MVar () -> (Int -> IO ()) -> DebounceSettings -> IO (IO ()) +mkDebounceInternal + :: MVar () -> (Int -> IO ()) -> DebounceSettings -> IO (IO ()) mkDebounceInternal baton delayFn (DebounceSettings freq action edge) = do mask_ $ void $ forkIO $ forever $ do takeMVar baton case edge of - Leading -> do - ignoreExc action - delayFn freq - Trailing -> do - delayFn freq - -- Empty the baton of any other activations during the interval - void $ tryTakeMVar baton - ignoreExc action + Leading -> do + ignoreExc action + delayFn freq + Trailing -> do + delayFn freq + -- Empty the baton of any other activations during the interval + void $ tryTakeMVar baton + ignoreExc action return $ void $ tryPutMVar baton () diff --git a/auto-update/Control/Reaper.hs b/auto-update/Control/Reaper.hs index 20720fe6a..eb09b1a9c 100644 --- a/auto-update/Control/Reaper.hs +++ b/auto-update/Control/Reaper.hs @@ -1,5 +1,5 @@ -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE RecordWildCards #-} -- | This module provides the ability to create reapers: dedicated cleanup -- threads. These threads will automatically spawn and die based on the @@ -12,28 +12,32 @@ -- For real-world usage, search the -- for imports of "Control.Reaper". module Control.Reaper ( - -- * Example: Regularly cleaning a cache - -- $example1 + -- * Example: Regularly cleaning a cache + -- $example1 - -- * Settings - ReaperSettings - , defaultReaperSettings - -- * Accessors - , reaperAction - , reaperDelay - , reaperCons - , reaperNull - , reaperEmpty - -- * Type - , Reaper(..) - -- * Creation - , mkReaper - -- * Helper - , mkListAction - ) where + -- * Settings + ReaperSettings, + defaultReaperSettings, + + -- * Accessors + reaperAction, + reaperDelay, + reaperCons, + reaperNull, + reaperEmpty, + + -- * Type + Reaper (..), + + -- * Creation + mkReaper, + + -- * Helper + mkListAction, +) where import Control.AutoUpdate.Util (atomicModifyIORef') -import Control.Concurrent (forkIO, threadDelay, killThread, ThreadId) +import Control.Concurrent (ThreadId, forkIO, killThread, threadDelay) import Control.Exception (mask_) import Data.IORef (IORef, newIORef, readIORef, writeIORef) @@ -90,30 +94,34 @@ data ReaperSettings workload item = ReaperSettings -- -- @since 0.1.1 defaultReaperSettings :: ReaperSettings [item] item -defaultReaperSettings = ReaperSettings - { reaperAction = \wl -> return (wl ++) - , reaperDelay = 30000000 - , reaperCons = (:) - , reaperNull = null - , reaperEmpty = [] - } +defaultReaperSettings = + ReaperSettings + { reaperAction = \wl -> return (wl ++) + , reaperDelay = 30000000 + , reaperCons = (:) + , reaperNull = null + , reaperEmpty = [] + } -- | A data structure to hold reaper APIs. -data Reaper workload item = Reaper { - -- | Adding an item to the workload - reaperAdd :: item -> IO () - -- | Reading workload. - , reaperRead :: IO workload - -- | Stopping the reaper thread if exists. +data Reaper workload item = Reaper + { reaperAdd :: item -> IO () + -- ^ Adding an item to the workload + , reaperRead :: IO workload + -- ^ Reading workload. + , reaperStop :: IO workload + -- ^ Stopping the reaper thread if exists. -- The current workload is returned. - , reaperStop :: IO workload - -- | Killing the reaper thread immediately if exists. - , reaperKill :: IO () - } + , reaperKill :: IO () + -- ^ Killing the reaper thread immediately if exists. + } -- | State of reaper. -data State workload = NoReaper -- ^ No reaper thread - | Workload !workload -- ^ The current jobs +data State workload + = -- | No reaper thread + NoReaper + | -- | The current jobs + Workload !workload -- | Create a reaper addition function. This function can be used to add -- new items to the workload. Spawning of reaper threads will be handled @@ -123,52 +131,62 @@ data State workload = NoReaper -- ^ No reaper thread mkReaper :: ReaperSettings workload item -> IO (Reaper workload item) mkReaper settings@ReaperSettings{..} = do stateRef <- newIORef NoReaper - tidRef <- newIORef Nothing - return Reaper { - reaperAdd = add settings stateRef tidRef - , reaperRead = readRef stateRef - , reaperStop = stop stateRef - , reaperKill = kill tidRef - } + tidRef <- newIORef Nothing + return + Reaper + { reaperAdd = add settings stateRef tidRef + , reaperRead = readRef stateRef + , reaperStop = stop stateRef + , reaperKill = kill tidRef + } where readRef stateRef = do mx <- readIORef stateRef case mx of - NoReaper -> return reaperEmpty + NoReaper -> return reaperEmpty Workload wl -> return wl stop stateRef = atomicModifyIORef' stateRef $ \mx -> case mx of - NoReaper -> (NoReaper, reaperEmpty) + NoReaper -> (NoReaper, reaperEmpty) Workload x -> (Workload reaperEmpty, x) kill tidRef = do mtid <- readIORef tidRef case mtid of - Nothing -> return () + Nothing -> return () Just tid -> killThread tid -add :: ReaperSettings workload item - -> IORef (State workload) -> IORef (Maybe ThreadId) - -> item -> IO () +add + :: ReaperSettings workload item + -> IORef (State workload) + -> IORef (Maybe ThreadId) + -> item + -> IO () add settings@ReaperSettings{..} stateRef tidRef item = mask_ $ do - next <- atomicModifyIORef' stateRef cons - next + next <- atomicModifyIORef' stateRef cons + next where - cons NoReaper = let wl = reaperCons item reaperEmpty - in (Workload wl, spawn settings stateRef tidRef) - cons (Workload wl) = let wl' = reaperCons item wl - in (Workload wl', return ()) + cons NoReaper = + let wl = reaperCons item reaperEmpty + in (Workload wl, spawn settings stateRef tidRef) + cons (Workload wl) = + let wl' = reaperCons item wl + in (Workload wl', return ()) -spawn :: ReaperSettings workload item - -> IORef (State workload) -> IORef (Maybe ThreadId) - -> IO () +spawn + :: ReaperSettings workload item + -> IORef (State workload) + -> IORef (Maybe ThreadId) + -> IO () spawn settings stateRef tidRef = do tid <- forkIO $ reaper settings stateRef tidRef writeIORef tidRef $ Just tid -reaper :: ReaperSettings workload item - -> IORef (State workload) -> IORef (Maybe ThreadId) - -> IO () +reaper + :: ReaperSettings workload item + -> IORef (State workload) + -> IORef (Maybe ThreadId) + -> IO () reaper settings@ReaperSettings{..} stateRef tidRef = do threadDelay reaperDelay -- Getting the current jobs. Push an empty job to the reference. @@ -181,15 +199,15 @@ reaper settings@ReaperSettings{..} stateRef tidRef = do next <- atomicModifyIORef' stateRef (check merge) next where - swapWithEmpty NoReaper = error "Control.Reaper.reaper: unexpected NoReaper (1)" + swapWithEmpty NoReaper = error "Control.Reaper.reaper: unexpected NoReaper (1)" swapWithEmpty (Workload wl) = (Workload reaperEmpty, wl) - check _ NoReaper = error "Control.Reaper.reaper: unexpected NoReaper (2)" + check _ NoReaper = error "Control.Reaper.reaper: unexpected NoReaper (2)" check merge (Workload wl) - -- If there is no job, reaper is terminated. - | reaperNull wl' = (NoReaper, writeIORef tidRef Nothing) - -- If there are jobs, carry them out. - | otherwise = (Workload wl', reaper settings stateRef tidRef) + -- If there is no job, reaper is terminated. + | reaperNull wl' = (NoReaper, writeIORef tidRef Nothing) + -- If there are jobs, carry them out. + | otherwise = (Workload wl', reaper settings stateRef tidRef) where wl' = merge wl @@ -199,19 +217,20 @@ reaper settings@ReaperSettings{..} stateRef tidRef = do -- expired. -- -- @since 0.1.1 -mkListAction :: (item -> IO (Maybe item')) - -> [item] - -> IO ([item'] -> [item']) +mkListAction + :: (item -> IO (Maybe item')) + -> [item] + -> IO ([item'] -> [item']) mkListAction f = go id where go !front [] = return front - go !front (x:xs) = do + go !front (x : xs) = do my <- f x let front' = case my of Nothing -> front - Just y -> front . (y:) + Just y -> front . (y :) go front' xs -- $example1 diff --git a/auto-update/Setup.hs b/auto-update/Setup.hs index 9a994af67..e8ef27dbb 100644 --- a/auto-update/Setup.hs +++ b/auto-update/Setup.hs @@ -1,2 +1,3 @@ import Distribution.Simple + main = defaultMain diff --git a/auto-update/test/Control/AutoUpdateSpec.hs b/auto-update/test/Control/AutoUpdateSpec.hs index 45261c082..7d95711e0 100644 --- a/auto-update/test/Control/AutoUpdateSpec.hs +++ b/auto-update/test/Control/AutoUpdateSpec.hs @@ -5,31 +5,33 @@ module Control.AutoUpdateSpec (spec) where -- import Control.Monad (replicateM_, forM_) -- import Data.IORef import Test.Hspec + -- import Test.Hspec.QuickCheck spec :: Spec spec = return () - -- do - -- prop "incrementer" $ \st' -> do - -- let st = abs st' `mod` 10000 - -- ref <- newIORef 0 - -- next <- mkAutoUpdate defaultUpdateSettings - -- { updateAction = atomicModifyIORef ref $ \i -> - -- let i' = succ i in i' `seq` (i', i') - -- , updateSpawnThreshold = st - -- , updateFreq = 10000 - -- } - -- forM_ [1..st + 1] $ \i -> do - -- j <- next - -- j `shouldBe` i +-- do +-- prop "incrementer" $ \st' -> do +-- let st = abs st' `mod` 10000 +-- ref <- newIORef 0 +-- next <- mkAutoUpdate defaultUpdateSettings +-- { updateAction = atomicModifyIORef ref $ \i -> +-- let i' = succ i in i' `seq` (i', i') +-- , updateSpawnThreshold = st +-- , updateFreq = 10000 +-- } + +-- forM_ [1..st + 1] $ \i -> do +-- j <- next +-- j `shouldBe` i - -- replicateM_ 50 $ do - -- i <- next - -- i `shouldBe` st + 2 +-- replicateM_ 50 $ do +-- i <- next +-- i `shouldBe` st + 2 - -- threadDelay 60000 - -- last1 <- readIORef ref - -- threadDelay 20000 - -- last2 <- readIORef ref - -- last2 `shouldBe` last1 +-- threadDelay 60000 +-- last1 <- readIORef ref +-- threadDelay 20000 +-- last2 <- readIORef ref +-- last2 `shouldBe` last1 diff --git a/auto-update/test/Control/DebounceSpec.hs b/auto-update/test/Control/DebounceSpec.hs index 7729eb3b0..0be944077 100644 --- a/auto-update/test/Control/DebounceSpec.hs +++ b/auto-update/test/Control/DebounceSpec.hs @@ -75,7 +75,6 @@ spec = describe "mkDebounce" $ do returnFromWait waitUntil 5 $ readIORef ref `shouldReturn` 1 - -- | Make a controllable delay function getWaitAction :: IO (p -> IO (), IO ()) getWaitAction = do @@ -87,20 +86,24 @@ getWaitAction = do -- | Get a debounce system with access to the internals for testing getDebounce :: DI.DebounceEdge -> IO (IORef Int, IO (), MVar (), IO ()) getDebounce edge = do - ref <- newIORef 0 - let action = modifyIORef ref (+ 1) + ref <- newIORef 0 + let action = modifyIORef ref (+ 1) - (waitAction, returnFromWait) <- getWaitAction + (waitAction, returnFromWait) <- getWaitAction - baton <- newEmptyMVar + baton <- newEmptyMVar - debounced <- DI.mkDebounceInternal baton waitAction defaultDebounceSettings { - debounceFreq = 5000000 -- unused - , debounceAction = action - , debounceEdge = edge - } + debounced <- + DI.mkDebounceInternal + baton + waitAction + defaultDebounceSettings + { debounceFreq = 5000000 -- unused + , debounceAction = action + , debounceEdge = edge + } - return (ref, debounced, baton, returnFromWait) + return (ref, debounced, baton, returnFromWait) -- | Pause briefly (100ms) pause :: IO () @@ -112,8 +115,9 @@ waitForBatonToBeTaken baton = waitUntil 5 $ tryReadMVar baton `shouldReturn` Not -- | Wait up to n seconds for an action to complete without throwing an HUnitFailure waitUntil :: Int -> IO a -> IO () waitUntil n action = recovering policy [handler] (\_status -> void action) - where policy = constantDelay 1000 `mappend` limitRetries (n * 1000) -- 1ms * n * 1000 tries = n seconds - handler _status = Handler (\HUnitFailure{} -> return True) + where + policy = constantDelay 1000 `mappend` limitRetries (n * 1000) -- 1ms * n * 1000 tries = n seconds + handler _status = Handler (\HUnitFailure{} -> return True) main :: IO () main = hspec spec diff --git a/auto-update/test/Control/ReaperSpec.hs b/auto-update/test/Control/ReaperSpec.hs index 64ccbf7c8..568a79cf6 100644 --- a/auto-update/test/Control/ReaperSpec.hs +++ b/auto-update/test/Control/ReaperSpec.hs @@ -4,11 +4,12 @@ module Control.ReaperSpec (spec) where -- import Control.Reaper -- import Data.IORef import Test.Hspec --- import Test.Hspec.QuickCheck +-- import Test.Hspec.QuickCheck spec :: Spec spec = return () + -- prop "works" $ \is -> do -- reaper <- mkReaper defaultReaperSettings -- { reaperAction = action diff --git a/fourmolu.yaml b/fourmolu.yaml new file mode 100644 index 000000000..13564e825 --- /dev/null +++ b/fourmolu.yaml @@ -0,0 +1,48 @@ +# Number of spaces per indentation step +indentation: 4 + +# Max line length for automatic line breaking +column-limit: 80 + +# Styling of arrows in type signatures (choices: trailing, leading, or leading-args) +function-arrows: leading + +# How to place commas in multi-line lists, records, etc. (choices: leading or trailing) +comma-style: leading + +# Styling of import/export lists (choices: leading, trailing, or diff-friendly) +import-export-style: diff-friendly + +# Whether to full-indent or half-indent 'where' bindings past the preceding body +indent-wheres: false + +# Whether to leave a space before an opening record brace +record-brace-space: false + +# Number of spaces between top-level declarations +newlines-between-decls: 1 + +# How to print Haddock comments (choices: single-line, multi-line, or multi-line-compact) +haddock-style: single-line + +# How to print module docstring +haddock-style-module: null + +# Styling of let blocks (choices: auto, inline, newline, or mixed) +let-style: inline + +# How to align the 'in' keyword with respect to the 'let' keyword (choices: left-align, right-align, or no-space) +in-style: right-align + +# Whether to put parentheses around a single constraint (choices: auto, always, or never) +single-constraint-parens: never + +# Output Unicode syntax (choices: detect, always, or never) +unicode: never + +# Give the programmer more choice on where to insert blank lines +respectful: true + +# Fixity information for operators +fixities: [] + diff --git a/mime-types/Network/Mime.hs b/mime-types/Network/Mime.hs index c27e46343..971a651a2 100644 --- a/mime-types/Network/Mime.hs +++ b/mime-types/Network/Mime.hs @@ -1,28 +1,32 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.Mime - ( -- * Lookups - mimeByExt - , defaultMimeLookup - -- * Defaults - , defaultMimeType - , defaultMimeMap - , defaultExtensionMap - -- * Utilities - , fileNameExtensions - -- * Types - , FileName - , MimeType - , MimeMap - , Extension - , ExtensionMap - ) where -import qualified Data.List as L -import Data.Text (Text) -import qualified Data.Text as T +module Network.Mime ( + -- * Lookups + mimeByExt, + defaultMimeLookup, + + -- * Defaults + defaultMimeType, + defaultMimeMap, + defaultExtensionMap, + + -- * Utilities + fileNameExtensions, + + -- * Types + FileName, + MimeType, + MimeMap, + Extension, + ExtensionMap, +) where + import Data.ByteString (ByteString) import Data.ByteString.Char8 () +import qualified Data.List as L import qualified Data.Map as Map +import Data.Text (Text) +import qualified Data.Text as T -- | Maps extensions to mime types. type MimeMap = Map.Map Extension MimeType @@ -41,15 +45,17 @@ type FileName = Text type MimeType = ByteString -- | Look up a mime type from the given mime map and default mime type. -mimeByExt :: MimeMap - -> MimeType -- ^ default mime type - -> FileName - -> MimeType +mimeByExt + :: MimeMap + -> MimeType + -- ^ default mime type + -> FileName + -> MimeType mimeByExt mm def = go . fileNameExtensions where go [] = def - go (e:es) = + go (e : es) = case Map.lookup e mm of Nothing -> go es Just mt -> mt @@ -274,10 +280,16 @@ mimeAscList = , ("dna", "application/vnd.dna") , ("doc", "application/msword") , ("docm", "application/vnd.ms-word.document.macroenabled.12") - , ("docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document") + , + ( "docx" + , "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ) , ("dot", "application/msword") , ("dotm", "application/vnd.ms-word.template.macroenabled.12") - , ("dotx", "application/vnd.openxmlformats-officedocument.wordprocessingml.template") + , + ( "dotx" + , "application/vnd.openxmlformats-officedocument.wordprocessingml.template" + ) , ("dp", "application/vnd.osgi.dp") , ("dpg", "application/vnd.dpgraph") , ("dra", "audio/vnd.dra") @@ -494,11 +506,9 @@ mimeAscList = , ("lbd", "application/vnd.llamagraphics.life-balance.desktop") , ("lbe", "application/vnd.llamagraphics.life-balance.exchange+xml") , ("les", "application/vnd.hhe.lesson-player") - - -- Added after deliberation in PR (https://github.com/yesodweb/wai/pull/534) - -- Accepted mainly because of StackOverflow (http://stackoverflow.com/questions/7319555/how-to-add-less-to-iis-7-0) - , ("less", "text/css") - + , -- Added after deliberation in PR (https://github.com/yesodweb/wai/pull/534) + -- Accepted mainly because of StackOverflow (http://stackoverflow.com/questions/7319555/how-to-add-less-to-iis-7-0) + ("less", "text/css") , ("lha", "application/x-lzh-compressed") , ("link66", "application/vnd.route66.link66+xml") , ("list", "text/plain") @@ -729,16 +739,25 @@ mimeAscList = , ("portpkg", "application/vnd.macports.portpkg") , ("pot", "application/vnd.ms-powerpoint") , ("potm", "application/vnd.ms-powerpoint.template.macroenabled.12") - , ("potx", "application/vnd.openxmlformats-officedocument.presentationml.template") + , + ( "potx" + , "application/vnd.openxmlformats-officedocument.presentationml.template" + ) , ("ppam", "application/vnd.ms-powerpoint.addin.macroenabled.12") , ("ppd", "application/vnd.cups-ppd") , ("ppm", "image/x-portable-pixmap") , ("pps", "application/vnd.ms-powerpoint") , ("ppsm", "application/vnd.ms-powerpoint.slideshow.macroenabled.12") - , ("ppsx", "application/vnd.openxmlformats-officedocument.presentationml.slideshow") + , + ( "ppsx" + , "application/vnd.openxmlformats-officedocument.presentationml.slideshow" + ) , ("ppt", "application/vnd.ms-powerpoint") , ("pptm", "application/vnd.ms-powerpoint.presentation.macroenabled.12") - , ("pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation") + , + ( "pptx" + , "application/vnd.openxmlformats-officedocument.presentationml.presentation" + ) , ("pqa", "application/vnd.palm") , ("prc", "application/x-mobipocket-ebook") , ("pre", "application/vnd.lotus-freelance") diff --git a/recv/Network/Socket/BufferPool.hs b/recv/Network/Socket/BufferPool.hs index da3930e72..de99fb333 100644 --- a/recv/Network/Socket/BufferPool.hs +++ b/recv/Network/Socket/BufferPool.hs @@ -21,22 +21,25 @@ -- When the buffer gets small -- and usefless, a new large buffer is allocated. module Network.Socket.BufferPool ( - -- * Recv - Recv - , receive - , BufferPool - , newBufferPool - , withBufferPool - -- * RecvN - , RecvN - , makeRecvN - -- * Types - , Buffer - , BufSize - -- * Utilities - , mallocBS - , copy - ) where + -- * Recv + Recv, + receive, + BufferPool, + newBufferPool, + withBufferPool, + + -- * RecvN + RecvN, + makeRecvN, + + -- * Types + Buffer, + BufSize, + + -- * Utilities + mallocBS, + copy, +) where import Network.Socket.BufferPool.Buffer import Network.Socket.BufferPool.Recv diff --git a/recv/Network/Socket/BufferPool/Buffer.hs b/recv/Network/Socket/BufferPool/Buffer.hs index 7bdafac65..021c9ee30 100644 --- a/recv/Network/Socket/BufferPool/Buffer.hs +++ b/recv/Network/Socket/BufferPool/Buffer.hs @@ -1,16 +1,16 @@ module Network.Socket.BufferPool.Buffer ( - newBufferPool - , withBufferPool - , mallocBS - , copy - ) where + newBufferPool, + withBufferPool, + mallocBS, + copy, +) where import qualified Data.ByteString as BS -import Data.ByteString.Internal (ByteString(..)) -import Data.ByteString.Unsafe (unsafeTake, unsafeDrop) +import Data.ByteString.Internal (ByteString (..)) +import Data.ByteString.Unsafe (unsafeDrop, unsafeTake) import Data.IORef (newIORef, readIORef, writeIORef) import Foreign.ForeignPtr -import Foreign.Marshal.Alloc (mallocBytes, finalizerFree) +import Foreign.Marshal.Alloc (finalizerFree, mallocBytes) import Foreign.Marshal.Utils (copyBytes) import Foreign.Ptr (castPtr, plusPtr) @@ -36,8 +36,10 @@ newBufferPool l h = BufferPool l h <$> newIORef BS.empty withBufferPool :: BufferPool -> (Buffer -> BufSize -> IO Int) -> IO ByteString withBufferPool (BufferPool l h ref) f = do buf0 <- readIORef ref - buf <- if BS.length buf0 >= l then return buf0 - else mallocBS h + buf <- + if BS.length buf0 >= l + then return buf0 + else mallocBS h consumed <- withForeignBuffer buf f writeIORef ref $ unsafeDrop consumed buf return $ unsafeTake consumed buf diff --git a/recv/Network/Socket/BufferPool/Types.hs b/recv/Network/Socket/BufferPool/Types.hs index bf3149fa5..bf7e4b8b3 100644 --- a/recv/Network/Socket/BufferPool/Types.hs +++ b/recv/Network/Socket/BufferPool/Types.hs @@ -12,14 +12,15 @@ type Buffer = Ptr Word8 type BufSize = Int -- | Type for read buffer pool. -data BufferPool = BufferPool { - minBufSize :: Int -- ^ If the buffer is larger than or equal to this size, - -- the buffer is used. - -- Otherwise, a new buffer is allocated. - -- The thrown buffer is eventually freed. - , maxBufSize :: Int - , poolBuffer :: IORef ByteString - } +data BufferPool = BufferPool + { minBufSize :: Int + -- ^ If the buffer is larger than or equal to this size, + -- the buffer is used. + -- Otherwise, a new buffer is allocated. + -- The thrown buffer is eventually freed. + , maxBufSize :: Int + , poolBuffer :: IORef ByteString + } -- | Type for the receiving function with a buffer pool. type Recv = IO ByteString diff --git a/recv/Network/Socket/BufferPool/Windows.hs b/recv/Network/Socket/BufferPool/Windows.hs index a28ce1fff..af6b976b0 100644 --- a/recv/Network/Socket/BufferPool/Windows.hs +++ b/recv/Network/Socket/BufferPool/Windows.hs @@ -1,7 +1,8 @@ {-# LANGUAGE CPP #-} -module Network.Socket.BufferPool.Windows - ( windowsThreadBlockHack - ) where + +module Network.Socket.BufferPool.Windows ( + windowsThreadBlockHack, +) where #ifdef mingw32_HOST_OS import Control.Concurrent.MVar diff --git a/recv/test/BufferPoolSpec.hs b/recv/test/BufferPoolSpec.hs index e0e34e4ec..28e145b85 100644 --- a/recv/test/BufferPoolSpec.hs +++ b/recv/test/BufferPoolSpec.hs @@ -1,20 +1,20 @@ module BufferPoolSpec where import qualified Data.ByteString as B -import qualified Data.ByteString.Internal as B (ByteString(PS)) +import qualified Data.ByteString.Internal as B (ByteString (PS)) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Marshal.Utils (copyBytes) import Foreign.Ptr (plusPtr) import Network.Socket.BufferPool -import Test.Hspec (Spec, hspec, shouldBe, describe, it) +import Test.Hspec (Spec, describe, hspec, it, shouldBe) main :: IO () main = hspec spec -- Two ByteStrings each big enough to fill a buffer (16K). wantData, otherData :: B.ByteString -wantData = B.replicate 16384 0xac +wantData = B.replicate 16384 0xac otherData = B.replicate 16384 0x77 spec :: Spec diff --git a/time-manager/System/TimeManager.hs b/time-manager/System/TimeManager.hs index 016e9402e..dba938f7b 100644 --- a/time-manager/System/TimeManager.hs +++ b/time-manager/System/TimeManager.hs @@ -1,34 +1,38 @@ {-# LANGUAGE DeriveDataTypeable #-} module System.TimeManager ( - -- ** Types - Manager - , TimeoutAction - , Handle - -- ** Manager - , initialize - , stopManager - , killManager - , withManager - , withManager' - -- ** Registration - , register - , registerKillThread - -- ** Control - , tickle - , cancel - , pause - , resume - -- ** Exceptions - , TimeoutThread (..) - ) where + -- ** Types + Manager, + TimeoutAction, + Handle, + + -- ** Manager + initialize, + stopManager, + killManager, + withManager, + withManager', + + -- ** Registration + register, + registerKillThread, + + -- ** Control + tickle, + cancel, + pause, + resume, + + -- ** Exceptions + TimeoutThread (..), +) where import Control.Concurrent (myThreadId) -import qualified UnliftIO.Exception as E import Control.Reaper -import Data.Typeable (Typeable) import Data.IORef (IORef) import qualified Data.IORef as I +import Data.Typeable (Typeable) +import qualified UnliftIO.Exception as E ---------------------------------------------------------------- @@ -41,20 +45,23 @@ type TimeoutAction = IO () -- | A handle used by 'Manager' data Handle = Handle !(IORef TimeoutAction) !(IORef State) -data State = Active -- Manager turns it to Inactive. - | Inactive -- Manager removes it with timeout action. - | Paused -- Manager does not change it. - | Canceled -- Manager removes it without timeout action. +data State + = Active -- Manager turns it to Inactive. + | Inactive -- Manager removes it with timeout action. + | Paused -- Manager does not change it. + | Canceled -- Manager removes it without timeout action. ---------------------------------------------------------------- -- | Creating timeout manager which works every N micro seconds -- where N is the first argument. initialize :: Int -> IO Manager -initialize timeout = mkReaper defaultReaperSettings - { reaperAction = mkListAction prune - , reaperDelay = timeout - } +initialize timeout = + mkReaper + defaultReaperSettings + { reaperAction = mkListAction prune + , reaperDelay = timeout + } where prune m@(Handle actionRef stateRef) = do state <- I.atomicModifyIORef' stateRef (\x -> (inactivate x, x)) @@ -64,7 +71,7 @@ initialize timeout = mkReaper defaultReaperSettings onTimeout `E.catch` ignoreAll return Nothing Canceled -> return Nothing - _ -> return $ Just m + _ -> return $ Just m inactivate Active = Inactive inactivate x = x @@ -92,7 +99,7 @@ killManager = reaperKill register :: Manager -> TimeoutAction -> IO Handle register mgr onTimeout = do actionRef <- I.newIORef onTimeout - stateRef <- I.newIORef Active + stateRef <- I.newIORef Active let h = Handle actionRef stateRef reaperAdd mgr h return h @@ -110,7 +117,7 @@ registerKillThread m onTimeout = do register m $ onTimeout `E.finally` E.throwTo tid TimeoutThread data TimeoutThread = TimeoutThread - deriving Typeable + deriving (Typeable) instance E.Exception TimeoutThread where toException = E.asyncExceptionToException fromException = E.asyncExceptionFromException @@ -145,18 +152,26 @@ resume = tickle -- | Call the inner function with a timeout manager. -- 'stopManager' is used after that. -withManager :: Int -- ^ timeout in microseconds - -> (Manager -> IO a) - -> IO a -withManager timeout f = E.bracket (initialize timeout) - stopManager - f +withManager + :: Int + -- ^ timeout in microseconds + -> (Manager -> IO a) + -> IO a +withManager timeout f = + E.bracket + (initialize timeout) + stopManager + f -- | Call the inner function with a timeout manager. -- 'killManager' is used after that. -withManager' :: Int -- ^ timeout in microseconds - -> (Manager -> IO a) - -> IO a -withManager' timeout f = E.bracket (initialize timeout) - killManager - f +withManager' + :: Int + -- ^ timeout in microseconds + -> (Manager -> IO a) + -> IO a +withManager' timeout f = + E.bracket + (initialize timeout) + killManager + f diff --git a/wai-app-static/Network/Wai/Application/Static.hs b/wai-app-static/Network/Wai/Application/Static.hs index 8e98c9737..86573e460 100644 --- a/wai-app-static/Network/Wai/Application/Static.hs +++ b/wai-app-static/Network/Wai/Application/Static.hs @@ -1,38 +1,42 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell, CPP #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE PatternGuards #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} + -- | Static file serving for WAI. -module Network.Wai.Application.Static - ( -- * WAI application - staticApp - -- ** Default Settings - , defaultWebAppSettings - , webAppSettingsWithLookup - , defaultFileServerSettings - , embeddedSettings - -- ** Settings - , StaticSettings - , ssLookupFile - , ssMkRedirect - , ssGetMimeType - , ssListing - , ssIndices - , ssMaxAge - , ssRedirectToIndex - , ssAddTrailingSlash - , ss404Handler - ) where +module Network.Wai.Application.Static ( + -- * WAI application + staticApp, + + -- ** Default Settings + defaultWebAppSettings, + webAppSettingsWithLookup, + defaultFileServerSettings, + embeddedSettings, + + -- ** Settings + StaticSettings, + ssLookupFile, + ssMkRedirect, + ssGetMimeType, + ssListing, + ssIndices, + ssMaxAge, + ssRedirectToIndex, + ssAddTrailingSlash, + ss404Handler, +) where -import Prelude hiding (FilePath) -import qualified Network.Wai as W -import qualified Network.HTTP.Types as H +import Control.Monad.IO.Class (liftIO) import Data.ByteString (ByteString) import qualified Data.ByteString.Char8 as S8 import qualified Data.ByteString.Lazy as L import Data.ByteString.Lazy.Char8 () -import Control.Monad.IO.Class (liftIO) +import qualified Network.HTTP.Types as H +import qualified Network.Wai as W +import Prelude hiding (FilePath) import Data.ByteString.Builder (toLazyByteString) @@ -42,51 +46,63 @@ import Data.Text (Text) import qualified Data.Text as T import qualified Data.Text.Encoding as TE -import Network.HTTP.Date (parseHTTPDate, epochTimeToHTTPDate, formatHTTPDate) +import Network.HTTP.Date (epochTimeToHTTPDate, formatHTTPDate, parseHTTPDate) -import WaiAppStatic.Types +import Network.Mime (MimeType) import Util -import WaiAppStatic.Storage.Filesystem import WaiAppStatic.Storage.Embedded -import Network.Mime (MimeType) +import WaiAppStatic.Storage.Filesystem +import WaiAppStatic.Types -data StaticResponse = - -- | Just the etag hash or Nothing for no etag hash +data StaticResponse + = -- | Just the etag hash or Nothing for no etag hash Redirect Pieces (Maybe ByteString) | RawRedirect ByteString | NotFound | FileResponse File H.ResponseHeaders | NotModified - -- TODO: add file size - | SendContent MimeType L.ByteString + | -- TODO: add file size + SendContent MimeType L.ByteString | WaiResponse W.Response -safeInit :: [a] -> [a] +safeInit :: [a] -> [a] safeInit [] = [] safeInit xs = init xs filterButLast :: (a -> Bool) -> [a] -> [a] filterButLast _ [] = [] filterButLast _ [x] = [x] -filterButLast f (x:xs) +filterButLast f (x : xs) | f x = x : filterButLast f xs | otherwise = filterButLast f xs -- | Serve an appropriate response for a folder request. -serveFolder :: StaticSettings -> Pieces -> W.Request -> Folder -> IO StaticResponse -serveFolder StaticSettings {..} pieces req folder = +serveFolder + :: StaticSettings -> Pieces -> W.Request -> Folder -> IO StaticResponse +serveFolder StaticSettings{..} pieces req folder = case ssListing of - Just _ | Just path <- addTrailingSlash req, ssAddTrailingSlash -> - return $ RawRedirect path + Just _ + | Just path <- addTrailingSlash req + , ssAddTrailingSlash -> + return $ RawRedirect path Just listing -> do -- directory listings turned on, display it builder <- listing pieces folder - return $ WaiResponse $ W.responseBuilder H.status200 - [ ("Content-Type", "text/html; charset=utf-8") - ] builder - Nothing -> return $ WaiResponse $ W.responseLBS H.status403 - [ ("Content-Type", "text/plain") - ] "Directory listings disabled" + return $ + WaiResponse $ + W.responseBuilder + H.status200 + [ ("Content-Type", "text/html; charset=utf-8") + ] + builder + Nothing -> + return $ + WaiResponse $ + W.responseLBS + H.status403 + [ ("Content-Type", "text/plain") + ] + "Directory listings disabled" addTrailingSlash :: W.Request -> Maybe ByteString addTrailingSlash req @@ -96,16 +112,18 @@ addTrailingSlash req where rp = W.rawPathInfo req -checkPieces :: StaticSettings - -> Pieces -- ^ parsed request - -> W.Request - -> IO StaticResponse +checkPieces + :: StaticSettings + -> Pieces + -- ^ parsed request + -> W.Request + -> IO StaticResponse -- If we have any empty pieces in the middle of the requested path, generate a -- redirect to get rid of them. -checkPieces _ pieces _ | any (T.null . fromPiece) $ safeInit pieces = - return $ Redirect (filterButLast (not . T.null . fromPiece) pieces) Nothing - -checkPieces ss@StaticSettings {..} pieces req = do +checkPieces _ pieces _ + | any (T.null . fromPiece) $ safeInit pieces = + return $ Redirect (filterButLast (not . T.null . fromPiece) pieces) Nothing +checkPieces ss@StaticSettings{..} pieces req = do res <- lookupResult case res of Left location -> return $ RawRedirect location @@ -115,32 +133,35 @@ checkPieces ss@StaticSettings {..} pieces req = do where lookupResult :: IO (Either ByteString LookupResult) lookupResult = do - nonIndexResult <- ssLookupFile pieces - case nonIndexResult of - LRFile{} -> return $ Right nonIndexResult - _ -> do - eIndexResult <- lookupIndices (map (\ index -> dropLastIfNull pieces ++ [index]) ssIndices) - return $ case eIndexResult of - Left redirect -> Left redirect - Right indexResult -> case indexResult of - LRNotFound -> Right nonIndexResult - LRFile file | ssRedirectToIndex -> - let relPath = - case reverse pieces of - -- Served at root - [] -> fromPiece $ fileName file - lastSegment:_ -> - case fromPiece lastSegment of - -- Ends with a trailing slash - "" -> fromPiece $ fileName file - -- Lacks a trailing slash - lastSegment' -> T.concat - [ lastSegment' - , "/" - , fromPiece $ fileName file - ] - in Left $ TE.encodeUtf8 relPath - _ -> Right indexResult + nonIndexResult <- ssLookupFile pieces + case nonIndexResult of + LRFile{} -> return $ Right nonIndexResult + _ -> do + eIndexResult <- + lookupIndices (map (\index -> dropLastIfNull pieces ++ [index]) ssIndices) + return $ case eIndexResult of + Left redirect -> Left redirect + Right indexResult -> case indexResult of + LRNotFound -> Right nonIndexResult + LRFile file + | ssRedirectToIndex -> + let relPath = + case reverse pieces of + -- Served at root + [] -> fromPiece $ fileName file + lastSegment : _ -> + case fromPiece lastSegment of + -- Ends with a trailing slash + "" -> fromPiece $ fileName file + -- Lacks a trailing slash + lastSegment' -> + T.concat + [ lastSegment' + , "/" + , fromPiece $ fileName file + ] + in Left $ TE.encodeUtf8 relPath + _ -> Right indexResult lookupIndices :: [Pieces] -> IO (Either ByteString LookupResult) lookupIndices (x : xs) = do @@ -153,7 +174,7 @@ checkPieces ss@StaticSettings {..} pieces req = do lookupIndices [] = return $ Right LRNotFound serveFile :: StaticSettings -> W.Request -> File -> IO StaticResponse -serveFile StaticSettings {..} req file +serveFile StaticSettings{..} req file -- First check etag values, if turned on | ssUseHash = do mHash <- fileGetHash file @@ -161,7 +182,6 @@ serveFile StaticSettings {..} req file case (mHash, lookup "if-none-match" $ W.requestHeaders req) of -- if-none-match matches the actual hash, return a 304 (Just hash, Just lastHash) | hash == lastHash -> return NotModified - -- Didn't match, but we have a hash value. Send the file contents -- with an ETag header. -- @@ -173,7 +193,6 @@ serveFile StaticSettings {..} req file -- > interoperating with older intermediaries that might not implement -- > If-None-Match. (Just hash, _) -> respond [("ETag", hash)] - -- No hash value available, fall back to last modified support. (Nothing, _) -> lastMod -- etag turned off, so jump straight to last modified @@ -188,10 +207,8 @@ serveFile StaticSettings {..} req file -- Question: should the comparison be, date <= lastSent? (Just mdate, Just lastSent) | mdate == lastSent -> return NotModified - -- Did not match, but we have a new last-modified header (Just mdate, _) -> respond [("last-modified", formatHTTPDate mdate)] - -- No modification time available (Nothing, _) -> respond [] @@ -217,11 +234,11 @@ cacheControl maxage = MaxAgeForever -> (:) ("Cache-Control", maxAgeValue oneYear) NoStore -> (:) ("Cache-Control", "no-store") headerExpires = - case maxage of - NoMaxAge -> id - MaxAgeSeconds _ -> id -- FIXME - MaxAgeForever -> (:) ("Expires", "Thu, 31 Dec 2037 23:55:55 GMT") - NoStore -> id + case maxage of + NoMaxAge -> id + MaxAgeSeconds _ -> id -- FIXME + MaxAgeForever -> (:) ("Expires", "Thu, 31 Dec 2037 23:55:55 GMT") + NoStore -> id -- | Turn a @StaticSettings@ into a WAI application. staticApp :: StaticSettings -> W.Application @@ -229,61 +246,84 @@ staticApp set req = staticAppPieces set (W.pathInfo req) req staticAppPieces :: StaticSettings -> [Text] -> W.Application staticAppPieces _ _ req sendResponse - | notElem (W.requestMethod req) ["GET", "HEAD"] = sendResponse $ W.responseLBS - H.status405 - [("Content-Type", "text/plain")] - "Only GET or HEAD is supported" -staticAppPieces _ [".hidden", "folder.png"] _ sendResponse = sendResponse $ W.responseLBS H.status200 [("Content-Type", "image/png")] $ L.fromChunks [$(makeRelativeToProject "images/folder.png" >>= embedFile)] -staticAppPieces _ [".hidden", "haskell.png"] _ sendResponse = sendResponse $ W.responseLBS H.status200 [("Content-Type", "image/png")] $ L.fromChunks [$(makeRelativeToProject "images/haskell.png" >>= embedFile)] + | notElem (W.requestMethod req) ["GET", "HEAD"] = + sendResponse $ + W.responseLBS + H.status405 + [("Content-Type", "text/plain")] + "Only GET or HEAD is supported" +staticAppPieces _ [".hidden", "folder.png"] _ sendResponse = + sendResponse $ + W.responseLBS H.status200 [("Content-Type", "image/png")] $ + L.fromChunks [$(makeRelativeToProject "images/folder.png" >>= embedFile)] +staticAppPieces _ [".hidden", "haskell.png"] _ sendResponse = + sendResponse $ + W.responseLBS H.status200 [("Content-Type", "image/png")] $ + L.fromChunks [$(makeRelativeToProject "images/haskell.png" >>= embedFile)] staticAppPieces ss rawPieces req sendResponse = liftIO $ do case toPieces rawPieces of Just pieces -> checkPieces ss pieces req >>= response - Nothing -> sendResponse $ W.responseLBS H.status403 - [ ("Content-Type", "text/plain") - ] "Forbidden" + Nothing -> + sendResponse $ + W.responseLBS + H.status403 + [ ("Content-Type", "text/plain") + ] + "Forbidden" where response :: StaticResponse -> IO W.ResponseReceived response (FileResponse file ch) = do mimetype <- ssGetMimeType ss file -- let filesize = fileGetSize file - let headers = ("Content-Type", mimetype) + let headers = + ("Content-Type", mimetype) -- Let Warp provide the content-length, since it takes -- range requests into account -- : ("Content-Length", S8.pack $ show filesize) : ch sendResponse $ fileToResponse file H.status200 headers - response NotModified = - sendResponse $ W.responseLBS H.status304 [] "" - + sendResponse $ W.responseLBS H.status304 [] "" response (SendContent mt lbs) = do - -- TODO: set caching headers - sendResponse $ W.responseLBS H.status200 + -- TODO: set caching headers + sendResponse $ + W.responseLBS + H.status200 [ ("Content-Type", mt) - -- TODO: set Content-Length - ] lbs - + -- TODO: set Content-Length + ] + lbs response (Redirect pieces' mHash) = do - let loc = ssMkRedirect ss pieces' $ L.toStrict $ toLazyByteString (H.encodePathSegments $ map fromPiece pieces') - let qString = case mHash of - Just hash -> replace "etag" (Just hash) (W.queryString req) - Nothing -> remove "etag" (W.queryString req) - - sendResponse $ W.responseLBS H.status302 + let loc = + ssMkRedirect ss pieces' $ + L.toStrict $ + toLazyByteString (H.encodePathSegments $ map fromPiece pieces') + let qString = case mHash of + Just hash -> replace "etag" (Just hash) (W.queryString req) + Nothing -> remove "etag" (W.queryString req) + + sendResponse $ + W.responseLBS + H.status302 [ ("Content-Type", "text/plain") , ("Location", S8.append loc $ H.renderQuery True qString) - ] "Redirect" - + ] + "Redirect" response (RawRedirect path) = - sendResponse $ W.responseLBS H.status302 + sendResponse $ + W.responseLBS + H.status302 [ ("Content-Type", "text/plain") , ("Location", path) - ] "Redirect" - + ] + "Redirect" response NotFound = case ss404Handler ss of Just app -> app req sendResponse - Nothing -> sendResponse $ W.responseLBS H.status404 - [ ("Content-Type", "text/plain") - ] "File not found" - + Nothing -> + sendResponse $ + W.responseLBS + H.status404 + [ ("Content-Type", "text/plain") + ] + "File not found" response (WaiResponse r) = sendResponse r diff --git a/wai-app-static/Util.hs b/wai-app-static/Util.hs index 518f1cbe4..ce3640426 100644 --- a/wai-app-static/Util.hs +++ b/wai-app-static/Util.hs @@ -1,28 +1,32 @@ -{-# LANGUAGE OverloadedStrings, ViewPatterns #-} -module Util - ( relativeDirFromPieces - , defaultMkRedirect - , replace - , remove - , dropLastIfNull - ) where +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ViewPatterns #-} + +module Util ( + relativeDirFromPieces, + defaultMkRedirect, + replace, + remove, + dropLastIfNull, +) where -import WaiAppStatic.Types -import qualified Data.Text as T import Data.ByteString (ByteString) import qualified Data.ByteString.Char8 as S8 +import qualified Data.Text as T import qualified Data.Text.Encoding as TE +import WaiAppStatic.Types -- alist helper functions replace :: Eq a => a -> b -> [(a, b)] -> [(a, b)] -replace k v [] = [(k,v)] -replace k v (x:xs) | fst x == k = (k,v):xs - | otherwise = x:replace k v xs +replace k v [] = [(k, v)] +replace k v (x : xs) + | fst x == k = (k, v) : xs + | otherwise = x : replace k v xs remove :: Eq a => a -> [(a, b)] -> [(a, b)] remove _ [] = [] -remove k (x:xs) | fst x == k = xs - | otherwise = x:remove k xs +remove k (x : xs) + | fst x == k = xs + | otherwise = x : remove k xs -- | Turn a list of pieces into a relative path to the root folder. relativeDirFromPieces :: Pieces -> T.Text @@ -31,8 +35,10 @@ relativeDirFromPieces pieces = T.concat $ map (const "../") (drop 1 pieces) -- l -- | Construct redirects with relative paths. defaultMkRedirect :: Pieces -> ByteString -> S8.ByteString defaultMkRedirect pieces newPath - | S8.null newPath || S8.null relDir || - S8.last relDir /= '/' || S8.head newPath /= '/' = + | S8.null newPath + || S8.null relDir + || S8.last relDir /= '/' + || S8.head newPath /= '/' = relDir `S8.append` newPath | otherwise = relDir `S8.append` S8.tail newPath where diff --git a/wai-app-static/WaiAppStatic/Listing.hs b/wai-app-static/WaiAppStatic/Listing.hs index 0de3e1a50..2b7ae8e3b 100644 --- a/wai-app-static/WaiAppStatic/Listing.hs +++ b/wai-app-static/WaiAppStatic/Listing.hs @@ -1,16 +1,17 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ViewPatterns #-} -module WaiAppStatic.Listing - ( defaultListing - ) where -import qualified Text.Blaze.Html5.Attributes as A -import qualified Text.Blaze.Html5 as H -import Text.Blaze ((!)) +module WaiAppStatic.Listing ( + defaultListing, +) where + import qualified Data.Text as T import Data.Time import Data.Time.Clock.POSIX +import Text.Blaze ((!)) +import qualified Text.Blaze.Html5 as H +import qualified Text.Blaze.Html5.Attributes as A import WaiAppStatic.Types #if !MIN_VERSION_time(1,5,0) import System.Locale (defaultTimeLocale) @@ -28,33 +29,37 @@ defaultListing pieces (Folder contents) = do let isTop = null pieces || map Just pieces == [toPiece ""] let fps'' :: [Either FolderName File] fps'' = (if isTop then id else (Left (unsafeToPiece "") :)) contents -- FIXME emptyParentFolder feels like a bit of a hack - return $ HU.renderHtmlBuilder - $ H.html $ do - H.head $ do - let title = T.intercalate "/" $ map fromPiece pieces - let title' = if T.null title then "root folder" else title - H.title $ H.toHtml title' - H.style $ H.toHtml $ unlines [ "table { margin: 0 auto; width: 760px; border-collapse: collapse; font-family: 'sans-serif'; }" - , "table, th, td { border: 1px solid #353948; }" - , "td.size { text-align: right; font-size: 0.7em; width: 50px }" - , "td.date { text-align: right; font-size: 0.7em; width: 130px }" - , "td { padding-right: 1em; padding-left: 1em; }" - , "th.first { background-color: white; width: 24px }" - , "td.first { padding-right: 0; padding-left: 0; text-align: center }" - , "tr { background-color: white; }" - , "tr.alt { background-color: #A3B5BA}" - , "th { background-color: #3C4569; color: white; font-size: 1.125em; }" - , "h1 { width: 760px; margin: 1em auto; font-size: 1em; font-family: sans-serif }" - , "img { width: 20px }" - , "a { text-decoration: none }" - ] - H.body $ do - let hasTrailingSlash = - case map fromPiece $ reverse pieces of - "":_ -> True - _ -> False - H.h1 $ showFolder' hasTrailingSlash $ filter (not . T.null . fromPiece) pieces - renderDirectoryContentsTable (map fromPiece pieces) haskellSrc folderSrc fps'' + return $ + HU.renderHtmlBuilder $ + H.html $ do + H.head $ do + let title = T.intercalate "/" $ map fromPiece pieces + let title' = if T.null title then "root folder" else title + H.title $ H.toHtml title' + H.style $ + H.toHtml $ + unlines + [ "table { margin: 0 auto; width: 760px; border-collapse: collapse; font-family: 'sans-serif'; }" + , "table, th, td { border: 1px solid #353948; }" + , "td.size { text-align: right; font-size: 0.7em; width: 50px }" + , "td.date { text-align: right; font-size: 0.7em; width: 130px }" + , "td { padding-right: 1em; padding-left: 1em; }" + , "th.first { background-color: white; width: 24px }" + , "td.first { padding-right: 0; padding-left: 0; text-align: center }" + , "tr { background-color: white; }" + , "tr.alt { background-color: #A3B5BA}" + , "th { background-color: #3C4569; color: white; font-size: 1.125em; }" + , "h1 { width: 760px; margin: 1em auto; font-size: 1em; font-family: sans-serif }" + , "img { width: 20px }" + , "a { text-decoration: none }" + ] + H.body $ do + let hasTrailingSlash = + case map fromPiece $ reverse pieces of + "" : _ -> True + _ -> False + H.h1 $ showFolder' hasTrailingSlash $ filter (not . T.null . fromPiece) pieces + renderDirectoryContentsTable (map fromPiece pieces) haskellSrc folderSrc fps'' where image x = T.unpack $ T.concat [relativeDirFromPieces pieces, ".hidden/", x, ".png"] folderSrc = image "folder" @@ -69,7 +74,7 @@ defaultListing pieces (Folder contents) = do showFolder :: Bool -> Pieces -> H.Html showFolder _ [] = "/" -- won't happen showFolder _ [x] = H.toHtml $ showName $ fromPiece x - showFolder hasTrailingSlash (x:xs) = do + showFolder hasTrailingSlash (x : xs) = do let len = length xs - (if hasTrailingSlash then 0 else 1) href | len == 0 = "." @@ -86,64 +91,72 @@ defaultListing pieces (Folder contents) = do -- a new page template to wrap around this HTML. -- -- see also: 'getMetaData', 'renderDirectoryContents' -renderDirectoryContentsTable :: [T.Text] -- ^ requested path info - -> String - -> String - -> [Either FolderName File] - -> H.Html +renderDirectoryContentsTable + :: [T.Text] + -- ^ requested path info + -> String + -> String + -> [Either FolderName File] + -> H.Html renderDirectoryContentsTable pathInfo' haskellSrc folderSrc fps = - H.table $ do H.thead $ do H.th ! A.class_ "first" $ H.img ! A.src (H.toValue haskellSrc) - H.th "Name" - H.th "Modified" - H.th "Size" - H.tbody $ mapM_ mkRow (zip (sortBy sortMD fps) $ cycle [False, True]) - where - sortMD :: Either FolderName File -> Either FolderName File -> Ordering - sortMD Left{} Right{} = LT - sortMD Right{} Left{} = GT - sortMD (Left a) (Left b) = compare a b - sortMD (Right a) (Right b) = compare (fileName a) (fileName b) + H.table $ do + H.thead $ do + H.th ! A.class_ "first" $ H.img ! A.src (H.toValue haskellSrc) + H.th "Name" + H.th "Modified" + H.th "Size" + H.tbody $ mapM_ mkRow (zip (sortBy sortMD fps) $ cycle [False, True]) + where + sortMD :: Either FolderName File -> Either FolderName File -> Ordering + sortMD Left{} Right{} = LT + sortMD Right{} Left{} = GT + sortMD (Left a) (Left b) = compare a b + sortMD (Right a) (Right b) = compare (fileName a) (fileName b) - mkRow :: (Either FolderName File, Bool) -> H.Html - mkRow (md, alt) = - (if alt then (! A.class_ "alt") else id) $ - H.tr $ do - H.td ! A.class_ "first" - $ case md of - Left{} -> H.img ! A.src (H.toValue folderSrc) - ! A.alt "Folder" - Right{} -> return () - let name = - case either id fileName md of - (fromPiece -> "") -> unsafeToPiece ".." - x -> x - let href = addCurrentDir $ fromPiece name - addCurrentDir x = - case reverse pathInfo' of - "":_ -> x -- has a trailing slash - [] -> x -- at the root - currentDir:_ -> T.concat [currentDir, "/", x] - H.td (H.a ! A.href (H.toValue href) $ H.toHtml $ fromPiece name) - H.td ! A.class_ "date" $ H.toHtml $ - case md of - Right File { fileGetModified = Just t } -> - formatCalendarTime defaultTimeLocale "%d-%b-%Y %X" t - _ -> "" - H.td ! A.class_ "size" $ H.toHtml $ - case md of - Right File { fileGetSize = s } -> prettyShow s - Left{} -> "" - formatCalendarTime a b c = formatTime a b $ posixSecondsToUTCTime (realToFrac c :: POSIXTime) - prettyShow x + mkRow :: (Either FolderName File, Bool) -> H.Html + mkRow (md, alt) = + (if alt then (! A.class_ "alt") else id) $ + H.tr $ do + H.td ! A.class_ "first" $ + case md of + Left{} -> + H.img + ! A.src (H.toValue folderSrc) + ! A.alt "Folder" + Right{} -> return () + let name = + case either id fileName md of + (fromPiece -> "") -> unsafeToPiece ".." + x -> x + let href = addCurrentDir $ fromPiece name + addCurrentDir x = + case reverse pathInfo' of + "" : _ -> x -- has a trailing slash + [] -> x -- at the root + currentDir : _ -> T.concat [currentDir, "/", x] + H.td (H.a ! A.href (H.toValue href) $ H.toHtml $ fromPiece name) + H.td ! A.class_ "date" $ + H.toHtml $ + case md of + Right File{fileGetModified = Just t} -> + formatCalendarTime defaultTimeLocale "%d-%b-%Y %X" t + _ -> "" + H.td ! A.class_ "size" $ + H.toHtml $ + case md of + Right File{fileGetSize = s} -> prettyShow s + Left{} -> "" + formatCalendarTime a b c = formatTime a b $ posixSecondsToUTCTime (realToFrac c :: POSIXTime) + prettyShow x | x > 1024 = prettyShowK $ x `div` 1024 | otherwise = addCommas "B" x - prettyShowK x + prettyShowK x | x > 1024 = prettyShowM $ x `div` 1024 | otherwise = addCommas "KB" x - prettyShowM x + prettyShowM x | x > 1024 = prettyShowG $ x `div` 1024 | otherwise = addCommas "MB" x - prettyShowG x = addCommas "GB" x - addCommas s = (++ (' ' : s)) . reverse . addCommas' . reverse . show - addCommas' (a:b:c:d:e) = a : b : c : ',' : addCommas' (d : e) - addCommas' x = x + prettyShowG x = addCommas "GB" x + addCommas s = (++ (' ' : s)) . reverse . addCommas' . reverse . show + addCommas' (a : b : c : d : e) = a : b : c : ',' : addCommas' (d : e) + addCommas' x = x diff --git a/wai-app-static/WaiAppStatic/Storage/Embedded.hs b/wai-app-static/WaiAppStatic/Storage/Embedded.hs index daa6e5048..46a686529 100644 --- a/wai-app-static/WaiAppStatic/Storage/Embedded.hs +++ b/wai-app-static/WaiAppStatic/Storage/Embedded.hs @@ -1,12 +1,12 @@ -module WaiAppStatic.Storage.Embedded( +module WaiAppStatic.Storage.Embedded ( -- * Basic - embeddedSettings + embeddedSettings, -- * Template Haskell - , Etag - , EmbeddableEntry(..) - , mkSettings - ) where + Etag, + EmbeddableEntry (..), + mkSettings, +) where import WaiAppStatic.Storage.Embedded.Runtime import WaiAppStatic.Storage.Embedded.TH diff --git a/wai-app-static/WaiAppStatic/Storage/Embedded/Runtime.hs b/wai-app-static/WaiAppStatic/Storage/Embedded/Runtime.hs index ac9abf971..1fc4194b0 100644 --- a/wai-app-static/WaiAppStatic/Storage/Embedded/Runtime.hs +++ b/wai-app-static/WaiAppStatic/Storage/Embedded/Runtime.hs @@ -1,21 +1,22 @@ {-# LANGUAGE CPP #-} + -- | Lookup files stored in memory instead of from the filesystem. -module WaiAppStatic.Storage.Embedded.Runtime - ( -- * Settings - embeddedSettings - ) where +module WaiAppStatic.Storage.Embedded.Runtime ( + -- * Settings + embeddedSettings, +) where -import WaiAppStatic.Types +import Control.Arrow (second, (&&&)) import Data.ByteString (ByteString) -import Control.Arrow ((&&&), second) -import Data.List (groupBy, sortBy) +import qualified Data.ByteString as S import Data.ByteString.Builder (byteString) -import qualified Network.Wai as W -import qualified Data.Map as Map import Data.Function (on) -import qualified Data.Text as T +import Data.List (groupBy, sortBy) +import qualified Data.Map as Map import Data.Ord -import qualified Data.ByteString as S +import qualified Data.Text as T +import qualified Network.Wai as W +import WaiAppStatic.Types #ifdef MIN_VERSION_crypton import Crypto.Hash (hash, MD5, Digest) import Data.ByteArray.Encoding @@ -23,14 +24,15 @@ import Data.ByteArray.Encoding import Crypto.Hash.MD5 (hash) import Data.ByteString.Base64 (encode) #endif -import WaiAppStatic.Storage.Filesystem (defaultFileServerSettings) import System.FilePath (isPathSeparator) +import WaiAppStatic.Storage.Filesystem (defaultFileServerSettings) -- | Serve the list of path/content pairs directly from memory. embeddedSettings :: [(Prelude.FilePath, ByteString)] -> StaticSettings -embeddedSettings files = (defaultFileServerSettings $ error "unused") - { ssLookupFile = embeddedLookup $ toEmbedded files - } +embeddedSettings files = + (defaultFileServerSettings $ error "unused") + { ssLookupFile = embeddedLookup $ toEmbedded files + } type Embedded = Map.Map Piece EmbeddedEntry @@ -40,10 +42,10 @@ embeddedLookup :: Embedded -> Pieces -> IO LookupResult embeddedLookup root pieces = return $ elookup pieces root where - elookup :: Pieces -> Embedded -> LookupResult + elookup :: Pieces -> Embedded -> LookupResult elookup [] x = LRFolder $ Folder $ fmap toEntry $ Map.toList x elookup [p] x | T.null (fromPiece p) = elookup [] x - elookup (p:ps) x = + elookup (p : ps) x = case Map.lookup p x of Nothing -> LRNotFound Just (EEFile f) -> @@ -54,13 +56,15 @@ embeddedLookup root pieces = toEntry :: (Piece, EmbeddedEntry) -> Either FolderName File toEntry (name, EEFolder{}) = Left name -toEntry (name, EEFile bs) = Right File - { fileGetSize = fromIntegral $ S.length bs - , fileToResponse = \s h -> W.responseBuilder s h $ byteString bs - , fileName = name - , fileGetHash = return $ Just $ runHash bs - , fileGetModified = Nothing - } +toEntry (name, EEFile bs) = + Right + File + { fileGetSize = fromIntegral $ S.length bs + , fileToResponse = \s h -> W.responseBuilder s h $ byteString bs + , fileName = name + , fileGetHash = return $ Just $ runHash bs + , fileGetModified = Nothing + } toEmbedded :: [(Prelude.FilePath, ByteString)] -> Embedded toEmbedded fps = @@ -91,13 +95,14 @@ toEmbedded fps = go' x = EEFolder $ go $ filter (\y -> not $ null $ fst y) x bsToFile :: Piece -> ByteString -> File -bsToFile name bs = File - { fileGetSize = fromIntegral $ S.length bs - , fileToResponse = \s h -> W.responseBuilder s h $ byteString bs - , fileName = name - , fileGetHash = return $ Just $ runHash bs - , fileGetModified = Nothing - } +bsToFile name bs = + File + { fileGetSize = fromIntegral $ S.length bs + , fileToResponse = \s h -> W.responseBuilder s h $ byteString bs + , fileName = name + , fileGetHash = return $ Just $ runHash bs + , fileGetModified = Nothing + } runHash :: ByteString -> ByteString #ifdef MIN_VERSION_crypton diff --git a/wai-app-static/WaiAppStatic/Storage/Embedded/TH.hs b/wai-app-static/WaiAppStatic/Storage/Embedded/TH.hs index 3aaa22d05..7fe7ea2a7 100644 --- a/wai-app-static/WaiAppStatic/Storage/Embedded/TH.hs +++ b/wai-app-static/WaiAppStatic/Storage/Embedded/TH.hs @@ -1,22 +1,27 @@ -{-# LANGUAGE TemplateHaskell, QuasiQuotes, OverloadedStrings, MagicHash, CPP #-} -module WaiAppStatic.Storage.Embedded.TH( - Etag - , EmbeddableEntry(..) - , mkSettings +{-# LANGUAGE CPP #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} + +module WaiAppStatic.Storage.Embedded.TH ( + Etag, + EmbeddableEntry (..), + mkSettings, ) where -import Data.ByteString.Builder.Extra (byteStringInsert) import Codec.Compression.GZip (compress) +import qualified Data.ByteString as B +import Data.ByteString.Builder.Extra (byteStringInsert) +import qualified Data.ByteString.Lazy as BL import Data.ByteString.Unsafe (unsafePackAddressLen) import Data.Either (lefts, rights) -import GHC.Exts (Int(..)) +import GHC.Exts (Int (..)) import Language.Haskell.TH import Network.Mime (MimeType, defaultMimeLookup) import System.IO.Unsafe (unsafeDupablePerformIO) -import WaiAppStatic.Types import WaiAppStatic.Storage.Filesystem (defaultWebAppSettings) -import qualified Data.ByteString as B -import qualified Data.ByteString.Lazy as BL +import WaiAppStatic.Types #if !MIN_VERSION_template_haskell(2, 8, 0) import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Lazy.Char8 as BL8 @@ -33,35 +38,37 @@ import qualified Network.Wai as W type Etag = T.Text -- | Used at compile time to hold data about an entry to embed into the compiled executable. -data EmbeddableEntry = EmbeddableEntry { - eLocation :: T.Text -- ^ The location where this resource should be served from. The - -- location can contain forward slashes (/) to simulate directories, - -- but must not end with a forward slash. - , eMimeType :: MimeType -- ^ The mime type. - , eContent :: Either (Etag, BL.ByteString) ExpQ - -- ^ The content itself. The content can be given as a tag and bytestring, - -- in which case the content will be embedded directly into the execuatble. - -- Alternatively, the content can be given as a template haskell expression - -- returning @IO ('Etag', 'BL.ByteString')@ in which case this action will - -- be executed on every request to reload the content (this is useful - -- for a debugging mode). -} +data EmbeddableEntry = EmbeddableEntry + { eLocation :: T.Text + -- ^ The location where this resource should be served from. The + -- location can contain forward slashes (/) to simulate directories, + -- but must not end with a forward slash. + , eMimeType :: MimeType + -- ^ The mime type. + , eContent :: Either (Etag, BL.ByteString) ExpQ + -- ^ The content itself. The content can be given as a tag and bytestring, + -- in which case the content will be embedded directly into the execuatble. + -- Alternatively, the content can be given as a template haskell expression + -- returning @IO ('Etag', 'BL.ByteString')@ in which case this action will + -- be executed on every request to reload the content (this is useful + -- for a debugging mode). + } -- | This structure is used at runtime to hold the entry. -data EmbeddedEntry = EmbeddedEntry { - embLocation :: !T.Text - , embMime :: !MimeType - , embEtag :: !B.ByteString - , embCompressed :: !Bool - , embContent :: !B.ByteString -} +data EmbeddedEntry = EmbeddedEntry + { embLocation :: !T.Text + , embMime :: !MimeType + , embEtag :: !B.ByteString + , embCompressed :: !Bool + , embContent :: !B.ByteString + } -- | This structure is used at runtime to hold the reload entries. -data ReloadEntry = ReloadEntry { - reloadLocation :: !T.Text - , reloadMime :: !MimeType - , reloadContent :: IO (T.Text, BL.ByteString) -} +data ReloadEntry = ReloadEntry + { reloadLocation :: !T.Text + , reloadMime :: !MimeType + , reloadContent :: IO (T.Text, BL.ByteString) + } -- The use of unsafePackAddressLen is safe here because the length -- is correct and we will only be reading from the bytestring, never @@ -102,90 +109,100 @@ bytestringLazyE b = -- | A template haskell expression which creates either an EmbeddedEntry or ReloadEntry. mkEntry :: EmbeddableEntry -> ExpQ mkEntry (EmbeddableEntry loc mime (Left (etag, ct))) = - [| Left $ EmbeddedEntry (T.pack $locE) - $(bytestringE mime) - $(bytestringE $ T.encodeUtf8 etag) - (1 == I# $compressedE) - $(bytestringLazyE ct') + [| + Left $ + EmbeddedEntry + (T.pack $locE) + $(bytestringE mime) + $(bytestringE $ T.encodeUtf8 etag) + (1 == I# $compressedE) + $(bytestringLazyE ct') |] - where - locE = litE $ stringL $ T.unpack loc - (compressed, ct') = tryCompress mime ct - compressedE = litE $ intPrimL $ if compressed then 1 else 0 - + where + locE = litE $ stringL $ T.unpack loc + (compressed, ct') = tryCompress mime ct + compressedE = litE $ intPrimL $ if compressed then 1 else 0 mkEntry (EmbeddableEntry loc mime (Right expr)) = - [| Right $ ReloadEntry (T.pack $locE) - $(bytestringE mime) - $expr + [| + Right $ + ReloadEntry + (T.pack $locE) + $(bytestringE mime) + $expr |] - where - locE = litE $ stringL $ T.unpack loc + where + locE = litE $ stringL $ T.unpack loc -- | Converts an embedded entry to a file embeddedToFile :: EmbeddedEntry -> File -embeddedToFile entry = File - { fileGetSize = fromIntegral $ B.length $ embContent entry - , fileToResponse = \s h -> - let h' = if embCompressed entry - then h ++ [("Content-Encoding", "gzip")] - else h - in W.responseBuilder s h' $ byteStringInsert $ embContent entry - - -- Usually the fileName should just be the filename not the entire path, - -- but we need the whole path to make the lookup within lookupMime - -- possible. lookupMime is provided only with the File and from that - -- we must find the mime type. Putting the path here is OK since - -- within staticApp the fileName is used for directory listings which - -- we have disabled. - , fileName = unsafeToPiece $ embLocation entry - , fileGetHash = return $ if B.null (embEtag entry) - then Nothing - else Just $ embEtag entry - , fileGetModified = Nothing - } +embeddedToFile entry = + File + { fileGetSize = fromIntegral $ B.length $ embContent entry + , fileToResponse = \s h -> + let h' = + if embCompressed entry + then h ++ [("Content-Encoding", "gzip")] + else h + in W.responseBuilder s h' $ byteStringInsert $ embContent entry + , -- Usually the fileName should just be the filename not the entire path, + -- but we need the whole path to make the lookup within lookupMime + -- possible. lookupMime is provided only with the File and from that + -- we must find the mime type. Putting the path here is OK since + -- within staticApp the fileName is used for directory listings which + -- we have disabled. + fileName = unsafeToPiece $ embLocation entry + , fileGetHash = + return $ + if B.null (embEtag entry) + then Nothing + else Just $ embEtag entry + , fileGetModified = Nothing + } -- | Converts a reload entry to a file reloadToFile :: ReloadEntry -> IO File reloadToFile entry = do (etag, ct) <- reloadContent entry let etag' = T.encodeUtf8 etag - return $ File - { fileGetSize = fromIntegral $ BL.length ct - , fileToResponse = \s h -> W.responseLBS s h ct - -- Similar to above the entire path needs to be in the fileName. - , fileName = unsafeToPiece $ reloadLocation entry - , fileGetHash = return $ if T.null etag then Nothing else Just etag' - , fileGetModified = Nothing - } - + return $ + File + { fileGetSize = fromIntegral $ BL.length ct + , fileToResponse = \s h -> W.responseLBS s h ct + , -- Similar to above the entire path needs to be in the fileName. + fileName = unsafeToPiece $ reloadLocation entry + , fileGetHash = return $ if T.null etag then Nothing else Just etag' + , fileGetModified = Nothing + } -- | Build a static settings based on a filemap. filemapToSettings :: M.HashMap T.Text (MimeType, IO File) -> StaticSettings -filemapToSettings mfiles = (defaultWebAppSettings "") - { ssLookupFile = lookupFile - , ssGetMimeType = lookupMime - } - where - piecesToFile p = T.intercalate "/" $ map fromPiece p +filemapToSettings mfiles = + (defaultWebAppSettings "") + { ssLookupFile = lookupFile + , ssGetMimeType = lookupMime + } + where + piecesToFile p = T.intercalate "/" $ map fromPiece p - lookupFile [] = return LRNotFound - lookupFile p = - case M.lookup (piecesToFile p) mfiles of - Nothing -> return LRNotFound - Just (_,act) -> LRFile <$> act + lookupFile [] = return LRNotFound + lookupFile p = + case M.lookup (piecesToFile p) mfiles of + Nothing -> return LRNotFound + Just (_, act) -> LRFile <$> act - lookupMime (File { fileName = p }) = - case M.lookup (fromPiece p) mfiles of - Just (mime,_) -> return mime - Nothing -> return $ defaultMimeLookup $ fromPiece p + lookupMime (File{fileName = p}) = + case M.lookup (fromPiece p) mfiles of + Just (mime, _) -> return mime + Nothing -> return $ defaultMimeLookup $ fromPiece p -- | Create a 'StaticSettings' from a list of entries. Executed at run time. entriesToSt :: [Either EmbeddedEntry ReloadEntry] -> StaticSettings entriesToSt entries = hmap `seq` filemapToSettings hmap - where - embFiles = [ (embLocation e, (embMime e, return $ embeddedToFile e)) | e <- lefts entries] - reloadFiles = [ (reloadLocation r, (reloadMime r, reloadToFile r)) | r <- rights entries] - hmap = M.fromList $ embFiles ++ reloadFiles + where + embFiles = + [(embLocation e, (embMime e, return $ embeddedToFile e)) | e <- lefts entries] + reloadFiles = [(reloadLocation r, (reloadMime r, reloadToFile r)) | r <- rights entries] + hmap = M.fromList $ embFiles ++ reloadFiles -- | Create a 'StaticSettings' at compile time that embeds resources directly into the compiled -- executable. The embedded resources are precompressed (depending on mime type) @@ -197,17 +214,17 @@ entriesToSt entries = hmap `seq` filemapToSettings hmap -- -- > {-# LANGUAGE TemplateHaskell, QuasiQuotes, OverloadedStrings #-} -- > module A (mkEmbedded) where --- > +-- > -- > import WaiAppStatic.Storage.Embedded -- > import Crypto.Hash.MD5 (hashlazy) -- > import qualified Data.ByteString.Lazy as BL -- > import qualified Data.ByteString.Base64 as B64 -- > import qualified Data.Text as T -- > import qualified Data.Text.Encoding as T --- > +-- > -- > hash :: BL.ByteString -> T.Text -- > hash = T.take 8 . T.decodeUtf8 . B64.encode . hashlazy --- > +-- > -- > mkEmbedded :: IO [EmbeddableEntry] -- > mkEmbedded = do -- > file <- BL.readFile "test.css" @@ -216,13 +233,13 @@ entriesToSt entries = hmap `seq` filemapToSettings hmap -- > , eMimeType = "text/css" -- > , eContent = Left (hash file, file) -- > } --- > +-- > -- > let reload = EmbeddableEntry { -- > eLocation = "anotherdir/test2.txt" -- > , eMimeType = "text/plain" -- > , eContent = Right [| BL.readFile "test2.txt" >>= \c -> return (hash c, c) |] -- > } --- > +-- > -- > return [emb, reload] -- -- The above @mkEmbedded@ will be executed at compile time. It loads the contents of test.css and @@ -241,37 +258,38 @@ entriesToSt entries = hmap `seq` filemapToSettings hmap -- -- > {-# LANGUAGE TemplateHaskell #-} -- > module B where --- > +-- > -- > import A -- > import Network.Wai (Application) -- > import Network.Wai.Application.Static (staticApp) -- > import WaiAppStatic.Storage.Embedded -- > import Network.Wai.Handler.Warp (run) --- > +-- > -- > myApp :: Application -- > myApp = staticApp $(mkSettings mkEmbedded) --- > +-- > -- > main :: IO () -- > main = run 3000 myApp mkSettings :: IO [EmbeddableEntry] -> ExpQ mkSettings action = do entries <- runIO action - [| entriesToSt $(listE $ map mkEntry entries) |] + [|entriesToSt $(listE $ map mkEntry entries)|] shouldCompress :: MimeType -> Bool shouldCompress m = "text/" `B.isPrefixOf` m || m `elem` extra - where - extra = [ "application/json" - , "application/javascript" - , "application/ecmascript" - ] + where + extra = + [ "application/json" + , "application/javascript" + , "application/ecmascript" + ] -- | Only compress if the mime type is correct and the compressed text is actually shorter. tryCompress :: MimeType -> BL.ByteString -> (Bool, BL.ByteString) tryCompress mime ct - | shouldCompress mime = (c, ct') - | otherwise = (False, ct) - where - compressed = compress ct - c = BL.length compressed < BL.length ct - ct' = if c then compressed else ct + | shouldCompress mime = (c, ct') + | otherwise = (False, ct) + where + compressed = compress ct + c = BL.length compressed < BL.length ct + ct' = if c then compressed else ct diff --git a/wai-app-static/WaiAppStatic/Storage/Filesystem.hs b/wai-app-static/WaiAppStatic/Storage/Filesystem.hs index 7a14f650c..81df82ec1 100644 --- a/wai-app-static/WaiAppStatic/Storage/Filesystem.hs +++ b/wai-app-static/WaiAppStatic/Storage/Filesystem.hs @@ -1,31 +1,42 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} -{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ScopedTypeVariables #-} + -- | Access files on the filesystem. -module WaiAppStatic.Storage.Filesystem - ( -- * Types - ETagLookup - -- * Settings - , defaultWebAppSettings - , defaultFileServerSettings - , webAppSettingsWithLookup - ) where +module WaiAppStatic.Storage.Filesystem ( + -- * Types + ETagLookup, -import WaiAppStatic.Types -import System.FilePath (()) -import System.IO (withBinaryFile, IOMode(..)) -import System.Directory (doesFileExist, doesDirectoryExist, getDirectoryContents) -import Data.List (foldl') + -- * Settings + defaultWebAppSettings, + defaultFileServerSettings, + webAppSettingsWithLookup, +) where + +import Control.Exception (SomeException, try) import Control.Monad (forM) -import Util import Data.ByteString (ByteString) -import Control.Exception (SomeException, try) +import Data.List (foldl') +import Data.Maybe (catMaybes) +import Network.Mime import qualified Network.Wai as W +import System.Directory ( + doesDirectoryExist, + doesFileExist, + getDirectoryContents, + ) +import System.FilePath (()) +import System.IO (IOMode (..), withBinaryFile) +import System.PosixCompat.Files ( + fileSize, + getFileStatus, + isRegularFile, + modificationTime, + ) +import Util import WaiAppStatic.Listing -import Network.Mime -import System.PosixCompat.Files (fileSize, getFileStatus, modificationTime, isRegularFile) -import Data.Maybe (catMaybes) +import WaiAppStatic.Types #ifdef MIN_VERSION_crypton import Data.ByteArray.Encoding import Crypto.Hash (hashlazy, MD5, Digest) @@ -42,70 +53,88 @@ pathFromPieces = foldl' (\fp p -> fp T.unpack (fromPiece p)) -- | Settings optimized for a web application. Files will have aggressive -- caching applied and hashes calculated, and indices and listings are disabled. -defaultWebAppSettings :: FilePath -- ^ root folder to serve from - -> StaticSettings -defaultWebAppSettings root = StaticSettings - { ssLookupFile = webAppLookup hashFileIfExists root - , ssMkRedirect = defaultMkRedirect - , ssGetMimeType = return . defaultMimeLookup . fromPiece . fileName - , ssMaxAge = MaxAgeForever - , ssListing = Nothing - , ssIndices = [] - , ssRedirectToIndex = False - , ssUseHash = True - , ssAddTrailingSlash = False - , ss404Handler = Nothing - } +defaultWebAppSettings + :: FilePath + -- ^ root folder to serve from + -> StaticSettings +defaultWebAppSettings root = + StaticSettings + { ssLookupFile = webAppLookup hashFileIfExists root + , ssMkRedirect = defaultMkRedirect + , ssGetMimeType = return . defaultMimeLookup . fromPiece . fileName + , ssMaxAge = MaxAgeForever + , ssListing = Nothing + , ssIndices = [] + , ssRedirectToIndex = False + , ssUseHash = True + , ssAddTrailingSlash = False + , ss404Handler = Nothing + } -- | Settings optimized for a file server. More conservative caching will be -- applied, and indices and listings are enabled. -defaultFileServerSettings :: FilePath -- ^ root folder to serve from - -> StaticSettings -defaultFileServerSettings root = StaticSettings - { ssLookupFile = fileSystemLookup (fmap Just . hashFile) root - , ssMkRedirect = defaultMkRedirect - , ssGetMimeType = return . defaultMimeLookup . fromPiece . fileName - , ssMaxAge = NoMaxAge - , ssListing = Just defaultListing - , ssIndices = map unsafeToPiece ["index.html", "index.htm"] - , ssRedirectToIndex = False - , ssUseHash = False - , ssAddTrailingSlash = False - , ss404Handler = Nothing - } +defaultFileServerSettings + :: FilePath + -- ^ root folder to serve from + -> StaticSettings +defaultFileServerSettings root = + StaticSettings + { ssLookupFile = fileSystemLookup (fmap Just . hashFile) root + , ssMkRedirect = defaultMkRedirect + , ssGetMimeType = return . defaultMimeLookup . fromPiece . fileName + , ssMaxAge = NoMaxAge + , ssListing = Just defaultListing + , ssIndices = map unsafeToPiece ["index.html", "index.htm"] + , ssRedirectToIndex = False + , ssUseHash = False + , ssAddTrailingSlash = False + , ss404Handler = Nothing + } -- | Same as @defaultWebAppSettings@, but additionally uses a specialized -- @ETagLookup@ in place of the standard one. This can allow you to cache your -- hash values, or even precompute them. -webAppSettingsWithLookup :: FilePath -- ^ root folder to serve from - -> ETagLookup - -> StaticSettings +webAppSettingsWithLookup + :: FilePath + -- ^ root folder to serve from + -> ETagLookup + -> StaticSettings webAppSettingsWithLookup dir etagLookup = - (defaultWebAppSettings dir) { ssLookupFile = webAppLookup etagLookup dir} + (defaultWebAppSettings dir){ssLookupFile = webAppLookup etagLookup dir} -- | Convenience wrapper for @fileHelper@. -fileHelperLR :: ETagLookup - -> FilePath -- ^ file location - -> Piece -- ^ file name - -> IO LookupResult +fileHelperLR + :: ETagLookup + -> FilePath + -- ^ file location + -> Piece + -- ^ file name + -> IO LookupResult fileHelperLR a b c = fmap (maybe LRNotFound LRFile) $ fileHelper a b c -- | Attempt to load up a @File@ from the given path. -fileHelper :: ETagLookup - -> FilePath -- ^ file location - -> Piece -- ^ file name - -> IO (Maybe File) +fileHelper + :: ETagLookup + -> FilePath + -- ^ file location + -> Piece + -- ^ file name + -> IO (Maybe File) fileHelper hashFunc fp name = do efs <- try $ getFileStatus fp case efs of Left (_ :: SomeException) -> return Nothing - Right fs | isRegularFile fs -> return $ Just File - { fileGetSize = fromIntegral $ fileSize fs - , fileToResponse = \s h -> W.responseFile s h fp Nothing - , fileName = name - , fileGetHash = hashFunc fp - , fileGetModified = Just $ modificationTime fs - } + Right fs + | isRegularFile fs -> + return $ + Just + File + { fileGetSize = fromIntegral $ fileSize fs + , fileToResponse = \s h -> W.responseFile s h fp Nothing + , fileName = name + , fileGetHash = hashFunc fp + , fileGetModified = Just $ modificationTime fs + } Right _ -> return Nothing -- | How to calculate etags. Can perform filesystem reads on each call, or use @@ -144,14 +173,17 @@ hashFileIfExists fp = do Right x -> Just x isVisible :: FilePath -> Bool -isVisible ('.':_) = False +isVisible ('.' : _) = False isVisible "" = False isVisible _ = True -- | Get a proper @LookupResult@, checking if the path is a file or folder. -- Compare with @webAppLookup@, which only deals with files. -fileSystemLookup :: ETagLookup - -> FilePath -> Pieces -> IO LookupResult +fileSystemLookup + :: ETagLookup + -> FilePath + -> Pieces + -> IO LookupResult fileSystemLookup hashFunc prefix pieces = do let fp = pathFromPieces prefix pieces fe <- doesFileExist fp diff --git a/wai-app-static/WaiAppStatic/Types.hs b/wai-app-static/WaiAppStatic/Types.hs index 4044d6c4c..bfbed6088 100644 --- a/wai-app-static/WaiAppStatic/Types.hs +++ b/wai-app-static/WaiAppStatic/Types.hs @@ -1,31 +1,34 @@ -module WaiAppStatic.Types - ( -- * Pieces - Piece - , toPiece - , fromPiece - , unsafeToPiece - , Pieces - , toPieces - -- * Caching - , MaxAge (..) - -- * File\/folder serving - , FolderName - , Folder (..) - , File (..) - , LookupResult (..) - , Listing - -- * Settings - , StaticSettings (..) - ) where +module WaiAppStatic.Types ( + -- * Pieces + Piece, + toPiece, + fromPiece, + unsafeToPiece, + Pieces, + toPieces, + + -- * Caching + MaxAge (..), + + -- * File\/folder serving + FolderName, + Folder (..), + File (..), + LookupResult (..), + Listing, + + -- * Settings + StaticSettings (..), +) where +import Data.ByteString (ByteString) +import Data.ByteString.Builder (Builder) import Data.Text (Text) +import qualified Data.Text as T import qualified Network.HTTP.Types as H +import Network.Mime (MimeType) import qualified Network.Wai as W -import Data.ByteString (ByteString) import System.Posix.Types (EpochTime) -import qualified Data.Text as T -import Data.ByteString.Builder (Builder) -import Network.Mime (MimeType) -- | An individual component of a path, or of a filepath. -- @@ -35,7 +38,7 @@ import Network.Mime (MimeType) -- -- Individual file lookup backends must know how to convert from a @Piece@ to -- their storage system. -newtype Piece = Piece { fromPiece :: Text } +newtype Piece = Piece {fromPiece :: Text} deriving (Show, Eq, Ord) -- | Smart constructor for a @Piece@. Won\'t allow unsafe components, such as @@ -64,10 +67,15 @@ toPieces = mapM toPiece type Pieces = [Piece] -- | Values for the max-age component of the cache-control response header. -data MaxAge = NoMaxAge -- ^ no cache-control set - | MaxAgeSeconds Int -- ^ set to the given number of seconds - | MaxAgeForever -- ^ essentially infinite caching; in reality, probably one year - | NoStore -- ^ set cache-control to no-store @since 3.1.8 +data MaxAge + = -- | no cache-control set + NoMaxAge + | -- | set to the given number of seconds + MaxAgeSeconds Int + | -- | essentially infinite caching; in reality, probably one year + MaxAgeForever + | -- | set cache-control to no-store @since 3.1.8 + NoStore -- | Just the name of a folder. type FolderName = Piece @@ -80,26 +88,27 @@ data Folder = Folder -- | Information on an individual file. data File = File - { -- | Size of file in bytes - fileGetSize :: Integer - -- | How to construct a WAI response for this file. Some files are stored - -- on the filesystem and can use @ResponseFile@, while others are stored - -- in memory and should use @ResponseBuilder@. + { fileGetSize :: Integer + -- ^ Size of file in bytes , fileToResponse :: H.Status -> H.ResponseHeaders -> W.Response - -- | Last component of the filename. + -- ^ How to construct a WAI response for this file. Some files are stored + -- on the filesystem and can use @ResponseFile@, while others are stored + -- in memory and should use @ResponseBuilder@. , fileName :: Piece - -- | Calculate a hash of the contents of this file, such as for etag. + -- ^ Last component of the filename. , fileGetHash :: IO (Maybe ByteString) - -- | Last modified time, used for both display in listings and if-modified-since. + -- ^ Calculate a hash of the contents of this file, such as for etag. , fileGetModified :: Maybe EpochTime + -- ^ Last modified time, used for both display in listings and if-modified-since. } -- | Result of looking up a file in some storage backend. -- -- The lookup is either a file or folder, or does not exist. -data LookupResult = LRFile File - | LRFolder Folder - | LRNotFound +data LookupResult + = LRFile File + | LRFolder Folder + | LRNotFound -- | How to construct a directory listing page for the given request path and -- the resulting folder. @@ -110,45 +119,35 @@ type Listing = Pieces -> Folder -> IO Builder -- Note that you should use the settings type approach for modifying values. -- See for more information. data StaticSettings = StaticSettings - { - -- | Lookup a single file or folder. This is how you can control storage - -- backend (filesystem, embedded, etc) and where to lookup. - ssLookupFile :: Pieces -> IO LookupResult - - -- | Determine the mime type of the given file. Note that this function - -- lives in @IO@ in case you want to perform more complicated mimetype - -- analysis, such as via the @file@ utility. + { ssLookupFile :: Pieces -> IO LookupResult + -- ^ Lookup a single file or folder. This is how you can control storage + -- backend (filesystem, embedded, etc) and where to lookup. , ssGetMimeType :: File -> IO MimeType - - -- | Ordered list of filenames to be used for indices. If the user - -- requests a folder, and a file with the given name is found in that - -- folder, that file is served. This supercedes any directory listing. + -- ^ Determine the mime type of the given file. Note that this function + -- lives in @IO@ in case you want to perform more complicated mimetype + -- analysis, such as via the @file@ utility. , ssIndices :: [Piece] - - -- | How to perform a directory listing. Optional. Will be used when the - -- user requested a folder. + -- ^ Ordered list of filenames to be used for indices. If the user + -- requests a folder, and a file with the given name is found in that + -- folder, that file is served. This supercedes any directory listing. , ssListing :: Maybe Listing - - -- | Value to provide for max age in the cache-control. + -- ^ How to perform a directory listing. Optional. Will be used when the + -- user requested a folder. , ssMaxAge :: MaxAge - - -- | Given a requested path and a new destination, construct a string - -- that will go there. Default implementation will use relative paths. + -- ^ Value to provide for max age in the cache-control. , ssMkRedirect :: Pieces -> ByteString -> ByteString - - -- | If @True@, send a redirect to the user when a folder is requested - -- and an index page should be displayed. When @False@, display the - -- content immediately. + -- ^ Given a requested path and a new destination, construct a string + -- that will go there. Default implementation will use relative paths. , ssRedirectToIndex :: Bool - - -- | Prefer usage of etag caching to last-modified caching. + -- ^ If @True@, send a redirect to the user when a folder is requested + -- and an index page should be displayed. When @False@, display the + -- content immediately. , ssUseHash :: Bool - - -- | Force a trailing slash at the end of directories + -- ^ Prefer usage of etag caching to last-modified caching. , ssAddTrailingSlash :: Bool - - -- | Optional `W.Application` to be used in case of 404 errors - -- - -- Since 3.1.3 + -- ^ Force a trailing slash at the end of directories , ss404Handler :: Maybe W.Application + -- ^ Optional `W.Application` to be used in case of 404 errors + -- + -- Since 3.1.3 } diff --git a/wai-app-static/embedded-sample.hs b/wai-app-static/embedded-sample.hs index 46a6ae65d..b126b4819 100644 --- a/wai-app-static/embedded-sample.hs +++ b/wai-app-static/embedded-sample.hs @@ -1,13 +1,17 @@ -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} + +import Data.FileEmbed import Network.Wai.Application.Static import Network.Wai.Handler.Warp (run) -import Data.FileEmbed -import WaiAppStatic.Types import WaiAppStatic.Storage.Embedded +import WaiAppStatic.Types main :: IO () -main = run 3000 $ staticApp (embeddedSettings $(embedDir "test")) - { ssIndices = [] - , ssMaxAge = NoMaxAge - } +main = + run 3000 $ + staticApp + (embeddedSettings $(embedDir "test")) + { ssIndices = [] + , ssMaxAge = NoMaxAge + } diff --git a/wai-app-static/sample.hs b/wai-app-static/sample.hs index b9cc2510f..c15355ced 100644 --- a/wai-app-static/sample.hs +++ b/wai-app-static/sample.hs @@ -1,17 +1,19 @@ {-# LANGUAGE OverloadedStrings #-} +import Data.Maybe (mapMaybe) import Data.String import Data.Text (pack) -import Data.Maybe (mapMaybe) +import Network.Wai.Application.Static (defaultFileServerSettings, staticApp) +import Network.Wai.Handler.Warp (defaultSettings, runSettings, settingsPort) import WaiAppStatic.Types (ssIndices, toPiece) -import Network.Wai.Application.Static (staticApp,defaultFileServerSettings) -import Network.Wai.Handler.Warp (runSettings, settingsPort, defaultSettings) - + main :: IO () -main = runSettings defaultSettings - { - settingsPort = 3000 - } $ staticApp (defaultFileServerSettings $ fromString ".") - { - ssIndices = mapMaybe (toPiece . pack) ["index.html"] - } +main = + runSettings + defaultSettings + { settingsPort = 3000 + } + $ staticApp + (defaultFileServerSettings $ fromString ".") + { ssIndices = mapMaybe (toPiece . pack) ["index.html"] + } diff --git a/wai-app-static/test/EmbeddedTestEntries.hs b/wai-app-static/test/EmbeddedTestEntries.hs index e0fd67ada..dc4fe47fd 100644 --- a/wai-app-static/test/EmbeddedTestEntries.hs +++ b/wai-app-static/test/EmbeddedTestEntries.hs @@ -1,55 +1,59 @@ -{-# LANGUAGE TemplateHaskell, QuasiQuotes, OverloadedStrings #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} module EmbeddedTestEntries where -import WaiAppStatic.Storage.Embedded +import qualified Data.ByteString.Lazy as BL import qualified Data.Text as T import qualified Data.Text.Lazy as TL import qualified Data.Text.Lazy.Encoding as TL -import qualified Data.ByteString.Lazy as BL +import WaiAppStatic.Storage.Embedded body :: Int -> Char -> BL.ByteString body i c = TL.encodeUtf8 $ TL.pack $ replicate i c mkEntries :: IO [EmbeddableEntry] -mkEntries = return - -- An entry that should be compressed - [ EmbeddableEntry "e1.txt" - "text/plain" - (Left ("Etag 1", body 1000 'A')) - - -- An entry so short that the compressed text is longer - , EmbeddableEntry "e2.txt" - "text/plain" - (Left ("Etag 2", "ABC")) - - -- An entry that is not compressed because of the mime - , EmbeddableEntry "somedir/e3.txt" - "xxx" - (Left ("Etag 3", body 1000 'A')) - - -- A reloadable entry - , EmbeddableEntry "e4.css" - "text/css" - (Right [| return ("Etag 4" :: T.Text, body 2000 'Q') |]) - - -- An entry without etag - , EmbeddableEntry "e5.txt" - "text/plain" - (Left ("", body 1000 'Z')) - - -- A reloadable entry without etag - , EmbeddableEntry "e6.txt" - "text/plain" - (Right [| return ("" :: T.Text, body 1000 'W') |] ) - - -- An index file - , EmbeddableEntry "index.html" - "text/html" - (Right [| return ("" :: T.Text, "index file") |] ) - - -- An index file in a subdir - , EmbeddableEntry "foo/index.html" - "text/html" - (Right [| return ("" :: T.Text, "index file in subdir") |] ) - ] +mkEntries = + return + -- An entry that should be compressed + [ EmbeddableEntry + "e1.txt" + "text/plain" + (Left ("Etag 1", body 1000 'A')) + , -- An entry so short that the compressed text is longer + EmbeddableEntry + "e2.txt" + "text/plain" + (Left ("Etag 2", "ABC")) + , -- An entry that is not compressed because of the mime + EmbeddableEntry + "somedir/e3.txt" + "xxx" + (Left ("Etag 3", body 1000 'A')) + , -- A reloadable entry + EmbeddableEntry + "e4.css" + "text/css" + (Right [|return ("Etag 4" :: T.Text, body 2000 'Q')|]) + , -- An entry without etag + EmbeddableEntry + "e5.txt" + "text/plain" + (Left ("", body 1000 'Z')) + , -- A reloadable entry without etag + EmbeddableEntry + "e6.txt" + "text/plain" + (Right [|return ("" :: T.Text, body 1000 'W')|]) + , -- An index file + EmbeddableEntry + "index.html" + "text/html" + (Right [|return ("" :: T.Text, "index file")|]) + , -- An index file in a subdir + EmbeddableEntry + "foo/index.html" + "text/html" + (Right [|return ("" :: T.Text, "index file in subdir")|]) + ] diff --git a/wai-app-static/test/WaiAppEmbeddedTest.hs b/wai-app-static/test/WaiAppEmbeddedTest.hs index 009370b35..fee556997 100644 --- a/wai-app-static/test/WaiAppEmbeddedTest.hs +++ b/wai-app-static/test/WaiAppEmbeddedTest.hs @@ -1,4 +1,6 @@ -{-# LANGUAGE TemplateHaskell, OverloadedStrings #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} + module WaiAppEmbeddedTest (embSpec) where import Codec.Compression.GZip (compress) @@ -28,27 +30,32 @@ embSpec = do assertBody (compress $ body 1000 'A') req it "304 when valid if-none-match sent" $ embed $ do - req <- request (setRawPathInfo defRequest "e1.txt") - { requestHeaders = [("If-None-Match", "Etag 1")] } + req <- + request + (setRawPathInfo defRequest "e1.txt") + { requestHeaders = [("If-None-Match", "Etag 1")] + } assertStatus 304 req it "ssIndices works" $ do - let testSettings = $(mkSettings mkEntries){ - ssIndices = [unsafeToPiece "index.html"] - } + let testSettings = + $(mkSettings mkEntries) + { ssIndices = [unsafeToPiece "index.html"] + } embedSettings testSettings $ do - req <- request defRequest - assertStatus 200 req - assertBody "index file" req + req <- request defRequest + assertStatus 200 req + assertBody "index file" req it "ssIndices works with trailing slashes" $ do - let testSettings = $(mkSettings mkEntries){ - ssIndices = [unsafeToPiece "index.html"] - } + let testSettings = + $(mkSettings mkEntries) + { ssIndices = [unsafeToPiece "index.html"] + } embedSettings testSettings $ do - req <- request (setRawPathInfo defRequest "/foo/") - assertStatus 200 req - assertBody "index file in subdir" req + req <- request (setRawPathInfo defRequest "/foo/") + assertStatus 200 req + assertBody "index file in subdir" req describe "embedded, uncompressed entry" $ do it "too short" $ embed $ do @@ -68,13 +75,14 @@ embSpec = do assertBody (body 1000 'A') req describe "reloadable entry" $ - it "served correctly" $ embed $ do - req <- request (setRawPathInfo defRequest "e4.css") - assertStatus 200 req - assertHeader "Content-Type" "text/css" req - assertNoHeader "Content-Encoding" req - assertHeader "ETag" "Etag 4" req - assertBody (body 2000 'Q') req + it "served correctly" $ + embed $ do + req <- request (setRawPathInfo defRequest "e4.css") + assertStatus 200 req + assertHeader "Content-Type" "text/css" req + assertNoHeader "Content-Encoding" req + assertHeader "ETag" "Etag 4" req + assertBody (body 2000 'Q') req describe "entries without etags" $ do it "embedded entry" $ embed $ do diff --git a/wai-app-static/test/WaiAppStaticTest.hs b/wai-app-static/test/WaiAppStaticTest.hs index 6c47d4acb..d24701b24 100644 --- a/wai-app-static/test/WaiAppStaticTest.hs +++ b/wai-app-static/test/WaiAppStaticTest.hs @@ -1,19 +1,24 @@ -{-# LANGUAGE OverloadedStrings, NoMonomorphismRestriction #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE NoMonomorphismRestriction #-} + module WaiAppStaticTest (spec) where import Network.Wai.Application.Static import WaiAppStatic.Types +import qualified Data.ByteString.Char8 as S8 import Test.Hspec import Test.Mockery.Directory -import qualified Data.ByteString.Char8 as S8 + -- import qualified Data.ByteString.Lazy.Char8 as L8 -import System.PosixCompat.Files (getFileStatus, modificationTime) + import System.FilePath import System.IO.Temp +import System.PosixCompat.Files (getFileStatus, modificationTime) import Network.HTTP.Date import Network.HTTP.Types (status500) + {-import System.Locale (defaultTimeLocale)-} {-import Data.Time.Format (formatTime)-} @@ -29,157 +34,181 @@ defRequest = defaultRequest spec :: Spec spec = do - let webApp = flip runSession $ staticApp $ defaultWebAppSettings "test" - let fileServerAppWithSettings settings = flip runSession $ staticApp settings - let fileServerApp = fileServerAppWithSettings (defaultFileServerSettings "test") - { ssAddTrailingSlash = True - } - - let etag = "1B2M2Y8AsgTpgAmY7PhCfg==" - let file = "a/b" - let statFile = setRawPathInfo defRequest file - - describe "mime types" $ do - it "fileNameExtensions" $ - fileNameExtensions "foo.tar.gz" `shouldBe` ["tar.gz", "gz"] - it "handles multi-extensions" $ - defaultMimeLookup "foo.tar.gz" `shouldBe` "application/x-tgz" - it "defaults correctly" $ - defaultMimeLookup "foo.unknown" `shouldBe` "application/octet-stream" - - describe "webApp" $ do - it "403 for unsafe paths" $ webApp $ - forM_ ["..", "."] $ \path -> - assertStatus 403 =<< - request (setRawPathInfo defRequest path) - - it "200 for hidden paths" $ webApp $ - forM_ [".hidden/folder.png", ".hidden/haskell.png"] $ \path -> - assertStatus 200 =<< - request (setRawPathInfo defRequest path) - - it "404 for non-existent files" $ webApp $ - assertStatus 404 =<< - request (setRawPathInfo defRequest "doesNotExist") - - it "302 redirect when multiple slashes" $ webApp $ do - req <- request (setRawPathInfo defRequest "a//b/c") - assertStatus 302 req - assertHeader "Location" "../../a/b/c" req - - let absoluteApp = flip runSession $ staticApp $ (defaultWebAppSettings "test") { - ssMkRedirect = \_ u -> S8.append "http://www.example.com" u - } - it "302 redirect when multiple slashes" $ absoluteApp $ - forM_ ["/a//b/c", "a//b/c"] $ \path -> do - req <- request (setRawPathInfo defRequest path) - assertStatus 302 req - assertHeader "Location" "http://www.example.com/a/b/c" req - - describe "webApp when requesting a static asset" $ do - it "200 and etag when no etag query parameters" $ webApp $ do - req <- request statFile - assertStatus 200 req - assertHeader "ETag" etag req - assertNoHeader "Last-Modified" req - - it "Cache-Control set when etag parameter is correct" $ webApp $ do - req <- request statFile { queryString = [("etag", Just etag)] } - assertStatus 200 req - assertHeader "Cache-Control" "public, max-age=31536000" req - assertNoHeader "Last-Modified" req - - it "200 when invalid in-none-match sent" $ webApp $ - forM_ ["cached", ""] $ \badETag -> do - req <- request statFile { requestHeaders = [("If-None-Match", badETag)] } - assertStatus 200 req - assertHeader "ETag" etag req - assertNoHeader "Last-Modified" req - - it "304 when valid if-none-match sent" $ webApp $ do - req <- request statFile { requestHeaders = [("If-None-Match", etag)] } - assertStatus 304 req - assertNoHeader "Etag" req - assertNoHeader "Last-Modified" req - - describe "fileServerApp" $ do - let fileDate = do - stat <- liftIO $ getFileStatus $ "test/" ++ file - return $ formatHTTPDate . epochTimeToHTTPDate $ modificationTime stat - - it "directory listing for index" $ fileServerApp $ do - resp <- request (setRawPathInfo defRequest "a/") - assertStatus 200 resp - -- note the unclosed img tags so both /> and > will pass - assertBodyContains "\"Folder\""b" resp - - it "200 when invalid if-modified-since header" $ fileServerApp $ do - forM_ ["123", ""] $ \badDate -> do - req <- request statFile { - requestHeaders = [("If-Modified-Since", badDate)] - } - assertStatus 200 req - fdate <- fileDate - assertHeader "Last-Modified" fdate req - - it "304 when if-modified-since matches" $ fileServerApp $ do - fdate <- fileDate - req <- request statFile { - requestHeaders = [("If-Modified-Since", fdate)] - } - assertStatus 304 req - assertNoHeader "Cache-Control" req - - context "302 redirect to add a trailing slash on directories if missing" $ do - it "works at the root" $ fileServerApp $ do - req <- request (setRawPathInfo defRequest "/a") - assertStatus 302 req - assertHeader "Location" "/a/" req - - it "works when an index.html is delivered" $ do - let settings = (defaultFileServerSettings "."){ - ssAddTrailingSlash = True - } - inTempDirectory $ fileServerAppWithSettings settings $ do - liftIO $ touch "foo/index.html" - req <- request (setRawPathInfo defRequest "/foo") - assertStatus 302 req - assertHeader "Location" "/foo/" req - - let urlMapApp = flip runSession $ \req send -> - case pathInfo req of - "subPath":rest -> - let req' = req { pathInfo = rest } - in (staticApp (defaultFileServerSettings "test") - { ssAddTrailingSlash = True - }) req' send - _ -> send $ responseLBS status500 [] - "urlMapApp: only works at subPath" - it "works with subpath at the root of the file server" $ urlMapApp $ do - req <- request (setRawPathInfo defRequest "/subPath") - assertStatus 302 req - assertHeader "Location" "/subPath/" req - - context "with defaultWebAppSettings" $ do - it "ssIndices works" $ do - withSystemTempDirectory "wai-app-static-test" $ \ dir -> do - writeFile (dir "index.html") "foo" - let testSettings = (defaultWebAppSettings dir) { - ssIndices = [unsafeToPiece "index.html"] - } - fileServerAppWithSettings testSettings $ do - resp <- request (setRawPathInfo defRequest "/") + let webApp = flip runSession $ staticApp $ defaultWebAppSettings "test" + let fileServerAppWithSettings settings = flip runSession $ staticApp settings + let fileServerApp = + fileServerAppWithSettings + (defaultFileServerSettings "test") + { ssAddTrailingSlash = True + } + + let etag = "1B2M2Y8AsgTpgAmY7PhCfg==" + let file = "a/b" + let statFile = setRawPathInfo defRequest file + + describe "mime types" $ do + it "fileNameExtensions" $ + fileNameExtensions "foo.tar.gz" `shouldBe` ["tar.gz", "gz"] + it "handles multi-extensions" $ + defaultMimeLookup "foo.tar.gz" `shouldBe` "application/x-tgz" + it "defaults correctly" $ + defaultMimeLookup "foo.unknown" `shouldBe` "application/octet-stream" + + describe "webApp" $ do + it "403 for unsafe paths" $ + webApp $ + forM_ ["..", "."] $ \path -> + assertStatus 403 + =<< request (setRawPathInfo defRequest path) + + it "200 for hidden paths" $ + webApp $ + forM_ [".hidden/folder.png", ".hidden/haskell.png"] $ \path -> + assertStatus 200 + =<< request (setRawPathInfo defRequest path) + + it "404 for non-existent files" $ + webApp $ + assertStatus 404 + =<< request (setRawPathInfo defRequest "doesNotExist") + + it "302 redirect when multiple slashes" $ webApp $ do + req <- request (setRawPathInfo defRequest "a//b/c") + assertStatus 302 req + assertHeader "Location" "../../a/b/c" req + + let absoluteApp = + flip runSession $ + staticApp $ + (defaultWebAppSettings "test") + { ssMkRedirect = \_ u -> S8.append "http://www.example.com" u + } + it "302 redirect when multiple slashes" $ + absoluteApp $ + forM_ ["/a//b/c", "a//b/c"] $ \path -> do + req <- request (setRawPathInfo defRequest path) + assertStatus 302 req + assertHeader "Location" "http://www.example.com/a/b/c" req + + describe "webApp when requesting a static asset" $ do + it "200 and etag when no etag query parameters" $ webApp $ do + req <- request statFile + assertStatus 200 req + assertHeader "ETag" etag req + assertNoHeader "Last-Modified" req + + it "Cache-Control set when etag parameter is correct" $ webApp $ do + req <- request statFile{queryString = [("etag", Just etag)]} + assertStatus 200 req + assertHeader "Cache-Control" "public, max-age=31536000" req + assertNoHeader "Last-Modified" req + + it "200 when invalid in-none-match sent" $ + webApp $ + forM_ ["cached", ""] $ \badETag -> do + req <- request statFile{requestHeaders = [("If-None-Match", badETag)]} + assertStatus 200 req + assertHeader "ETag" etag req + assertNoHeader "Last-Modified" req + + it "304 when valid if-none-match sent" $ webApp $ do + req <- request statFile{requestHeaders = [("If-None-Match", etag)]} + assertStatus 304 req + assertNoHeader "Etag" req + assertNoHeader "Last-Modified" req + + describe "fileServerApp" $ do + let fileDate = do + stat <- liftIO $ getFileStatus $ "test/" ++ file + return $ formatHTTPDate . epochTimeToHTTPDate $ modificationTime stat + + it "directory listing for index" $ fileServerApp $ do + resp <- request (setRawPathInfo defRequest "a/") assertStatus 200 resp - assertBody "foo" resp - - context "with defaultFileServerSettings" $ do - it "prefers ssIndices over ssListing" $ do - withSystemTempDirectory "wai-app-static-test" $ \ dir -> do - writeFile (dir "index.html") "foo" - let testSettings = defaultFileServerSettings dir - fileServerAppWithSettings testSettings $ do - resp <- request (setRawPathInfo defRequest "/") - assertStatus 200 resp - assertBody "foo" resp + -- note the unclosed img tags so both /> and > will pass + assertBodyContains "\"Folder\""b" resp + + it "200 when invalid if-modified-since header" $ fileServerApp $ do + forM_ ["123", ""] $ \badDate -> do + req <- + request + statFile + { requestHeaders = [("If-Modified-Since", badDate)] + } + assertStatus 200 req + fdate <- fileDate + assertHeader "Last-Modified" fdate req + + it "304 when if-modified-since matches" $ fileServerApp $ do + fdate <- fileDate + req <- + request + statFile + { requestHeaders = [("If-Modified-Since", fdate)] + } + assertStatus 304 req + assertNoHeader "Cache-Control" req + + context "302 redirect to add a trailing slash on directories if missing" $ do + it "works at the root" $ fileServerApp $ do + req <- request (setRawPathInfo defRequest "/a") + assertStatus 302 req + assertHeader "Location" "/a/" req + + it "works when an index.html is delivered" $ do + let settings = + (defaultFileServerSettings ".") + { ssAddTrailingSlash = True + } + inTempDirectory $ fileServerAppWithSettings settings $ do + liftIO $ touch "foo/index.html" + req <- request (setRawPathInfo defRequest "/foo") + assertStatus 302 req + assertHeader "Location" "/foo/" req + + let urlMapApp = flip runSession $ \req send -> + case pathInfo req of + "subPath" : rest -> + let req' = req{pathInfo = rest} + in ( staticApp + (defaultFileServerSettings "test") + { ssAddTrailingSlash = True + } + ) + req' + send + _ -> + send $ + responseLBS + status500 + [] + "urlMapApp: only works at subPath" + it "works with subpath at the root of the file server" $ urlMapApp $ do + req <- request (setRawPathInfo defRequest "/subPath") + assertStatus 302 req + assertHeader "Location" "/subPath/" req + + context "with defaultWebAppSettings" $ do + it "ssIndices works" $ do + withSystemTempDirectory "wai-app-static-test" $ \dir -> do + writeFile (dir "index.html") "foo" + let testSettings = + (defaultWebAppSettings dir) + { ssIndices = [unsafeToPiece "index.html"] + } + fileServerAppWithSettings testSettings $ do + resp <- request (setRawPathInfo defRequest "/") + assertStatus 200 resp + assertBody "foo" resp + + context "with defaultFileServerSettings" $ do + it "prefers ssIndices over ssListing" $ do + withSystemTempDirectory "wai-app-static-test" $ \dir -> do + writeFile (dir "index.html") "foo" + let testSettings = defaultFileServerSettings dir + fileServerAppWithSettings testSettings $ do + resp <- request (setRawPathInfo defRequest "/") + assertStatus 200 resp + assertBody "foo" resp diff --git a/wai-app-static/tests.hs b/wai-app-static/tests.hs index f6e4d9ccf..f2bcb11f0 100644 --- a/wai-app-static/tests.hs +++ b/wai-app-static/tests.hs @@ -1,6 +1,6 @@ -import WaiAppStaticTest (spec) -import WaiAppEmbeddedTest (embSpec) import Test.Hspec +import WaiAppEmbeddedTest (embSpec) +import WaiAppStaticTest (spec) main :: IO () main = hspec $ spec >> embSpec diff --git a/wai-conduit/Network/Wai/Conduit.hs b/wai-conduit/Network/Wai/Conduit.hs index e2d18b572..d10bf8eee 100644 --- a/wai-conduit/Network/Wai/Conduit.hs +++ b/wai-conduit/Network/Wai/Conduit.hs @@ -1,23 +1,25 @@ -- | A light-weight wrapper around @Network.Wai@ to provide easy conduit support. -module Network.Wai.Conduit - ( -- * Request body - sourceRequestBody - -- * Response body - , responseSource - , responseRawSource - -- * Re-export - , module Network.Wai - ) where +module Network.Wai.Conduit ( + -- * Request body + sourceRequestBody, -import Network.Wai -import Data.Conduit + -- * Response body + responseSource, + responseRawSource, + + -- * Re-export + module Network.Wai, +) where + +import Control.Monad (unless) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.ByteString (ByteString) import qualified Data.ByteString as S -import Control.Monad (unless) -import Network.HTTP.Types import Data.ByteString.Builder (Builder) +import Data.Conduit import qualified Data.Conduit.List as CL +import Network.HTTP.Types +import Network.Wai -- | Stream the request body. -- @@ -37,12 +39,17 @@ sourceRequestBody req = -- | Create an HTTP response out of a @Source@. -- -- Since 3.0.0 -responseSource :: Status -> ResponseHeaders -> ConduitT () (Flush Builder) IO () -> Response +responseSource + :: Status -> ResponseHeaders -> ConduitT () (Flush Builder) IO () -> Response responseSource s hs src = responseStream s hs $ \send flush -> - runConduit $ src .| CL.mapM_ (\mbuilder -> - case mbuilder of - Chunk b -> send b - Flush -> flush) + runConduit $ + src + .| CL.mapM_ + ( \mbuilder -> + case mbuilder of + Chunk b -> send b + Flush -> flush + ) -- | Create a raw response using @Source@ and @Sink@ conduits. -- @@ -55,10 +62,11 @@ responseSource s hs src = responseStream s hs $ \send flush -> -- the handler does not support @responseRaw@. -- -- Since 3.0.0 -responseRawSource :: (MonadIO m, MonadIO n) - => (ConduitT () ByteString m () -> ConduitT ByteString Void n () -> IO ()) - -> Response - -> Response +responseRawSource + :: (MonadIO m, MonadIO n) + => (ConduitT () ByteString m () -> ConduitT ByteString Void n () -> IO ()) + -> Response + -> Response responseRawSource app = responseRaw app' where diff --git a/wai-conduit/Setup.hs b/wai-conduit/Setup.hs index 9a994af67..e8ef27dbb 100644 --- a/wai-conduit/Setup.hs +++ b/wai-conduit/Setup.hs @@ -1,2 +1,3 @@ import Distribution.Simple + main = defaultMain diff --git a/wai-conduit/example/Main.hs b/wai-conduit/example/Main.hs index a895f1c2f..eb3efbea6 100755 --- a/wai-conduit/example/Main.hs +++ b/wai-conduit/example/Main.hs @@ -1,19 +1,21 @@ #!/usr/bin/env stack -- stack --resolver lts-11.10 script -import Network.Wai -import Network.Wai.Handler.Warp -import Network.Wai.Conduit -import Network.HTTP.Types + import Conduit -import Data.ByteString.Builder (byteString) import Control.Monad.Trans.Resource +import Data.ByteString.Builder (byteString) +import Network.HTTP.Types +import Network.Wai +import Network.Wai.Conduit +import Network.Wai.Handler.Warp main :: IO () main = run 3000 app app :: Application app _req respond = - runResourceT $ withInternalState $ \is -> - respond $ responseSource status200 [] $ - transPipe (`runInternalState` is) (sourceFile "Main.hs") .| - mapC (Chunk . byteString) + runResourceT $ withInternalState $ \is -> + respond $ + responseSource status200 [] $ + transPipe (`runInternalState` is) (sourceFile "Main.hs") + .| mapC (Chunk . byteString) diff --git a/wai-extra/Network/Wai/EventSource.hs b/wai-extra/Network/Wai/EventSource.hs index c19a631cd..035f1a208 100644 --- a/wai-extra/Network/Wai/EventSource.hs +++ b/wai-extra/Network/Wai/EventSource.hs @@ -1,13 +1,12 @@ -{-| - A WAI adapter to the HTML5 Server-Sent Events API. - - If running through a proxy like Nginx you might need to add the - headers: - - > [ ("X-Accel-Buffering", "no"), ("Cache-Control", "no-cache")] - - There is a small example using these functions in the @example@ directory. --} +-- | +-- A WAI adapter to the HTML5 Server-Sent Events API. +-- +-- If running through a proxy like Nginx you might need to add the +-- headers: +-- +-- > [ ("X-Accel-Buffering", "no"), ("Cache-Control", "no-cache")] +-- +-- There is a small example using these functions in the @example@ directory. module Network.Wai.EventSource ( ServerEvent (..), eventSourceAppChan, @@ -34,29 +33,32 @@ eventSourceAppChan chan req sendResponse = do -- the given IO action. eventSourceAppIO :: IO ServerEvent -> Application eventSourceAppIO src _ sendResponse = - sendResponse $ responseStream - status200 - [(hContentType, "text/event-stream")] + sendResponse + $ responseStream + status200 + [(hContentType, "text/event-stream")] $ \sendChunk flush -> do flush fix $ \loop -> do se <- src case eventToBuilder se of Nothing -> return () - Just b -> sendChunk b >> flush >> loop + Just b -> sendChunk b >> flush >> loop -- | Make a new WAI EventSource application with a handler that emits events. -- -- @since 3.0.28 -eventStreamAppRaw :: ((ServerEvent -> IO()) -> IO () -> IO ()) -> Application +eventStreamAppRaw :: ((ServerEvent -> IO ()) -> IO () -> IO ()) -> Application eventStreamAppRaw handler _ sendResponse = - sendResponse $ responseStream - status200 - [(hContentType, "text/event-stream")] + sendResponse + $ responseStream + status200 + [(hContentType, "text/event-stream")] $ \sendChunk flush -> handler (sendEvent sendChunk) flush - where - sendEvent sendChunk event = - case eventToBuilder event of - Nothing -> return () - Just b -> sendChunk b + where + sendEvent sendChunk event = + case eventToBuilder event of + Nothing -> return () + Just b -> sendChunk b + {- HLint ignore eventStreamAppRaw "Use forM_" -} diff --git a/wai-extra/Network/Wai/EventSource/EventStream.hs b/wai-extra/Network/Wai/EventSource/EventStream.hs index c29620742..5cc0f8eab 100644 --- a/wai-extra/Network/Wai/EventSource/EventStream.hs +++ b/wai-extra/Network/Wai/EventSource/EventStream.hs @@ -1,9 +1,9 @@ {-# LANGUAGE CPP #-} + {- code adapted by Mathias Billman originally from Chris Smith https://github.com/cdsmith/gloss-web -} -{-| - Internal module, usually you don't need to use it. --} +-- | +-- Internal module, usually you don't need to use it. module Network.Wai.EventSource.EventStream ( ServerEvent (..), eventToBuilder, @@ -15,36 +15,31 @@ import Data.Monoid #endif import Data.Word8 (_colon, _lf) -{-| - Type representing a communication over an event stream. This can be an - actual event, a comment, a modification to the retry timer, or a special - "close" event indicating the server should close the connection. --} +-- | +-- Type representing a communication over an event stream. This can be an +-- actual event, a comment, a modification to the retry timer, or a special +-- "close" event indicating the server should close the connection. data ServerEvent - = ServerEvent { - eventName :: Maybe Builder, - eventId :: Maybe Builder, - eventData :: [Builder] + = ServerEvent + { eventName :: Maybe Builder + , eventId :: Maybe Builder + , eventData :: [Builder] } - | CommentEvent { - eventComment :: Builder + | CommentEvent + { eventComment :: Builder } - | RetryEvent { - eventRetry :: Int + | RetryEvent + { eventRetry :: Int } | CloseEvent - -{-| - Newline as a Builder. --} +-- | +-- Newline as a Builder. nl :: Builder nl = word8 _lf - -{-| - Field names as Builder --} +-- | +-- Field names as Builder nameField, idField, dataField, retryField, commentField :: Builder nameField = string7 "event:" idField = string7 "id:" @@ -52,26 +47,23 @@ dataField = string7 "data:" retryField = string7 "retry:" commentField = word8 _colon - -{-| - Wraps the text as a labeled field of an event stream. --} +-- | +-- Wraps the text as a labeled field of an event stream. field :: Builder -> Builder -> Builder field l b = l `mappend` b `mappend` nl - -{-| - Converts a 'ServerEvent' to its wire representation as specified by the - @text/event-stream@ content type. --} +-- | +-- Converts a 'ServerEvent' to its wire representation as specified by the +-- @text/event-stream@ content type. eventToBuilder :: ServerEvent -> Maybe Builder eventToBuilder (CommentEvent txt) = Just $ field commentField txt -eventToBuilder (RetryEvent n) = Just $ field retryField (string8 . show $ n) -eventToBuilder CloseEvent = Nothing -eventToBuilder (ServerEvent n i d)= Just $ - name n (evid i $ mconcat (map (field dataField) d)) `mappend` nl +eventToBuilder (RetryEvent n) = Just $ field retryField (string8 . show $ n) +eventToBuilder CloseEvent = Nothing +eventToBuilder (ServerEvent n i d) = + Just $ + name n (evid i $ mconcat (map (field dataField) d)) `mappend` nl where - name Nothing = id + name Nothing = id name (Just n') = mappend (field nameField n') - evid Nothing = id - evid (Just i') = mappend (field idField i') + evid Nothing = id + evid (Just i') = mappend (field idField i') diff --git a/wai-extra/Network/Wai/Handler/SCGI.hs b/wai-extra/Network/Wai/Handler/SCGI.hs index 149f66320..b27b9c900 100644 --- a/wai-extra/Network/Wai/Handler/SCGI.hs +++ b/wai-extra/Network/Wai/Handler/SCGI.hs @@ -1,9 +1,10 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE ForeignFunctionInterface #-} -module Network.Wai.Handler.SCGI - ( run - , runSendfile - ) where + +module Network.Wai.Handler.SCGI ( + run, + runSendfile, +) where import Data.ByteString (ByteString) import qualified Data.ByteString as S @@ -35,8 +36,12 @@ runOne sf app = do (i, _) <- listToMaybe $ reads conLenS pure i conLenI <- newIORef conLen - runGeneric headers (requestBodyFunc $ input socket conLenI) - (write socket) sf app + runGeneric + headers + (requestBodyFunc $ input socket conLenI) + (write socket) + sf + app drain socket conLenI _ <- c'close socket return () @@ -52,8 +57,9 @@ input socket ilen rlen = do case len of 0 -> return Nothing _ -> do - bs <- readByteString socket - $ minimum [defaultChunkSize, len, rlen] + bs <- + readByteString socket $ + minimum [defaultChunkSize, len, rlen] writeIORef ilen $ len - S.length bs return $ Just bs @@ -64,7 +70,7 @@ drain socket ilen = do return () parseHeaders :: [S.ByteString] -> [(String, String)] -parseHeaders (x:y:z) = (S8.unpack x, S8.unpack y) : parseHeaders z +parseHeaders (x : y : z) = (S8.unpack x, S8.unpack y) : parseHeaders z parseHeaders _ = [] readNetstring :: CInt -> IO S.ByteString diff --git a/wai-extra/Network/Wai/Header.hs b/wai-extra/Network/Wai/Header.hs index 884b47e0c..0fe116dc4 100644 --- a/wai-extra/Network/Wai/Header.hs +++ b/wai-extra/Network/Wai/Header.hs @@ -1,9 +1,9 @@ -- | Some helpers for dealing with WAI 'Header's. -module Network.Wai.Header - ( contentLength - , parseQValueList - , replaceHeader - ) where +module Network.Wai.Header ( + contentLength, + parseQValueList, + replaceHeader, +) where import Control.Monad (guard) import qualified Data.ByteString as S @@ -49,7 +49,9 @@ parseQValueList = fmap go . splitCommas checkQ (val, bs) = -- RFC 7231 says optional whitespace can be around the semicolon. -- So drop any before it , . and any behind it $ and drop the semicolon - (dropWhileEnd (== _space) val, parseQval . S.dropWhile (== _space) $ S.drop 1 bs) + ( dropWhileEnd (== _space) val + , parseQval . S.dropWhile (== _space) $ S.drop 1 bs + ) where parseQval qVal = do q <- S.stripPrefix "q=" qVal diff --git a/wai-extra/Network/Wai/Middleware/AcceptOverride.hs b/wai-extra/Network/Wai/Middleware/AcceptOverride.hs index 22183ab70..99e56bb37 100644 --- a/wai-extra/Network/Wai/Middleware/AcceptOverride.hs +++ b/wai-extra/Network/Wai/Middleware/AcceptOverride.hs @@ -1,10 +1,10 @@ -module Network.Wai.Middleware.AcceptOverride - ( -- $howto - acceptOverride - ) where +module Network.Wai.Middleware.AcceptOverride ( + -- $howto + acceptOverride, +) where -import Network.Wai import Control.Monad (join) +import Network.Wai import Network.Wai.Header (replaceHeader) @@ -31,6 +31,7 @@ acceptOverride app req = req' = case join $ lookup "_accept" $ queryString req of Nothing -> req - Just a -> req { - requestHeaders = replaceHeader "Accept" a $ requestHeaders req - } + Just a -> + req + { requestHeaders = replaceHeader "Accept" a $ requestHeaders req + } diff --git a/wai-extra/Network/Wai/Middleware/AddHeaders.hs b/wai-extra/Network/Wai/Middleware/AddHeaders.hs index 246f7c573..f213085a8 100644 --- a/wai-extra/Network/Wai/Middleware/AddHeaders.hs +++ b/wai-extra/Network/Wai/Middleware/AddHeaders.hs @@ -1,9 +1,9 @@ -- | -- -- Since 3.0.3 -module Network.Wai.Middleware.AddHeaders - ( addHeaders - ) where +module Network.Wai.Middleware.AddHeaders ( + addHeaders, +) where import Control.Arrow (first) import Data.ByteString (ByteString) @@ -12,12 +12,10 @@ import Network.HTTP.Types (Header) import Network.Wai (Middleware, mapResponseHeaders, modifyResponse) import Network.Wai.Internal (Response (..)) - addHeaders :: [(ByteString, ByteString)] -> Middleware -- ^ Prepend a list of headers without any checks -- -- Since 3.0.3 - addHeaders h = modifyResponse $ addHeaders' (map (first CI.mk) h) addHeaders' :: [Header] -> Response -> Response diff --git a/wai-extra/Network/Wai/Middleware/Approot.hs b/wai-extra/Network/Wai/Middleware/Approot.hs index e6c162789..c4ae1b96d 100644 --- a/wai-extra/Network/Wai/Middleware/Approot.hs +++ b/wai-extra/Network/Wai/Middleware/Approot.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DeriveDataTypeable #-} + -- | Middleware for establishing the root of the application. -- -- Many application need the ability to create URLs referring back to the @@ -13,18 +14,20 @@ -- @/foo/bar?baz=bin@. For example, if your application is hosted on -- example.com using HTTPS, the approot would be @https://example.com@. Note -- the lack of a trailing slash. -module Network.Wai.Middleware.Approot - ( -- * Middleware - approotMiddleware - -- * Common providers - , envFallback - , envFallbackNamed - , hardcoded - , fromRequest - -- * Functions for applications - , getApproot - , getApprootMay - ) where +module Network.Wai.Middleware.Approot ( + -- * Middleware + approotMiddleware, + + -- * Common providers + envFallback, + envFallbackNamed, + hardcoded, + fromRequest, + + -- * Functions for applications + getApproot, + getApprootMay, +) where import Control.Exception (Exception, throw) import Data.ByteString (ByteString) @@ -48,11 +51,13 @@ approotKey = unsafePerformIO V.newKey -- functionality more conveniently. -- -- Since 3.0.7 -approotMiddleware :: (Request -> IO ByteString) -- ^ get the approot - -> Middleware +approotMiddleware + :: (Request -> IO ByteString) + -- ^ get the approot + -> Middleware approotMiddleware getRoot app req respond = do ar <- getRoot req - let req' = req { vault = V.insert approotKey ar $ vault req } + let req' = req{vault = V.insert approotKey ar $ vault req} app req' respond -- | Same as @'envFallbackNamed' "APPROOT"@. diff --git a/wai-extra/Network/Wai/Middleware/Autohead.hs b/wai-extra/Network/Wai/Middleware/Autohead.hs index 400f973b9..674ff9d86 100644 --- a/wai-extra/Network/Wai/Middleware/Autohead.hs +++ b/wai-extra/Network/Wai/Middleware/Autohead.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} + -- | Automatically produce responses to HEAD requests based on the underlying -- applications GET response. module Network.Wai.Middleware.Autohead (autohead) where @@ -6,11 +7,16 @@ module Network.Wai.Middleware.Autohead (autohead) where #if __GLASGOW_HASKELL__ < 710 import Data.Monoid (mempty) #endif -import Network.Wai (Middleware, requestMethod, responseBuilder, responseToStream) +import Network.Wai ( + Middleware, + requestMethod, + responseBuilder, + responseToStream, + ) autohead :: Middleware autohead app req sendResponse - | requestMethod req == "HEAD" = app req { requestMethod = "GET" } $ \res -> do + | requestMethod req == "HEAD" = app req{requestMethod = "GET"} $ \res -> do let (s, hs, _) = responseToStream res sendResponse $ responseBuilder s hs mempty | otherwise = app req sendResponse diff --git a/wai-extra/Network/Wai/Middleware/CleanPath.hs b/wai-extra/Network/Wai/Middleware/CleanPath.hs index 8a647269b..5e23ab7e9 100644 --- a/wai-extra/Network/Wai/Middleware/CleanPath.hs +++ b/wai-extra/Network/Wai/Middleware/CleanPath.hs @@ -1,7 +1,8 @@ {-# LANGUAGE CPP #-} -module Network.Wai.Middleware.CleanPath - ( cleanPath - ) where + +module Network.Wai.Middleware.CleanPath ( + cleanPath, +) where import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy as L @@ -12,10 +13,11 @@ import Data.Text (Text) import Network.HTTP.Types (hLocation, status301) import Network.Wai (Application, pathInfo, rawQueryString, responseLBS) -cleanPath :: ([Text] -> Either B.ByteString [Text]) - -> B.ByteString - -> ([Text] -> Application) - -> Application +cleanPath + :: ([Text] -> Either B.ByteString [Text]) + -> B.ByteString + -> ([Text] -> Application) + -> Application cleanPath splitter prefix app env sendResponse = case splitter $ pathInfo env of Right pieces -> app pieces env sendResponse @@ -25,10 +27,10 @@ cleanPath splitter prefix app env sendResponse = status301 [(hLocation, mconcat [prefix, p, suffix])] L.empty - where - -- include the query string if present - suffix = - case B.uncons $ rawQueryString env of - Nothing -> B.empty - Just ('?', _) -> rawQueryString env - _ -> B.cons '?' $ rawQueryString env + where + -- include the query string if present + suffix = + case B.uncons $ rawQueryString env of + Nothing -> B.empty + Just ('?', _) -> rawQueryString env + _ -> B.cons '?' $ rawQueryString env diff --git a/wai-extra/Network/Wai/Middleware/CombineHeaders.hs b/wai-extra/Network/Wai/Middleware/CombineHeaders.hs index 59c1859be..0e6c5e408 100644 --- a/wai-extra/Network/Wai/Middleware/CombineHeaders.hs +++ b/wai-extra/Network/Wai/Middleware/CombineHeaders.hs @@ -1,34 +1,35 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RecordWildCards #-} -{- | -Sometimes incoming requests don't stick to the -"no duplicate headers" invariant, for a number -of possible reasons (e.g. proxy servers blindly -adding headers), or your application (or other -middleware) blindly adds headers. -In those cases, you can use this 'Middleware' -to make sure that headers that /can/ be combined -/are/ combined. (e.g. applications might only -check the first \"Accept\" header and fail, while -there might be another one that would match) - -} -module Network.Wai.Middleware.CombineHeaders - ( combineHeaders - , CombineSettings - , defaultCombineSettings - , HeaderMap - , HandleType - , defaultHeaderMap +-- | +-- Sometimes incoming requests don't stick to the +-- "no duplicate headers" invariant, for a number +-- of possible reasons (e.g. proxy servers blindly +-- adding headers), or your application (or other +-- middleware) blindly adds headers. +-- +-- In those cases, you can use this 'Middleware' +-- to make sure that headers that /can/ be combined +-- /are/ combined. (e.g. applications might only +-- check the first \"Accept\" header and fail, while +-- there might be another one that would match) +module Network.Wai.Middleware.CombineHeaders ( + combineHeaders, + CombineSettings, + defaultCombineSettings, + HeaderMap, + HandleType, + defaultHeaderMap, + -- * Adjusting the settings - , setHeader - , removeHeader - , setHeaderMap - , regular - , keepOnly - , setRequestHeaders - , setResponseHeaders - ) where + setHeader, + removeHeader, + setHeaderMap, + regular, + keepOnly, + setRequestHeaders, + setResponseHeaders, +) where import qualified Data.ByteString as B import qualified Data.List as L (foldl', reverse) @@ -36,7 +37,7 @@ import qualified Data.Map.Strict as M import Data.Word8 (_comma, _space, _tab) import Network.HTTP.Types (Header, HeaderName, RequestHeaders) import qualified Network.HTTP.Types.Header as H -import Network.Wai (Middleware, requestHeaders, mapResponseHeaders) +import Network.Wai (Middleware, mapResponseHeaders, requestHeaders) import Network.Wai.Util (dropWhileEnd) -- | The mapping of 'HeaderName' to 'HandleType' @@ -56,14 +57,15 @@ type HeaderMap = M.Map HeaderName HandleType -- combined) -- -- @since 3.1.13.0 -data CombineSettings = CombineSettings { - combineHeaderMap :: HeaderMap, +data CombineSettings = CombineSettings + { combineHeaderMap :: HeaderMap -- ^ Which headers should be combined? And how? (cf. 'HandleType') - combineRequestHeaders :: Bool, + , combineRequestHeaders :: Bool -- ^ Should request headers be combined? - combineResponseHeaders :: Bool + , combineResponseHeaders :: Bool -- ^ Should response headers be combined? -} deriving (Eq, Show) + } + deriving (Eq, Show) -- | Settings that combine request headers, -- but don't touch response headers. @@ -111,11 +113,12 @@ data CombineSettings = CombineSettings { -- -- @since 3.1.13.0 defaultCombineSettings :: CombineSettings -defaultCombineSettings = CombineSettings { - combineHeaderMap = defaultHeaderMap, - combineRequestHeaders = True, - combineResponseHeaders = False -} +defaultCombineSettings = + CombineSettings + { combineHeaderMap = defaultHeaderMap + , combineRequestHeaders = True + , combineResponseHeaders = False + } -- | Override the 'HeaderMap' of the 'CombineSettings' -- (default: 'defaultHeaderMap') @@ -144,18 +147,18 @@ setResponseHeaders b set = set{combineResponseHeaders = b} -- @since 3.1.13.0 setHeader :: HeaderName -> HandleType -> CombineSettings -> CombineSettings setHeader name typ settings = - settings { - combineHeaderMap = M.insert name typ $ combineHeaderMap settings - } + settings + { combineHeaderMap = M.insert name typ $ combineHeaderMap settings + } -- | Convenience function to remove a header from the header map. -- -- @since 3.1.13.0 removeHeader :: HeaderName -> CombineSettings -> CombineSettings removeHeader name settings = - settings { - combineHeaderMap = M.delete name $ combineHeaderMap settings - } + settings + { combineHeaderMap = M.delete name $ combineHeaderMap settings + } -- | This middleware will reorganize the incoming and/or outgoing -- headers in such a way that it combines any duplicates of @@ -180,7 +183,7 @@ combineHeaders CombineSettings{..} app req resFunc = app newReq $ resFunc . adjustRes where newReq - | combineRequestHeaders = req { requestHeaders = mkNewHeaders oldHeaders } + | combineRequestHeaders = req{requestHeaders = mkNewHeaders oldHeaders} | otherwise = req oldHeaders = requestHeaders req adjustRes @@ -191,14 +194,16 @@ combineHeaders CombineSettings{..} app req resFunc = go acc hdr@(name, _) = M.alter (checkHeader hdr) name acc checkHeader :: Header -> Maybe HeaderHandling -> Maybe HeaderHandling - checkHeader (name, newVal) = Just . \case - Nothing -> (name `M.lookup` combineHeaderMap, [newVal]) - -- Yes, this reverses the order of headers, but these - -- will be reversed again in 'finishHeaders' - Just (mHandleType, hdrs) -> (mHandleType, newVal : hdrs) + checkHeader (name, newVal) = + Just . \case + Nothing -> (name `M.lookup` combineHeaderMap, [newVal]) + -- Yes, this reverses the order of headers, but these + -- will be reversed again in 'finishHeaders' + Just (mHandleType, hdrs) -> (mHandleType, newVal : hdrs) -- | Unpack 'HeaderHandling' back into 'Header's again -finishHeaders :: HeaderName -> HeaderHandling -> RequestHeaders -> RequestHeaders +finishHeaders + :: HeaderName -> HeaderHandling -> RequestHeaders -> RequestHeaders finishHeaders name (shouldCombine, xs) hdrs = case shouldCombine of Just typ -> (name, combinedHeader typ) : hdrs @@ -230,7 +235,7 @@ type HeaderHandling = (Maybe HandleType, [B.ByteString]) data HandleType = Regular | KeepOnly B.ByteString - deriving (Eq, Show) + deriving (Eq, Show) -- | Use the regular strategy when combining headers. -- (i.e. merge into one header and separate values with commas) @@ -256,44 +261,42 @@ keepOnly = KeepOnly -- -- @since 3.1.13.0 defaultHeaderMap :: HeaderMap -defaultHeaderMap = M.fromList - [ (H.hAccept, Regular) - , ("Accept-CH", Regular) - , (H.hAcceptCharset, Regular) - , (H.hAcceptEncoding, Regular) - , (H.hAcceptLanguage, Regular) - , ("Accept-Post", Regular) - , ("Access-Control-Allow-Headers" , Regular) -- wildcard? yes, but can just add to list - , ("Access-Control-Allow-Methods" , Regular) -- wildcard? yes, but can just add to list - , ("Access-Control-Expose-Headers" , Regular) -- wildcard? yes, but can just add to list - , ("Access-Control-Request-Headers", Regular) - , (H.hAllow, Regular) - , ("Alt-Svc", KeepOnly "clear") -- special "clear" value (if any is "clear", only keep that one) - , (H.hCacheControl, Regular) - , ("Clear-Site-Data", KeepOnly "*") -- wildcard (if any is "*", only keep that one) - - -- If "close" and anything else is used together, it's already F-ed, - -- so just combine them. - , (H.hConnection, Regular) - - , (H.hContentEncoding, Regular) - , (H.hContentLanguage, Regular) - , ("Digest", Regular) - - -- We could handle this, but it's experimental AND - -- will be replaced by "Permissions-Policy" - -- , "Feature-Policy" -- "semicolon ';' separated" +defaultHeaderMap = + M.fromList + [ (H.hAccept, Regular) + , ("Accept-CH", Regular) + , (H.hAcceptCharset, Regular) + , (H.hAcceptEncoding, Regular) + , (H.hAcceptLanguage, Regular) + , ("Accept-Post", Regular) + , ("Access-Control-Allow-Headers", Regular) -- wildcard? yes, but can just add to list + , ("Access-Control-Allow-Methods", Regular) -- wildcard? yes, but can just add to list + , ("Access-Control-Expose-Headers", Regular) -- wildcard? yes, but can just add to list + , ("Access-Control-Request-Headers", Regular) + , (H.hAllow, Regular) + , ("Alt-Svc", KeepOnly "clear") -- special "clear" value (if any is "clear", only keep that one) + , (H.hCacheControl, Regular) + , ("Clear-Site-Data", KeepOnly "*") -- wildcard (if any is "*", only keep that one) + , -- If "close" and anything else is used together, it's already F-ed, + -- so just combine them. + (H.hConnection, Regular) + , (H.hContentEncoding, Regular) + , (H.hContentLanguage, Regular) + , ("Digest", Regular) + , -- We could handle this, but it's experimental AND + -- will be replaced by "Permissions-Policy" + -- , "Feature-Policy" -- "semicolon ';' separated" - , (H.hIfMatch, Regular) - , (H.hIfNoneMatch, KeepOnly "*") -- wildcard? (if any is "*", only keep that one) - , ("Link", Regular) - , ("Permissions-Policy", Regular) - , (H.hTE, Regular) - , ("Timing-Allow-Origin", KeepOnly "*") -- wildcard? (if any is "*", only keep that one) - , (H.hTrailer, Regular) - , (H.hTransferEncoding, Regular) - , (H.hUpgrade, Regular) - , (H.hVia, Regular) - , (H.hVary, KeepOnly "*") -- wildcard? (if any is "*", only keep that one) - , ("Want-Digest", Regular) - ] + (H.hIfMatch, Regular) + , (H.hIfNoneMatch, KeepOnly "*") -- wildcard? (if any is "*", only keep that one) + , ("Link", Regular) + , ("Permissions-Policy", Regular) + , (H.hTE, Regular) + , ("Timing-Allow-Origin", KeepOnly "*") -- wildcard? (if any is "*", only keep that one) + , (H.hTrailer, Regular) + , (H.hTransferEncoding, Regular) + , (H.hUpgrade, Regular) + , (H.hVia, Regular) + , (H.hVary, KeepOnly "*") -- wildcard? (if any is "*", only keep that one) + , ("Want-Digest", Regular) + ] diff --git a/wai-extra/Network/Wai/Middleware/ForceDomain.hs b/wai-extra/Network/Wai/Middleware/ForceDomain.hs index 482644973..f39a011c9 100644 --- a/wai-extra/Network/Wai/Middleware/ForceDomain.hs +++ b/wai-extra/Network/Wai/Middleware/ForceDomain.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} + -- | -- -- @since 3.0.14 @@ -28,16 +29,15 @@ forceDomain checkDomain app req sendResponse = app req sendResponse Just domain -> sendResponse $ redirectResponse domain + where + -- From: Network.Wai.Middleware.ForceSSL + redirectResponse domain = + responseBuilder status [(hLocation, location domain)] mempty - where - -- From: Network.Wai.Middleware.ForceSSL - redirectResponse domain = - responseBuilder status [(hLocation, location domain)] mempty - - location h = - let p = if appearsSecure req then "https://" else "http://" in - p <> h <> rawPathInfo req <> rawQueryString req + location h = + let p = if appearsSecure req then "https://" else "http://" + in p <> h <> rawPathInfo req <> rawQueryString req - status - | requestMethod req == methodGet = status301 - | otherwise = status307 + status + | requestMethod req == methodGet = status301 + | otherwise = status307 diff --git a/wai-extra/Network/Wai/Middleware/ForceSSL.hs b/wai-extra/Network/Wai/Middleware/ForceSSL.hs index 41aff2368..b5a7cc0fe 100644 --- a/wai-extra/Network/Wai/Middleware/ForceSSL.hs +++ b/wai-extra/Network/Wai/Middleware/ForceSSL.hs @@ -1,11 +1,11 @@ {-# LANGUAGE CPP #-} + -- | Redirect non-SSL requests to https -- -- Since 3.0.7 -module Network.Wai.Middleware.ForceSSL - ( forceSSL - ) where - +module Network.Wai.Middleware.ForceSSL ( + forceSSL, +) where #if __GLASGOW_HASKELL__ < 710 import Control.Applicative ((<$>)) @@ -26,12 +26,12 @@ forceSSL :: Middleware forceSSL app req sendResponse = case (appearsSecure req, redirectResponse req) of (False, Just resp) -> sendResponse resp - _ -> app req sendResponse + _ -> app req sendResponse redirectResponse :: Request -> Maybe Response redirectResponse req = do - host <- requestHeaderHost req - return $ responseBuilder status [(hLocation, location host)] mempty + host <- requestHeaderHost req + return $ responseBuilder status [(hLocation, location host)] mempty where location h = "https://" <> h <> rawPathInfo req <> rawQueryString req status diff --git a/wai-extra/Network/Wai/Middleware/Gzip.hs b/wai-extra/Network/Wai/Middleware/Gzip.hs index 04cad772e..c722815c9 100644 --- a/wai-extra/Network/Wai/Middleware/Gzip.hs +++ b/wai-extra/Network/Wai/Middleware/Gzip.hs @@ -1,4 +1,7 @@ --------------------------------------------------------- + +--------------------------------------------------------- + -- | -- Module : Network.Wai.Middleware.Gzip -- Copyright : Michael Snoyman @@ -9,34 +12,38 @@ -- Portability : portable -- -- Automatic gzip compression of responses. --- ---------------------------------------------------------- -module Network.Wai.Middleware.Gzip - ( -- * How to use this module - -- $howto +module Network.Wai.Middleware.Gzip ( + -- * How to use this module + -- $howto - -- ** The Middleware - -- $gzip - gzip + -- ** The Middleware + -- $gzip + gzip, - -- ** The Settings - -- $settings - , GzipSettings - , defaultGzipSettings - , gzipFiles - , gzipCheckMime - , gzipSizeThreshold + -- ** The Settings + -- $settings + GzipSettings, + defaultGzipSettings, + gzipFiles, + gzipCheckMime, + gzipSizeThreshold, - -- ** How to handle file responses - , GzipFiles (..) + -- ** How to handle file responses + GzipFiles (..), - -- ** Miscellaneous - -- $miscellaneous - , defaultCheckMime - , def - ) where + -- ** Miscellaneous + -- $miscellaneous + defaultCheckMime, + def, +) where -import Control.Exception (IOException, SomeException, fromException, throwIO, try) +import Control.Exception ( + IOException, + SomeException, + fromException, + throwIO, + try, + ) import Control.Monad (unless) import qualified Data.ByteString as S import Data.ByteString.Builder (byteString) @@ -109,8 +116,6 @@ import Network.Wai.Util (splitCommas, trimWS) -- evaluate to 'True' when applied to 'gzipCheckMime' -- (though 'GzipPreCompressed' will use the \".gz\" file regardless -- of MIME type on any 'ResponseFile' response) --- - -- $settings -- @@ -133,34 +138,34 @@ import Network.Wai.Util (splitCommas, trimWS) -- @ data GzipSettings = GzipSettings - { -- | Gzip behavior for files - -- - -- Only applies to 'ResponseFile' ('responseFile') responses. - -- So any streamed data will be compressed based solely on the - -- response headers having the right \"Content-Type\" and - -- \"Content-Length\". (which are checked with 'gzipCheckMime' - -- and 'gzipSizeThreshold', respectively) - gzipFiles :: GzipFiles - -- | Decide which files to compress based on MIME type - -- - -- The 'S.ByteString' is the value of the \"Content-Type\" response - -- header and will default to 'False' if the header is missing. - -- - -- E.g. if you'd only want to compress @json@ data, you might - -- define your own function as follows: - -- - -- > myCheckMime mime = mime == "application/json" + { gzipFiles :: GzipFiles + -- ^ Gzip behavior for files + -- + -- Only applies to 'ResponseFile' ('responseFile') responses. + -- So any streamed data will be compressed based solely on the + -- response headers having the right \"Content-Type\" and + -- \"Content-Length\". (which are checked with 'gzipCheckMime' + -- and 'gzipSizeThreshold', respectively) , gzipCheckMime :: S.ByteString -> Bool - -- | Skip compression when the size of the response body is - -- below this amount of bytes (default: 860.) - -- - -- /Setting this option to less than 150 will actually increase/ - -- /the size of outgoing data if its original size is less than 150 bytes/. - -- - -- This will only skip compression if the response includes a - -- \"Content-Length\" header /AND/ the length is less than this - -- threshold. + -- ^ Decide which files to compress based on MIME type + -- + -- The 'S.ByteString' is the value of the \"Content-Type\" response + -- header and will default to 'False' if the header is missing. + -- + -- E.g. if you'd only want to compress @json@ data, you might + -- define your own function as follows: + -- + -- > myCheckMime mime = mime == "application/json" , gzipSizeThreshold :: Integer + -- ^ Skip compression when the size of the response body is + -- below this amount of bytes (default: 860.) + -- + -- /Setting this option to less than 150 will actually increase/ + -- /the size of outgoing data if its original size is less than 150 bytes/. + -- + -- This will only skip compression if the response includes a + -- \"Content-Length\" header /AND/ the length is less than this + -- threshold. } -- | Gzip behavior for files. @@ -218,12 +223,13 @@ defaultCheckMime bs = S8.isPrefixOf "text/" bs || bs' `Set.member` toCompress where bs' = fst $ S.break (== _semicolon) bs - toCompress = Set.fromList - [ "application/json" - , "application/javascript" - , "application/ecmascript" - , "image/x-icon" - ] + toCompress = + Set.fromList + [ "application/json" + , "application/javascript" + , "application/ecmascript" + , "image/x-icon" + ] -- | Use gzip to compress the body of the response. gzip :: GzipSettings -> Middleware @@ -233,14 +239,14 @@ gzip set app req sendResponse' let runAction x = case x of (ResponseRaw{}, _) -> sendResponse res -- Always skip if 'GzipIgnore' - (ResponseFile {}, GzipIgnore) -> sendResponse res + (ResponseFile{}, GzipIgnore) -> sendResponse res -- If there's a compressed version of the file, we send that. (ResponseFile s hs file Nothing, GzipPreCompressed nextAction) -> let compressedVersion = file ++ ".gz" - in doesFileExist compressedVersion >>= \y -> - if y - then sendResponse $ ResponseFile s (fixHeaders hs) compressedVersion Nothing - else runAction (ResponseFile s hs file Nothing, nextAction) + in doesFileExist compressedVersion >>= \y -> + if y + then sendResponse $ ResponseFile s (fixHeaders hs) compressedVersion Nothing + else runAction (ResponseFile s hs file Nothing, nextAction) -- Skip if it's not a MIME type we want to compress _ | not $ isCorrectMime (responseHeaders res) -> sendResponse res -- Use static caching logic @@ -252,12 +258,12 @@ gzip set app req sendResponse' in compressFile s hs file mETag cache sendResponse -- Use streaming logic _ -> compressE res sendResponse - in runAction (res, gzipFiles set) + in runAction (res, gzipFiles set) where isCorrectMime = maybe False (gzipCheckMime set) . lookup hContentType sendResponse = sendResponse' . mapResponseHeaders mAddVary - acceptEncoding = "Accept-Encoding" + acceptEncoding = "Accept-Encoding" acceptEncodingLC = "accept-encoding" -- Instead of just adding a header willy-nilly, we check if -- "Vary" is already present, and add to it if not already included. @@ -268,8 +274,9 @@ gzip set app req sendResponse' lowercase = S.map W8.toLower -- Field names are case-insensitive, so we lowercase to match hasAccEnc = acceptEncodingLC `elem` fmap lowercase vals - newH | hasAccEnc = h - | otherwise = (hVary, acceptEncoding <> ", " <> val) + newH + | hasAccEnc = h + | otherwise = (hVary, acceptEncoding <> ", " <> val) in newH : hs | otherwise = h : mAddVary hs @@ -288,7 +295,8 @@ gzip set app req sendResponse' maybe False ("MSIE 6" `S.isInfixOf`) $ hUserAgent `lookup` reqHdrs -- Can we skip just by looking at the current 'Response'? - checkCompress :: (Response -> IO ResponseReceived) -> Response -> IO ResponseReceived + checkCompress + :: (Response -> IO ResponseReceived) -> Response -> IO ResponseReceived checkCompress continue res = if isEncodedAlready || isPartial || tooSmall then sendResponse res @@ -311,7 +319,14 @@ gzip set app req sendResponse' minimumLength :: Integer minimumLength = 860 -compressFile :: Status -> [Header] -> FilePath -> Maybe S.ByteString -> FilePath -> (Response -> IO a) -> IO a +compressFile + :: Status + -> [Header] + -> FilePath + -> Maybe S.ByteString + -> FilePath + -> (Response -> IO a) + -> IO a compressFile s hs file mETag cache sendResponse = do e <- doesFileExist tmpfile if e @@ -319,25 +334,25 @@ compressFile s hs file mETag cache sendResponse = do else do createDirectoryIfMissing True cache x <- try $ - IO.withBinaryFile file IO.ReadMode $ \inH -> - IO.withBinaryFile tmpfile IO.WriteMode $ \outH -> do - deflate <- Z.initDeflate 7 $ Z.WindowBits 31 - -- FIXME this code should write to a temporary file, then - -- rename to the final file - let goPopper popper = fix $ \loop -> do - res <- popper - case res of - Z.PRDone -> return () - Z.PRNext bs -> do - S.hPut outH bs - loop - Z.PRError ex -> throwIO ex - fix $ \loop -> do - bs <- S.hGetSome inH defaultChunkSize - unless (S.null bs) $ do - Z.feedDeflate deflate bs >>= goPopper - loop - goPopper $ Z.finishDeflate deflate + IO.withBinaryFile file IO.ReadMode $ \inH -> + IO.withBinaryFile tmpfile IO.WriteMode $ \outH -> do + deflate <- Z.initDeflate 7 $ Z.WindowBits 31 + -- FIXME this code should write to a temporary file, then + -- rename to the final file + let goPopper popper = fix $ \loop -> do + res <- popper + case res of + Z.PRDone -> return () + Z.PRNext bs -> do + S.hPut outH bs + loop + Z.PRError ex -> throwIO ex + fix $ \loop -> do + bs <- S.hGetSome inH defaultChunkSize + unless (S.null bs) $ do + Z.feedDeflate deflate bs >>= goPopper + loop + goPopper $ Z.finishDeflate deflate either onErr (const onSucc) (x :: Either SomeException ()) where onSucc = sendResponse $ responseFile s (fixHeaders hs) tmpfile Nothing @@ -365,9 +380,10 @@ compressFile s hs file mETag cache sendResponse = do safe '_' = '_' safe _ = '_' -compressE :: Response - -> (Response -> IO ResponseReceived) - -> IO ResponseReceived +compressE + :: Response + -> (Response -> IO ResponseReceived) + -> IO ResponseReceived compressE res sendResponse = wb $ \body -> sendResponse $ responseStream s (fixHeaders hs) $ \sendChunk flush -> do diff --git a/wai-extra/Network/Wai/Middleware/HealthCheckEndpoint.hs b/wai-extra/Network/Wai/Middleware/HealthCheckEndpoint.hs index 8318d2f2f..d4e4df92f 100644 --- a/wai-extra/Network/Wai/Middleware/HealthCheckEndpoint.hs +++ b/wai-extra/Network/Wai/Middleware/HealthCheckEndpoint.hs @@ -1,4 +1,7 @@ --------------------------------------------------------- + +--------------------------------------------------------- + -- | -- Module : Network.Wai.Middleware.HealthCheckEndpoint -- Copyright : Michael Snoyman @@ -9,12 +12,10 @@ -- Portability : portable -- -- Add empty endpoint (for Health check tests) --- ---------------------------------------------------------- -module Network.Wai.Middleware.HealthCheckEndpoint - ( healthCheck, +module Network.Wai.Middleware.HealthCheckEndpoint ( + healthCheck, voidEndpoint, - ) +) where import Data.ByteString (ByteString) @@ -32,6 +33,6 @@ healthCheck = voidEndpoint "/_healthz" -- @since 3.1.9 voidEndpoint :: ByteString -> Middleware voidEndpoint endpointPath router request respond = - if rawPathInfo request == endpointPath - then respond $ responseLBS status200 mempty "-" - else router request respond + if rawPathInfo request == endpointPath + then respond $ responseLBS status200 mempty "-" + else router request respond diff --git a/wai-extra/Network/Wai/Middleware/HttpAuth.hs b/wai-extra/Network/Wai/Middleware/HttpAuth.hs index 2fb9c9b05..2e56011f8 100644 --- a/wai-extra/Network/Wai/Middleware/HttpAuth.hs +++ b/wai-extra/Network/Wai/Middleware/HttpAuth.hs @@ -1,22 +1,24 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TupleSections #-} + -- | Implements HTTP Basic Authentication. -- -- This module may add digest authentication in the future. -module Network.Wai.Middleware.HttpAuth - ( -- * Middleware - basicAuth - , basicAuth' - , CheckCreds - , AuthSettings - , authRealm - , authOnNoAuth - , authIsProtected - -- * Helping functions - , extractBasicAuth - , extractBearerAuth - ) where +module Network.Wai.Middleware.HttpAuth ( + -- * Middleware + basicAuth, + basicAuth', + CheckCreds, + AuthSettings, + authRealm, + authOnNoAuth, + authIsProtected, + + -- * Helping functions + extractBasicAuth, + extractBearerAuth, +) where #if __GLASGOW_HASKELL__ < 710 import Control.Applicative @@ -27,31 +29,38 @@ import Data.ByteString.Base64 (decodeLenient) import Data.String (IsString (..)) import Data.Word8 (isSpace, toLower, _colon) import Network.HTTP.Types (hAuthorization, hContentType, status401) -import Network.Wai (Application, Middleware, Request (requestHeaders), responseLBS) - +import Network.Wai ( + Application, + Middleware, + Request (requestHeaders), + responseLBS, + ) -- | Check if a given username and password is valid. -type CheckCreds = ByteString - -> ByteString - -> IO Bool +type CheckCreds = + ByteString + -> ByteString + -> IO Bool -- | Perform basic authentication. -- -- > basicAuth (\u p -> return $ u == "michael" && p == "mypass") "My Realm" -- -- @since 1.3.4 -basicAuth :: CheckCreds - -> AuthSettings - -> Middleware +basicAuth + :: CheckCreds + -> AuthSettings + -> Middleware basicAuth = basicAuth' . const -- | Like 'basicAuth', but also passes a request to the authentication function. -- -- @since 3.0.19 -basicAuth' :: (Request -> CheckCreds) - -> AuthSettings - -> Middleware -basicAuth' checkCreds AuthSettings {..} app req sendResponse = do +basicAuth' + :: (Request -> CheckCreds) + -> AuthSettings + -> Middleware +basicAuth' checkCreds AuthSettings{..} app req sendResponse = do isProtected <- authIsProtected req allowed <- if isProtected then check else return True if allowed @@ -60,7 +69,7 @@ basicAuth' checkCreds AuthSettings {..} app req sendResponse = do where check = case lookup hAuthorization (requestHeaders req) - >>= extractBasicAuth of + >>= extractBasicAuth of Nothing -> return False Just (username, password) -> checkCreds req username password @@ -91,20 +100,26 @@ data AuthSettings = AuthSettings } instance IsString AuthSettings where - fromString s = AuthSettings - { authRealm = fromString s - , authOnNoAuth = \realm _req f -> f $ responseLBS - status401 - [ (hContentType, "text/plain") - , ("WWW-Authenticate", S.concat - [ "Basic realm=\"" - , realm - , "\"" - ]) - ] - "Basic authentication is required" - , authIsProtected = const $ return True - } + fromString s = + AuthSettings + { authRealm = fromString s + , authOnNoAuth = \realm _req f -> + f $ + responseLBS + status401 + [ (hContentType, "text/plain") + , + ( "WWW-Authenticate" + , S.concat + [ "Basic realm=\"" + , realm + , "\"" + ] + ) + ] + "Basic authentication is required" + , authIsProtected = const $ return True + } -- | Extract basic authentication data from usually __Authorization__ -- header value. Returns username and password @@ -113,14 +128,14 @@ instance IsString AuthSettings where extractBasicAuth :: ByteString -> Maybe (ByteString, ByteString) extractBasicAuth bs = let (x, y) = S.break isSpace bs - in if S.map toLower x == "basic" - then extract $ S.dropWhile isSpace y - else Nothing + in if S.map toLower x == "basic" + then extract $ S.dropWhile isSpace y + else Nothing where extract encoded = let raw = decodeLenient encoded (username, password') = S.break (== _colon) raw - in (username,) . snd <$> S.uncons password' + in (username,) . snd <$> S.uncons password' -- | Extract bearer authentication data from __Authorization__ header -- value. Returns bearer token @@ -129,6 +144,6 @@ extractBasicAuth bs = extractBearerAuth :: ByteString -> Maybe ByteString extractBearerAuth bs = let (x, y) = S.break isSpace bs - in if S.map toLower x == "bearer" - then Just $ S.dropWhile isSpace y - else Nothing + in if S.map toLower x == "bearer" + then Just $ S.dropWhile isSpace y + else Nothing diff --git a/wai-extra/Network/Wai/Middleware/Jsonp.hs b/wai-extra/Network/Wai/Middleware/Jsonp.hs index 661d62411..efe873469 100644 --- a/wai-extra/Network/Wai/Middleware/Jsonp.hs +++ b/wai-extra/Network/Wai/Middleware/Jsonp.hs @@ -1,5 +1,9 @@ {-# LANGUAGE CPP #-} + +--------------------------------------------------------- + --------------------------------------------------------- + -- | -- Module : Network.Wai.Middleware.Jsonp -- Copyright : Michael Snoyman @@ -10,8 +14,6 @@ -- Portability : portable -- -- Automatic wrapping of JSON responses to convert into JSONP. --- ---------------------------------------------------------- module Network.Wai.Middleware.Jsonp (jsonp) where import Control.Monad (join) @@ -46,10 +48,13 @@ jsonp app env sendResponse = do let env' = case callback of Nothing -> env - Just _ -> env - { requestHeaders = changeVal hAccept - "application/json" - $ requestHeaders env + Just _ -> + env + { requestHeaders = + changeVal + hAccept + "application/json" + $ requestHeaders env } app env' $ \res -> case callback of @@ -59,11 +64,12 @@ jsonp app env sendResponse = do go c r@(ResponseBuilder s hs b) = sendResponse $ case checkJSON hs of Nothing -> r - Just hs' -> responseBuilder s hs' $ - byteStringCopy c - `mappend` char7 '(' - `mappend` b - `mappend` char7 ')' + Just hs' -> + responseBuilder s hs' $ + byteStringCopy c + `mappend` char7 '(' + `mappend` b + `mappend` char7 ')' go c r = case checkJSON hs of Just hs' -> addCallback c s hs' wb @@ -85,10 +91,12 @@ jsonp app env sendResponse = do _ <- body sendChunk flush sendChunk $ char7 ')' -changeVal :: Eq a - => a - -> ByteString - -> [(a, ByteString)] - -> [(a, ByteString)] -changeVal key val old = (key, val) - : filter (\(k, _) -> k /= key) old +changeVal + :: Eq a + => a + -> ByteString + -> [(a, ByteString)] + -> [(a, ByteString)] +changeVal key val old = + (key, val) + : filter (\(k, _) -> k /= key) old diff --git a/wai-extra/Network/Wai/Middleware/MethodOverride.hs b/wai-extra/Network/Wai/Middleware/MethodOverride.hs index 16b509f7b..254c930a9 100644 --- a/wai-extra/Network/Wai/Middleware/MethodOverride.hs +++ b/wai-extra/Network/Wai/Middleware/MethodOverride.hs @@ -1,6 +1,6 @@ -module Network.Wai.Middleware.MethodOverride - ( methodOverride - ) where +module Network.Wai.Middleware.MethodOverride ( + methodOverride, +) where import Control.Monad (join) import Network.Wai (Middleware, queryString, requestMethod) @@ -16,5 +16,5 @@ methodOverride app req = where req' = case (requestMethod req, join $ lookup "_method" $ queryString req) of - ("POST", Just m) -> req { requestMethod = m } + ("POST", Just m) -> req{requestMethod = m} _ -> req diff --git a/wai-extra/Network/Wai/Middleware/MethodOverridePost.hs b/wai-extra/Network/Wai/Middleware/MethodOverridePost.hs index 5a4804386..2bef4dd27 100644 --- a/wai-extra/Network/Wai/Middleware/MethodOverridePost.hs +++ b/wai-extra/Network/Wai/Middleware/MethodOverridePost.hs @@ -1,12 +1,15 @@ {-# LANGUAGE CPP #-} + +----------------------------------------------------------------- + ----------------------------------------------------------------- + -- | Module : Network.Wai.Middleware.MethodOverridePost -- -- Changes the request-method via first post-parameter _method. ------------------------------------------------------------------ -module Network.Wai.Middleware.MethodOverridePost - ( methodOverridePost - ) where +module Network.Wai.Middleware.MethodOverridePost ( + methodOverridePost, +) where import Data.ByteString.Lazy (toChunks) import Data.IORef (atomicModifyIORef, newIORef) @@ -26,20 +29,20 @@ import Network.Wai -- parameter. -- -- * This middleware only applies when the initial request method is POST. --- methodOverridePost :: Middleware methodOverridePost app req send = case (requestMethod req, lookup hContentType (requestHeaders req)) of - ("POST", Just "application/x-www-form-urlencoded") -> setPost req >>= flip app send - _ -> app req send + ("POST", Just "application/x-www-form-urlencoded") -> setPost req >>= flip app send + _ -> app req send setPost :: Request -> IO Request setPost req = do - body <- (mconcat . toChunks) `fmap` lazyRequestBody req - ref <- newIORef body - let rb = atomicModifyIORef ref $ \bs -> (mempty, bs) - req' = setRequestBodyChunks rb req - case parseQuery body of - (("_method", Just newmethod):_) -> return req' {requestMethod = newmethod} - _ -> return req' + body <- (mconcat . toChunks) `fmap` lazyRequestBody req + ref <- newIORef body + let rb = atomicModifyIORef ref $ \bs -> (mempty, bs) + req' = setRequestBodyChunks rb req + case parseQuery body of + (("_method", Just newmethod) : _) -> return req'{requestMethod = newmethod} + _ -> return req' + {- HLint ignore setPost "Use tuple-section" -} diff --git a/wai-extra/Network/Wai/Middleware/RealIp.hs b/wai-extra/Network/Wai/Middleware/RealIp.hs index 0aabefae4..8444634ef 100644 --- a/wai-extra/Network/Wai/Middleware/RealIp.hs +++ b/wai-extra/Network/Wai/Middleware/RealIp.hs @@ -1,11 +1,11 @@ -- | Infer the remote IP address using headers -module Network.Wai.Middleware.RealIp - ( realIp - , realIpHeader - , realIpTrusted - , defaultTrusted - , ipInRange - ) where +module Network.Wai.Middleware.RealIp ( + realIp, + realIpHeader, + realIpTrusted, + defaultTrusted, + ipInRange, +) where import qualified Data.ByteString.Char8 as B8 (split, unpack) import qualified Data.IP as IP @@ -51,23 +51,25 @@ realIpTrusted :: HeaderName -> (IP.IP -> Bool) -> Middleware realIpTrusted header isTrusted app req = app req' where req' = fromMaybe req $ do - (ip, port) <- IP.fromSockAddr (remoteHost req) - ip' <- if isTrusted ip - then findRealIp (requestHeaders req) header isTrusted - else Nothing - Just $ req { remoteHost = IP.toSockAddr (ip', port) } + (ip, port) <- IP.fromSockAddr (remoteHost req) + ip' <- + if isTrusted ip + then findRealIp (requestHeaders req) header isTrusted + else Nothing + Just $ req{remoteHost = IP.toSockAddr (ip', port)} -- | Standard private IP ranges. -- -- @since 3.1.5 defaultTrusted :: [IP.IPRange] -defaultTrusted = [ "127.0.0.0/8" - , "10.0.0.0/8" - , "172.16.0.0/12" - , "192.168.0.0/16" - , "::1/128" - , "fc00::/7" - ] +defaultTrusted = + [ "127.0.0.0/8" + , "10.0.0.0/8" + , "172.16.0.0/12" + , "192.168.0.0/16" + , "::1/128" + , "fc00::/7" + ] -- | Check if the given IP address is in the given range. -- @@ -81,14 +83,13 @@ ipInRange (IP.IPv6 ip) (IP.IPv6Range r) = ip `IP.isMatchedTo` r ipInRange (IP.IPv4 ip) (IP.IPv6Range r) = IP.ipv4ToIPv6 ip `IP.isMatchedTo` r ipInRange _ _ = False - findRealIp :: RequestHeaders -> HeaderName -> (IP.IP -> Bool) -> Maybe IP.IP findRealIp reqHeaders header isTrusted = case (nonTrusted, ips) of - ([], xs) -> listToMaybe xs - (xs, _) -> listToMaybe $ reverse xs + ([], xs) -> listToMaybe xs + (xs, _) -> listToMaybe $ reverse xs where -- account for repeated headers - headerVals = [ v | (k, v) <- reqHeaders, k == header ] + headerVals = [v | (k, v) <- reqHeaders, k == header] ips = concatMap (mapMaybe (readMaybe . B8.unpack) . B8.split ',') headerVals nonTrusted = filter (not . isTrusted) ips diff --git a/wai-extra/Network/Wai/Middleware/RequestLogger.hs b/wai-extra/Network/Wai/Middleware/RequestLogger.hs index d24112a3b..8974b6fee 100644 --- a/wai-extra/Network/Wai/Middleware/RequestLogger.hs +++ b/wai-extra/Network/Wai/Middleware/RequestLogger.hs @@ -4,31 +4,32 @@ -- NOTE: Due to https://github.com/yesodweb/wai/issues/192, this module should -- not use CPP. -- EDIT: Fixed this by adding two "zero-width spaces" in between the "*/*" -module Network.Wai.Middleware.RequestLogger - ( -- * Basic stdout logging - logStdout - , logStdoutDev - -- * Create more versions - , mkRequestLogger - , RequestLoggerSettings - , defaultRequestLoggerSettings - , outputFormat - , autoFlush - , destination - , OutputFormat (..) - , ApacheSettings - , defaultApacheSettings - , setApacheIPAddrSource - , setApacheRequestFilter - , setApacheUserGetter - , DetailedSettings (..) - , OutputFormatter - , OutputFormatterWithDetails - , OutputFormatterWithDetailsAndHeaders - , Destination (..) - , Callback - , IPAddrSource (..) - ) where +module Network.Wai.Middleware.RequestLogger ( + -- * Basic stdout logging + logStdout, + logStdoutDev, + + -- * Create more versions + mkRequestLogger, + RequestLoggerSettings, + defaultRequestLoggerSettings, + outputFormat, + autoFlush, + destination, + OutputFormat (..), + ApacheSettings, + defaultApacheSettings, + setApacheIPAddrSource, + setApacheRequestFilter, + setApacheUserGetter, + DetailedSettings (..), + OutputFormatter, + OutputFormatterWithDetails, + OutputFormatterWithDetailsAndHeaders, + Destination (..), + Callback, + IPAddrSource (..), +) where import Control.Monad (when) import Control.Monad.IO.Class (liftIO) @@ -46,12 +47,17 @@ import Data.Monoid ((<>)) import Data.Text.Encoding (decodeUtf8') import Data.Time (NominalDiffTime, UTCTime, diffUTCTime, getCurrentTime) import Network.HTTP.Types as H -import Network.Wai - ( Request(..), requestBodyLength, RequestBodyLength(..) - , Middleware - , Response, responseStatus, responseHeaders - , getRequestBodyChunk, setRequestBodyChunks - ) +import Network.Wai ( + Middleware, + Request (..), + RequestBodyLength (..), + Response, + getRequestBodyChunk, + requestBodyLength, + responseHeaders, + responseStatus, + setRequestBodyChunks, + ) import Network.Wai.Internal (Response (..)) import Network.Wai.Logger import System.Console.ANSI @@ -61,24 +67,27 @@ import System.Log.FastLogger import Network.Wai.Header (contentLength) import Network.Wai.Middleware.RequestLogger.Internal -import Network.Wai.Parse - ( Param - , File - , fileName - , getRequestBodyType - , lbsBackEnd - , sinkRequestBody - ) +import Network.Wai.Parse ( + File, + Param, + fileName, + getRequestBodyType, + lbsBackEnd, + sinkRequestBody, + ) -- | The logging format. data OutputFormat - = Apache IPAddrSource - | ApacheWithSettings ApacheSettings -- ^ @since 3.1.8 - | Detailed Bool -- ^ use colors? - | DetailedWithSettings DetailedSettings -- ^ @since 3.1.3 - | CustomOutputFormat OutputFormatter - | CustomOutputFormatWithDetails OutputFormatterWithDetails - | CustomOutputFormatWithDetailsAndHeaders OutputFormatterWithDetailsAndHeaders + = Apache IPAddrSource + | -- | @since 3.1.8 + ApacheWithSettings ApacheSettings + | -- | use colors? + Detailed Bool + | -- | @since 3.1.3 + DetailedWithSettings DetailedSettings + | CustomOutputFormat OutputFormatter + | CustomOutputFormatWithDetails OutputFormatterWithDetails + | CustomOutputFormatWithDetailsAndHeaders OutputFormatterWithDetailsAndHeaders -- | Settings for the `ApacheWithSettings` `OutputFormat`. This is purposely kept as an abstract data -- type so that new settings can be added without breaking backwards @@ -95,11 +104,12 @@ data ApacheSettings = ApacheSettings } defaultApacheSettings :: ApacheSettings -defaultApacheSettings = ApacheSettings - { apacheIPAddrSource = FromSocket - , apacheRequestFilter = \_ _ -> True - , apacheUserGetter = const Nothing - } +defaultApacheSettings = + ApacheSettings + { apacheIPAddrSource = FromSocket + , apacheRequestFilter = \_ _ -> True + , apacheUserGetter = const Nothing + } -- | Where to take IP addresses for clients from. See 'IPAddrSource' for more information. -- @@ -107,7 +117,7 @@ defaultApacheSettings = ApacheSettings -- -- @since 3.1.8 setApacheIPAddrSource :: IPAddrSource -> ApacheSettings -> ApacheSettings -setApacheIPAddrSource x y = y { apacheIPAddrSource = x } +setApacheIPAddrSource x y = y{apacheIPAddrSource = x} -- | Function that allows you to filter which requests are logged, based on -- the request and response @@ -115,8 +125,9 @@ setApacheIPAddrSource x y = y { apacheIPAddrSource = x } -- Default: log all requests -- -- @since 3.1.8 -setApacheRequestFilter :: (Request -> Response -> Bool) -> ApacheSettings -> ApacheSettings -setApacheRequestFilter x y = y { apacheRequestFilter = x } +setApacheRequestFilter + :: (Request -> Response -> Bool) -> ApacheSettings -> ApacheSettings +setApacheRequestFilter x y = y{apacheRequestFilter = x} -- | Function that allows you to get the current user from the request, which -- will then be added in the log. @@ -124,8 +135,9 @@ setApacheRequestFilter x y = y { apacheRequestFilter = x } -- Default: return no user -- -- @since 3.1.8 -setApacheUserGetter :: (Request -> Maybe BS.ByteString) -> ApacheSettings -> ApacheSettings -setApacheUserGetter x y = y { apacheUserGetter = x } +setApacheUserGetter + :: (Request -> Maybe BS.ByteString) -> ApacheSettings -> ApacheSettings +setApacheUserGetter x y = y{apacheUserGetter = x} -- | Settings for the `Detailed` `OutputFormat`. -- @@ -143,28 +155,30 @@ data DetailedSettings = DetailedSettings { useColors :: Bool , mModifyParams :: Maybe (Param -> Maybe Param) , mFilterRequests :: Maybe (Request -> Response -> Bool) - , mPrelogRequests :: Bool -- ^ @since 3.1.7 + , mPrelogRequests :: Bool + -- ^ @since 3.1.7 } instance Default DetailedSettings where - def = DetailedSettings - { useColors = True - , mModifyParams = Nothing - , mFilterRequests = Nothing - , mPrelogRequests = False - } + def = + DetailedSettings + { useColors = True + , mModifyParams = Nothing + , mFilterRequests = Nothing + , mPrelogRequests = False + } type OutputFormatter = ZonedDate -> Request -> Status -> Maybe Integer -> LogStr -type OutputFormatterWithDetails - = ZonedDate - -> Request - -> Status - -> Maybe Integer - -> NominalDiffTime - -> [S8.ByteString] - -> B.Builder - -> LogStr +type OutputFormatterWithDetails = + ZonedDate + -> Request + -> Status + -> Maybe Integer + -> NominalDiffTime + -> [S8.ByteString] + -> B.Builder + -> LogStr -- | Same as @OutputFormatterWithDetails@ but with response headers included -- @@ -173,20 +187,29 @@ type OutputFormatterWithDetails -- header in your application and retrieve in the log formatter. -- -- @since 3.0.27 -type OutputFormatterWithDetailsAndHeaders - = ZonedDate -- ^ When the log message was generated - -> Request -- ^ The WAI request - -> Status -- ^ HTTP status code - -> Maybe Integer -- ^ Response size - -> NominalDiffTime -- ^ Duration of the request - -> [S8.ByteString] -- ^ The request body - -> B.Builder -- ^ Raw response - -> [Header] -- ^ The response headers - -> LogStr - -data Destination = Handle Handle - | Logger LoggerSet - | Callback Callback +type OutputFormatterWithDetailsAndHeaders = + ZonedDate + -- ^ When the log message was generated + -> Request + -- ^ The WAI request + -> Status + -- ^ HTTP status code + -> Maybe Integer + -- ^ Response size + -> NominalDiffTime + -- ^ Duration of the request + -> [S8.ByteString] + -- ^ The request body + -> B.Builder + -- ^ Raw response + -> [Header] + -- ^ The response headers + -> LogStr + +data Destination + = Handle Handle + | Logger LoggerSet + | Callback Callback type Callback = LogStr -> IO () @@ -196,23 +219,23 @@ type Callback = LogStr -> IO () -- for the record type @RequestLoggerSettings@, so they can be used to -- modify settings values using record syntax. data RequestLoggerSettings = RequestLoggerSettings - { - -- | Default value: @Detailed@ @True@. - outputFormat :: OutputFormat - -- | Only applies when using the @Handle@ constructor for @destination@. - -- - -- Default value: @True@. + { outputFormat :: OutputFormat + -- ^ Default value: @Detailed@ @True@. , autoFlush :: Bool - -- | Default: @Handle@ @stdout@. + -- ^ Only applies when using the @Handle@ constructor for @destination@. + -- + -- Default value: @True@. , destination :: Destination + -- ^ Default: @Handle@ @stdout@. } defaultRequestLoggerSettings :: RequestLoggerSettings -defaultRequestLoggerSettings = RequestLoggerSettings - { outputFormat = Detailed True - , autoFlush = True - , destination = Handle stdout - } +defaultRequestLoggerSettings = + RequestLoggerSettings + { outputFormat = Detailed True + , autoFlush = True + , destination = Handle stdout + } instance Default RequestLoggerSettings where def = defaultRequestLoggerSettings @@ -232,11 +255,16 @@ mkRequestLogger RequestLoggerSettings{..} = do return $ apacheMiddleware (\_ _ -> True) apache ApacheWithSettings ApacheSettings{..} -> do getdate <- getDateGetter flusher - apache <- initLoggerUser (Just apacheUserGetter) apacheIPAddrSource (LogCallback callback flusher) getdate + apache <- + initLoggerUser + (Just apacheUserGetter) + apacheIPAddrSource + (LogCallback callback flusher) + getdate return $ apacheMiddleware apacheRequestFilter apache Detailed useColors -> - let settings = def { useColors = useColors} - in detailedMiddleware callbackAndFlush settings + let settings = def{useColors = useColors} + in detailedMiddleware callbackAndFlush settings DetailedWithSettings settings -> detailedMiddleware callbackAndFlush settings CustomOutputFormat formatter -> do @@ -247,12 +275,15 @@ mkRequestLogger RequestLoggerSettings{..} = do return $ customMiddlewareWithDetails callbackAndFlush getdate formatter CustomOutputFormatWithDetailsAndHeaders formatter -> do getdate <- getDateGetter flusher - return $ customMiddlewareWithDetailsAndHeaders callbackAndFlush getdate formatter + return $ + customMiddlewareWithDetailsAndHeaders callbackAndFlush getdate formatter -apacheMiddleware :: (Request -> Response -> Bool) -> ApacheLoggerActions -> Middleware +apacheMiddleware + :: (Request -> Response -> Bool) -> ApacheLoggerActions -> Middleware apacheMiddleware applyRequestFilter ala app req sendResponse = app req $ \res -> do when (applyRequestFilter req res) $ - apacheLogger ala req (responseStatus res) $ contentLength (responseHeaders res) + apacheLogger ala req (responseStatus res) $ + contentLength (responseHeaders res) sendResponse res customMiddleware :: Callback -> IO ZonedDate -> OutputFormatter -> Middleware @@ -262,47 +293,53 @@ customMiddleware cb getdate formatter app req sendResponse = app req $ \res -> d liftIO $ cb $ formatter date req (responseStatus res) msize sendResponse res -customMiddlewareWithDetails :: Callback -> IO ZonedDate -> OutputFormatterWithDetails -> Middleware +customMiddlewareWithDetails + :: Callback -> IO ZonedDate -> OutputFormatterWithDetails -> Middleware customMiddlewareWithDetails cb getdate formatter app req sendResponse = do - (req', reqBody) <- getRequestBody req - t0 <- getCurrentTime - app req' $ \res -> do - t1 <- getCurrentTime - date <- liftIO getdate - let msize = contentLength (responseHeaders res) - builderIO <- newIORef $ B.byteString "" - res' <- recordChunks builderIO res - rspRcv <- sendResponse res' - _ <- liftIO . cb . - formatter date req' (responseStatus res') msize (t1 `diffUTCTime` t0) reqBody =<< - readIORef builderIO - return rspRcv - -customMiddlewareWithDetailsAndHeaders :: Callback -> IO ZonedDate -> OutputFormatterWithDetailsAndHeaders -> Middleware + (req', reqBody) <- getRequestBody req + t0 <- getCurrentTime + app req' $ \res -> do + t1 <- getCurrentTime + date <- liftIO getdate + let msize = contentLength (responseHeaders res) + builderIO <- newIORef $ B.byteString "" + res' <- recordChunks builderIO res + rspRcv <- sendResponse res' + _ <- + liftIO + . cb + . formatter date req' (responseStatus res') msize (t1 `diffUTCTime` t0) reqBody + =<< readIORef builderIO + return rspRcv + +customMiddlewareWithDetailsAndHeaders + :: Callback -> IO ZonedDate -> OutputFormatterWithDetailsAndHeaders -> Middleware customMiddlewareWithDetailsAndHeaders cb getdate formatter app req sendResponse = do - (req', reqBody) <- getRequestBody req - t0 <- getCurrentTime - app req' $ \res -> do - t1 <- getCurrentTime - date <- liftIO getdate - let msize = contentLength (responseHeaders res) - builderIO <- newIORef $ B.byteString "" - res' <- recordChunks builderIO res - rspRcv <- sendResponse res' - _ <- do - rawResponse <- readIORef builderIO - let status = responseStatus res' - duration = t1 `diffUTCTime` t0 - resHeaders = responseHeaders res' - liftIO . cb $ formatter date req' status msize duration reqBody rawResponse resHeaders - return rspRcv + (req', reqBody) <- getRequestBody req + t0 <- getCurrentTime + app req' $ \res -> do + t1 <- getCurrentTime + date <- liftIO getdate + let msize = contentLength (responseHeaders res) + builderIO <- newIORef $ B.byteString "" + res' <- recordChunks builderIO res + rspRcv <- sendResponse res' + _ <- do + rawResponse <- readIORef builderIO + let status = responseStatus res' + duration = t1 `diffUTCTime` t0 + resHeaders = responseHeaders res' + liftIO . cb $ + formatter date req' status msize duration reqBody rawResponse resHeaders + return rspRcv + -- | Production request logger middleware. -- -- This uses the 'Apache' logging format, and takes IP addresses for clients from -- the socket (see 'IPAddrSource' for more information). It logs to 'stdout'. {-# NOINLINE logStdout #-} logStdout :: Middleware -logStdout = unsafePerformIO $ mkRequestLogger def { outputFormat = Apache FromSocket } +logStdout = unsafePerformIO $ mkRequestLogger def{outputFormat = Apache FromSocket} -- | Development request logger middleware. -- @@ -333,16 +370,14 @@ logStdoutDev = unsafePerformIO $ mkRequestLogger def -- > Accept: text/css,*​/​*;q=0.1 -- > Status: 304 Not Modified 0.010555s detailedMiddleware :: Callback -> DetailedSettings -> IO Middleware - -- NB: The *​/​* in the comments above have "zero-width spaces" in them, so the -- CPP doesn't screw up everything. So don't copy those; they're technically wrong. detailedMiddleware cb settings = let (ansiColor, ansiMethod, ansiStatusCode) = - if useColors settings - then (ansiColor', ansiMethod', ansiStatusCode') - else (\_ t -> [t], (:[]), \_ t -> [t]) - - in return $ detailedMiddleware' cb settings ansiColor ansiMethod ansiStatusCode + if useColors settings + then (ansiColor', ansiMethod', ansiStatusCode') + else (\_ t -> [t], (: []), \_ t -> [t]) + in return $ detailedMiddleware' cb settings ansiColor ansiMethod ansiStatusCode ansiColor' :: Color -> BS.ByteString -> [BS.ByteString] ansiColor' color bs = @@ -354,115 +389,124 @@ ansiColor' color bs = -- | Tags http method with a unique color. ansiMethod' :: BS.ByteString -> [BS.ByteString] ansiMethod' m = case m of - "GET" -> ansiColor' Cyan m - "HEAD" -> ansiColor' Cyan m - "PUT" -> ansiColor' Green m - "POST" -> ansiColor' Yellow m + "GET" -> ansiColor' Cyan m + "HEAD" -> ansiColor' Cyan m + "PUT" -> ansiColor' Green m + "POST" -> ansiColor' Yellow m "DELETE" -> ansiColor' Red m - _ -> ansiColor' Magenta m + _ -> ansiColor' Magenta m ansiStatusCode' :: BS.ByteString -> BS.ByteString -> [BS.ByteString] ansiStatusCode' c t = case S8.take 1 c of - "2" -> ansiColor' Green t - "3" -> ansiColor' Yellow t - "4" -> ansiColor' Red t - "5" -> ansiColor' Magenta t - _ -> ansiColor' Blue t + "2" -> ansiColor' Green t + "3" -> ansiColor' Yellow t + "4" -> ansiColor' Red t + "5" -> ansiColor' Magenta t + _ -> ansiColor' Blue t recordChunks :: IORef B.Builder -> Response -> IO Response recordChunks i (ResponseStream s h sb) = - return . ResponseStream s h $ (\send flush -> sb (\b -> modifyIORef i (<> b) >> send b) flush) + return . ResponseStream s h $ + (\send flush -> sb (\b -> modifyIORef i (<> b) >> send b) flush) recordChunks i (ResponseBuilder s h b) = - modifyIORef i (<> b) >> return (ResponseBuilder s h b) + modifyIORef i (<> b) >> return (ResponseBuilder s h b) recordChunks _ r = - return r + return r getRequestBody :: Request -> IO (Request, [S8.ByteString]) getRequestBody req = do - let loop front = do - bs <- getRequestBodyChunk req - if S8.null bs - then return $ front [] - else loop $ front . (bs:) - body <- loop id - -- logging the body here consumes it, so fill it back up - -- obviously not efficient, but this is the development logger - -- - -- Note: previously, we simply used CL.sourceList. However, - -- that meant that you could read the request body in twice. - -- While that in itself is not a problem, the issue is that, - -- in production, you wouldn't be able to do this, and - -- therefore some bugs wouldn't show up during testing. This - -- implementation ensures that each chunk is only returned - -- once. - ichunks <- newIORef body - let rbody = atomicModifyIORef ichunks $ \chunks -> - case chunks of - [] -> ([], S8.empty) - x:y -> (y, x) - let req' = setRequestBodyChunks rbody req - return (req', body) + let loop front = do + bs <- getRequestBodyChunk req + if S8.null bs + then return $ front [] + else loop $ front . (bs :) + body <- loop id + -- logging the body here consumes it, so fill it back up + -- obviously not efficient, but this is the development logger + -- + -- Note: previously, we simply used CL.sourceList. However, + -- that meant that you could read the request body in twice. + -- While that in itself is not a problem, the issue is that, + -- in production, you wouldn't be able to do this, and + -- therefore some bugs wouldn't show up during testing. This + -- implementation ensures that each chunk is only returned + -- once. + ichunks <- newIORef body + let rbody = atomicModifyIORef ichunks $ \chunks -> + case chunks of + [] -> ([], S8.empty) + x : y -> (y, x) + let req' = setRequestBodyChunks rbody req + return (req', body) + {- HLint ignore getRequestBody "Use lambda-case" -} -detailedMiddleware' :: Callback - -> DetailedSettings - -> (Color -> BS.ByteString -> [BS.ByteString]) - -> (BS.ByteString -> [BS.ByteString]) - -> (BS.ByteString -> BS.ByteString -> [BS.ByteString]) - -> Middleware +detailedMiddleware' + :: Callback + -> DetailedSettings + -> (Color -> BS.ByteString -> [BS.ByteString]) + -> (BS.ByteString -> [BS.ByteString]) + -> (BS.ByteString -> BS.ByteString -> [BS.ByteString]) + -> Middleware detailedMiddleware' cb DetailedSettings{..} ansiColor ansiMethod ansiStatusCode app req sendResponse = do - (req', body) <- - -- second tuple item should not be necessary, but a test runner might mess it up - case (requestBodyLength req, contentLength (requestHeaders req)) of - -- log the request body if it is small - (KnownLength len, _) | len <= 2048 -> getRequestBody req - (_, Just len) | len <= 2048 -> getRequestBody req - _ -> return (req, []) - - let reqbodylog _ = if null body || isJust mModifyParams - then [""] - else ansiColor White " Request Body: " <> body <> ["\n"] - reqbody = concatMap (either (const [""]) reqbodylog . decodeUtf8') body - postParams <- if requestMethod req `elem` ["GET", "HEAD"] - then return [] - else do (unmodifiedPostParams, files) <- liftIO $ allPostParams body - let postParams = - case mModifyParams of - Just modifyParams -> mapMaybe modifyParams unmodifiedPostParams - Nothing -> unmodifiedPostParams - return $ collectPostParams (postParams, files) - - let getParams = map emptyGetParam $ queryString req - accept = fromMaybe "" $ lookup H.hAccept $ requestHeaders req - params = let par | not $ null postParams = [pack (show postParams)] - | not $ null getParams = [pack (show getParams)] - | otherwise = [] - in if null par then [""] else ansiColor White " Params: " <> par <> ["\n"] - - t0 <- getCurrentTime - - -- Optionally prelog the request - when mPrelogRequests $ - cb $ "PRELOGGING REQUEST: " <> mkRequestLog params reqbody accept - - app req' $ \rsp -> do - case mFilterRequests of - Just f | not $ f req' rsp -> pure () - _ -> do - let isRaw = - case rsp of - ResponseRaw{} -> True - _ -> False - stCode = statusBS rsp - stMsg = msgBS rsp - t1 <- getCurrentTime - - -- log the status of the response - cb $ - mkRequestLog params reqbody accept - <> mkResponseLog isRaw stCode stMsg t1 t0 - - sendResponse rsp + (req', body) <- + -- second tuple item should not be necessary, but a test runner might mess it up + case (requestBodyLength req, contentLength (requestHeaders req)) of + -- log the request body if it is small + (KnownLength len, _) | len <= 2048 -> getRequestBody req + (_, Just len) | len <= 2048 -> getRequestBody req + _ -> return (req, []) + + let reqbodylog _ = + if null body || isJust mModifyParams + then [""] + else ansiColor White " Request Body: " <> body <> ["\n"] + reqbody = concatMap (either (const [""]) reqbodylog . decodeUtf8') body + postParams <- + if requestMethod req `elem` ["GET", "HEAD"] + then return [] + else do + (unmodifiedPostParams, files) <- liftIO $ allPostParams body + let postParams = + case mModifyParams of + Just modifyParams -> mapMaybe modifyParams unmodifiedPostParams + Nothing -> unmodifiedPostParams + return $ collectPostParams (postParams, files) + + let getParams = map emptyGetParam $ queryString req + accept = fromMaybe "" $ lookup H.hAccept $ requestHeaders req + params = + let par + | not $ null postParams = [pack (show postParams)] + | not $ null getParams = [pack (show getParams)] + | otherwise = [] + in if null par then [""] else ansiColor White " Params: " <> par <> ["\n"] + + t0 <- getCurrentTime + + -- Optionally prelog the request + when mPrelogRequests $ + cb $ + "PRELOGGING REQUEST: " <> mkRequestLog params reqbody accept + + app req' $ \rsp -> do + case mFilterRequests of + Just f | not $ f req' rsp -> pure () + _ -> do + let isRaw = + case rsp of + ResponseRaw{} -> True + _ -> False + stCode = statusBS rsp + stMsg = msgBS rsp + t1 <- getCurrentTime + + -- log the status of the response + cb $ + mkRequestLog params reqbody accept + <> mkResponseLog isRaw stCode stMsg t1 t0 + + sendResponse rsp where allPostParams body = case getRequestBodyType req of @@ -472,37 +516,43 @@ detailedMiddleware' cb DetailedSettings{..} ansiColor ansiMethod ansiStatusCode let rbody = atomicModifyIORef ichunks $ \chunks -> case chunks of [] -> ([], S8.empty) - x:y -> (y, x) + x : y -> (y, x) sinkRequestBody lbsBackEnd rbt rbody - emptyGetParam :: (BS.ByteString, Maybe BS.ByteString) -> (BS.ByteString, BS.ByteString) - emptyGetParam (k, Just v) = (k,v) - emptyGetParam (k, Nothing) = (k,"") + emptyGetParam + :: (BS.ByteString, Maybe BS.ByteString) -> (BS.ByteString, BS.ByteString) + emptyGetParam (k, Just v) = (k, v) + emptyGetParam (k, Nothing) = (k, "") collectPostParams :: ([Param], [File LBS.ByteString]) -> [Param] - collectPostParams (postParams, files) = postParams ++ - map (\(k,v) -> (k, "FILE: " <> fileName v)) files + collectPostParams (postParams, files) = + postParams + ++ map (\(k, v) -> (k, "FILE: " <> fileName v)) files mkRequestLog :: (Foldable t, ToLogStr m) => t m -> t m -> m -> LogStr mkRequestLog params reqbody accept = foldMap toLogStr (ansiMethod (requestMethod req)) - <> " " - <> toLogStr (rawPathInfo req) - <> "\n" - <> foldMap toLogStr params - <> foldMap toLogStr reqbody - <> foldMap toLogStr (ansiColor White " Accept: ") - <> toLogStr accept - <> "\n" - - mkResponseLog :: Bool -> S8.ByteString -> S8.ByteString -> UTCTime -> UTCTime -> LogStr + <> " " + <> toLogStr (rawPathInfo req) + <> "\n" + <> foldMap toLogStr params + <> foldMap toLogStr reqbody + <> foldMap toLogStr (ansiColor White " Accept: ") + <> toLogStr accept + <> "\n" + + mkResponseLog + :: Bool -> S8.ByteString -> S8.ByteString -> UTCTime -> UTCTime -> LogStr mkResponseLog isRaw stCode stMsg t1 t0 = - if isRaw then "" else - foldMap toLogStr (ansiColor White " Status: ") - <> foldMap toLogStr (ansiStatusCode stCode (stCode <> " " <> stMsg)) - <> " " - <> toLogStr (pack $ show $ diffUTCTime t1 t0) - <> "\n" + if isRaw + then "" + else + foldMap toLogStr (ansiColor White " Status: ") + <> foldMap toLogStr (ansiStatusCode stCode (stCode <> " " <> stMsg)) + <> " " + <> toLogStr (pack $ show $ diffUTCTime t1 t0) + <> "\n" + {- HLint ignore detailedMiddleware' "Use lambda-case" -} statusBS :: Response -> BS.ByteString diff --git a/wai-extra/Network/Wai/Middleware/RequestLogger/Internal.hs b/wai-extra/Network/Wai/Middleware/RequestLogger/Internal.hs index 8edaf52a8..62d7bf429 100644 --- a/wai-extra/Network/Wai/Middleware/RequestLogger/Internal.hs +++ b/wai-extra/Network/Wai/Middleware/RequestLogger/Internal.hs @@ -1,11 +1,12 @@ {-# LANGUAGE CPP #-} + -- | A module for containing some CPPed code, due to: -- -- https://github.com/yesodweb/wai/issues/192 -module Network.Wai.Middleware.RequestLogger.Internal - ( getDateGetter - , logToByteString - ) where +module Network.Wai.Middleware.RequestLogger.Internal ( + getDateGetter, + logToByteString, +) where #if !MIN_VERSION_wai_logger(2, 2, 0) import Control.Concurrent (forkIO, threadDelay) @@ -18,8 +19,10 @@ import System.Log.FastLogger (LogStr, fromLogStr) logToByteString :: LogStr -> ByteString logToByteString = fromLogStr -getDateGetter :: IO () -- ^ flusher - -> IO (IO ByteString) +getDateGetter + :: IO () + -- ^ flusher + -> IO (IO ByteString) #if !MIN_VERSION_wai_logger(2, 2, 0) getDateGetter flusher = do (getter, updater) <- clockDateCacher diff --git a/wai-extra/Network/Wai/Middleware/RequestLogger/JSON.hs b/wai-extra/Network/Wai/Middleware/RequestLogger/JSON.hs index 11acdb067..8a2484b1e 100644 --- a/wai-extra/Network/Wai/Middleware/RequestLogger/JSON.hs +++ b/wai-extra/Network/Wai/Middleware/RequestLogger/JSON.hs @@ -1,11 +1,10 @@ {-# LANGUAGE CPP #-} -module Network.Wai.Middleware.RequestLogger.JSON - ( formatAsJSON - , formatAsJSONWithHeaders - - , requestToJSON - ) where +module Network.Wai.Middleware.RequestLogger.JSON ( + formatAsJSON, + formatAsJSONWithHeaders, + requestToJSON, +) where import Data.Aeson import qualified Data.ByteString.Builder as BB (toLazyByteString) @@ -33,20 +32,23 @@ import Network.Wai.Middleware.RequestLogger formatAsJSON :: OutputFormatterWithDetails formatAsJSON date req status responseSize duration reqBody response = - toLogStr (encode $ - object - [ "request" .= requestToJSON req reqBody (Just duration) - , "response" .= - object - [ "status" .= statusCode status - , "size" .= responseSize - , "body" .= - if statusCode status >= 400 - then Just . decodeUtf8With lenientDecode . toStrict . BB.toLazyByteString $ response - else Nothing - ] - , "time" .= decodeUtf8With lenientDecode date - ]) <> "\n" + toLogStr + ( encode $ + object + [ "request" .= requestToJSON req reqBody (Just duration) + , "response" + .= object + [ "status" .= statusCode status + , "size" .= responseSize + , "body" + .= if statusCode status >= 400 + then Just . decodeUtf8With lenientDecode . toStrict . BB.toLazyByteString $ response + else Nothing + ] + , "time" .= decodeUtf8With lenientDecode date + ] + ) + <> "\n" -- | Same as @formatAsJSON@ but with response headers included -- @@ -57,20 +59,24 @@ formatAsJSON date req status responseSize duration reqBody response = -- @since 3.0.27 formatAsJSONWithHeaders :: OutputFormatterWithDetailsAndHeaders formatAsJSONWithHeaders date req status resSize duration reqBody res resHeaders = - toLogStr (encode $ - object - [ "request" .= requestToJSON req reqBody (Just duration) - , "response" .= object - [ "status" .= statusCode status - , "size" .= resSize - , "headers" .= responseHeadersToJSON resHeaders - , "body" .= - if statusCode status >= 400 - then Just . decodeUtf8With lenientDecode . toStrict . BB.toLazyByteString $ res - else Nothing - ] - , "time" .= decodeUtf8With lenientDecode date - ]) <> "\n" + toLogStr + ( encode $ + object + [ "request" .= requestToJSON req reqBody (Just duration) + , "response" + .= object + [ "status" .= statusCode status + , "size" .= resSize + , "headers" .= responseHeadersToJSON resHeaders + , "body" + .= if statusCode status >= 400 + then Just . decodeUtf8With lenientDecode . toStrict . BB.toLazyByteString $ res + else Nothing + ] + , "time" .= decodeUtf8With lenientDecode date + ] + ) + <> "\n" word32ToHostAddress :: Word32 -> Text word32ToHostAddress = T.intercalate "." . map (T.pack . show) . fromIPv4 . fromHostAddress @@ -102,62 +108,81 @@ readAsDouble = read -- without a major version bump. -- -- @since 3.1.4 -requestToJSON :: Request -- ^ The WAI request - -> [S8.ByteString] -- ^ Chunked request body - -> Maybe NominalDiffTime -- ^ Optional request duration - -> Value +requestToJSON + :: Request + -- ^ The WAI request + -> [S8.ByteString] + -- ^ Chunked request body + -> Maybe NominalDiffTime + -- ^ Optional request duration + -> Value requestToJSON req reqBody duration = - object $ - [ "method" .= decodeUtf8With lenientDecode (requestMethod req) - , "path" .= decodeUtf8With lenientDecode (rawPathInfo req) - , "queryString" .= map queryItemToJSON (queryString req) - , "size" .= requestBodyLengthToJSON (requestBodyLength req) - , "body" .= decodeUtf8With lenientDecode (S8.concat reqBody) - , "remoteHost" .= sockToJSON (remoteHost req) - , "httpVersion" .= httpVersionToJSON (httpVersion req) - , "headers" .= requestHeadersToJSON (requestHeaders req) - ] - <> - maybeToList (("durationMs" .=) . readAsDouble . printf "%.2f" . rationalToDouble . (* 1000) . toRational <$> duration) + object $ + [ "method" .= decodeUtf8With lenientDecode (requestMethod req) + , "path" .= decodeUtf8With lenientDecode (rawPathInfo req) + , "queryString" .= map queryItemToJSON (queryString req) + , "size" .= requestBodyLengthToJSON (requestBodyLength req) + , "body" .= decodeUtf8With lenientDecode (S8.concat reqBody) + , "remoteHost" .= sockToJSON (remoteHost req) + , "httpVersion" .= httpVersionToJSON (httpVersion req) + , "headers" .= requestHeadersToJSON (requestHeaders req) + ] + <> maybeToList + ( ("durationMs" .=) + . readAsDouble + . printf "%.2f" + . rationalToDouble + . (* 1000) + . toRational + <$> duration + ) where rationalToDouble :: Rational -> Double rationalToDouble = fromRational sockToJSON :: SockAddr -> Value sockToJSON (SockAddrInet pn ha) = - object - [ "port" .= portToJSON pn - , "hostAddress" .= word32ToHostAddress ha - ] + object + [ "port" .= portToJSON pn + , "hostAddress" .= word32ToHostAddress ha + ] sockToJSON (SockAddrInet6 pn _ ha _) = - object - [ "port" .= portToJSON pn - , "hostAddress" .= ha - ] + object + [ "port" .= portToJSON pn + , "hostAddress" .= ha + ] sockToJSON (SockAddrUnix sock) = - object [ "unix" .= sock ] + object ["unix" .= sock] #if !MIN_VERSION_network(3,0,0) sockToJSON (SockAddrCan i) = object [ "can" .= i ] #endif queryItemToJSON :: QueryItem -> Value -queryItemToJSON (name, mValue) = toJSON (decodeUtf8With lenientDecode name, fmap (decodeUtf8With lenientDecode) mValue) +queryItemToJSON (name, mValue) = + toJSON + (decodeUtf8With lenientDecode name, fmap (decodeUtf8With lenientDecode) mValue) requestHeadersToJSON :: RequestHeaders -> Value -requestHeadersToJSON = toJSON . map hToJ where - -- Redact cookies - hToJ ("Cookie", _) = toJSON ("Cookie" :: Text, "-RDCT-" :: Text) - hToJ hd = headerToJSON hd +requestHeadersToJSON = toJSON . map hToJ + where + -- Redact cookies + hToJ ("Cookie", _) = toJSON ("Cookie" :: Text, "-RDCT-" :: Text) + hToJ hd = headerToJSON hd responseHeadersToJSON :: [Header] -> Value -responseHeadersToJSON = toJSON . map hToJ where - -- Redact cookies - hToJ ("Set-Cookie", _) = toJSON ("Set-Cookie" :: Text, "-RDCT-" :: Text) - hToJ hd = headerToJSON hd +responseHeadersToJSON = toJSON . map hToJ + where + -- Redact cookies + hToJ ("Set-Cookie", _) = toJSON ("Set-Cookie" :: Text, "-RDCT-" :: Text) + hToJ hd = headerToJSON hd headerToJSON :: Header -> Value -headerToJSON (headerName, header) = toJSON (decodeUtf8With lenientDecode . original $ headerName, decodeUtf8With lenientDecode header) +headerToJSON (headerName, header) = + toJSON + ( decodeUtf8With lenientDecode . original $ headerName + , decodeUtf8With lenientDecode header + ) portToJSON :: PortNumber -> Value portToJSON = toJSON . toInteger diff --git a/wai-extra/Network/Wai/Middleware/RequestSizeLimit.hs b/wai-extra/Network/Wai/Middleware/RequestSizeLimit.hs index 7f4b26209..6b6a7a46c 100644 --- a/wai-extra/Network/Wai/Middleware/RequestSizeLimit.hs +++ b/wai-extra/Network/Wai/Middleware/RequestSizeLimit.hs @@ -1,19 +1,21 @@ {-# LANGUAGE CPP #-} + -- | The functions in this module allow you to limit the total size of incoming request bodies. -- -- Limiting incoming request body size helps protect your server against denial-of-service (DOS) attacks, -- in which an attacker sends huge bodies to your server. -module Network.Wai.Middleware.RequestSizeLimit - ( +module Network.Wai.Middleware.RequestSizeLimit ( -- * Middleware - requestSizeLimitMiddleware + requestSizeLimitMiddleware, + -- * Constructing 'RequestSizeLimitSettings' - , defaultRequestSizeLimitSettings + defaultRequestSizeLimitSettings, + -- * 'RequestSizeLimitSettings' and accessors - , RequestSizeLimitSettings - , setMaxLengthForRequest - , setOnLengthExceeded - ) where + RequestSizeLimitSettings, + setMaxLengthForRequest, + setOnLengthExceeded, +) where import Control.Exception (catch, try) import qualified Data.ByteString.Lazy as BSL @@ -25,7 +27,11 @@ import Data.Word (Word64) import Network.HTTP.Types.Status (requestEntityTooLarge413) import Network.Wai -import Network.Wai.Middleware.RequestSizeLimit.Internal (RequestSizeLimitSettings (..), setMaxLengthForRequest, setOnLengthExceeded) +import Network.Wai.Middleware.RequestSizeLimit.Internal ( + RequestSizeLimitSettings (..), + setMaxLengthForRequest, + setOnLengthExceeded, + ) import Network.Wai.Request -- | Create a 'RequestSizeLimitSettings' with these settings: @@ -35,10 +41,11 @@ import Network.Wai.Request -- -- @since 3.1.1 defaultRequestSizeLimitSettings :: RequestSizeLimitSettings -defaultRequestSizeLimitSettings = RequestSizeLimitSettings - { maxLengthForRequest = \_req -> pure $ Just $ 2 * 1024 * 1024 - , onLengthExceeded = \maxLen _app req sendResponse -> sendResponse (tooLargeResponse maxLen (requestBodyLength req)) - } +defaultRequestSizeLimitSettings = + RequestSizeLimitSettings + { maxLengthForRequest = \_req -> pure $ Just $ 2 * 1024 * 1024 + , onLengthExceeded = \maxLen _app req sendResponse -> sendResponse (tooLargeResponse maxLen (requestBodyLength req)) + } -- | Middleware to limit request bodies to a certain size. -- @@ -57,21 +64,23 @@ requestSizeLimitMiddleware settings app req sendResponse = do -- In the case of a known-length request, RequestSizeException will be thrown immediately Left (RequestSizeException _maxLen) -> handleLengthExceeded maxLen -- In the case of a chunked request (unknown length), RequestSizeException will be thrown during the processing of a body - Right newReq -> app newReq sendResponse `catch` \(RequestSizeException _maxLen) -> handleLengthExceeded maxLen - - where - handleLengthExceeded maxLen = onLengthExceeded settings maxLen app req sendResponse + Right newReq -> + app newReq sendResponse `catch` \(RequestSizeException _maxLen) -> handleLengthExceeded maxLen + where + handleLengthExceeded maxLen = onLengthExceeded settings maxLen app req sendResponse tooLargeResponse :: Word64 -> RequestBodyLength -> Response -tooLargeResponse maxLen bodyLen = responseLBS - requestEntityTooLarge413 - [("Content-Type", "text/plain")] - (BSL.concat - [ "Request body too large to be processed. The maximum size is " - , LS8.pack (show maxLen) - , " bytes; your request body was " - , case bodyLen of - KnownLength knownLen -> LS8.pack (show knownLen) <> " bytes." - ChunkedBody -> "split into chunks, whose total size is unknown, but exceeded the limit." - , " If you're the developer of this site, you can configure the maximum length with `requestSizeLimitMiddleware`." - ]) +tooLargeResponse maxLen bodyLen = + responseLBS + requestEntityTooLarge413 + [("Content-Type", "text/plain")] + ( BSL.concat + [ "Request body too large to be processed. The maximum size is " + , LS8.pack (show maxLen) + , " bytes; your request body was " + , case bodyLen of + KnownLength knownLen -> LS8.pack (show knownLen) <> " bytes." + ChunkedBody -> "split into chunks, whose total size is unknown, but exceeded the limit." + , " If you're the developer of this site, you can configure the maximum length with `requestSizeLimitMiddleware`." + ] + ) diff --git a/wai-extra/Network/Wai/Middleware/RequestSizeLimit/Internal.hs b/wai-extra/Network/Wai/Middleware/RequestSizeLimit/Internal.hs index c597a9a8a..f82f69713 100644 --- a/wai-extra/Network/Wai/Middleware/RequestSizeLimit/Internal.hs +++ b/wai-extra/Network/Wai/Middleware/RequestSizeLimit/Internal.hs @@ -1,9 +1,9 @@ -- | Internal constructors and helper functions. Note that no guarantees are given for stability of these interfaces. -module Network.Wai.Middleware.RequestSizeLimit.Internal - ( RequestSizeLimitSettings(..) - , setMaxLengthForRequest - , setOnLengthExceeded - ) where +module Network.Wai.Middleware.RequestSizeLimit.Internal ( + RequestSizeLimitSettings (..), + setMaxLengthForRequest, + setOnLengthExceeded, +) where import Data.Word (Word64) import Network.Wai (Middleware, Request) @@ -40,18 +40,24 @@ import Network.Wai (Middleware, Request) -- -- @since 3.1.1 data RequestSizeLimitSettings = RequestSizeLimitSettings - { maxLengthForRequest :: Request -> IO (Maybe Word64) -- ^ Function to determine the maximum request size in bytes for the request. Return 'Nothing' for no limit. Since 3.1.1 - , onLengthExceeded :: Word64 -> Middleware -- ^ Callback function when maximum length is exceeded. The 'Word64' argument is the limit computed by 'maxLengthForRequest'. Since 3.1.1 + { maxLengthForRequest :: Request -> IO (Maybe Word64) + -- ^ Function to determine the maximum request size in bytes for the request. Return 'Nothing' for no limit. Since 3.1.1 + , onLengthExceeded :: Word64 -> Middleware + -- ^ Callback function when maximum length is exceeded. The 'Word64' argument is the limit computed by 'maxLengthForRequest'. Since 3.1.1 } -- | Function to determine the maximum request size in bytes for the request. Return 'Nothing' for no limit. -- -- @since 3.1.1 -setMaxLengthForRequest :: (Request -> IO (Maybe Word64)) -> RequestSizeLimitSettings -> RequestSizeLimitSettings -setMaxLengthForRequest fn settings = settings { maxLengthForRequest = fn } +setMaxLengthForRequest + :: (Request -> IO (Maybe Word64)) + -> RequestSizeLimitSettings + -> RequestSizeLimitSettings +setMaxLengthForRequest fn settings = settings{maxLengthForRequest = fn} -- | Callback function when maximum length is exceeded. The 'Word64' argument is the limit computed by 'setMaxLengthForRequest'. -- -- @since 3.1.1 -setOnLengthExceeded :: (Word64 -> Middleware) -> RequestSizeLimitSettings -> RequestSizeLimitSettings -setOnLengthExceeded fn settings = settings { onLengthExceeded = fn } +setOnLengthExceeded + :: (Word64 -> Middleware) -> RequestSizeLimitSettings -> RequestSizeLimitSettings +setOnLengthExceeded fn settings = settings{onLengthExceeded = fn} diff --git a/wai-extra/Network/Wai/Middleware/Rewrite.hs b/wai-extra/Network/Wai/Middleware/Rewrite.hs index a5453679c..9c4d5fd28 100644 --- a/wai-extra/Network/Wai/Middleware/Rewrite.hs +++ b/wai-extra/Network/Wai/Middleware/Rewrite.hs @@ -1,43 +1,38 @@ -{-# LANGUAGE TupleSections #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE TupleSections #-} -module Network.Wai.Middleware.Rewrite - ( -- * How to use this module - -- $howto - - -- ** A note on semantics - - -- $semantics - - -- ** Paths and Queries - - -- $pathsandqueries - PathsAndQueries - - -- ** An example rewriting paths with queries +module Network.Wai.Middleware.Rewrite ( + -- * How to use this module + -- $howto - -- $takeover + -- ** A note on semantics + -- $semantics - -- ** Upgrading from wai-extra ≤ 3.0.16.1 + -- ** Paths and Queries + -- $pathsandqueries + PathsAndQueries, - -- $upgrading + -- ** An example rewriting paths with queries + -- $takeover - -- * 'Middleware' + -- ** Upgrading from wai-extra ≤ 3.0.16.1 + -- $upgrading - -- ** Recommended functions - , rewriteWithQueries - , rewritePureWithQueries - , rewriteRoot + -- * 'Middleware' - -- ** Deprecated - , rewrite - , rewritePure + -- ** Recommended functions + rewriteWithQueries, + rewritePureWithQueries, + rewriteRoot, - -- * Operating on 'Request's + -- ** Deprecated + rewrite, + rewritePure, - , rewriteRequest - , rewriteRequestPure - ) where + -- * Operating on 'Request's + rewriteRequest, + rewriteRequestPure, +) where -- GHC ≤ 7.10 does not export Applicative functions from the prelude. #if __GLASGOW_HASKELL__ <= 710 @@ -51,7 +46,6 @@ import qualified Data.Text.Encoding as TE import Network.HTTP.Types as H import Network.Wai - -- $howto -- This module provides 'Middleware' to rewrite URL paths. It also provides -- functions that will convert a 'Request' to a modified 'Request'. @@ -202,7 +196,9 @@ import Network.Wai -- upgrading as the upgraded functions no longer modify 'rawPathInfo'. -------------------------------------------------- + -- * Types + -------------------------------------------------- -- | A tuple of the path sections as ['Text'] and query parameters as @@ -216,7 +212,9 @@ import Network.Wai type PathsAndQueries = ([Text], H.Query) -------------------------------------------------- + -- * Rewriting 'Middleware' + -------------------------------------------------- -- | Rewrite based on your own conversion function for paths only, to be @@ -225,14 +223,17 @@ type PathsAndQueries = ([Text], H.Query) -- For new code, use 'rewriteWithQueries' instead. rewrite :: ([Text] -> H.RequestHeaders -> IO [Text]) -> Middleware rewrite convert app req sendResponse = do - let convertIO = pathsOnly . curry $ liftIO . uncurry convert - newReq <- rewriteRequestRawM convertIO req - app newReq sendResponse -{-# WARNING rewrite [ - "This modifies the 'rawPathInfo' field of a 'Request'." - , " This is not recommended behaviour; it is however how" - , " this function has worked in the past." - , " Use 'rewriteWithQueries' instead"] #-} + let convertIO = pathsOnly . curry $ liftIO . uncurry convert + newReq <- rewriteRequestRawM convertIO req + app newReq sendResponse +{-# WARNING + rewrite + [ "This modifies the 'rawPathInfo' field of a 'Request'." + , " This is not recommended behaviour; it is however how" + , " this function has worked in the past." + , " Use 'rewriteWithQueries' instead" + ] + #-} -- | Rewrite based on pure conversion function for paths only, to be -- supplied by users of this library. @@ -240,28 +241,33 @@ rewrite convert app req sendResponse = do -- For new code, use 'rewritePureWithQueries' instead. rewritePure :: ([Text] -> H.RequestHeaders -> [Text]) -> Middleware rewritePure convert app req = - let convertPure = pathsOnly . curry $ Identity . uncurry convert - newReq = runIdentity $ rewriteRequestRawM convertPure req - in app newReq -{-# WARNING rewritePure [ - "This modifies the 'rawPathInfo' field of a 'Request'." - , " This is not recommended behaviour; it is however how" - , " this function has worked in the past." - , " Use 'rewritePureWithQueries' instead"] #-} + let convertPure = pathsOnly . curry $ Identity . uncurry convert + newReq = runIdentity $ rewriteRequestRawM convertPure req + in app newReq +{-# WARNING + rewritePure + [ "This modifies the 'rawPathInfo' field of a 'Request'." + , " This is not recommended behaviour; it is however how" + , " this function has worked in the past." + , " Use 'rewritePureWithQueries' instead" + ] + #-} -- | Rewrite based on your own conversion function for paths and queries. -- This function is to be supplied by users of this library, and operates -- in 'IO'. -rewriteWithQueries :: (PathsAndQueries -> H.RequestHeaders -> IO PathsAndQueries) - -> Middleware +rewriteWithQueries + :: (PathsAndQueries -> H.RequestHeaders -> IO PathsAndQueries) + -> Middleware rewriteWithQueries convert app req sendResponse = do - newReq <- rewriteRequestM convert req - app newReq sendResponse + newReq <- rewriteRequestM convert req + app newReq sendResponse -- | Rewrite based on pure conversion function for paths and queries. This -- function is to be supplied by users of this library. -rewritePureWithQueries :: (PathsAndQueries -> H.RequestHeaders -> PathsAndQueries) - -> Middleware +rewritePureWithQueries + :: (PathsAndQueries -> H.RequestHeaders -> PathsAndQueries) + -> Middleware rewritePureWithQueries convert app req = app $ rewriteRequestPure convert req -- | Rewrite root requests (/) to a specified path @@ -280,60 +286,79 @@ rewriteRoot root = rewritePureWithQueries onlyRoot onlyRoot paths _ = paths -------------------------------------------------- + -- * Modifying 'Request's directly + -------------------------------------------------- -- | Modify a 'Request' using the supplied function in 'IO'. This is suitable for -- the reverse proxy example. -rewriteRequest :: (PathsAndQueries -> H.RequestHeaders -> IO PathsAndQueries) - -> Request -> IO Request +rewriteRequest + :: (PathsAndQueries -> H.RequestHeaders -> IO PathsAndQueries) + -> Request + -> IO Request rewriteRequest convert req = - let convertIO = curry $ liftIO . uncurry convert - in rewriteRequestRawM convertIO req + let convertIO = curry $ liftIO . uncurry convert + in rewriteRequestRawM convertIO req -- | Modify a 'Request' using the pure supplied function. This is suitable for -- the reverse proxy example. -rewriteRequestPure :: (PathsAndQueries -> H.RequestHeaders -> PathsAndQueries) - -> Request -> Request +rewriteRequestPure + :: (PathsAndQueries -> H.RequestHeaders -> PathsAndQueries) + -> Request + -> Request rewriteRequestPure convert req = - let convertPure = curry $ Identity . uncurry convert - in runIdentity $ rewriteRequestRawM convertPure req + let convertPure = curry $ Identity . uncurry convert + in runIdentity $ rewriteRequestRawM convertPure req -------------------------------------------------- + -- * Helper functions + -------------------------------------------------- -- | This helper function factors out the common behaviour of rewriting requests. -rewriteRequestM :: (Applicative m, Monad m) - => (PathsAndQueries -> H.RequestHeaders -> m PathsAndQueries) - -> Request -> m Request +rewriteRequestM + :: (Applicative m, Monad m) + => (PathsAndQueries -> H.RequestHeaders -> m PathsAndQueries) + -> Request + -> m Request rewriteRequestM convert req = do - (pInfo, qByteStrings) <- curry convert (pathInfo req) (queryString req) (requestHeaders req) - pure req {pathInfo = pInfo, queryString = qByteStrings} + (pInfo, qByteStrings) <- + curry convert (pathInfo req) (queryString req) (requestHeaders req) + pure req{pathInfo = pInfo, queryString = qByteStrings} -- | This helper function preserves the semantics of wai-extra ≤ 3.0, in -- which the rewrite functions modify the 'rawPathInfo' parameter. Note -- that this has not been extended to modify the 'rawQueryInfo' as -- modifying either of these values has been deprecated. -rewriteRequestRawM :: (Applicative m, Monad m) - => (PathsAndQueries -> H.RequestHeaders -> m PathsAndQueries) - -> Request -> m Request +rewriteRequestRawM + :: (Applicative m, Monad m) + => (PathsAndQueries -> H.RequestHeaders -> m PathsAndQueries) + -> Request + -> m Request rewriteRequestRawM convert req = do - newReq <- rewriteRequestM convert req - let rawPInfo = TE.encodeUtf8 . T.intercalate "/" . pathInfo $ newReq - pure newReq { rawPathInfo = rawPInfo } -{-# WARNING rewriteRequestRawM [ - "This modifies the 'rawPathInfo' field of a 'Request'." - , " This is not recommended behaviour; it is however how" - , " this function has worked in the past." - , " Use 'rewriteRequestM' instead"] #-} + newReq <- rewriteRequestM convert req + let rawPInfo = TE.encodeUtf8 . T.intercalate "/" . pathInfo $ newReq + pure newReq{rawPathInfo = rawPInfo} +{-# WARNING + rewriteRequestRawM + [ "This modifies the 'rawPathInfo' field of a 'Request'." + , " This is not recommended behaviour; it is however how" + , " this function has worked in the past." + , " Use 'rewriteRequestM' instead" + ] + #-} -- | Produce a function that works on 'PathsAndQueries' from one working -- only on paths. This is not exported, as it is only needed to handle -- code written for versions ≤ 3.0 of the library; see the -- example above using 'Data.Bifunctor.first' to do something similar. -pathsOnly :: (Applicative m, Monad m) - => ([Text] -> H.RequestHeaders -> m [Text]) - -> PathsAndQueries -> H.RequestHeaders -> m PathsAndQueries +pathsOnly + :: (Applicative m, Monad m) + => ([Text] -> H.RequestHeaders -> m [Text]) + -> PathsAndQueries + -> H.RequestHeaders + -> m PathsAndQueries pathsOnly convert psAndQs headers = (,[]) <$> convert (fst psAndQs) headers {-# INLINE pathsOnly #-} diff --git a/wai-extra/Network/Wai/Middleware/Routed.hs b/wai-extra/Network/Wai/Middleware/Routed.hs index 05a908a95..af61512d2 100644 --- a/wai-extra/Network/Wai/Middleware/Routed.hs +++ b/wai-extra/Network/Wai/Middleware/Routed.hs @@ -1,10 +1,10 @@ -- | -- -- Since 3.0.9 -module Network.Wai.Middleware.Routed - ( routedMiddleware - , hostedMiddleware - ) where +module Network.Wai.Middleware.Routed ( + routedMiddleware, + hostedMiddleware, +) where import Data.ByteString (ByteString) import Data.Text (Text) @@ -17,22 +17,30 @@ import Network.Wai -- > let corsify = routedMiddleWare ("static" `elem`) addCorsHeaders -- -- Since 3.0.9 -routedMiddleware :: ([Text] -> Bool) -- ^ Only use middleware if this pathInfo test returns True - -> Middleware -- ^ middleware to apply the path prefix guard to - -> Middleware -- ^ modified middleware +routedMiddleware + :: ([Text] -> Bool) + -- ^ Only use middleware if this pathInfo test returns True + -> Middleware + -- ^ middleware to apply the path prefix guard to + -> Middleware + -- ^ modified middleware routedMiddleware pathCheck middle app req - | pathCheck (pathInfo req) = middle app req - | otherwise = app req + | pathCheck (pathInfo req) = middle app req + | otherwise = app req -- | Only apply the middleware to certain hosts -- -- Since 3.0.9 -hostedMiddleware :: ByteString -- ^ Domain the middleware applies to - -> Middleware -- ^ middleware to apply the path prefix guard to - -> Middleware -- ^ modified middleware +hostedMiddleware + :: ByteString + -- ^ Domain the middleware applies to + -> Middleware + -- ^ middleware to apply the path prefix guard to + -> Middleware + -- ^ modified middleware hostedMiddleware domain middle app req - | hasDomain domain req = middle app req - | otherwise = app req + | hasDomain domain req = middle app req + | otherwise = app req hasDomain :: ByteString -> Request -> Bool hasDomain domain req = Just domain == requestHeaderHost req diff --git a/wai-extra/Network/Wai/Middleware/Select.hs b/wai-extra/Network/Wai/Middleware/Select.hs index b5a55db21..234675471 100644 --- a/wai-extra/Network/Wai/Middleware/Select.hs +++ b/wai-extra/Network/Wai/Middleware/Select.hs @@ -1,4 +1,7 @@ --------------------------------------------------------- + +--------------------------------------------------------- + -- | -- Module : Network.Wai.Middleware.Select -- Copyright : Michael Snoyman @@ -24,10 +27,8 @@ -- > $ healthCheck app -- -- @since 3.1.10 --- ---------------------------------------------------------- -module Network.Wai.Middleware.Select - ( -- * Middleware selection +module Network.Wai.Middleware.Select ( + -- * Middleware selection MiddlewareSelection (..), selectMiddleware, @@ -36,7 +37,7 @@ module Network.Wai.Middleware.Select selectMiddlewareOnRawPathInfo, selectMiddlewareExceptRawPathInfo, passthroughMiddleware, - ) +) where import Control.Applicative ((<|>)) @@ -45,31 +46,35 @@ import Data.Maybe (fromMaybe) import Network.Wai -------------------------------------------------- + -- * Middleware selection + -------------------------------------------------- -- | Relevant Middleware for a given 'Request'. newtype MiddlewareSelection = MiddlewareSelection - { applySelectedMiddleware :: Request -> Maybe Middleware - } + { applySelectedMiddleware :: Request -> Maybe Middleware + } instance Semigroup MiddlewareSelection where - MiddlewareSelection f <> MiddlewareSelection g = - MiddlewareSelection $ \req -> f req <|> g req + MiddlewareSelection f <> MiddlewareSelection g = + MiddlewareSelection $ \req -> f req <|> g req instance Monoid MiddlewareSelection where - mempty = MiddlewareSelection $ const Nothing + mempty = MiddlewareSelection $ const Nothing -- | Create the 'Middleware' dynamically applying 'MiddlewareSelection'. selectMiddleware :: MiddlewareSelection -> Middleware selectMiddleware selection app request respond = - mw app request respond + mw app request respond where mw :: Middleware mw = fromMaybe passthroughMiddleware (applySelectedMiddleware selection request) -------------------------------------------------- + -- * Helpers + -------------------------------------------------- passthroughMiddleware :: Middleware @@ -78,14 +83,15 @@ passthroughMiddleware = id -- | Use the 'Middleware' when the predicate holds. selectMiddlewareOn :: (Request -> Bool) -> Middleware -> MiddlewareSelection selectMiddlewareOn doesApply mw = MiddlewareSelection $ \request -> - if doesApply request - then Just mw - else Nothing + if doesApply request + then Just mw + else Nothing -- | Use the `Middleware` for the given 'rawPathInfo'. selectMiddlewareOnRawPathInfo :: ByteString -> Middleware -> MiddlewareSelection selectMiddlewareOnRawPathInfo path = selectMiddlewareOn ((== path) . rawPathInfo) -- | Use the `Middleware` for all 'rawPathInfo' except then given one. -selectMiddlewareExceptRawPathInfo :: ByteString -> Middleware -> MiddlewareSelection +selectMiddlewareExceptRawPathInfo + :: ByteString -> Middleware -> MiddlewareSelection selectMiddlewareExceptRawPathInfo path = selectMiddlewareOn ((/= path) . rawPathInfo) diff --git a/wai-extra/Network/Wai/Middleware/StreamFile.hs b/wai-extra/Network/Wai/Middleware/StreamFile.hs index 0256ccfe0..5ea0a6a05 100644 --- a/wai-extra/Network/Wai/Middleware/StreamFile.hs +++ b/wai-extra/Network/Wai/Middleware/StreamFile.hs @@ -1,8 +1,7 @@ -- | -- -- Since 3.0.4 -module Network.Wai.Middleware.StreamFile - (streamFile) where +module Network.Wai.Middleware.StreamFile (streamFile) where import qualified Data.ByteString.Char8 as S8 import Network.HTTP.Types (hContentLength) @@ -10,28 +9,28 @@ import Network.Wai (Middleware, responseStream, responseToStream) import Network.Wai.Internal import System.Directory (getFileSize) --- |Convert ResponseFile type responses into ResponseStream type +-- | Convert ResponseFile type responses into ResponseStream type -- --- Checks the response type, and if it's a ResponseFile, converts it --- into a ResponseStream. Other response types are passed through --- unchanged. +-- Checks the response type, and if it's a ResponseFile, converts it +-- into a ResponseStream. Other response types are passed through +-- unchanged. -- --- Converted responses get a Content-Length header. +-- Converted responses get a Content-Length header. -- --- Streaming a file will bypass a sendfile system call, and may be --- useful to work around systems without working sendfile --- implementations. +-- Streaming a file will bypass a sendfile system call, and may be +-- useful to work around systems without working sendfile +-- implementations. -- --- Since 3.0.4 +-- Since 3.0.4 streamFile :: Middleware streamFile app env sendResponse = app env $ \res -> case res of - ResponseFile _ _ fp _ -> withBody sendBody + ResponseFile _ _ fp _ -> withBody sendBody where (s, hs, withBody) = responseToStream res sendBody :: StreamingBody -> IO ResponseReceived sendBody body = do - len <- getFileSize fp - let hs' = (hContentLength, S8.pack (show len)) : hs - sendResponse $ responseStream s hs' body - _ -> sendResponse res + len <- getFileSize fp + let hs' = (hContentLength, S8.pack (show len)) : hs + sendResponse $ responseStream s hs' body + _ -> sendResponse res diff --git a/wai-extra/Network/Wai/Middleware/StripHeaders.hs b/wai-extra/Network/Wai/Middleware/StripHeaders.hs index 19373ffef..55abe2687 100644 --- a/wai-extra/Network/Wai/Middleware/StripHeaders.hs +++ b/wai-extra/Network/Wai/Middleware/StripHeaders.hs @@ -7,26 +7,31 @@ -- When using this, care should be taken not to strip out headers that are -- required for correct operation of the client (eg Content-Type). -module Network.Wai.Middleware.StripHeaders - ( stripHeader - , stripHeaders - , stripHeaderIf - , stripHeadersIf - ) where +module Network.Wai.Middleware.StripHeaders ( + stripHeader, + stripHeaders, + stripHeaderIf, + stripHeadersIf, +) where import Data.ByteString (ByteString) import qualified Data.CaseInsensitive as CI -import Network.Wai (Middleware, Request, modifyResponse, mapResponseHeaders, ifRequest) +import Network.Wai ( + Middleware, + Request, + ifRequest, + mapResponseHeaders, + modifyResponse, + ) import Network.Wai.Internal (Response) - stripHeader :: ByteString -> (Response -> Response) -stripHeader h = mapResponseHeaders (filter (\ hdr -> fst hdr /= CI.mk h)) +stripHeader h = mapResponseHeaders (filter (\hdr -> fst hdr /= CI.mk h)) stripHeaders :: [ByteString] -> (Response -> Response) stripHeaders hs = - let hnames = map CI.mk hs - in mapResponseHeaders (filter (\ hdr -> fst hdr `notElem` hnames)) + let hnames = map CI.mk hs + in mapResponseHeaders (filter (\hdr -> fst hdr `notElem` hnames)) -- | If the request satisifes the provided predicate, strip headers matching -- the provided header name. @@ -34,12 +39,12 @@ stripHeaders hs = -- Since 3.0.8 stripHeaderIf :: ByteString -> (Request -> Bool) -> Middleware stripHeaderIf h rpred = - ifRequest rpred (modifyResponse $ stripHeader h) + ifRequest rpred (modifyResponse $ stripHeader h) -- | If the request satisifes the provided predicate, strip all headers whose -- header name is in the list of provided header names. -- -- Since 3.0.8 stripHeadersIf :: [ByteString] -> (Request -> Bool) -> Middleware -stripHeadersIf hs rpred - = ifRequest rpred (modifyResponse $ stripHeaders hs) +stripHeadersIf hs rpred = + ifRequest rpred (modifyResponse $ stripHeaders hs) diff --git a/wai-extra/Network/Wai/Middleware/Timeout.hs b/wai-extra/Network/Wai/Middleware/Timeout.hs index 7c6b3a948..164f64311 100644 --- a/wai-extra/Network/Wai/Middleware/Timeout.hs +++ b/wai-extra/Network/Wai/Middleware/Timeout.hs @@ -1,9 +1,9 @@ -- | Timeout requests -module Network.Wai.Middleware.Timeout - ( timeout - , timeoutStatus - , timeoutAs - ) where +module Network.Wai.Middleware.Timeout ( + timeout, + timeoutStatus, + timeoutAs, +) where import Network.HTTP.Types (Status, status503) import Network.Wai diff --git a/wai-extra/Network/Wai/Middleware/Vhost.hs b/wai-extra/Network/Wai/Middleware/Vhost.hs index 8d82ccb37..aa738b41d 100644 --- a/wai-extra/Network/Wai/Middleware/Vhost.hs +++ b/wai-extra/Network/Wai/Middleware/Vhost.hs @@ -1,11 +1,11 @@ {-# LANGUAGE CPP #-} -module Network.Wai.Middleware.Vhost - ( vhost - , redirectWWW - , redirectTo - , redirectToLogged - ) where +module Network.Wai.Middleware.Vhost ( + vhost, + redirectWWW, + redirectTo, + redirectToLogged, +) where import qualified Data.ByteString as BS #if __GLASGOW_HASKELL__ < 710 @@ -20,23 +20,28 @@ vhost :: [(Request -> Bool, Application)] -> Application -> Application vhost vhosts def req = case filter (\(b, _) -> b req) vhosts of [] -> def req - (_, app):_ -> app req + (_, app) : _ -> app req redirectWWW :: Text -> Application -> Application -- W.MiddleWare redirectWWW home = - redirectIf home (maybe True (BS.isPrefixOf "www") . lookup "host" . requestHeaders) + redirectIf + home + (maybe True (BS.isPrefixOf "www") . lookup "host" . requestHeaders) redirectIf :: Text -> (Request -> Bool) -> Application -> Application redirectIf home cond app req sendResponse = - if cond req - then sendResponse $ redirectTo $ TE.encodeUtf8 home - else app req sendResponse + if cond req + then sendResponse $ redirectTo $ TE.encodeUtf8 home + else app req sendResponse redirectTo :: BS.ByteString -> Response -redirectTo location = responseLBS H.status301 - [ (H.hContentType, "text/plain") , (H.hLocation, location) ] "Redirect" +redirectTo location = + responseLBS + H.status301 + [(H.hContentType, "text/plain"), (H.hLocation, location)] + "Redirect" redirectToLogged :: (Text -> IO ()) -> BS.ByteString -> IO Response redirectToLogged logger loc = do - logger $ "redirecting to: " `mappend` TE.decodeUtf8 loc - return $ redirectTo loc + logger $ "redirecting to: " `mappend` TE.decodeUtf8 loc + return $ redirectTo loc diff --git a/wai-extra/Network/Wai/Request.hs b/wai-extra/Network/Wai/Request.hs index 1a972c551..41a8974ec 100644 --- a/wai-extra/Network/Wai/Request.hs +++ b/wai-extra/Network/Wai/Request.hs @@ -1,12 +1,12 @@ {-# LANGUAGE DeriveDataTypeable #-} --- | Some helpers for interrogating a WAI 'Request'. -module Network.Wai.Request - ( appearsSecure - , guessApproot - , RequestSizeException(..) - , requestSizeCheck - ) where +-- | Some helpers for interrogating a WAI 'Request'. +module Network.Wai.Request ( + appearsSecure, + guessApproot, + RequestSizeException (..), + requestSizeCheck, +) where import Control.Exception (Exception, throwIO) import Data.ByteString (ByteString) @@ -19,7 +19,6 @@ import Data.Word (Word64) import Network.HTTP.Types (HeaderName) import Network.Wai - -- | Does this request appear to have been made over an SSL connection? -- -- This function first checks @'isSecure'@, but also checks for headers that may @@ -33,14 +32,16 @@ import Network.Wai -- -- @since 3.0.7 appearsSecure :: Request -> Bool -appearsSecure request = isSecure request || any (uncurry matchHeader) - [ ("HTTPS" , (== "on")) - , ("HTTP_X_FORWARDED_SSL" , (== "on")) - , ("HTTP_X_FORWARDED_SCHEME", (== "https")) - , ("HTTP_X_FORWARDED_PROTO" , (== ["https"]) . take 1 . C.split ',') - , ("X-Forwarded-Proto" , (== "https")) -- Used by Nginx and AWS ELB. - ] - +appearsSecure request = + isSecure request + || any + (uncurry matchHeader) + [ ("HTTPS", (== "on")) + , ("HTTP_X_FORWARDED_SSL", (== "on")) + , ("HTTP_X_FORWARDED_SCHEME", (== "https")) + , ("HTTP_X_FORWARDED_PROTO", (== ["https"]) . take 1 . C.split ',') + , ("X-Forwarded-Proto", (== "https")) -- Used by Nginx and AWS ELB. + ] where matchHeader :: HeaderName -> (ByteString -> Bool) -> Bool matchHeader h f = maybe False f $ lookup h $ requestHeaders request @@ -54,8 +55,8 @@ appearsSecure request = isSecure request || any (uncurry matchHeader) -- @since 3.0.7 guessApproot :: Request -> ByteString guessApproot req = - (if appearsSecure req then "https://" else "http://") `S.append` - fromMaybe "localhost" (requestHeaderHost req) + (if appearsSecure req then "https://" else "http://") + `S.append` fromMaybe "localhost" (requestHeaderHost req) -- | see 'requestSizeCheck' -- @@ -68,7 +69,9 @@ instance Exception RequestSizeException instance Show RequestSizeException where showsPrec p (RequestSizeException limit) = - showString "Request Body is larger than " . showsPrec p limit . showString " bytes." + showString "Request Body is larger than " + . showsPrec p limit + . showString " bytes." -- | Check request body size to avoid server crash when request is too large. -- @@ -81,19 +84,19 @@ instance Show RequestSizeException where requestSizeCheck :: Word64 -> Request -> IO Request requestSizeCheck maxSize req = case requestBodyLength req of - KnownLength len -> + KnownLength len -> if len > maxSize then return $ setRequestBodyChunks (throwIO $ RequestSizeException maxSize) req else return req - ChunkedBody -> do + ChunkedBody -> do currentSize <- newIORef 0 let rbody = do bs <- getRequestBodyChunk req total <- atomicModifyIORef' currentSize $ \sz -> let nextSize = sz + fromIntegral (S.length bs) - in (nextSize, nextSize) + in (nextSize, nextSize) if total > maxSize - then throwIO (RequestSizeException maxSize) - else return bs + then throwIO (RequestSizeException maxSize) + else return bs return $ setRequestBodyChunks rbody req diff --git a/wai-extra/Network/Wai/Test.hs b/wai-extra/Network/Wai/Test.hs index a72320cb1..37ae89ee2 100644 --- a/wai-extra/Network/Wai/Test.hs +++ b/wai-extra/Network/Wai/Test.hs @@ -1,35 +1,39 @@ -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE CPP #-} -module Network.Wai.Test - ( -- * Session - Session - , runSession - , withSession - -- * Client Cookies - , ClientCookies - , getClientCookies - , modifyClientCookies - , setClientCookie - , deleteClientCookie - -- * Requests - , request - , srequest - , SRequest (..) - , SResponse (..) - , defaultRequest - , setPath - , setRawPathInfo - -- * Assertions - , assertStatus - , assertContentType - , assertBody - , assertBodyContains - , assertHeader - , assertNoHeader - , assertClientCookieExists - , assertNoClientCookieExists - , assertClientCookieValue - ) where +{-# LANGUAGE OverloadedStrings #-} + +module Network.Wai.Test ( + -- * Session + Session, + runSession, + withSession, + + -- * Client Cookies + ClientCookies, + getClientCookies, + modifyClientCookies, + setClientCookie, + deleteClientCookie, + + -- * Requests + request, + srequest, + SRequest (..), + SResponse (..), + defaultRequest, + setPath, + setRawPathInfo, + + -- * Assertions + assertStatus, + assertContentType, + assertBody, + assertBodyContains, + assertHeader, + assertNoHeader, + assertClientCookieExists, + assertNoClientCookieExists, + assertClientCookieValue, +) where #if __GLASGOW_HASKELL__ < 710 import Control.Applicative ((<$>)) @@ -72,22 +76,22 @@ getClientCookies = clientCookies <$> lift ST.get -- Since 3.0.6 modifyClientCookies :: (ClientCookies -> ClientCookies) -> Session () modifyClientCookies f = - lift (ST.modify (\cs -> cs { clientCookies = f $ clientCookies cs })) + lift (ST.modify (\cs -> cs{clientCookies = f $ clientCookies cs})) -- | -- -- Since 3.0.6 setClientCookie :: Cookie.SetCookie -> Session () setClientCookie c = - modifyClientCookies $ - Map.insert (Cookie.setCookieName c) c + modifyClientCookies $ + Map.insert (Cookie.setCookieName c) c -- | -- -- Since 3.0.6 deleteClientCookie :: ByteString -> Session () deleteClientCookie = - modifyClientCookies . Map.delete + modifyClientCookies . Map.delete -- | See also: 'runSessionWith'. runSession :: Session a -> Application -> IO a @@ -124,60 +128,66 @@ request req = do -- | Set whole path (request path + query string). setPath :: Request -> S8.ByteString -> Request -setPath req path = req { - pathInfo = segments - , rawPathInfo = L8.toStrict . toLazyByteString $ H.encodePathSegments segments - , queryString = query - , rawQueryString = H.renderQuery True query - } +setPath req path = + req + { pathInfo = segments + , rawPathInfo = L8.toStrict . toLazyByteString $ H.encodePathSegments segments + , queryString = query + , rawQueryString = H.renderQuery True query + } where (segments, query) = H.decodePath path setRawPathInfo :: Request -> S8.ByteString -> Request setRawPathInfo r rawPinfo = let pInfo = dropFrontSlash $ T.split (== '/') $ TE.decodeUtf8 rawPinfo - in r { rawPathInfo = rawPinfo, pathInfo = pInfo } + in r{rawPathInfo = rawPinfo, pathInfo = pInfo} where - dropFrontSlash ["",""] = [] -- homepage, a single slash - dropFrontSlash ("":path) = path + dropFrontSlash ["", ""] = [] -- homepage, a single slash + dropFrontSlash ("" : path) = path dropFrontSlash path = path addCookiesToRequest :: Request -> Session Request addCookiesToRequest req = do - oldClientCookies <- getClientCookies - let requestPath = "/" `T.append` T.intercalate "/" (pathInfo req) - currentUTCTime <- liftIO getCurrentTime - let cookiesForRequest = - Map.filter - (\c -> checkCookieTime currentUTCTime c - && checkCookiePath requestPath c) - oldClientCookies - let cookiePairs = [ (Cookie.setCookieName c, Cookie.setCookieValue c) - | c <- map snd $ Map.toList cookiesForRequest - ] - let cookieValue = L8.toStrict . toLazyByteString $ Cookie.renderCookies cookiePairs - addCookieHeader rest - | null cookiePairs = rest - | otherwise = ("Cookie", cookieValue) : rest - return $ req { requestHeaders = addCookieHeader $ requestHeaders req } - where checkCookieTime t c = - case Cookie.setCookieExpires c of - Nothing -> True - Just t' -> t < t' - checkCookiePath p c = - case Cookie.setCookiePath c of - Nothing -> True - Just p' -> p' `S8.isPrefixOf` TE.encodeUtf8 p + oldClientCookies <- getClientCookies + let requestPath = "/" `T.append` T.intercalate "/" (pathInfo req) + currentUTCTime <- liftIO getCurrentTime + let cookiesForRequest = + Map.filter + ( \c -> + checkCookieTime currentUTCTime c + && checkCookiePath requestPath c + ) + oldClientCookies + let cookiePairs = + [ (Cookie.setCookieName c, Cookie.setCookieValue c) + | c <- map snd $ Map.toList cookiesForRequest + ] + let cookieValue = L8.toStrict . toLazyByteString $ Cookie.renderCookies cookiePairs + addCookieHeader rest + | null cookiePairs = rest + | otherwise = ("Cookie", cookieValue) : rest + return $ req{requestHeaders = addCookieHeader $ requestHeaders req} + where + checkCookieTime t c = + case Cookie.setCookieExpires c of + Nothing -> True + Just t' -> t < t' + checkCookiePath p c = + case Cookie.setCookiePath c of + Nothing -> True + Just p' -> p' `S8.isPrefixOf` TE.encodeUtf8 p extractSetCookieFromSResponse :: SResponse -> Session SResponse extractSetCookieFromSResponse response = do - let setCookieHeaders = - filter (("Set-Cookie"==) . fst) $ simpleHeaders response - let newClientCookies = map (Cookie.parseSetCookie . snd) setCookieHeaders - modifyClientCookies - (Map.union - (Map.fromList [(Cookie.setCookieName c, c) | c <- newClientCookies ])) - return response + let setCookieHeaders = + filter (("Set-Cookie" ==) . fst) $ simpleHeaders response + let newClientCookies = map (Cookie.parseSetCookie . snd) setCookieHeaders + modifyClientCookies + ( Map.union + (Map.fromList [(Cookie.setCookieName c, c) | c <- newClientCookies]) + ) + return response -- | Similar to 'request', but allows setting the request body as a plain -- 'L.ByteString'. @@ -187,8 +197,9 @@ srequest (SRequest req bod) = do let rbody = atomicModifyIORef refChunks $ \bss -> case bss of [] -> ([], S.empty) - x:y -> (y, x) + x : y -> (y, x) request $ setRequestBodyChunks rbody req + {- HLint ignore srequest "Use lambda-case" -} runResponse :: IORef SResponse -> Response -> IO ResponseReceived @@ -218,112 +229,140 @@ assertFailure = liftIO . HUnit.assertFailure assertContentType :: HasCallStack => ByteString -> SResponse -> Session () assertContentType ct SResponse{simpleHeaders = h} = case lookup "content-type" h of - Nothing -> assertString $ concat - [ "Expected content type " - , show ct - , ", but no content type provided" - ] - Just ct' -> assertBool (concat - [ "Expected content type " - , show ct - , ", but received " - , show ct' - ]) (go ct == go ct') + Nothing -> + assertString $ + concat + [ "Expected content type " + , show ct + , ", but no content type provided" + ] + Just ct' -> + assertBool + ( concat + [ "Expected content type " + , show ct + , ", but received " + , show ct' + ] + ) + (go ct == go ct') where go = S8.takeWhile (/= ';') assertStatus :: HasCallStack => Int -> SResponse -> Session () -assertStatus i SResponse{simpleStatus = s} = assertBool (concat - [ "Expected status code " - , show i - , ", but received " - , show sc - ]) $ i == sc +assertStatus i SResponse{simpleStatus = s} = + assertBool + ( concat + [ "Expected status code " + , show i + , ", but received " + , show sc + ] + ) + $ i == sc where sc = H.statusCode s assertBody :: HasCallStack => L.ByteString -> SResponse -> Session () -assertBody lbs SResponse{simpleBody = lbs'} = assertBool (concat - [ "Expected response body " - , show $ L8.unpack lbs - , ", but received " - , show $ L8.unpack lbs' - ]) $ lbs == lbs' +assertBody lbs SResponse{simpleBody = lbs'} = + assertBool + ( concat + [ "Expected response body " + , show $ L8.unpack lbs + , ", but received " + , show $ L8.unpack lbs' + ] + ) + $ lbs == lbs' assertBodyContains :: HasCallStack => L.ByteString -> SResponse -> Session () -assertBodyContains lbs SResponse{simpleBody = lbs'} = assertBool (concat - [ "Expected response body to contain " - , show $ L8.unpack lbs - , ", but received " - , show $ L8.unpack lbs' - ]) $ strict lbs `S.isInfixOf` strict lbs' +assertBodyContains lbs SResponse{simpleBody = lbs'} = + assertBool + ( concat + [ "Expected response body to contain " + , show $ L8.unpack lbs + , ", but received " + , show $ L8.unpack lbs' + ] + ) + $ strict lbs `S.isInfixOf` strict lbs' where strict = S.concat . L.toChunks -assertHeader :: HasCallStack => CI ByteString -> ByteString -> SResponse -> Session () +assertHeader + :: HasCallStack => CI ByteString -> ByteString -> SResponse -> Session () assertHeader header value SResponse{simpleHeaders = h} = case lookup header h of - Nothing -> assertString $ concat - [ "Expected header " - , show header - , " to be " - , show value - , ", but it was not present" - ] - Just value' -> assertBool (concat - [ "Expected header " - , show header - , " to be " - , show value - , ", but received " - , show value' - ]) (value == value') + Nothing -> + assertString $ + concat + [ "Expected header " + , show header + , " to be " + , show value + , ", but it was not present" + ] + Just value' -> + assertBool + ( concat + [ "Expected header " + , show header + , " to be " + , show value + , ", but received " + , show value' + ] + ) + (value == value') assertNoHeader :: HasCallStack => CI ByteString -> SResponse -> Session () assertNoHeader header SResponse{simpleHeaders = h} = case lookup header h of Nothing -> return () - Just s -> assertString $ concat - [ "Unexpected header " - , show header - , " containing " - , show s - ] + Just s -> + assertString $ + concat + [ "Unexpected header " + , show header + , " containing " + , show s + ] -- | -- -- Since 3.0.6 assertClientCookieExists :: HasCallStack => String -> ByteString -> Session () assertClientCookieExists s cookieName = do - cookies <- getClientCookies - assertBool s $ Map.member cookieName cookies + cookies <- getClientCookies + assertBool s $ Map.member cookieName cookies -- | -- -- Since 3.0.6 assertNoClientCookieExists :: HasCallStack => String -> ByteString -> Session () assertNoClientCookieExists s cookieName = do - cookies <- getClientCookies - assertBool s $ not $ Map.member cookieName cookies + cookies <- getClientCookies + assertBool s $ not $ Map.member cookieName cookies -- | -- -- Since 3.0.6 -assertClientCookieValue :: HasCallStack => String -> ByteString -> ByteString -> Session () +assertClientCookieValue + :: HasCallStack => String -> ByteString -> ByteString -> Session () assertClientCookieValue s cookieName cookieValue = do - cookies <- getClientCookies - case Map.lookup cookieName cookies of - Nothing -> - assertFailure (s ++ " (cookie does not exist)") - Just c -> - assertBool - (concat - [ s - , " (actual value " - , show $ Cookie.setCookieValue c - , " expected value " - , show cookieValue - , ")" - ] - ) - (Cookie.setCookieValue c == cookieValue) + cookies <- getClientCookies + case Map.lookup cookieName cookies of + Nothing -> + assertFailure (s ++ " (cookie does not exist)") + Just c -> + assertBool + ( concat + [ s + , " (actual value " + , show $ Cookie.setCookieValue c + , " expected value " + , show cookieValue + , ")" + ] + ) + (Cookie.setCookieValue c == cookieValue) diff --git a/wai-extra/Network/Wai/UrlMap.hs b/wai-extra/Network/Wai/UrlMap.hs index b09f84026..96cc54f25 100644 --- a/wai-extra/Network/Wai/UrlMap.hs +++ b/wai-extra/Network/Wai/UrlMap.hs @@ -1,21 +1,20 @@ {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleInstances #-} -{- | This module gives you a way to mount applications under sub-URIs. -For example: -> bugsApp, helpdeskApp, apiV1, apiV2, mainApp :: Application -> -> myApp :: Application -> myApp = mapUrls $ -> mount "bugs" bugsApp -> <|> mount "helpdesk" helpdeskApp -> <|> mount "api" -> ( mount "v1" apiV1 -> <|> mount "v2" apiV2 -> ) -> <|> mountRoot mainApp - --} +-- | This module gives you a way to mount applications under sub-URIs. +-- For example: +-- +-- > bugsApp, helpdeskApp, apiV1, apiV2, mainApp :: Application +-- > +-- > myApp :: Application +-- > myApp = mapUrls $ +-- > mount "bugs" bugsApp +-- > <|> mount "helpdesk" helpdeskApp +-- > <|> mount "api" +-- > ( mount "v1" apiV1 +-- > <|> mount "v2" apiV2 +-- > ) +-- > <|> mountRoot mainApp module Network.Wai.UrlMap ( UrlMap', UrlMap, @@ -35,19 +34,22 @@ import Network.HTTP.Types (hContentType, status404) import Network.Wai (Application, Request (pathInfo, rawPathInfo), responseLBS) type Path = [Text] -newtype UrlMap' a = UrlMap' { unUrlMap :: [(Path, a)] } +newtype UrlMap' a = UrlMap' {unUrlMap :: [(Path, a)]} instance Functor UrlMap' where fmap f (UrlMap' xs) = UrlMap' (fmap (fmap f) xs) instance Applicative UrlMap' where - pure x = UrlMap' [([], x)] - (UrlMap' xs) <*> (UrlMap' ys) = UrlMap' [ (p, f y) | - (p, y) <- ys, - f <- map snd xs ] + pure x = UrlMap' [([], x)] + (UrlMap' xs) <*> (UrlMap' ys) = + UrlMap' + [ (p, f y) + | (p, y) <- ys + , f <- map snd xs + ] instance Alternative UrlMap' where - empty = UrlMap' empty + empty = UrlMap' empty (UrlMap' xs) <|> (UrlMap' ys) = UrlMap' (xs <|> ys) type UrlMap = UrlMap' Application @@ -68,14 +70,17 @@ mount prefix = mount' [prefix] mountRoot :: ToApplication a => a -> UrlMap mountRoot = mount' [] -try :: Eq a - => [a] -- ^ Path info of request - -> [([a], b)] -- ^ List of applications to match +try + :: Eq a + => [a] + -- ^ Path info of request + -> [([a], b)] + -- ^ List of applications to match -> Maybe ([a], b) try xs = foldl go Nothing - where - go (Just x) _ = Just x - go _ (prefix, y) = stripPrefix prefix xs >>= \xs' -> return (xs', y) + where + go (Just x) _ = Just x + go _ (prefix, y) = stripPrefix prefix xs >>= \xs' -> return (xs', y) class ToApplication a where toApplication :: a -> Application @@ -87,16 +92,20 @@ instance ToApplication UrlMap where toApplication urlMap req sendResponse = case try (pathInfo req) (unUrlMap urlMap) of Just (newPath, app) -> - app (req { pathInfo = newPath - , rawPathInfo = makeRaw newPath - }) sendResponse + app + ( req + { pathInfo = newPath + , rawPathInfo = makeRaw newPath + } + ) + sendResponse Nothing -> - sendResponse $ responseLBS - status404 - [(hContentType, "text/plain")] - "Not found\n" - - where + sendResponse $ + responseLBS + status404 + [(hContentType, "text/plain")] + "Not found\n" + where makeRaw :: [Text] -> B.ByteString makeRaw = ("/" `B.append`) . T.encodeUtf8 . T.intercalate "/" diff --git a/wai-extra/Network/Wai/Util.hs b/wai-extra/Network/Wai/Util.hs index 138faa459..3a68527a1 100644 --- a/wai-extra/Network/Wai/Util.hs +++ b/wai-extra/Network/Wai/Util.hs @@ -1,10 +1,11 @@ {-# LANGUAGE CPP #-} + -- | Some helpers functions. -module Network.Wai.Util - ( dropWhileEnd - , splitCommas - , trimWS - ) where +module Network.Wai.Util ( + dropWhileEnd, + splitCommas, + trimWS, +) where import qualified Data.ByteString as S import Data.Word8 (Word8, _comma, _space) diff --git a/wai-extra/example/Main.hs b/wai-extra/example/Main.hs index aa70c12b0..4cd558f73 100644 --- a/wai-extra/example/Main.hs +++ b/wai-extra/example/Main.hs @@ -1,26 +1,36 @@ {-# LANGUAGE OverloadedStrings #-} + module Main where -import Data.ByteString.Builder (string8) import Control.Concurrent (forkIO, threadDelay) import Control.Concurrent.Chan import Control.Monad (forever) +import Data.ByteString.Builder (string8) import Data.Time.Clock.POSIX (getPOSIXTime) import Network.HTTP.Types (status200) import Network.Wai (Application, Middleware, pathInfo, responseFile) -import Network.Wai.EventSource (ServerEvent(..), eventSourceAppChan, eventSourceAppIO) +import Network.Wai.EventSource ( + ServerEvent (..), + eventSourceAppChan, + eventSourceAppIO, + ) import Network.Wai.Handler.Warp (run) import Network.Wai.Middleware.AddHeaders (addHeaders) -import Network.Wai.Middleware.Gzip (gzip, def) - +import Network.Wai.Middleware.Gzip (def, gzip) app :: Chan ServerEvent -> Application app chan req respond = case pathInfo req of - [] -> respond $ responseFile status200 [("Content-Type", "text/html")] "example/index.html" Nothing - ["esold"] -> eventSourceAppChan chan req respond + [] -> + respond $ + responseFile + status200 + [("Content-Type", "text/html")] + "example/index.html" + Nothing + ["esold"] -> eventSourceAppChan chan req respond ["eschan"] -> eventSourceAppChan chan req respond - ["esio"] -> eventSourceAppIO eventIO req respond + ["esio"] -> eventSourceAppIO eventIO req respond _ -> error "unexpected pathInfo" eventChan :: Chan ServerEvent -> IO () @@ -33,30 +43,37 @@ eventIO :: IO ServerEvent eventIO = do threadDelay 1000000 time <- getPOSIXTime - return $ ServerEvent (Just $ string8 "io") - Nothing - [string8 . show $ time] + return $ + ServerEvent + (Just $ string8 "io") + Nothing + [string8 . show $ time] eventRaw :: (ServerEvent -> IO ()) -> IO () -> IO () eventRaw = handle (0 :: Int) - where - handle counter emit flush = do - threadDelay 1000000 - _ <- emit $ ServerEvent (Just $ string8 "raw") - Nothing - [string8 . show $ counter] - _ <- flush - handle (counter + 1) emit flush + where + handle counter emit flush = do + threadDelay 1000000 + _ <- + emit $ + ServerEvent + (Just $ string8 "raw") + Nothing + [string8 . show $ counter] + _ <- flush + handle (counter + 1) emit flush main :: IO () main = do chan <- newChan _ <- forkIO . eventChan $ chan run 8080 (gzip def $ headers $ app chan) - where - -- headers required for SSE to work through nginx - -- not required if using warp directly - headers :: Middleware - headers = addHeaders [ ("X-Accel-Buffering", "no") - , ("Cache-Control", "no-cache") - ] + where + -- headers required for SSE to work through nginx + -- not required if using warp directly + headers :: Middleware + headers = + addHeaders + [ ("X-Accel-Buffering", "no") + , ("Cache-Control", "no-cache") + ] diff --git a/wai-extra/proxy.hs b/wai-extra/proxy.hs index 6178bacd3..8c03a66d8 100644 --- a/wai-extra/proxy.hs +++ b/wai-extra/proxy.hs @@ -1,19 +1,20 @@ {-# LANGUAGE OverloadedStrings #-} -import qualified Network.HTTP.Enumerator as H -import qualified Network.Wai as W -import Network.Wai.Middleware.Gzip (gzip) -import Network.Wai.Handler.SimpleServer (run) + +import Blaze.ByteString.Builder (fromByteString) import qualified Data.ByteString.Char8 as S8 import qualified Data.ByteString.Lazy.Char8 () -import Data.Enumerator (($$), joinI, run_) -import Blaze.ByteString.Builder (fromByteString) +import Data.Enumerator (joinI, run_, ($$)) import qualified Data.Enumerator as E +import qualified Network.HTTP.Enumerator as H +import qualified Network.Wai as W +import Network.Wai.Handler.SimpleServer (run) +import Network.Wai.Middleware.Gzip (gzip) main :: IO () main = run 3000 $ gzip False app app :: W.Application -app W.Request { W.pathInfo = path } = +app W.Request{W.pathInfo = path} = case H.parseUrl $ "http://wiki.yesodweb.com" ++ S8.unpack path of Nothing -> return notFound Just hreq -> return $ W.ResponseEnumerator $ run_ . H.http hreq . go diff --git a/wai-extra/test/Network/Wai/Middleware/ApprootSpec.hs b/wai-extra/test/Network/Wai/Middleware/ApprootSpec.hs index fe384ee61..5f2527cf5 100644 --- a/wai-extra/test/Network/Wai/Middleware/ApprootSpec.hs +++ b/wai-extra/test/Network/Wai/Middleware/ApprootSpec.hs @@ -1,8 +1,9 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.Middleware.ApprootSpec - ( main - , spec - ) where + +module Network.Wai.Middleware.ApprootSpec ( + main, + spec, +) where import Data.ByteString (ByteString) import Network.HTTP.Types (RequestHeaders, status200) @@ -22,16 +23,23 @@ spec = do simpleHeaders resp `shouldBe` [("Approot", expected)] test "respects host header" "foobar" False [] "http://foobar" test "respects isSecure" "foobar" True [] "https://foobar" - test "respects SSL headers" "foobar" False - [("HTTP_X_FORWARDED_SSL", "on")] "https://foobar" + test + "respects SSL headers" + "foobar" + False + [("HTTP_X_FORWARDED_SSL", "on")] + "https://foobar" runApp :: ByteString -> Bool -> RequestHeaders -> IO SResponse -runApp host secure headers = runSession - (request defaultRequest - { requestHeaderHost = Just host - , isSecure = secure - , requestHeaders = headers - }) $ fromRequest app - +runApp host secure headers = + runSession + ( request + defaultRequest + { requestHeaderHost = Just host + , isSecure = secure + , requestHeaders = headers + } + ) + $ fromRequest app where app req respond = respond $ responseLBS status200 [("Approot", getApproot req)] "" diff --git a/wai-extra/test/Network/Wai/Middleware/CombineHeadersSpec.hs b/wai-extra/test/Network/Wai/Middleware/CombineHeadersSpec.hs index d8faa6937..e9b7ba3f1 100644 --- a/wai-extra/test/Network/Wai/Middleware/CombineHeadersSpec.hs +++ b/wai-extra/test/Network/Wai/Middleware/CombineHeadersSpec.hs @@ -1,8 +1,9 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.Middleware.CombineHeadersSpec - ( main - , spec - ) where + +module Network.Wai.Middleware.CombineHeadersSpec ( + main, + spec, +) where import Data.ByteString (ByteString) import Data.IORef (newIORef, readIORef, writeIORef) @@ -11,7 +12,13 @@ import Network.HTTP.Types.Header import Network.Wai import Test.Hspec -import Network.Wai.Middleware.CombineHeaders (CombineSettings, combineHeaders, defaultCombineSettings, setRequestHeaders, setResponseHeaders) +import Network.Wai.Middleware.CombineHeaders ( + CombineSettings, + combineHeaders, + defaultCombineSettings, + setRequestHeaders, + setResponseHeaders, + ) import Network.Wai.Test (SResponse (simpleHeaders), request, runSession) main :: IO () @@ -26,18 +33,24 @@ spec = do testReqHdrs name a b = test name defaultCombineSettings a b [] [] testResHdrs name a b = - test name (setRequestHeaders False $ setResponseHeaders True defaultCombineSettings) [] [] a b + test + name + (setRequestHeaders False $ setResponseHeaders True defaultCombineSettings) + [] + [] + a + b -- Request Headers testReqHdrs "should reorder alphabetically (request)" - [host , userAgent, acceptHtml] - [acceptHtml, host , userAgent ] + [host, userAgent, acceptHtml] + [acceptHtml, host, userAgent] -- Response Headers testResHdrs "should reorder alphabetically (response)" - [expires , location, contentTypeHtml] - [contentTypeHtml, expires , location ] + [expires, location, contentTypeHtml] + [contentTypeHtml, expires, location] -- Request Headers testReqHdrs @@ -48,58 +61,70 @@ spec = do testResHdrs -- Using the default header map, Cache-Control is a "combineable" header, "Set-Cookie" is not "combines Cache-Control (in order) and keeps Set-Cookie (in order)" - [ cacheControlPublic, setCookie "2", date, cacheControlMax, setCookie "1"] - [ cacheControlPublic `combineHdrs` cacheControlMax, date, setCookie "2", setCookie "1"] + [cacheControlPublic, setCookie "2", date, cacheControlMax, setCookie "1"] + [ cacheControlPublic `combineHdrs` cacheControlMax + , date + , setCookie "2" + , setCookie "1" + ] -- Request Headers testReqHdrs "KeepOnly works as expected (present | request)" -- "Alt-Svc" has (KeepOnly "clear") - [ date, altSvc "wrong", altSvc "clear", altSvc "wrong again", host ] - [ altSvc "clear", date, host ] + [date, altSvc "wrong", altSvc "clear", altSvc "wrong again", host] + [altSvc "clear", date, host] testReqHdrs "KeepOnly works as expected ( absent | request)" -- "Alt-Svc" has (KeepOnly "clear"), but will combine when there's no "clear" (AND keeps order) - [ date, altSvc "wrong", altSvc "not clear", altSvc "wrong again", host ] - [ altSvc "wrong, not clear, wrong again", date, host ] + [date, altSvc "wrong", altSvc "not clear", altSvc "wrong again", host] + [altSvc "wrong, not clear, wrong again", date, host] -- Response Headers testResHdrs "KeepOnly works as expected (present | response)" -- "If-None-Match" has (KeepOnly "*") - [ date, ifNoneMatch "wrong", ifNoneMatch "*", ifNoneMatch "wrong again", host ] - [ date, host, ifNoneMatch "*" ] + [date, ifNoneMatch "wrong", ifNoneMatch "*", ifNoneMatch "wrong again", host] + [date, host, ifNoneMatch "*"] testResHdrs "KeepOnly works as expected ( absent | response)" -- "If-None-Match" has (KeepOnly "*"), but will combine when there's no "*" (AND keeps order) - [ date, ifNoneMatch "wrong", ifNoneMatch "not *", ifNoneMatch "wrong again", host ] - [ date, host, ifNoneMatch "wrong, not *, wrong again" ] + [ date + , ifNoneMatch "wrong" + , ifNoneMatch "not *" + , ifNoneMatch "wrong again" + , host + ] + [date, host, ifNoneMatch "wrong, not *, wrong again"] -- Request Headers testReqHdrs "Technically acceptable headers get combined correctly (request)" - [ ifNoneMatch "correct, ", ifNoneMatch "something else \t", ifNoneMatch "and more , "] - [ ifNoneMatch "correct, something else, and more" ] + [ ifNoneMatch "correct, " + , ifNoneMatch "something else \t" + , ifNoneMatch "and more , " + ] + [ifNoneMatch "correct, something else, and more"] -- Response Headers testResHdrs "Technically acceptable headers get combined correctly (response)" - [ altSvc "correct\t, ", altSvc "something else", altSvc "and more, , "] - [ altSvc "correct, something else, and more" ] + [altSvc "correct\t, ", altSvc "something else", altSvc "and more, , "] + [altSvc "correct, something else, and more"] combineHdrs :: Header -> Header -> Header combineHdrs (hname, h1) (_, h2) = (hname, h1 <> ", " <> h2) -acceptHtml, - acceptJSON, - cacheControlMax, - cacheControlPublic, - contentTypeHtml, - date, - expires, - host, - location, - userAgent :: Header - +acceptHtml + , acceptJSON + , cacheControlMax + , cacheControlPublic + , contentTypeHtml + , date + , expires + , host + , location + , userAgent + :: Header acceptHtml = (hAccept, "text/html") acceptJSON = (hAccept, "application/json") altSvc :: ByteString -> Header @@ -117,18 +142,24 @@ setCookie :: ByteString -> Header setCookie val = (hSetCookie, val) userAgent = (hUserAgent, "curl/7.68.0") -runApp :: CombineSettings -> RequestHeaders -> ResponseHeaders -> IO (RequestHeaders, ResponseHeaders) +runApp + :: CombineSettings + -> RequestHeaders + -> ResponseHeaders + -> IO (RequestHeaders, ResponseHeaders) runApp settings reqHeaders resHeaders = do reqHdrs <- newIORef $ error "IORef not set" - sResponse <- runSession - session - $ combineHeaders settings $ app reqHdrs + sResponse <- + runSession + session + $ combineHeaders settings + $ app reqHdrs finalReqHeaders <- readIORef reqHdrs pure (finalReqHeaders, simpleHeaders sResponse) where session = request - defaultRequest { requestHeaders = reqHeaders } + defaultRequest{requestHeaders = reqHeaders} app hdrRef req respond = do writeIORef hdrRef $ requestHeaders req respond $ responseLBS status200 resHeaders "" diff --git a/wai-extra/test/Network/Wai/Middleware/ForceSSLSpec.hs b/wai-extra/test/Network/Wai/Middleware/ForceSSLSpec.hs index 998b06d1b..4cf70de7b 100644 --- a/wai-extra/test/Network/Wai/Middleware/ForceSSLSpec.hs +++ b/wai-extra/test/Network/Wai/Middleware/ForceSSLSpec.hs @@ -1,10 +1,10 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.Middleware.ForceSSLSpec - ( main - , spec - ) where +module Network.Wai.Middleware.ForceSSLSpec ( + main, + spec, +) where import Control.Monad (forM_) import Data.ByteString (ByteString) @@ -28,7 +28,6 @@ spec = describe "forceSSL" (forM_ hosts $ \host -> hostSpec host) hostSpec :: ByteString -> Spec hostSpec host = describe ("forceSSL on host " <> show host <> "") $ do - it "redirects non-https requests to https" $ do resp <- runApp host forceSSL defaultRequest @@ -36,28 +35,39 @@ hostSpec host = describe ("forceSSL on host " <> show host <> "") $ do simpleHeaders resp `shouldBe` [("Location", "https://" <> host)] it "redirects with 307 in the case of a non-GET request" $ do - resp <- runApp host forceSSL defaultRequest - { requestMethod = methodPost } + resp <- + runApp + host + forceSSL + defaultRequest + { requestMethod = methodPost + } simpleStatus resp `shouldBe` status307 simpleHeaders resp `shouldBe` [("Location", "https://" <> host)] it "does not redirect already-secure requests" $ do - resp <- runApp host forceSSL defaultRequest { isSecure = True } + resp <- runApp host forceSSL defaultRequest{isSecure = True} simpleStatus resp `shouldBe` status200 it "preserves the original host, path, and query string" $ do - resp <- runApp host forceSSL defaultRequest - { rawPathInfo = "/foo/bar" - , rawQueryString = "?baz=bat" - } + resp <- + runApp + host + forceSSL + defaultRequest + { rawPathInfo = "/foo/bar" + , rawQueryString = "?baz=bat" + } - simpleHeaders resp `shouldBe` - [("Location", "https://" <> host <> "/foo/bar?baz=bat")] + simpleHeaders resp + `shouldBe` [("Location", "https://" <> host <> "/foo/bar?baz=bat")] runApp :: ByteString -> Middleware -> Request -> IO SResponse -runApp host mw req = runSession - (request req { requestHeaderHost = Just host }) $ mw app +runApp host mw req = + runSession + (request req{requestHeaderHost = Just host}) + $ mw app where app _ respond = respond $ responseLBS status200 [] "" diff --git a/wai-extra/test/Network/Wai/Middleware/RealIpSpec.hs b/wai-extra/test/Network/Wai/Middleware/RealIpSpec.hs index 13c532a6f..aba41a5dc 100644 --- a/wai-extra/test/Network/Wai/Middleware/RealIpSpec.hs +++ b/wai-extra/test/Network/Wai/Middleware/RealIpSpec.hs @@ -1,8 +1,8 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.Middleware.RealIpSpec - ( spec - ) where +module Network.Wai.Middleware.RealIpSpec ( + spec, +) where import qualified Data.ByteString.Lazy.Char8 as B8 import qualified Data.IP as IP @@ -82,10 +82,13 @@ spec = do simpleBody resp1 `shouldBe` "10.0.0.1" simpleBody resp2 `shouldBe` "10.0.0.2" - runApp :: IP.IP -> RequestHeaders -> Middleware -> IO SResponse -runApp ip hs mw = runSession - (request defaultRequest { remoteHost = IP.toSockAddr (ip, 80), requestHeaders = hs }) $ mw app +runApp ip hs mw = + runSession + ( request + defaultRequest{remoteHost = IP.toSockAddr (ip, 80), requestHeaders = hs} + ) + $ mw app where app req respond = respond $ responseLBS status200 [] $ renderIp req renderIp = B8.pack . maybe "" (show . fst) . IP.fromSockAddr . remoteHost diff --git a/wai-extra/test/Network/Wai/Middleware/RequestSizeLimitSpec.hs b/wai-extra/test/Network/Wai/Middleware/RequestSizeLimitSpec.hs index 99dc76fb4..366cc9051 100644 --- a/wai-extra/test/Network/Wai/Middleware/RequestSizeLimitSpec.hs +++ b/wai-extra/test/Network/Wai/Middleware/RequestSizeLimitSpec.hs @@ -1,4 +1,5 @@ {-# LANGUAGE OverloadedStrings #-} + module Network.Wai.Middleware.RequestSizeLimitSpec (main, spec) where import Control.Monad (replicateM) @@ -20,86 +21,121 @@ main = hspec spec spec :: Spec spec = describe "RequestSizeLimitMiddleware" $ do - describe "Plain text response" $ do - runStrictBodyTests "returns 413 for request bodies > 10 bytes, when streaming the whole body" tenByteLimitSettings "1234567890a" isStatus413 - runStrictBodyTests "returns 200 for request bodies <= 10 bytes, when streaming the whole body" tenByteLimitSettings "1234567890" isStatus200 - - describe "JSON response" $ do - runStrictBodyTests "returns 413 for request bodies > 10 bytes, when streaming the whole body" tenByteLimitJSONSettings "1234567890a" (isStatus413 >> isJSONContentType) - runStrictBodyTests "returns 200 for request bodies <= 10 bytes, when streaming the whole body" tenByteLimitJSONSettings "1234567890" isStatus200 - - describe "Per-request sizes" $ do - - it "allows going over the limit, when the path has been whitelisted" $ do - let req = SRequest defaultRequest - { pathInfo = ["upload", "image"] - } "1234567890a" - settings = - setMaxLengthForRequest - (\req' -> if pathInfo req' == ["upload", "image"] then pure $ Just 20 else pure $ Just 10) - defaultRequestSizeLimitSettings - resp <- runStrictBodyApp settings req - isStatus200 resp - - describe "streaming chunked bodies" $ do - let streamingReq = - setRequestBodyChunks (return "a") defaultRequest - { isSecure = False - , requestBodyLength = ChunkedBody - } - it "413s if the combined chunk size is > the size limit" $ do - resp <- runStreamingChunkApp 11 tenByteLimitSettings streamingReq - simpleStatus resp `shouldBe` status413 - it "200s if the combined chunk size is <= the size limit" $ do - resp <- runStreamingChunkApp 10 tenByteLimitSettings streamingReq - simpleStatus resp `shouldBe` status200 - + describe "Plain text response" $ do + runStrictBodyTests + "returns 413 for request bodies > 10 bytes, when streaming the whole body" + tenByteLimitSettings + "1234567890a" + isStatus413 + runStrictBodyTests + "returns 200 for request bodies <= 10 bytes, when streaming the whole body" + tenByteLimitSettings + "1234567890" + isStatus200 + + describe "JSON response" $ do + runStrictBodyTests + "returns 413 for request bodies > 10 bytes, when streaming the whole body" + tenByteLimitJSONSettings + "1234567890a" + (isStatus413 >> isJSONContentType) + runStrictBodyTests + "returns 200 for request bodies <= 10 bytes, when streaming the whole body" + tenByteLimitJSONSettings + "1234567890" + isStatus200 + + describe "Per-request sizes" $ do + it "allows going over the limit, when the path has been whitelisted" $ do + let req = + SRequest + defaultRequest + { pathInfo = ["upload", "image"] + } + "1234567890a" + settings = + setMaxLengthForRequest + ( \req' -> + if pathInfo req' == ["upload", "image"] then pure $ Just 20 else pure $ Just 10 + ) + defaultRequestSizeLimitSettings + resp <- runStrictBodyApp settings req + isStatus200 resp + + describe "streaming chunked bodies" $ do + let streamingReq = + setRequestBodyChunks + (return "a") + defaultRequest + { isSecure = False + , requestBodyLength = ChunkedBody + } + it "413s if the combined chunk size is > the size limit" $ do + resp <- runStreamingChunkApp 11 tenByteLimitSettings streamingReq + simpleStatus resp `shouldBe` status413 + it "200s if the combined chunk size is <= the size limit" $ do + resp <- runStreamingChunkApp 10 tenByteLimitSettings streamingReq + simpleStatus resp `shouldBe` status200 where tenByteLimitSettings = - setMaxLengthForRequest - (\_req -> pure $ Just 10) - defaultRequestSizeLimitSettings + setMaxLengthForRequest + (\_req -> pure $ Just 10) + defaultRequestSizeLimitSettings tenByteLimitJSONSettings = - setOnLengthExceeded - (\_maxLen _app _req sendResponse -> sendResponse $ responseLBS status413 [("Content-Type", "application/json")] (encode $ object ["error" .= ("request size too large" :: Text)])) - tenByteLimitSettings + setOnLengthExceeded + ( \_maxLen _app _req sendResponse -> + sendResponse $ + responseLBS + status413 + [("Content-Type", "application/json")] + (encode $ object ["error" .= ("request size too large" :: Text)]) + ) + tenByteLimitSettings isStatus413 = \sResp -> simpleStatus sResp `shouldBe` status413 isStatus200 = \sResp -> simpleStatus sResp `shouldBe` status200 isJSONContentType = \sResp -> simpleHeaders sResp `shouldBe` [("Content-Type", "application/json")] data LengthType = UseKnownLength | UseChunked - deriving (Show, Eq) - -runStrictBodyTests :: String -> RequestSizeLimitSettings -> ByteString -> (SResponse -> Expectation) -> Spec + deriving (Show, Eq) + +runStrictBodyTests + :: String + -> RequestSizeLimitSettings + -> ByteString + -> (SResponse -> Expectation) + -> Spec runStrictBodyTests name settings reqBody runExpectations = describe name $ do - it "chunked" $ do - let req = mkRequestWithBytestring reqBody UseChunked - resp <- runStrictBodyApp settings req + it "chunked" $ do + let req = mkRequestWithBytestring reqBody UseChunked + resp <- runStrictBodyApp settings req - runExpectations resp - it "non-chunked" $ do - let req = mkRequestWithBytestring reqBody UseKnownLength - resp <- runStrictBodyApp settings req + runExpectations resp + it "non-chunked" $ do + let req = mkRequestWithBytestring reqBody UseKnownLength + resp <- runStrictBodyApp settings req - runExpectations resp + runExpectations resp where mkRequestWithBytestring :: ByteString -> LengthType -> SRequest mkRequestWithBytestring body lengthType = SRequest adjustedRequest $ - L.fromChunks $ map S.singleton $ S.unpack body + L.fromChunks $ + map S.singleton $ + S.unpack body where - adjustedRequest = defaultRequest - { requestHeaders = - [ (hContentLength, S8.pack $ show $ S.length body) - | lengthType == UseKnownLength - ] - , requestMethod = "POST" - , requestBodyLength = - if lengthType == UseKnownLength - then KnownLength $ fromIntegral $ S.length body - else ChunkedBody - } + adjustedRequest = + defaultRequest + { requestHeaders = + [ (hContentLength, S8.pack $ show $ S.length body) + | lengthType == UseKnownLength + ] + , requestMethod = "POST" + , requestBodyLength = + if lengthType == UseKnownLength + then KnownLength $ fromIntegral $ S.length body + else ChunkedBody + } runStrictBodyApp :: RequestSizeLimitSettings -> SRequest -> IO SResponse runStrictBodyApp settings req = @@ -107,14 +143,15 @@ runStrictBodyApp settings req = requestSizeLimitMiddleware settings app where app req' respond = do - _body <- strictRequestBody req' - respond $ responseLBS status200 [] "" + _body <- strictRequestBody req' + respond $ responseLBS status200 [] "" -runStreamingChunkApp :: Int -> RequestSizeLimitSettings -> Request -> IO SResponse +runStreamingChunkApp + :: Int -> RequestSizeLimitSettings -> Request -> IO SResponse runStreamingChunkApp times settings req = runSession (request req) $ requestSizeLimitMiddleware settings app where app req' respond = do - _chunks <- replicateM times (getRequestBodyChunk req') - respond $ responseLBS status200 [] "" + _chunks <- replicateM times (getRequestBodyChunk req') + respond $ responseLBS status200 [] "" diff --git a/wai-extra/test/Network/Wai/Middleware/RoutedSpec.hs b/wai-extra/test/Network/Wai/Middleware/RoutedSpec.hs index 3eb479855..6e66d52e3 100644 --- a/wai-extra/test/Network/Wai/Middleware/RoutedSpec.hs +++ b/wai-extra/test/Network/Wai/Middleware/RoutedSpec.hs @@ -1,8 +1,9 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.Middleware.RoutedSpec - ( main - , spec - ) where + +module Network.Wai.Middleware.RoutedSpec ( + main, + spec, +) where import Data.ByteString (ByteString) import Data.String (IsString) @@ -11,42 +12,49 @@ import Network.Wai import Network.Wai.Test import Test.Hspec -import Network.Wai.Middleware.Routed import Network.Wai.Middleware.ForceSSL (forceSSL) +import Network.Wai.Middleware.Routed main :: IO () main = hspec spec spec :: Spec spec = describe "forceSSL" $ do - it "routed middleware" $ do - let destination = "https://example.com/d/" - let routedSslJsonApp prefix = routedMiddleware (checkPrefix prefix) forceSSL jsonApp - checkPrefix p (p1:_) = p == p1 - checkPrefix _ _ = False - - flip runSession (routedSslJsonApp "r") $ do - res <- testDPath "http" - assertNoHeader location res - assertStatus 200 res - assertBody "{\"foo\":\"bar\"}" res - - flip runSession (routedSslJsonApp "d") $ do - res2 <- testDPath "http" - assertHeader location destination res2 - assertStatus 301 res2 + it "routed middleware" $ do + let destination = "https://example.com/d/" + let routedSslJsonApp prefix = routedMiddleware (checkPrefix prefix) forceSSL jsonApp + checkPrefix p (p1 : _) = p == p1 + checkPrefix _ _ = False + + flip runSession (routedSslJsonApp "r") $ do + res <- testDPath "http" + assertNoHeader location res + assertStatus 200 res + assertBody "{\"foo\":\"bar\"}" res + + flip runSession (routedSslJsonApp "d") $ do + res2 <- testDPath "http" + assertHeader location destination res2 + assertStatus 301 res2 jsonApp :: Application -jsonApp _req cps = cps $ responseLBS status200 - [(hContentType, "application/json")] - "{\"foo\":\"bar\"}" +jsonApp _req cps = + cps $ + responseLBS + status200 + [(hContentType, "application/json")] + "{\"foo\":\"bar\"}" testDPath :: ByteString -> Session SResponse testDPath proto = - request $ flip setRawPathInfo "/d/" defaultRequest - { requestHeaders = [("X-Forwarded-Proto", proto)] - , requestHeaderHost = Just "example.com" - } + request $ + flip + setRawPathInfo + "/d/" + defaultRequest + { requestHeaders = [("X-Forwarded-Proto", proto)] + , requestHeaderHost = Just "example.com" + } location :: IsString ci => ci location = "Location" diff --git a/wai-extra/test/Network/Wai/Middleware/SelectSpec.hs b/wai-extra/test/Network/Wai/Middleware/SelectSpec.hs index 4befb74ea..ff3de43db 100644 --- a/wai-extra/test/Network/Wai/Middleware/SelectSpec.hs +++ b/wai-extra/test/Network/Wai/Middleware/SelectSpec.hs @@ -1,10 +1,10 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.Middleware.SelectSpec - ( main, +module Network.Wai.Middleware.SelectSpec ( + main, spec, - ) +) where import Data.ByteString (ByteString) @@ -22,34 +22,52 @@ main = hspec spec spec :: Spec spec = describe "Select" $ do - it "With empty select should passthrough" $ - runApp (selectMiddleware mempty) "/" `shouldReturn` status200 - it "With other path select should passthrough" $ - runApp (selectMiddleware $ selectOverride "/_" status401) "/" `shouldReturn` status200 - it "With path select should hit override" $ - runApp (selectMiddleware $ selectOverride "/_" status401) "/_" `shouldReturn` status401 - it "With twice path select should hit first override" $ - runApp (selectMiddleware $ selectOverride "/_" status401 <> selectOverride "/_" status500) "/_" - `shouldReturn` status401 - it "With two paths select should hit first matching (first)" $ - runApp (selectMiddleware $ selectOverride "/_" status401 <> selectOverride "/-" status500) "/_" - `shouldReturn` status401 - it "With two paths select should hit first matching (last)" $ - runApp (selectMiddleware $ selectOverride "/_" status401 <> selectOverride "/-" status500) "/-" - `shouldReturn` status500 - it "With other two paths select should passthrough" $ - runApp (selectMiddleware $ selectOverride "/_" status401 <> selectOverride "/-" status500) "/" - `shouldReturn` status200 - it "With mempty then the path select should hit the pass" $ - runApp (selectMiddleware $ mempty <> selectOverride "/-" status500) "/-" - `shouldReturn` status500 + it "With empty select should passthrough" $ + runApp (selectMiddleware mempty) "/" `shouldReturn` status200 + it "With other path select should passthrough" $ + runApp (selectMiddleware $ selectOverride "/_" status401) "/" + `shouldReturn` status200 + it "With path select should hit override" $ + runApp (selectMiddleware $ selectOverride "/_" status401) "/_" + `shouldReturn` status401 + it "With twice path select should hit first override" $ + runApp + ( selectMiddleware $ + selectOverride "/_" status401 <> selectOverride "/_" status500 + ) + "/_" + `shouldReturn` status401 + it "With two paths select should hit first matching (first)" $ + runApp + ( selectMiddleware $ + selectOverride "/_" status401 <> selectOverride "/-" status500 + ) + "/_" + `shouldReturn` status401 + it "With two paths select should hit first matching (last)" $ + runApp + ( selectMiddleware $ + selectOverride "/_" status401 <> selectOverride "/-" status500 + ) + "/-" + `shouldReturn` status500 + it "With other two paths select should passthrough" $ + runApp + ( selectMiddleware $ + selectOverride "/_" status401 <> selectOverride "/-" status500 + ) + "/" + `shouldReturn` status200 + it "With mempty then the path select should hit the pass" $ + runApp (selectMiddleware $ mempty <> selectOverride "/-" status500) "/-" + `shouldReturn` status500 runApp :: Middleware -> ByteString -> IO Status runApp mw path = - fmap simpleStatus $ - runSession - (request $ defaultRequest {rawPathInfo = path}) - $ mw app + fmap simpleStatus + $ runSession + (request $ defaultRequest{rawPathInfo = path}) + $ mw app app :: Application app = constApp status200 diff --git a/wai-extra/test/Network/Wai/Middleware/StripHeadersSpec.hs b/wai-extra/test/Network/Wai/Middleware/StripHeadersSpec.hs index 8a15520c5..fbb56b722 100644 --- a/wai-extra/test/Network/Wai/Middleware/StripHeadersSpec.hs +++ b/wai-extra/test/Network/Wai/Middleware/StripHeadersSpec.hs @@ -1,9 +1,10 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.Middleware.StripHeadersSpec - ( main - , spec - ) where + +module Network.Wai.Middleware.StripHeadersSpec ( + main, + spec, +) where import Control.Arrow (first) import Data.ByteString (ByteString) @@ -19,11 +20,9 @@ import Test.Hspec import Network.Wai.Middleware.AddHeaders (addHeaders) import Network.Wai.Middleware.StripHeaders (stripHeaderIf, stripHeadersIf) - main :: IO () main = hspec spec - spec :: Spec spec = describe "stripHeader" $ do let host = "example.com" @@ -31,8 +30,16 @@ spec = describe "stripHeader" $ do it "strips a specific header" $ do resp1 <- runApp host (addHeaders testHeaders) defaultRequest - resp2 <- runApp host (stripHeaderIf "Foo" (const False) . addHeaders testHeaders) defaultRequest - resp3 <- runApp host (stripHeaderIf "Foo" (const True) . addHeaders testHeaders) defaultRequest + resp2 <- + runApp + host + (stripHeaderIf "Foo" (const False) . addHeaders testHeaders) + defaultRequest + resp3 <- + runApp + host + (stripHeaderIf "Foo" (const True) . addHeaders testHeaders) + defaultRequest simpleHeaders resp1 `shouldBe` ciTestHeaders simpleHeaders resp2 `shouldBe` ciTestHeaders @@ -40,20 +47,28 @@ spec = describe "stripHeader" $ do it "strips specific set of headers" $ do resp1 <- runApp host (addHeaders testHeaders) defaultRequest - resp2 <- runApp host (stripHeadersIf ["Bar", "Foo"] (const False) . addHeaders testHeaders) defaultRequest - resp3 <- runApp host (stripHeadersIf ["Bar", "Foo"] (const True) . addHeaders testHeaders) defaultRequest + resp2 <- + runApp + host + (stripHeadersIf ["Bar", "Foo"] (const False) . addHeaders testHeaders) + defaultRequest + resp3 <- + runApp + host + (stripHeadersIf ["Bar", "Foo"] (const True) . addHeaders testHeaders) + defaultRequest simpleHeaders resp1 `shouldBe` ciTestHeaders simpleHeaders resp2 `shouldBe` ciTestHeaders simpleHeaders resp3 `shouldBe` [last ciTestHeaders] - testHeaders :: [(ByteString, ByteString)] testHeaders = [("Foo", "fooey"), ("Bar", "barbican"), ("Baz", "bazooka")] - runApp :: ByteString -> Middleware -> Request -> IO SResponse -runApp host mw req = runSession - (request req { requestHeaderHost = Just $ host <> ":80" }) $ mw app +runApp host mw req = + runSession + (request req{requestHeaderHost = Just $ host <> ":80"}) + $ mw app where app _ respond = respond $ responseLBS status200 [] "" diff --git a/wai-extra/test/Network/Wai/Middleware/TimeoutSpec.hs b/wai-extra/test/Network/Wai/Middleware/TimeoutSpec.hs index 8144dac61..3f382821f 100644 --- a/wai-extra/test/Network/Wai/Middleware/TimeoutSpec.hs +++ b/wai-extra/test/Network/Wai/Middleware/TimeoutSpec.hs @@ -1,7 +1,8 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.Middleware.TimeoutSpec - ( spec - ) where + +module Network.Wai.Middleware.TimeoutSpec ( + spec, +) where import Control.Concurrent (threadDelay) import Network.HTTP.Types (status200, status503, status504) diff --git a/wai-extra/test/Network/Wai/ParseSpec.hs b/wai-extra/test/Network/Wai/ParseSpec.hs index 43964ff91..295843474 100644 --- a/wai-extra/test/Network/Wai/ParseSpec.hs +++ b/wai-extra/test/Network/Wai/ParseSpec.hs @@ -1,5 +1,6 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} + module Network.Wai.ParseSpec (main, spec) where #if __GLASGOW_HASKELL__ < 804 @@ -13,8 +14,12 @@ import qualified Data.IORef as I import qualified Data.Text as TS import qualified Data.Text.Encoding as TE import Data.Word8 (_e) -import Network.Wai (Request (requestHeaders), defaultRequest, setRequestBodyChunks) -import Network.Wai.Handler.Warp (InvalidRequest(..)) +import Network.Wai ( + Request (requestHeaders), + defaultRequest, + setRequestBodyChunks, + ) +import Network.Wai.Handler.Warp (InvalidRequest (..)) import System.IO (IOMode (ReadMode), withFile) import Test.HUnit (Assertion, (@=?), (@?=)) import Test.Hspec @@ -30,15 +35,24 @@ spec :: Spec spec = do describe "parseContentType" $ do let go (x, y, z) = it (TS.unpack $ TE.decodeUtf8 x) $ parseContentType x `shouldBe` (y, z) - mapM_ go + mapM_ + go [ ("text/plain", "text/plain", []) , ("text/plain; charset=UTF-8 ", "text/plain", [("charset", "UTF-8")]) - , ("text/plain; charset=UTF-8 ; boundary = foo", "text/plain", [("charset", "UTF-8"), ("boundary", "foo")]) - , ("text/plain; charset=UTF-8 ; boundary = \"quoted\"", "text/plain", [("charset", "UTF-8"), ("boundary", "quoted")]) + , + ( "text/plain; charset=UTF-8 ; boundary = foo" + , "text/plain" + , [("charset", "UTF-8"), ("boundary", "foo")] + ) + , + ( "text/plain; charset=UTF-8 ; boundary = \"quoted\"" + , "text/plain" + , [("charset", "UTF-8"), ("boundary", "quoted")] + ) ] it "parseHttpAccept" caseParseHttpAccept describe "parseRequestBody" $ do - caseParseRequestBody + caseParseRequestBody it "multipart with plus" caseMultipartPlus it "multipart with multiple attributes" caseMultipartAttrs it "urlencoded with plus" caseUrlEncPlus @@ -48,163 +62,186 @@ spec = do caseParseHttpAccept :: Assertion caseParseHttpAccept = do - let input = "text/plain; q=0.5, text/html;charset=utf-8, text/*;q=0.8;ext=blah, text/x-dvi; q=0.8, text/x-c" + let input = + "text/plain; q=0.5, text/html;charset=utf-8, text/*;q=0.8;ext=blah, text/x-dvi; q=0.8, text/x-c" expected = ["text/html;charset=utf-8", "text/x-c", "text/x-dvi", "text/*", "text/plain"] expected @=? parseHttpAccept input -parseRequestBody' :: BackEnd file - -> SRequest - -> IO ([(S.ByteString, S.ByteString)], [(S.ByteString, FileInfo file)]) +parseRequestBody' + :: BackEnd file + -> SRequest + -> IO ([(S.ByteString, S.ByteString)], [(S.ByteString, FileInfo file)]) parseRequestBody' sink (SRequest req bod) = case getRequestBodyType req of Nothing -> return ([], []) Just rbt -> do ref <- I.newIORef $ L.toChunks bod let rb = I.atomicModifyIORef ref $ \chunks -> - case chunks of - [] -> ([], S.empty) - x:y -> (y, x) + case chunks of + [] -> ([], S.empty) + x : y -> (y, x) sinkRequestBody sink rbt rb caseParseRequestBody :: Spec caseParseRequestBody = do - it "parsing post x-www-form-urlencoded" $ do - let content1 = "foo=bar&baz=bin" - let ctype1 = "application/x-www-form-urlencoded" - result1 <- parseRequestBody' lbsBackEnd $ toRequest ctype1 content1 - result1 `shouldBe` ([("foo", "bar"), ("baz", "bin")], []) - - let ctype2 = "multipart/form-data; boundary=AaB03x" - let expectedsmap2 = - [ ("title", "A File") - , ("summary", "This is my file\nfile test") - ] - let textPlain = "text/plain; charset=iso-8859-1" - let expectedfile2 = - [("document", FileInfo "b.txt" textPlain "This is a file.\nIt has two lines.")] - let expected2 = (expectedsmap2, expectedfile2) - - it "parsing post multipart/form-data" $ do - result2 <- parseRequestBody' lbsBackEnd $ toRequest ctype2 content2 - result2 `shouldBe` expected2 - - it "parsing post multipart/form-data 2" $ do - result2' <- parseRequestBody' lbsBackEnd $ toRequest' ctype2 content2 - result2' `shouldBe` expected2 - - - let ctype3 = "multipart/form-data; boundary=----WebKitFormBoundaryB1pWXPZ6lNr8RiLh" - let expectedsmap3 = [] - let expectedfile3 = [("yaml", FileInfo "README" "application/octet-stream" "Photo blog using Hack.\n")] - let expected3 = (expectedsmap3, expectedfile3) - - let def = defaultParseRequestBodyOptions - it "parsing actual post multipart/form-data" $ do - result3 <- parseRequestBody' lbsBackEnd $ toRequest ctype3 content3 - result3 `shouldBe` expected3 - - it "parsing actual post multipart/form-data 2" $ do - result3' <- parseRequestBody' lbsBackEnd $ toRequest' ctype3 content3 - result3' `shouldBe` expected3 - - it "parsing with memory limit" $ do - SRequest req4 _bod4 <- toRequest'' ctype3 content3 - result4' <- parseRequestBodyEx ( setMaxRequestNumFiles 1 $ setMaxRequestKeyLength 14 def ) lbsBackEnd req4 - result4' `shouldBe` expected3 - - it "exceeding number of files" $ do - SRequest req4 _bod4 <- toRequest'' ctype3 content3 - (parseRequestBodyEx ( setMaxRequestNumFiles 0 def ) lbsBackEnd req4) `shouldThrow` anyException - - it "exceeding parameter length" $ do - SRequest req4 _bod4 <- toRequest'' ctype3 content3 - (parseRequestBodyEx ( setMaxRequestKeyLength 2 def ) lbsBackEnd req4) `shouldThrow` anyException - - it "exceeding file size" $ do - SRequest req4 _bod4 <- toRequest'' ctype3 content3 - (parseRequestBodyEx ( setMaxRequestFileSize 2 def ) lbsBackEnd req4) - `shouldThrow` (== PayloadTooLarge) - - it "exceeding total file size" $ do - SRequest req4 _bod4 <- toRequest'' ctype3 content3 - (parseRequestBodyEx ( setMaxRequestFilesSize 20 def ) lbsBackEnd req4) - `shouldThrow` (== PayloadTooLarge) - SRequest req5 _bod5 <- toRequest'' ctype3 content5 - (parseRequestBodyEx ( setMaxRequestFilesSize 20 def ) lbsBackEnd req5) - `shouldThrow` (== PayloadTooLarge) - - it "exceeding max parm value size" $ do - SRequest req4 _bod4 <- toRequest'' ctype2 content2 - (parseRequestBodyEx ( setMaxRequestParmsSize 10 def ) lbsBackEnd req4) - `shouldThrow` (== PayloadTooLarge) - - it "exceeding max header lines" $ do - SRequest req4 _bod4 <- toRequest'' ctype2 content2 - (parseRequestBodyEx ( setMaxHeaderLines 1 def ) lbsBackEnd req4) `shouldThrow` anyException - - it "exceeding header line size" $ do - SRequest req4 _bod4 <- toRequest'' ctype3 content4 - (parseRequestBodyEx ( setMaxHeaderLineLength 8190 def ) lbsBackEnd req4) - `shouldThrow` (== RequestHeaderFieldsTooLarge) - - it "Testing parseRequestBodyEx with application/x-www-form-urlencoded" $ do - let content = "thisisalongparameterkey=andthisbeanevenlongerparametervaluehelloworldhowareyou" - let ctype = "application/x-www-form-urlencoded" - SRequest req _bod <- toRequest'' ctype content - result <- parseRequestBodyEx def lbsBackEnd req - result `shouldBe` ([( "thisisalongparameterkey" - , "andthisbeanevenlongerparametervaluehelloworldhowareyou" )], []) - - it "exceeding max parm value size with x-www-form-urlencoded mimetype" $ do - let content = "thisisalongparameterkey=andthisbeanevenlongerparametervaluehelloworldhowareyou" - let ctype = "application/x-www-form-urlencoded" - SRequest req _bod <- toRequest'' ctype content - (parseRequestBodyEx ( setMaxRequestParmsSize 10 def ) lbsBackEnd req) `shouldThrow` anyException - + it "parsing post x-www-form-urlencoded" $ do + let content1 = "foo=bar&baz=bin" + let ctype1 = "application/x-www-form-urlencoded" + result1 <- parseRequestBody' lbsBackEnd $ toRequest ctype1 content1 + result1 `shouldBe` ([("foo", "bar"), ("baz", "bin")], []) + + let ctype2 = "multipart/form-data; boundary=AaB03x" + let expectedsmap2 = + [ ("title", "A File") + , ("summary", "This is my file\nfile test") + ] + let textPlain = "text/plain; charset=iso-8859-1" + let expectedfile2 = + [("document", FileInfo "b.txt" textPlain "This is a file.\nIt has two lines.")] + let expected2 = (expectedsmap2, expectedfile2) + + it "parsing post multipart/form-data" $ do + result2 <- parseRequestBody' lbsBackEnd $ toRequest ctype2 content2 + result2 `shouldBe` expected2 + + it "parsing post multipart/form-data 2" $ do + result2' <- parseRequestBody' lbsBackEnd $ toRequest' ctype2 content2 + result2' `shouldBe` expected2 + + let ctype3 = "multipart/form-data; boundary=----WebKitFormBoundaryB1pWXPZ6lNr8RiLh" + let expectedsmap3 = [] + let expectedfile3 = + [ + ( "yaml" + , FileInfo "README" "application/octet-stream" "Photo blog using Hack.\n" + ) + ] + let expected3 = (expectedsmap3, expectedfile3) + + let def = defaultParseRequestBodyOptions + it "parsing actual post multipart/form-data" $ do + result3 <- parseRequestBody' lbsBackEnd $ toRequest ctype3 content3 + result3 `shouldBe` expected3 + + it "parsing actual post multipart/form-data 2" $ do + result3' <- parseRequestBody' lbsBackEnd $ toRequest' ctype3 content3 + result3' `shouldBe` expected3 + + it "parsing with memory limit" $ do + SRequest req4 _bod4 <- toRequest'' ctype3 content3 + result4' <- + parseRequestBodyEx + (setMaxRequestNumFiles 1 $ setMaxRequestKeyLength 14 def) + lbsBackEnd + req4 + result4' `shouldBe` expected3 + + it "exceeding number of files" $ do + SRequest req4 _bod4 <- toRequest'' ctype3 content3 + (parseRequestBodyEx (setMaxRequestNumFiles 0 def) lbsBackEnd req4) + `shouldThrow` anyException + + it "exceeding parameter length" $ do + SRequest req4 _bod4 <- toRequest'' ctype3 content3 + (parseRequestBodyEx (setMaxRequestKeyLength 2 def) lbsBackEnd req4) + `shouldThrow` anyException + + it "exceeding file size" $ do + SRequest req4 _bod4 <- toRequest'' ctype3 content3 + (parseRequestBodyEx (setMaxRequestFileSize 2 def) lbsBackEnd req4) + `shouldThrow` (== PayloadTooLarge) + + it "exceeding total file size" $ do + SRequest req4 _bod4 <- toRequest'' ctype3 content3 + (parseRequestBodyEx (setMaxRequestFilesSize 20 def) lbsBackEnd req4) + `shouldThrow` (== PayloadTooLarge) + SRequest req5 _bod5 <- toRequest'' ctype3 content5 + (parseRequestBodyEx (setMaxRequestFilesSize 20 def) lbsBackEnd req5) + `shouldThrow` (== PayloadTooLarge) + + it "exceeding max parm value size" $ do + SRequest req4 _bod4 <- toRequest'' ctype2 content2 + (parseRequestBodyEx (setMaxRequestParmsSize 10 def) lbsBackEnd req4) + `shouldThrow` (== PayloadTooLarge) + + it "exceeding max header lines" $ do + SRequest req4 _bod4 <- toRequest'' ctype2 content2 + (parseRequestBodyEx (setMaxHeaderLines 1 def) lbsBackEnd req4) + `shouldThrow` anyException + + it "exceeding header line size" $ do + SRequest req4 _bod4 <- toRequest'' ctype3 content4 + (parseRequestBodyEx (setMaxHeaderLineLength 8190 def) lbsBackEnd req4) + `shouldThrow` (== RequestHeaderFieldsTooLarge) + + it "Testing parseRequestBodyEx with application/x-www-form-urlencoded" $ do + let content = + "thisisalongparameterkey=andthisbeanevenlongerparametervaluehelloworldhowareyou" + let ctype = "application/x-www-form-urlencoded" + SRequest req _bod <- toRequest'' ctype content + result <- parseRequestBodyEx def lbsBackEnd req + result + `shouldBe` ( + [ + ( "thisisalongparameterkey" + , "andthisbeanevenlongerparametervaluehelloworldhowareyou" + ) + ] + , [] + ) + + it "exceeding max parm value size with x-www-form-urlencoded mimetype" $ do + let content = + "thisisalongparameterkey=andthisbeanevenlongerparametervaluehelloworldhowareyou" + let ctype = "application/x-www-form-urlencoded" + SRequest req _bod <- toRequest'' ctype content + (parseRequestBodyEx (setMaxRequestParmsSize 10 def) lbsBackEnd req) + `shouldThrow` anyException where content2 = - "--AaB03x\n" - <> "Content-Disposition: form-data; name=\"document\"; filename=\"b.txt\"\n" - <> "Content-Type: text/plain; charset=iso-8859-1\n\n" - <> "This is a file.\n" - <> "It has two lines.\n" - <> "--AaB03x\n" - <> "Content-Disposition: form-data; name=\"title\"\n" - <> "Content-Type: text/plain; charset=iso-8859-1\n\n" - <> "A File\n" - <> "--AaB03x\n" - <> "Content-Disposition: form-data; name=\"summary\"\n" - <> "Content-Type: text/plain; charset=iso-8859-1\n\n" - <> "This is my file\n" - <> "file test\n" - <> "--AaB03x--" + "--AaB03x\n" + <> "Content-Disposition: form-data; name=\"document\"; filename=\"b.txt\"\n" + <> "Content-Type: text/plain; charset=iso-8859-1\n\n" + <> "This is a file.\n" + <> "It has two lines.\n" + <> "--AaB03x\n" + <> "Content-Disposition: form-data; name=\"title\"\n" + <> "Content-Type: text/plain; charset=iso-8859-1\n\n" + <> "A File\n" + <> "--AaB03x\n" + <> "Content-Disposition: form-data; name=\"summary\"\n" + <> "Content-Type: text/plain; charset=iso-8859-1\n\n" + <> "This is my file\n" + <> "file test\n" + <> "--AaB03x--" content3 = - "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh\r\n" - <> "Content-Disposition: form-data; name=\"yaml\"; filename=\"README\"\r\n" - <> "Content-Type: application/octet-stream\r\n\r\n" - <> "Photo blog using Hack.\n\r\n" - <> "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh--\r\n" + "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh\r\n" + <> "Content-Disposition: form-data; name=\"yaml\"; filename=\"README\"\r\n" + <> "Content-Type: application/octet-stream\r\n\r\n" + <> "Photo blog using Hack.\n\r\n" + <> "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh--\r\n" content4 = - "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh\r\n" - <> "Content-Disposition: form-data; name=\"alb\"; filename=\"README\"\r\n" - <> "Content-Type: application/octet-stream\r\n\r\n" - <> "Photo blog using Hack.\r\n\r\n" - <> "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh\r\n" - <> "Content-Disposition: form-data; name=\"bla\"; filename=\"riedmi" - <> S.replicate 8190 _e <> "\"\r\n" - <> "Content-Type: application/octet-stream\r\n\r\n" - <> "Photo blog using Hack.\r\n\r\n" - <> "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh--\r\n" + "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh\r\n" + <> "Content-Disposition: form-data; name=\"alb\"; filename=\"README\"\r\n" + <> "Content-Type: application/octet-stream\r\n\r\n" + <> "Photo blog using Hack.\r\n\r\n" + <> "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh\r\n" + <> "Content-Disposition: form-data; name=\"bla\"; filename=\"riedmi" + <> S.replicate 8190 _e + <> "\"\r\n" + <> "Content-Type: application/octet-stream\r\n\r\n" + <> "Photo blog using Hack.\r\n\r\n" + <> "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh--\r\n" content5 = - "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh\r\n" - <> "Content-Disposition: form-data; name=\"yaml\"; filename=\"README\"\r\n" - <> "Content-Type: application/octet-stream\r\n\r\n" - <> "Photo blog using Hack.\n\r\n" - <> "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh\r\n" - <> "Content-Disposition: form-data; name=\"yaml2\"; filename=\"MEADRE\"\r\n" - <> "Content-Type: application/octet-stream\r\n\r\n" - <> "Photo blog using Hack.\n\r\n" - <> "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh--\r\n" + "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh\r\n" + <> "Content-Disposition: form-data; name=\"yaml\"; filename=\"README\"\r\n" + <> "Content-Type: application/octet-stream\r\n\r\n" + <> "Photo blog using Hack.\n\r\n" + <> "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh\r\n" + <> "Content-Disposition: form-data; name=\"yaml2\"; filename=\"MEADRE\"\r\n" + <> "Content-Type: application/octet-stream\r\n\r\n" + <> "Photo blog using Hack.\n\r\n" + <> "------WebKitFormBoundaryB1pWXPZ6lNr8RiLh--\r\n" caseMultipartPlus :: Assertion caseMultipartPlus = do @@ -212,11 +249,11 @@ caseMultipartPlus = do result @?= ([("email", "has+plus")], []) where content = - "--AaB03x\n" <> - "Content-Disposition: form-data; name=\"email\"\n" <> - "Content-Type: text/plain; charset=iso-8859-1\n\n" <> - "has+plus\n" <> - "--AaB03x--" + "--AaB03x\n" + <> "Content-Disposition: form-data; name=\"email\"\n" + <> "Content-Type: text/plain; charset=iso-8859-1\n\n" + <> "has+plus\n" + <> "--AaB03x--" ctype = "multipart/form-data; boundary=AaB03x" caseMultipartAttrs :: Assertion @@ -225,17 +262,17 @@ caseMultipartAttrs = do result @?= ([("email", "has+plus")], []) where content = - "--AaB03x\n" <> - "Content-Disposition: form-data; name=\"email\"\n" <> - "Content-Type: text/plain; charset=iso-8859-1\n\n" <> - "has+plus\n" <> - "--AaB03x--" + "--AaB03x\n" + <> "Content-Disposition: form-data; name=\"email\"\n" + <> "Content-Type: text/plain; charset=iso-8859-1\n\n" + <> "has+plus\n" + <> "--AaB03x--" ctype = "multipart/form-data; charset=UTF-8; boundary=AaB03x" caseUrlEncPlus :: Assertion caseUrlEncPlus = do result <- runResourceT $ withInternalState $ \state -> - parseRequestBody' (tempFileBackEnd state) $ toRequest ctype content + parseRequestBody' (tempFileBackEnd state) $ toRequest ctype content result @?= ([("email", "has+plus")], []) where content = "email=has%2Bplus" @@ -253,8 +290,14 @@ dalvikHelper includeLength = do , ("REQUEST_URI", "http://192.168.1.115:3000/") , ("REQUEST_METHOD", "POST") , ("HTTP_CONNECTION", "Keep-Alive") - , ("HTTP_COOKIE", "_SESSION=fgUGM5J/k6mGAAW+MMXIJZCJHobw/oEbb6T17KQN0p9yNqiXn/m/ACrsnRjiCEgqtG4fogMUDI+jikoFGcwmPjvuD5d+MDz32iXvDdDJsFdsFMfivuey2H+n6IF6yFGD") - , ("HTTP_USER_AGENT", "Dalvik/1.1.0 (Linux; U; Android 2.1-update1; sdk Build/ECLAIR)") + , + ( "HTTP_COOKIE" + , "_SESSION=fgUGM5J/k6mGAAW+MMXIJZCJHobw/oEbb6T17KQN0p9yNqiXn/m/ACrsnRjiCEgqtG4fogMUDI+jikoFGcwmPjvuD5d+MDz32iXvDdDJsFdsFMfivuey2H+n6IF6yFGD" + ) + , + ( "HTTP_USER_AGENT" + , "Dalvik/1.1.0 (Linux; U; Android 2.1-update1; sdk Build/ECLAIR)" + ) , ("HTTP_HOST", "192.168.1.115:3000") , ("HTTP_ACCEPT", "*, */*") , ("HTTP_VERSION", "HTTP/1.1") @@ -263,9 +306,10 @@ dalvikHelper includeLength = do headers | includeLength = ("content-length", "12098") : headers' | otherwise = headers' - let request' = defaultRequest - { requestHeaders = headers - } + let request' = + defaultRequest + { requestHeaders = headers + } (params, files) <- case getRequestBodyType request' of Nothing -> return ([], []) @@ -277,18 +321,25 @@ dalvikHelper includeLength = do length files @?= 1 toRequest' :: S8.ByteString -> S8.ByteString -> SRequest -toRequest' ctype content = SRequest defaultRequest - { requestHeaders = [("Content-Type", ctype)] - } (L.fromChunks $ map S.singleton $ S.unpack content) +toRequest' ctype content = + SRequest + defaultRequest + { requestHeaders = [("Content-Type", ctype)] + } + (L.fromChunks $ map S.singleton $ S.unpack content) toRequest'' :: S8.ByteString -> S8.ByteString -> IO SRequest toRequest'' ctype content = mkRB content >>= \b -> do - let req = setRequestBodyChunks b defaultRequest {requestHeaders = [("Content-Type", ctype)]} + let req = + setRequestBodyChunks + b + defaultRequest{requestHeaders = [("Content-Type", ctype)]} return $ SRequest req (L.fromChunks $ map S.singleton $ S.unpack content) mkRB :: S8.ByteString -> IO (IO S8.ByteString) mkRB content = do r <- I.newIORef content return $ - I.atomicModifyIORef r $ \a -> (S8.empty, a) + I.atomicModifyIORef r $ + \a -> (S8.empty, a) diff --git a/wai-extra/test/Network/Wai/RequestSpec.hs b/wai-extra/test/Network/Wai/RequestSpec.hs index df1126f08..82bbd7010 100644 --- a/wai-extra/test/Network/Wai/RequestSpec.hs +++ b/wai-extra/test/Network/Wai/RequestSpec.hs @@ -1,15 +1,21 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.RequestSpec - ( main - , spec - ) where +module Network.Wai.RequestSpec ( + main, + spec, +) where import Control.Exception (try) import Control.Monad (forever) import Data.ByteString (ByteString) import Network.HTTP.Types (HeaderName) -import Network.Wai (Request (..), RequestBodyLength (..), defaultRequest, getRequestBodyChunk, setRequestBodyChunks) +import Network.Wai ( + Request (..), + RequestBodyLength (..), + defaultRequest, + getRequestBodyChunk, + setRequestBodyChunks, + ) import Network.Wai.Request import Test.Hspec @@ -22,45 +28,50 @@ spec = do it "too large content length should throw RequestSizeException" $ do let limit = 1024 largeRequest = - setRequestBodyChunks (return "repeat this chunk") defaultRequest - { isSecure = False - , requestBodyLength = KnownLength (limit + 1) - } + setRequestBodyChunks + (return "repeat this chunk") + defaultRequest + { isSecure = False + , requestBodyLength = KnownLength (limit + 1) + } checkedRequest <- requestSizeCheck limit largeRequest body <- try (getRequestBodyChunk checkedRequest) case body of Left (RequestSizeException l) -> l `shouldBe` limit - Right _ -> expectationFailure "request size check failed" + Right _ -> expectationFailure "request size check failed" it "too many chunks should throw RequestSizeException" $ do let limit = 1024 largeRequest = - setRequestBodyChunks (return "repeat this chunk") defaultRequest - { isSecure = False - , requestBodyLength = ChunkedBody - } + setRequestBodyChunks + (return "repeat this chunk") + defaultRequest + { isSecure = False + , requestBodyLength = ChunkedBody + } checkedRequest <- requestSizeCheck limit largeRequest body <- try (forever $ getRequestBodyChunk checkedRequest) case body of Left (RequestSizeException l) -> l `shouldBe` limit - Right _ -> expectationFailure "request size check failed" + Right _ -> expectationFailure "request size check failed" describe "appearsSecure" $ do - let insecureRequest = defaultRequest - { isSecure = False - , requestHeaders = - [ ("HTTPS", "off") - , ("HTTP_X_FORWARDED_SSL", "off") - , ("HTTP_X_FORWARDED_SCHEME", "http") - , ("HTTP_X_FORWARDED_PROTO", "http,xyz") - ] - } + let insecureRequest = + defaultRequest + { isSecure = False + , requestHeaders = + [ ("HTTPS", "off") + , ("HTTP_X_FORWARDED_SSL", "off") + , ("HTTP_X_FORWARDED_SCHEME", "http") + , ("HTTP_X_FORWARDED_PROTO", "http,xyz") + ] + } it "returns False for an insecure request" $ insecureRequest `shouldSatisfy` not . appearsSecure it "checks if the Request is actually secure" $ do - let req = insecureRequest { isSecure = True } + let req = insecureRequest{isSecure = True} req `shouldSatisfy` appearsSecure @@ -85,8 +96,9 @@ spec = do req `shouldSatisfy` appearsSecure addHeader :: HeaderName -> ByteString -> Request -> Request -addHeader name value req = req - { requestHeaders = (name, value) : otherHeaders } - +addHeader name value req = + req + { requestHeaders = (name, value) : otherHeaders + } where otherHeaders = filter ((/= name) . fst) $ requestHeaders req diff --git a/wai-extra/test/Network/Wai/TestSpec.hs b/wai-extra/test/Network/Wai/TestSpec.hs index 1f9652776..00d85030c 100644 --- a/wai-extra/test/Network/Wai/TestSpec.hs +++ b/wai-extra/test/Network/Wai/TestSpec.hs @@ -1,4 +1,5 @@ {-# LANGUAGE OverloadedStrings #-} + module Network.Wai.TestSpec (main, spec) where import Control.Monad (void) @@ -24,192 +25,215 @@ toByteString = L8.toStrict . toLazyByteString spec :: Spec spec = do - describe "setPath" $ do - - let req = setPath defaultRequest "/foo/bar/baz?foo=23&bar=42&baz" - - it "sets pathInfo" $ do - pathInfo req `shouldBe` ["foo", "bar", "baz"] - - it "utf8 path" $ - pathInfo (setPath defaultRequest "/foo/%D7%A9%D7%9C%D7%95%D7%9D/bar") `shouldBe` - ["foo", "שלום", "bar"] - - it "sets rawPathInfo" $ do - rawPathInfo req `shouldBe` "/foo/bar/baz" - - it "sets queryString" $ do - queryString req `shouldBe` [("foo", Just "23"), ("bar", Just "42"), ("baz", Nothing)] - - it "sets rawQueryString" $ do - rawQueryString req `shouldBe` "?foo=23&bar=42&baz" - - context "when path has no query string" $ do - it "sets rawQueryString to empty string" $ do - rawQueryString (setPath defaultRequest "/foo/bar/baz") `shouldBe` "" - - describe "srequest" $ do - - let echoApp req respond = do - reqBody <- L8.fromStrict <$> getRequestBodyChunk req - let reqHeaders = requestHeaders req - respond $ - responseLBS - status200 - reqHeaders - reqBody - - it "returns the response body of an echo app" $ do - sresp <- flip runSession echoApp $ - srequest $ SRequest defaultRequest "request body" - simpleBody sresp `shouldBe` "request body" - - describe "request" $ do - - let echoApp req respond = do - reqBody <- L8.fromStrict <$> getRequestBodyChunk req - let reqHeaders = requestHeaders req - respond $ - responseLBS - status200 - reqHeaders - reqBody - - it "returns the status code of an echo app on default request" $ do - sresp <- runSession (request defaultRequest) echoApp - simpleStatus sresp `shouldBe` status200 - - it "returns the response body of an echo app" $ do - bodyRef <- IORef.newIORef "request body" - let getBodyChunk = IORef.atomicModifyIORef bodyRef $ \leftover -> ("", leftover) - sresp <- flip runSession echoApp $ - request $ - setRequestBodyChunks getBodyChunk defaultRequest - simpleBody sresp `shouldBe` "request body" - - it "returns the response headers of an echo app" $ do - sresp <- flip runSession echoApp $ - request $ - defaultRequest - { requestHeaders = [("foo", "bar")] - } - simpleHeaders sresp `shouldBe` [("foo", "bar")] - - let cookieApp req respond = - case pathInfo req of - ["set", name, val] -> - respond $ - responseLBS - status200 - [( "Set-Cookie" - , toByteString $ Cookie.renderSetCookie $ - Cookie.def { Cookie.setCookieName = TE.encodeUtf8 name - , Cookie.setCookieValue = TE.encodeUtf8 val - } - ) - ] - "set_cookie_body" - ["delete", name] -> - respond $ - responseLBS - status200 - [( "Set-Cookie" - , toByteString $ Cookie.renderSetCookie $ - Cookie.def { Cookie.setCookieName = - TE.encodeUtf8 name - , Cookie.setCookieExpires = - Just $ UTCTime (fromGregorian 1970 1 1) 0 - } - ) - ] - "set_cookie_body" - _ -> - respond $ - responseLBS - status200 - [] - ( L8.pack - $ show - $ map snd - $ filter ((=="Cookie") . fst) - $ requestHeaders req - ) - - it "sends a Cookie header with correct value after receiving a Set-Cookie header" $ do - sresp <- flip runSession cookieApp $ do - void $ request $ - setPath defaultRequest "/set/cookie_name/cookie_value" - request $ - setPath defaultRequest "/get" - simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value\"]" - - it "sends a Cookie header with updated value after receiving a Set-Cookie header update" $ do - sresp <- flip runSession cookieApp $ do - void $ request $ - setPath defaultRequest "/set/cookie_name/cookie_value" - void $ request $ - setPath defaultRequest "/set/cookie_name/cookie_value2" - request $ - setPath defaultRequest "/get" - simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value2\"]" - - it "handles multiple cookies" $ do - sresp <- flip runSession cookieApp $ do - void $ request $ - setPath defaultRequest "/set/cookie_name/cookie_value" - void $ request $ - setPath defaultRequest "/set/cookie_name2/cookie_value2" - request $ - setPath defaultRequest "/get" - simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value;cookie_name2=cookie_value2\"]" - - it "removes a deleted cookie" $ do - sresp <- flip runSession cookieApp $ do - void $ request $ - setPath defaultRequest "/set/cookie_name/cookie_value" - void $ request $ - setPath defaultRequest "/set/cookie_name2/cookie_value2" - void $ request $ - setPath defaultRequest "/delete/cookie_name2" - request $ - setPath defaultRequest "/get" - simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value\"]" - - it "sends a cookie set with setClientCookie to server" $ do - sresp <- flip runSession cookieApp $ do - setClientCookie - (Cookie.def { Cookie.setCookieName = "cookie_name" - , Cookie.setCookieValue = "cookie_value" - } - ) - request $ - setPath defaultRequest "/get" - simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value\"]" - - it "sends a cookie updated with setClientCookie to server" $ do - sresp <- flip runSession cookieApp $ do - setClientCookie - (Cookie.def { Cookie.setCookieName = "cookie_name" - , Cookie.setCookieValue = "cookie_value" - } - ) - setClientCookie - (Cookie.def { Cookie.setCookieName = "cookie_name" - , Cookie.setCookieValue = "cookie_value2" - } - ) - request $ - setPath defaultRequest "/get" - simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value2\"]" - - it "does not send a cookie deleted with deleteClientCookie to server" $ do - sresp <- flip runSession cookieApp $ do - setClientCookie - (Cookie.def { Cookie.setCookieName = "cookie_name" - , Cookie.setCookieValue = "cookie_value" - } - ) - deleteClientCookie "cookie_name" - request $ - setPath defaultRequest "/get" - simpleBody sresp `shouldBe` "[]" + describe "setPath" $ do + let req = setPath defaultRequest "/foo/bar/baz?foo=23&bar=42&baz" + + it "sets pathInfo" $ do + pathInfo req `shouldBe` ["foo", "bar", "baz"] + + it "utf8 path" $ + pathInfo (setPath defaultRequest "/foo/%D7%A9%D7%9C%D7%95%D7%9D/bar") + `shouldBe` ["foo", "שלום", "bar"] + + it "sets rawPathInfo" $ do + rawPathInfo req `shouldBe` "/foo/bar/baz" + + it "sets queryString" $ do + queryString req + `shouldBe` [("foo", Just "23"), ("bar", Just "42"), ("baz", Nothing)] + + it "sets rawQueryString" $ do + rawQueryString req `shouldBe` "?foo=23&bar=42&baz" + + context "when path has no query string" $ do + it "sets rawQueryString to empty string" $ do + rawQueryString (setPath defaultRequest "/foo/bar/baz") `shouldBe` "" + + describe "srequest" $ do + let echoApp req respond = do + reqBody <- L8.fromStrict <$> getRequestBodyChunk req + let reqHeaders = requestHeaders req + respond $ + responseLBS + status200 + reqHeaders + reqBody + + it "returns the response body of an echo app" $ do + sresp <- + flip runSession echoApp $ + srequest $ + SRequest defaultRequest "request body" + simpleBody sresp `shouldBe` "request body" + + describe "request" $ do + let echoApp req respond = do + reqBody <- L8.fromStrict <$> getRequestBodyChunk req + let reqHeaders = requestHeaders req + respond $ + responseLBS + status200 + reqHeaders + reqBody + + it "returns the status code of an echo app on default request" $ do + sresp <- runSession (request defaultRequest) echoApp + simpleStatus sresp `shouldBe` status200 + + it "returns the response body of an echo app" $ do + bodyRef <- IORef.newIORef "request body" + let getBodyChunk = IORef.atomicModifyIORef bodyRef $ \leftover -> ("", leftover) + sresp <- + flip runSession echoApp $ + request $ + setRequestBodyChunks getBodyChunk defaultRequest + simpleBody sresp `shouldBe` "request body" + + it "returns the response headers of an echo app" $ do + sresp <- + flip runSession echoApp $ + request $ + defaultRequest + { requestHeaders = [("foo", "bar")] + } + simpleHeaders sresp `shouldBe` [("foo", "bar")] + + let cookieApp req respond = + case pathInfo req of + ["set", name, val] -> + respond $ + responseLBS + status200 + [ + ( "Set-Cookie" + , toByteString $ + Cookie.renderSetCookie $ + Cookie.def + { Cookie.setCookieName = TE.encodeUtf8 name + , Cookie.setCookieValue = TE.encodeUtf8 val + } + ) + ] + "set_cookie_body" + ["delete", name] -> + respond $ + responseLBS + status200 + [ + ( "Set-Cookie" + , toByteString $ + Cookie.renderSetCookie $ + Cookie.def + { Cookie.setCookieName = + TE.encodeUtf8 name + , Cookie.setCookieExpires = + Just $ UTCTime (fromGregorian 1970 1 1) 0 + } + ) + ] + "set_cookie_body" + _ -> + respond $ + responseLBS + status200 + [] + ( L8.pack $ + show $ + map snd $ + filter ((== "Cookie") . fst) $ + requestHeaders req + ) + + it + "sends a Cookie header with correct value after receiving a Set-Cookie header" $ do + sresp <- flip runSession cookieApp $ do + void $ + request $ + setPath defaultRequest "/set/cookie_name/cookie_value" + request $ + setPath defaultRequest "/get" + simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value\"]" + + it + "sends a Cookie header with updated value after receiving a Set-Cookie header update" $ do + sresp <- flip runSession cookieApp $ do + void $ + request $ + setPath defaultRequest "/set/cookie_name/cookie_value" + void $ + request $ + setPath defaultRequest "/set/cookie_name/cookie_value2" + request $ + setPath defaultRequest "/get" + simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value2\"]" + + it "handles multiple cookies" $ do + sresp <- flip runSession cookieApp $ do + void $ + request $ + setPath defaultRequest "/set/cookie_name/cookie_value" + void $ + request $ + setPath defaultRequest "/set/cookie_name2/cookie_value2" + request $ + setPath defaultRequest "/get" + simpleBody sresp + `shouldBe` "[\"cookie_name=cookie_value;cookie_name2=cookie_value2\"]" + + it "removes a deleted cookie" $ do + sresp <- flip runSession cookieApp $ do + void $ + request $ + setPath defaultRequest "/set/cookie_name/cookie_value" + void $ + request $ + setPath defaultRequest "/set/cookie_name2/cookie_value2" + void $ + request $ + setPath defaultRequest "/delete/cookie_name2" + request $ + setPath defaultRequest "/get" + simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value\"]" + + it "sends a cookie set with setClientCookie to server" $ do + sresp <- flip runSession cookieApp $ do + setClientCookie + ( Cookie.def + { Cookie.setCookieName = "cookie_name" + , Cookie.setCookieValue = "cookie_value" + } + ) + request $ + setPath defaultRequest "/get" + simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value\"]" + + it "sends a cookie updated with setClientCookie to server" $ do + sresp <- flip runSession cookieApp $ do + setClientCookie + ( Cookie.def + { Cookie.setCookieName = "cookie_name" + , Cookie.setCookieValue = "cookie_value" + } + ) + setClientCookie + ( Cookie.def + { Cookie.setCookieName = "cookie_name" + , Cookie.setCookieValue = "cookie_value2" + } + ) + request $ + setPath defaultRequest "/get" + simpleBody sresp `shouldBe` "[\"cookie_name=cookie_value2\"]" + + it "does not send a cookie deleted with deleteClientCookie to server" $ do + sresp <- flip runSession cookieApp $ do + setClientCookie + ( Cookie.def + { Cookie.setCookieName = "cookie_name" + , Cookie.setCookieValue = "cookie_value" + } + ) + deleteClientCookie "cookie_name" + request $ + setPath defaultRequest "/get" + simpleBody sresp `shouldBe` "[]" diff --git a/wai-extra/test/sample.hs b/wai-extra/test/sample.hs index 553552fae..a21fffb91 100644 --- a/wai-extra/test/sample.hs +++ b/wai-extra/test/sample.hs @@ -5,22 +5,29 @@ import Data.ByteString.Lazy (fromChunks) import Data.Text () import Network.HTTP.Types import Network.Wai +import Network.Wai.Handler.Warp import Network.Wai.Middleware.Gzip import Network.Wai.Middleware.Jsonp -import Network.Wai.Handler.Warp app :: Application app request = return $ case pathInfo request of - [] -> responseLBS status200 [] - $ fromChunks $ flip map [1..10000] $ \i -> pack $ concat - [ "

Just this same paragraph again. " - , show (i :: Int) - , "

" - ] + [] -> responseLBS status200 [] $ + fromChunks $ + flip map [1 .. 10000] $ \i -> + pack $ + concat + [ "

Just this same paragraph again. " + , show (i :: Int) + , "

" + ] ["test.html"] -> ResponseFile status200 [] "test.html" Nothing - ["json"] -> ResponseFile status200 [(hContentType, "application/json")] - "json" Nothing - _ -> ResponseFile status404 [] "../LICENSE" Nothing + ["json"] -> + ResponseFile + status200 + [(hContentType, "application/json")] + "json" + Nothing + _ -> ResponseFile status404 [] "../LICENSE" Nothing main :: IO () main = run 3000 $ gzip def $ jsonp app diff --git a/wai-frontend-monadcgi/Network/Wai/Frontend/MonadCGI.hs b/wai-frontend-monadcgi/Network/Wai/Frontend/MonadCGI.hs index ca84c099a..ec696b018 100644 --- a/wai-frontend-monadcgi/Network/Wai/Frontend/MonadCGI.hs +++ b/wai-frontend-monadcgi/Network/Wai/Frontend/MonadCGI.hs @@ -1,18 +1,18 @@ -module Network.Wai.Frontend.MonadCGI - ( cgiToApp - , cgiToAppGeneric - ) where +module Network.Wai.Frontend.MonadCGI ( + cgiToApp, + cgiToAppGeneric, +) where -import Network.Wai +import Control.Monad.IO.Class (liftIO) +import Data.CaseInsensitive (original) import Network.CGI.Monad import Network.CGI.Protocol import Network.HTTP.Types (Status (..)) -import Control.Monad.IO.Class (liftIO) -import Data.CaseInsensitive (original) +import Network.Wai -import qualified Data.Map as Map -import qualified Data.ByteString.Lazy as BS import qualified Data.ByteString.Char8 as S8 +import qualified Data.ByteString.Lazy as BS +import qualified Data.Map as Map import Control.Arrow (first) import Data.Char (toUpper) @@ -20,41 +20,48 @@ import Data.String (fromString) safeRead :: Read a => a -> String -> a safeRead d s = case reads s of - ((x, _):_) -> x - _ -> d + ((x, _) : _) -> x + _ -> d cgiToApp :: CGI CGIResult -> Application cgiToApp = cgiToAppGeneric id -cgiToAppGeneric :: Monad m - => (m (Headers, CGIResult) -> IO (Headers, CGIResult)) - -> CGIT m CGIResult - -> Application +cgiToAppGeneric + :: Monad m + => (m (Headers, CGIResult) -> IO (Headers, CGIResult)) + -> CGIT m CGIResult + -> Application cgiToAppGeneric toIO cgi env sendResponse = do input <- lazyRequestBody env - let vars = map (first fixVarName . go) (requestHeaders env) - ++ getCgiVars env + let vars = + map (first fixVarName . go) (requestHeaders env) + ++ getCgiVars env (inputs, body') = decodeInput vars input - req = CGIRequest + req = + CGIRequest { cgiVars = Map.fromList vars , cgiInputs = inputs , cgiRequestBody = body' } (headers'', output') <- liftIO $ toIO $ runCGIT cgi req let output = case output' of - CGIOutput bs -> bs - CGINothing -> BS.empty - let headers' = map (\(HeaderName x, y) -> - (fromString x, S8.pack y)) headers'' + CGIOutput bs -> bs + CGINothing -> BS.empty + let headers' = + map + ( \(HeaderName x, y) -> + (fromString x, S8.pack y) + ) + headers'' let status' = case lookup (fromString "Status") headers' of - Nothing -> 200 - Just s -> safeRead 200 $ S8.unpack s + Nothing -> 200 + Just s -> safeRead 200 $ S8.unpack s sendResponse $ responseLBS (Status status' S8.empty) headers' output where go (x, y) = (S8.unpack $ original x, S8.unpack y) fixVarName :: String -> String -fixVarName = ("HTTP_"++) . map fixVarNameChar +fixVarName = ("HTTP_" ++) . map fixVarNameChar fixVarNameChar :: Char -> Char fixVarNameChar '-' = '_' @@ -64,8 +71,10 @@ getCgiVars :: Request -> [(String, String)] getCgiVars e = [ ("PATH_INFO", S8.unpack $ rawPathInfo e) , ("REQUEST_METHOD", show $ requestMethod e) - , ("QUERY_STRING", - case S8.unpack $ rawQueryString e of - '?':rest -> rest - x -> x) + , + ( "QUERY_STRING" + , case S8.unpack $ rawQueryString e of + '?' : rest -> rest + x -> x + ) ] diff --git a/wai-frontend-monadcgi/samples/wai_cgi.hs b/wai-frontend-monadcgi/samples/wai_cgi.hs index 6b8e73147..83e419639 100644 --- a/wai-frontend-monadcgi/samples/wai_cgi.hs +++ b/wai-frontend-monadcgi/samples/wai_cgi.hs @@ -1,8 +1,9 @@ import qualified Network.CGI -import qualified Network.Wai.Handler.SimpleServer import qualified Network.Wai.Frontend.MonadCGI +import qualified Network.Wai.Handler.SimpleServer -main = Network.Wai.Handler.SimpleServer.run 3000 - $ Network.Wai.Frontend.MonadCGI.cgiToApp mainCGI +main = + Network.Wai.Handler.SimpleServer.run 3000 $ + Network.Wai.Frontend.MonadCGI.cgiToApp mainCGI mainCGI = Network.CGI.output "This is a test" diff --git a/wai-frontend-monadcgi/samples/wai_cgi_generic.hs b/wai-frontend-monadcgi/samples/wai_cgi_generic.hs index 35e7dce46..1460c970c 100644 --- a/wai-frontend-monadcgi/samples/wai_cgi_generic.hs +++ b/wai-frontend-monadcgi/samples/wai_cgi_generic.hs @@ -1,14 +1,16 @@ {-# LANGUAGE PackageImports #-} + import qualified Network.CGI -import qualified Network.Wai.Handler.SimpleServer import qualified Network.Wai.Frontend.MonadCGI +import qualified Network.Wai.Handler.SimpleServer import "mtl" Control.Monad.Reader main :: IO () -main = Network.Wai.Handler.SimpleServer.run 3000 - $ Network.Wai.Frontend.MonadCGI.cgiToAppGeneric - monadToIO - mainCGI +main = + Network.Wai.Handler.SimpleServer.run 3000 $ + Network.Wai.Frontend.MonadCGI.cgiToAppGeneric + monadToIO + mainCGI mainCGI :: Network.CGI.CGIT (Reader String) Network.CGI.CGIResult mainCGI = do diff --git a/wai-frontend-monadcgi/samples/wai_fastcgi.hs b/wai-frontend-monadcgi/samples/wai_fastcgi.hs index 82321e2b5..5a624a342 100644 --- a/wai-frontend-monadcgi/samples/wai_fastcgi.hs +++ b/wai-frontend-monadcgi/samples/wai_fastcgi.hs @@ -1,8 +1,9 @@ import qualified Network.CGI -import qualified Network.Wai.Handler.FastCGI import qualified Network.Wai.Frontend.MonadCGI +import qualified Network.Wai.Handler.FastCGI -main = Network.Wai.Handler.FastCGI.run - $ Network.Wai.Frontend.MonadCGI.cgiToApp mainCGI +main = + Network.Wai.Handler.FastCGI.run $ + Network.Wai.Frontend.MonadCGI.cgiToApp mainCGI mainCGI = Network.CGI.output "This is a test" diff --git a/wai-http2-extra/Network/Wai/Middleware/Push/Referer.hs b/wai-http2-extra/Network/Wai/Middleware/Push/Referer.hs index fac40cd9e..f0cbda3eb 100644 --- a/wai-http2-extra/Network/Wai/Middleware/Push/Referer.hs +++ b/wai-http2-extra/Network/Wai/Middleware/Push/Referer.hs @@ -2,28 +2,30 @@ -- | Middleware for server push learning dependency based on Referer:. module Network.Wai.Middleware.Push.Referer ( - -- * Middleware - pushOnReferer - -- * Making push promise - , URLPath - , MakePushPromise - , defaultMakePushPromise - -- * Settings - , Settings - , M.defaultSettings - , makePushPromise - , duration - , keyLimit - , valueLimit - ) where + -- * Middleware + pushOnReferer, + + -- * Making push promise + URLPath, + MakePushPromise, + defaultMakePushPromise, + + -- * Settings + Settings, + M.defaultSettings, + makePushPromise, + duration, + keyLimit, + valueLimit, +) where import Control.Monad (when) import qualified Data.ByteString as BS import Data.Maybe (isNothing) -import Network.HTTP.Types (Status(..)) +import Network.HTTP.Types (Status (..)) import Network.Wai import Network.Wai.Handler.Warp hiding (Settings, defaultSettings) -import Network.Wai.Internal (Response(..)) +import Network.Wai.Internal (Response (..)) import qualified Network.Wai.Middleware.Push.Referer.Manager as M import Network.Wai.Middleware.Push.Referer.ParseURL @@ -41,39 +43,43 @@ pushOnReferer settings app req sendResponse = do where path = rawPathInfo req push mgr res@(ResponseFile (Status 200 "OK") _ file Nothing) - -- file: /index.html - -- path: / - -- referer: - -- refPath: - | isHTML path = do + -- file: /index.html + -- path: / + -- referer: + -- refPath: + | isHTML path = do xs <- M.lookup path mgr case xs of - [] -> return () - ps -> do - let h2d = defaultHTTP2Data { http2dataPushPromise = ps } - setHTTP2Data req $ Just h2d + [] -> return () + ps -> do + let h2d = defaultHTTP2Data{http2dataPushPromise = ps} + setHTTP2Data req $ Just h2d sendResponse res - -- file: /style.css - -- path: /style.css - -- referer: /index.html - -- refPath: / - | otherwise = case requestHeaderReferer req of - Nothing -> sendResponse res - Just referer -> do - (mauth,refPath) <- parseUrl referer - when ((isNothing mauth || requestHeaderHost req == mauth) - && path /= refPath - && isHTML refPath) $ do - let path' = BS.copy path - refPath' = BS.copy refPath - mpp <- makePushPromise settings refPath' path' file - case mpp of - Nothing -> return () - Just pp -> M.insert refPath' pp mgr - sendResponse res + -- file: /style.css + -- path: /style.css + -- referer: /index.html + -- refPath: / + | otherwise = case requestHeaderReferer req of + Nothing -> sendResponse res + Just referer -> do + (mauth, refPath) <- parseUrl referer + when + ( (isNothing mauth || requestHeaderHost req == mauth) + && path /= refPath + && isHTML refPath + ) + $ do + let path' = BS.copy path + refPath' = BS.copy refPath + mpp <- makePushPromise settings refPath' path' file + case mpp of + Nothing -> return () + Just pp -> M.insert refPath' pp mgr + sendResponse res push _ res = sendResponse res isHTML :: URLPath -> Bool -isHTML p = ("/" `BS.isSuffixOf` p) +isHTML p = + ("/" `BS.isSuffixOf` p) || (".html" `BS.isSuffixOf` p) || (".htm" `BS.isSuffixOf` p) diff --git a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/LRU.hs b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/LRU.hs index 0be574ec1..ec92e4774 100644 --- a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/LRU.hs +++ b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/LRU.hs @@ -1,15 +1,15 @@ -- from https://jaspervdj.be/posts/2015-02-24-lru-cache.html module Network.Wai.Middleware.Push.Referer.LRU ( - Cache(..) - , Priority - , empty - , insert - , lookup - ) where + Cache (..), + Priority, + empty, + insert, + lookup, +) where +import Data.Int (Int64) import Data.OrdPSQ (OrdPSQ) import qualified Data.OrdPSQ as PSQ -import Data.Int (Int64) import Prelude hiding (lookup) import Network.Wai.Middleware.Push.Referer.Multi (Multi) @@ -17,48 +17,55 @@ import qualified Network.Wai.Middleware.Push.Referer.Multi as M type Priority = Int64 -data Cache k v = Cache { - cCapacity :: Int -- ^ The maximum number of elements in the queue - , cSize :: Int -- ^ The current number of elements in the queue - , cValLimit :: Int - , cTick :: Priority -- ^ The next logical time - , cQueue :: OrdPSQ k Priority (Multi v) - } deriving (Eq, Show) +data Cache k v = Cache + { cCapacity :: Int + -- ^ The maximum number of elements in the queue + , cSize :: Int + -- ^ The current number of elements in the queue + , cValLimit :: Int + , cTick :: Priority + -- ^ The next logical time + , cQueue :: OrdPSQ k Priority (Multi v) + } + deriving (Eq, Show) empty :: Int -> Int -> Cache k v empty capacity valLimit - | capacity < 1 = error "Cache.empty: capacity < 1" - | otherwise = Cache { - cCapacity = capacity - , cSize = 0 - , cValLimit = valLimit - , cTick = 0 - , cQueue = PSQ.empty - } + | capacity < 1 = error "Cache.empty: capacity < 1" + | otherwise = + Cache + { cCapacity = capacity + , cSize = 0 + , cValLimit = valLimit + , cTick = 0 + , cQueue = PSQ.empty + } trim :: Ord k => Cache k v -> Cache k v trim c - | cTick c == maxBound = empty (cCapacity c) (cValLimit c) - | cSize c > cCapacity c = c { - cSize = cSize c - 1 - , cQueue = PSQ.deleteMin (cQueue c) - } - | otherwise = c + | cTick c == maxBound = empty (cCapacity c) (cValLimit c) + | cSize c > cCapacity c = + c + { cSize = cSize c - 1 + , cQueue = PSQ.deleteMin (cQueue c) + } + | otherwise = c insert :: (Ord k, Ord v) => k -> v -> Cache k v -> Cache k v insert k v c = case PSQ.alter lookupAndBump k (cQueue c) of - (True, q) -> trim $ c { cTick = cTick c + 1, cQueue = q, cSize = cSize c + 1} - (False, q) -> trim $ c { cTick = cTick c + 1, cQueue = q } + (True, q) -> trim $ c{cTick = cTick c + 1, cQueue = q, cSize = cSize c + 1} + (False, q) -> trim $ c{cTick = cTick c + 1, cQueue = q} where - lookupAndBump Nothing = (True, Just (cTick c, M.singleton (cValLimit c) v)) + lookupAndBump Nothing = (True, Just (cTick c, M.singleton (cValLimit c) v)) lookupAndBump (Just (_, x)) = (False, Just (cTick c, M.insert v x)) lookup :: Ord k => k -> Cache k v -> (Cache k v, [v]) lookup k c = case PSQ.alter lookupAndBump k (cQueue c) of (Nothing, _) -> (c, []) - (Just x, q) -> let c' = trim $ c { cTick = cTick c + 1, cQueue = q } - xs = M.list x - in (c', xs) + (Just x, q) -> + let c' = trim $ c{cTick = cTick c + 1, cQueue = q} + xs = M.list x + in (c', xs) where - lookupAndBump Nothing = (Nothing, Nothing) - lookupAndBump (Just (_, x)) = (Just x, Just (cTick c, x)) + lookupAndBump Nothing = (Nothing, Nothing) + lookupAndBump (Just (_, x)) = (Just x, Just (cTick c, x)) diff --git a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/LimitMultiMap.hs b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/LimitMultiMap.hs index 33b561297..6daf8e51e 100644 --- a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/LimitMultiMap.hs +++ b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/LimitMultiMap.hs @@ -7,11 +7,12 @@ import qualified Data.Map.Strict as M import Data.Set (Set) import qualified Data.Set as S -data LimitMultiMap k v = LimitMultiMap { - limitKey :: !Int +data LimitMultiMap k v = LimitMultiMap + { limitKey :: !Int , limitVal :: !Int , multiMap :: !(Map k (Set v)) - } deriving (Eq, Show) + } + deriving (Eq, Show) isEmpty :: LimitMultiMap k t -> Bool isEmpty (LimitMultiMap _ _ m) = M.null m @@ -19,22 +20,22 @@ isEmpty (LimitMultiMap _ _ m) = M.null m empty :: Int -> Int -> LimitMultiMap k v empty lk lv = LimitMultiMap lk lv M.empty -insert :: (Ord k, Ord v) => (k,v) -> LimitMultiMap k v -> LimitMultiMap k v -insert (k,v) (LimitMultiMap lk lv m) - | siz < lk = let !m' = M.alter alt k m in LimitMultiMap lk lv m' - | siz == lk = let !m' = M.adjust adj k m in LimitMultiMap lk lv m' - | otherwise = error "insert" +insert :: (Ord k, Ord v) => (k, v) -> LimitMultiMap k v -> LimitMultiMap k v +insert (k, v) (LimitMultiMap lk lv m) + | siz < lk = let !m' = M.alter alt k m in LimitMultiMap lk lv m' + | siz == lk = let !m' = M.adjust adj k m in LimitMultiMap lk lv m' + | otherwise = error "insert" where siz = M.size m - alt Nothing = Just $ S.singleton v + alt Nothing = Just $ S.singleton v alt s@(Just set) - | S.size set == lv = s - | otherwise = Just $ S.insert v set + | S.size set == lv = s + | otherwise = Just $ S.insert v set adj set - | S.size set == lv = set - | otherwise = S.insert v set + | S.size set == lv = set + | otherwise = S.insert v set lookup :: Ord k => k -> LimitMultiMap k v -> [v] lookup k (LimitMultiMap _ _ m) = case M.lookup k m of - Nothing -> [] - Just set -> S.toList set + Nothing -> [] + Just set -> S.toList set diff --git a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Manager.hs b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Manager.hs index 1917b1e9c..1bbb4696a 100644 --- a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Manager.hs +++ b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Manager.hs @@ -2,24 +2,24 @@ {-# LANGUAGE RecordWildCards #-} module Network.Wai.Middleware.Push.Referer.Manager ( - MakePushPromise - , defaultMakePushPromise - , Settings(..) - , defaultSettings - , Manager - , URLPath - , getManager - , Network.Wai.Middleware.Push.Referer.Manager.lookup - , Network.Wai.Middleware.Push.Referer.Manager.insert - ) where + MakePushPromise, + defaultMakePushPromise, + Settings (..), + defaultSettings, + Manager, + URLPath, + getManager, + Network.Wai.Middleware.Push.Referer.Manager.lookup, + Network.Wai.Middleware.Push.Referer.Manager.insert, +) where import Control.Monad (unless) import Data.IORef import Network.Wai.Handler.Warp hiding (Settings, defaultSettings) import System.IO.Unsafe (unsafePerformIO) -import Network.Wai.Middleware.Push.Referer.Types import qualified Network.Wai.Middleware.Push.Referer.LRU as LRU +import Network.Wai.Middleware.Push.Referer.Types newtype Manager = Manager (IORef (LRU.Cache URLPath PushPromise)) diff --git a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Multi.hs b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Multi.hs index ec5c9a40b..c0c80dc9e 100644 --- a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Multi.hs +++ b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Multi.hs @@ -1,14 +1,16 @@ {-# LANGUAGE RecordWildCards #-} + module Network.Wai.Middleware.Push.Referer.Multi where import Data.Set (Set) import qualified Data.Set as Set -data Multi a = Multi { - limit :: Int - , list :: [a] - , check :: Set a - } deriving (Eq, Show) +data Multi a = Multi + { limit :: Int + , list :: [a] + , check :: Set a + } + deriving (Eq, Show) empty :: Int -> Multi a empty n = Multi n [] Set.empty @@ -18,9 +20,9 @@ singleton n v = Multi n [v] $ Set.singleton v insert :: Ord a => a -> Multi a -> Multi a insert _ m@Multi{..} - | Set.size check == limit = m + | Set.size check == limit = m insert v m@Multi{..} - | Set.size check == Set.size check' = m - | otherwise = Multi limit (v:list) check' + | Set.size check == Set.size check' = m + | otherwise = Multi limit (v : list) check' where check' = Set.insert v check diff --git a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/ParseURL.hs b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/ParseURL.hs index 77607dbb0..5474a9bab 100644 --- a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/ParseURL.hs +++ b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/ParseURL.hs @@ -1,13 +1,14 @@ {-# LANGUAGE OverloadedStrings #-} + module Network.Wai.Middleware.Push.Referer.ParseURL ( - parseUrl - ) where + parseUrl, +) where import Data.ByteString (ByteString) -import Data.ByteString.Internal (ByteString(..), memchr) +import Data.ByteString.Internal (ByteString (..), memchr) import Data.Word8 -import Foreign.ForeignPtr (withForeignPtr, ForeignPtr) -import Foreign.Ptr (Ptr, plusPtr, minusPtr, nullPtr) +import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) +import Foreign.Ptr (Ptr, minusPtr, nullPtr, plusPtr) import Foreign.Storable (peek) import Network.Wai.Middleware.Push.Referer.Types @@ -30,57 +31,63 @@ import Network.Wai.Middleware.Push.Referer.Types -- (Just "www.example.com:8080","/path/to/dir/") -- >>> parseUrl "/path/to/dir/" -- (Nothing,"/path/to/dir/") - parseUrl :: ByteString -> IO (Maybe ByteString, URLPath) parseUrl bs@(PS fptr0 off len) - | len == 0 = return (Nothing, "") - | len == 1 = return (Nothing, bs) - | otherwise = withForeignPtr fptr0 $ \ptr0 -> do - let begptr = ptr0 `plusPtr` off - limptr = begptr `plusPtr` len - parseUrl' fptr0 ptr0 begptr limptr len + | len == 0 = return (Nothing, "") + | len == 1 = return (Nothing, bs) + | otherwise = withForeignPtr fptr0 $ \ptr0 -> do + let begptr = ptr0 `plusPtr` off + limptr = begptr `plusPtr` len + parseUrl' fptr0 ptr0 begptr limptr len -parseUrl' :: ForeignPtr Word8 -> Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int - -> IO (Maybe ByteString, URLPath) +parseUrl' + :: ForeignPtr Word8 + -> Ptr Word8 + -> Ptr Word8 + -> Ptr Word8 + -> Int + -> IO (Maybe ByteString, URLPath) parseUrl' fptr0 ptr0 begptr limptr len0 = do - w0 <- peek begptr - if w0 == _slash then do - w1 <- peek $ begptr `plusPtr` 1 - if w1 == _slash then - doubleSlashed begptr len0 - else - slashed begptr len0 Nothing + w0 <- peek begptr + if w0 == _slash + then do + w1 <- peek $ begptr `plusPtr` 1 + if w1 == _slash + then doubleSlashed begptr len0 + else slashed begptr len0 Nothing else do - colonptr <- memchr begptr _colon $ fromIntegral len0 - if colonptr == nullPtr then - return (Nothing, "") - else do - let authptr = colonptr `plusPtr` 1 - doubleSlashed authptr (limptr `minusPtr` authptr) + colonptr <- memchr begptr _colon $ fromIntegral len0 + if colonptr == nullPtr + then return (Nothing, "") + else do + let authptr = colonptr `plusPtr` 1 + doubleSlashed authptr (limptr `minusPtr` authptr) where -- // / ? doubleSlashed :: Ptr Word8 -> Int -> IO (Maybe ByteString, URLPath) doubleSlashed ptr len - | len < 2 = return (Nothing, "") - | otherwise = do - let ptr1 = ptr `plusPtr` 2 - pathptr <- memchr ptr1 _slash $ fromIntegral len - if pathptr == nullPtr then - return (Nothing, "") - else do - let auth = bs ptr0 ptr1 pathptr - slashed pathptr (limptr `minusPtr` pathptr) (Just auth) + | len < 2 = return (Nothing, "") + | otherwise = do + let ptr1 = ptr `plusPtr` 2 + pathptr <- memchr ptr1 _slash $ fromIntegral len + if pathptr == nullPtr + then return (Nothing, "") + else do + let auth = bs ptr0 ptr1 pathptr + slashed pathptr (limptr `minusPtr` pathptr) (Just auth) -- / ? - slashed :: Ptr Word8 -> Int -> Maybe ByteString -> IO (Maybe ByteString, URLPath) + slashed + :: Ptr Word8 -> Int -> Maybe ByteString -> IO (Maybe ByteString, URLPath) slashed ptr len mauth = do questionptr <- memchr ptr _question $ fromIntegral len - if questionptr == nullPtr then do - let path = bs ptr0 ptr limptr - return (mauth, path) - else do - let path = bs ptr0 ptr questionptr - return (mauth, path) + if questionptr == nullPtr + then do + let path = bs ptr0 ptr limptr + return (mauth, path) + else do + let path = bs ptr0 ptr questionptr + return (mauth, path) bs p0 p1 p2 = path where off = p1 `minusPtr` p0 diff --git a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Types.hs b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Types.hs index 2dd92329b..95b519937 100644 --- a/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Types.hs +++ b/wai-http2-extra/Network/Wai/Middleware/Push/Referer/Types.hs @@ -1,15 +1,16 @@ {-# LANGUAGE OverloadedStrings #-} + module Network.Wai.Middleware.Push.Referer.Types ( - URLPath - , MakePushPromise - , defaultMakePushPromise - , Settings(..) - , defaultSettings - ) where + URLPath, + MakePushPromise, + defaultMakePushPromise, + Settings (..), + defaultSettings, +) where import Data.ByteString (ByteString) import qualified Data.ByteString as BS -import Network.Wai.Handler.Warp (PushPromise(..), defaultPushPromise) +import Network.Wai.Handler.Warp (PushPromise (..), defaultPushPromise) -- | Type for URL path. type URLPath = ByteString @@ -21,43 +22,55 @@ type URLPath = ByteString -- this function should return 'Just'. -- If 'Nothing' is returned, -- the middleware learns nothing. -type MakePushPromise = URLPath -- ^ path in referer (key: /index.html) - -> URLPath -- ^ path to be pushed (value: /style.css) - -> FilePath -- ^ file to be pushed (file_path/style.css) - -> IO (Maybe PushPromise) +type MakePushPromise = + URLPath + -- ^ path in referer (key: /index.html) + -> URLPath + -- ^ path to be pushed (value: /style.css) + -> FilePath + -- ^ file to be pushed (file_path/style.css) + -> IO (Maybe PushPromise) -- | Learn if the file to be pushed is CSS (.css) or JavaScript (.js) file. defaultMakePushPromise :: MakePushPromise defaultMakePushPromise refPath path file = case getCT path of - Nothing -> return Nothing - Just ct -> do - let pp = defaultPushPromise { - promisedPath = path - , promisedFile = file - , promisedResponseHeaders = [("content-type", ct) - ,("x-http2-push", refPath)] - } - return $ Just pp + Nothing -> return Nothing + Just ct -> do + let pp = + defaultPushPromise + { promisedPath = path + , promisedFile = file + , promisedResponseHeaders = + [ ("content-type", ct) + , ("x-http2-push", refPath) + ] + } + return $ Just pp getCT :: URLPath -> Maybe ByteString getCT p - | ".js" `BS.isSuffixOf` p = Just "application/javascript" - | ".css" `BS.isSuffixOf` p = Just "text/css" - | otherwise = Nothing + | ".js" `BS.isSuffixOf` p = Just "application/javascript" + | ".css" `BS.isSuffixOf` p = Just "text/css" + | otherwise = Nothing -- | Settings for server push based on Referer:. -data Settings = Settings { - makePushPromise :: MakePushPromise -- ^ Default: 'defaultMakePushPromise' - , duration :: Int -- ^ Deprecated - , keyLimit :: Int -- ^ Max number of keys (e.g. index.html) in the learning information. Default: 20 - , valueLimit :: Int -- ^ Max number of values (e.g. style.css) in the learning information. Default: 20 - } +data Settings = Settings + { makePushPromise :: MakePushPromise + -- ^ Default: 'defaultMakePushPromise' + , duration :: Int + -- ^ Deprecated + , keyLimit :: Int + -- ^ Max number of keys (e.g. index.html) in the learning information. Default: 20 + , valueLimit :: Int + -- ^ Max number of values (e.g. style.css) in the learning information. Default: 20 + } -- | Default settings. defaultSettings :: Settings -defaultSettings = Settings { - makePushPromise = defaultMakePushPromise - , duration = 0 - , keyLimit = 20 - , valueLimit = 20 - } +defaultSettings = + Settings + { makePushPromise = defaultMakePushPromise + , duration = 0 + , keyLimit = 20 + , valueLimit = 20 + } diff --git a/wai-websockets/Network/Wai/Handler/WebSockets.hs b/wai-websockets/Network/Wai/Handler/WebSockets.hs index 838732a99..7930ec7fc 100644 --- a/wai-websockets/Network/Wai/Handler/WebSockets.hs +++ b/wai-websockets/Network/Wai/Handler/WebSockets.hs @@ -1,30 +1,33 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.Handler.WebSockets - ( websocketsOr - , websocketsApp - , isWebSocketsReq - , getRequestHead - , runWebSockets - ) where -import Control.Exception (bracket, tryJust) -import Data.ByteString (ByteString) -import qualified Data.ByteString.Char8 as BC -import qualified Data.ByteString.Lazy as BL -import qualified Data.CaseInsensitive as CI -import Network.HTTP.Types (status500) -import qualified Network.Wai as Wai -import qualified Network.WebSockets as WS -import qualified Network.WebSockets.Connection as WS -import qualified Network.WebSockets.Stream as WS +module Network.Wai.Handler.WebSockets ( + websocketsOr, + websocketsApp, + isWebSocketsReq, + getRequestHead, + runWebSockets, +) where + +import Control.Exception (bracket, tryJust) +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import qualified Data.CaseInsensitive as CI +import Network.HTTP.Types (status500) +import qualified Network.Wai as Wai +import qualified Network.WebSockets as WS +import qualified Network.WebSockets.Connection as WS +import qualified Network.WebSockets.Stream as WS -------------------------------------------------------------------------------- + -- | Returns whether or not the given 'Wai.Request' is a WebSocket request. isWebSocketsReq :: Wai.Request -> Bool isWebSocketsReq req = fmap CI.mk (lookup "upgrade" $ Wai.requestHeaders req) == Just "websocket" -------------------------------------------------------------------------------- + -- | Upgrade a @websockets@ 'WS.ServerApp' to a @wai@ 'Wai.Application'. Uses -- the given backup 'Wai.Application' to handle 'Wai.Request's that are not -- WebSocket requests. @@ -51,25 +54,28 @@ isWebSocketsReq req = -- backupApp :: 'Wai.Application' -- backupApp _ respond = respond $ 'Wai.responseLBS' 'Network.HTTP.Types.status400' [] "Not a WebSocket request" -- @ -websocketsOr :: WS.ConnectionOptions - -> WS.ServerApp - -> Wai.Application - -> Wai.Application +websocketsOr + :: WS.ConnectionOptions + -> WS.ServerApp + -> Wai.Application + -> Wai.Application websocketsOr opts app backup req sendResponse = case websocketsApp opts app req of Nothing -> backup req sendResponse Just res -> sendResponse res -------------------------------------------------------------------------------- + -- | Handle a single @wai@ 'Wai.Request' with the given @websockets@ -- 'WS.ServerApp'. Returns 'Nothing' if the 'Wai.Request' is not a WebSocket -- request, 'Just' otherwise. -- -- Usually, 'websocketsOr' is more convenient. -websocketsApp :: WS.ConnectionOptions - -> WS.ServerApp - -> Wai.Request - -> Maybe Wai.Response +websocketsApp + :: WS.ConnectionOptions + -> WS.ServerApp + -> Wai.Request + -> Maybe Wai.Response websocketsApp opts app req | isWebSocketsReq req = Just $ flip Wai.responseRaw backup $ \src sink -> @@ -77,41 +83,49 @@ websocketsApp opts app req | otherwise = Nothing where req' = getRequestHead req - backup = Wai.responseLBS status500 [("Content-Type", "text/plain")] - "The web application attempted to send a WebSockets response, but WebSockets are not supported by your WAI handler." + backup = + Wai.responseLBS + status500 + [("Content-Type", "text/plain")] + "The web application attempted to send a WebSockets response, but WebSockets are not supported by your WAI handler." -------------------------------------------------------------------------------- getRequestHead :: Wai.Request -> WS.RequestHead -getRequestHead req = WS.RequestHead - (Wai.rawPathInfo req `BC.append` Wai.rawQueryString req) - (Wai.requestHeaders req) - (Wai.isSecure req) +getRequestHead req = + WS.RequestHead + (Wai.rawPathInfo req `BC.append` Wai.rawQueryString req) + (Wai.requestHeaders req) + (Wai.isSecure req) -------------------------------------------------------------------------------- + -- | Internal function to run the WebSocket io-streams using the conduit library. -runWebSockets :: WS.ConnectionOptions - -> WS.RequestHead - -> (WS.PendingConnection -> IO a) - -> IO ByteString - -> (ByteString -> IO ()) - -> IO a +runWebSockets + :: WS.ConnectionOptions + -> WS.RequestHead + -> (WS.PendingConnection -> IO a) + -> IO ByteString + -> (ByteString -> IO ()) + -> IO a runWebSockets opts req app src sink = bracket mkStream ensureClose (app . pc) where ensureClose = tryJust onConnectionException . WS.close onConnectionException :: WS.ConnectionException -> Maybe () onConnectionException WS.ConnectionClosed = Just () - onConnectionException _ = Nothing + onConnectionException _ = Nothing mkStream = WS.makeStream - (do + ( do bs <- src - return $ if BC.null bs then Nothing else Just bs) + return $ if BC.null bs then Nothing else Just bs + ) -- 'mapM_' works here because it's 'return ()' on Nothing (mapM_ $ \bl -> mapM_ sink (BL.toChunks bl)) - pc stream = WS.PendingConnection - { WS.pendingOptions = opts - , WS.pendingRequest = req - , WS.pendingOnAccept = \_ -> return () - , WS.pendingStream = stream - } + pc stream = + WS.PendingConnection + { WS.pendingOptions = opts + , WS.pendingRequest = req + , WS.pendingOnAccept = \_ -> return () + , WS.pendingStream = stream + } diff --git a/wai/Network/Wai.hs b/wai/Network/Wai.hs index ec1625ec8..9aaa24da5 100644 --- a/wai/Network/Wai.hs +++ b/wai/Network/Wai.hs @@ -1,116 +1,127 @@ -{-| - -This module defines a generic web application interface. It is a common -protocol between web servers and web applications. - -The overriding design principles here are performance and generality. To -address performance, this library uses a streaming interface for request and -response bodies, paired with bytestring's 'Builder' type. The advantages of a -streaming API over lazy IO have been debated elsewhere and so will not be -addressed here. However, helper functions like 'responseLBS' allow you to -continue using lazy IO if you so desire. - -Generality is achieved by removing many variables commonly found in similar -projects that are not universal to all servers. The goal is that the 'Request' -object contains only data which is meaningful in all circumstances. - -Please remember when using this package that, while your application may -compile without a hitch against many different servers, there are other -considerations to be taken when moving to a new backend. For example, if you -transfer from a CGI application to a FastCGI one, you might suddenly find you -have a memory leak. Conversely, a FastCGI application would be well served to -preload all templates from disk when first starting; this would kill the -performance of a CGI application. - -This package purposely provides very little functionality. You can find various -middlewares, backends and utilities on Hackage. Some of the most commonly used -include: - -[warp] - -[wai-extra] - --} -- Ignore deprecations, because this module needs to use the deprecated requestBody to construct a response. {-# OPTIONS_GHC -fno-warn-deprecations #-} -module Network.Wai - ( - -- * Types - Application - , Middleware - , ResponseReceived - -- * Request - , Request - , defaultRequest - , RequestBodyLength (..) - -- ** Request accessors - , requestMethod - , httpVersion - , rawPathInfo - , rawQueryString - , requestHeaders - , isSecure - , remoteHost - , pathInfo - , queryString - , getRequestBodyChunk - , requestBody - , vault - , requestBodyLength - , requestHeaderHost - , requestHeaderRange - , requestHeaderReferer - , requestHeaderUserAgent + +-- | +-- +-- This module defines a generic web application interface. It is a common +-- protocol between web servers and web applications. +-- +-- The overriding design principles here are performance and generality. To +-- address performance, this library uses a streaming interface for request and +-- response bodies, paired with bytestring's 'Builder' type. The advantages of a +-- streaming API over lazy IO have been debated elsewhere and so will not be +-- addressed here. However, helper functions like 'responseLBS' allow you to +-- continue using lazy IO if you so desire. +-- +-- Generality is achieved by removing many variables commonly found in similar +-- projects that are not universal to all servers. The goal is that the 'Request' +-- object contains only data which is meaningful in all circumstances. +-- +-- Please remember when using this package that, while your application may +-- compile without a hitch against many different servers, there are other +-- considerations to be taken when moving to a new backend. For example, if you +-- transfer from a CGI application to a FastCGI one, you might suddenly find you +-- have a memory leak. Conversely, a FastCGI application would be well served to +-- preload all templates from disk when first starting; this would kill the +-- performance of a CGI application. +-- +-- This package purposely provides very little functionality. You can find various +-- middlewares, backends and utilities on Hackage. Some of the most commonly used +-- include: +-- +-- [warp] +-- +-- [wai-extra] +module Network.Wai ( + -- * Types + Application, + Middleware, + ResponseReceived, + + -- * Request + Request, + defaultRequest, + RequestBodyLength (..), + + -- ** Request accessors + requestMethod, + httpVersion, + rawPathInfo, + rawQueryString, + requestHeaders, + isSecure, + remoteHost, + pathInfo, + queryString, + getRequestBodyChunk, + requestBody, + vault, + requestBodyLength, + requestHeaderHost, + requestHeaderRange, + requestHeaderReferer, + requestHeaderUserAgent, -- $streamingRequestBodies - , strictRequestBody - , consumeRequestBodyStrict - , lazyRequestBody - , consumeRequestBodyLazy - -- ** Request modifiers - , setRequestBodyChunks - , mapRequestHeaders - -- * Response - , Response - , StreamingBody - , FilePart (..) - -- ** Response composers - , responseFile - , responseBuilder - , responseLBS - , responseStream - , responseRaw - -- ** Response accessors - , responseStatus - , responseHeaders - -- ** Response modifiers - , responseToStream - , mapResponseHeaders - , mapResponseStatus - -- * Middleware composition - , ifRequest - , modifyRequest - , modifyResponse - ) where - -import Data.ByteString.Builder (Builder, byteString, lazyByteString) -import Control.Monad (unless) -import qualified Data.ByteString as B -import qualified Data.ByteString.Lazy as L + strictRequestBody, + consumeRequestBodyStrict, + lazyRequestBody, + consumeRequestBodyLazy, + + -- ** Request modifiers + setRequestBodyChunks, + mapRequestHeaders, + + -- * Response + Response, + StreamingBody, + FilePart (..), + + -- ** Response composers + responseFile, + responseBuilder, + responseLBS, + responseStream, + responseRaw, + + -- ** Response accessors + responseStatus, + responseHeaders, + + -- ** Response modifiers + responseToStream, + mapResponseHeaders, + mapResponseStatus, + + -- * Middleware composition + ifRequest, + modifyRequest, + modifyResponse, +) where + +import Control.Monad (unless) +import qualified Data.ByteString as B +import Data.ByteString.Builder ( + Builder, + byteString, + lazyByteString, + ) +import qualified Data.ByteString.Lazy as L +import Data.ByteString.Lazy.Internal (defaultChunkSize) import qualified Data.ByteString.Lazy.Internal as LI -import Data.ByteString.Lazy.Internal (defaultChunkSize) -import Data.Function (fix) -import qualified Network.HTTP.Types as H -import Network.Socket (SockAddr (SockAddrInet)) -import Network.Wai.Internal -import qualified System.IO as IO -import System.IO.Unsafe (unsafeInterleaveIO) +import Data.Function (fix) +import qualified Network.HTTP.Types as H +import Network.Socket (SockAddr (SockAddrInet)) +import Network.Wai.Internal +import qualified System.IO as IO +import System.IO.Unsafe (unsafeInterleaveIO) ---------------------------------------------------------------- -- | Creating 'Response' from a file. -- -- @since 2.0.0 -responseFile :: H.Status -> H.ResponseHeaders -> FilePath -> Maybe FilePart -> Response +responseFile + :: H.Status -> H.ResponseHeaders -> FilePath -> Maybe FilePart -> Response responseFile = ResponseFile -- | Creating 'Response' from 'Builder'. @@ -170,10 +181,11 @@ responseLBS s h = ResponseBuilder s h . lazyByteString -- and response headers to depend on the scarce resource. -- -- @since 3.0.0 -responseStream :: H.Status - -> H.ResponseHeaders - -> StreamingBody - -> Response +responseStream + :: H.Status + -> H.ResponseHeaders + -> StreamingBody + -> Response responseStream = ResponseStream -- | Create a response for a raw application. This is useful for \"upgrade\" @@ -187,9 +199,10 @@ responseStream = ResponseStream -- @responseRaw@, behavior is undefined. -- -- @since 2.1.0 -responseRaw :: (IO B.ByteString -> (B.ByteString -> IO ()) -> IO ()) - -> Response - -> Response +responseRaw + :: (IO B.ByteString -> (B.ByteString -> IO ()) -> IO ()) + -> Response + -> Response responseRaw = ResponseRaw ---------------------------------------------------------------- @@ -198,28 +211,29 @@ responseRaw = ResponseRaw -- -- @since 1.2.0 responseStatus :: Response -> H.Status -responseStatus (ResponseFile s _ _ _) = s -responseStatus (ResponseBuilder s _ _ ) = s -responseStatus (ResponseStream s _ _ ) = s -responseStatus (ResponseRaw _ res ) = responseStatus res +responseStatus (ResponseFile s _ _ _) = s +responseStatus (ResponseBuilder s _ _) = s +responseStatus (ResponseStream s _ _) = s +responseStatus (ResponseRaw _ res) = responseStatus res -- | Accessing 'H.ResponseHeaders' in 'Response'. -- -- @since 2.0.0 responseHeaders :: Response -> H.ResponseHeaders -responseHeaders (ResponseFile _ hs _ _) = hs -responseHeaders (ResponseBuilder _ hs _ ) = hs -responseHeaders (ResponseStream _ hs _ ) = hs -responseHeaders (ResponseRaw _ res) = responseHeaders res +responseHeaders (ResponseFile _ hs _ _) = hs +responseHeaders (ResponseBuilder _ hs _) = hs +responseHeaders (ResponseStream _ hs _) = hs +responseHeaders (ResponseRaw _ res) = responseHeaders res -- | Converting the body information in 'Response' to a 'StreamingBody'. -- -- @since 3.0.0 -responseToStream :: Response - -> ( H.Status - , H.ResponseHeaders - , (StreamingBody -> IO a) -> IO a - ) +responseToStream + :: Response + -> ( H.Status + , H.ResponseHeaders + , (StreamingBody -> IO a) -> IO a + ) responseToStream (ResponseStream s h b) = (s, h, ($ b)) responseToStream (ResponseFile s h fp (Just part)) = ( s @@ -239,7 +253,7 @@ responseToStream (ResponseFile s h fp Nothing) = ( s , h , \withBody -> IO.withBinaryFile fp IO.ReadMode $ \handle -> - withBody $ \sendChunk _flush -> fix $ \loop -> do + withBody $ \sendChunk _flush -> fix $ \loop -> do bs <- B.hGetSome handle defaultChunkSize unless (B.null bs) $ do sendChunk $ byteString bs @@ -252,7 +266,8 @@ responseToStream (ResponseRaw _ res) = responseToStream res -- | Apply the provided function to the response header list of the Response. -- -- @since 3.0.3.0 -mapResponseHeaders :: (H.ResponseHeaders -> H.ResponseHeaders) -> Response -> Response +mapResponseHeaders + :: (H.ResponseHeaders -> H.ResponseHeaders) -> Response -> Response mapResponseHeaders f (ResponseFile s h b1 b2) = ResponseFile s (f h) b1 b2 mapResponseHeaders f (ResponseBuilder s h b) = ResponseBuilder s (f h) b mapResponseHeaders f (ResponseStream s h b) = ResponseStream s (f h) b @@ -282,32 +297,32 @@ mapResponseStatus _ r@(ResponseRaw _ _) = r -- (putStrLn \"Cleaning up\") -- (respond $ responseLBS status200 [] \"Hello World\") -- @ -type Application = Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived - +type Application = + Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived -- | A default, blank request. -- -- @since 2.0.0 defaultRequest :: Request -defaultRequest = Request - { requestMethod = H.methodGet - , httpVersion = H.http10 - , rawPathInfo = B.empty - , rawQueryString = B.empty - , requestHeaders = [] - , isSecure = False - , remoteHost = SockAddrInet 0 0 - , pathInfo = [] - , queryString = [] - , requestBody = return B.empty - , vault = mempty - , requestBodyLength = KnownLength 0 - , requestHeaderHost = Nothing - , requestHeaderRange = Nothing - , requestHeaderReferer = Nothing - , requestHeaderUserAgent = Nothing - } - +defaultRequest = + Request + { requestMethod = H.methodGet + , httpVersion = H.http10 + , rawPathInfo = B.empty + , rawQueryString = B.empty + , requestHeaders = [] + , isSecure = False + , remoteHost = SockAddrInet 0 0 + , pathInfo = [] + , queryString = [] + , requestBody = return B.empty + , vault = mempty + , requestBodyLength = KnownLength 0 + , requestHeaderHost = Nothing + , requestHeaderRange = Nothing + , requestHeaderReferer = Nothing + , requestHeaderUserAgent = Nothing + } -- | A @Middleware@ is a component that sits between the server and application. -- @@ -451,7 +466,6 @@ defaultRequest = Request -- However, modifying the response (especially the response body) is not trivial, -- so in order to get a sense of how to do it (dealing with the type of 'responseToStream'), -- it’s best to look at an example, for example . - type Middleware = Application -> Application -- | Apply a function that modifies a request as a 'Middleware' @@ -472,7 +486,7 @@ modifyResponse f app req respond = app req $ respond . f ifRequest :: (Request -> Bool) -> Middleware -> Middleware ifRequest rpred middle app req | rpred req = middle app req - | otherwise = app req + | otherwise = app req -- $streamingRequestBodies -- @@ -576,5 +590,6 @@ consumeRequestBodyLazy = lazyRequestBody -- | Apply the provided function to the request header list of the 'Request'. -- -- @since 3.2.4 -mapRequestHeaders :: (H.RequestHeaders -> H.RequestHeaders) -> Request -> Request -mapRequestHeaders f request = request { requestHeaders = f (requestHeaders request) } +mapRequestHeaders + :: (H.RequestHeaders -> H.RequestHeaders) -> Request -> Request +mapRequestHeaders f request = request{requestHeaders = f (requestHeaders request)} diff --git a/wai/Network/Wai/Internal.hs b/wai/Network/Wai/Internal.hs index 954ccde77..31d0b3f5f 100644 --- a/wai/Network/Wai/Internal.hs +++ b/wai/Network/Wai/Internal.hs @@ -1,94 +1,99 @@ -{-# OPTIONS_HADDOCK not-home #-} {-# LANGUAGE RecordWildCards #-} +{-# OPTIONS_HADDOCK not-home #-} + -- | Internal constructors and helper functions. Note that no guarantees are -- given for stability of these interfaces. module Network.Wai.Internal where -import Data.ByteString.Builder (Builder) -import qualified Data.ByteString as B -import Data.Text (Text) -import Data.Typeable (Typeable) -import Data.Vault.Lazy (Vault) -import Data.Word (Word64) -import qualified Network.HTTP.Types as H -import Network.Socket (SockAddr) -import Data.List (intercalate) +import qualified Data.ByteString as B +import Data.ByteString.Builder (Builder) +import Data.List (intercalate) +import Data.Text (Text) +import Data.Typeable (Typeable) +import Data.Vault.Lazy (Vault) +import Data.Word (Word64) +import qualified Network.HTTP.Types as H +import Network.Socket (SockAddr) -- | Information on the request sent by the client. This abstracts away the -- details of the underlying implementation. -{-# DEPRECATED requestBody "requestBody's name is misleading because it only gets a partial chunk of the body. Use getRequestBodyChunk instead when getting the field, and setRequestBodyChunks when setting the field." #-} -data Request = Request { - -- | Request method such as GET. - requestMethod :: H.Method - -- | HTTP version such as 1.1. - , httpVersion :: H.HttpVersion - -- | Extra path information sent by the client. The meaning varies slightly - -- depending on backend; in a standalone server setting, this is most likely - -- all information after the domain name. In a CGI application, this would be - -- the information following the path to the CGI executable itself. - -- - -- Middlewares and routing tools should not modify this raw value, as it may - -- be used for such things as creating redirect destinations by applications. - -- Instead, if you are writing a middleware or routing framework, modify the - -- @pathInfo@ instead. This is the approach taken by systems like Yesod - -- subsites. - -- - -- /Note/: At the time of writing this documentation, there is at least one - -- system (@Network.Wai.UrlMap@ from @wai-extra@) that does not follow the - -- above recommendation. Therefore, it is recommended that you test the - -- behavior of your application when using @rawPathInfo@ and any form of - -- library that might modify the @Request@. - , rawPathInfo :: B.ByteString - -- | If no query string was specified, this should be empty. This value - -- /will/ include the leading question mark. - -- Do not modify this raw value - modify queryString instead. - , rawQueryString :: B.ByteString - -- | A list of headers (a pair of key and value) in an HTTP request. - , requestHeaders :: H.RequestHeaders - -- | Was this request made over an SSL connection? - -- - -- Note that this value will /not/ tell you if the client originally made - -- this request over SSL, but rather whether the current connection is SSL. - -- The distinction lies with reverse proxies. In many cases, the client will - -- connect to a load balancer over SSL, but connect to the WAI handler - -- without SSL. In such a case, 'isSecure' will be 'False', but from a user - -- perspective, there is a secure connection. - , isSecure :: Bool - -- | The client\'s host information. - , remoteHost :: SockAddr - -- | Path info in individual pieces - the URL without a hostname/port and - -- without a query string, split on forward slashes. - , pathInfo :: [Text] - -- | Parsed query string information. - , queryString :: H.Query - -- | Get the next chunk of the body. Returns 'B.empty' when the - -- body is fully consumed. Since 3.2.2, this is deprecated in favor of 'getRequestBodyChunk'. - , requestBody :: IO B.ByteString - -- | A location for arbitrary data to be shared by applications and middleware. - , vault :: Vault - -- | The size of the request body. In the case of a chunked request body, - -- this may be unknown. - -- - -- @since 1.4.0 - , requestBodyLength :: RequestBodyLength - -- | The value of the Host header in a HTTP request. - -- - -- @since 2.0.0 - , requestHeaderHost :: Maybe B.ByteString - -- | The value of the Range header in a HTTP request. - -- - -- @since 2.0.0 - , requestHeaderRange :: Maybe B.ByteString - -- | The value of the Referer header in a HTTP request. - -- - -- @since 3.2.0 - , requestHeaderReferer :: Maybe B.ByteString - -- | The value of the User-Agent header in a HTTP request. - -- - -- @since 3.2.0 - , requestHeaderUserAgent :: Maybe B.ByteString - } - deriving (Typeable) +{-# DEPRECATED + requestBody + "requestBody's name is misleading because it only gets a partial chunk of the body. Use getRequestBodyChunk instead when getting the field, and setRequestBodyChunks when setting the field." + #-} + +data Request = Request + { requestMethod :: H.Method + -- ^ Request method such as GET. + , httpVersion :: H.HttpVersion + -- ^ HTTP version such as 1.1. + , rawPathInfo :: B.ByteString + -- ^ Extra path information sent by the client. The meaning varies slightly + -- depending on backend; in a standalone server setting, this is most likely + -- all information after the domain name. In a CGI application, this would be + -- the information following the path to the CGI executable itself. + -- + -- Middlewares and routing tools should not modify this raw value, as it may + -- be used for such things as creating redirect destinations by applications. + -- Instead, if you are writing a middleware or routing framework, modify the + -- @pathInfo@ instead. This is the approach taken by systems like Yesod + -- subsites. + -- + -- /Note/: At the time of writing this documentation, there is at least one + -- system (@Network.Wai.UrlMap@ from @wai-extra@) that does not follow the + -- above recommendation. Therefore, it is recommended that you test the + -- behavior of your application when using @rawPathInfo@ and any form of + -- library that might modify the @Request@. + , rawQueryString :: B.ByteString + -- ^ If no query string was specified, this should be empty. This value + -- /will/ include the leading question mark. + -- Do not modify this raw value - modify queryString instead. + , requestHeaders :: H.RequestHeaders + -- ^ A list of headers (a pair of key and value) in an HTTP request. + , isSecure :: Bool + -- ^ Was this request made over an SSL connection? + -- + -- Note that this value will /not/ tell you if the client originally made + -- this request over SSL, but rather whether the current connection is SSL. + -- The distinction lies with reverse proxies. In many cases, the client will + -- connect to a load balancer over SSL, but connect to the WAI handler + -- without SSL. In such a case, 'isSecure' will be 'False', but from a user + -- perspective, there is a secure connection. + , remoteHost :: SockAddr + -- ^ The client\'s host information. + , pathInfo :: [Text] + -- ^ Path info in individual pieces - the URL without a hostname/port and + -- without a query string, split on forward slashes. + , queryString :: H.Query + -- ^ Parsed query string information. + , requestBody :: IO B.ByteString + -- ^ Get the next chunk of the body. Returns 'B.empty' when the + -- body is fully consumed. Since 3.2.2, this is deprecated in favor of 'getRequestBodyChunk'. + , vault :: Vault + -- ^ A location for arbitrary data to be shared by applications and middleware. + , requestBodyLength :: RequestBodyLength + -- ^ The size of the request body. In the case of a chunked request body, + -- this may be unknown. + -- + -- @since 1.4.0 + , requestHeaderHost :: Maybe B.ByteString + -- ^ The value of the Host header in a HTTP request. + -- + -- @since 2.0.0 + , requestHeaderRange :: Maybe B.ByteString + -- ^ The value of the Range header in a HTTP request. + -- + -- @since 2.0.0 + , requestHeaderReferer :: Maybe B.ByteString + -- ^ The value of the Referer header in a HTTP request. + -- + -- @since 3.2.0 + , requestHeaderUserAgent :: Maybe B.ByteString + -- ^ The value of the User-Agent header in a HTTP request. + -- + -- @since 3.2.0 + } + deriving (Typeable) -- | Get the next chunk of the body. Returns 'B.empty' when the -- body is fully consumed. @@ -106,35 +111,34 @@ getRequestBodyChunk = requestBody -- @since 3.2.4 setRequestBodyChunks :: IO B.ByteString -> Request -> Request setRequestBodyChunks requestBody r = - r {requestBody = requestBody} + r{requestBody = requestBody} instance Show Request where - show Request{..} = "Request {" ++ intercalate ", " [a ++ " = " ++ b | (a,b) <- fields] ++ "}" - where - fields = - [("requestMethod",show requestMethod) - ,("httpVersion",show httpVersion) - ,("rawPathInfo",show rawPathInfo) - ,("rawQueryString",show rawQueryString) - ,("requestHeaders",show requestHeaders) - ,("isSecure",show isSecure) - ,("remoteHost",show remoteHost) - ,("pathInfo",show pathInfo) - ,("queryString",show queryString) - ,("requestBody","") - ,("vault","") - ,("requestBodyLength",show requestBodyLength) - ,("requestHeaderHost",show requestHeaderHost) - ,("requestHeaderRange",show requestHeaderRange) - ] - + show Request{..} = "Request {" ++ intercalate ", " [a ++ " = " ++ b | (a, b) <- fields] ++ "}" + where + fields = + [ ("requestMethod", show requestMethod) + , ("httpVersion", show httpVersion) + , ("rawPathInfo", show rawPathInfo) + , ("rawQueryString", show rawQueryString) + , ("requestHeaders", show requestHeaders) + , ("isSecure", show isSecure) + , ("remoteHost", show remoteHost) + , ("pathInfo", show pathInfo) + , ("queryString", show queryString) + , ("requestBody", "") + , ("vault", "") + , ("requestBodyLength", show requestBodyLength) + , ("requestHeaderHost", show requestHeaderHost) + , ("requestHeaderRange", show requestHeaderRange) + ] data Response = ResponseFile H.Status H.ResponseHeaders FilePath (Maybe FilePart) | ResponseBuilder H.Status H.ResponseHeaders Builder | ResponseStream H.Status H.ResponseHeaders StreamingBody | ResponseRaw (IO B.ByteString -> (B.ByteString -> IO ()) -> IO ()) Response - deriving Typeable + deriving (Typeable) -- | Represents a streaming HTTP response body. It's a function of two -- parameters; the first parameter provides a means of sending another chunk of @@ -148,7 +152,7 @@ type StreamingBody = (Builder -> IO ()) -> IO () -> IO () -- not be known. -- -- @since 1.4.0 -data RequestBodyLength = ChunkedBody | KnownLength Word64 deriving Show +data RequestBodyLength = ChunkedBody | KnownLength Word64 deriving (Show) -- | Information on which part to be sent. -- Sophisticated application handles Range (and If-Range) then @@ -156,10 +160,11 @@ data RequestBodyLength = ChunkedBody | KnownLength Word64 deriving Show -- -- @since 0.4.0 data FilePart = FilePart - { filePartOffset :: Integer + { filePartOffset :: Integer , filePartByteCount :: Integer - , filePartFileSize :: Integer - } deriving Show + , filePartFileSize :: Integer + } + deriving (Show) -- | A special datatype to indicate that the WAI handler has received the -- response. This is to avoid the need for Rank2Types in the definition of @@ -170,4 +175,4 @@ data FilePart = FilePart -- -- @since 3.0.0 data ResponseReceived = ResponseReceived - deriving Typeable + deriving (Typeable) diff --git a/wai/test/Network/WaiSpec.hs b/wai/test/Network/WaiSpec.hs index 503930d0b..91436fce4 100644 --- a/wai/test/Network/WaiSpec.hs +++ b/wai/test/Network/WaiSpec.hs @@ -1,15 +1,16 @@ {-# LANGUAGE LambdaCase #-} + module Network.WaiSpec (spec) where -import Test.Hspec -import Test.Hspec.QuickCheck (prop) -import Network.Wai -import Data.Word (Word8) -import Data.IORef +import Control.Monad (forM_) import qualified Data.ByteString as S -import qualified Data.ByteString.Lazy as L import Data.ByteString.Builder (Builder, toLazyByteString, word8) -import Control.Monad (forM_) +import qualified Data.ByteString.Lazy as L +import Data.IORef +import Data.Word (Word8) +import Network.Wai +import Test.Hspec +import Test.Hspec.QuickCheck (prop) spec :: Spec spec = do @@ -29,8 +30,11 @@ spec = do body <- getBody $ responseLBS undefined undefined $ L.pack bytes body `shouldBe` S.pack bytes prop "responseBuilder" $ \bytes -> do - body <- getBody $ responseBuilder undefined undefined - $ mconcat $ map word8 bytes + body <- + getBody $ + responseBuilder undefined undefined $ + mconcat $ + map word8 bytes body `shouldBe` S.pack bytes prop "responseStream" $ \chunks -> do body <- getBody $ responseStream undefined undefined $ \sendChunk _ -> @@ -47,11 +51,15 @@ spec = do let total = S.length totalBS offset = abs offset' `mod` total count = abs count' `mod` (total - offset) - body <- getBody $ responseFile undefined undefined fp $ Just FilePart - { filePartOffset = fromIntegral offset - , filePartByteCount = fromIntegral count - , filePartFileSize = fromIntegral total - } + body <- + getBody $ + responseFile undefined undefined fp $ + Just + FilePart + { filePartOffset = fromIntegral offset + , filePartByteCount = fromIntegral count + , filePartFileSize = fromIntegral total + } let expected = S.take count $ S.drop offset totalBS body `shouldBe` expected describe "lazyRequestBody" $ do @@ -76,4 +84,4 @@ mkRequestFromChunks chunks = do flip setRequestBodyChunks defaultRequest $ atomicModifyIORef ref $ \case [] -> ([], S.empty) - x:y -> (y, x) + x : y -> (y, x) diff --git a/wai/webkit-sample/webkit-sample.hs b/wai/webkit-sample/webkit-sample.hs index 0ae8e72d4..1b0816352 100644 --- a/wai/webkit-sample/webkit-sample.hs +++ b/wai/webkit-sample/webkit-sample.hs @@ -1,11 +1,12 @@ -{-# LANGUAGE Rank2Types #-} {-# LANGUAGE OverloadedStrings #-} -import Network.Wai -import Network.Wai.Handler.Webkit +{-# LANGUAGE Rank2Types #-} + import Data.ByteString (ByteString) -import qualified Data.ByteString.Lazy as L -import Data.Enumerator (consume, Iteratee) import Data.ByteString.Builder (lazyByteString) +import qualified Data.ByteString.Lazy as L +import Data.Enumerator (Iteratee, consume) +import Network.Wai +import Network.Wai.Handler.Webkit main :: IO () main = putStrLn "http://localhost:3000/" >> run "Webkit Sample" app @@ -18,13 +19,17 @@ app req = case pathInfo req of _ -> indexResponse indexResponse :: Iteratee ByteString IO Response -indexResponse = return $ ResponseFile - status200 - [("Content-Type" , "text/html")] - "index.html" +indexResponse = + return $ + ResponseFile + status200 + [("Content-Type", "text/html")] + "index.html" postResponse :: L.ByteString -> Iteratee ByteString IO Response -postResponse lbs = return $ ResponseBuilder - status200 - [("Content-Type", "text/plain")] - (lazyByteString lbs) +postResponse lbs = + return $ + ResponseBuilder + status200 + [("Content-Type", "text/plain")] + (lazyByteString lbs) diff --git a/warp-quic/Network/Wai/Handler/WarpQUIC.hs b/warp-quic/Network/Wai/Handler/WarpQUIC.hs index 60f1b5f82..44f9a9c8b 100644 --- a/warp-quic/Network/Wai/Handler/WarpQUIC.hs +++ b/warp-quic/Network/Wai/Handler/WarpQUIC.hs @@ -21,21 +21,23 @@ runQUIC :: QUICSettings -> Settings -> Application -> IO () runQUIC quicsettings settings app = do withII settings $ \ii -> Q.run quicsettings $ \conn -> do - info <- getConnectionInfo conn - mccc <- clientCertificateChain conn - let addr = remoteSockAddr info - malpn = alpn info - transport = QUIC { - quicNegotiatedProtocol = malpn - , quicChiperID = cipherID $ cipher info - , quicClientCertificate = mccc - } - pread = pReadMaker ii - timmgr = timeoutManager ii - conf = H3.Config H3.defaultHooks pread timmgr - case malpn of - Nothing -> return () - Just appProto -> do - let runX | "h3" `BS.isPrefixOf` appProto = H3.run - | otherwise = HQ.run - runX conn conf $ http2server settings ii transport addr app + info <- getConnectionInfo conn + mccc <- clientCertificateChain conn + let addr = remoteSockAddr info + malpn = alpn info + transport = + QUIC + { quicNegotiatedProtocol = malpn + , quicChiperID = cipherID $ cipher info + , quicClientCertificate = mccc + } + pread = pReadMaker ii + timmgr = timeoutManager ii + conf = H3.Config H3.defaultHooks pread timmgr + case malpn of + Nothing -> return () + Just appProto -> do + let runX + | "h3" `BS.isPrefixOf` appProto = H3.run + | otherwise = HQ.run + runX conn conf $ http2server settings ii transport addr app diff --git a/warp-quic/Setup.hs b/warp-quic/Setup.hs index 9a994af67..e8ef27dbb 100644 --- a/warp-quic/Setup.hs +++ b/warp-quic/Setup.hs @@ -1,2 +1,3 @@ import Distribution.Simple + main = defaultMain diff --git a/warp-tls/graceful-pong.hs b/warp-tls/graceful-pong.hs index 659b56a14..a1937a318 100644 --- a/warp-tls/graceful-pong.hs +++ b/warp-tls/graceful-pong.hs @@ -1,44 +1,45 @@ {-# LANGUAGE OverloadedStrings #-} + +import Blaze.ByteString.Builder (copyByteString) import Control.Concurrent -import UnliftIO.Async +import qualified Data.Conduit as C +import qualified Data.Conduit.List as CL +import Data.Monoid import Network +import Network.HTTP.Types (status200) +import Network.Socket (Socket (..), SocketStatus (..), mkSocket) import Network.Wai import Network.Wai.Handler.Warp (defaultSettings) import Network.Wai.Handler.WarpTLS -import Network.HTTP.Types (status200) -import Network.Socket (Socket(..), mkSocket, SocketStatus(..)) -import Blaze.ByteString.Builder (copyByteString) -import Data.Monoid -import qualified Data.Conduit as C -import qualified Data.Conduit.List as CL import System.Environment (getEnv) import System.Posix.Process import System.Posix.Signals +import UnliftIO.Async envSocketName = "GRACEFUL_PONG_SOCKET" fromSocket2Env sock = [(envSocketName, show (fd, addrFamily, socketType, protocolNumber))] - where - MkSocket fd addrFamily socketType protocolNumber _socketStatus = sock + where + MkSocket fd addrFamily socketType protocolNumber _socketStatus = sock handleSIGHUP proc spawn = modifyMVar_ proc $ \oldpid -> do - pid <- spawn - signalProcess lostConnection oldpid - status <- waitpid oldpid - print (oldpid, status) - return pid + pid <- spawn + signalProcess lostConnection oldpid + status <- waitpid oldpid + print (oldpid, status) + return pid waitpid pid = do mstatus <- getProcessStatus False True pid case mstatus of - Nothing -> waitpid pid - Just status -> return status + Nothing -> waitpid pid + Just status -> return status handleSIGTERM proc finish = - takeMVar proc >>= - signalProcess softwareTermination >> - putMVar finish () + takeMVar proc + >>= signalProcess softwareTermination + >> putMVar finish () child (fd, addrFamily, socketType, protocolNumber) = do sock <- mkSocket fd addrFamily socketType protocolNumber Listening @@ -46,7 +47,8 @@ child (fd, addrFamily, socketType, protocolNumber) = do finish <- newEmptyMVar let sighup = Catch $ putMVar finish () _handler <- installHandler lostConnection sighup Nothing - runTLSSocket (TLSSettings "certificate.pem" "key.pem") defaultSettings sock app `race_` takeMVar finish + runTLSSocket (TLSSettings "certificate.pem" "key.pem") defaultSettings sock app + `race_` takeMVar finish parent = do sock <- listenOn $ PortNumber $ toEnum 3000 @@ -73,53 +75,63 @@ app req = return $ "/source/nolen" -> sourceNoLen x -> index x -builderWithLen = ResponseBuilder - status200 - [ ("Content-Type", "text/plain") - , ("Content-Length", "4") - ] - $ copyByteString "PONG" - -builderNoLen = ResponseBuilder - status200 - [ ("Content-Type", "text/plain") - ] - $ copyByteString "PONG" - -sourceWithLen = ResponseSource - status200 - [ ("Content-Type", "text/plain") - , ("Content-Length", "4") - ] - $ CL.sourceList [C.Chunk $ copyByteString "PONG"] - -sourceNoLen = ResponseSource - status200 - [ ("Content-Type", "text/plain") - ] - $ CL.sourceList [C.Chunk $ copyByteString "PONG"] - -fileWithLen = ResponseFile - status200 - [ ("Content-Type", "text/plain") - , ("Content-Length", "4") - ] - "pong.txt" - Nothing - -fileNoLen = ResponseFile - status200 - [ ("Content-Type", "text/plain") - ] - "pong.txt" - Nothing - -index p = ResponseBuilder status200 [("Content-Type", "text/html")] $ mconcat $ map copyByteString - [ "

builder withlen

\n" - , "

builder nolen

\n" - , "

file withlen

\n" - , "

file nolen

\n" - , "

source withlen

\n" - , "

source nolen

\n" - , p - ] +builderWithLen = + ResponseBuilder + status200 + [ ("Content-Type", "text/plain") + , ("Content-Length", "4") + ] + $ copyByteString "PONG" + +builderNoLen = + ResponseBuilder + status200 + [ ("Content-Type", "text/plain") + ] + $ copyByteString "PONG" + +sourceWithLen = + ResponseSource + status200 + [ ("Content-Type", "text/plain") + , ("Content-Length", "4") + ] + $ CL.sourceList [C.Chunk $ copyByteString "PONG"] + +sourceNoLen = + ResponseSource + status200 + [ ("Content-Type", "text/plain") + ] + $ CL.sourceList [C.Chunk $ copyByteString "PONG"] + +fileWithLen = + ResponseFile + status200 + [ ("Content-Type", "text/plain") + , ("Content-Length", "4") + ] + "pong.txt" + Nothing + +fileNoLen = + ResponseFile + status200 + [ ("Content-Type", "text/plain") + ] + "pong.txt" + Nothing + +index p = + ResponseBuilder status200 [("Content-Type", "text/html")] $ + mconcat $ + map + copyByteString + [ "

builder withlen

\n" + , "

builder nolen

\n" + , "

file withlen

\n" + , "

file nolen

\n" + , "

source withlen

\n" + , "

source nolen

\n" + , p + ] diff --git a/warp-tls/pong.hs b/warp-tls/pong.hs index 279f40148..8477b794b 100644 --- a/warp-tls/pong.hs +++ b/warp-tls/pong.hs @@ -1,21 +1,25 @@ {-# LANGUAGE OverloadedStrings #-} -import Network.Wai -import Network.Wai.Handler.Warp (defaultSettings, setPort) -import Network.Wai.Handler.WarpTLS -import Network.HTTP.Types (status200) + import Blaze.ByteString.Builder (copyByteString) -import Data.Monoid import qualified Data.Conduit as C import qualified Data.Conduit.List as CL -import Network.HTTP.ReverseProxy +import Data.Monoid import Network.HTTP.Conduit +import Network.HTTP.ReverseProxy +import Network.HTTP.Types (status200) +import Network.Wai +import Network.Wai.Handler.Warp (defaultSettings, setPort) +import Network.Wai.Handler.WarpTLS main = do putStrLn "https://localhost:3009/" - --manager <- newManager def - runTLS (tlsSettings "certificate.pem" "key.pem") - { onInsecure = AllowInsecure - } (setPort 3009 defaultSettings) app + -- manager <- newManager def + runTLS + (tlsSettings "certificate.pem" "key.pem") + { onInsecure = AllowInsecure + } + (setPort 3009 defaultSettings) + app app req f = f $ case rawPathInfo req of @@ -28,18 +32,20 @@ app req f = f $ "/secure" -> responseLBS status200 [] $ if isSecure req then "secure" else "insecure" x -> index x -builderWithLen = responseBuilder - status200 - [ ("Content-Type", "text/plain") - , ("Content-Length", "4") - ] - $ copyByteString "PONG" +builderWithLen = + responseBuilder + status200 + [ ("Content-Type", "text/plain") + , ("Content-Length", "4") + ] + $ copyByteString "PONG" -builderNoLen = responseBuilder - status200 - [ ("Content-Type", "text/plain") - ] - $ copyByteString "PONG" +builderNoLen = + responseBuilder + status200 + [ ("Content-Type", "text/plain") + ] + $ copyByteString "PONG" sourceWithLen = responseStream status200 @@ -54,27 +60,33 @@ sourceNoLen = responseStream ] $ \sendChunk _flush -> sendChunk $ copyByteString "PONG" -fileWithLen = responseFile - status200 - [ ("Content-Type", "text/plain") - , ("Content-Length", "4") - ] - "pong.txt" - Nothing +fileWithLen = + responseFile + status200 + [ ("Content-Type", "text/plain") + , ("Content-Length", "4") + ] + "pong.txt" + Nothing -fileNoLen = responseFile - status200 - [ ("Content-Type", "text/plain") - ] - "pong.txt" - Nothing +fileNoLen = + responseFile + status200 + [ ("Content-Type", "text/plain") + ] + "pong.txt" + Nothing -index p = responseBuilder status200 [("Content-Type", "text/html")] $ mconcat $ map copyByteString - [ "

builder withlen

\n" - , "

builder nolen

\n" - , "

file withlen

\n" - , "

file nolen

\n" - , "

source withlen

\n" - , "

source nolen

\n" - , p - ] +index p = + responseBuilder status200 [("Content-Type", "text/html")] $ + mconcat $ + map + copyByteString + [ "

builder withlen

\n" + , "

builder nolen

\n" + , "

file withlen

\n" + , "

file nolen

\n" + , "

source withlen

\n" + , "

source nolen

\n" + , p + ] diff --git a/warp/Network/Wai/Handler/Warp/Buffer.hs b/warp/Network/Wai/Handler/Warp/Buffer.hs index f15d8e83f..931ea8a8c 100644 --- a/warp/Network/Wai/Handler/Warp/Buffer.hs +++ b/warp/Network/Wai/Handler/Warp/Buffer.hs @@ -1,15 +1,15 @@ module Network.Wai.Handler.Warp.Buffer ( - createWriteBuffer - , allocateBuffer - , freeBuffer - , toBuilderBuffer - , bufferIO - ) where + createWriteBuffer, + allocateBuffer, + freeBuffer, + toBuilderBuffer, + bufferIO, +) where import Data.IORef (IORef, readIORef) import qualified Data.Streaming.ByteString.Builder.Buffer as B (Buffer (..)) import Foreign.ForeignPtr -import Foreign.Marshal.Alloc (mallocBytes, free) +import Foreign.Marshal.Alloc (free, mallocBytes) import Foreign.Ptr (plusPtr) import Network.Socket.BufferPool @@ -22,13 +22,13 @@ import Network.Wai.Handler.Warp.Types -- containing that size and a finalizer. createWriteBuffer :: BufSize -> IO WriteBuffer createWriteBuffer size = do - bytes <- allocateBuffer size - return - WriteBuffer - { bufBuffer = bytes, - bufSize = size, - bufFree = freeBuffer bytes - } + bytes <- allocateBuffer size + return + WriteBuffer + { bufBuffer = bytes + , bufSize = size + , bufFree = freeBuffer bytes + } ---------------------------------------------------------------- diff --git a/warp/Network/Wai/Handler/Warp/Conduit.hs b/warp/Network/Wai/Handler/Warp/Conduit.hs index 2c6b62101..f9d846366 100644 --- a/warp/Network/Wai/Handler/Warp/Conduit.hs +++ b/warp/Network/Wai/Handler/Warp/Conduit.hs @@ -2,10 +2,10 @@ module Network.Wai.Handler.Warp.Conduit where -import UnliftIO (assert, throwIO) import qualified Data.ByteString as S import qualified Data.IORef as I import Data.Word8 (_0, _9, _A, _F, _a, _cr, _f, _lf) +import UnliftIO (assert, throwIO) import Network.Wai.Handler.Warp.Imports import Network.Wai.Handler.Warp.Types @@ -30,44 +30,44 @@ readISource (ISource src ref) = do if count == 0 then return S.empty else do - - bs <- readSource src - - -- If no chunk available, then there aren't enough bytes in the - -- stream. Throw a ConnectionClosedByPeer - when (S.null bs) $ throwIO ConnectionClosedByPeer - - let -- How many of the bytes in this chunk to send downstream - toSend = min count (S.length bs) - -- How many bytes will still remain to be sent downstream - count' = count - toSend - - I.writeIORef ref count' - - if count' > 0 then - -- The expected count is greater than the size of the - -- chunk we just read. Send the entire chunk - -- downstream, and then loop on this function for the - -- next chunk. - return bs - else do - -- Some of the bytes in this chunk should not be sent - -- downstream. Split up the chunk into the sent and - -- not-sent parts, add the not-sent parts onto the new - -- source, and send the rest of the chunk downstream. - let (x, y) = S.splitAt toSend bs - leftoverSource src y - assert (count' == 0) $ return x + bs <- readSource src + + -- If no chunk available, then there aren't enough bytes in the + -- stream. Throw a ConnectionClosedByPeer + when (S.null bs) $ throwIO ConnectionClosedByPeer + + let -- How many of the bytes in this chunk to send downstream + toSend = min count (S.length bs) + -- How many bytes will still remain to be sent downstream + count' = count - toSend + + I.writeIORef ref count' + + if count' > 0 + then -- The expected count is greater than the size of the + -- chunk we just read. Send the entire chunk + -- downstream, and then loop on this function for the + -- next chunk. + return bs + else do + -- Some of the bytes in this chunk should not be sent + -- downstream. Split up the chunk into the sent and + -- not-sent parts, add the not-sent parts onto the new + -- source, and send the rest of the chunk downstream. + let (x, y) = S.splitAt toSend bs + leftoverSource src y + assert (count' == 0) $ return x ---------------------------------------------------------------- data CSource = CSource !Source !(I.IORef ChunkState) -data ChunkState = NeedLen - | NeedLenNewline - | HaveLen Word - | DoneChunking - deriving Show +data ChunkState + = NeedLen + | NeedLenNewline + | HaveLen Word + | DoneChunking + deriving (Show) mkCSource :: Source -> IO CSource mkCSource src = do @@ -144,13 +144,14 @@ readCSource (CSource src ref) = do (x, y) | S.null y -> do bs2 <- readSource' src - return $ if S.null bs2 - then (x, y) - else S.break (== _lf) $ bs `S.append` bs2 + return $ + if S.null bs2 + then (x, y) + else S.break (== _lf) $ bs `S.append` bs2 | otherwise -> return (x, y) let w = - S.foldl' (\i c -> i * 16 + fromIntegral (hexToWord c)) 0 - $ S.takeWhile isHexDigit x + S.foldl' (\i c -> i * 16 + fromIntegral (hexToWord c)) 0 $ + S.takeWhile isHexDigit x let y' = S.drop 1 y y'' <- @@ -165,6 +166,7 @@ readCSource (CSource src ref) = do | otherwise = w - 87 isHexDigit :: Word8 -> Bool -isHexDigit w = w >= _0 && w <= _9 - || w >= _A && w <= _F - || w >= _a && w <= _f +isHexDigit w = + w >= _0 && w <= _9 + || w >= _A && w <= _F + || w >= _a && w <= _f diff --git a/warp/Network/Wai/Handler/Warp/Counter.hs b/warp/Network/Wai/Handler/Warp/Counter.hs index a01c6abf6..34dbcd976 100644 --- a/warp/Network/Wai/Handler/Warp/Counter.hs +++ b/warp/Network/Wai/Handler/Warp/Counter.hs @@ -1,18 +1,17 @@ {-# LANGUAGE CPP #-} module Network.Wai.Handler.Warp.Counter ( - Counter - , newCounter - , waitForZero - , increase - , decrease - ) where + Counter, + newCounter, + waitForZero, + increase, + decrease, +) where import Control.Concurrent.STM import Network.Wai.Handler.Warp.Imports - newtype Counter = Counter (TVar Int) newCounter :: IO Counter diff --git a/warp/Network/Wai/Handler/Warp/Date.hs b/warp/Network/Wai/Handler/Warp/Date.hs index c01975a3b..b85c8ec66 100644 --- a/warp/Network/Wai/Handler/Warp/Date.hs +++ b/warp/Network/Wai/Handler/Warp/Date.hs @@ -1,11 +1,11 @@ {-# LANGUAGE CPP #-} module Network.Wai.Handler.Warp.Date ( - withDateCache - , GMTDate - ) where + withDateCache, + GMTDate, +) where -import Control.AutoUpdate (defaultUpdateSettings, updateAction, mkAutoUpdate) +import Control.AutoUpdate (defaultUpdateSettings, mkAutoUpdate, updateAction) import Data.ByteString import Network.HTTP.Date @@ -25,9 +25,11 @@ withDateCache :: (IO GMTDate -> IO a) -> IO a withDateCache action = initialize >>= action initialize :: IO (IO GMTDate) -initialize = mkAutoUpdate defaultUpdateSettings { - updateAction = formatHTTPDate <$> getCurrentHTTPDate - } +initialize = + mkAutoUpdate + defaultUpdateSettings + { updateAction = formatHTTPDate <$> getCurrentHTTPDate + } #ifdef WINDOWS uToH :: UTCTime -> HTTPDate diff --git a/warp/Network/Wai/Handler/Warp/File.hs b/warp/Network/Wai/Handler/Warp/File.hs index b27084837..8da87dde1 100644 --- a/warp/Network/Wai/Handler/Warp/File.hs +++ b/warp/Network/Wai/Handler/Warp/File.hs @@ -1,13 +1,13 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE BangPatterns #-} module Network.Wai.Handler.Warp.File ( - RspFileInfo(..) - , conditionalRequest - , addContentHeadersForFilePart - , H.parseByteRanges - ) where + RspFileInfo (..), + conditionalRequest, + addContentHeadersForFilePart, + H.parseByteRanges, +) where import Data.Array ((!)) import qualified Data.ByteString.Char8 as C8 (pack) @@ -21,36 +21,39 @@ import Network.Wai.Handler.Warp.Header import Network.Wai.Handler.Warp.Imports import Network.Wai.Handler.Warp.PackInt - -- $setup -- >>> import Test.QuickCheck ---------------------------------------------------------------- -data RspFileInfo = WithoutBody H.Status - | WithBody H.Status H.ResponseHeaders Integer Integer - deriving (Eq,Show) +data RspFileInfo + = WithoutBody H.Status + | WithBody H.Status H.ResponseHeaders Integer Integer + deriving (Eq, Show) ---------------------------------------------------------------- -conditionalRequest :: I.FileInfo - -> H.ResponseHeaders - -> H.Method - -> IndexedHeader -- ^ Response - -> IndexedHeader -- ^ Request - -> RspFileInfo +conditionalRequest + :: I.FileInfo + -> H.ResponseHeaders + -> H.Method + -> IndexedHeader + -- ^ Response + -> IndexedHeader + -- ^ Request + -> RspFileInfo conditionalRequest finfo hs0 method rspidx reqidx = case condition of nobody@(WithoutBody _) -> nobody - WithBody s _ off len -> + WithBody s _ off len -> let !hs1 = addContentHeaders hs0 off len size !hs = case rspidx ! fromEnum ResLastModified of Just _ -> hs1 - Nothing -> (H.hLastModified,date) : hs1 - in WithBody s hs off len + Nothing -> (H.hLastModified, date) : hs1 + in WithBody s hs off len where !mtime = I.fileInfoTime finfo - !size = I.fileInfoSize finfo - !date = I.fileInfoDate finfo + !size = I.fileInfoSize finfo + !date = I.fileInfoDate finfo -- According to RFC 9110: -- "A recipient cache or origin server MUST evaluate the request -- preconditions defined by this specification in the following order: @@ -64,9 +67,10 @@ conditionalRequest finfo hs0 method rspidx reqidx = case condition of -- we also don't want to block middleware or applications from -- using ETags. And sending If-(None-)Match headers in a request -- to a server that doesn't use them is requester's problem. - !mcondition = ifunmodified reqidx mtime - <|> ifmodified reqidx mtime method - <|> ifrange reqidx mtime method size + !mcondition = + ifunmodified reqidx mtime + <|> ifmodified reqidx mtime method + <|> ifrange reqidx mtime method size !condition = fromMaybe (unconditional reqidx size) mcondition ---------------------------------------------------------------- @@ -112,7 +116,7 @@ ifrange reqidx mtime method size = do -- "When the method is GET and both Range and If-Range are -- present, evaluate the If-Range precondition:" date <- ifRange reqidx - rng <- reqidx ! fromEnum ReqRange + rng <- reqidx ! fromEnum ReqRange guard $ method == H.methodGet return $ if date == mtime @@ -122,27 +126,28 @@ ifrange reqidx mtime method size = do unconditional :: IndexedHeader -> Integer -> RspFileInfo unconditional reqidx = case reqidx ! fromEnum ReqRange of - Nothing -> WithBody H.ok200 [] 0 + Nothing -> WithBody H.ok200 [] 0 Just rng -> parseRange rng ---------------------------------------------------------------- parseRange :: ByteString -> Integer -> RspFileInfo parseRange rng size = case H.parseByteRanges rng of - Nothing -> WithoutBody H.requestedRangeNotSatisfiable416 - Just [] -> WithoutBody H.requestedRangeNotSatisfiable416 - Just (r:_) -> let (!beg, !end) = checkRange r size - !len = end - beg + 1 - s = if beg == 0 && end == size - 1 then - H.ok200 - else - H.partialContent206 - in WithBody s [] beg len + Nothing -> WithoutBody H.requestedRangeNotSatisfiable416 + Just [] -> WithoutBody H.requestedRangeNotSatisfiable416 + Just (r : _) -> + let (!beg, !end) = checkRange r size + !len = end - beg + 1 + s = + if beg == 0 && end == size - 1 + then H.ok200 + else H.partialContent206 + in WithBody s [] beg len checkRange :: H.ByteRange -> Integer -> (Integer, Integer) -checkRange (H.ByteRangeFrom beg) size = (beg, size - 1) -checkRange (H.ByteRangeFromTo beg end) size = (beg, min (size - 1) end) -checkRange (H.ByteRangeSuffix count) size = (max 0 (size - count), size - 1) +checkRange (H.ByteRangeFrom beg) size = (beg, size - 1) +checkRange (H.ByteRangeFromTo beg end) size = (beg, min (size - 1) end) +checkRange (H.ByteRangeSuffix count) size = (max 0 (size - count), size - 1) ---------------------------------------------------------------- @@ -151,24 +156,37 @@ checkRange (H.ByteRangeSuffix count) size = (max 0 (size - count), size - 1) contentRangeHeader :: Integer -> Integer -> Integer -> H.Header contentRangeHeader beg end total = (H.hContentRange, range) where - range = C8.pack - -- building with ShowS - $ 'b' : 'y': 't' : 'e' : 's' : ' ' - : (if beg > end then ('*':) else - showInt beg - . ('-' :) - . showInt end) - ( '/' - : showInt total "") - -addContentHeaders :: H.ResponseHeaders -> Integer -> Integer -> Integer -> H.ResponseHeaders + range = + C8.pack + -- building with ShowS + $ + 'b' + : 'y' + : 't' + : 'e' + : 's' + : ' ' + : ( if beg > end + then ('*' :) + else + showInt beg + . ('-' :) + . showInt end + ) + ( '/' + : showInt total "" + ) + +addContentHeaders + :: H.ResponseHeaders -> Integer -> Integer -> Integer -> H.ResponseHeaders addContentHeaders hs off len size - | len == size = hs' - | otherwise = let !ctrng = contentRangeHeader off (off + len - 1) size - in ctrng:hs' + | len == size = hs' + | otherwise = + let !ctrng = contentRangeHeader off (off + len - 1) size + in ctrng : hs' where !lengthBS = packIntegral len - !hs' = (H.hContentLength, lengthBS) : (H.hAcceptRanges,"bytes") : hs + !hs' = (H.hContentLength, lengthBS) : (H.hAcceptRanges, "bytes") : hs -- | -- @@ -176,7 +194,8 @@ addContentHeaders hs off len size -- [("Content-Range","bytes 2-11/16"),("Content-Length","10"),("Accept-Ranges","bytes")] -- >>> addContentHeadersForFilePart [] (FilePart 0 16 16) -- [("Content-Length","16"),("Accept-Ranges","bytes")] -addContentHeadersForFilePart :: H.ResponseHeaders -> FilePart -> H.ResponseHeaders +addContentHeadersForFilePart + :: H.ResponseHeaders -> FilePart -> H.ResponseHeaders addContentHeadersForFilePart hs part = addContentHeaders hs off len size where off = filePartOffset part diff --git a/warp/Network/Wai/Handler/Warp/FileInfoCache.hs b/warp/Network/Wai/Handler/Warp/FileInfoCache.hs index 8fa4fa90f..26e2d388b 100644 --- a/warp/Network/Wai/Handler/Warp/FileInfoCache.hs +++ b/warp/Network/Wai/Handler/Warp/FileInfoCache.hs @@ -1,10 +1,11 @@ -{-# LANGUAGE RecordWildCards, CPP #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE RecordWildCards #-} module Network.Wai.Handler.Warp.FileInfoCache ( - FileInfo(..) - , withFileInfoCache - , getInfo -- test purpose only - ) where + FileInfo (..), + withFileInfoCache, + getInfo, -- test purpose only +) where import Control.Reaper import Network.HTTP.Date @@ -13,7 +14,7 @@ import System.PosixCompat.Files #else import System.Posix.Files #endif -import qualified UnliftIO (onException, bracket, throwIO) +import qualified UnliftIO (bracket, onException, throwIO) import Network.Wai.Handler.Warp.HashMap (HashMap) import qualified Network.Wai.Handler.Warp.HashMap as M @@ -22,16 +23,19 @@ import Network.Wai.Handler.Warp.Imports ---------------------------------------------------------------- -- | File information. -data FileInfo = FileInfo { - fileInfoName :: !FilePath - , fileInfoSize :: !Integer - , fileInfoTime :: HTTPDate -- ^ Modification time - , fileInfoDate :: ByteString -- ^ Modification time in the GMT format - } deriving (Eq, Show) +data FileInfo = FileInfo + { fileInfoName :: !FilePath + , fileInfoSize :: !Integer + , fileInfoTime :: HTTPDate + -- ^ Modification time + , fileInfoDate :: ByteString + -- ^ Modification time in the GMT format + } + deriving (Eq, Show) data Entry = Negative | Positive FileInfo type Cache = HashMap Entry -type FileInfoCache = Reaper Cache (FilePath,Entry) +type FileInfoCache = Reaper Cache (FilePath, Entry) ---------------------------------------------------------------- @@ -41,19 +45,20 @@ getInfo path = do fs <- getFileStatus path -- file access let regular = not (isDirectory fs) readable = fileMode fs `intersectFileModes` ownerReadMode /= 0 - if regular && readable then do - let time = epochTimeToHTTPDate $ modificationTime fs - date = formatHTTPDate time - size = fromIntegral $ fileSize fs - info = FileInfo { - fileInfoName = path - , fileInfoSize = size - , fileInfoTime = time - , fileInfoDate = date - } - return info - else - UnliftIO.throwIO (userError "FileInfoCache:getInfo") + if regular && readable + then do + let time = epochTimeToHTTPDate $ modificationTime fs + date = formatHTTPDate time + size = fromIntegral $ fileSize fs + info = + FileInfo + { fileInfoName = path + , fileInfoSize = size + , fileInfoTime = time + , fileInfoDate = date + } + return info + else UnliftIO.throwIO (userError "FileInfoCache:getInfo") getInfoNaive :: FilePath -> IO FileInfo getInfoNaive = getInfo @@ -64,10 +69,11 @@ getAndRegisterInfo :: FileInfoCache -> FilePath -> IO FileInfo getAndRegisterInfo reaper@Reaper{..} path = do cache <- reaperRead case M.lookup path cache of - Just Negative -> UnliftIO.throwIO (userError "FileInfoCache:getAndRegisterInfo") + Just Negative -> UnliftIO.throwIO (userError "FileInfoCache:getAndRegisterInfo") Just (Positive x) -> return x - Nothing -> positive reaper path - `UnliftIO.onException` negative reaper path + Nothing -> + positive reaper path + `UnliftIO.onException` negative reaper path positive :: FileInfoCache -> FilePath -> IO FileInfo positive Reaper{..} path = do @@ -85,26 +91,28 @@ negative Reaper{..} path = do -- | Creating a file information cache -- and executing the action in the second argument. -- The first argument is a cache duration in second. -withFileInfoCache :: Int - -> ((FilePath -> IO FileInfo) -> IO a) - -> IO a -withFileInfoCache 0 action = action getInfoNaive +withFileInfoCache + :: Int + -> ((FilePath -> IO FileInfo) -> IO a) + -> IO a +withFileInfoCache 0 action = action getInfoNaive withFileInfoCache duration action = UnliftIO.bracket - (initialize duration) - terminate - (action . getAndRegisterInfo) + (initialize duration) + terminate + (action . getAndRegisterInfo) initialize :: Int -> IO FileInfoCache initialize duration = mkReaper settings where - settings = defaultReaperSettings { - reaperAction = override - , reaperDelay = duration - , reaperCons = \(path,v) -> M.insert path v - , reaperNull = M.isEmpty - , reaperEmpty = M.empty - } + settings = + defaultReaperSettings + { reaperAction = override + , reaperDelay = duration + , reaperCons = \(path, v) -> M.insert path v + , reaperNull = M.isEmpty + , reaperEmpty = M.empty + } override :: Cache -> IO (Cache -> Cache) override _ = return $ const M.empty diff --git a/warp/Network/Wai/Handler/Warp/HTTP1.hs b/warp/Network/Wai/Handler/Warp/HTTP1.hs index c60664777..0d16d2423 100644 --- a/warp/Network/Wai/Handler/Warp/HTTP1.hs +++ b/warp/Network/Wai/Handler/Warp/HTTP1.hs @@ -5,21 +5,21 @@ {-# LANGUAGE ScopedTypeVariables #-} module Network.Wai.Handler.Warp.HTTP1 ( - http1 - ) where + http1, +) where -import "iproute" Data.IP (toHostAddress, toHostAddress6) import qualified Control.Concurrent as Conc (yield) -import qualified UnliftIO -import UnliftIO (SomeException, fromException, throwIO) import qualified Data.ByteString as BS import Data.Char (chr) import Data.IORef (IORef, newIORef, readIORef, writeIORef) import Data.Word8 (_cr, _space) -import Network.Socket (SockAddr(SockAddrInet, SockAddrInet6)) +import Network.Socket (SockAddr (SockAddrInet, SockAddrInet6)) import Network.Wai import Network.Wai.Internal (ResponseReceived (ResponseReceived)) import qualified System.TimeManager as T +import UnliftIO (SomeException, fromException, throwIO) +import qualified UnliftIO +import "iproute" Data.IP (toHostAddress, toHostAddress6) import Network.Wai.Handler.Warp.Header import Network.Wai.Handler.Warp.Imports hiding (readInt) @@ -29,7 +29,16 @@ import Network.Wai.Handler.Warp.Response import Network.Wai.Handler.Warp.Settings import Network.Wai.Handler.Warp.Types -http1 :: Settings -> InternalInfo -> Connection -> Transport -> Application -> SockAddr -> T.Handle -> ByteString -> IO () +http1 + :: Settings + -> InternalInfo + -> Connection + -> Transport + -> Application + -> SockAddr + -> T.Handle + -> ByteString + -> IO () http1 settings ii conn transport app origAddr th bs0 = do istatus <- newIORef True src <- mkSource (wrappedRecv conn istatus (settingsSlowlorisSize settings)) @@ -37,7 +46,7 @@ http1 settings ii conn transport app origAddr th bs0 = do addr <- getProxyProtocolAddr src http1server settings ii conn transport app addr th istatus src where - wrappedRecv Connection { connRecv = recv } istatus slowlorisSize = do + wrappedRecv Connection{connRecv = recv} istatus slowlorisSize = do bs <- recv unless (BS.null bs) $ do writeIORef istatus True @@ -55,56 +64,97 @@ http1 settings ii conn transport app origAddr th bs0 = do seg <- readSource src if BS.isPrefixOf "PROXY " seg then parseProxyProtocolHeader src seg - else do leftoverSource src seg - return origAddr + else do + leftoverSource src seg + return origAddr parseProxyProtocolHeader src seg = do - let (header,seg') = BS.break (== _cr) seg + let (header, seg') = BS.break (== _cr) seg maybeAddr = case BS.split _space header of - ["PROXY","TCP4",clientAddr,_,clientPort,_] -> + ["PROXY", "TCP4", clientAddr, _, clientPort, _] -> case [x | (x, t) <- reads (decodeAscii clientAddr), null t] of - [a] -> Just (SockAddrInet (readInt clientPort) - (toHostAddress a)) + [a] -> + Just + ( SockAddrInet + (readInt clientPort) + (toHostAddress a) + ) _ -> Nothing - ["PROXY","TCP6",clientAddr,_,clientPort,_] -> + ["PROXY", "TCP6", clientAddr, _, clientPort, _] -> case [x | (x, t) <- reads (decodeAscii clientAddr), null t] of - [a] -> Just (SockAddrInet6 (readInt clientPort) - 0 - (toHostAddress6 a) - 0) + [a] -> + Just + ( SockAddrInet6 + (readInt clientPort) + 0 + (toHostAddress6 a) + 0 + ) _ -> Nothing - ("PROXY":"UNKNOWN":_) -> + ("PROXY" : "UNKNOWN" : _) -> Just origAddr _ -> Nothing case maybeAddr of Nothing -> throwIO (BadProxyHeader (decodeAscii header)) - Just a -> do leftoverSource src (BS.drop 2 seg') -- drop CRLF - return a + Just a -> do + leftoverSource src (BS.drop 2 seg') -- drop CRLF + return a decodeAscii = map (chr . fromEnum) . BS.unpack -http1server :: Settings -> InternalInfo -> Connection -> Transport -> Application -> SockAddr -> T.Handle -> IORef Bool -> Source -> IO () +http1server + :: Settings + -> InternalInfo + -> Connection + -> Transport + -> Application + -> SockAddr + -> T.Handle + -> IORef Bool + -> Source + -> IO () http1server settings ii conn transport app addr th istatus src = loop True `UnliftIO.catchAny` handler where handler e - -- See comment below referencing - -- https://github.com/yesodweb/wai/issues/618 - | Just NoKeepAliveRequest <- fromException e = return () - -- No valid request - | Just (BadFirstLine _) <- fromException e = return () - | otherwise = do - _ <- sendErrorResponse settings ii conn th istatus defaultRequest { remoteHost = addr } e - throwIO e + -- See comment below referencing + -- https://github.com/yesodweb/wai/issues/618 + | Just NoKeepAliveRequest <- fromException e = return () + -- No valid request + | Just (BadFirstLine _) <- fromException e = return () + | otherwise = do + _ <- + sendErrorResponse + settings + ii + conn + th + istatus + defaultRequest{remoteHost = addr} + e + throwIO e loop firstRequest = do - (req, mremainingRef, idxhdr, nextBodyFlush) <- recvRequest firstRequest settings conn ii th addr src transport - keepAlive <- processRequest settings ii conn app th istatus src req mremainingRef idxhdr nextBodyFlush - `UnliftIO.catchAny` \e -> do - settingsOnException settings (Just req) e - -- Don't throw the error again to prevent calling settingsOnException twice. - return False + (req, mremainingRef, idxhdr, nextBodyFlush) <- + recvRequest firstRequest settings conn ii th addr src transport + keepAlive <- + processRequest + settings + ii + conn + app + th + istatus + src + req + mremainingRef + idxhdr + nextBodyFlush + `UnliftIO.catchAny` \e -> do + settingsOnException settings (Just req) e + -- Don't throw the error again to prevent calling settingsOnException twice. + return False -- When doing a keep-alive connection, the other side may just -- close the connection. We don't want to treat that as an @@ -117,7 +167,19 @@ http1server settings ii conn transport app addr th istatus src = when keepAlive $ loop False -processRequest :: Settings -> InternalInfo -> Connection -> Application -> T.Handle -> IORef Bool -> Source -> Request -> Maybe (IORef Int) -> IndexedHeader -> IO ByteString -> IO Bool +processRequest + :: Settings + -> InternalInfo + -> Connection + -> Application + -> T.Handle + -> IORef Bool + -> Source + -> Request + -> Maybe (IORef Int) + -> IndexedHeader + -> IO ByteString + -> IO Bool processRequest settings ii conn app th istatus src req mremainingRef idxhdr nextBodyFlush = do -- Let the application run for as long as it wants T.pause th @@ -138,8 +200,8 @@ processRequest settings ii conn app th istatus src req mremainingRef idxhdr next case r of Right ResponseReceived -> return () Left (e :: SomeException) - | Just (ExceptionInsideResponseBody e') <- fromException e -> throwIO e' - | otherwise -> do + | Just (ExceptionInsideResponseBody e') <- fromException e -> throwIO e' + | otherwise -> do keepAlive <- sendErrorResponse settings ii conn th istatus req e settingsOnException settings (Just req) e writeIORef keepAliveRef keepAlive @@ -157,8 +219,7 @@ processRequest settings ii conn app th istatus src req mremainingRef idxhdr next Conc.yield if keepAlive - then - -- If there is an unknown or large amount of data to still be read + then -- If there is an unknown or large amount of data to still be read -- from the request body, simple drop this connection instead of -- reading it all in to satisfy a keep-alive request. case settingsMaximumBodyFlush settings of @@ -170,33 +231,47 @@ processRequest settings ii conn app th istatus src req mremainingRef idxhdr next let tryKeepAlive = do -- flush the rest of the request body isComplete <- flushBody nextBodyFlush maxToRead - if isComplete then do - T.resume th - return True - else - return False + if isComplete + then do + T.resume th + return True + else return False case mremainingRef of Just ref -> do remaining <- readIORef ref - if remaining <= maxToRead then - tryKeepAlive - else - return False + if remaining <= maxToRead + then tryKeepAlive + else return False Nothing -> tryKeepAlive - else - return False + else return False -sendErrorResponse :: Settings -> InternalInfo -> Connection -> T.Handle -> IORef Bool -> Request -> SomeException -> IO Bool +sendErrorResponse + :: Settings + -> InternalInfo + -> Connection + -> T.Handle + -> IORef Bool + -> Request + -> SomeException + -> IO Bool sendErrorResponse settings ii conn th istatus req e = do status <- readIORef istatus - if shouldSendErrorResponse e && status then - sendResponse settings conn ii th req defaultIndexRequestHeader (return BS.empty) errorResponse - else - return False + if shouldSendErrorResponse e && status + then + sendResponse + settings + conn + ii + th + req + defaultIndexRequestHeader + (return BS.empty) + errorResponse + else return False where shouldSendErrorResponse se - | Just ConnectionClosedByPeer <- fromException se = False - | otherwise = True + | Just ConnectionClosedByPeer <- fromException se = False + | otherwise = True errorResponse = settingsOnExceptionResponse settings e flushEntireBody :: IO ByteString -> IO () @@ -207,9 +282,13 @@ flushEntireBody src = bs <- src unless (BS.null bs) loop -flushBody :: IO ByteString -- ^ get next chunk - -> Int -- ^ maximum to flush - -> IO Bool -- ^ True == flushed the entire body, False == we didn't +flushBody + :: IO ByteString + -- ^ get next chunk + -> Int + -- ^ maximum to flush + -> IO Bool + -- ^ True == flushed the entire body, False == we didn't flushBody src = loop where loop toRead = do diff --git a/warp/Network/Wai/Handler/Warp/HTTP2/PushPromise.hs b/warp/Network/Wai/Handler/Warp/HTTP2/PushPromise.hs index 78c21aa04..4083b3e00 100644 --- a/warp/Network/Wai/Handler/Warp/HTTP2/PushPromise.hs +++ b/warp/Network/Wai/Handler/Warp/HTTP2/PushPromise.hs @@ -3,9 +3,9 @@ module Network.Wai.Handler.Warp.HTTP2.PushPromise where -import qualified UnliftIO import qualified Network.HTTP.Types as H import qualified Network.HTTP2.Server as H2 +import qualified UnliftIO import Network.Wai import Network.Wai.Handler.Warp.FileInfoCache @@ -24,10 +24,10 @@ fromPushPromise :: InternalInfo -> PushPromise -> IO (Maybe H2.PushPromise) fromPushPromise ii (PushPromise path file rsphdr w) = do efinfo <- UnliftIO.tryIO $ getFileInfo ii file case efinfo of - Left (_ex :: UnliftIO.IOException) -> return Nothing - Right finfo -> do - let !siz = fromIntegral $ fileInfoSize finfo - !fileSpec = H2.FileSpec file 0 siz - !rsp = H2.responseFile H.ok200 rsphdr fileSpec - !pp = H2.pushPromise path rsp w - return $ Just pp + Left (_ex :: UnliftIO.IOException) -> return Nothing + Right finfo -> do + let !siz = fromIntegral $ fileInfoSize finfo + !fileSpec = H2.FileSpec file 0 siz + !rsp = H2.responseFile H.ok200 rsphdr fileSpec + !pp = H2.pushPromise path rsp w + return $ Just pp diff --git a/warp/Network/Wai/Handler/Warp/HTTP2/Response.hs b/warp/Network/Wai/Handler/Warp/HTTP2/Response.hs index 25a6e89d0..ca0f78d57 100644 --- a/warp/Network/Wai/Handler/Warp/HTTP2/Response.hs +++ b/warp/Network/Wai/Handler/Warp/HTTP2/Response.hs @@ -1,50 +1,55 @@ -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} module Network.Wai.Handler.Warp.HTTP2.Response ( - fromResponse - ) where + fromResponse, +) where import qualified Data.ByteString.Builder as BB import qualified Data.List as L (find) import qualified Network.HTTP.Types as H import qualified Network.HTTP2.Server as H2 -import Network.Wai hiding (responseFile, responseBuilder, responseStream) -import Network.Wai.Internal (Response(..)) +import Network.Wai hiding (responseBuilder, responseFile, responseStream) +import Network.Wai.Internal (Response (..)) import qualified UnliftIO import Network.Wai.Handler.Warp.File -import Network.Wai.Handler.Warp.Header import Network.Wai.Handler.Warp.HTTP2.Request (getHTTP2Data) import Network.Wai.Handler.Warp.HTTP2.Types +import Network.Wai.Handler.Warp.Header import qualified Network.Wai.Handler.Warp.Response as R import qualified Network.Wai.Handler.Warp.Settings as S import Network.Wai.Handler.Warp.Types ---------------------------------------------------------------- -fromResponse :: S.Settings -> InternalInfo -> Request -> Response -> IO (H2.Response, H.Status, Bool) +fromResponse + :: S.Settings + -> InternalInfo + -> Request + -> Response + -> IO (H2.Response, H.Status, Bool) fromResponse settings ii req rsp = do date <- getDate ii rspst@(h2rsp, st, hasBody) <- case rsp of - ResponseFile st rsphdr path mpart -> do - let rsphdr' = add date rsphdr - responseFile st rsphdr' method path mpart ii reqhdr - ResponseBuilder st rsphdr builder -> do - let rsphdr' = add date rsphdr - return $ responseBuilder st rsphdr' method builder - ResponseStream st rsphdr strmbdy -> do - let rsphdr' = add date rsphdr - return $ responseStream st rsphdr' method strmbdy - _ -> error "ResponseRaw is not supported in HTTP/2" + ResponseFile st rsphdr path mpart -> do + let rsphdr' = add date rsphdr + responseFile st rsphdr' method path mpart ii reqhdr + ResponseBuilder st rsphdr builder -> do + let rsphdr' = add date rsphdr + return $ responseBuilder st rsphdr' method builder + ResponseStream st rsphdr strmbdy -> do + let rsphdr' = add date rsphdr + return $ responseStream st rsphdr' method strmbdy + _ -> error "ResponseRaw is not supported in HTTP/2" mh2data <- getHTTP2Data req case mh2data of - Nothing -> return rspst - Just h2data -> do - let !trailers = http2dataTrailers h2data - !h2rsp' = H2.setResponseTrailersMaker h2rsp trailers - return (h2rsp', st, hasBody) + Nothing -> return rspst + Just h2data -> do + let !trailers = http2dataTrailers h2data + !h2rsp' = H2.setResponseTrailersMaker h2rsp trailers + return (h2rsp', st, hasBody) where !method = requestMethod req !reqhdr = requestHeaders req @@ -53,24 +58,28 @@ fromResponse settings ii req rsp = do let hasServerHdr = L.find ((== H.hServer) . fst) rsphdr addSVR = maybe ((H.hServer, server) :) (const id) hasServerHdr - in R.addAltSvc settings $ - (H.hDate, date) : addSVR rsphdr + in R.addAltSvc settings $ + (H.hDate, date) : addSVR rsphdr ---------------------------------------------------------------- -responseFile :: H.Status -> H.ResponseHeaders -> H.Method - -> FilePath -> Maybe FilePart -> InternalInfo -> H.RequestHeaders - -> IO (H2.Response, H.Status, Bool) +responseFile + :: H.Status + -> H.ResponseHeaders + -> H.Method + -> FilePath + -> Maybe FilePart + -> InternalInfo + -> H.RequestHeaders + -> IO (H2.Response, H.Status, Bool) responseFile st rsphdr _ _ _ _ _ - | noBody st = return $ responseNoBody st rsphdr - + | noBody st = return $ responseNoBody st rsphdr responseFile st rsphdr method path (Just fp) _ _ = return $ responseFile2XX st rsphdr method fileSpec where - !off' = fromIntegral $ filePartOffset fp + !off' = fromIntegral $ filePartOffset fp !bytes' = fromIntegral $ filePartByteCount fp !fileSpec = H2.FileSpec path off' bytes' - responseFile _ rsphdr method path Nothing ii reqhdr = do efinfo <- UnliftIO.tryIO $ getFileInfo ii path case efinfo of @@ -79,37 +88,48 @@ responseFile _ rsphdr method path Nothing ii reqhdr = do let reqidx = indexRequestHeader reqhdr rspidx = indexResponseHeader rsphdr case conditionalRequest finfo rsphdr method rspidx reqidx of - WithoutBody s -> return $ responseNoBody s rsphdr + WithoutBody s -> return $ responseNoBody s rsphdr WithBody s rsphdr' off bytes -> do - let !off' = fromIntegral off + let !off' = fromIntegral off !bytes' = fromIntegral bytes !fileSpec = H2.FileSpec path off' bytes' return $ responseFile2XX s rsphdr' method fileSpec ---------------------------------------------------------------- -responseFile2XX :: H.Status -> H.ResponseHeaders -> H.Method -> H2.FileSpec -> (H2.Response, H.Status, Bool) +responseFile2XX + :: H.Status + -> H.ResponseHeaders + -> H.Method + -> H2.FileSpec + -> (H2.Response, H.Status, Bool) responseFile2XX st rsphdr method fileSpec - | method == H.methodHead = responseNoBody st rsphdr - | otherwise = (H2.responseFile st rsphdr fileSpec, st, True) + | method == H.methodHead = responseNoBody st rsphdr + | otherwise = (H2.responseFile st rsphdr fileSpec, st, True) ---------------------------------------------------------------- -responseBuilder :: H.Status -> H.ResponseHeaders -> H.Method - -> BB.Builder - -> (H2.Response, H.Status, Bool) +responseBuilder + :: H.Status + -> H.ResponseHeaders + -> H.Method + -> BB.Builder + -> (H2.Response, H.Status, Bool) responseBuilder st rsphdr method builder - | method == H.methodHead || noBody st = responseNoBody st rsphdr - | otherwise = (H2.responseBuilder st rsphdr builder, st, True) + | method == H.methodHead || noBody st = responseNoBody st rsphdr + | otherwise = (H2.responseBuilder st rsphdr builder, st, True) ---------------------------------------------------------------- -responseStream :: H.Status -> H.ResponseHeaders -> H.Method - -> StreamingBody - -> (H2.Response, H.Status, Bool) +responseStream + :: H.Status + -> H.ResponseHeaders + -> H.Method + -> StreamingBody + -> (H2.Response, H.Status, Bool) responseStream st rsphdr method strmbdy - | method == H.methodHead || noBody st = responseNoBody st rsphdr - | otherwise = (H2.responseStreaming st rsphdr strmbdy, st, True) + | method == H.methodHead || noBody st = responseNoBody st rsphdr + | otherwise = (H2.responseStreaming st rsphdr strmbdy, st, True) ---------------------------------------------------------------- diff --git a/warp/Network/Wai/Handler/Warp/HTTP2/Types.hs b/warp/Network/Wai/Handler/Warp/HTTP2/Types.hs index 67b5133a5..d3f2bccf0 100644 --- a/warp/Network/Wai/Handler/Warp/HTTP2/Types.hs +++ b/warp/Network/Wai/Handler/Warp/HTTP2/Types.hs @@ -18,7 +18,7 @@ isHTTP2 TCP = False isHTTP2 tls = useHTTP2 where useHTTP2 = case tlsNegotiatedProtocol tls of - Nothing -> False + Nothing -> False Just proto -> "h2" `BS.isPrefixOf` proto ---------------------------------------------------------------- @@ -26,15 +26,15 @@ isHTTP2 tls = useHTTP2 -- | HTTP/2 specific data. -- -- Since: 3.2.7 -data HTTP2Data = HTTP2Data { - -- | Accessor for 'PushPromise' in 'HTTP2Data'. +data HTTP2Data = HTTP2Data + { http2dataPushPromise :: [PushPromise] + -- ^ Accessor for 'PushPromise' in 'HTTP2Data'. -- -- Since: 3.2.7 - http2dataPushPromise :: [PushPromise] - -- | Accessor for 'H2.TrailersMaker' in 'HTTP2Data'. + , http2dataTrailers :: H2.TrailersMaker + -- ^ Accessor for 'H2.TrailersMaker' in 'HTTP2Data'. -- -- Since: 3.2.8 but the type changed in 3.3.0 - , http2dataTrailers :: H2.TrailersMaker } -- | Default HTTP/2 specific data. @@ -48,30 +48,31 @@ defaultHTTP2Data = HTTP2Data [] H2.defaultTrailersMaker -- while the HTTP/2 library supports other types. -- -- Since: 3.2.7 -data PushPromise = PushPromise { - -- | Accessor for a URL path in 'PushPromise'. +data PushPromise = PushPromise + { promisedPath :: ByteString + -- ^ Accessor for a URL path in 'PushPromise'. -- E.g. \"\/style\/default.css\". -- -- Since: 3.2.7 - promisedPath :: ByteString - -- | Accessor for 'FilePath' in 'PushPromise'. + , promisedFile :: FilePath + -- ^ Accessor for 'FilePath' in 'PushPromise'. -- E.g. \"FILE_PATH/default.css\". -- -- Since: 3.2.7 - , promisedFile :: FilePath - -- | Accessor for 'H.ResponseHeaders' in 'PushPromise' + , promisedResponseHeaders :: H.ResponseHeaders + -- ^ Accessor for 'H.ResponseHeaders' in 'PushPromise' -- \"content-type\" must be specified. -- Default value: []. -- -- -- Since: 3.2.7 - , promisedResponseHeaders :: H.ResponseHeaders - -- | Accessor for 'Weight' in 'PushPromise'. + , promisedWeight :: Weight + -- ^ Accessor for 'Weight' in 'PushPromise'. -- Default value: 16. -- -- Since: 3.2.7 - , promisedWeight :: Weight - } deriving (Eq,Ord,Show) + } + deriving (Eq, Ord, Show) -- | Default push promise. -- diff --git a/warp/Network/Wai/Handler/Warp/HashMap.hs b/warp/Network/Wai/Handler/Warp/HashMap.hs index 31ddd5dfe..0230aa675 100644 --- a/warp/Network/Wai/Handler/Warp/HashMap.hs +++ b/warp/Network/Wai/Handler/Warp/HashMap.hs @@ -28,8 +28,9 @@ isEmpty (HashMap hm) = I.null hm ---------------------------------------------------------------- insert :: FilePath -> v -> HashMap v -> HashMap v -insert path v (HashMap hm) = HashMap - $ I.insertWith M.union (hash path) (M.singleton path v) hm +insert path v (HashMap hm) = + HashMap $ + I.insertWith M.union (hash path) (M.singleton path v) hm lookup :: FilePath -> HashMap v -> Maybe v -lookup path (HashMap hm) = I.lookup (hash path) hm >>= M.lookup path \ No newline at end of file +lookup path (HashMap hm) = I.lookup (hash path) hm >>= M.lookup path diff --git a/warp/Network/Wai/Handler/Warp/Header.hs b/warp/Network/Wai/Handler/Warp/Header.hs index 94970a067..38edd0bc6 100644 --- a/warp/Network/Wai/Handler/Warp/Header.hs +++ b/warp/Network/Wai/Handler/Warp/Header.hs @@ -1,4 +1,5 @@ -{-# LANGUAGE OverloadedStrings, FlexibleContexts #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE OverloadedStrings #-} module Network.Wai.Handler.Warp.Header where @@ -20,20 +21,21 @@ type IndexedHeader = Array Int (Maybe HeaderValue) indexRequestHeader :: RequestHeaders -> IndexedHeader indexRequestHeader hdr = traverseHeader hdr requestMaxIndex requestKeyIndex -data RequestHeaderIndex = ReqContentLength - | ReqTransferEncoding - | ReqExpect - | ReqConnection - | ReqRange - | ReqHost - | ReqIfModifiedSince - | ReqIfUnmodifiedSince - | ReqIfRange - | ReqReferer - | ReqUserAgent - | ReqIfMatch - | ReqIfNoneMatch - deriving (Enum,Bounded) +data RequestHeaderIndex + = ReqContentLength + | ReqTransferEncoding + | ReqExpect + | ReqConnection + | ReqRange + | ReqHost + | ReqIfModifiedSince + | ReqIfUnmodifiedSince + | ReqIfRange + | ReqReferer + | ReqUserAgent + | ReqIfMatch + | ReqIfNoneMatch + deriving (Enum, Bounded) -- | The size for 'IndexedHeader' for HTTP Request. -- From 0 to this corresponds to: @@ -56,36 +58,40 @@ requestMaxIndex = fromEnum (maxBound :: RequestHeaderIndex) requestKeyIndex :: HeaderName -> Int requestKeyIndex hn = case BS.length bs of - 4 | bs == "host" -> fromEnum ReqHost - 5 | bs == "range" -> fromEnum ReqRange - 6 | bs == "expect" -> fromEnum ReqExpect - 7 | bs == "referer" -> fromEnum ReqReferer - 8 | bs == "if-range" -> fromEnum ReqIfRange - | bs == "if-match" -> fromEnum ReqIfMatch - 10 | bs == "user-agent" -> fromEnum ReqUserAgent - | bs == "connection" -> fromEnum ReqConnection - 13 | bs == "if-none-match" -> fromEnum ReqIfNoneMatch - 14 | bs == "content-length" -> fromEnum ReqContentLength - 17 | bs == "transfer-encoding" -> fromEnum ReqTransferEncoding - | bs == "if-modified-since" -> fromEnum ReqIfModifiedSince - 19 | bs == "if-unmodified-since" -> fromEnum ReqIfUnmodifiedSince - _ -> -1 + 4 | bs == "host" -> fromEnum ReqHost + 5 | bs == "range" -> fromEnum ReqRange + 6 | bs == "expect" -> fromEnum ReqExpect + 7 | bs == "referer" -> fromEnum ReqReferer + 8 + | bs == "if-range" -> fromEnum ReqIfRange + | bs == "if-match" -> fromEnum ReqIfMatch + 10 + | bs == "user-agent" -> fromEnum ReqUserAgent + | bs == "connection" -> fromEnum ReqConnection + 13 | bs == "if-none-match" -> fromEnum ReqIfNoneMatch + 14 | bs == "content-length" -> fromEnum ReqContentLength + 17 + | bs == "transfer-encoding" -> fromEnum ReqTransferEncoding + | bs == "if-modified-since" -> fromEnum ReqIfModifiedSince + 19 | bs == "if-unmodified-since" -> fromEnum ReqIfUnmodifiedSince + _ -> -1 where bs = foldedCase hn defaultIndexRequestHeader :: IndexedHeader -defaultIndexRequestHeader = array (0, requestMaxIndex) [(i, Nothing) | i <- [0..requestMaxIndex]] +defaultIndexRequestHeader = array (0, requestMaxIndex) [(i, Nothing) | i <- [0 .. requestMaxIndex]] ---------------------------------------------------------------- indexResponseHeader :: ResponseHeaders -> IndexedHeader indexResponseHeader hdr = traverseHeader hdr responseMaxIndex responseKeyIndex -data ResponseHeaderIndex = ResContentLength - | ResServer - | ResDate - | ResLastModified - deriving (Enum,Bounded) +data ResponseHeaderIndex + = ResContentLength + | ResServer + | ResDate + | ResLastModified + deriving (Enum, Bounded) -- | The size for 'IndexedHeader' for HTTP Response. responseMaxIndex :: Int @@ -93,11 +99,11 @@ responseMaxIndex = fromEnum (maxBound :: ResponseHeaderIndex) responseKeyIndex :: HeaderName -> Int responseKeyIndex hn = case BS.length bs of - 4 | bs == "date" -> fromEnum ResDate - 6 | bs == "server" -> fromEnum ResServer + 4 | bs == "date" -> fromEnum ResDate + 6 | bs == "server" -> fromEnum ResServer 13 | bs == "last-modified" -> fromEnum ResLastModified 14 | bs == "content-length" -> fromEnum ResContentLength - _ -> -1 + _ -> -1 where bs = foldedCase hn @@ -105,12 +111,12 @@ responseKeyIndex hn = case BS.length bs of traverseHeader :: [Header] -> Int -> (HeaderName -> Int) -> IndexedHeader traverseHeader hdr maxidx getIndex = runSTArray $ do - arr <- newArray (0,maxidx) Nothing + arr <- newArray (0, maxidx) Nothing mapM_ (insert arr) hdr return arr where - insert arr (key,val) - | idx == -1 = return () - | otherwise = writeArray arr idx (Just val) + insert arr (key, val) + | idx == -1 = return () + | otherwise = writeArray arr idx (Just val) where idx = getIndex key diff --git a/warp/Network/Wai/Handler/Warp/IO.hs b/warp/Network/Wai/Handler/Warp/IO.hs index 9cc83d8cc..dd0dbac09 100644 --- a/warp/Network/Wai/Handler/Warp/IO.hs +++ b/warp/Network/Wai/Handler/Warp/IO.hs @@ -8,39 +8,43 @@ import Network.Wai.Handler.Warp.Buffer import Network.Wai.Handler.Warp.Imports import Network.Wai.Handler.Warp.Types -toBufIOWith :: Int -> IORef WriteBuffer -> (ByteString -> IO ()) -> Builder -> IO Integer +toBufIOWith + :: Int -> IORef WriteBuffer -> (ByteString -> IO ()) -> Builder -> IO Integer toBufIOWith maxRspBufSize writeBufferRef io builder = do - writeBuffer <- readIORef writeBufferRef - loop writeBuffer firstWriter 0 + writeBuffer <- readIORef writeBufferRef + loop writeBuffer firstWriter 0 where firstWriter = runBuilder builder loop writeBuffer writer bytesSent = do - let buf = bufBuffer writeBuffer - size = bufSize writeBuffer - (len, signal) <- writer buf size - bufferIO buf len io - let totalBytesSent = toInteger len + bytesSent - case signal of - Done -> return totalBytesSent - More minSize next - | size < minSize -> do - when (minSize > maxRspBufSize) $ - error $ "Sending a Builder response required a buffer of size " - ++ show minSize ++ " which is bigger than the specified maximum of " - ++ show maxRspBufSize ++ "!" - -- The current WriteBuffer is too small to fit the next - -- batch of bytes from the Builder so we free it and - -- create a new bigger one. Freeing the current buffer, - -- creating a new one and writing it to the IORef need - -- to be performed atomically to prevent both double - -- frees and missed frees. So we mask async exceptions: - biggerWriteBuffer <- mask_ $ do - bufFree writeBuffer - biggerWriteBuffer <- createWriteBuffer minSize - writeIORef writeBufferRef biggerWriteBuffer - return biggerWriteBuffer - loop biggerWriteBuffer next totalBytesSent - | otherwise -> loop writeBuffer next totalBytesSent - Chunk bs next -> do - io bs - loop writeBuffer next totalBytesSent + let buf = bufBuffer writeBuffer + size = bufSize writeBuffer + (len, signal) <- writer buf size + bufferIO buf len io + let totalBytesSent = toInteger len + bytesSent + case signal of + Done -> return totalBytesSent + More minSize next + | size < minSize -> do + when (minSize > maxRspBufSize) $ + error $ + "Sending a Builder response required a buffer of size " + ++ show minSize + ++ " which is bigger than the specified maximum of " + ++ show maxRspBufSize + ++ "!" + -- The current WriteBuffer is too small to fit the next + -- batch of bytes from the Builder so we free it and + -- create a new bigger one. Freeing the current buffer, + -- creating a new one and writing it to the IORef need + -- to be performed atomically to prevent both double + -- frees and missed frees. So we mask async exceptions: + biggerWriteBuffer <- mask_ $ do + bufFree writeBuffer + biggerWriteBuffer <- createWriteBuffer minSize + writeIORef writeBufferRef biggerWriteBuffer + return biggerWriteBuffer + loop biggerWriteBuffer next totalBytesSent + | otherwise -> loop writeBuffer next totalBytesSent + Chunk bs next -> do + io bs + loop writeBuffer next totalBytesSent diff --git a/warp/Network/Wai/Handler/Warp/Imports.hs b/warp/Network/Wai/Handler/Warp/Imports.hs index 5c886c111..ff0fd3199 100644 --- a/warp/Network/Wai/Handler/Warp/Imports.hs +++ b/warp/Network/Wai/Handler/Warp/Imports.hs @@ -1,23 +1,23 @@ module Network.Wai.Handler.Warp.Imports ( - ByteString(..) - , NonEmpty(..) - , module Control.Applicative - , module Control.Monad - , module Data.Bits - , module Data.Int - , module Data.Monoid - , module Data.Ord - , module Data.Word - , module Data.Maybe - , module Numeric - ) where + ByteString (..), + NonEmpty (..), + module Control.Applicative, + module Control.Monad, + module Data.Bits, + module Data.Int, + module Data.Monoid, + module Data.Ord, + module Data.Word, + module Data.Maybe, + module Numeric, +) where import Control.Applicative import Control.Monad import Data.Bits -import Data.ByteString.Internal (ByteString(..)) +import Data.ByteString.Internal (ByteString (..)) import Data.Int -import Data.List.NonEmpty (NonEmpty(..)) +import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe import Data.Monoid import Data.Ord diff --git a/warp/Network/Wai/Handler/Warp/Internal.hs b/warp/Network/Wai/Handler/Warp/Internal.hs index c98f0d1a9..1fc6983b7 100644 --- a/warp/Network/Wai/Handler/Warp/Internal.hs +++ b/warp/Network/Wai/Handler/Warp/Internal.hs @@ -2,40 +2,49 @@ module Network.Wai.Handler.Warp.Internal ( -- * Settings - Settings (..) - , ProxyProtocol(..) + Settings (..), + ProxyProtocol (..), + -- * Low level run functions - , runSettingsConnection - , runSettingsConnectionMaker - , runSettingsConnectionMakerSecure - , Transport (..) + runSettingsConnection, + runSettingsConnectionMaker, + runSettingsConnectionMakerSecure, + Transport (..), + -- * Connection - , Connection (..) - , socketConnection + Connection (..), + socketConnection, + -- ** Receive - , Recv - , RecvBuf + Recv, + RecvBuf, + -- ** Buffer - , Buffer - , BufSize - , WriteBuffer(..) - , createWriteBuffer - , allocateBuffer - , freeBuffer - , copy + Buffer, + BufSize, + WriteBuffer (..), + createWriteBuffer, + allocateBuffer, + freeBuffer, + copy, + -- ** Sendfile - , FileId (..) - , SendFile - , sendFile - , readSendFile + FileId (..), + SendFile, + sendFile, + readSendFile, + -- * Version - , warpVersion + warpVersion, + -- * Data types - , InternalInfo (..) - , HeaderValue - , IndexedHeader - , requestMaxIndex + InternalInfo (..), + HeaderValue, + IndexedHeader, + requestMaxIndex, + -- * Time out manager + -- | -- -- In order to provide slowloris protection, Warp provides timeout handlers. We @@ -53,25 +62,31 @@ module Network.Wai.Handler.Warp.Internal ( -- resumed as soon as we return from user code. -- -- * Every time data is successfully sent to the client, the timeout is tickled. - , module System.TimeManager + module System.TimeManager, + -- * File descriptor cache - , module Network.Wai.Handler.Warp.FdCache + module Network.Wai.Handler.Warp.FdCache, + -- * File information cache - , module Network.Wai.Handler.Warp.FileInfoCache + module Network.Wai.Handler.Warp.FileInfoCache, + -- * Date - , module Network.Wai.Handler.Warp.Date + module Network.Wai.Handler.Warp.Date, + -- * Request and response - , Source - , recvRequest - , sendResponse + Source, + recvRequest, + sendResponse, + -- * Platform dependent helper functions - , setSocketCloseOnExec - , windowsThreadBlockHack + setSocketCloseOnExec, + windowsThreadBlockHack, + -- * Misc - , http2server - , withII - , pReadMaker - ) where + http2server, + withII, + pReadMaker, +) where import Network.Socket.BufferPool import System.TimeManager diff --git a/warp/Network/Wai/Handler/Warp/MultiMap.hs b/warp/Network/Wai/Handler/Warp/MultiMap.hs index 8a994c091..49bb2dec3 100644 --- a/warp/Network/Wai/Handler/Warp/MultiMap.hs +++ b/warp/Network/Wai/Handler/Warp/MultiMap.hs @@ -1,14 +1,14 @@ module Network.Wai.Handler.Warp.MultiMap ( - MultiMap - , isEmpty - , empty - , singleton - , insert - , Network.Wai.Handler.Warp.MultiMap.lookup - , pruneWith - , toList - , merge - ) where + MultiMap, + isEmpty, + empty, + singleton, + insert, + Network.Wai.Handler.Warp.MultiMap.lookup, + pruneWith, + toList, + merge, +) where import Control.Monad (filterM) import Data.Hashable (hash) @@ -28,7 +28,7 @@ import Prelude -- Silence redundant import warnings -- Because only positive entries are stored, -- Malicious attack cannot cause the inner list to blow up. -- So, lists are good enough. -newtype MultiMap v = MultiMap (IntMap [(FilePath,v)]) +newtype MultiMap v = MultiMap (IntMap [(FilePath, v)]) ---------------------------------------------------------------- @@ -44,7 +44,7 @@ isEmpty (MultiMap mm) = I.null mm -- | O(1) singleton :: FilePath -> v -> MultiMap v -singleton path v = MultiMap $ I.singleton (hash path) [(path,v)] +singleton path v = MultiMap $ I.singleton (hash path) [(path, v)] ---------------------------------------------------------------- @@ -52,35 +52,37 @@ singleton path v = MultiMap $ I.singleton (hash path) [(path,v)] lookup :: FilePath -> MultiMap v -> Maybe v lookup path (MultiMap mm) = case I.lookup (hash path) mm of Nothing -> Nothing - Just s -> Prelude.lookup path s + Just s -> Prelude.lookup path s ---------------------------------------------------------------- -- | O(log n) insert :: FilePath -> v -> MultiMap v -> MultiMap v -insert path v (MultiMap mm) = MultiMap - $ I.insertWith (<>) (hash path) [(path,v)] mm +insert path v (MultiMap mm) = + MultiMap $ + I.insertWith (<>) (hash path) [(path, v)] mm ---------------------------------------------------------------- -- | O(n) -toList :: MultiMap v -> [(FilePath,v)] +toList :: MultiMap v -> [(FilePath, v)] toList (MultiMap mm) = concatMap snd $ I.toAscList mm ---------------------------------------------------------------- -- | O(n) -pruneWith :: MultiMap v - -> ((FilePath,v) -> IO Bool) - -> IO (MultiMap v) -pruneWith (MultiMap mm) action - = I.foldrWithKey go (pure . MultiMap) mm I.empty +pruneWith + :: MultiMap v + -> ((FilePath, v) -> IO Bool) + -> IO (MultiMap v) +pruneWith (MultiMap mm) action = + I.foldrWithKey go (pure . MultiMap) mm I.empty where go h s cont acc = do - rs <- filterM action s - case rs of - [] -> cont acc - _ -> cont $! I.insert h rs acc + rs <- filterM action s + case rs of + [] -> cont acc + _ -> cont $! I.insert h rs acc ---------------------------------------------------------------- diff --git a/warp/Network/Wai/Handler/Warp/PackInt.hs b/warp/Network/Wai/Handler/Warp/PackInt.hs index 930206d43..db0e8e8ff 100644 --- a/warp/Network/Wai/Handler/Warp/PackInt.hs +++ b/warp/Network/Wai/Handler/Warp/PackInt.hs @@ -1,4 +1,5 @@ -{-# LANGUAGE OverloadedStrings, BangPatterns #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE OverloadedStrings #-} module Network.Wai.Handler.Warp.PackInt where @@ -18,7 +19,6 @@ import Network.Wai.Handler.Warp.Imports -- -- prop> packIntegral (abs n) == C8.pack (show (abs n)) -- prop> \(Large n) -> let n' = fromIntegral (abs n :: Int) in packIntegral n' == C8.pack (show n') - packIntegral :: Integral a => a -> ByteString packIntegral 0 = "0" packIntegral n | n < 0 = error "packIntegral" @@ -29,10 +29,9 @@ packIntegral n = unsafeCreate len go0 go0 p = go n $ p `plusPtr` (len - 1) go :: Integral a => a -> Ptr Word8 -> IO () go i p = do - let (d,r) = i `divMod` 10 + let (d, r) = i `divMod` 10 poke p (_0 + fromIntegral r) when (d /= 0) $ go d (p `plusPtr` (-1)) - {-# SPECIALIZE packIntegral :: Int -> ByteString #-} {-# SPECIALIZE packIntegral :: Integer -> ByteString #-} @@ -42,16 +41,15 @@ packIntegral n = unsafeCreate len go0 -- "200" -- >>> packStatus H.preconditionFailed412 -- "412" - packStatus :: H.Status -> ByteString packStatus status = unsafeCreate 3 $ \p -> do - poke p (toW8 r2) + poke p (toW8 r2) poke (p `plusPtr` 1) (toW8 r1) poke (p `plusPtr` 2) (toW8 r0) where toW8 :: Int -> Word8 toW8 n = _0 + fromIntegral n !s = fromIntegral $ H.statusCode status - (!q0,!r0) = s `divMod` 10 - (!q1,!r1) = q0 `divMod` 10 + (!q0, !r0) = s `divMod` 10 + (!q1, !r1) = q0 `divMod` 10 !r2 = q1 `mod` 10 diff --git a/warp/Network/Wai/Handler/Warp/ReadInt.hs b/warp/Network/Wai/Handler/Warp/ReadInt.hs index cd4bad1e7..24a44b0ed 100644 --- a/warp/Network/Wai/Handler/Warp/ReadInt.hs +++ b/warp/Network/Wai/Handler/Warp/ReadInt.hs @@ -1,13 +1,13 @@ -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE OverloadedStrings #-} -- Copyright : Erik de Castro Lopo -- License : BSD3 module Network.Wai.Handler.Warp.ReadInt ( - readInt - , readInt64 - ) where + readInt, + readInt64, +) where import qualified Data.ByteString as S import Data.Word8 (isDigit, _0) @@ -28,5 +28,6 @@ readInt bs = fromIntegral $ readInt64 bs {-# NOINLINE readInt64 #-} readInt64 :: ByteString -> Int64 -readInt64 bs = S.foldl' (\ !i !c -> i * 10 + fromIntegral (c - _0)) 0 - $ S.takeWhile isDigit bs +readInt64 bs = + S.foldl' (\ !i !c -> i * 10 + fromIntegral (c - _0)) 0 $ + S.takeWhile isDigit bs diff --git a/warp/Network/Wai/Handler/Warp/RequestHeader.hs b/warp/Network/Wai/Handler/Warp/RequestHeader.hs index 1941a68ad..1caf79dd6 100644 --- a/warp/Network/Wai/Handler/Warp/RequestHeader.hs +++ b/warp/Network/Wai/Handler/Warp/RequestHeader.hs @@ -1,19 +1,19 @@ {-# LANGUAGE BangPatterns #-} module Network.Wai.Handler.Warp.RequestHeader ( - parseHeaderLines - ) where + parseHeaderLines, +) where -import UnliftIO (throwIO) import qualified Data.ByteString as S import qualified Data.ByteString.Char8 as C8 (unpack) import Data.ByteString.Internal (memchr) import qualified Data.CaseInsensitive as CI import Data.Word8 import Foreign.ForeignPtr (withForeignPtr) -import Foreign.Ptr (Ptr, plusPtr, minusPtr, nullPtr) +import Foreign.Ptr (Ptr, minusPtr, nullPtr, plusPtr) import Foreign.Storable (peek) import qualified Network.HTTP.Types as H +import UnliftIO (throwIO) import Network.Wai.Handler.Warp.Imports import Network.Wai.Handler.Warp.Types @@ -23,16 +23,18 @@ import Network.Wai.Handler.Warp.Types ---------------------------------------------------------------- -parseHeaderLines :: [ByteString] - -> IO (H.Method - ,ByteString -- Path - ,ByteString -- Path, parsed - ,ByteString -- Query - ,H.HttpVersion - ,H.RequestHeaders - ) +parseHeaderLines + :: [ByteString] + -> IO + ( H.Method + , ByteString -- Path + , ByteString -- Path, parsed + , ByteString -- Query + , H.HttpVersion + , H.RequestHeaders + ) parseHeaderLines [] = throwIO $ NotEnoughLines [] -parseHeaderLines (firstLine:otherLines) = do +parseHeaderLines (firstLine : otherLines) = do (method, path', query, httpversion) <- parseRequestLine firstLine let path = H.extractPath path' hdr = map parseHeader otherLines @@ -52,11 +54,14 @@ parseHeaderLines (firstLine:otherLines) = do -- *** Exception: Warp: Request line specified a non-HTTP request -- >>> parseRequestLine "PRI * HTTP/2.0" -- ("PRI","*","",HTTP/2.0) -parseRequestLine :: ByteString - -> IO (H.Method - ,ByteString -- Path - ,ByteString -- Query - ,H.HttpVersion) +parseRequestLine + :: ByteString + -> IO + ( H.Method + , ByteString -- Path + , ByteString -- Query + , H.HttpVersion + ) parseRequestLine requestLine@(PS fptr off len) = withForeignPtr fptr $ \ptr -> do when (len < 14) $ throwIO baderr let methodptr = ptr `plusPtr` off @@ -81,13 +86,13 @@ parseRequestLine requestLine@(PS fptr off len) = withForeignPtr fptr $ \ptr -> d let !method = bs ptr methodptr pathptr0 !path - | queryptr == nullPtr = bs ptr pathptr httpptr0 - | otherwise = bs ptr pathptr queryptr + | queryptr == nullPtr = bs ptr pathptr httpptr0 + | otherwise = bs ptr pathptr queryptr !query - | queryptr == nullPtr = S.empty - | otherwise = bs ptr queryptr httpptr0 + | queryptr == nullPtr = S.empty + | otherwise = bs ptr queryptr httpptr0 - return (method,path,query,hv) + return (method, path, query, hv) where baderr = BadFirstLine $ C8.unpack requestLine check :: Ptr Word8 -> Int -> Word8 -> IO () @@ -105,9 +110,9 @@ parseRequestLine requestLine@(PS fptr off len) = withForeignPtr fptr $ \ptr -> d major <- peek (httpptr `plusPtr` 5) :: IO Word8 minor <- peek (httpptr `plusPtr` 7) :: IO Word8 let version - | major == _1 = if minor == _1 then H.http11 else H.http10 - | major == _2 && minor == _0 = H.http20 - | otherwise = H.http10 + | major == _1 = if minor == _1 then H.http11 else H.http10 + | major == _2 && minor == _0 = H.http20 + | otherwise = H.http10 return version bs ptr p0 p1 = PS fptr o l where @@ -126,7 +131,6 @@ parseRequestLine requestLine@(PS fptr off len) = withForeignPtr fptr $ \ptr -> d -- ("Host","example.com:8080") -- >>> parseHeader "NoSemiColon" -- ("NoSemiColon","") - parseHeader :: ByteString -> H.Header parseHeader s = let (k, rest) = S.break (== _colon) s diff --git a/warp/Network/Wai/Handler/Warp/ResponseHeader.hs b/warp/Network/Wai/Handler/Warp/ResponseHeader.hs index f98ac19a9..d3282a332 100644 --- a/warp/Network/Wai/Handler/Warp/ResponseHeader.hs +++ b/warp/Network/Wai/Handler/Warp/ResponseHeader.hs @@ -1,13 +1,13 @@ -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE OverloadedStrings #-} module Network.Wai.Handler.Warp.ResponseHeader (composeHeader) where import qualified Data.ByteString as S import Data.ByteString.Internal (create) import qualified Data.CaseInsensitive as CI -import Data.Word8 import Data.List (foldl') +import Data.Word8 import Foreign.Ptr import GHC.Storable import qualified Network.HTTP.Types as H @@ -24,7 +24,7 @@ composeHeader !httpversion !status !responseHeaders = create len $ \ptr -> do void $ copyCRLF ptr2 where !len = 17 + slen + foldl' fieldLength 0 responseHeaders - fieldLength !l (!k,!v) = l + S.length (CI.original k) + S.length v + 4 + fieldLength !l (!k, !v) = l + S.length (CI.original k) + S.length v + 4 !slen = S.length $ H.statusMessage status httpVer11 :: ByteString @@ -45,22 +45,22 @@ copyStatus !ptr !httpversion !status = do copyCRLF ptr2 where httpVer - | httpversion == H.HttpVersion 1 1 = httpVer11 - | otherwise = httpVer10 - (q0,r0) = H.statusCode status `divMod` 10 - (q1,r1) = q0 `divMod` 10 + | httpversion == H.HttpVersion 1 1 = httpVer11 + | otherwise = httpVer10 + (q0, r0) = H.statusCode status `divMod` 10 + (q1, r1) = q0 `divMod` 10 r2 = q1 `mod` 10 {-# INLINE copyHeaders #-} copyHeaders :: Ptr Word8 -> [H.Header] -> IO (Ptr Word8) copyHeaders !ptr [] = return ptr -copyHeaders !ptr (h:hs) = do +copyHeaders !ptr (h : hs) = do ptr1 <- copyHeader ptr h copyHeaders ptr1 hs {-# INLINE copyHeader #-} copyHeader :: Ptr Word8 -> H.Header -> IO (Ptr Word8) -copyHeader !ptr (k,v) = do +copyHeader !ptr (k, v) = do ptr1 <- copy ptr (CI.original k) writeWord8OffPtr ptr1 0 _colon writeWord8OffPtr ptr1 1 _space diff --git a/warp/Network/Wai/Handler/Warp/Windows.hs b/warp/Network/Wai/Handler/Warp/Windows.hs index c2821aca2..575c7b17d 100644 --- a/warp/Network/Wai/Handler/Warp/Windows.hs +++ b/warp/Network/Wai/Handler/Warp/Windows.hs @@ -1,7 +1,8 @@ {-# LANGUAGE CPP #-} -module Network.Wai.Handler.Warp.Windows - ( windowsThreadBlockHack - ) where + +module Network.Wai.Handler.Warp.Windows ( + windowsThreadBlockHack, +) where #if WINDOWS import Control.Concurrent.MVar diff --git a/warp/Network/Wai/Handler/Warp/WithApplication.hs b/warp/Network/Wai/Handler/Warp/WithApplication.hs index ac36c00d8..4259931ee 100644 --- a/warp/Network/Wai/Handler/Warp/WithApplication.hs +++ b/warp/Network/Wai/Handler/Warp/WithApplication.hs @@ -1,23 +1,24 @@ {-# LANGUAGE OverloadedStrings #-} + module Network.Wai.Handler.Warp.WithApplication ( - withApplication, - withApplicationSettings, - testWithApplication, - testWithApplicationSettings, - openFreePort, - withFreePort, + withApplication, + withApplicationSettings, + testWithApplication, + testWithApplicationSettings, + openFreePort, + withFreePort, ) where -import Control.Concurrent +import Control.Concurrent +import Control.Monad (when) +import Data.Streaming.Network (bindRandomPortTCP) +import Network.Socket +import Network.Wai +import Network.Wai.Handler.Warp.Run +import Network.Wai.Handler.Warp.Settings +import Network.Wai.Handler.Warp.Types import qualified UnliftIO -import UnliftIO.Async -import Control.Monad (when) -import Data.Streaming.Network (bindRandomPortTCP) -import Network.Socket -import Network.Wai -import Network.Wai.Handler.Warp.Run -import Network.Wai.Handler.Warp.Settings -import Network.Wai.Handler.Warp.Types +import UnliftIO.Async -- | Runs the given 'Application' on a free port. Passes the port to the given -- operation and executes it, while the 'Application' is running. Shuts down the @@ -33,20 +34,21 @@ withApplication = withApplicationSettings defaultSettings -- @since 3.2.7 withApplicationSettings :: Settings -> IO Application -> (Port -> IO a) -> IO a withApplicationSettings settings' mkApp action = do - app <- mkApp - withFreePort $ \ (port, sock) -> do - started <- mkWaiter - let settings = - settings' { - settingsBeforeMainLoop - = notify started () >> settingsBeforeMainLoop settings' - } - result <- race - (runSettingsSocket settings sock app) - (waitFor started >> action port) - case result of - Left () -> UnliftIO.throwString "Unexpected: runSettingsSocket exited" - Right x -> return x + app <- mkApp + withFreePort $ \(port, sock) -> do + started <- mkWaiter + let settings = + settings' + { settingsBeforeMainLoop = + notify started () >> settingsBeforeMainLoop settings' + } + result <- + race + (runSettingsSocket settings sock app) + (waitFor started >> action port) + case result of + Left () -> UnliftIO.throwString "Unexpected: runSettingsSocket exited" + Right x -> return x -- | Same as 'withApplication' but with different exception handling: If the -- given 'Application' throws an exception, 'testWithApplication' will re-throw @@ -67,31 +69,32 @@ testWithApplication = testWithApplicationSettings defaultSettings -- | 'testWithApplication' with given 'Settings'. -- -- @since 3.2.7 -testWithApplicationSettings :: Settings -> IO Application -> (Port -> IO a) -> IO a +testWithApplicationSettings + :: Settings -> IO Application -> (Port -> IO a) -> IO a testWithApplicationSettings settings mkApp action = do - callingThread <- myThreadId - app <- mkApp - let wrappedApp request respond = - app request respond `UnliftIO.catchAny` \ e -> do - when - (defaultShouldDisplayException e) - (throwTo callingThread e) - UnliftIO.throwIO e - withApplicationSettings settings (return wrappedApp) action + callingThread <- myThreadId + app <- mkApp + let wrappedApp request respond = + app request respond `UnliftIO.catchAny` \e -> do + when + (defaultShouldDisplayException e) + (throwTo callingThread e) + UnliftIO.throwIO e + withApplicationSettings settings (return wrappedApp) action -data Waiter a - = Waiter { - notify :: a -> IO (), - waitFor :: IO a - } +data Waiter a = Waiter + { notify :: a -> IO () + , waitFor :: IO a + } mkWaiter :: IO (Waiter a) mkWaiter = do - mvar <- newEmptyMVar - return Waiter { - notify = putMVar mvar, - waitFor = readMVar mvar - } + mvar <- newEmptyMVar + return + Waiter + { notify = putMVar mvar + , waitFor = readMVar mvar + } -- | Opens a socket on a free port and returns both port and socket. -- diff --git a/warp/attic/bigtable-single.hs b/warp/attic/bigtable-single.hs index a834f6cef..8ba83c622 100644 --- a/warp/attic/bigtable-single.hs +++ b/warp/attic/bigtable-single.hs @@ -1,21 +1,24 @@ {-# LANGUAGE OverloadedStrings #-} -import Network.Wai + import Blaze.ByteString.Builder (Builder, fromByteString) import Blaze.ByteString.Builder.Char8 (fromShow) import Data.Monoid (mappend) +import Network.Wai import Network.Wai.Handler.Warp (run) bigtable :: Builder bigtable = fromByteString "" - `mappend` foldr mappend (fromByteString "
") (replicate 2 row) + `mappend` foldr mappend (fromByteString "") (replicate 2 row) where - row = fromByteString "" - `mappend` foldr go (fromByteString "") [1..2] - go i rest = fromByteString "" - `mappend` fromShow i - `mappend` fromByteString "" - `mappend` rest + row = + fromByteString "" + `mappend` foldr go (fromByteString "") [1 .. 2] + go i rest = + fromByteString "" + `mappend` fromShow i + `mappend` fromByteString "" + `mappend` rest main = run 3000 app diff --git a/warp/attic/bigtable-stream.hs b/warp/attic/bigtable-stream.hs index bf9be1982..bf97aa8db 100644 --- a/warp/attic/bigtable-stream.hs +++ b/warp/attic/bigtable-stream.hs @@ -1,22 +1,25 @@ {-# LANGUAGE OverloadedStrings #-} -import Network.Wai + import Blaze.ByteString.Builder (Builder, fromByteString) import Blaze.ByteString.Builder.Char8 (fromShow) +import Data.Enumerator (enumList, run_, ($$)) import Data.Monoid (mappend) +import Network.Wai import Network.Wai.Handler.Warp (run) -import Data.Enumerator (enumList, ($$), run_) bigtable :: [Builder] bigtable = fromByteString "" - : foldr (:) [fromByteString "
"] (replicate 2 row) + : foldr (:) [fromByteString ""] (replicate 2 row) where - row = fromByteString "" - `mappend` foldr go (fromByteString "") [1..2] - go i rest = fromByteString "" - `mappend` fromShow i - `mappend` fromByteString "" - `mappend` rest + row = + fromByteString "" + `mappend` foldr go (fromByteString "") [1 .. 2] + go i rest = + fromByteString "" + `mappend` fromShow i + `mappend` fromByteString "" + `mappend` rest main = run 3000 app diff --git a/warp/attic/file-nolen.hs b/warp/attic/file-nolen.hs index 9b3c081d6..1fed8dce9 100644 --- a/warp/attic/file-nolen.hs +++ b/warp/attic/file-nolen.hs @@ -1,6 +1,11 @@ {-# LANGUAGE OverloadedStrings #-} + +import Network.HTTP.Types import Network.Wai import Network.Wai.Handler.Warp -import Network.HTTP.Types -main = run 3000 $ const $ return $ ResponseFile status200 [("Content-Type", "text/plain")] "test.txt" Nothing +main = + run 3000 $ + const $ + return $ + ResponseFile status200 [("Content-Type", "text/plain")] "test.txt" Nothing diff --git a/warp/attic/file.hs b/warp/attic/file.hs index 57734f10a..4833fcd8a 100644 --- a/warp/attic/file.hs +++ b/warp/attic/file.hs @@ -1,6 +1,15 @@ {-# LANGUAGE OverloadedStrings #-} + +import Network.HTTP.Types import Network.Wai import Network.Wai.Handler.Warp -import Network.HTTP.Types -main = run 3000 $ const $ return $ ResponseFile status200 [("Content-Type", "text/plain"), ("Content-Length", "16")] "test.txt" Nothing +main = + run 3000 $ + const $ + return $ + ResponseFile + status200 + [("Content-Type", "text/plain"), ("Content-Length", "16")] + "test.txt" + Nothing diff --git a/warp/attic/pong.hs b/warp/attic/pong.hs index 5327d8579..406315cca 100644 --- a/warp/attic/pong.hs +++ b/warp/attic/pong.hs @@ -1,37 +1,42 @@ {-# LANGUAGE OverloadedStrings #-} -import Network.Wai -import Network.Wai.Handler.Warp -import Network.HTTP.Types (status200) + import Blaze.ByteString.Builder (copyByteString) -import Data.Monoid import qualified Data.Conduit as C import qualified Data.Conduit.List as CL +import Data.Monoid +import Network.HTTP.Types (status200) +import Network.Wai +import Network.Wai.Handler.Warp main = run 3000 app -app req = ($ - case rawPathInfo req of - "/builder/withlen" -> builderWithLen - "/builder/nolen" -> builderNoLen - "/file/withlen" -> fileWithLen - "/file/nolen" -> fileNoLen - "/source/withlen" -> sourceWithLen - "/source/nolen" -> sourceNoLen - "/notfound" -> responseFile status200 [] "notfound" Nothing - x -> index x) +app req = + ( $ + case rawPathInfo req of + "/builder/withlen" -> builderWithLen + "/builder/nolen" -> builderNoLen + "/file/withlen" -> fileWithLen + "/file/nolen" -> fileNoLen + "/source/withlen" -> sourceWithLen + "/source/nolen" -> sourceNoLen + "/notfound" -> responseFile status200 [] "notfound" Nothing + x -> index x + ) -builderWithLen = responseBuilder - status200 - [ ("Content-Type", "text/plain") - , ("Content-Length", "4") - ] - $ copyByteString "PONG" +builderWithLen = + responseBuilder + status200 + [ ("Content-Type", "text/plain") + , ("Content-Length", "4") + ] + $ copyByteString "PONG" -builderNoLen = responseBuilder - status200 - [ ("Content-Type", "text/plain") - ] - $ copyByteString "PONG" +builderNoLen = + responseBuilder + status200 + [ ("Content-Type", "text/plain") + ] + $ copyByteString "PONG" sourceWithLen = responseStream status200 @@ -46,27 +51,33 @@ sourceNoLen = responseStream ] $ \send _ -> send $ copyByteString "PONG" -fileWithLen = responseFile - status200 - [ ("Content-Type", "text/plain") - , ("Content-Length", "4") - ] - "pong.txt" - Nothing +fileWithLen = + responseFile + status200 + [ ("Content-Type", "text/plain") + , ("Content-Length", "4") + ] + "pong.txt" + Nothing -fileNoLen = responseFile - status200 - [ ("Content-Type", "text/plain") - ] - "pong.txt" - Nothing +fileNoLen = + responseFile + status200 + [ ("Content-Type", "text/plain") + ] + "pong.txt" + Nothing -index p = responseBuilder status200 [("Content-Type", "text/html")] $ mconcat $ map copyByteString - [ "

builder withlen

\n" - , "

builder nolen

\n" - , "

file withlen

\n" - , "

file nolen

\n" - , "

source withlen

\n" - , "

source nolen

\n" - , p - ] +index p = + responseBuilder status200 [("Content-Type", "text/html")] $ + mconcat $ + map + copyByteString + [ "

builder withlen

\n" + , "

builder nolen

\n" + , "

file withlen

\n" + , "

file nolen

\n" + , "

source withlen

\n" + , "

source nolen

\n" + , p + ] diff --git a/warp/attic/print-post.hs b/warp/attic/print-post.hs index 55894de93..e2037ba09 100644 --- a/warp/attic/print-post.hs +++ b/warp/attic/print-post.hs @@ -1,11 +1,12 @@ {-# LANGUAGE OverloadedStrings #-} -import Network.Wai -import Network.Wai.Handler.Warp -import Network.HTTP.Types (status200) + import Blaze.ByteString.Builder (fromByteString) +import Control.Monad.IO.Class (liftIO) import qualified Data.Conduit as C import qualified Data.Conduit.List as CL -import Control.Monad.IO.Class (liftIO) +import Network.HTTP.Types (status200) +import Network.Wai +import Network.Wai.Handler.Warp {- - use `curl -H "Transfer-Encoding: chunked" estfile http://localhost:3000/` to send a chunked post request. diff --git a/warp/attic/readInt.hs b/warp/attic/readInt.hs index d5698baeb..a268d51b9 100644 --- a/warp/attic/readInt.hs +++ b/warp/attic/readInt.hs @@ -1,4 +1,6 @@ -{-# LANGUAGE OverloadedStrings, MagicHash, BangPatterns #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} -- Copyright : Erik de Castro Lopo -- License : BSD3 @@ -16,10 +18,10 @@ import Data.Maybe (fromMaybe) import qualified Data.ByteString as S import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lex.Integral as LI import qualified Data.Char as C import qualified Numeric as N import qualified Test.QuickCheck as QC -import qualified Data.ByteString.Lex.Integral as LI import GHC.Prim import GHC.Types @@ -30,33 +32,29 @@ readIntOrig :: ByteString -> Integer readIntOrig = S.foldl' (\x w -> x * 10 + fromIntegral w - 48) 0 - -- Using Numeric.readDec which works on String, so the ByteString has to be -- unpacked first. readDec :: ByteString -> Integer readDec s = case N.readDec (B.unpack s) of [] -> 0 - (x, _):_ -> x - + (x, _) : _ -> x -- Use ByteString's readInteger function (which requires some extra stuff to -- slow us down). readIntegerBS :: ByteString -> Integer readIntegerBS bs = fst $ fromMaybe (0, "") $ B.readInteger bs - -- No checking for non-digits. Will overflow at 2^31 on 32 bit CPUs. readIntRaw :: ByteString -> Int readIntRaw = B.foldl' (\i c -> i * 10 + C.digitToInt c) 0 - -- The first good solution. readInt64 :: ByteString -> Int64 readInt64 bs = - B.foldl' (\i c -> i * 10 + fromIntegral (C.digitToInt c)) (0::Int64) - $ B.takeWhile C.isDigit bs + B.foldl' (\i c -> i * 10 + fromIntegral (C.digitToInt c)) (0 :: Int64) $ + B.takeWhile C.isDigit bs readInt :: ByteString -> Int readInt bs = fromIntegral $ readInt64 bs @@ -64,15 +62,14 @@ readInt bs = fromIntegral $ readInt64 bs readInteger :: ByteString -> Integer readInteger bs = fromIntegral $ readInt64 bs - -- MagicHash version suggested by Vincent Hanquez. readIntMH :: Integral a => ByteString -> a readIntMH s = fromIntegral $ ireadInt64MH s where ireadInt64MH :: ByteString -> Int64 ireadInt64MH bs = - B.foldl' (\i c -> i * 10 + fromIntegral (mhDigitToInt c)) 0 - $ B.takeWhile C.isDigit bs + B.foldl' (\i c -> i * 10 + fromIntegral (mhDigitToInt c)) 0 $ + B.takeWhile C.isDigit bs readInt64MH :: ByteString -> Int64 readInt64MH = readIntMH @@ -87,24 +84,24 @@ mhDigitToInt (C# i) = I# (word2Int# (indexWord8OffAddr# addr (ord# i))) where !(Table addr) = table table :: Table - table = Table - "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ - \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"# - + table = + Table + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\ + \\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"# readIntBSL :: ByteString -> Int readIntBSL = LI.readDecimal_ @@ -112,18 +109,17 @@ readIntBSL = LI.readDecimal_ readInt64BSL :: ByteString -> Int64 readInt64BSL = LI.readDecimal_ - -- A QuickCheck property. Test that for a number >= 0, converting it to -- a string using show and then reading the value back with the function -- under test returns the original value. -- The functions under test only work on Natural numbers (the Conent-Length -- field in a HTTP header is always >= 0) so we check the absolute value of -- the value that QuickCheck generates for us. -prop_read_show_idempotent :: (Integral a, Show a) => (ByteString -> a) -> a -> Bool +prop_read_show_idempotent + :: (Integral a, Show a) => (ByteString -> a) -> a -> Bool prop_read_show_idempotent freader x = let px = abs x - in px == freader (B.pack $ show px) - + in px == freader (B.pack $ show px) runQuickCheckTests :: IO () runQuickCheckTests = do @@ -138,20 +134,19 @@ runQuickCheckTests = do runCriterionTests :: ByteString -> IO () runCriterionTests number = defaultMain - [ bench "readIntOrig" $ nf readIntOrig number - , bench "readDec" $ nf readDec number - , bench "readIntegerBS" $ nf readIntegerBS number - , bench "readRaw" $ nf readIntRaw number - , bench "readInt" $ nf readInt number - , bench "readInt64" $ nf readInt64 number - , bench "readInteger" $ nf readInteger number - , bench "readInt64MH" $ nf readInt64MH number - , bench "readIntegerMH" $ nf readIntegerMH number - - -- Current best. - , bench "readIntBSL" $ nf readIntBSL number - , bench "readInt64BSL" $ nf readInt64BSL number - ] + [ bench "readIntOrig" $ nf readIntOrig number + , bench "readDec" $ nf readDec number + , bench "readIntegerBS" $ nf readIntegerBS number + , bench "readRaw" $ nf readIntRaw number + , bench "readInt" $ nf readInt number + , bench "readInt64" $ nf readInt64 number + , bench "readInteger" $ nf readInteger number + , bench "readInt64MH" $ nf readInt64MH number + , bench "readIntegerMH" $ nf readIntegerMH number + , -- Current best. + bench "readIntBSL" $ nf readIntBSL number + , bench "readInt64BSL" $ nf readInt64BSL number + ] main :: IO () main = do @@ -159,4 +154,3 @@ main = do runQuickCheckTests putStrLn "Criterion tests." runCriterionTests "1234567898765432178979128361238162386182" - diff --git a/warp/attic/runtests.hs b/warp/attic/runtests.hs index 823ed4527..3b1c17375 100644 --- a/warp/attic/runtests.hs +++ b/warp/attic/runtests.hs @@ -1,26 +1,29 @@ {-# LANGUAGE OverloadedStrings #-} -import Test.Framework (defaultMain, testGroup, Test) + +import Test.Framework (Test, defaultMain, testGroup) import Test.Framework.Providers.HUnit import Test.Framework.Providers.QuickCheck2 import Test.HUnit hiding (Test) -import Network.Wai.Handler.Warp (takeHeaders, InvalidRequest (..), readInt) -import Data.Enumerator (run_, ($$), enumList, run) import Control.Exception (fromException) import qualified Data.ByteString.Char8 as S8 +import Data.Enumerator (enumList, run, run_, ($$)) +import Network.Wai.Handler.Warp (InvalidRequest (..), readInt, takeHeaders) main :: IO () main = defaultMain [testSuite] testSuite :: Test -testSuite = testGroup "Text.Hamlet" - [ testCase "takeUntilBlank safe" caseTakeUntilBlankSafe - , testCase "takeUntilBlank too many lines" caseTakeUntilBlankTooMany - , testCase "takeUntilBlank too large" caseTakeUntilBlankTooLarge - , testProperty "takeInt" $ \i' -> - let i = abs i' - in i == readInt (S8.pack $ show i) - ] +testSuite = + testGroup + "Text.Hamlet" + [ testCase "takeUntilBlank safe" caseTakeUntilBlankSafe + , testCase "takeUntilBlank too many lines" caseTakeUntilBlankTooMany + , testCase "takeUntilBlank too large" caseTakeUntilBlankTooLarge + , testProperty "takeInt" $ \i' -> + let i = abs i' + in i == readInt (S8.pack $ show i) + ] caseTakeUntilBlankSafe = do x <- run_ $ (enumList 1 ["f", "oo\n", "bar\nbaz\n\r\n"]) $$ takeHeaders diff --git a/warp/attic/server-no-keepalive.hs b/warp/attic/server-no-keepalive.hs index 8ffadac77..f852bb6dd 100644 --- a/warp/attic/server-no-keepalive.hs +++ b/warp/attic/server-no-keepalive.hs @@ -1,34 +1,43 @@ {-# LANGUAGE OverloadedStrings #-} -import Network.Wai.Handler.Warp + +import Control.Concurrent (forkIO) import Control.Exception (bracket) import Control.Monad (forever) -import Network (sClose) -import Network.Socket (accept) import Control.Monad.IO.Class (liftIO) +import Data.Enumerator (run_, ($$)) import qualified Data.Enumerator as E import qualified Data.Enumerator.Binary as EB -import Control.Concurrent (forkIO) -import Network.Wai (responseLBS) +import Network (sClose) import Network.HTTP.Types (status200) -import Data.Enumerator (($$), run_) +import Network.Socket (accept) +import Network.Wai (responseLBS) +import Network.Wai.Handler.Warp -app = const $ return $ responseLBS status200 [("Content-type", "text/plain")] "This is not kept alive under any circumstances" +app = + const $ + return $ + responseLBS + status200 + [("Content-type", "text/plain")] + "This is not kept alive under any circumstances" -main = withManager 30000000 $ \man -> bracket - (bindPort (settingsPort set) (settingsHost set)) - sClose - (\socket -> forever $ do - (conn, sa) <- accept socket - th <- liftIO $ registerKillThread man - _ <- forkIO $ do - run_ $ enumSocket th 4096 conn $$ do - liftIO $ pause th - (len, env) <- parseRequest (settingsPort set) sa - liftIO $ resume th - res <- E.joinI $ EB.isolate len $$ app env - _ <- liftIO $ sendResponse th env conn res - liftIO $ sClose conn - return () - ) +main = withManager 30000000 $ \man -> + bracket + (bindPort (settingsPort set) (settingsHost set)) + sClose + ( \socket -> forever $ do + (conn, sa) <- accept socket + th <- liftIO $ registerKillThread man + _ <- forkIO $ do + run_ $ + enumSocket th 4096 conn $$ do + liftIO $ pause th + (len, env) <- parseRequest (settingsPort set) sa + liftIO $ resume th + res <- E.joinI $ EB.isolate len $$ app env + _ <- liftIO $ sendResponse th env conn res + liftIO $ sClose conn + return () + ) where set = defaultSettings diff --git a/warp/attic/statuses.hs b/warp/attic/statuses.hs index 11d91ebf6..3037c6793 100644 --- a/warp/attic/statuses.hs +++ b/warp/attic/statuses.hs @@ -1,18 +1,21 @@ {-# LANGUAGE OverloadedStrings #-} + +import qualified Data.ByteString.Char8 as S +import Data.ByteString.Lazy.Char8 (pack) import Network.Wai import Network.Wai.Handler.Warp -import Data.ByteString.Lazy.Char8 (pack) -import qualified Data.ByteString.Char8 as S main = run 3000 app app req = - return $ responseLBS (Status s' s) [("Content-Type", "text/plain")] - $ pack $ concat - [ "The status code is " - , S.unpack s - , ". Have a nice day!" - ] + return $ + responseLBS (Status s' s) [("Content-Type", "text/plain")] $ + pack $ + concat + [ "The status code is " + , S.unpack s + , ". Have a nice day!" + ] where s = S.dropWhile (== '/') $ pathInfo req s' = read $ S.unpack s diff --git a/warp/attic/undrained.hs b/warp/attic/undrained.hs index d9c2b32b1..d09691ec4 100644 --- a/warp/attic/undrained.hs +++ b/warp/attic/undrained.hs @@ -1,9 +1,15 @@ {-# LANGUAGE OverloadedStrings #-} + +import Blaze.ByteString.Builder (fromByteString) import Network.Wai import Network.Wai.Handler.Warp -import Blaze.ByteString.Builder (fromByteString) -main = run 3000 $ const $ return $ responseBuilder - status200 - [("Content-Type", "text/html")] - $ fromByteString "
" +main = + run 3000 + $ const + $ return + $ responseBuilder + status200 + [("Content-Type", "text/html")] + $ fromByteString + "
" diff --git a/warp/bench/Parser.hs b/warp/bench/Parser.hs index a200c5877..98b4e3102 100644 --- a/warp/bench/Parser.hs +++ b/warp/bench/Parser.hs @@ -1,17 +1,17 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE BangPatterns #-} module Main where import Control.Monad import qualified Data.ByteString as S import qualified Data.ByteString.Char8 as B (unpack) -import Data.Word8 (_H, _P, _T, _slash, _period) +import Data.Word8 (_H, _P, _T, _period, _slash) import qualified Network.HTTP.Types as H import Network.Wai.Handler.Warp.Types +import UnliftIO.Exception (impureThrow, throwIO) import Prelude hiding (lines) -import UnliftIO.Exception (throwIO, impureThrow) import Data.ByteString.Internal import Data.Word @@ -34,20 +34,22 @@ main :: IO () main = do let requestLine1 = "GET http://www.example.com HTTP/1.1" let requestLine2 = "GET http://www.example.com/cgi-path/search.cgi?key=parser HTTP/1.0" - defaultMain [ - bgroup "requestLine1" [ - bench "parseRequestLine3" $ whnf parseRequestLine3 requestLine1 - , bench "parseRequestLine2" $ whnfIO $ parseRequestLine2 requestLine1 - , bench "parseRequestLine1" $ whnfIO $ parseRequestLine1 requestLine1 - , bench "parseRequestLine0" $ whnfIO $ parseRequestLine0 requestLine1 - ] - , bgroup "requestLine2" [ - bench "parseRequestLine3" $ whnf parseRequestLine3 requestLine2 - , bench "parseRequestLine2" $ whnfIO $ parseRequestLine2 requestLine2 - , bench "parseRequestLine1" $ whnfIO $ parseRequestLine1 requestLine2 - , bench "parseRequestLine0" $ whnfIO $ parseRequestLine0 requestLine2 - ] - ] + defaultMain + [ bgroup + "requestLine1" + [ bench "parseRequestLine3" $ whnf parseRequestLine3 requestLine1 + , bench "parseRequestLine2" $ whnfIO $ parseRequestLine2 requestLine1 + , bench "parseRequestLine1" $ whnfIO $ parseRequestLine1 requestLine1 + , bench "parseRequestLine0" $ whnfIO $ parseRequestLine0 requestLine1 + ] + , bgroup + "requestLine2" + [ bench "parseRequestLine3" $ whnf parseRequestLine3 requestLine2 + , bench "parseRequestLine2" $ whnfIO $ parseRequestLine2 requestLine2 + , bench "parseRequestLine1" $ whnfIO $ parseRequestLine1 requestLine2 + , bench "parseRequestLine0" $ whnfIO $ parseRequestLine0 requestLine2 + ] + ] ---------------------------------------------------------------- @@ -61,26 +63,29 @@ main = do -- *** Exception: BadFirstLine "GET " -- >>> parseRequestLine3 "GET /NotHTTP UNKNOWN/1.1" -- *** Exception: NonHttp -parseRequestLine3 :: ByteString - -> (H.Method - ,ByteString -- Path - ,ByteString -- Query - ,H.HttpVersion) +parseRequestLine3 + :: ByteString + -> ( H.Method + , ByteString -- Path + , ByteString -- Query + , H.HttpVersion + ) parseRequestLine3 requestLine = ret where - (!method,!rest) = S.break (== 32) requestLine -- ' ' - (!pathQuery,!httpVer') - | rest == "" = impureThrow badmsg - | otherwise = S.break (== 32) (S.drop 1 rest) -- ' ' - (!path,!query) = S.break (== 63) pathQuery -- '?' + (!method, !rest) = S.break (== 32) requestLine -- ' ' + (!pathQuery, !httpVer') + | rest == "" = impureThrow badmsg + | otherwise = S.break (== 32) (S.drop 1 rest) -- ' ' + (!path, !query) = S.break (== 63) pathQuery -- '?' !httpVer = S.drop 1 httpVer' - (!http,!ver) - | httpVer == "" = impureThrow badmsg - | otherwise = S.break (== 47) httpVer -- '/' - !hv | http /= "HTTP" = impureThrow NonHttp - | ver == "/1.1" = H.http11 - | otherwise = H.http10 - !ret = (method,path,query,hv) + (!http, !ver) + | httpVer == "" = impureThrow badmsg + | otherwise = S.break (== 47) httpVer -- '/' + !hv + | http /= "HTTP" = impureThrow NonHttp + | ver == "/1.1" = H.http11 + | otherwise = H.http10 + !ret = (method, path, query, hv) badmsg = BadFirstLine $ B.unpack requestLine ---------------------------------------------------------------- @@ -95,11 +100,14 @@ parseRequestLine3 requestLine = ret -- *** Exception: BadFirstLine "GET " -- >>> parseRequestLine2 "GET /NotHTTP UNKNOWN/1.1" -- *** Exception: NonHttp -parseRequestLine2 :: ByteString - -> IO (H.Method - ,ByteString -- Path - ,ByteString -- Query - ,H.HttpVersion) +parseRequestLine2 + :: ByteString + -> IO + ( H.Method + , ByteString -- Path + , ByteString -- Query + , H.HttpVersion + ) parseRequestLine2 requestLine@(PS fptr off len) = withForeignPtr fptr $ \ptr -> do when (len < 14) $ throwIO baderr let methodptr = ptr `plusPtr` off @@ -121,16 +129,15 @@ parseRequestLine2 requestLine@(PS fptr off len) = withForeignPtr fptr $ \ptr -> checkHTTP httpptr !hv <- httpVersion httpptr queryptr <- memchr pathptr 63 lim2 -- '?' - let !method = bs ptr methodptr pathptr0 !path - | queryptr == nullPtr = bs ptr pathptr httpptr0 - | otherwise = bs ptr pathptr queryptr + | queryptr == nullPtr = bs ptr pathptr httpptr0 + | otherwise = bs ptr pathptr queryptr !query - | queryptr == nullPtr = S.empty - | otherwise = bs ptr queryptr httpptr0 + | queryptr == nullPtr = S.empty + | otherwise = bs ptr queryptr httpptr0 - return (method,path,query,hv) + return (method, path, query, hv) where baderr = BadFirstLine $ B.unpack requestLine check :: Ptr Word8 -> Int -> Word8 -> IO () @@ -147,10 +154,9 @@ parseRequestLine2 requestLine@(PS fptr off len) = withForeignPtr fptr $ \ptr -> httpVersion httpptr = do major <- peek $ httpptr `plusPtr` 5 minor <- peek $ httpptr `plusPtr` 7 - if major == (49 :: Word8) && minor == (49 :: Word8) then - return H.http11 - else - return H.http10 + if major == (49 :: Word8) && minor == (49 :: Word8) + then return H.http11 + else return H.http10 bs ptr p0 p1 = PS fptr o l where o = p0 `minusPtr` ptr @@ -168,23 +174,29 @@ parseRequestLine2 requestLine@(PS fptr off len) = withForeignPtr fptr $ \ptr -> -- *** Exception: BadFirstLine "GET " -- >>> parseRequestLine1 "GET /NotHTTP UNKNOWN/1.1" -- *** Exception: NonHttp -parseRequestLine1 :: ByteString - -> IO (H.Method - ,ByteString -- Path - ,ByteString -- Query - ,H.HttpVersion) +parseRequestLine1 + :: ByteString + -> IO + ( H.Method + , ByteString -- Path + , ByteString -- Query + , H.HttpVersion + ) parseRequestLine1 requestLine = do - let (!method,!rest) = S.break (== 32) requestLine -- ' ' - (!pathQuery,!httpVer') = S.break (== 32) (S.drop 1 rest) -- ' ' + let (!method, !rest) = S.break (== 32) requestLine -- ' ' + (!pathQuery, !httpVer') = S.break (== 32) (S.drop 1 rest) -- ' ' !httpVer = S.drop 1 httpVer' when (rest == "" || httpVer == "") $ - throwIO $ BadFirstLine $ B.unpack requestLine - let (!path,!query) = S.break (== 63) pathQuery -- '?' - (!http,!ver) = S.break (== 47) httpVer -- '/' + throwIO $ + BadFirstLine $ + B.unpack requestLine + let (!path, !query) = S.break (== 63) pathQuery -- '?' + (!http, !ver) = S.break (== 47) httpVer -- '/' when (http /= "HTTP") $ throwIO NonHttp - let !hv | ver == "/1.1" = H.http11 - | otherwise = H.http10 - return $! (method,path,query,hv) + let !hv + | ver == "/1.1" = H.http11 + | otherwise = H.http10 + return $! (method, path, query, hv) ---------------------------------------------------------------- @@ -198,23 +210,27 @@ parseRequestLine1 requestLine = do -- *** Exception: BadFirstLine "GET " -- >>> parseRequestLine0 "GET /NotHTTP UNKNOWN/1.1" -- *** Exception: NonHttp -parseRequestLine0 :: ByteString - -> IO (H.Method - ,ByteString -- Path - ,ByteString -- Query - ,H.HttpVersion) +parseRequestLine0 + :: ByteString + -> IO + ( H.Method + , ByteString -- Path + , ByteString -- Query + , H.HttpVersion + ) parseRequestLine0 s = - case filter (not . S.null) $ S.splitWith (\c -> c == 32 || c == 9) s of -- ' - (method':query:http'') -> do + case filter (not . S.null) $ S.splitWith (\c -> c == 32 || c == 9) s of -- ' + (method' : query : http'') -> do let !method = method' !http' = S.concat http'' (!hfirst, !hsecond) = S.splitAt 5 http' if hfirst == "HTTP/" - then let (!rpath, !qstring) = S.break (== 63) query -- '?' + then + let (!rpath, !qstring) = S.break (== 63) query -- '?' !hv = case hsecond of "1.1" -> H.http11 _ -> H.http10 - in return $! (method, rpath, qstring, hv) - else throwIO NonHttp + in return $! (method, rpath, qstring, hv) + else throwIO NonHttp _ -> throwIO $ BadFirstLine $ B.unpack s diff --git a/warp/test/ConduitSpec.hs b/warp/test/ConduitSpec.hs index 0e9929034..76a02fe58 100644 --- a/warp/test/ConduitSpec.hs +++ b/warp/test/ConduitSpec.hs @@ -1,12 +1,13 @@ {-# LANGUAGE OverloadedStrings #-} + module ConduitSpec (main, spec) where +import Control.Monad (replicateM) +import qualified Data.ByteString as S +import Data.IORef as I import Network.Wai.Handler.Warp.Conduit import Network.Wai.Handler.Warp.Types -import Control.Monad (replicateM) import Test.Hspec -import Data.IORef as I -import qualified Data.ByteString as S main :: IO () main = hspec spec @@ -14,23 +15,23 @@ main = hspec spec spec :: Spec spec = describe "conduit" $ do it "IsolatedBSSource" $ do - ref <- newIORef $ map S.singleton [1..50] + ref <- newIORef $ map S.singleton [1 .. 50] src <- mkSource $ do x <- readIORef ref case x of [] -> return S.empty - y:z -> do + y : z -> do writeIORef ref z return y isrc <- mkISource src 40 x <- replicateM 20 $ readISource isrc - S.concat x `shouldBe` S.pack [1..20] + S.concat x `shouldBe` S.pack [1 .. 20] y <- replicateM 40 $ readISource isrc - S.concat y `shouldBe` S.pack [21..40] + S.concat y `shouldBe` S.pack [21 .. 40] z <- replicateM 40 $ readSource src - S.concat z `shouldBe` S.pack [41..50] + S.concat z `shouldBe` S.pack [41 .. 50] it "chunkedSource" $ do ref <- newIORef "5\r\n12345\r\n3\r\n678\r\n0\r\n\r\nBLAH" src <- mkSource $ do @@ -45,17 +46,18 @@ spec = describe "conduit" $ do y <- replicateM 15 $ readSource src S.concat y `shouldBe` "BLAH" it "chunk boundaries" $ do - ref <- newIORef - [ "5\r\n" - , "12345\r\n3\r" - , "\n678\r\n0\r\n" - , "\r\nBLAH" - ] + ref <- + newIORef + [ "5\r\n" + , "12345\r\n3\r" + , "\n678\r\n0\r\n" + , "\r\nBLAH" + ] src <- mkSource $ do x <- readIORef ref case x of [] -> return S.empty - y:z -> do + y : z -> do writeIORef ref z return y csrc <- mkCSource src diff --git a/warp/test/ExceptionSpec.hs b/warp/test/ExceptionSpec.hs index 13cd08f15..947806bae 100644 --- a/warp/test/ExceptionSpec.hs +++ b/warp/test/ExceptionSpec.hs @@ -1,4 +1,5 @@ -{-# LANGUAGE OverloadedStrings, CPP #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE OverloadedStrings #-} module ExceptionSpec (main, spec) where @@ -6,15 +7,15 @@ module ExceptionSpec (main, spec) where import Control.Applicative #endif import Control.Monad +import qualified Data.Streaming.Network as N import Network.HTTP.Types hiding (Header) +import Network.Socket (close) import Network.Wai hiding (Response, responseStatus) -import Network.Wai.Internal (Request(..)) import Network.Wai.Handler.Warp +import Network.Wai.Internal (Request (..)) import Test.Hspec -import UnliftIO.Exception -import qualified Data.Streaming.Network as N import UnliftIO.Async (withAsync) -import Network.Socket (close) +import UnliftIO.Exception import HTTP @@ -26,11 +27,11 @@ withTestServer inner = bracket (N.bindRandomPortTCP "127.0.0.1") (close . snd) $ \(prt, lsocket) -> do - withAsync (runSettingsSocket defaultSettings lsocket testApp) - $ \_ -> inner prt + withAsync (runSettingsSocket defaultSettings lsocket testApp) $ + \_ -> inner prt testApp :: Application -testApp (Network.Wai.Internal.Request {pathInfo = [x]}) f +testApp (Network.Wai.Internal.Request{pathInfo = [x]}) f | x == "statusError" = f $ responseLBS undefined [] "foo" | x == "headersError" = @@ -43,24 +44,25 @@ testApp (Network.Wai.Internal.Request {pathInfo = [x]}) f void $ fail "ioException" f $ responseLBS ok200 [] "foo" testApp _ f = - f $ responseLBS ok200 [] "foo" + f $ responseLBS ok200 [] "foo" spec :: Spec spec = describe "responds even if there is an exception" $ do - {- Disabling these tests. We can consider forcing evaluation in Warp. - it "statusError" $ do - sc <- responseStatus <$> sendGET "http://127.0.0.1:2345/statusError" - sc `shouldBe` internalServerError500 - it "headersError" $ do - sc <- responseStatus <$> sendGET "http://127.0.0.1:2345/headersError" - sc `shouldBe` internalServerError500 - it "headerError" $ do - sc <- responseStatus <$> sendGET "http://127.0.0.1:2345/headerError" - sc `shouldBe` internalServerError500 - it "bodyError" $ do - sc <- responseStatus <$> sendGET "http://127.0.0.1:2345/bodyError" - sc `shouldBe` internalServerError500 - -} - it "ioException" $ withTestServer $ \prt -> do - sc <- responseStatus <$> sendGET ("http://127.0.0.1:" ++ show prt ++ "/ioException") - sc `shouldBe` internalServerError500 + {- Disabling these tests. We can consider forcing evaluation in Warp. + it "statusError" $ do + sc <- responseStatus <$> sendGET "http://127.0.0.1:2345/statusError" + sc `shouldBe` internalServerError500 + it "headersError" $ do + sc <- responseStatus <$> sendGET "http://127.0.0.1:2345/headersError" + sc `shouldBe` internalServerError500 + it "headerError" $ do + sc <- responseStatus <$> sendGET "http://127.0.0.1:2345/headerError" + sc `shouldBe` internalServerError500 + it "bodyError" $ do + sc <- responseStatus <$> sendGET "http://127.0.0.1:2345/bodyError" + sc `shouldBe` internalServerError500 + -} + it "ioException" $ withTestServer $ \prt -> do + sc <- + responseStatus <$> sendGET ("http://127.0.0.1:" ++ show prt ++ "/ioException") + sc `shouldBe` internalServerError500 diff --git a/warp/test/FileSpec.hs b/warp/test/FileSpec.hs index cd088ded0..bae10b4e6 100644 --- a/warp/test/FileSpec.hs +++ b/warp/test/FileSpec.hs @@ -15,7 +15,8 @@ import Test.Hspec main :: IO () main = hspec spec -changeHeaders :: (ResponseHeaders -> ResponseHeaders) -> RspFileInfo -> RspFileInfo +changeHeaders + :: (ResponseHeaders -> ResponseHeaders) -> RspFileInfo -> RspFileInfo changeHeaders f rfi = case rfi of WithBody s hs off len -> WithBody s (f hs) off len @@ -27,10 +28,11 @@ getHeaders rfi = WithBody _ hs _ _ -> hs _ -> [] -testFileRange :: String - -> RequestHeaders - -> RspFileInfo - -> Spec +testFileRange + :: String + -> RequestHeaders + -> RspFileInfo + -> Spec testFileRange desc reqhs ans = it desc $ do finfo <- getInfo "attic/hex" let f = (:) ("Last-Modified", fileInfoDate finfo) @@ -41,21 +43,25 @@ testFileRange desc reqhs ans = it desc $ do [] methodGet (indexResponseHeader hs) - (indexRequestHeader reqhs) `shouldBe` ans' + (indexRequestHeader reqhs) + `shouldBe` ans' farPast, farFuture :: ByteString farPast = "Thu, 01 Jan 1970 00:00:00 GMT" farFuture = "Sun, 05 Oct 3000 00:00:00 GMT" regularBody :: RspFileInfo -regularBody = WithBody ok200 [("Content-Length","16"),("Accept-Ranges","bytes")] 0 16 +regularBody = WithBody ok200 [("Content-Length", "16"), ("Accept-Ranges", "bytes")] 0 16 make206Body :: Integer -> Integer -> RspFileInfo make206Body start len = - WithBody status206 [crHeader, lenHeader, ("Accept-Ranges","bytes")] start len + WithBody status206 [crHeader, lenHeader, ("Accept-Ranges", "bytes")] start len where lenHeader = ("Content-Length", fromString $ show len) - crHeader = ("Content-Range", fromString $ "bytes " <> show start <> "-" <> show (start + len - 1) <> "/16") + crHeader = + ( "Content-Range" + , fromString $ "bytes " <> show start <> "-" <> show (start + len - 1) <> "/16" + ) spec :: Spec spec = do @@ -66,15 +72,15 @@ spec = do regularBody testFileRange "gets a file size from file system and handles Range and returns Partical Content" - [("Range","bytes=2-14")] + [("Range", "bytes=2-14")] $ make206Body 2 13 testFileRange "truncates end point of range to file size" - [("Range","bytes=10-20")] + [("Range", "bytes=10-20")] $ make206Body 10 6 testFileRange "gets a file size from file system and handles Range and returns OK if Range means the entire" - [("Range:","bytes=0-15")] + [("Range:", "bytes=0-15")] regularBody testFileRange "returns a 412 if the file has been changed in the meantime" @@ -90,7 +96,10 @@ spec = do regularBody testFileRange "still gives only a range, even after conditionals" - [("If-Match", "SomeETag"), ("If-Unmodified-Since", farPast), ("Range","bytes=10-20")] + [ ("If-Match", "SomeETag") + , ("If-Unmodified-Since", farPast) + , ("Range", "bytes=10-20") + ] $ make206Body 10 6 testFileRange "gets a file if the file has been changed in the meantime" @@ -106,13 +115,18 @@ spec = do regularBody testFileRange "still gives only a range, even after conditionals" - [("If-None-Match", "SomeETag"), ("If-Modified-Since", farFuture), ("Range","bytes=10-13")] + [ ("If-None-Match", "SomeETag") + , ("If-Modified-Since", farFuture) + , ("Range", "bytes=10-13") + ] $ make206Body 10 4 testFileRange "gives the a range, if the condition is met" - [("If-Range", fileInfoDate (unsafePerformIO $ getInfo "attic/hex")), ("Range","bytes=2-7")] + [ ("If-Range", fileInfoDate (unsafePerformIO $ getInfo "attic/hex")) + , ("Range", "bytes=2-7") + ] $ make206Body 2 6 testFileRange "gives the entire body and ignores the Range header if the condition isn't met" - [("If-Range", farPast), ("Range","bytes=2-7")] + [("If-Range", farPast), ("Range", "bytes=2-7")] regularBody diff --git a/warp/test/HTTP.hs b/warp/test/HTTP.hs index 0a591633b..09ffaf0a7 100644 --- a/warp/test/HTTP.hs +++ b/warp/test/HTTP.hs @@ -1,19 +1,19 @@ module HTTP ( - sendGET - , sendGETwH - , sendHEAD - , sendHEADwH - , responseBody - , responseStatus - , responseHeaders - , getHeaderValue - , HeaderName - ) where + sendGET, + sendGETwH, + sendHEAD, + sendHEADwH, + responseBody, + responseStatus, + responseHeaders, + getHeaderValue, + HeaderName, +) where -import Network.HTTP.Client -import Network.HTTP.Types import Data.ByteString import qualified Data.ByteString.Lazy as BL +import Network.HTTP.Client +import Network.HTTP.Types sendGET :: String -> IO (Response BL.ByteString) sendGET url = sendGETwH url [] @@ -22,7 +22,7 @@ sendGETwH :: String -> [Header] -> IO (Response BL.ByteString) sendGETwH url hdr = do manager <- newManager defaultManagerSettings request <- parseRequest url - let request' = request { requestHeaders = hdr } + let request' = request{requestHeaders = hdr} response <- httpLbs request' manager return response @@ -33,7 +33,7 @@ sendHEADwH :: String -> [Header] -> IO (Response BL.ByteString) sendHEADwH url hdr = do manager <- newManager defaultManagerSettings request <- parseRequest url - let request' = request { requestHeaders = hdr, method = methodHead } + let request' = request{requestHeaders = hdr, method = methodHead} response <- httpLbs request' manager return response diff --git a/warp/test/ReadIntSpec.hs b/warp/test/ReadIntSpec.hs index 59fad5dc3..a9dfb3f37 100644 --- a/warp/test/ReadIntSpec.hs +++ b/warp/test/ReadIntSpec.hs @@ -1,9 +1,9 @@ module ReadIntSpec (main, spec) where import Data.ByteString (ByteString) -import Test.Hspec -import Network.Wai.Handler.Warp.ReadInt import qualified Data.ByteString.Char8 as B +import Network.Wai.Handler.Warp.ReadInt +import Test.Hspec import qualified Test.QuickCheck as QC main :: IO () @@ -11,7 +11,8 @@ main = hspec spec spec :: Spec spec = describe "readInt64" $ do - it "converts ByteString to Int" $ QC.property (prop_read_show_idempotent readInt64) + it "converts ByteString to Int" $ + QC.property (prop_read_show_idempotent readInt64) -- A QuickCheck property. Test that for a number >= 0, converting it to -- a string using show and then reading the value back with the function @@ -19,7 +20,8 @@ spec = describe "readInt64" $ do -- The functions under test only work on Natural numbers (the Conent-Length -- field in a HTTP header is always >= 0) so we check the absolute value of -- the value that QuickCheck generates for us. -prop_read_show_idempotent :: (Integral a, Show a) => (ByteString -> a) -> a -> Bool +prop_read_show_idempotent + :: (Integral a, Show a) => (ByteString -> a) -> a -> Bool prop_read_show_idempotent freader x = px == freader (toByteString px) where px = abs x diff --git a/warp/test/RequestSpec.hs b/warp/test/RequestSpec.hs index 96724d4d6..f3e6c7522 100644 --- a/warp/test/RequestSpec.hs +++ b/warp/test/RequestSpec.hs @@ -1,18 +1,21 @@ -{-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} module RequestSpec (main, spec) where +import qualified Data.ByteString as S +import qualified Data.ByteString.Char8 as S8 +import Data.IORef +import qualified Network.HTTP.Types.Header as HH import Network.Wai.Handler.Warp.File (parseByteRanges) import Network.Wai.Handler.Warp.Request +import Network.Wai.Handler.Warp.Settings ( + defaultSettings, + settingsMaxTotalHeaderLength, + ) import Network.Wai.Handler.Warp.Types -import Network.Wai.Handler.Warp.Settings (settingsMaxTotalHeaderLength, defaultSettings) import Test.Hspec import Test.Hspec.QuickCheck -import qualified Data.ByteString as S -import qualified Data.ByteString.Char8 as S8 -import qualified Network.HTTP.Types.Header as HH -import Data.IORef main :: IO () main = hspec spec @@ -22,86 +25,95 @@ defaultMaxTotalHeaderLength = settingsMaxTotalHeaderLength defaultSettings spec :: Spec spec = do - describe "headerLines" $ do - it "takes until blank" $ - blankSafe `shouldReturn` ("", ["foo", "bar", "baz"]) - it "ignored leading whitespace in bodies" $ - whiteSafe `shouldReturn` (" hi there", ["foo", "bar", "baz"]) - it "throws OverLargeHeader when too many" $ - tooMany `shouldThrow` overLargeHeader - it "throws OverLargeHeader when too large" $ - tooLarge `shouldThrow` overLargeHeader - it "known bad chunking behavior #239" $ do - let chunks = - [ "GET / HTTP/1.1\r\nConnection: Close\r" - , "\n\r\n" - ] - (actual, src) <- headerLinesList' chunks - leftover <- readLeftoverSource src - leftover `shouldBe` S.empty - actual `shouldBe` ["GET / HTTP/1.1", "Connection: Close"] - prop "random chunking" $ \breaks extraS -> do - let bsFull = "GET / HTTP/1.1\r\nConnection: Close\r\n\r\n" `S8.append` extra - extra = S8.pack extraS - chunks = loop breaks bsFull - loop [] bs = [bs, undefined] - loop (x:xs) bs = - bs1 : loop xs bs2 - where - (bs1, bs2) = S8.splitAt ((x `mod` 10) + 1) bs - (actual, src) <- headerLinesList' chunks - leftover <- consumeLen (length extraS) src - - actual `shouldBe` ["GET / HTTP/1.1", "Connection: Close"] - leftover `shouldBe` extra - describe "parseByteRanges" $ do - let test x y = it x $ parseByteRanges (S8.pack x) `shouldBe` y - test "bytes=0-499" $ Just [HH.ByteRangeFromTo 0 499] - test "bytes=500-999" $ Just [HH.ByteRangeFromTo 500 999] - test "bytes=-500" $ Just [HH.ByteRangeSuffix 500] - test "bytes=9500-" $ Just [HH.ByteRangeFrom 9500] - test "foobytes=9500-" Nothing - test "bytes=0-0,-1" $ Just [HH.ByteRangeFromTo 0 0, HH.ByteRangeSuffix 1] - - describe "headerLines" $ do - it "can handle a normal case" $ do - src <- mkSourceFunc ["Status: 200\r\nContent-Type: text/plain\r\n\r\n"] >>= mkSource - x <- headerLines defaultMaxTotalHeaderLength True src - x `shouldBe` ["Status: 200", "Content-Type: text/plain"] - - it "can handle a nasty case (1)" $ do - src <- mkSourceFunc ["Status: 200", "\r\nContent-Type: text/plain", "\r\n\r\n"] >>= mkSource - x <- headerLines defaultMaxTotalHeaderLength True src - x `shouldBe` ["Status: 200", "Content-Type: text/plain"] - - it "can handle a nasty case (1)" $ do - src <- mkSourceFunc ["Status: 200", "\r", "\nContent-Type: text/plain", "\r", "\n\r\n"] >>= mkSource - x <- headerLines defaultMaxTotalHeaderLength True src - x `shouldBe` ["Status: 200", "Content-Type: text/plain"] - - it "can handle a nasty case (1)" $ do - src <- mkSourceFunc ["Status: 200", "\r", "\n", "Content-Type: text/plain", "\r", "\n", "\r", "\n"] >>= mkSource - x <- headerLines defaultMaxTotalHeaderLength True src - x `shouldBe` ["Status: 200", "Content-Type: text/plain"] - - it "can handle an illegal case (1)" $ do - src <- mkSourceFunc ["\nStatus:", "\n 200", "\nContent-Type: text/plain", "\r\n\r\n"] >>= mkSource - x <- headerLines defaultMaxTotalHeaderLength True src - x `shouldBe` [] - y <- headerLines defaultMaxTotalHeaderLength True src - y `shouldBe` ["Status: 200", "Content-Type: text/plain"] - - -- Length is 39, this shouldn't fail - let testLengthHeaders = ["Sta", "tus: 200\r", "\n", "Content-Type: ", "text/plain\r\n\r\n"] - it "doesn't throw on correct length" $ do - src <- mkSourceFunc testLengthHeaders >>= mkSource - x <- headerLines 39 True src - x `shouldBe` ["Status: 200", "Content-Type: text/plain"] - -- Length is still 39, this should fail - it "throws error on correct length too long" $ do - src <- mkSourceFunc testLengthHeaders >>= mkSource - headerLines 38 True src `shouldThrow` (== OverLargeHeader) - + describe "headerLines" $ do + it "takes until blank" $ + blankSafe `shouldReturn` ("", ["foo", "bar", "baz"]) + it "ignored leading whitespace in bodies" $ + whiteSafe `shouldReturn` (" hi there", ["foo", "bar", "baz"]) + it "throws OverLargeHeader when too many" $ + tooMany `shouldThrow` overLargeHeader + it "throws OverLargeHeader when too large" $ + tooLarge `shouldThrow` overLargeHeader + it "known bad chunking behavior #239" $ do + let chunks = + [ "GET / HTTP/1.1\r\nConnection: Close\r" + , "\n\r\n" + ] + (actual, src) <- headerLinesList' chunks + leftover <- readLeftoverSource src + leftover `shouldBe` S.empty + actual `shouldBe` ["GET / HTTP/1.1", "Connection: Close"] + prop "random chunking" $ \breaks extraS -> do + let bsFull = "GET / HTTP/1.1\r\nConnection: Close\r\n\r\n" `S8.append` extra + extra = S8.pack extraS + chunks = loop breaks bsFull + loop [] bs = [bs, undefined] + loop (x : xs) bs = + bs1 : loop xs bs2 + where + (bs1, bs2) = S8.splitAt ((x `mod` 10) + 1) bs + (actual, src) <- headerLinesList' chunks + leftover <- consumeLen (length extraS) src + + actual `shouldBe` ["GET / HTTP/1.1", "Connection: Close"] + leftover `shouldBe` extra + describe "parseByteRanges" $ do + let test x y = it x $ parseByteRanges (S8.pack x) `shouldBe` y + test "bytes=0-499" $ Just [HH.ByteRangeFromTo 0 499] + test "bytes=500-999" $ Just [HH.ByteRangeFromTo 500 999] + test "bytes=-500" $ Just [HH.ByteRangeSuffix 500] + test "bytes=9500-" $ Just [HH.ByteRangeFrom 9500] + test "foobytes=9500-" Nothing + test "bytes=0-0,-1" $ Just [HH.ByteRangeFromTo 0 0, HH.ByteRangeSuffix 1] + + describe "headerLines" $ do + it "can handle a normal case" $ do + src <- + mkSourceFunc ["Status: 200\r\nContent-Type: text/plain\r\n\r\n"] >>= mkSource + x <- headerLines defaultMaxTotalHeaderLength True src + x `shouldBe` ["Status: 200", "Content-Type: text/plain"] + + it "can handle a nasty case (1)" $ do + src <- + mkSourceFunc ["Status: 200", "\r\nContent-Type: text/plain", "\r\n\r\n"] + >>= mkSource + x <- headerLines defaultMaxTotalHeaderLength True src + x `shouldBe` ["Status: 200", "Content-Type: text/plain"] + + it "can handle a nasty case (1)" $ do + src <- + mkSourceFunc ["Status: 200", "\r", "\nContent-Type: text/plain", "\r", "\n\r\n"] + >>= mkSource + x <- headerLines defaultMaxTotalHeaderLength True src + x `shouldBe` ["Status: 200", "Content-Type: text/plain"] + + it "can handle a nasty case (1)" $ do + src <- + mkSourceFunc + ["Status: 200", "\r", "\n", "Content-Type: text/plain", "\r", "\n", "\r", "\n"] + >>= mkSource + x <- headerLines defaultMaxTotalHeaderLength True src + x `shouldBe` ["Status: 200", "Content-Type: text/plain"] + + it "can handle an illegal case (1)" $ do + src <- + mkSourceFunc ["\nStatus:", "\n 200", "\nContent-Type: text/plain", "\r\n\r\n"] + >>= mkSource + x <- headerLines defaultMaxTotalHeaderLength True src + x `shouldBe` [] + y <- headerLines defaultMaxTotalHeaderLength True src + y `shouldBe` ["Status: 200", "Content-Type: text/plain"] + + -- Length is 39, this shouldn't fail + let testLengthHeaders = ["Sta", "tus: 200\r", "\n", "Content-Type: ", "text/plain\r\n\r\n"] + it "doesn't throw on correct length" $ do + src <- mkSourceFunc testLengthHeaders >>= mkSource + x <- headerLines 39 True src + x `shouldBe` ["Status: 200", "Content-Type: text/plain"] + -- Length is still 39, this should fail + it "throws error on correct length too long" $ do + src <- mkSourceFunc testLengthHeaders >>= mkSource + headerLines 38 True src `shouldThrow` (== OverLargeHeader) where blankSafe = headerLinesList ["f", "oo\n", "bar\nbaz\n\r\n"] whiteSafe = headerLinesList ["foo\r\nbar\r\nbaz\r\n\r\n hi there"] @@ -121,7 +133,7 @@ headerLinesList' orig = do x <- readIORef ref case x of [] -> return S.empty - y:z -> do + y : z -> do writeIORef ref z return y src' <- mkSource src @@ -140,7 +152,7 @@ consumeLen len0 src = then loop front 0 else do let (x, _) = S.splitAt len bs - loop (front . (x:)) (len - S.length x) + loop (front . (x :)) (len - S.length x) overLargeHeader :: Selector InvalidRequest overLargeHeader e = e == OverLargeHeader @@ -153,7 +165,7 @@ mkSourceFunc bss = do reader ref = do xss <- readIORef ref case xss of - [] -> return S.empty - (x:xs) -> do + [] -> return S.empty + (x : xs) -> do writeIORef ref xs return x diff --git a/warp/test/ResponseHeaderSpec.hs b/warp/test/ResponseHeaderSpec.hs index 361128656..d8ca6df9c 100644 --- a/warp/test/ResponseHeaderSpec.hs +++ b/warp/test/ResponseHeaderSpec.hs @@ -4,9 +4,9 @@ module ResponseHeaderSpec (main, spec) where import Data.ByteString import qualified Network.HTTP.Types as H -import Network.Wai.Handler.Warp.ResponseHeader -import Network.Wai.Handler.Warp.Response import Network.Wai.Handler.Warp.Header +import Network.Wai.Handler.Warp.Response +import Network.Wai.Handler.Warp.ResponseHeader import Test.Hspec main :: IO () @@ -21,9 +21,9 @@ spec = do it "adds Server if not exist" $ do let hdrs = [] rspidxhdr = indexResponseHeader hdrs - addServer "MyServer" rspidxhdr hdrs `shouldBe` [("Server","MyServer")] + addServer "MyServer" rspidxhdr hdrs `shouldBe` [("Server", "MyServer")] it "does not add Server if exists" $ do - let hdrs = [("Server","MyServer")] + let hdrs = [("Server", "MyServer")] rspidxhdr = indexResponseHeader hdrs addServer "MyServer2" rspidxhdr hdrs `shouldBe` hdrs it "does not add Server if empty" $ do @@ -31,18 +31,19 @@ spec = do rspidxhdr = indexResponseHeader hdrs addServer "" rspidxhdr hdrs `shouldBe` hdrs it "deletes Server " $ do - let hdrs = [("Server","MyServer")] + let hdrs = [("Server", "MyServer")] rspidxhdr = indexResponseHeader hdrs addServer "" rspidxhdr hdrs `shouldBe` [] headers :: H.ResponseHeaders -headers = [ - ("Date", "Mon, 13 Aug 2012 04:22:55 GMT") - , ("Content-Length", "151") - , ("Server", "Mighttpd/2.5.8") - , ("Last-Modified", "Fri, 22 Jun 2012 01:18:08 GMT") - , ("Content-Type", "text/html") - ] +headers = + [ ("Date", "Mon, 13 Aug 2012 04:22:55 GMT") + , ("Content-Length", "151") + , ("Server", "Mighttpd/2.5.8") + , ("Last-Modified", "Fri, 22 Jun 2012 01:18:08 GMT") + , ("Content-Type", "text/html") + ] composedHeader :: ByteString -composedHeader = "HTTP/1.1 200 OK\r\nDate: Mon, 13 Aug 2012 04:22:55 GMT\r\nContent-Length: 151\r\nServer: Mighttpd/2.5.8\r\nLast-Modified: Fri, 22 Jun 2012 01:18:08 GMT\r\nContent-Type: text/html\r\n\r\n" +composedHeader = + "HTTP/1.1 200 OK\r\nDate: Mon, 13 Aug 2012 04:22:55 GMT\r\nContent-Length: 151\r\nServer: Mighttpd/2.5.8\r\nLast-Modified: Fri, 22 Jun 2012 01:18:08 GMT\r\nContent-Type: text/html\r\n\r\n" diff --git a/warp/test/ResponseSpec.hs b/warp/test/ResponseSpec.hs index ddc12c77e..75569262f 100644 --- a/warp/test/ResponseSpec.hs +++ b/warp/test/ResponseSpec.hs @@ -10,16 +10,20 @@ import Network.HTTP.Types import Network.Wai hiding (responseHeaders) import Network.Wai.Handler.Warp import Network.Wai.Handler.Warp.Response -import RunSpec (withApp, msWrite, msRead, withMySocket) +import RunSpec (msRead, msWrite, withApp, withMySocket) import Test.Hspec main :: IO () main = hspec spec -testRange :: S.ByteString -- ^ range value - -> String -- ^ expected output - -> Maybe String -- ^ expected content-range value - -> Spec +testRange + :: S.ByteString + -- ^ range value + -> String + -- ^ expected output + -> Maybe String + -- ^ expected content-range value + -> Spec testRange range out crange = it title $ withApp defaultSettings app $ withMySocket $ \ms -> do msWrite ms "GET / HTTP/1.0\r\n" msWrite ms "Range: bytes=" @@ -36,14 +40,19 @@ testRange range out crange = it title $ withApp defaultSettings app $ withMySock title = show (range, out, crange) toHeader s = case break (== ':') s of - (x, ':':y) -> Just (x, dropWhile (== ' ') y) + (x, ':' : y) -> Just (x, dropWhile (== ' ') y) _ -> Nothing -testPartial :: Integer -- ^ file size - -> Integer -- ^ offset - -> Integer -- ^ byte count - -> String -- ^ expected output - -> Spec +testPartial + :: Integer + -- ^ file size + -> Integer + -- ^ offset + -> Integer + -- ^ byte count + -> String + -- ^ expected output + -> Spec testPartial size offset count out = it title $ withApp defaultSettings app $ withMySocket $ \ms -> do msWrite ms "GET / HTTP/1.0\r\n\r\n" threadDelay 10000 @@ -57,21 +66,22 @@ testPartial size offset count out = it title $ withApp defaultSettings app $ wit title = show (offset, count, out) toHeader s = case break (== ':') s of - (x, ':':y) -> Just (x, dropWhile (== ' ') y) + (x, ':' : y) -> Just (x, dropWhile (== ' ') y) _ -> Nothing - range = "bytes " ++ show offset ++ "-" ++ show (offset + count - 1) ++ "/" ++ show size + range = + "bytes " ++ show offset ++ "-" ++ show (offset + count - 1) ++ "/" ++ show size spec :: Spec spec = do -{- http-client does not support this. - describe "preventing response splitting attack" $ do - it "sanitizes header values" $ do - let app _ respond = respond $ responseLBS status200 [("foo", "foo\r\nbar")] "Hello" - withApp defaultSettings app $ \port -> do - res <- sendGET $ "http://127.0.0.1:" ++ show port - getHeaderValue "foo" (responseHeaders res) `shouldBe` - Just "foo bar" -- HTTP inserts two spaces for \r\n. --} + {- http-client does not support this. + describe "preventing response splitting attack" $ do + it "sanitizes header values" $ do + let app _ respond = respond $ responseLBS status200 [("foo", "foo\r\nbar")] "Hello" + withApp defaultSettings app $ \port -> do + res <- sendGET $ "http://127.0.0.1:" ++ show port + getHeaderValue "foo" (responseHeaders res) `shouldBe` + Just "foo bar" -- HTTP inserts two spaces for \r\n. + -} describe "sanitizeHeaderValue" $ do it "doesn't alter valid multiline header values" $ do diff --git a/warp/test/RunSpec.hs b/warp/test/RunSpec.hs index 5aab5746a..9e93f87d2 100644 --- a/warp/test/RunSpec.hs +++ b/warp/test/RunSpec.hs @@ -4,10 +4,8 @@ module RunSpec (main, spec, withApp, MySocket, msWrite, msRead, withMySocket) where import Control.Concurrent (forkIO, killThread, threadDelay) -import Control.Concurrent.MVar (newEmptyMVar, takeMVar, putMVar) +import Control.Concurrent.MVar (newEmptyMVar, putMVar, takeMVar) import Control.Concurrent.STM -import qualified UnliftIO.Exception as E -import UnliftIO.Exception (bracket, try, IOException, onException) import Control.Monad (forM_, replicateM_, unless) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.ByteString (ByteString) @@ -25,6 +23,8 @@ import Network.Wai.Handler.Warp import System.IO.Unsafe (unsafePerformIO) import System.Timeout (timeout) import Test.Hspec +import UnliftIO.Exception (IOException, bracket, onException, try) +import qualified UnliftIO.Exception as E import HTTP @@ -35,35 +35,35 @@ type Counter = I.IORef (Either String Int) type CounterApplication = Counter -> Application data MySocket = MySocket - { msSocket :: !Socket - , msBuffer :: !(I.IORef ByteString) - } + { msSocket :: !Socket + , msBuffer :: !(I.IORef ByteString) + } msWrite :: MySocket -> ByteString -> IO () msWrite = sendAll . msSocket msRead :: MySocket -> Int -> IO ByteString msRead (MySocket s ref) expected = do - bs <- I.readIORef ref - inner (bs:) (S.length bs) + bs <- I.readIORef ref + inner (bs :) (S.length bs) where inner front total = - case compare total expected of - EQ -> do - I.writeIORef ref mempty - pure $ S.concat $ front [] - GT -> do - let bs = S.concat $ front [] - (x, y) = S.splitAt expected bs - I.writeIORef ref y - pure x - LT -> do - bs <- safeRecv s 4096 - if S.null bs - then do - I.writeIORef ref mempty - pure $ S.concat $ front [] - else inner (front . (bs:)) (total + S.length bs) + case compare total expected of + EQ -> do + I.writeIORef ref mempty + pure $ S.concat $ front [] + GT -> do + let bs = S.concat $ front [] + (x, y) = S.splitAt expected bs + I.writeIORef ref y + pure x + LT -> do + bs <- safeRecv s 4096 + if S.null bs + then do + I.writeIORef ref mempty + pure $ S.concat $ front [] + else inner (front . (bs :)) (total + S.length bs) msClose :: MySocket -> IO () msClose = Network.Socket.close . msSocket @@ -72,19 +72,22 @@ connectTo :: Int -> IO MySocket connectTo port = do s <- fst <$> getSocketTCP "127.0.0.1" port ref <- I.newIORef mempty - return MySocket { - msSocket = s - , msBuffer = ref - } + return + MySocket + { msSocket = s + , msBuffer = ref + } withMySocket :: (MySocket -> IO a) -> Int -> IO a withMySocket body port = bracket (connectTo port) msClose body incr :: MonadIO m => Counter -> m () incr icount = liftIO $ I.atomicModifyIORef icount $ \ecount -> - (case ecount of + ( case ecount of Left s -> Left s - Right i -> Right $ i + 1, ()) + Right i -> Right $ i + 1 + , () + ) err :: (MonadIO m, Show a) => Counter -> a -> m () err icount msg = liftIO $ I.writeIORef icount $ Left $ show msg @@ -94,12 +97,12 @@ readBody icount req f = do body <- consumeBody $ getRequestBodyChunk req case () of () - | pathInfo req == ["hello"] && L.fromChunks body /= "Hello" - -> err icount ("Invalid hello" :: String, body) - | requestMethod req == "GET" && L.fromChunks body /= "" - -> err icount ("Invalid GET" :: String, body) - | requestMethod req `notElem` ["GET", "POST"] - -> err icount ("Invalid request method (readBody)" :: String, requestMethod req) + | pathInfo req == ["hello"] && L.fromChunks body /= "Hello" -> + err icount ("Invalid hello" :: String, body) + | requestMethod req == "GET" && L.fromChunks body /= "" -> + err icount ("Invalid GET" :: String, body) + | requestMethod req `notElem` ["GET", "POST"] -> + err icount ("Invalid request method (readBody)" :: String, requestMethod req) | otherwise -> incr icount f $ responseLBS status200 [] "Read the body" @@ -135,25 +138,31 @@ withApp :: Settings -> Application -> (Int -> IO a) -> IO a withApp settings app f = do port <- RunSpec.getPort baton <- newEmptyMVar - let settings' = setPort port - $ setHost "127.0.0.1" - $ setBeforeMainLoop (putMVar baton ()) - settings + let settings' = + setPort port $ + setHost "127.0.0.1" $ + setBeforeMainLoop + (putMVar baton ()) + settings bracket (forkIO $ runSettings settings' app `onException` putMVar baton ()) killThread - (const $ do + ( const $ do takeMVar baton -- use timeout to make sure we don't take too long mres <- timeout (60 * 1000 * 1000) (f port) case mres of - Nothing -> error "Timeout triggered, too slow!" - Just a -> pure a) - -runTest :: Int -- ^ expected number of requests - -> CounterApplication - -> [ByteString] -- ^ chunks to send - -> IO () + Nothing -> error "Timeout triggered, too slow!" + Just a -> pure a + ) + +runTest + :: Int + -- ^ expected number of requests + -> CounterApplication + -> [ByteString] + -- ^ chunks to send + -> IO () runTest expected app chunks = do ref <- I.newIORef (Right 0) withApp defaultSettings (app ref) $ withMySocket $ \ms -> do @@ -167,9 +176,10 @@ runTest expected app chunks = do dummyApp :: Application dummyApp _ f = f $ responseLBS status200 [] "foo" -runTerminateTest :: InvalidRequest - -> ByteString - -> IO () +runTerminateTest + :: InvalidRequest + -> ByteString + -> IO () runTerminateTest expected input = do ref <- I.newIORef Nothing let onExc _ = I.writeIORef ref . Just @@ -197,43 +207,66 @@ spec = do describe "non-pipelining" $ do it "no body, read" $ runTest 5 readBody $ replicate 5 singleGet it "no body, ignore" $ runTest 5 ignoreBody $ replicate 5 singleGet - it "has body, read" $ runTest 2 readBody - [ singlePostHello - , singleGet - ] - it "has body, ignore" $ runTest 2 ignoreBody - [ singlePostHello - , singleGet - ] - it "chunked body, read" $ runTest 2 readBody $ - singleChunkedPostHello ++ [singleGet] - it "chunked body, ignore" $ runTest 2 ignoreBody $ - singleChunkedPostHello ++ [singleGet] + it "has body, read" $ + runTest + 2 + readBody + [ singlePostHello + , singleGet + ] + it "has body, ignore" $ + runTest + 2 + ignoreBody + [ singlePostHello + , singleGet + ] + it "chunked body, read" $ + runTest 2 readBody $ + singleChunkedPostHello ++ [singleGet] + it "chunked body, ignore" $ + runTest 2 ignoreBody $ + singleChunkedPostHello ++ [singleGet] describe "pipelining" $ do it "no body, read" $ runTest 5 readBody [S.concat $ replicate 5 singleGet] it "no body, ignore" $ runTest 5 ignoreBody [S.concat $ replicate 5 singleGet] - it "has body, read" $ runTest 2 readBody $ return $ S.concat - [ singlePostHello - , singleGet - ] - it "has body, ignore" $ runTest 2 ignoreBody $ return $ S.concat - [ singlePostHello - , singleGet - ] - it "chunked body, read" $ runTest 2 readBody $ return $ S.concat - [ S.concat singleChunkedPostHello - , singleGet - ] - it "chunked body, ignore" $ runTest 2 ignoreBody $ return $ S.concat - [ S.concat singleChunkedPostHello - , singleGet - ] + it "has body, read" $ + runTest 2 readBody $ + return $ + S.concat + [ singlePostHello + , singleGet + ] + it "has body, ignore" $ + runTest 2 ignoreBody $ + return $ + S.concat + [ singlePostHello + , singleGet + ] + it "chunked body, read" $ + runTest 2 readBody $ + return $ + S.concat + [ S.concat singleChunkedPostHello + , singleGet + ] + it "chunked body, ignore" $ + runTest 2 ignoreBody $ + return $ + S.concat + [ S.concat singleChunkedPostHello + , singleGet + ] describe "no hanging" $ do - it "has body, read" $ runTest 1 readBody $ map S.singleton $ S.unpack singlePostHello + it "has body, read" $ + runTest 1 readBody $ + map S.singleton $ + S.unpack singlePostHello it "double connect" $ runTest 1 doubleConnect [singlePostHello] describe "connection termination" $ do --- it "ConnectionClosedByPeer" $ runTerminateTest ConnectionClosedByPeer "GET / HTTP/1.1\r\ncontent-length: 10\r\n\r\nhello" + -- it "ConnectionClosedByPeer" $ runTerminateTest ConnectionClosedByPeer "GET / HTTP/1.1\r\ncontent-length: 10\r\n\r\nhello" it "IncompleteHeaders" $ runTerminateTest IncompleteHeaders "GET / HTTP/1.1\r\ncontent-length: 10\r\n" @@ -244,30 +277,32 @@ spec = do liftIO $ I.writeIORef iheaders $ requestHeaders req f $ responseLBS status200 [] "" withApp defaultSettings app $ withMySocket $ \ms -> do - let input = S.concat - [ "GET / HTTP/1.1\r\nfoo: bar\r\n baz\r\n\tbin\r\n\r\n" - ] + let input = + S.concat + [ "GET / HTTP/1.1\r\nfoo: bar\r\n baz\r\n\tbin\r\n\r\n" + ] msWrite ms input threadDelay 5000 headers <- I.readIORef iheaders - headers `shouldBe` - [ ("foo", "bar baz\tbin") - ] + headers + `shouldBe` [ ("foo", "bar baz\tbin") + ] it "no space between colon and value" $ do iheaders <- I.newIORef [] let app req f = do liftIO $ I.writeIORef iheaders $ requestHeaders req f $ responseLBS status200 [] "" withApp defaultSettings app $ withMySocket $ \ms -> do - let input = S.concat - [ "GET / HTTP/1.1\r\nfoo:bar\r\n\r\n" - ] + let input = + S.concat + [ "GET / HTTP/1.1\r\nfoo:bar\r\n\r\n" + ] msWrite ms input threadDelay 5000 headers <- I.readIORef iheaders - headers `shouldBe` - [ ("foo", "bar") - ] + headers + `shouldBe` [ ("foo", "bar") + ] describe "chunked bodies" $ do it "works" $ do @@ -275,81 +310,92 @@ spec = do ifront <- I.newIORef id let app req f = do bss <- consumeBody $ getRequestBodyChunk req - liftIO $ I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss:), ()) + liftIO $ I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss :), ()) atomically $ modifyTVar countVar (+ 1) f $ responseLBS status200 [] "" withApp defaultSettings app $ withMySocket $ \ms -> do - let input = S.concat - [ "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n" - , "c\r\nHello World\n\r\n3\r\nBye\r\n0\r\n\r\n" - , "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n" - , "b\r\nHello World\r\n0\r\n\r\n" - ] + let input = + S.concat + [ "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n" + , "c\r\nHello World\n\r\n3\r\nBye\r\n0\r\n\r\n" + , "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n" + , "b\r\nHello World\r\n0\r\n\r\n" + ] msWrite ms input atomically $ do - count <- readTVar countVar - check $ count == 2 + count <- readTVar countVar + check $ count == 2 front <- I.readIORef ifront - front [] `shouldBe` - [ "Hello World\nBye" - , "Hello World" - ] + front [] + `shouldBe` [ "Hello World\nBye" + , "Hello World" + ] it "lots of chunks" $ do ifront <- I.newIORef id countVar <- newTVarIO (0 :: Int) let app req f = do bss <- consumeBody $ getRequestBodyChunk req - I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss:), ()) + I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss :), ()) atomically $ modifyTVar countVar (+ 1) f $ responseLBS status200 [] "" withApp defaultSettings app $ withMySocket $ \ms -> do - let input = concat $ replicate 2 $ - ["POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n"] ++ - replicate 50 "5\r\n12345\r\n" ++ - ["0\r\n\r\n"] + let input = + concat $ + replicate 2 $ + ["POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n"] + ++ replicate 50 "5\r\n12345\r\n" + ++ ["0\r\n\r\n"] mapM_ (msWrite ms) input atomically $ do - count <- readTVar countVar - check $ count == 2 + count <- readTVar countVar + check $ count == 2 front <- I.readIORef ifront front [] `shouldBe` replicate 2 (S.concat $ replicate 50 "12345") --- For some reason, the following test on Windows causes the socket --- to be killed prematurely. Worth investigating in the future if possible. + -- For some reason, the following test on Windows causes the socket + -- to be killed prematurely. Worth investigating in the future if possible. it "in chunks" $ do ifront <- I.newIORef id countVar <- newTVarIO (0 :: Int) let app req f = do bss <- consumeBody $ getRequestBodyChunk req - liftIO $ I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss:), ()) + liftIO $ I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss :), ()) atomically $ modifyTVar countVar (+ 1) f $ responseLBS status200 [] "" withApp defaultSettings app $ withMySocket $ \ms -> do - let input = S.concat - [ "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n" - , "c\r\nHello World\n\r\n3\r\nBye\r\n0\r\n" - , "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n" - , "b\r\nHello World\r\n0\r\n\r\n" - ] + let input = + S.concat + [ "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n" + , "c\r\nHello World\n\r\n3\r\nBye\r\n0\r\n" + , "POST / HTTP/1.1\r\nTransfer-Encoding: Chunked\r\n\r\n" + , "b\r\nHello World\r\n0\r\n\r\n" + ] mapM_ (msWrite ms . S.singleton) $ S.unpack input atomically $ do - count <- readTVar countVar - check $ count == 2 + count <- readTVar countVar + check $ count == 2 front <- I.readIORef ifront - front [] `shouldBe` - [ "Hello World\nBye" - , "Hello World" - ] + front [] + `shouldBe` [ "Hello World\nBye" + , "Hello World" + ] it "timeout in request body" $ do ifront <- I.newIORef id let app req f = do - bss <- consumeBody (getRequestBodyChunk req) `onException` - liftIO (I.atomicModifyIORef ifront (\front -> (front . ("consume interrupted":), ()))) - liftIO $ threadDelay 4000000 `E.catch` \e -> do - I.atomicModifyIORef ifront (\front -> - ( front . ((S8.pack $ "threadDelay interrupted: " ++ show e):) - , ())) - E.throwIO (e :: E.SomeException) - liftIO $ I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss:), ()) + bss <- + consumeBody (getRequestBodyChunk req) + `onException` liftIO + (I.atomicModifyIORef ifront (\front -> (front . ("consume interrupted" :), ()))) + liftIO $ + threadDelay 4000000 `E.catch` \e -> do + I.atomicModifyIORef + ifront + ( \front -> + ( front . ((S8.pack $ "threadDelay interrupted: " ++ show e) :) + , () + ) + ) + E.throwIO (e :: E.SomeException) + liftIO $ I.atomicModifyIORef ifront $ \front -> (front . (S.concat bss :), ()) f $ responseLBS status200 [] "" withApp (setTimeout 1 defaultSettings) app $ withMySocket $ \ms -> do let bs1 = S.replicate 2048 88 @@ -384,13 +430,18 @@ spec = do msWrite ms "67890" timeout 100000 (msRead ms 10) `shouldReturn` Just "6677889900" it "only one date and server header" $ do - let app _ f = f $ responseLBS status200 - [ ("server", "server") - , ("date", "date") - ] "" - getValues key = map snd - . filter (\(key', _) -> key == key') - . responseHeaders + let app _ f = + f $ + responseLBS + status200 + [ ("server", "server") + , ("date", "date") + ] + "" + getValues key = + map snd + . filter (\(key', _) -> key == key') + . responseHeaders withApp defaultSettings app $ \port -> do res <- sendGET $ "http://127.0.0.1:" ++ show port getValues hServer res `shouldBe` ["server"] @@ -399,20 +450,20 @@ spec = do it "streaming echo #249" $ do countVar <- newTVarIO (0 :: Int) let app req f = f $ responseStream status200 [] $ \write _ -> do - let loop = do - bs <- getRequestBodyChunk req - unless (S.null bs) $ do - write $ byteString bs - atomically $ modifyTVar countVar (+ 1) - loop - loop + let loop = do + bs <- getRequestBodyChunk req + unless (S.null bs) $ do + write $ byteString bs + atomically $ modifyTVar countVar (+ 1) + loop + loop withApp defaultSettings app $ withMySocket $ \ms -> do msWrite ms "POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" threadDelay 10000 msWrite ms "5\r\nhello\r\n0\r\n\r\n" atomically $ do - count <- readTVar countVar - check $ count >= 1 + count <- readTVar countVar + check $ count >= 1 bs <- safeRecv (msSocket ms) 4096 -- must not use msRead S.takeWhile (/= 13) bs `shouldBe` "HTTP/1.1 200 OK" @@ -441,11 +492,13 @@ spec = do it "file, no range" $ withApp defaultSettings app $ \port -> do bs <- S.readFile fp res <- sendHEAD $ concat ["http://127.0.0.1:", show port, "/file"] - getHeaderValue hContentLength (responseHeaders res) `shouldBe` Just (S8.pack $ show $ S.length bs) + getHeaderValue hContentLength (responseHeaders res) + `shouldBe` Just (S8.pack $ show $ S.length bs) it "file, with range" $ withApp defaultSettings app $ \port -> do - res <- sendHEADwH - (concat ["http://127.0.0.1:", show port, "/file"]) - [(hRange, "bytes=0-1")] + res <- + sendHEADwH + (concat ["http://127.0.0.1:", show port, "/file"]) + [(hRange, "bytes=0-1")] getHeaderValue hContentLength (responseHeaders res) `shouldBe` Just "2" consumeBody :: IO ByteString -> IO [ByteString] @@ -456,4 +509,4 @@ consumeBody body = bs <- body if S.null bs then return $ front [] - else loop $ front . (bs:) + else loop $ front . (bs :) diff --git a/warp/test/SendFileSpec.hs b/warp/test/SendFileSpec.hs index 19fe99221..8b385a54f 100644 --- a/warp/test/SendFileSpec.hs +++ b/warp/test/SendFileSpec.hs @@ -20,40 +20,44 @@ main = hspec spec spec :: Spec spec = do - describe "packHeader" $ do - it "returns how much the buffer is consumed (1)" $ - tryPackHeader 10 ["foo"] `shouldReturn` 3 - it "returns how much the buffer is consumed (2)" $ - tryPackHeader 10 ["foo", "bar"] `shouldReturn` 6 - it "returns how much the buffer is consumed (3)" $ - tryPackHeader 10 ["0123456789"] `shouldReturn` 0 - it "returns how much the buffer is consumed (4)" $ - tryPackHeader 10 ["01234", "56789"] `shouldReturn` 0 - it "returns how much the buffer is consumed (5)" $ - tryPackHeader 10 ["01234567890", "12"] `shouldReturn` 3 - it "returns how much the buffer is consumed (6)" $ - tryPackHeader 10 ["012345678901234567890123456789012", "34"] `shouldReturn` 5 + describe "packHeader" $ do + it "returns how much the buffer is consumed (1)" $ + tryPackHeader 10 ["foo"] `shouldReturn` 3 + it "returns how much the buffer is consumed (2)" $ + tryPackHeader 10 ["foo", "bar"] `shouldReturn` 6 + it "returns how much the buffer is consumed (3)" $ + tryPackHeader 10 ["0123456789"] `shouldReturn` 0 + it "returns how much the buffer is consumed (4)" $ + tryPackHeader 10 ["01234", "56789"] `shouldReturn` 0 + it "returns how much the buffer is consumed (5)" $ + tryPackHeader 10 ["01234567890", "12"] `shouldReturn` 3 + it "returns how much the buffer is consumed (6)" $ + tryPackHeader 10 ["012345678901234567890123456789012", "34"] `shouldReturn` 5 - it "sends headers correctly (1)" $ - tryPackHeader2 10 ["foo"] "" `shouldReturn` True - it "sends headers correctly (2)" $ - tryPackHeader2 10 ["foo", "bar"] "" `shouldReturn` True - it "sends headers correctly (3)" $ - tryPackHeader2 10 ["0123456789"] "0123456789" `shouldReturn` True - it "sends headers correctly (4)" $ - tryPackHeader2 10 ["01234", "56789"] "0123456789" `shouldReturn` True - it "sends headers correctly (5)" $ - tryPackHeader2 10 ["01234567890", "12"] "0123456789" `shouldReturn` True - it "sends headers correctly (6)" $ - tryPackHeader2 10 ["012345678901234567890123456789012", "34"] "012345678901234567890123456789" `shouldReturn` True + it "sends headers correctly (1)" $ + tryPackHeader2 10 ["foo"] "" `shouldReturn` True + it "sends headers correctly (2)" $ + tryPackHeader2 10 ["foo", "bar"] "" `shouldReturn` True + it "sends headers correctly (3)" $ + tryPackHeader2 10 ["0123456789"] "0123456789" `shouldReturn` True + it "sends headers correctly (4)" $ + tryPackHeader2 10 ["01234", "56789"] "0123456789" `shouldReturn` True + it "sends headers correctly (5)" $ + tryPackHeader2 10 ["01234567890", "12"] "0123456789" `shouldReturn` True + it "sends headers correctly (6)" $ + tryPackHeader2 + 10 + ["012345678901234567890123456789012", "34"] + "012345678901234567890123456789" + `shouldReturn` True - describe "readSendFile" $ do - it "sends a file correctly (1)" $ - tryReadSendFile 10 0 1474 ["foo"] `shouldReturn` ExitSuccess - it "sends a file correctly (2)" $ - tryReadSendFile 10 0 1474 ["012345678", "901234"] `shouldReturn` ExitSuccess - it "sends a file correctly (3)" $ - tryReadSendFile 10 20 100 ["012345678", "901234"] `shouldReturn` ExitSuccess + describe "readSendFile" $ do + it "sends a file correctly (1)" $ + tryReadSendFile 10 0 1474 ["foo"] `shouldReturn` ExitSuccess + it "sends a file correctly (2)" $ + tryReadSendFile 10 0 1474 ["012345678", "901234"] `shouldReturn` ExitSuccess + it "sends a file correctly (3)" $ + tryReadSendFile 10 20 100 ["012345678", "901234"] `shouldReturn` ExitSuccess tryPackHeader :: Int -> [ByteString] -> IO Int tryPackHeader siz hdrs = bracket (allocateBuffer siz) freeBuffer $ \buf -> @@ -95,20 +99,20 @@ tryReadSendFile siz off len hdrs = bracket setup teardown $ \buf -> do checkFile :: FilePath -> ByteString -> IO Bool checkFile path bs = do exist <- doesFileExist path - if exist then do - bs' <- BS.readFile path - return $ bs == bs' - else - return $ bs == "" + if exist + then do + bs' <- BS.readFile path + return $ bs == bs' + else return $ bs == "" compareFiles :: FilePath -> FilePath -> IO ExitCode compareFiles file1 file2 = system $ "cmp -s " ++ file1 ++ " " ++ file2 copyfile :: FilePath -> FilePath -> Integer -> Integer -> IO () copyfile src dst off len = - IO.withBinaryFile src IO.ReadMode $ \h -> do - IO.hSeek h IO.AbsoluteSeek off - BS.hGet h (fromIntegral len) >>= BS.appendFile dst + IO.withBinaryFile src IO.ReadMode $ \h -> do + IO.hSeek h IO.AbsoluteSeek off + BS.hGet h (fromIntegral len) >>= BS.appendFile dst removeFileIfExists :: FilePath -> IO () removeFileIfExists file = do diff --git a/warp/test/WithApplicationSpec.hs b/warp/test/WithApplicationSpec.hs index b786223c9..8d534fd6a 100644 --- a/warp/test/WithApplicationSpec.hs +++ b/warp/test/WithApplicationSpec.hs @@ -2,41 +2,45 @@ module WithApplicationSpec where -import Network.HTTP.Types -import Network.Wai -import System.Environment -import System.Process -import Test.Hspec -import UnliftIO.Exception +import Network.HTTP.Types +import Network.Wai +import System.Environment +import System.Process +import Test.Hspec +import UnliftIO.Exception -import Network.Wai.Handler.Warp.WithApplication +import Network.Wai.Handler.Warp.WithApplication -- All these tests assume the "curl" process can be called directly. spec :: Spec spec = do - runIO $ do - unsetEnv "http_proxy" - unsetEnv "https_proxy" - describe "\"curl\" dependency" $ - let msg = "All \"WithApplication\" tests assume the \"curl\" process can be called directly." - underline = replicate (length msg) '^' - in it (msg ++ "\n " ++ underline) True - describe "withApplication" $ do - it "runs a wai Application while executing the given action" $ do - let mkApp = return $ \ _request respond -> respond $ responseLBS ok200 [] "foo" - withApplication mkApp $ \ port -> do - output <- readProcess "curl" ["-s", "localhost:" ++ show port] "" - output `shouldBe` "foo" + runIO $ do + unsetEnv "http_proxy" + unsetEnv "https_proxy" + describe "\"curl\" dependency" $ + let msg = + "All \"WithApplication\" tests assume the \"curl\" process can be called directly." + underline = replicate (length msg) '^' + in it (msg ++ "\n " ++ underline) True + describe "withApplication" $ do + it "runs a wai Application while executing the given action" $ do + let mkApp = return $ \_request respond -> respond $ responseLBS ok200 [] "foo" + withApplication mkApp $ \port -> do + output <- readProcess "curl" ["-s", "localhost:" ++ show port] "" + output `shouldBe` "foo" - it "does not propagate exceptions from the server to the executing thread" $ do - let mkApp = return $ \ _request _respond -> throwString "foo" - withApplication mkApp $ \ port -> do - output <- readProcess "curl" ["-s", "localhost:" ++ show port] "" - output `shouldContain` "Something went wron" + it "does not propagate exceptions from the server to the executing thread" $ do + let mkApp = return $ \_request _respond -> throwString "foo" + withApplication mkApp $ \port -> do + output <- readProcess "curl" ["-s", "localhost:" ++ show port] "" + output `shouldContain` "Something went wron" - describe "testWithApplication" $ do - it "propagates exceptions from the server to the executing thread" $ do - let mkApp = return $ \ _request _respond -> throwString "foo" - testWithApplication mkApp (\ port -> do - readProcess "curl" ["-s", "localhost:" ++ show port] "") - `shouldThrow` (\(StringException str _) -> str == "foo") + describe "testWithApplication" $ do + it "propagates exceptions from the server to the executing thread" $ do + let mkApp = return $ \_request _respond -> throwString "foo" + testWithApplication + mkApp + ( \port -> do + readProcess "curl" ["-s", "localhost:" ++ show port] "" + ) + `shouldThrow` (\(StringException str _) -> str == "foo")