diff --git a/CHANGELOG.md b/CHANGELOG.md index 61dc41b07..23dcb15d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,7 @@ reverts inside of a try-catch. for (1) ERC20 transfers with tokens that return false success values instead of throwing an error and (2) ERC20 transfers with other operations that don't bring about the expected resulting balance for the transfer recipient. +- [#2092](https://github.com/NibiruChain/nibiru/pull/2092) - feat(evm): add validation for wasm multi message execution #### Nibiru EVM | Before Audit 1 - 2024-10-18 diff --git a/x/evm/precompile/wasm.go b/x/evm/precompile/wasm.go index a7b21684c..f19511cb9 100644 --- a/x/evm/precompile/wasm.go +++ b/x/evm/precompile/wasm.go @@ -10,6 +10,7 @@ import ( "github.com/NibiruChain/nibiru/v2/x/evm/embeds" wasmkeeper "github.com/CosmWasm/wasmd/x/wasm/keeper" + wasm "github.com/CosmWasm/wasmd/x/wasm/types" gethabi "github.com/ethereum/go-ethereum/accounts/abi" gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/vm" @@ -275,10 +276,15 @@ func (p precompileWasm) executeMulti( callerBech32 := eth.EthAddrToNibiruAddr(caller) var responses [][]byte - for _, m := range wasmExecMsgs { + for i, m := range wasmExecMsgs { wasmContract, e := sdk.AccAddressFromBech32(m.ContractAddr) if e != nil { - err = fmt.Errorf("Execute failed: %w", e) + err = fmt.Errorf("Execute failed at index %d: %w", i, e) + return + } + msgArgsCopy := wasm.RawContractMessage(m.MsgArgs) + if e := msgArgsCopy.ValidateBasic(); e != nil { + err = fmt.Errorf("Execute failed at index %d: error parsing msg args: %w", i, e) return } var funds sdk.Coins @@ -290,7 +296,7 @@ func (p precompileWasm) executeMulti( } respBz, e := p.Wasm.Execute(ctx, wasmContract, callerBech32, m.MsgArgs, funds) if e != nil { - err = e + err = fmt.Errorf("Execute failed at index %d: %w", i, e) return } responses = append(responses, respBz) diff --git a/x/evm/precompile/wasm_test.go b/x/evm/precompile/wasm_test.go index dbe35b839..72c45046c 100644 --- a/x/evm/precompile/wasm_test.go +++ b/x/evm/precompile/wasm_test.go @@ -8,6 +8,7 @@ import ( wasm "github.com/CosmWasm/wasmd/x/wasm/types" "github.com/NibiruChain/nibiru/v2/x/common/testutil" + "github.com/NibiruChain/nibiru/v2/x/common/testutil/testapp" "github.com/NibiruChain/nibiru/v2/x/evm/embeds" "github.com/NibiruChain/nibiru/v2/x/evm/evmtest" "github.com/NibiruChain/nibiru/v2/x/evm/precompile" @@ -313,3 +314,185 @@ func (s *WasmSuite) TestSadArgsExecute() { }) } } + +type WasmExecuteMsg struct { + ContractAddr string `json:"contractAddr"` + MsgArgs []byte `json:"msgArgs"` + Funds []precompile.WasmBankCoin `json:"funds"` +} + +func (s *WasmSuite) TestExecuteMultiValidation() { + deps := evmtest.NewTestDeps() + + s.Require().NoError(testapp.FundAccount( + deps.App.BankKeeper, + deps.Ctx, + deps.Sender.NibiruAddr, + sdk.NewCoins(sdk.NewCoin("unibi", sdk.NewInt(100))), + )) + + wasmContracts := test.SetupWasmContracts(&deps, &s.Suite) + wasmContract := wasmContracts[1] // hello_world_counter.wasm + + invalidMsgArgsBz := []byte(`{"invalid": "json"}`) // Invalid message format + validMsgArgsBz := []byte(`{"increment": {}}`) // Valid increment message + + var emptyFunds []precompile.WasmBankCoin + validFunds := []precompile.WasmBankCoin{{ + Denom: "unibi", + Amount: big.NewInt(100), + }} + invalidFunds := []precompile.WasmBankCoin{{ + Denom: "invalid!denom", + Amount: big.NewInt(100), + }} + + testCases := []struct { + name string + executeMsgs []WasmExecuteMsg + wantError string + }{ + { + name: "valid - single message", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: validMsgArgsBz, + Funds: emptyFunds, + }, + }, + wantError: "", + }, + { + name: "valid - multiple messages", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: validMsgArgsBz, + Funds: validFunds, + }, + { + ContractAddr: wasmContract.String(), + MsgArgs: validMsgArgsBz, + Funds: emptyFunds, + }, + }, + wantError: "", + }, + { + name: "invalid - malformed contract address", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: "invalid-address", + MsgArgs: validMsgArgsBz, + Funds: emptyFunds, + }, + }, + wantError: "decoding bech32 failed", + }, + { + name: "invalid - malformed message args", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: invalidMsgArgsBz, + Funds: emptyFunds, + }, + }, + wantError: "unknown variant", + }, + { + name: "invalid - malformed funds", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: validMsgArgsBz, + Funds: invalidFunds, + }, + }, + wantError: "invalid coins", + }, + { + name: "invalid - second message fails validation", + executeMsgs: []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: validMsgArgsBz, + Funds: emptyFunds, + }, + { + ContractAddr: wasmContract.String(), + MsgArgs: invalidMsgArgsBz, + Funds: emptyFunds, + }, + }, + wantError: "unknown variant", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + callArgs := []any{tc.executeMsgs} + input, err := embeds.SmartContract_Wasm.ABI.Pack( + string(precompile.WasmMethod_executeMulti), + callArgs..., + ) + s.Require().NoError(err) + + ethTxResp, _, err := deps.EvmKeeper.CallContractWithInput( + deps.Ctx, deps.Sender.EthAddr, &precompile.PrecompileAddr_Wasm, true, input, + ) + + if tc.wantError != "" { + s.Require().ErrorContains(err, tc.wantError) + return + } + s.Require().NoError(err) + s.NotNil(ethTxResp) + s.NotEmpty(ethTxResp.Ret) + }) + } +} + +// TestExecuteMultiPartialExecution ensures that no state changes occur if any message +// in the batch fails validation +func (s *WasmSuite) TestExecuteMultiPartialExecution() { + deps := evmtest.NewTestDeps() + wasmContracts := test.SetupWasmContracts(&deps, &s.Suite) + wasmContract := wasmContracts[1] // hello_world_counter.wasm + + // First verify initial state is 0 + test.AssertWasmCounterState(&s.Suite, deps, wasmContract, 0) + + // Create a batch where the second message will fail validation + executeMsgs := []WasmExecuteMsg{ + { + ContractAddr: wasmContract.String(), + MsgArgs: []byte(`{"increment": {}}`), + Funds: []precompile.WasmBankCoin{}, + }, + { + ContractAddr: wasmContract.String(), + MsgArgs: []byte(`{"invalid": "json"}`), // This will fail validation + Funds: []precompile.WasmBankCoin{}, + }, + } + + callArgs := []any{executeMsgs} + input, err := embeds.SmartContract_Wasm.ABI.Pack( + string(precompile.WasmMethod_executeMulti), + callArgs..., + ) + s.Require().NoError(err) + + ethTxResp, _, err := deps.EvmKeeper.CallContractWithInput( + deps.Ctx, deps.Sender.EthAddr, &precompile.PrecompileAddr_Wasm, true, input, + ) + + // Verify that the call failed + s.Require().Error(err, "ethTxResp: ", ethTxResp) + s.Require().Contains(err.Error(), "unknown variant") + + // Verify that no state changes occurred + test.AssertWasmCounterState(&s.Suite, deps, wasmContract, 0) +} diff --git a/x/evm/statedb/statedb.go b/x/evm/statedb/statedb.go index 4c79e61af..a411660f9 100644 --- a/x/evm/statedb/statedb.go +++ b/x/evm/statedb/statedb.go @@ -480,9 +480,6 @@ func (s *StateDB) Snapshot() int { // RevertToSnapshot reverts all state changes made since the given revision. func (s *StateDB) RevertToSnapshot(revid int) { - fmt.Printf("len(s.validRevisions): %d\n", len(s.validRevisions)) - fmt.Printf("s.validRevisions: %v\n", s.validRevisions) - // Find the snapshot in the stack of valid snapshots. idx := sort.Search(len(s.validRevisions), func(i int) bool { return s.validRevisions[i].id >= revid