diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index 7ff31a3e22..228c5e5332 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -339,7 +339,7 @@ func (b *Blockchain) Store(block *core.Block, blockCommitments *core.BlockCommit return err } - if err := txn.Delete(db.Pending.Key()); err != nil { + if err := storeEmptyPending(txn, block.Header); err != nil { return err } @@ -822,14 +822,6 @@ func (b *Blockchain) revertHead(txn db.Transaction) error { return err } } - if !genesisBlock { - var newHeader *core.Header - newHeader, err = blockHeaderByNumber(txn, blockNumber-1) - if err != nil { - return err - } - b.newHeads.Send(newHeader) - } if err = removeTxsAndReceipts(txn, blockNumber, header.TransactionCount); err != nil { return err @@ -840,16 +832,26 @@ func (b *Blockchain) revertHead(txn db.Transaction) error { return err } - // remove pending - if err = txn.Delete(db.Pending.Key()); err != nil { - return err - } + // Revert chain height and pending. - // update chain height if genesisBlock { + if err = txn.Delete(db.Pending.Key()); err != nil { + return err + } return txn.Delete(db.ChainHeight.Key()) } + var newHeader *core.Header + newHeader, err = blockHeaderByNumber(txn, blockNumber-1) + if err != nil { + return err + } + b.newHeads.Send(newHeader) + + if err := storeEmptyPending(txn, newHeader); err != nil { + return err + } + heightBin := core.MarshalBlockNumber(blockNumber - 1) return txn.Set(db.ChainHeight.Key(), heightBin) } @@ -881,21 +883,52 @@ func removeTxsAndReceipts(txn db.Transaction, blockNumber, numTxs uint64) error return nil } +func storeEmptyPending(txn db.Transaction, latestHeader *core.Header) error { + receipts := make([]*core.TransactionReceipt, 0) + pendingBlock := &core.Block{ + Header: &core.Header{ + ParentHash: latestHeader.Hash, + SequencerAddress: latestHeader.SequencerAddress, + Timestamp: latestHeader.Timestamp + 1, + ProtocolVersion: latestHeader.ProtocolVersion, + EventsBloom: core.EventsBloom(receipts), + GasPrice: latestHeader.GasPrice, + }, + Transactions: make([]core.Transaction, 0), + Receipts: receipts, + } + + emptyPending := &Pending{ + Block: pendingBlock, + StateUpdate: &core.StateUpdate{ + OldRoot: latestHeader.GlobalStateRoot, + StateDiff: &core.StateDiff{ + StorageDiffs: make(map[felt.Felt][]core.StorageDiff, 0), + Nonces: make(map[felt.Felt]*felt.Felt, 0), + DeployedContracts: make([]core.DeployedContract, 0), + DeclaredV0Classes: make([]*felt.Felt, 0), + DeclaredV1Classes: make([]core.DeclaredV1Class, 0), + ReplacedClasses: make([]core.ReplacedClass, 0), + }, + }, + NewClasses: make(map[felt.Felt]core.Class, 0), + } + return storePending(txn, emptyPending) +} + // StorePending stores a pending block given that it is for the next height func (b *Blockchain) StorePending(pending *Pending) error { return b.database.Update(func(txn db.Transaction) error { - expectedParent := new(felt.Felt) - expectedOldRoot := new(felt.Felt) - h, err := head(txn) + expectedParentHash := new(felt.Felt) + h, err := headsHeader(txn) if err != nil && !errors.Is(err, db.ErrKeyNotFound) { return err } else if err == nil { - expectedParent = h.Hash - expectedOldRoot = h.GlobalStateRoot + expectedParentHash = h.Hash } - if !expectedParent.Equal(pending.Block.ParentHash) || !expectedOldRoot.Equal(pending.StateUpdate.OldRoot) { - return errors.New("pending block parent is not our local HEAD") + if !expectedParentHash.Equal(pending.Block.ParentHash) { + return ErrParentDoesNotMatchHead } existingPending, err := pendingBlock(txn) @@ -903,14 +936,18 @@ func (b *Blockchain) StorePending(pending *Pending) error { return nil // ignore the incoming pending if it has fewer transactions than the one we already have } - pendingBytes, err := encoder.Marshal(pending) - if err != nil { - return err - } - return txn.Set(db.Pending.Key(), pendingBytes) + return storePending(txn, pending) }) } +func storePending(txn db.Transaction, pending *Pending) error { + pendingBytes, err := encoder.Marshal(pending) + if err != nil { + return err + } + return txn.Set(db.Pending.Key(), pendingBytes) +} + func pendingBlock(txn db.Transaction) (Pending, error) { var pending Pending err := txn.Get(db.Pending.Key(), func(bytes []byte) error { diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go index a5eda36b88..73580817c7 100644 --- a/blockchain/blockchain_test.go +++ b/blockchain/blockchain_test.go @@ -652,6 +652,11 @@ func TestPending(t *testing.T) { su, err := gw.StateUpdate(context.Background(), 0) require.NoError(t, err) + t.Run("pending state shouldnt exist if no pending block", func(t *testing.T) { + _, _, err = chain.PendingState() + require.Error(t, err) + }) + t.Run("store genesis as pending", func(t *testing.T) { pendingGenesis := blockchain.Pending{ Block: b, @@ -664,10 +669,50 @@ func TestPending(t *testing.T) { assert.Equal(t, pendingGenesis, gotPending) }) - t.Run("storing genesis as an accepted block should clear pending", func(t *testing.T) { - require.NoError(t, chain.Store(b, &emptyCommitments, su, nil)) - _, pErr := chain.Pending() - require.ErrorIs(t, pErr, db.ErrKeyNotFound) + require.NoError(t, chain.Store(b, &emptyCommitments, su, nil)) + + t.Run("no pending block means pending state matches head state", func(t *testing.T) { + pending, pErr := chain.Pending() + require.NoError(t, pErr) + require.Equal(t, b.Timestamp+1, pending.Block.Timestamp) + require.Equal(t, b.SequencerAddress, pending.Block.SequencerAddress) + require.Equal(t, b.GasPrice, pending.Block.GasPrice) + require.Equal(t, b.ProtocolVersion, pending.Block.ProtocolVersion) + require.Equal(t, su.NewRoot, pending.StateUpdate.OldRoot) + require.Empty(t, pending.StateUpdate.StateDiff.Nonces) + require.Empty(t, pending.StateUpdate.StateDiff.StorageDiffs) + require.Empty(t, pending.StateUpdate.StateDiff.ReplacedClasses) + require.Empty(t, pending.StateUpdate.StateDiff.DeclaredV0Classes) + require.Empty(t, pending.StateUpdate.StateDiff.DeclaredV1Classes) + require.Empty(t, pending.StateUpdate.StateDiff.DeployedContracts) + require.Empty(t, pending.NewClasses) + + // PendingState matches head state. + require.NoError(t, pErr) + reader, closer, pErr := chain.PendingState() + require.NoError(t, pErr) + t.Cleanup(func() { + require.NoError(t, closer()) + }) + + for address, diff := range su.StateDiff.StorageDiffs { + for _, kv := range diff { + value, csErr := reader.ContractStorage(&address, kv.Key) + require.NoError(t, csErr) + require.Equal(t, kv.Value, value) + } + } + + for address, nonce := range su.StateDiff.Nonces { + got, cnErr := reader.ContractNonce(&address) + require.NoError(t, cnErr) + require.Equal(t, nonce, got) + } + + for _, hash := range su.StateDiff.DeclaredV0Classes { + _, err = reader.Class(hash) + require.NoError(t, err) + } }) t.Run("storing a pending too far into the future should fail", func(t *testing.T) { @@ -680,12 +725,7 @@ func TestPending(t *testing.T) { Block: b, StateUpdate: su, } - require.EqualError(t, chain.StorePending(¬ExpectedPending), "pending block parent is not our local HEAD") - }) - - t.Run("pending state shouldnt exist if no pending block", func(t *testing.T) { - _, _, err = chain.PendingState() - require.Error(t, err) + require.ErrorIs(t, chain.StorePending(¬ExpectedPending), blockchain.ErrParentDoesNotMatchHead) }) t.Run("store expected pending block", func(t *testing.T) {