diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e2484c..172298e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,12 +8,13 @@ - Replace `array` with `primitive` - Make calling `cancel` more than once on a recurring timer not enter an infinite loop - Slightly improve timer insert performance +- Fix minor race condition in `register` -## [0.4.0.1] - 2022-11-05 +## [0.4.0.1] - November 5, 2022 - Fix inaccurate haddock on `recurring` -## [0.4.0] - 2022-11-05 +## [0.4.0] - November 5, 2022 - Add `create` - Rename `Data.TimerWheel` to `TimerWheel` @@ -21,7 +22,7 @@ - Treat negative delays as 0 - Drop support for GHC < 8.6 -## [0.3.0] - 2020-06-18 +## [0.3.0] - June 18, 2020 - Add `with` - Add support for GHC 8.8, GHC 8.10 @@ -35,11 +36,11 @@ - Remove `InvalidTimerWheelConfig` exception. `error` is used instead - Remove support for GHC < 8.6 -## [0.2.0.1] - 2019-05-19 +## [0.2.0.1] - May 19, 2019 - Swap out `ghc-prim` and `primitive` for `vector` -## [0.2.0] - 2019-02-03 +## [0.2.0] - February 3, 2019 - Add `destroy` function, for reaping the background thread - Add `recurring_` function @@ -54,6 +55,6 @@ spin forever and peg a CPU - Rename `new` to `create` - Make recurring timers more accurate -## [0.1.0] - 2018-07-18 +## [0.1.0] - July 18, 2018 - Initial release diff --git a/src/TimerWheel.hs b/src/TimerWheel.hs index fd94d31..c095465 100644 --- a/src/TimerWheel.hs +++ b/src/TimerWheel.hs @@ -16,21 +16,24 @@ module TimerWheel register_, recurring, recurring_, + reregister, ) where +import Control.Exception (mask_) import qualified Data.Atomics as Atomics -import Data.Foldable (for_) +import qualified Data.Map.Strict as Map import Data.Primitive.Array (MutableArray) import qualified Data.Primitive.Array as Array import GHC.Base (RealWorld) import qualified Ki import TimerWheel.Internal.Counter (Counter, decrCounter_, incrCounter, incrCounter_, newCounter, readCounter) -import TimerWheel.Internal.Entries (Entries) -import qualified TimerWheel.Internal.Entries as Entries import TimerWheel.Internal.Micros (Micros (..)) import qualified TimerWheel.Internal.Micros as Micros import TimerWheel.Internal.Prelude +import TimerWheel.Internal.Timer (Timer (..)) +import TimerWheel.Internal.TimerBucket (TimerBucket) +import qualified TimerWheel.Internal.TimerBucket as TimerBucket import TimerWheel.Internal.Timestamp (Timestamp) import qualified TimerWheel.Internal.Timestamp as Timestamp @@ -90,7 +93,7 @@ import qualified TimerWheel.Internal.Timestamp as Timestamp -- ↑ -- @ data TimerWheel = TimerWheel - { buckets :: {-# UNPACK #-} !(MutableArray RealWorld Entries), + { buckets :: {-# UNPACK #-} !(MutableArray RealWorld TimerBucket), resolution :: {-# UNPACK #-} !Micros, numTimers :: {-# UNPACK #-} !Counter, -- A counter to generate unique ints that identify registered actions, so they can be canceled. @@ -100,9 +103,6 @@ data TimerWheel = TimerWheel totalMicros :: {-# UNPACK #-} !Micros } --- Internal type alias for readability -type TimerId = Int - -- | Timer wheel config. -- -- * @spokes@ must be ∈ @[1, maxBound]@, and is set to @1024@ if invalid. @@ -124,10 +124,10 @@ create :: -- | ​ IO TimerWheel create scope (Config spokes0 resolution0) = do - buckets <- Array.newArray spokes Entries.empty + buckets <- Array.newArray spokes Map.empty numTimers <- newCounter timerIdSupply <- newCounter - Ki.fork_ scope (runTimerReaperThread buckets numTimers resolution) + Ki.fork_ scope (runTimerReaperThread buckets resolution) pure TimerWheel {buckets, numTimers, resolution, timerIdSupply, totalMicros} where spokes = if spokes0 <= 0 then 1024 else spokes0 @@ -178,14 +178,31 @@ register_ wheel delay action = do pure () registerImpl :: TimerWheel -> Micros -> IO () -> IO (IO Bool) -registerImpl TimerWheel {buckets, numTimers, resolution, timerIdSupply, totalMicros} delay action = do +registerImpl wheel delay action = do now <- Timestamp.now - incrCounter_ numTimers - timerId <- incrCounter timerIdSupply - let index = timestampToIndex buckets resolution (now `Timestamp.plus` delay) - let c = unMicros (delay `Micros.div` totalMicros) - atomicInsertIntoBucket buckets index (Entries.insert timerId c action) - pure (atomicDeleteFromBucket buckets index timerId) + + let timestamp = now `Timestamp.plus` delay + let index = timestampToIndex wheel.buckets wheel.resolution timestamp + let action1 = action >> decrCounter_ wheel.numTimers + + timerId <- incrCounter wheel.timerIdSupply + + -- Mask so that we don't increment the timer count, then get hit by an async exception before we can put the timer + -- in the wheel + mask_ do + incrCounter_ wheel.numTimers + atomicModifyArray wheel.buckets index (TimerBucket.insert timestamp (Timer timerId action1)) + + pure do + let loop :: Atomics.Ticket TimerBucket -> IO Bool + loop ticket = + case TimerBucket.delete timestamp timerId (Atomics.peekTicket ticket) of + DidntDelete -> pure False + Deleted bucket -> do + (success, ticket1) <- Atomics.casArrayElem wheel.buckets index ticket bucket + if success then pure True else loop ticket1 + ticket0 <- Atomics.readArrayElem wheel.buckets index + loop ticket0 -- | @recurring wheel action delay@ registers an action __@action@__ in timer wheel __@wheel@__ to fire every -- __@delay@__ seconds (but no more often than __@wheel@__'s /resolution/). @@ -273,19 +290,19 @@ recurring_ wheel (Micros.fromSeconds -> delay) action = do -- act as if it's still "one bucket ago" at the moment we re-register -- it. reregister :: TimerWheel -> Micros -> IO () -> IO (IO Bool) -reregister wheel@TimerWheel {resolution} delay = +reregister wheel delay = registerImpl wheel - if resolution > delay + if wheel.resolution > delay then Micros 0 - else delay `Micros.minus` resolution + else delay `Micros.minus` wheel.resolution -- | Get the number of timers in a timer wheel. -- -- /O(1)/. count :: TimerWheel -> IO Int -count TimerWheel {numTimers} = - readCounter numTimers +count wheel = + readCounter wheel.numTimers -- `timestampToIndex buckets resolution timestamp` figures out which index `timestamp` corresponds to in `buckets`, -- where each bucket corresponds to `resolution` microseconds. @@ -305,7 +322,7 @@ count TimerWheel {numTimers} = -- 2. Wrap around per the actual length of the array: -- -- 105329801 `rem` 3 = 2 -timestampToIndex :: MutableArray RealWorld Entries -> Micros -> Timestamp -> Int +timestampToIndex :: MutableArray RealWorld bucket -> Micros -> Timestamp -> Int timestampToIndex buckets resolution timestamp = -- This downcast is safe because there are at most `maxBound :: Int` buckets (not that anyone would ever have that -- many...) @@ -313,63 +330,53 @@ timestampToIndex buckets resolution timestamp = (Timestamp.epoch resolution timestamp `rem` fromIntegral @Int @Word64 (Array.sizeofMutableArray buckets)) ------------------------------------------------------------------------------------------------------------------------ --- Atomic operations on buckets +-- Atomic operations on arrays -atomicInsertIntoBucket :: MutableArray RealWorld Entries -> Int -> (Entries -> Entries) -> IO () -atomicInsertIntoBucket buckets index doInsert = do - ticket0 <- Atomics.readArrayElem buckets index +atomicModifyArray :: forall a. MutableArray RealWorld a -> Int -> (a -> a) -> IO () +atomicModifyArray array index f = do + ticket0 <- Atomics.readArrayElem array index loop ticket0 where + loop :: Atomics.Ticket a -> IO () loop ticket = do - (success, ticket1) <- Atomics.casArrayElem buckets index ticket (doInsert (Atomics.peekTicket ticket)) + (success, ticket1) <- Atomics.casArrayElem array index ticket (f (Atomics.peekTicket ticket)) if success then pure () else loop ticket1 -atomicDeleteFromBucket :: MutableArray RealWorld Entries -> Int -> TimerId -> IO Bool -atomicDeleteFromBucket buckets index timerId = do - ticket0 <- Atomics.readArrayElem buckets index - loop ticket0 - where - loop ticket = - case Entries.delete timerId (Atomics.peekTicket ticket) of - Nothing -> pure False - Just entries1 -> do - (success, ticket1) <- Atomics.casArrayElem buckets index ticket entries1 - if success then pure True else loop ticket1 - -atomicExtractExpiredTimersFromBucket :: MutableArray RealWorld Entries -> Int -> IO [IO ()] -atomicExtractExpiredTimersFromBucket buckets index = do +atomicExtractExpiredTimersFromBucket :: MutableArray RealWorld TimerBucket -> Int -> Timestamp -> IO TimerBucket +atomicExtractExpiredTimersFromBucket buckets index now = do ticket0 <- Atomics.readArrayElem buckets index loop ticket0 where + loop :: Atomics.Ticket TimerBucket -> IO TimerBucket loop ticket - | Entries.null entries = pure [] + | Map.null bucket0 = pure bucket0 | otherwise = do - let (expired, entries1) = Entries.partition entries - (success, ticket1) <- Atomics.casArrayElem buckets index ticket entries1 + let Pair expired bucket1 = TimerBucket.partition now bucket0 + (success, ticket1) <- Atomics.casArrayElem buckets index ticket bucket1 if success then pure expired else loop ticket1 where - entries = Atomics.peekTicket ticket + bucket0 = Atomics.peekTicket ticket ------------------------------------------------------------------------------------------------------------------------ -- Timer reaper thread -runTimerReaperThread :: MutableArray RealWorld Entries -> Counter -> Micros -> IO void -runTimerReaperThread buckets numTimers resolution = do +runTimerReaperThread :: MutableArray RealWorld TimerBucket -> Micros -> IO void +runTimerReaperThread buckets resolution = do -- Sleep until the very first bucket of timers expires now <- Timestamp.now let remainingBucketMicros = resolution `Micros.minus` (now `Timestamp.rem` resolution) Micros.sleep remainingBucketMicros loop + (now `Timestamp.plus` remainingBucketMicros) (now `Timestamp.plus` remainingBucketMicros `Timestamp.plus` resolution) (timestampToIndex buckets resolution now) where - loop :: Timestamp -> Int -> IO void - loop !nextTime !index = do - expired <- atomicExtractExpiredTimersFromBucket buckets index - for_ expired \action -> do - action - decrCounter_ numTimers + -- FIXME read over this carefully and document + loop :: Timestamp -> Timestamp -> Int -> IO void + loop !thisTime !nextTime !index = do + expired <- atomicExtractExpiredTimersFromBucket buckets index thisTime + TimerBucket.fire expired now <- Timestamp.now - when (now < nextTime) (Micros.sleep (nextTime `Timestamp.minus` now)) - loop (nextTime `Timestamp.plus` resolution) ((index + 1) `rem` Array.sizeofMutableArray buckets) + Micros.sleep (nextTime `Timestamp.minus` now) -- it's ok if now > nextTime; that'll return immediately + loop nextTime (nextTime `Timestamp.plus` resolution) ((index + 1) `rem` Array.sizeofMutableArray buckets) diff --git a/src/TimerWheel/Internal/Entries.hs b/src/TimerWheel/Internal/Entries.hs deleted file mode 100644 index bcc667d..0000000 --- a/src/TimerWheel/Internal/Entries.hs +++ /dev/null @@ -1,63 +0,0 @@ -module TimerWheel.Internal.Entries - ( Entries, - empty, - TimerWheel.Internal.Entries.null, - size, - insert, - delete, - partition, - ) -where - -import Data.IntPSQ (IntPSQ) -import qualified Data.IntPSQ as IntPSQ -import TimerWheel.Internal.Prelude - -newtype Entries - = Entries (IntPSQ Word64 (IO ())) - --- | An empty collection. -empty :: Entries -empty = - Entries IntPSQ.empty -{-# INLINEABLE empty #-} - -null :: Entries -> Bool -null = - coerce (IntPSQ.null @Word64 @(IO ())) -{-# INLINEABLE null #-} - --- | The number of timers in the collection. --- --- /O(n)/. -size :: Entries -> Int -size = - coerce (IntPSQ.size @Word64 @(IO ())) -{-# INLINEABLE size #-} - --- | @insert i n m x@ inserts callback @m@ into collection @x@ with unique --- identifier @i@ and "count" @n@. The -insert :: Int -> Word64 -> IO () -> Entries -> Entries -insert i n m = - coerce (IntPSQ.unsafeInsertNew i n m) -{-# INLINEABLE insert #-} - --- | Delete a timer by id. Returns 'Nothing' if the timer was not found. -delete :: Int -> Entries -> Maybe Entries -delete = - coerce delete_ -{-# INLINEABLE delete #-} - -delete_ :: Int -> IntPSQ Word64 (IO ()) -> Maybe (IntPSQ Word64 (IO ())) -delete_ i xs = - (\(_, _, ys) -> ys) <$> IntPSQ.deleteView i xs - --- | Extract expired timers. -partition :: Entries -> ([IO ()], Entries) -partition (Entries entries) = - case IntPSQ.atMostView 0 entries of - (expired, alive) -> - ( map (\(_, _, m) -> m) expired, - Entries (IntPSQ.unsafeMapMonotonic (\_ n m -> (n - 1, m)) alive) - ) -{-# INLINEABLE partition #-} diff --git a/src/TimerWheel/Internal/Prelude.hs b/src/TimerWheel/Internal/Prelude.hs index 138e340..2f9874b 100644 --- a/src/TimerWheel/Internal/Prelude.hs +++ b/src/TimerWheel/Internal/Prelude.hs @@ -1,5 +1,7 @@ module TimerWheel.Internal.Prelude - ( Seconds, + ( DeleteResult (..), + Pair (..), + Seconds, module X, ) where @@ -8,9 +10,20 @@ import Control.Monad as X (when) import Data.Coerce as X (coerce) import Data.Fixed (E6, Fixed) import Data.IORef as X (newIORef, readIORef, writeIORef) +import Data.Map as X (Map) import Data.Word as X (Word64) import GHC.Generics as X (Generic) +-- | The result of attempting to delete something. +data DeleteResult a + = Deleted !a + | DidntDelete + deriving stock (Functor) + +-- | A strict pair. +data Pair a b + = Pair !a !b + -- | A number of seconds, with microsecond precision. -- -- You can use numeric literals to construct a value of this type, e.g. @0.5@. diff --git a/src/TimerWheel/Internal/Timer.hs b/src/TimerWheel/Internal/Timer.hs new file mode 100644 index 0000000..cf1bb96 --- /dev/null +++ b/src/TimerWheel/Internal/Timer.hs @@ -0,0 +1,14 @@ +module TimerWheel.Internal.Timer + ( TimerId, + Timer (..), + ) +where + +data Timer + = Timer + { id :: !TimerId, + action :: !(IO ()) + } + +type TimerId = + Int diff --git a/src/TimerWheel/Internal/TimerBag.hs b/src/TimerWheel/Internal/TimerBag.hs new file mode 100644 index 0000000..ff51c9c --- /dev/null +++ b/src/TimerWheel/Internal/TimerBag.hs @@ -0,0 +1,60 @@ +module TimerWheel.Internal.TimerBag + ( TimerBag, + singleton, + delete, + insert, + fire, + ) +where + +import TimerWheel.Internal.Prelude +import TimerWheel.Internal.Timer (Timer (..), TimerId) + +data TimerBag + = OneTimer {-# UNPACK #-} !Timer + | -- 2+ timers, stored in the reverse order that they were enqueued (so the last should fire first) + ManyTimers ![Timer] + +singleton :: Timer -> TimerBag +singleton = + OneTimer + +-- Make a timer bag from a non-empty list. +unsafeFromList :: [Timer] -> TimerBag +unsafeFromList = \case + [timer] -> OneTimer timer + timers -> ManyTimers timers + +delete :: TimerId -> TimerBag -> DeleteResult (Maybe TimerBag) +delete timerId = \case + OneTimer timer + | timerId == timer.id -> Deleted Nothing + | otherwise -> DidntDelete + ManyTimers timers0 -> Just . unsafeFromList <$> delete_ timerId timers0 + +delete_ :: TimerId -> [Timer] -> DeleteResult [Timer] +delete_ timerId = + go + where + go :: [Timer] -> DeleteResult [Timer] + go = \case + [] -> DidntDelete + timer : timers -> + if timerId == timer.id + then Deleted timers + else (timer :) <$> go timers + +insert :: Timer -> TimerBag -> TimerBag +insert new = \case + OneTimer old -> ManyTimers [new, old] + ManyTimers old -> ManyTimers (new : old) + +-- | Fire every timer in a bag in the order they were inserted. +fire :: TimerBag -> IO () +fire = \case + OneTimer timer -> timer.action + ManyTimers timers -> foldr f (pure ()) timers + where + f :: Timer -> IO () -> IO () + f timer acc = + acc >> timer.action diff --git a/src/TimerWheel/Internal/TimerBucket.hs b/src/TimerWheel/Internal/TimerBucket.hs new file mode 100644 index 0000000..da1e1f7 --- /dev/null +++ b/src/TimerWheel/Internal/TimerBucket.hs @@ -0,0 +1,55 @@ +module TimerWheel.Internal.TimerBucket + ( TimerBucket, + delete, + fire, + insert, + partition, + ) +where + +import Data.Foldable (traverse_) +import qualified Data.Map.Internal as Map.Internal (link) +import qualified Data.Map.Strict as Map +import qualified Data.Map.Strict.Internal as Map.Internal +import TimerWheel.Internal.Prelude +import TimerWheel.Internal.Timer (Timer (..), TimerId) +import TimerWheel.Internal.TimerBag (TimerBag) +import qualified TimerWheel.Internal.TimerBag as TimerBag +import TimerWheel.Internal.Timestamp (Timestamp) + +type TimerBucket = + Map Timestamp TimerBag + +-- | @delete timestamp timerId bucket@ deletes the timer with expiry time @timestamp@ and id @timerId@ in bucket +-- @bucket@. +delete :: Timestamp -> TimerId -> TimerBucket -> DeleteResult TimerBucket +delete timestamp timerId = + Map.alterF f timestamp + where + f :: Maybe TimerBag -> DeleteResult (Maybe TimerBag) + f = \case + Nothing -> DidntDelete + Just bag -> TimerBag.delete timerId bag + +insert :: Timestamp -> Timer -> TimerBucket -> TimerBucket +insert timestamp timer = + Map.alter (Just . maybe (TimerBag.singleton timer) (TimerBag.insert timer)) timestamp + +partition :: Timestamp -> TimerBucket -> Pair TimerBucket TimerBucket +partition now = go + where + go :: TimerBucket -> Pair TimerBucket TimerBucket + go = \case + Map.Internal.Tip -> Pair Map.Internal.Tip Map.Internal.Tip + Map.Internal.Bin _ expires timers bucketL bucketR + | expires <= now -> + let Pair bucketRL bucketRR = go bucketR + in Pair (Map.Internal.link expires timers bucketL bucketRL) bucketRR + | otherwise -> + let Pair bucketLL bucketRL = go bucketL + in Pair bucketLL (Map.Internal.link expires timers bucketRL bucketR) + +-- | Fire every timer in a bucket, in order +fire :: TimerBucket -> IO () +fire = + traverse_ TimerBag.fire diff --git a/timer-wheel.cabal b/timer-wheel.cabal index a452b6f..59b7c15 100644 --- a/timer-wheel.cabal +++ b/timer-wheel.cabal @@ -35,17 +35,20 @@ source-repository head common component build-depends: - base ^>= 4.12 || ^>= 4.13 || ^>= 4.14 || ^>= 4.15 || ^>= 4.16 || ^>= 4.17 || ^>= 4.18, + base ^>= 4.16 || ^>= 4.17 || ^>= 4.18 || ^>= 4.19 || ^>= 4.20, default-extensions: BangPatterns BlockArguments DeriveAnyClass + DeriveFunctor DeriveGeneric DerivingStrategies DuplicateRecordFields GeneralizedNewtypeDeriving LambdaCase NamedFieldPuns + NoFieldSelectors + OverloadedRecordDot RecursiveDo ScopedTypeVariables TupleSections @@ -65,22 +68,27 @@ common component if impl(ghc >= 9.2) ghc-options: -Wno-missing-kind-signatures + if impl(ghc >= 9.8) + ghc-options: + -Wno-missing-role-annotations library import: component build-depends: atomic-primops ^>= 0.8, + containers ^>= 0.6 || ^>= 0.7, ki ^>= 1.0.0, primitive ^>= 0.7 || ^>= 0.8, - psqueues ^>= 0.2.7, exposed-modules: TimerWheel hs-source-dirs: src other-modules: TimerWheel.Internal.Counter - TimerWheel.Internal.Entries TimerWheel.Internal.Micros TimerWheel.Internal.Prelude + TimerWheel.Internal.Timer + TimerWheel.Internal.TimerBag + TimerWheel.Internal.TimerBucket TimerWheel.Internal.Timestamp test-suite tests