Skip to content

Commit

Permalink
Add jitter
Browse files Browse the repository at this point in the history
  • Loading branch information
rucek committed Nov 22, 2023
1 parent 4c267d5 commit 5b430e7
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 9 deletions.
31 changes: 23 additions & 8 deletions core/src/main/scala/ox/retry/retry.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ox.retry

import scala.concurrent.duration.*
import scala.util.Try
import scala.util.{Random, Try}

sealed trait Jitter

Expand All @@ -13,22 +13,37 @@ object Jitter:

trait RetryPolicy:
def maxRetries: Int
def nextDelay(attempt: Int): FiniteDuration
def nextDelay(attempt: Int, lastDelay: Option[FiniteDuration]): FiniteDuration

object RetryPolicy:
case class Direct(maxRetries: Int) extends RetryPolicy:
def nextDelay(attempt: Int): FiniteDuration = Duration.Zero
def nextDelay(attempt: Int, lastDelay: Option[FiniteDuration]): FiniteDuration = Duration.Zero

case class Delay(maxRetries: Int, delay: FiniteDuration) extends RetryPolicy:
def nextDelay(attempt: Int): FiniteDuration = delay
def nextDelay(attempt: Int, lastDelay: Option[FiniteDuration]): FiniteDuration = delay

case class Backoff(
maxRetries: Int,
initialDelay: FiniteDuration,
maxDelay: FiniteDuration = 1.day,
jitter: Jitter = Jitter.None
) extends RetryPolicy:
def nextDelay(attempt: Int): FiniteDuration = (initialDelay * Math.pow(2, attempt).toLong).min(maxDelay) // TODO jitter
def nextDelay(attempt: Int, lastDelay: Option[FiniteDuration]): FiniteDuration =
def backoffDelay = Backoff.delay(attempt, initialDelay, maxDelay)

jitter match
case Jitter.None => backoffDelay
case Jitter.Full => Random.between(0, backoffDelay.toMillis).millis
case Jitter.Equal =>
val backoff = backoffDelay.toMillis
(backoff / 2 + Random.between(0, backoff / 2)).millis
case Jitter.Decorrelated =>
val last = lastDelay.getOrElse(initialDelay).toMillis
Random.between(initialDelay.toMillis, last * 3).millis

private[retry] object Backoff:
def delay(attempt: Int, initialDelay: FiniteDuration, maxDelay: FiniteDuration): FiniteDuration =
(initialDelay * Math.pow(2, attempt).toLong).min(maxDelay)

def retry[T](f: => T)(policy: RetryPolicy): T =
retry(f, _ => true)(policy)
Expand All @@ -42,12 +57,12 @@ def retry[E, T](f: => Either[E, T])(policy: RetryPolicy)(using dummy: DummyImpli
def retry[E, T](f: => Either[E, T], isSuccess: T => Boolean, isWorthRetrying: E => Boolean = (_: E) => true)(policy: RetryPolicy)(using
dummy: DummyImplicit
): Either[E, T] =
def loop(remainingAttempts: Int): Either[E, T] =
def loop(remainingAttempts: Int, lastDelay: Option[FiniteDuration] = None): Either[E, T] =
def nextAttemptOr(e: => Either[E, T]) =
if remainingAttempts > 0 then
val delay = policy.nextDelay(policy.maxRetries - remainingAttempts).toMillis
val delay = policy.nextDelay(policy.maxRetries - remainingAttempts, lastDelay).toMillis
if delay > 0 then Thread.sleep(delay)
loop(remainingAttempts - 1)
loop(remainingAttempts - 1, Some(delay.millis))
else e

f match
Expand Down
23 changes: 22 additions & 1 deletion core/src/test/scala/ox/retry/BackoffRetryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ class BackoffRetryTest extends AnyFlatSpec with Matchers with EitherValues with
counter shouldBe 4
}

it should "use jitter" in {
// given
val maxRetries = 3
val initialDelay = 100.millis
val maxDelay = 200.millis
var counter = 0
def f =
counter += 1
if true then throw new RuntimeException("boom")

// when
val (result, elapsedTime) =
measure(the[RuntimeException] thrownBy retry(f)(RetryPolicy.Backoff(maxRetries, initialDelay, maxDelay, Jitter.Equal)))

// then
result should have message "boom"
elapsedTime.toMillis should be >= expectedTotalBackoffTimeMillis(maxRetries, initialDelay, maxDelay) / 2
elapsedTime.toMillis should be < initialDelay.toMillis + maxRetries * maxDelay.toMillis
counter shouldBe 4
}

it should "retry an Either" in {
// given
val maxRetries = 3
Expand Down Expand Up @@ -92,4 +113,4 @@ class BackoffRetryTest extends AnyFlatSpec with Matchers with EitherValues with
}

private def expectedTotalBackoffTimeMillis(maxRetries: Int, initialDelay: FiniteDuration, maxDelay: FiniteDuration = 1.day): Long =
(0 until maxRetries).map(attempt => (initialDelay * Math.pow(2, attempt)).min(maxDelay).toMillis).sum
(0 until maxRetries).map(RetryPolicy.Backoff.delay(_, initialDelay, maxDelay).toMillis).sum
68 changes: 68 additions & 0 deletions core/src/test/scala/ox/retry/JitterTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package ox.retry

import org.scalatest.Inspectors
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import scala.concurrent.duration.*

class JitterTest extends AnyFlatSpec with Matchers {

behavior of "Jitter"

private val basePolicy = RetryPolicy.Backoff(maxRetries = 3, initialDelay = 100.millis)

it should "use no jitter" in {
// given
val policy = basePolicy

// when
val delays = (1 to 5).map(policy.nextDelay(_, None))

// then
delays should contain theSameElementsInOrderAs Seq(200, 400, 800, 1600, 3200).map(_.millis)
}

it should "use full jitter" in {
// given
val policy = basePolicy.copy(jitter = Jitter.Full)

// when
val delays = (1 to 5).map(policy.nextDelay(_, None))

// then
Inspectors.forEvery(delays.zipWithIndex) { case (delay, i) =>
val backoffDelay = RetryPolicy.Backoff.delay(i + 1, policy.initialDelay, policy.maxDelay)
delay should (be >= 0.millis and be <= backoffDelay)
}
}

it should "use equal jitter" in {
// given
val policy = basePolicy.copy(jitter = Jitter.Equal)

// when
val delays = (1 to 5).map(policy.nextDelay(_, None))

// then
Inspectors.forEvery(delays.zipWithIndex) { case (delay, i) =>
val backoffDelay = RetryPolicy.Backoff.delay(i + 1, policy.initialDelay, policy.maxDelay)
delay should (be >= backoffDelay / 2 and be <= backoffDelay)
}
}

it should "use decorrelated jitter" in {
// given
val policy = basePolicy.copy(jitter = Jitter.Decorrelated)

// when
val delays = (1 to 5).map(policy.nextDelay(_, None))

// then
Inspectors.forEvery(delays.sliding(2).map(_.toList).toList) {
case List(previousDelay, delay) =>
delay should (be >= policy.initialDelay and be <= previousDelay * 3)
case _ => succeed // so that the match is exhaustive
}
}
}

0 comments on commit 5b430e7

Please sign in to comment.