diff --git a/src/ethereum_test_tools/utility/generators.py b/src/ethereum_test_tools/utility/generators.py index 72135a5bb52..0f1defc0cb4 100644 --- a/src/ethereum_test_tools/utility/generators.py +++ b/src/ethereum_test_tools/utility/generators.py @@ -21,6 +21,13 @@ class DeploymentTestType(Enum): DEPLOY_AFTER_FORK = "deploy_after_fork" +class ContractAddressHasBalance(Enum): + """Represents whether the target deployment test has a balance before deployment.""" + + ZERO_BALANCE = "zero_balance" + NONZERO_BALANCE = "nonzero_balance" + + class SystemContractDeployTestFunction(Protocol): """ Represents a function to be decorated with the `generate_system_contract_deploy_test` @@ -61,9 +68,17 @@ def generate_system_contract_deploy_test( """ Generate a test that verifies the correct deployment of a system contract. - Generates two tests: - - One that deploys the contract before the fork. - - One that deploys the contract after the fork. + Generates four test cases: + + | before/after fork | has balance | + ------------------------------------|-------------------|-------------| + `deploy_before_fork-nonzero_balance`| before | True | + `deploy_before_fork-zero_balance` | before | False | + `deploy_after_fork-nonzero_balance` | after | True | + `deploy_after_fork-zero_balance` | after | False | + + where `has balance` refers to whether the contract address has a non-zero balance before + deployment, or not. Args: fork (Fork): The fork to test. @@ -89,6 +104,14 @@ def generate_system_contract_deploy_test( deployer_address = deploy_tx.sender def decorator(func: SystemContractDeployTestFunction): + @pytest.mark.parametrize( + "has_balance", + [ + pytest.param(ContractAddressHasBalance.NONZERO_BALANCE), + pytest.param(ContractAddressHasBalance.ZERO_BALANCE), + ], + ids=lambda x: x.name.lower(), + ) @pytest.mark.parametrize( "test_type", [ @@ -101,6 +124,7 @@ def decorator(func: SystemContractDeployTestFunction): @pytest.mark.valid_at_transition_to(fork.name()) def wrapper( blockchain_test: BlockchainTestFiller, + has_balance: ContractAddressHasBalance, pre: Alloc, test_type: DeploymentTestType, fork: Fork, @@ -131,11 +155,11 @@ def wrapper( timestamp=15_001, ), ] - + balance = 1 if has_balance == ContractAddressHasBalance.NONZERO_BALANCE else 0 pre[expected_deploy_address] = Account( code=b"", # Remove the code that is automatically allocated on the fork nonce=0, - balance=0, + balance=balance, ) pre[deployer_address] = Account( balance=deployer_required_balance, @@ -151,12 +175,14 @@ def wrapper( post[expected_deploy_address] = Account( code=expected_code, nonce=1, + balance=balance, ) else: post[expected_deploy_address] = Account( storage=expected_system_contract_storage, code=expected_code, nonce=1, + balance=balance, ) post[deployer_address] = Account( nonce=1,