From de3bc4243305d76b050fe524e8e693e29a5149fa Mon Sep 17 00:00:00 2001 From: Michael <151919150+heswithme@users.noreply.github.com> Date: Mon, 14 Oct 2024 10:32:02 +0200 Subject: [PATCH] Stateful testing of TWA module (#11) * test: hypothesis structure for TWA module * test: hypothesis parallel testing * test: hypothesis parallel testing * dev: TWA contract min_dt enforced; test: rm parallel hypothesis, added TWA invariant --------- Co-authored-by: Alberto --- contracts/TWA.vy | 16 +++- tests/__init__.py | 0 tests/hypothesis/__init__.py | 0 tests/hypothesis/conftest.py | 4 + tests/hypothesis/twa/__init__.py | 0 tests/hypothesis/twa/stateful_base.py | 114 ++++++++++++++++++++++ tests/hypothesis/twa/test_twa.py | 133 ++++++++++++++++++++++++++ 7 files changed, 263 insertions(+), 4 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/hypothesis/__init__.py create mode 100644 tests/hypothesis/conftest.py create mode 100644 tests/hypothesis/twa/__init__.py create mode 100644 tests/hypothesis/twa/stateful_base.py create mode 100644 tests/hypothesis/twa/test_twa.py diff --git a/contracts/TWA.vy b/contracts/TWA.vy index e5e90d9..497cc51 100644 --- a/contracts/TWA.vy +++ b/contracts/TWA.vy @@ -70,7 +70,7 @@ struct Snapshot: @deploy def __init__(_twa_window: uint256, _min_snapshot_dt_seconds: uint256): self._set_twa_window(_twa_window) - self._set_snapshot_dt(_min_snapshot_dt_seconds) + self._set_snapshot_dt(max(1, _min_snapshot_dt_seconds)) ################################################################ @@ -158,15 +158,20 @@ def _compute() -> uint256: i_backwards: uint256 = index_array_end - i current_snapshot: Snapshot = self.snapshots[i_backwards] next_snapshot: Snapshot = current_snapshot - if i != 0: # If not the first iteration, get the next snapshot + if i != 0: # If not the first iteration (last snapshot), get the next snapshot next_snapshot = self.snapshots[i_backwards + 1] + # Time Axis (Increasing to the Right) ---> + # SNAPSHOT + # |---------|---------|---------|------------------------|---------|---------| + # t0 time_window_start interval_start interval_end block.timestamp (Now) + interval_start: uint256 = current_snapshot.timestamp # Adjust interval start if it is before the time window start if interval_start < time_window_start: interval_start = time_window_start - interval_end: uint256 = 0 + interval_end: uint256 = interval_start if i == 0: # First iteration - we are on the last snapshot (i_backwards = num_snapshots - 1) # For the last snapshot, interval end is block.timestamp interval_end = block.timestamp @@ -186,7 +191,10 @@ def _compute() -> uint256: total_weighted_tracked_value += averaged_tracked_value * time_delta total_time += time_delta + if total_time == 0 and len(self.snapshots) == 1: + # case when only snapshot is taken in the block where computation is called + return self.snapshots[0].tracked_value + assert total_time > 0, "Zero total time!" twa: uint256 = total_weighted_tracked_value // total_time - return twa diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/hypothesis/__init__.py b/tests/hypothesis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/hypothesis/conftest.py b/tests/hypothesis/conftest.py new file mode 100644 index 0000000..7fcd91f --- /dev/null +++ b/tests/hypothesis/conftest.py @@ -0,0 +1,4 @@ +from hypothesis import Phase, Verbosity, settings + +settings.register_profile("debug", settings(verbosity=Verbosity.verbose, phases=list(Phase)[:4])) +settings.load_profile("debug") diff --git a/tests/hypothesis/twa/__init__.py b/tests/hypothesis/twa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/hypothesis/twa/stateful_base.py b/tests/hypothesis/twa/stateful_base.py new file mode 100644 index 0000000..28f92c9 --- /dev/null +++ b/tests/hypothesis/twa/stateful_base.py @@ -0,0 +1,114 @@ +import boa +from hypothesis import note +from hypothesis import strategies as st +from hypothesis.stateful import RuleBasedStateMachine, initialize # , invariant, rule + + +class TWAStatefulBase(RuleBasedStateMachine): + twa_deployer = boa.load_partial("contracts/TWA.vy") + + def __init__(self): + super().__init__() + note("INIT") + self.twa_contract = None + self.twa_window = None + self.min_snapshot_dt_seconds = None + self.snapshots = [] + self.last_snapshot_timestamp = 0 + + @initialize( + twa_window=st.integers(min_value=1, max_value=86400 * 7), # 1 second to 1 week + min_snapshot_dt_seconds=st.integers(min_value=1, max_value=86400), # 1 second to 1 day + ) + def setup(self, twa_window, min_snapshot_dt_seconds): + """Initialize the TWA contract and set up initial parameters.""" + note("SETUP") + self.twa_contract = TWAStatefulBase.twa_deployer(twa_window, min_snapshot_dt_seconds) + + self.twa_window = twa_window + self.min_snapshot_dt_seconds = min_snapshot_dt_seconds + self.snapshots = [] + self.last_snapshot_timestamp = 0 + + def python_take_snapshot(self, value): + """ + Python model of the contract's `_take_snapshot` function. + Mirrors the contract logic and updates the internal state. + """ + # Contract logic: only take a snapshot if the time condition is met + block_timestamp = boa.env.evm.patch.timestamp + if self.last_snapshot_timestamp + self.min_snapshot_dt_seconds <= block_timestamp: + self.last_snapshot_timestamp = block_timestamp + self.snapshots.append({"tracked_value": value, "timestamp": block_timestamp}) + note( + f"python_take_snapshot: Python snapshot added: value={value}, timestamp={block_timestamp}" # noqa: E501 + ) + else: + note("python_take_snapshot: Python snapshot skipped (time condition not met)") + + def python_compute_twa(self): + """ + Python version of the contract's _compute function. + Computes the TWA (Time-Weighted Average) based on the snapshots in self.snapshots. + """ + block_timestamp = boa.env.evm.patch.timestamp + + num_snapshots = len(self.snapshots) + if num_snapshots == 0: + note("python_compute_twa: No snapshots, no TWA") + return 0 + + time_window_start = block_timestamp - self.twa_window + + total_weighted_tracked_value = 0 + total_time = 0 + + # Iterate backwards over all snapshots + index_array_end = num_snapshots - 1 + for i in range(0, num_snapshots): + i_backwards = index_array_end - i + current_snapshot = self.snapshots[i_backwards] + next_snapshot = current_snapshot + + if i != 0: # If not the first iteration, get the next snapshot + next_snapshot = self.snapshots[i_backwards + 1] + + interval_start = current_snapshot["timestamp"] + + # Adjust interval start if it is before the time window start + if interval_start < time_window_start: + interval_start = time_window_start + + if i == 0: + # For the last snapshot, interval end is the block_timestamp + interval_end = block_timestamp + else: + # For other snapshots, interval end is the timestamp of the next snapshot + interval_end = next_snapshot["timestamp"] + + if interval_end <= time_window_start: + break + + time_delta = interval_end - interval_start + + # Interpolation using the trapezoidal rule + averaged_tracked_value = ( + current_snapshot["tracked_value"] + next_snapshot["tracked_value"] + ) // 2 + + # Accumulate weighted rate and time + total_weighted_tracked_value += averaged_tracked_value * time_delta + total_time += time_delta + + if total_time == 0 and len(self.snapshots) == 1: + # case when only snapshot is taken in the block where computation is called + return self.snapshots[0]["tracked_value"] + + # Ensure there is non-zero time for division + if total_time == 0: + raise ValueError("TWA: Zero total time!") + + # Calculate TWA + twa = total_weighted_tracked_value // total_time + note(f"python_compute_twa: Computed TWA: {twa}") + return twa diff --git a/tests/hypothesis/twa/test_twa.py b/tests/hypothesis/twa/test_twa.py new file mode 100644 index 0000000..4b8b6f5 --- /dev/null +++ b/tests/hypothesis/twa/test_twa.py @@ -0,0 +1,133 @@ +import boa +from hypothesis import HealthCheck, Verbosity, settings +from hypothesis import strategies as st +from hypothesis.stateful import invariant, rule + +from tests.hypothesis.twa.stateful_base import TWAStatefulBase + + +def test_state_machine(): + # Explicitly run the state machine + TestTWAStateful = TWAStateful.TestCase() + TestTWAStateful.run() + + +@settings( + max_examples=10, + stateful_step_count=1000, + suppress_health_check=[ + HealthCheck.large_base_example + ], # skips issue when trying to add 1000 examples with 0 dt + verbosity=Verbosity.verbose, +) +class TWAStateful(TWAStatefulBase): + @invariant() + def check_initialization(self): + assert self.twa_window > 0, "TWA window must be set" + assert self.min_snapshot_dt_seconds > 0, "Minimum snapshot interval must be set" + + @invariant() + def check_crude_twa_invariant(self): + """ + Crude invariant to ensure that the computed TWA is reasonable. + It checks that the TWA is non-negative and is between the minimum and maximum + values of the snapshots within the TWA window. + """ + # Get current block timestamp + current_time = boa.env.evm.patch.timestamp + + # Calculate the time window start + time_window_start = current_time - self.twa_window + + # Collect snapshots within the TWA window + snapshots_in_window = [ + snapshot for snapshot in self.snapshots if snapshot["timestamp"] >= time_window_start + ] + + # Also consider the last snapshot just outside TWA window (needed for trapezoidal rule) + previous_snapshot = None + for snapshot in self.snapshots: + if snapshot["timestamp"] < time_window_start: + previous_snapshot = snapshot + else: + break # We passed the start of the window + + # If a previous snapshot exists, we add it to the window (on the boundary) + # not changing timestamp as we only assert values here + if previous_snapshot: + snapshots_in_window.append(previous_snapshot) + + # If there are still no snapshots (even outside the window), TWA should be zero + if not snapshots_in_window: + contract_twa = self.twa_contract.compute_twa() + python_twa = self.python_compute_twa() + + # Assert both TWAs are zero + assert contract_twa == 0, f"Contract TWA should be zero but is {contract_twa}" + assert python_twa == 0, f"Python TWA should be zero but is {python_twa}" + return + + # Extract tracked values from snapshots in the window + tracked_values = [snapshot["tracked_value"] for snapshot in snapshots_in_window] + + # Compute the min and max values of the tracked values + min_value = min(tracked_values) + max_value = max(tracked_values) + # Compute the TWA from the contract and Python model + contract_twa = self.twa_contract.compute_twa() + python_twa = self.python_compute_twa() + + # Ensure that the TWA is non-negative + assert contract_twa >= 0, f"Contract TWA is negative: {contract_twa}" + assert python_twa >= 0, f"Python TWA is negative: {python_twa}" + + # Ensure that the TWA is between the min and max values of the snapshots + assert ( + min_value <= contract_twa <= max_value + ), f"Contract TWA {contract_twa} is not between min {min_value} and max {max_value}" + assert ( + min_value <= python_twa <= max_value + ), f"Python TWA {python_twa} is not between min {min_value} and max {max_value}" + + @rule( + value=st.integers(min_value=0, max_value=100_000_000 * 10**18), # 0 to 100 million crvUSD + timestamp_delta=st.integers( + min_value=0, max_value=10 * 86400 + ), # 0s to 10 days between snapshots + ) + def take_snapshot_rule(self, value, timestamp_delta): + """ + Rule to test taking snapshots in both the Python model and the contract. + """ + boa.env.time_travel(seconds=timestamp_delta) + # Call snapshot-taking functions in both the Python model and the contract + self.twa_contract.eval(f"self._take_snapshot({value})") + self.python_take_snapshot(value) + + # Assert equal numbe of the snapshots + contract_snapshot_len = self.twa_contract.get_len_snapshots() + python_snapshot_len = len(self.snapshots) + + assert contract_snapshot_len == python_snapshot_len, ( + "Mismatch in snapshot length: " + + f"contract={contract_snapshot_len}, python={python_snapshot_len}" + ) + + @rule( + timestamp_delta=st.integers( + min_value=0, max_value=10 * 86400 + ), # 0s to 10days between compute calls + ) + def compute_twa_rule(self, timestamp_delta): + boa.env.time_travel(seconds=timestamp_delta) + # TWA computation for contract/python model + contract_twa = self.twa_contract.compute_twa() + python_twa = self.python_compute_twa() + + # Assert that both values are the same + assert ( + contract_twa == python_twa + ), f"Mismatch in TWA: contract={contract_twa}, python={python_twa}" + + +# TestTWAStateful = TWAStateful.TestCase