From 7b7beb7b53a537a8ad1462cf8ca5520e21e4978c Mon Sep 17 00:00:00 2001 From: Matthias <97468149+matthiasmatt@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:09:25 +0100 Subject: [PATCH] feat: add validation for multi message execution wasm (#2092) * feat: add validation for multi message execution wasm * chore: changelog * merge conflicts + simplify --------- Co-authored-by: Unique-Divine --- CHANGELOG.md | 1 + x/evm/precompile/wasm.go | 12 ++- x/evm/precompile/wasm_test.go | 183 ++++++++++++++++++++++++++++++++++ x/evm/statedb/statedb.go | 3 - 4 files changed, 193 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61dc41b07e..23dcb15d1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,7 @@ reverts inside of a try-catch. for (1) ERC20 transfers with tokens that return false success values instead of throwing an error and (2) ERC20 transfers with other operations that don't bring about the expected resulting balance for the transfer recipient. +- [#2092](https://github.com/NibiruChain/nibiru/pull/2092) - feat(evm): add validation for wasm multi message execution #### Nibiru EVM | Before Audit 1 - 2024-10-18 diff --git a/x/evm/precompile/wasm.go b/x/evm/precompile/wasm.go index a7b21684cf..f19511cb9a 100644 --- a/x/evm/precompile/wasm.go +++ b/x/evm/precompile/wasm.go @@ -10,6 +10,7 @@ import ( "github.com/NibiruChain/nibiru/v2/x/evm/embeds" wasmkeeper "github.com/CosmWasm/wasmd/x/wasm/keeper" + wasm "github.com/CosmWasm/wasmd/x/wasm/types" gethabi "github.com/ethereum/go-ethereum/accounts/abi" gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/vm" @@ -275,10 +276,15 @@ func (p precompileWasm) executeMulti( callerBech32 := eth.EthAddrToNibiruAddr(caller) var responses [][]byte - for _, m := range wasmExecMsgs { + for i, m := range wasmExecMsgs { wasmContract, e := sdk.AccAddressFromBech32(m.ContractAddr) if e != nil { - err = fmt.Errorf("Execute failed: %w", e) + err = fmt.Errorf("Execute failed at index %d: %w", i, e) + return + } + msgArgsCopy := wasm.RawContractMessage(m.MsgArgs) + if e := msgArgsCopy.ValidateBasic(); e != nil { + err = fmt.Errorf("Execute failed at index %d: error parsing msg args: %w", i, e) return } var funds sdk.Coins @@ -290,7 +296,7 @@ func (p precompileWasm) executeMulti( } respBz, e := p.Wasm.Execute(ctx, wasmContract, callerBech32, m.MsgArgs, funds) if e != nil { - err = e + err = fmt.Errorf("Execute failed at index %d: %w", i, e) return } responses = append(responses, respBz) diff --git a/x/evm/precompile/wasm_test.go b/x/evm/precompile/wasm_test.go index dbe35b839b..72c45046c6 100644 --- a/x/evm/precompile/wasm_test.go +++ b/x/evm/precompile/wasm_test.go @@ -8,6 +8,7 @@ import ( wasm "github.com/CosmWasm/wasmd/x/wasm/types" "github.com/NibiruChain/nibiru/v2/x/common/testutil" + "github.com/NibiruChain/nibiru/v2/x/common/testutil/testapp" "github.com/NibiruChain/nibiru/v2/x/evm/embeds" "github.com/NibiruChain/nibiru/v2/x/evm/evmtest" "github.com/NibiruChain/nibiru/v2/x/evm/precompile" @@ -313,3 +314,185 @@ func (s *WasmSuite) TestSadArgsExecute() { }) } } + +type WasmExecuteMsg struct { + ContractAddr string `json:"contractAddr"` + MsgArgs []byte `json:"msgArgs"` + Funds []precompile.WasmBankCoin `json:"funds"` +} + +func (s *WasmSuite) TestExecuteMultiValidation() { + deps := evmtest.NewTestDeps() + + s.Require().NoError(testapp.FundAccount( + deps.App.BankKeeper, + deps.Ctx, + deps.Sender.NibiruAddr, + sdk.NewCoins(sdk.NewCoin("unibi", sdk.NewInt(100))), + )) + + wasmContracts := test.SetupWasmContracts(&deps, &s.Suite) + wasmContract := wasmContracts[1] // hello_world_counter.wasm + + invalidMsgArgsBz := []byte(`{"invalid": "json"}`) // Invalid message format + validMsgArgsBz := []byte(`{"increment": {}}`) // Valid increment message + + var emptyFunds []precompile.WasmBankCoin + validFunds := []precompile.WasmBankCoin{{ + Denom: "unibi", + Amount: big.NewInt(100), + }} + invalidFunds := []precompile.WasmBankCoin{{ + Denom: "invalid!denom", + Amount: big.NewInt(100), + }} + + testCases := []struct { + name string + executeMsgs []WasmExecuteMsg + wantError string + }{ + { + name: "valid - single message", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: validMsgArgsBz, + Funds: emptyFunds, + }, + }, + wantError: "", + }, + { + name: "valid - multiple messages", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: validMsgArgsBz, + Funds: validFunds, + }, + { + ContractAddr: wasmContract.String(), + MsgArgs: validMsgArgsBz, + Funds: emptyFunds, + }, + }, + wantError: "", + }, + { + name: "invalid - malformed contract address", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: "invalid-address", + MsgArgs: validMsgArgsBz, + Funds: emptyFunds, + }, + }, + wantError: "decoding bech32 failed", + }, + { + name: "invalid - malformed message args", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: invalidMsgArgsBz, + Funds: emptyFunds, + }, + }, + wantError: "unknown variant", + }, + { + name: "invalid - malformed funds", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: validMsgArgsBz, + Funds: invalidFunds, + }, + }, + wantError: "invalid coins", + }, + { + name: "invalid - second message fails validation", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: validMsgArgsBz, + Funds: emptyFunds, + }, + { + ContractAddr: wasmContract.String(), + MsgArgs: invalidMsgArgsBz, + Funds: emptyFunds, + }, + }, + wantError: "unknown variant", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + callArgs := []any{tc.executeMsgs} + input, err := embeds.SmartContract_Wasm.ABI.Pack( + string(precompile.WasmMethod_executeMulti), + callArgs..., + ) + s.Require().NoError(err) + + ethTxResp, _, err := deps.EvmKeeper.CallContractWithInput( + deps.Ctx, deps.Sender.EthAddr, &precompile.PrecompileAddr_Wasm, true, input, + ) + + if tc.wantError != "" { + s.Require().ErrorContains(err, tc.wantError) + return + } + s.Require().NoError(err) + s.NotNil(ethTxResp) + s.NotEmpty(ethTxResp.Ret) + }) + } +} + +// TestExecuteMultiPartialExecution ensures that no state changes occur if any message +// in the batch fails validation +func (s *WasmSuite) TestExecuteMultiPartialExecution() { + deps := evmtest.NewTestDeps() + wasmContracts := test.SetupWasmContracts(&deps, &s.Suite) + wasmContract := wasmContracts[1] // hello_world_counter.wasm + + // First verify initial state is 0 + test.AssertWasmCounterState(&s.Suite, deps, wasmContract, 0) + + // Create a batch where the second message will fail validation + executeMsgs := []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: []byte(`{"increment": {}}`), + Funds: []precompile.WasmBankCoin{}, + }, + { + ContractAddr: wasmContract.String(), + MsgArgs: []byte(`{"invalid": "json"}`), // This will fail validation + Funds: []precompile.WasmBankCoin{}, + }, + } + + callArgs := []any{executeMsgs} + input, err := embeds.SmartContract_Wasm.ABI.Pack( + string(precompile.WasmMethod_executeMulti), + callArgs..., + ) + s.Require().NoError(err) + + ethTxResp, _, err := deps.EvmKeeper.CallContractWithInput( + deps.Ctx, deps.Sender.EthAddr, &precompile.PrecompileAddr_Wasm, true, input, + ) + + // Verify that the call failed + s.Require().Error(err, "ethTxResp: ", ethTxResp) + s.Require().Contains(err.Error(), "unknown variant") + + // Verify that no state changes occurred + test.AssertWasmCounterState(&s.Suite, deps, wasmContract, 0) +} diff --git a/x/evm/statedb/statedb.go b/x/evm/statedb/statedb.go index 4c79e61af7..a411660f9e 100644 --- a/x/evm/statedb/statedb.go +++ b/x/evm/statedb/statedb.go @@ -480,9 +480,6 @@ func (s *StateDB) Snapshot() int { // RevertToSnapshot reverts all state changes made since the given revision. func (s *StateDB) RevertToSnapshot(revid int) { - fmt.Printf("len(s.validRevisions): %d\n", len(s.validRevisions)) - fmt.Printf("s.validRevisions: %v\n", s.validRevisions) - // Find the snapshot in the stack of valid snapshots. idx := sort.Search(len(s.validRevisions), func(i int) bool { return s.validRevisions[i].id >= revid