From 4184906514166f7ac12ce0663af11c342d78e29e Mon Sep 17 00:00:00 2001 From: Unique-Divine Date: Mon, 28 Oct 2024 19:49:45 -0500 Subject: [PATCH] remove new bank keeper --- CHANGELOG.md | 23 ++--- app/keepers.go | 15 +-- x/evm/keeper/bank_extension.go | 163 --------------------------------- x/evm/keeper/keeper.go | 5 +- x/evm/keeper/statedb.go | 17 +++- x/evm/statedb/debug.go | 39 ++++++++ x/evm/statedb/journal.go | 19 ---- x/evm/statedb/journal_test.go | 21 ++--- x/evm/statedb/statedb.go | 9 -- 9 files changed, 72 insertions(+), 239 deletions(-) delete mode 100644 x/evm/keeper/bank_extension.go create mode 100644 x/evm/statedb/debug.go diff --git a/CHANGELOG.md b/CHANGELOG.md index ac2dba69cb..473fdaf98a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,21 +70,14 @@ consistent setup and dynamic gas calculations, addressing the following tickets. - [#2089](https://github.com/NibiruChain/nibiru/pull/2089) - better handling of gas consumption within erc20 contract execution - [#2091](https://github.com/NibiruChain/nibiru/pull/2091) - feat(evm): add fun token creation fee validation - [#2094](https://github.com/NibiruChain/nibiru/pull/2094) - fix(evm): Following -from the changs in #2086, this pull request implements two critical security -fixes. - 1. First, we add new `JournalChange` struct that saves a deep copy of the - state multi store before each state-modifying, Nibiru-specific precompiled - contract is called (`OnRunStart`). Additionally, we commit the `StateDB` there - as well. This guarantees that the non-EVM and EVM state will be in sync even - if there are complex, multi-step Ethereum transactions, such as in the case of - an EthereumTx that influences the `StateDB`, then calls a precompile that also - changes non-EVM state, and then EVM reverts inside of a try-catch. - 2. Second, the solution from #2086 that records NIBI (ether) transfers on the - `StateDB` during precompiled contract calls is generalized as - `NibiruBankKeeper`, which is struct extension of the `bankkeeper.BaseKeeper` - that is used throughout the Nibiru base application. The `NibiruBankKeeper` - holds a reference to the current EVM `StateDB` if there is one and records - balance changes in wei as journal changes automatically. +from the changs in #2086, this pull request implements a new `JournalChange` +struct that saves a deep copy of the state multi store before each +state-modifying, Nibiru-specific precompiled contract is called (`OnRunStart`). +Additionally, we commit the `StateDB` there as well. This guarantees that the +non-EVM and EVM state will be in sync even if there are complex, multi-step +Ethereum transactions, such as in the case of an EthereumTx that influences the +`StateDB`, then calls a precompile that also changes non-EVM state, and then EVM +reverts inside of a try-catch. #### Nibiru EVM | Before Audit 1 - 2024-10-18 diff --git a/app/keepers.go b/app/keepers.go index be75b23572..12421be659 100644 --- a/app/keepers.go +++ b/app/keepers.go @@ -264,24 +264,13 @@ func (app *NibiruApp) InitKeepers( govModuleAddr, ) - app.bankBaseKeeper = bankkeeper.NewBaseKeeper( + app.BankKeeper = bankkeeper.NewBaseKeeper( appCodec, keys[banktypes.StoreKey], app.AccountKeeper, BlockedAddresses(), govModuleAddr, ) - nibiruBankKeeper := evmkeeper.NibiruBankKeeper{ - BaseKeeper: bankkeeper.NewBaseKeeper( - appCodec, - keys[banktypes.StoreKey], - app.AccountKeeper, - BlockedAddresses(), - govModuleAddr, - ), - StateDB: nil, - } - app.BankKeeper = nibiruBankKeeper app.StakingKeeper = stakingkeeper.NewKeeper( appCodec, @@ -384,7 +373,7 @@ func (app *NibiruApp) InitKeepers( tkeys[evm.TransientKey], authtypes.NewModuleAddress(govtypes.ModuleName), app.AccountKeeper, - &nibiruBankKeeper, + app.BankKeeper, app.StakingKeeper, cast.ToString(appOpts.Get("evm.tracer")), ) diff --git a/x/evm/keeper/bank_extension.go b/x/evm/keeper/bank_extension.go deleted file mode 100644 index a2bcd6d277..0000000000 --- a/x/evm/keeper/bank_extension.go +++ /dev/null @@ -1,163 +0,0 @@ -package keeper - -import ( - sdk "github.com/cosmos/cosmos-sdk/types" - auth "github.com/cosmos/cosmos-sdk/x/auth/types" - bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper" - - "github.com/NibiruChain/nibiru/v2/eth" - "github.com/NibiruChain/nibiru/v2/x/evm" - "github.com/NibiruChain/nibiru/v2/x/evm/statedb" -) - -var ( - _ bankkeeper.Keeper = &NibiruBankKeeper{} - _ bankkeeper.SendKeeper = &NibiruBankKeeper{} -) - -type NibiruBankKeeper struct { - bankkeeper.BaseKeeper - StateDB *statedb.StateDB - balanceChangesForStateDB uint64 -} - -func (evmKeeper *Keeper) NewStateDB( - ctx sdk.Context, txConfig statedb.TxConfig, -) *statedb.StateDB { - stateDB := statedb.New(ctx, evmKeeper, txConfig) - bk := evmKeeper.bankKeeper - bk.StateDB = stateDB - bk.balanceChangesForStateDB = 0 - return stateDB -} - -// BalanceChangesForStateDB returns the count of [statedb.JournalChange] entries -// that were added to the current [statedb.StateDB] -func (bk *NibiruBankKeeper) BalanceChangesForStateDB() uint64 { return bk.balanceChangesForStateDB } - -func (bk NibiruBankKeeper) MintCoins( - ctx sdk.Context, - moduleName string, - coins sdk.Coins, -) error { - // Use the embedded function from [bankkeeper.Keeper] - if err := bk.BaseKeeper.MintCoins(ctx, moduleName, coins); err != nil { - return err - } - if findEtherBalanceChangeFromCoins(coins) { - moduleBech32Addr := auth.NewModuleAddress(evm.ModuleName) - bk.SyncStateDBWithAccount(ctx, moduleBech32Addr) - } - return nil -} - -func (bk NibiruBankKeeper) BurnCoins( - ctx sdk.Context, - moduleName string, - coins sdk.Coins, -) error { - // Use the embedded function from [bankkeeper.Keeper] - if err := bk.BaseKeeper.BurnCoins(ctx, moduleName, coins); err != nil { - return err - } - if findEtherBalanceChangeFromCoins(coins) { - moduleBech32Addr := auth.NewModuleAddress(evm.ModuleName) - bk.SyncStateDBWithAccount(ctx, moduleBech32Addr) - } - return nil -} - -func (bk NibiruBankKeeper) SendCoins( - ctx sdk.Context, - fromAddr sdk.AccAddress, - toAddr sdk.AccAddress, - coins sdk.Coins, -) error { - // Use the embedded function from [bankkeeper.Keeper] - if err := bk.BaseKeeper.SendCoins(ctx, fromAddr, toAddr, coins); err != nil { - return err - } - if findEtherBalanceChangeFromCoins(coins) { - bk.SyncStateDBWithAccount(ctx, fromAddr) - bk.SyncStateDBWithAccount(ctx, toAddr) - } - return nil -} - -func (bk *NibiruBankKeeper) SyncStateDBWithAccount( - ctx sdk.Context, acc sdk.AccAddress, -) { - // If there's no StateDB set, it means we're not in an EthereumTx. - if bk.StateDB == nil { - return - } - balanceWei := evm.NativeToWei( - bk.GetBalance(ctx, acc, evm.EVMBankDenom).Amount.BigInt(), - ) - bk.StateDB.SetBalanceWei(eth.NibiruAddrToEthAddr(acc), balanceWei) - bk.balanceChangesForStateDB += 1 -} - -func findEtherBalanceChangeFromCoins(coins sdk.Coins) (found bool) { - for _, c := range coins { - if c.Denom == evm.EVMBankDenom { - return true - } - } - return false -} - -func (bk NibiruBankKeeper) SendCoinsFromAccountToModule( - ctx sdk.Context, - senderAddr sdk.AccAddress, - recipientModule string, - coins sdk.Coins, -) error { - // Use the embedded function from [bankkeeper.Keeper] - if err := bk.BaseKeeper.SendCoinsFromAccountToModule(ctx, senderAddr, recipientModule, coins); err != nil { - return err - } - if findEtherBalanceChangeFromCoins(coins) { - bk.SyncStateDBWithAccount(ctx, senderAddr) - moduleBech32Addr := auth.NewModuleAddress(recipientModule) - bk.SyncStateDBWithAccount(ctx, moduleBech32Addr) - } - return nil -} - -func (bk NibiruBankKeeper) SendCoinsFromModuleToAccount( - ctx sdk.Context, - senderModule string, - recipientAddr sdk.AccAddress, - coins sdk.Coins, -) error { - // Use the embedded function from [bankkeeper.Keeper] - if err := bk.BaseKeeper.SendCoinsFromModuleToAccount(ctx, senderModule, recipientAddr, coins); err != nil { - return err - } - if findEtherBalanceChangeFromCoins(coins) { - moduleBech32Addr := auth.NewModuleAddress(senderModule) - bk.SyncStateDBWithAccount(ctx, moduleBech32Addr) - bk.SyncStateDBWithAccount(ctx, recipientAddr) - } - return nil -} - -func (bk NibiruBankKeeper) SendCoinsFromModuleToModule( - ctx sdk.Context, - senderModule string, - recipientModule string, - coins sdk.Coins, -) error { - // Use the embedded function from [bankkeeper.Keeper] - if err := bk.BaseKeeper.SendCoinsFromModuleToModule(ctx, senderModule, recipientModule, coins); err != nil { - return err - } - if findEtherBalanceChangeFromCoins(coins) { - senderBech32Addr := auth.NewModuleAddress(senderModule) - recipientBech32Addr := auth.NewModuleAddress(recipientModule) - bk.SyncStateDBWithAccount(ctx, senderBech32Addr) - bk.SyncStateDBWithAccount(ctx, recipientBech32Addr) - } - return nil -} diff --git a/x/evm/keeper/keeper.go b/x/evm/keeper/keeper.go index c6b0720d8a..dd31229fd1 100644 --- a/x/evm/keeper/keeper.go +++ b/x/evm/keeper/keeper.go @@ -15,6 +15,7 @@ import ( "github.com/cosmos/cosmos-sdk/codec" storetypes "github.com/cosmos/cosmos-sdk/store/types" sdk "github.com/cosmos/cosmos-sdk/types" + bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper" gethcommon "github.com/ethereum/go-ethereum/common" "github.com/NibiruChain/nibiru/v2/app/appconst" @@ -40,7 +41,7 @@ type Keeper struct { // this should be the x/gov module account. authority sdk.AccAddress - bankKeeper *NibiruBankKeeper + bankKeeper bankkeeper.Keeper accountKeeper evm.AccountKeeper stakingKeeper evm.StakingKeeper @@ -63,7 +64,7 @@ func NewKeeper( storeKey, transientKey storetypes.StoreKey, authority sdk.AccAddress, accKeeper evm.AccountKeeper, - bankKeeper *NibiruBankKeeper, + bankKeeper bankkeeper.Keeper, stakingKeeper evm.StakingKeeper, tracer string, ) Keeper { diff --git a/x/evm/keeper/statedb.go b/x/evm/keeper/statedb.go index 575962d022..7aff8c02d8 100644 --- a/x/evm/keeper/statedb.go +++ b/x/evm/keeper/statedb.go @@ -17,6 +17,13 @@ import ( var _ statedb.Keeper = &Keeper{} +func (k *Keeper) NewStateDB( + ctx sdk.Context, + txConfig statedb.TxConfig, +) *statedb.StateDB { + return statedb.New(ctx, k, txConfig) +} + // ---------------------------------------------------------------------------- // StateDB Keeper implementation // ---------------------------------------------------------------------------- @@ -73,26 +80,26 @@ func (k *Keeper) SetAccBalance( ctx sdk.Context, addr gethcommon.Address, amountEvmDenom *big.Int, ) error { nativeAddr := sdk.AccAddress(addr.Bytes()) - balance := k.bankKeeper.BaseKeeper.GetBalance(ctx, nativeAddr, evm.EVMBankDenom).Amount.BigInt() + balance := k.bankKeeper.GetBalance(ctx, nativeAddr, evm.EVMBankDenom).Amount.BigInt() delta := new(big.Int).Sub(amountEvmDenom, balance) switch delta.Sign() { case 1: // mint coins := sdk.NewCoins(sdk.NewCoin(evm.EVMBankDenom, sdkmath.NewIntFromBigInt(delta))) - if err := k.bankKeeper.BaseKeeper.MintCoins(ctx, evm.ModuleName, coins); err != nil { + if err := k.bankKeeper.MintCoins(ctx, evm.ModuleName, coins); err != nil { return err } - if err := k.bankKeeper.BaseKeeper.SendCoinsFromModuleToAccount(ctx, evm.ModuleName, nativeAddr, coins); err != nil { + if err := k.bankKeeper.SendCoinsFromModuleToAccount(ctx, evm.ModuleName, nativeAddr, coins); err != nil { return err } case -1: // burn coins := sdk.NewCoins(sdk.NewCoin(evm.EVMBankDenom, sdkmath.NewIntFromBigInt(new(big.Int).Neg(delta)))) - if err := k.bankKeeper.BaseKeeper.SendCoinsFromAccountToModule(ctx, nativeAddr, evm.ModuleName, coins); err != nil { + if err := k.bankKeeper.SendCoinsFromAccountToModule(ctx, nativeAddr, evm.ModuleName, coins); err != nil { return err } - if err := k.bankKeeper.BaseKeeper.BurnCoins(ctx, evm.ModuleName, coins); err != nil { + if err := k.bankKeeper.BurnCoins(ctx, evm.ModuleName, coins); err != nil { return err } default: diff --git a/x/evm/statedb/debug.go b/x/evm/statedb/debug.go new file mode 100644 index 0000000000..c2b5fb968b --- /dev/null +++ b/x/evm/statedb/debug.go @@ -0,0 +1,39 @@ +package statedb + +// Copyright (c) 2023-2024 Nibi, Inc. + +import ( + "github.com/ethereum/go-ethereum/common" +) + +// DebugDirtiesCount is a test helper to inspect how many entries in the journal +// are still dirty (uncommitted). After calling [StateDB.Commit], this function +// should return zero. +func (s *StateDB) DebugDirtiesCount() int { + dirtiesCount := 0 + for _, dirtyCount := range s.Journal.dirties { + dirtiesCount += dirtyCount + } + return dirtiesCount +} + +// DebugDirties is a test helper that returns the journal's dirty account changes map. +func (s *StateDB) DebugDirties() map[common.Address]int { + return s.Journal.dirties +} + +// DebugEntries is a test helper that returns the sequence of [JournalChange] +// objects added during execution. +func (s *StateDB) DebugEntries() []JournalChange { + return s.Journal.entries +} + +// DebugStateObjects is a test helper that returns returns a copy of the +// [StateDB.stateObjects] map. +func (s *StateDB) DebugStateObjects() map[common.Address]*stateObject { + copyOfMap := make(map[common.Address]*stateObject) + for key, val := range s.stateObjects { + copyOfMap[key] = val + } + return copyOfMap +} diff --git a/x/evm/statedb/journal.go b/x/evm/statedb/journal.go index e684fd5742..dd42a16f10 100644 --- a/x/evm/statedb/journal.go +++ b/x/evm/statedb/journal.go @@ -93,25 +93,6 @@ func (j *journal) Length() int { return len(j.entries) } -// DirtiesCount is a test helper to inspect how many entries in the journal are -// still dirty (uncommitted). After calling [StateDB.Commit], this function should -// return zero. -func (s *StateDB) DirtiesCount() int { - dirtiesCount := 0 - for _, dirtyCount := range s.Journal.dirties { - dirtiesCount += dirtyCount - } - return dirtiesCount -} - -func (s *StateDB) Dirties() map[common.Address]int { - return s.Journal.dirties -} - -func (s *StateDB) Entries() []JournalChange { - return s.Journal.entries -} - // ------------------------------------------------------ // createObjectChange diff --git a/x/evm/statedb/journal_test.go b/x/evm/statedb/journal_test.go index 046fc514c0..6390b640fb 100644 --- a/x/evm/statedb/journal_test.go +++ b/x/evm/statedb/journal_test.go @@ -59,7 +59,7 @@ func (s *Suite) TestComplexJournalChanges() { s.Run("Populate dirty journal entries. Remove with Commit", func() { stateDB := evmObj.StateDB.(*statedb.StateDB) - s.Equal(0, stateDB.DirtiesCount()) + s.Equal(0, stateDB.DebugDirtiesCount()) randomAcc := evmtest.NewEthPrivAcc().EthAddr balDelta := evm.NativeToWei(big.NewInt(4)) @@ -69,7 +69,7 @@ func (s *Suite) TestComplexJournalChanges() { stateDB.AddBalance(randomAcc, balDelta) // 1 dirties from [balanceChange] stateDB.SubBalance(randomAcc, balDelta) - if stateDB.DirtiesCount() != 4 { + if stateDB.DebugDirtiesCount() != 4 { debugDirtiesCountMismatch(stateDB, s.T()) s.FailNow("expected 4 dirty journal changes") } @@ -77,7 +77,7 @@ func (s *Suite) TestComplexJournalChanges() { s.T().Log("StateDB.Commit, then Dirties should be gone") err = stateDB.Commit() s.NoError(err) - if stateDB.DirtiesCount() != 0 { + if stateDB.DebugDirtiesCount() != 0 { debugDirtiesCountMismatch(stateDB, s.T()) s.FailNow("expected 0 dirty journal changes") } @@ -99,7 +99,7 @@ func (s *Suite) TestComplexJournalChanges() { ) s.Require().NoError(err) stateDB := evmObj.StateDB.(*statedb.StateDB) - if stateDB.DirtiesCount() != 2 { + if stateDB.DebugDirtiesCount() != 2 { debugDirtiesCountMismatch(stateDB, s.T()) s.FailNow("expected 2 dirty journal changes") } @@ -137,7 +137,7 @@ func (s *Suite) TestComplexJournalChanges() { ) stateDB, ok := evmObj.StateDB.(*statedb.StateDB) s.Require().True(ok, "error retrieving StateDB from the EVM") - if stateDB.DirtiesCount() != 0 { + if stateDB.DebugDirtiesCount() != 0 { debugDirtiesCountMismatch(stateDB, s.T()) s.FailNow("expected 0 dirty journal changes") } @@ -151,7 +151,7 @@ func (s *Suite) TestComplexJournalChanges() { s.Require().True(ok, "error retrieving StateDB from the EVM") s.T().Log("Expect exactly 0 dirty journal entry for the precompile snapshot") - if stateDB.DirtiesCount() != 0 { + if stateDB.DebugDirtiesCount() != 0 { debugDirtiesCountMismatch(stateDB, s.T()) s.FailNow("expected 0 dirty journal changes") } @@ -218,17 +218,12 @@ snapshots and see the prior states.`)) &s.Suite, deps, wasmContract, 7, // state before precompile called ) }) - - s.Run("too many precompile calls in one tx will fail", func() { - // currently - // evmObj - }) } func debugDirtiesCountMismatch(db *statedb.StateDB, t *testing.T) string { lines := []string{} - dirties := db.Dirties() - stateObjects := db.StateObjects() + dirties := db.DebugDirties() + stateObjects := db.DebugStateObjects() for addr, dirtyCountForAddr := range dirties { lines = append(lines, fmt.Sprintf("Dirty addr: %s, dirtyCountForAddr=%d", addr, dirtyCountForAddr)) diff --git a/x/evm/statedb/statedb.go b/x/evm/statedb/statedb.go index 957da78888..4c79e61af7 100644 --- a/x/evm/statedb/statedb.go +++ b/x/evm/statedb/statedb.go @@ -608,12 +608,3 @@ func (s *StateDB) SavePrecompileCalledJournalChange( } const maxMultistoreCacheCount uint8 = 10 - -// StateObjects: Returns a copy of the [StateDB.stateObjects] map. -func (s *StateDB) StateObjects() map[common.Address]*stateObject { - copyOfMap := make(map[common.Address]*stateObject) - for key, val := range s.stateObjects { - copyOfMap[key] = val - } - return copyOfMap -}