Skip to content

Commit

Permalink
refactor dirty contracts
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann committed Oct 28, 2024
1 parent 8a5ed9f commit c466cca
Showing 1 changed file with 29 additions and 33 deletions.
62 changes: 29 additions & 33 deletions core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ type StateReader interface {

type State struct {
txn db.Transaction

// This map holds the contract objects which are being updated in the current state update.
contracts map[felt.Felt]*StateContract
}

func NewState(txn db.Transaction) *State {
return &State{
txn: txn,
txn: txn,
contracts: make(map[felt.Felt]*StateContract),
}
}

Expand Down Expand Up @@ -292,7 +296,6 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses
return err
}

contracts := make(map[felt.Felt]*StateContract)
// register deployed contracts
for addr, classHash := range update.StateDiff.DeployedContracts {
// check if contract is already deployed
Expand All @@ -305,14 +308,14 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses
return err
}

contracts[addr] = NewStateContract(&addr, classHash, &felt.Zero, blockNumber)
s.contracts[addr] = NewStateContract(&addr, classHash, &felt.Zero, blockNumber)
}

if err = s.updateContracts(blockNumber, update.StateDiff, true, contracts); err != nil {
if err = s.updateContracts(blockNumber, update.StateDiff, true); err != nil {
return err
}

if err = s.Commit(stateTrie, contracts, true, blockNumber); err != nil {
if err = s.Commit(stateTrie, true, blockNumber); err != nil {
return fmt.Errorf("state commit: %v", err)
}

Check warning on line 320 in core/state.go

View check run for this annotation

Codecov / codecov/patch

core/state.go#L319-L320

Added lines #L319 - L320 were not covered by tests

Expand Down Expand Up @@ -386,7 +389,6 @@ var (
// Commit updates the state by committing the dirty contracts to the database.
func (s *State) Commit(
stateTrie *trie.Trie,
contracts map[felt.Felt]*StateContract,
logChanges bool,
blockNumber uint64,
) error {
Expand All @@ -396,22 +398,21 @@ func (s *State) Commit(
}

// // sort the contracts in descending storage diff order
keys := slices.SortedStableFunc(maps.Keys(contracts), func(a, b felt.Felt) int {
return len(contracts[a].dirtyStorage) - len(contracts[b].dirtyStorage)
keys := slices.SortedStableFunc(maps.Keys(s.contracts), func(a, b felt.Felt) int {
return len(s.contracts[a].dirtyStorage) - len(s.contracts[b].dirtyStorage)
})

contractPools := pool.NewWithResults[*bufferedTransactionWithAddress]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0))
for _, addr := range keys {
contract := contracts[addr]
contractPools.Go(func() (*bufferedTransactionWithAddress, error) {
txn, err := contract.BufferedCommit(s.txn, logChanges, blockNumber)
txn, err := s.contracts[addr].BufferedCommit(s.txn, logChanges, blockNumber)
if err != nil {
return nil, err
}

Check warning on line 411 in core/state.go

View check run for this annotation

Codecov / codecov/patch

core/state.go#L410-L411

Added lines #L410 - L411 were not covered by tests

return &bufferedTransactionWithAddress{
txn: txn,
addr: contract.Address,
addr: &addr,
}, nil
})
}
Expand All @@ -432,39 +433,37 @@ func (s *State) Commit(
}

Check warning on line 433 in core/state.go

View check run for this annotation

Codecov / codecov/patch

core/state.go#L432-L433

Added lines #L432 - L433 were not covered by tests
}

for _, contract := range contracts {
for _, contract := range s.contracts {
if err := s.updateContractCommitment(stateTrie, contract); err != nil {
return err
}

Check warning on line 439 in core/state.go

View check run for this annotation

Codecov / codecov/patch

core/state.go#L438-L439

Added lines #L438 - L439 were not covered by tests
}

// finally, clear the contracts map
s.contracts = make(map[felt.Felt]*StateContract)

return nil
}

func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges bool, contracts map[felt.Felt]*StateContract) error {
if contracts == nil {
return fmt.Errorf("contracts is nil")
}

if err := s.updateContractClasses(blockNumber, diff.ReplacedClasses, logChanges, contracts); err != nil {
func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges bool) error {
if err := s.updateContractClasses(blockNumber, diff.ReplacedClasses, logChanges); err != nil {
return err
}

Check warning on line 451 in core/state.go

View check run for this annotation

Codecov / codecov/patch

core/state.go#L450-L451

Added lines #L450 - L451 were not covered by tests

if err := s.updateContractNonces(blockNumber, diff.Nonces, logChanges, contracts); err != nil {
if err := s.updateContractNonces(blockNumber, diff.Nonces, logChanges); err != nil {
return err
}

Check warning on line 455 in core/state.go

View check run for this annotation

Codecov / codecov/patch

core/state.go#L454-L455

Added lines #L454 - L455 were not covered by tests

return s.updateContractStorages(blockNumber, diff.StorageDiffs, contracts)
return s.updateContractStorages(blockNumber, diff.StorageDiffs)
}

func (s *State) updateContractClasses(
blockNumber uint64,
replacedClasses map[felt.Felt]*felt.Felt,
logChanges bool,
contracts map[felt.Felt]*StateContract,
) error {
for addr, classHash := range replacedClasses {
contract, err := s.getContract(addr, contracts)
contract, err := s.getContract(addr)
if err != nil {
return err
}
Expand All @@ -484,10 +483,9 @@ func (s *State) updateContractNonces(
blockNumber uint64,
nonces map[felt.Felt]*felt.Felt,
logChanges bool,
contracts map[felt.Felt]*StateContract,
) error {
for addr, nonce := range nonces {
contract, err := s.getContract(addr, contracts)
contract, err := s.getContract(addr)
if err != nil {
return err
}
Expand All @@ -506,14 +504,13 @@ func (s *State) updateContractNonces(
func (s *State) updateContractStorages(
blockNumber uint64,
storageDiffs map[felt.Felt]map[felt.Felt]*felt.Felt,
contracts map[felt.Felt]*StateContract,
) error {
for addr, diff := range storageDiffs {
contract, err := s.getContract(addr, contracts)
contract, err := s.getContract(addr)
if err != nil {
if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) {
contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber)
contracts[addr] = contract
s.contracts[addr] = contract
} else {
return err
}
Expand All @@ -524,15 +521,15 @@ func (s *State) updateContractStorages(
return nil
}

func (s *State) getContract(addr felt.Felt, contracts map[felt.Felt]*StateContract) (*StateContract, error) {
contract, ok := contracts[addr]
func (s *State) getContract(addr felt.Felt) (*StateContract, error) {
contract, ok := s.contracts[addr]
if !ok {
var err error
contract, err = GetContract(&addr, s.txn)
if err != nil {
return nil, err
}
contracts[addr] = contract
s.contracts[addr] = contract
}
return contract, nil
}
Expand Down Expand Up @@ -655,12 +652,11 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error {
return err
}

contracts := make(map[felt.Felt]*StateContract)
if err = s.updateContracts(blockNumber, reversedDiff, false, contracts); err != nil {
if err = s.updateContracts(blockNumber, reversedDiff, false); err != nil {
return fmt.Errorf("update contracts: %v", err)
}

if err = s.Commit(stateTrie, contracts, false, blockNumber); err != nil {
if err = s.Commit(stateTrie, false, blockNumber); err != nil {
return fmt.Errorf("state commit: %v", err)

Check warning on line 660 in core/state.go

View check run for this annotation

Codecov / codecov/patch

core/state.go#L660

Added line #L660 was not covered by tests
}

Expand Down

0 comments on commit c466cca

Please sign in to comment.