From fc8da62625a80ca7de366ce4957f75d101ee5a1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Walter?= Date: Wed, 16 Oct 2024 18:06:55 +0200 Subject: [PATCH] Implement and diff test validate_header --- cairo/programs/fork.cairo | 203 +++++++++++++++++++++++++++ cairo/tests/fixtures/data.py | 30 ++++ cairo/tests/programs/test_fork.cairo | 48 +++++++ cairo/tests/programs/test_fork.py | 99 +++++++++++++ cairo/tests/utils/models.py | 15 +- cairo/tests/utils/parsers.py | 7 +- 6 files changed, 398 insertions(+), 4 deletions(-) create mode 100644 cairo/programs/fork.cairo create mode 100644 cairo/tests/programs/test_fork.cairo create mode 100644 cairo/tests/programs/test_fork.py diff --git a/cairo/programs/fork.cairo b/cairo/programs/fork.cairo new file mode 100644 index 0000000..1923d0f --- /dev/null +++ b/cairo/programs/fork.cairo @@ -0,0 +1,203 @@ +// See https://github.com/ethereum/execution-specs/blob/master/src/ethereum/cancun/fork.py + +from starkware.cairo.common.uint256 import Uint256 +from starkware.cairo.common.math import unsigned_div_rem, split_felt +from starkware.cairo.common.math_cmp import is_nn +from starkware.cairo.common.bool import FALSE + +from src.model import model + +using Uint128 = felt; + +const ELASTICITY_MULTIPLIER = 2; +const GAS_LIMIT_ADJUSTMENT_FACTOR = 1024; +const GAS_LIMIT_MINIMUM = 5000; +const BASE_FEE_MAX_CHANGE_DENOMINATOR = 8; +const EMPTY_OMMER_HASH_LOW = 0xd312451b948a7413f0a142fd40d49347; +const EMPTY_OMMER_HASH_HIGH = 0x1dcc4de8dec75d7aab85b567b6ccd41a; + +// @notice See https://github.com/ethereum/execution-specs/blob/master/src/ethereum/cancun/fork.py#L1118-L1154 +// @dev We use the Uint128 alias to strenghten the fact that these felts should have been range_checked before +func check_gas_limit{range_check_ptr}(gas_limit: Uint128, parent_gas_limit: Uint128) { + let (max_adjustment_delta, _) = unsigned_div_rem(parent_gas_limit, GAS_LIMIT_ADJUSTMENT_FACTOR); + + with_attr error_message("InvalidBlock") { + assert [range_check_ptr] = parent_gas_limit + max_adjustment_delta - gas_limit - 1; + assert [range_check_ptr + 1] = gas_limit - (parent_gas_limit - max_adjustment_delta) - 1; + assert [range_check_ptr + 2] = gas_limit - GAS_LIMIT_MINIMUM; + let range_check_ptr = range_check_ptr + 3; + } + + return (); +} + +// @notice See https://github.com/ethereum/execution-specs/blob/master/src/ethereum/cancun/fork.py#L226-L285 +// @dev We use the Uint128 alias to strenghten the fact that these felts should have been range_checked before +func calculate_base_fee_per_gas{range_check_ptr}( + block_gas_limit: Uint128, + parent_gas_limit: Uint128, + parent_gas_used: Uint128, + parent_base_fee_per_gas: Uint128, +) -> Uint128 { + let (parent_gas_target, _) = unsigned_div_rem(parent_gas_limit, ELASTICITY_MULTIPLIER); + + check_gas_limit(block_gas_limit, parent_gas_limit); + + if (parent_gas_used == parent_gas_target) { + return parent_base_fee_per_gas; + } + + let is_parent_gas_used_greater_than_parent_gas_target = is_nn( + parent_gas_used - parent_gas_target - 1 + ); + if (is_parent_gas_used_greater_than_parent_gas_target != FALSE) { + let gas_used_delta = parent_gas_used - parent_gas_target; + let parent_fee_gas_delta = parent_base_fee_per_gas * gas_used_delta; + let (target_fee_gas_delta, _) = unsigned_div_rem(parent_fee_gas_delta, parent_gas_target); + let (base_fee_per_gas_delta, _) = unsigned_div_rem( + target_fee_gas_delta, BASE_FEE_MAX_CHANGE_DENOMINATOR + ); + if (base_fee_per_gas_delta == 0) { + return 1; + } + return base_fee_per_gas_delta; + } + + let gas_used_delta = parent_gas_target - parent_gas_used; + let parent_fee_gas_delta = parent_base_fee_per_gas * gas_used_delta; + let (target_fee_gas_delta, _) = unsigned_div_rem(parent_fee_gas_delta, parent_gas_target); + let (base_fee_per_gas_delta, _) = unsigned_div_rem( + target_fee_gas_delta, BASE_FEE_MAX_CHANGE_DENOMINATOR + ); + + return parent_base_fee_per_gas - base_fee_per_gas_delta; +} + +// @notice See https://github.com/ethereum/execution-specs/blob/master/src/ethereum/cancun/fork.py#L288-L332 +// @dev Initial range checks for all values because header is filled with a hint +func validate_header{range_check_ptr}(header: model.BlockHeader, parent_header: model.BlockHeader) { + // parent_hash + assert [range_check_ptr] = header.parent_hash.low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = header.parent_hash.high; + let range_check_ptr = range_check_ptr + 1; + // ommers_hash + assert [range_check_ptr] = header.ommers_hash.low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = header.ommers_hash.high; + let range_check_ptr = range_check_ptr + 1; + // coinbase + let (coinbase_high, coinbase_low) = split_felt(header.coinbase); + assert [range_check_ptr] = coinbase_low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = coinbase_high; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = 2 ** 32 - coinbase_high - 1; + let range_check_ptr = range_check_ptr + 1; + // state_root + assert [range_check_ptr] = header.state_root.low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = header.state_root.high; + let range_check_ptr = range_check_ptr + 1; + // transactions_root + assert [range_check_ptr] = header.transactions_root.low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = header.transactions_root.high; + let range_check_ptr = range_check_ptr + 1; + // receipt_root + assert [range_check_ptr] = header.receipt_root.low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = header.receipt_root.high; + let range_check_ptr = range_check_ptr + 1; + // withdrawals_root + assert header.withdrawals_root.is_some * (1 - header.withdrawals_root.is_some) = 0; + let withdrawals_root = cast(header.withdrawals_root.value, Uint256*); + assert [range_check_ptr] = withdrawals_root.low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = withdrawals_root.high; + let range_check_ptr = range_check_ptr + 1; + // difficulty + assert [range_check_ptr] = header.difficulty.low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = header.difficulty.high; + let range_check_ptr = range_check_ptr + 1; + // number + assert [range_check_ptr] = header.number; + let range_check_ptr = range_check_ptr + 1; + // gas_limit + assert [range_check_ptr] = header.gas_limit; + let range_check_ptr = range_check_ptr + 1; + // gas_used + assert [range_check_ptr] = header.gas_used; + let range_check_ptr = range_check_ptr + 1; + // timestamp + assert [range_check_ptr] = header.timestamp; + let range_check_ptr = range_check_ptr + 1; + // mix_hash + assert [range_check_ptr] = header.mix_hash.low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = header.mix_hash.high; + let range_check_ptr = range_check_ptr + 1; + // nonce + assert [range_check_ptr] = header.nonce; + let range_check_ptr = range_check_ptr + 1; + // base_fee_per_gas + assert header.base_fee_per_gas.is_some * (1 - header.base_fee_per_gas.is_some) = 0; + assert [range_check_ptr] = header.base_fee_per_gas.value; + let range_check_ptr = range_check_ptr + 1; + // blob_gas_used + assert header.blob_gas_used.is_some * (1 - header.blob_gas_used.is_some) = 0; + assert [range_check_ptr] = header.blob_gas_used.value; + let range_check_ptr = range_check_ptr + 1; + // excess_blob_gas + assert header.excess_blob_gas.is_some * (1 - header.excess_blob_gas.is_some) = 0; + assert [range_check_ptr] = header.excess_blob_gas.value; + let range_check_ptr = range_check_ptr + 1; + // parent_beacon_block_root + assert header.parent_beacon_block_root.is_some * ( + 1 - header.parent_beacon_block_root.is_some + ) = 0; + let parent_beacon_block_root = cast(header.parent_beacon_block_root.value, Uint256*); + assert [range_check_ptr] = parent_beacon_block_root.low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = parent_beacon_block_root.high; + let range_check_ptr = range_check_ptr + 1; + // requests_root + assert header.requests_root.is_some * (1 - header.requests_root.is_some) = 0; + let requests_root = cast(header.requests_root.value, Uint256*); + assert [range_check_ptr] = requests_root.low; + let range_check_ptr = range_check_ptr + 1; + assert [range_check_ptr] = requests_root.high; + let range_check_ptr = range_check_ptr + 1; + // extra_data_len + assert [range_check_ptr] = header.extra_data_len; + let range_check_ptr = range_check_ptr + 1; + + with_attr error_message("InvalidBlock") { + assert [range_check_ptr] = header.gas_limit - header.gas_used; + let range_check_ptr = range_check_ptr + 1; + let expected_base_fee_per_gas = calculate_base_fee_per_gas( + header.gas_limit, + parent_header.gas_limit, + parent_header.gas_used, + parent_header.base_fee_per_gas.value, + ); + + assert expected_base_fee_per_gas = header.base_fee_per_gas.value; + assert [range_check_ptr] = header.timestamp - parent_header.timestamp - 1; + assert [range_check_ptr + 1] = header.number - parent_header.number - 1; + assert [range_check_ptr + 2] = 32 - header.extra_data_len; + let range_check_ptr = range_check_ptr + 3; + assert header.difficulty.low = 0; + assert header.difficulty.high = 0; + assert header.nonce = 0; + assert header.ommers_hash.low = EMPTY_OMMER_HASH_LOW; + assert header.ommers_hash.high = EMPTY_OMMER_HASH_HIGH; + } + + // TODO: Implement block header hash check + // block_parent_hash = keccak256(rlp.encode(parent_header)) + // if header.parent_hash != block_parent_hash: + // raise InvalidBlock + return (); +} diff --git a/cairo/tests/fixtures/data.py b/cairo/tests/fixtures/data.py index 095d9d5..ee20e80 100644 --- a/cairo/tests/fixtures/data.py +++ b/cairo/tests/fixtures/data.py @@ -1,7 +1,37 @@ import pytest +from hypothesis import strategies as st from tests.utils.models import Account, Block, State +block_header_strategy = st.fixed_dictionaries( + { + "parent_hash": st.binary(min_size=32, max_size=32), + "ommers_hash": st.just( + bytes.fromhex( + "1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347" + ) + ), + "coinbase": st.binary(min_size=20, max_size=20), + "state_root": st.binary(min_size=32, max_size=32), + "transactions_root": st.binary(min_size=32, max_size=32), + "receipt_root": st.binary(min_size=32, max_size=32), + "bloom": st.binary(min_size=256, max_size=256), + "difficulty": st.just(0x00), + "number": st.integers(min_value=0, max_value=2**64 - 1), + "gas_limit": st.integers(min_value=0, max_value=2**64 - 1), + "gas_used": st.integers(min_value=0, max_value=2**64 - 1), + "timestamp": st.integers(min_value=0, max_value=2**64 - 1), + "extra_data": st.binary(max_size=32), + "prev_randao": st.binary(min_size=32, max_size=32), + "nonce": st.just("0x0000000000000000"), + "base_fee_per_gas": st.integers(min_value=0, max_value=2**128 - 1), + "withdrawals_root": st.binary(min_size=32, max_size=32), + "blob_gas_used": st.integers(min_value=0, max_value=2**64 - 1), + "excess_blob_gas": st.integers(min_value=0, max_value=2**64 - 1), + "parent_beacon_block_root": st.binary(min_size=32, max_size=32), + } +) + @pytest.fixture def block(): diff --git a/cairo/tests/programs/test_fork.cairo b/cairo/tests/programs/test_fork.cairo new file mode 100644 index 0000000..9b8145f --- /dev/null +++ b/cairo/tests/programs/test_fork.cairo @@ -0,0 +1,48 @@ +from programs.fork import check_gas_limit, calculate_base_fee_per_gas, validate_header, Uint128 +from src.model import model + +func test_check_gas_limit{range_check_ptr}() { + tempvar gas_limit: Uint128; + tempvar parent_gas_limit: Uint128; + %{ + ids.gas_limit = program_input["gas_limit"] + ids.parent_gas_limit = program_input["parent_gas_limit"] + %} + check_gas_limit(gas_limit, parent_gas_limit); + + return (); +} + +func test_calculate_base_fee_per_gas{range_check_ptr}() -> Uint128 { + tempvar block_gas_limit: Uint128; + tempvar parent_gas_limit: Uint128; + tempvar parent_gas_used: Uint128; + tempvar parent_base_fee_per_gas: Uint128; + %{ + ids.block_gas_limit = program_input["block_gas_limit"] + ids.parent_gas_limit = program_input["parent_gas_limit"] + ids.parent_gas_used = program_input["parent_gas_used"] + ids.parent_base_fee_per_gas = program_input["parent_base_fee_per_gas"] + %} + return calculate_base_fee_per_gas( + block_gas_limit, parent_gas_limit, parent_gas_used, parent_base_fee_per_gas + ); +} + +func test_validate_header{range_check_ptr}() { + alloc_locals; + local header: model.BlockHeader*; + local parent_header: model.BlockHeader*; + %{ + if '__dict_manager' not in globals(): + from starkware.cairo.common.dict import DictManager + __dict_manager = DictManager() + + from tests.utils.hints import gen_arg + + ids.header = gen_arg(__dict_manager, segments, program_input["header"]) + ids.parent_header = gen_arg(__dict_manager, segments, program_input["parent_header"]) + %} + validate_header([header], [parent_header]); + return (); +} diff --git a/cairo/tests/programs/test_fork.py b/cairo/tests/programs/test_fork.py new file mode 100644 index 0000000..421f1c1 --- /dev/null +++ b/cairo/tests/programs/test_fork.py @@ -0,0 +1,99 @@ +from ethereum.cancun.blocks import Header +from ethereum.cancun.fork import ( + calculate_base_fee_per_gas, + check_gas_limit, + validate_header, +) +from ethereum.exceptions import InvalidBlock +from hypothesis import given +from hypothesis.strategies import integers + +from tests.fixtures.data import block_header_strategy +from tests.utils.errors import cairo_error +from tests.utils.models import BlockHeader + + +class TestFork: + @given( + integers(min_value=0, max_value=2**128 - 1), + integers(min_value=0, max_value=2**128 - 1), + ) + def test_check_gas_limit(self, cairo_run, gas_limit, parent_gas_limit): + expected = check_gas_limit(gas_limit, parent_gas_limit) + if not expected: + with cairo_error("InvalidBlock"): + cairo_run( + "test_check_gas_limit", + gas_limit=gas_limit, + parent_gas_limit=parent_gas_limit, + ) + else: + cairo_run( + "test_check_gas_limit", + gas_limit=gas_limit, + parent_gas_limit=parent_gas_limit, + ) + + @given( + integers(min_value=0, max_value=2**128 - 1), + integers(min_value=0, max_value=2**128 - 1), + integers(min_value=0, max_value=2**128 - 1), + integers(min_value=0, max_value=2**128 - 1), + ) + def test_calculate_base_fee_per_gas( + self, + cairo_run, + block_gas_limit, + parent_gas_limit, + parent_gas_used, + parent_base_fee_per_gas, + ): + try: + expected = calculate_base_fee_per_gas( + block_gas_limit, + parent_gas_limit, + parent_gas_used, + parent_base_fee_per_gas, + ) + except InvalidBlock: + expected = None + + if expected is not None: + assert expected == cairo_run( + "test_calculate_base_fee_per_gas", + block_gas_limit=block_gas_limit, + parent_gas_limit=parent_gas_limit, + parent_gas_used=parent_gas_used, + parent_base_fee_per_gas=parent_base_fee_per_gas, + ) + else: + with cairo_error("InvalidBlock"): + cairo_run( + "test_calculate_base_fee_per_gas", + block_gas_limit=block_gas_limit, + parent_gas_limit=parent_gas_limit, + parent_gas_used=parent_gas_used, + parent_base_fee_per_gas=parent_base_fee_per_gas, + ) + + @given(header=block_header_strategy, parent_header=block_header_strategy) + def test_validate_header(self, cairo_run, header, parent_header): + error = None + try: + validate_header(Header(**header), Header(**parent_header)) + except InvalidBlock as e: + error = e + + if error is not None: + with cairo_error("InvalidBlock"): + cairo_run( + "test_validate_header", + header=BlockHeader.model_validate(header), + parent_header=BlockHeader.model_validate(parent_header), + ) + else: + cairo_run( + "test_validate_header", + header=BlockHeader.model_validate(header), + parent_header=BlockHeader.model_validate(parent_header), + ) diff --git a/cairo/tests/utils/models.py b/cairo/tests/utils/models.py index c513b13..b4a7c02 100644 --- a/cairo/tests/utils/models.py +++ b/cairo/tests/utils/models.py @@ -68,6 +68,7 @@ def split_uint256(cls, values): "withdrawals_root", "difficulty", "mix_hash", + "prev_randao", ]: if key not in values: key = to_camel(key) @@ -125,7 +126,7 @@ def parse_bloom(cls, v): raise ValueError("Bloom cannot be empty") if len(bloom) != 256: raise ValueError("Bloom must be 256 bytes") - return tuple(int(chunk) for chunk in wrap(bloom.hex(), 32)) + return tuple(int(chunk, 16) for chunk in wrap(bloom.hex(), 32)) parent_hash_low: int parent_hash_high: int @@ -179,8 +180,16 @@ def parse_bloom(cls, v): gas_limit: int gas_used: int timestamp: int - mix_hash_low: int - mix_hash_high: int + mix_hash_low: int = Field( + validation_alias=AliasChoices( + "mix_hash", "mixHash", "prev_randao", "prevRandao" + ) + ) + mix_hash_high: int = Field( + validation_alias=AliasChoices( + "mixHashHigh", "prev_randao_high", "prevRandaoHigh" + ) + ) nonce: int base_fee_per_gas_is_some: bool base_fee_per_gas_value: int diff --git a/cairo/tests/utils/parsers.py b/cairo/tests/utils/parsers.py index 544b860..f19c102 100644 --- a/cairo/tests/utils/parsers.py +++ b/cairo/tests/utils/parsers.py @@ -1,13 +1,18 @@ +import re from typing import Optional, Union +hex_pattern = re.compile(r"^(0x)?[0-9a-fA-F]+$") + def to_int(v: Optional[Union[str, int]]) -> Optional[int]: if v is None: return v if isinstance(v, str): - if v.startswith("0x"): + if hex_pattern.match(v): return int(v, 16) return int(v) + if isinstance(v, bytes): + return int.from_bytes(v, "big") return v