Skip to content

Commit

Permalink
Add upper bound for delay when using backoff
Browse files Browse the repository at this point in the history
  • Loading branch information
rucek committed Nov 21, 2023
1 parent d61e292 commit 4c267d5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
9 changes: 7 additions & 2 deletions core/src/main/scala/ox/retry/retry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ object RetryPolicy:
case class Delay(maxRetries: Int, delay: FiniteDuration) extends RetryPolicy:
def nextDelay(attempt: Int): FiniteDuration = delay

case class Backoff(maxRetries: Int, initialDelay: FiniteDuration, jitter: Jitter = Jitter.None) extends RetryPolicy:
def nextDelay(attempt: Int): FiniteDuration = initialDelay * Math.pow(2, attempt).toLong // TODO jitter
case class Backoff(
maxRetries: Int,
initialDelay: FiniteDuration,
maxDelay: FiniteDuration = 1.day,
jitter: Jitter = Jitter.None
) extends RetryPolicy:
def nextDelay(attempt: Int): FiniteDuration = (initialDelay * Math.pow(2, attempt).toLong).min(maxDelay) // TODO jitter

def retry[T](f: => T)(policy: RetryPolicy): T =
retry(f, _ => true)(policy)
Expand Down
42 changes: 31 additions & 11 deletions core/src/test/scala/ox/retry/BackoffRetryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,45 @@ class BackoffRetryTest extends AnyFlatSpec with Matchers with EitherValues with
it should "retry a function" in {
// given
val maxRetries = 3
val sleep = 100.millis
val initialDelay = 100.millis
var counter = 0
def f =
counter += 1
if true then throw new RuntimeException("boom")

// when
val (result, elapsedTime) = measure(the[RuntimeException] thrownBy retry(f)(RetryPolicy.Backoff(maxRetries, sleep)))
val (result, elapsedTime) = measure(the[RuntimeException] thrownBy retry(f)(RetryPolicy.Backoff(maxRetries, initialDelay)))

// then
result should have message "boom"
elapsedTime.toMillis should be >= expectedTotalBackoffTimeMillis(maxRetries, sleep)
elapsedTime.toMillis should be >= expectedTotalBackoffTimeMillis(maxRetries, initialDelay)
counter shouldBe 4
}

it should "respect maximum delay" in {
// given
val maxRetries = 3
val initialDelay = 100.millis
val maxDelay = 200.millis
var counter = 0
def f =
counter += 1
if true then throw new RuntimeException("boom")

// when
val (result, elapsedTime) = measure(the[RuntimeException] thrownBy retry(f)(RetryPolicy.Backoff(maxRetries, initialDelay, maxDelay)))

// then
result should have message "boom"
elapsedTime.toMillis should be >= expectedTotalBackoffTimeMillis(maxRetries, initialDelay, maxDelay)
elapsedTime.toMillis should be < initialDelay.toMillis + maxRetries * maxDelay.toMillis
counter shouldBe 4
}

it should "retry an Either" in {
// given
val maxRetries = 3
val sleep = 100.millis
val initialDelay = 100.millis
var counter = 0
val errorMessage = "boom"

Expand All @@ -43,18 +63,18 @@ class BackoffRetryTest extends AnyFlatSpec with Matchers with EitherValues with
Left(errorMessage)

// when
val (result, elapsedTime) = measure(retry(f)(RetryPolicy.Backoff(maxRetries, sleep)))
val (result, elapsedTime) = measure(retry(f)(RetryPolicy.Backoff(maxRetries, initialDelay)))

// then
result.left.value shouldBe errorMessage
elapsedTime.toMillis should be >= expectedTotalBackoffTimeMillis(maxRetries, sleep)
elapsedTime.toMillis should be >= expectedTotalBackoffTimeMillis(maxRetries, initialDelay)
counter shouldBe 4
}

it should "retry a Try" in {
// given
val maxRetries = 3
val sleep = 100.millis
val initialDelay = 100.millis
var counter = 0
val errorMessage = "boom"

Expand All @@ -63,13 +83,13 @@ class BackoffRetryTest extends AnyFlatSpec with Matchers with EitherValues with
Failure(new RuntimeException(errorMessage))

// when
val (result, elapsedTime) = measure(retry(f)(RetryPolicy.Backoff(maxRetries, sleep)))
val (result, elapsedTime) = measure(retry(f)(RetryPolicy.Backoff(maxRetries, initialDelay)))

// then
result.failure.exception should have message errorMessage
elapsedTime.toMillis should be >= expectedTotalBackoffTimeMillis(maxRetries, sleep)
elapsedTime.toMillis should be >= expectedTotalBackoffTimeMillis(maxRetries, initialDelay)
counter shouldBe 4
}

private def expectedTotalBackoffTimeMillis(maxRetries: Int, sleep: FiniteDuration): Long =
(0 until maxRetries).map(sleep.toMillis * Math.pow(2, _).toLong).sum
private def expectedTotalBackoffTimeMillis(maxRetries: Int, initialDelay: FiniteDuration, maxDelay: FiniteDuration = 1.day): Long =
(0 until maxRetries).map(attempt => (initialDelay * Math.pow(2, attempt)).min(maxDelay).toMillis).sum

0 comments on commit 4c267d5

Please sign in to comment.