Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
revitteth committed Dec 9, 2024
1 parent 894254d commit b3b3f8d
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 123 deletions.
151 changes: 80 additions & 71 deletions zk/acc_input_hash/acc_input_hash.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,39 @@
package acc_input_hash

import (
"context"
"fmt"

"github.com/ledgerwatch/erigon-lib/chain"
"github.com/ledgerwatch/erigon-lib/common"
"context"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon/core/state"
"github.com/ledgerwatch/erigon/core/systemcontracts"
eritypes "github.com/ledgerwatch/erigon/core/types"
"github.com/ledgerwatch/erigon/zk/utils"
"github.com/ledgerwatch/erigon/zk/types"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/chain"
)

const SpecialZeroHash = "0x27AE5BA08D7291C96C8CBDDCC148BF48A6D68C7974B94356F53754EF6171D757"

// BlockReader is a rawdb block reader abstraction for easier testing
type BlockReader interface {
ReadBlockByNumber(blockNo uint64) (*eritypes.Block, error)
}

// BatchDataReader is an abstraction for reading batch data
type BatchDataReader interface {
GetBlockL1InfoTreeIndex(blockNo uint64) (uint64, error)
GetEffectiveGasPricePercentage(txHash common.Hash) (uint8, error)
GetL2BlockNosByBatch(batchNo uint64) ([]uint64, error)
GetForkId(batchNo uint64) (uint64, error)
}

// L1DataReader is an abstraction for reading L1 data
type L1DataReader interface {
GetBlockGlobalExitRoot(l2BlockNo uint64) (common.Hash, error)
GetL1InfoTreeUpdateByGer(ger common.Hash) (*types.L1InfoTreeUpdate, error)
GetL1InfoTreeIndexToRoots() (map[uint64]common.Hash, error)
GetBlockL1InfoTreeIndex(blockNo uint64) (uint64, error)
}

// AccInputHashReader is an abstraction for reading acc input hashes
// AccInputHashReader combines the necessary reader interfaces
type AccInputHashReader interface {
GetAccInputHashForBatchOrPrevious(batchNo uint64) (common.Hash, uint64, error)
BatchDataReader
Expand Down Expand Up @@ -102,7 +104,6 @@ func NewPreFork7Calculator(bc *BaseCalc) AccInputHashCalculator {
}

func (p PreFork7Calculator) Calculate(batchNum uint64) (common.Hash, error) {
// TODO: warn log - and return error
// this isn't supported
return common.Hash{}, nil
}
Expand Down Expand Up @@ -131,6 +132,15 @@ func (f Fork7Calculator) Calculate(batchNum uint64) (common.Hash, error) {
return common.Hash{}, fmt.Errorf("unsupported fork ID: %d", forkId)
}

// TODO: remove test spoooofing! (1001 and 997 are l1 held batch accinputhashes - sequence ends)
if batchNum >= 997 {
// let's just spoof it backwards:
accInputHash, returnedBatchNo, err = f.Reader.GetAccInputHashForBatchOrPrevious(995)
if err != nil {
return common.Hash{}, err
}
}

// if we have it, return it
if returnedBatchNo == batchNum {
return accInputHash, nil
Expand Down Expand Up @@ -161,79 +171,78 @@ func LocalAccInputHashCalc(ctx context.Context, reader AccInputHashReader, block
startBatchNo = 1
}

// TODO: handle batch 1 case where we should get check the aggregator code: https://github.com/0xPolygon/cdk/blob/develop/aggregator/aggregator.go#L1167

for i := startBatchNo; i <= batchNum; i++ {
select {
case <-ctx.Done():
return common.Hash{}, ctx.Err()
default:
currentForkId, err := reader.GetForkId(i)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to get fork id for batch %d: %w", i, err)
}
currentForkId, err := reader.GetForkId(i)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to get fork id for batch %d: %w", i, err)
}

batchBlockNos, err := reader.GetL2BlockNosByBatch(i)
batchBlockNos, err := reader.GetL2BlockNosByBatch(i)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to get batch blocks for batch %d: %w", i, err)
}
batchBlocks := []*eritypes.Block{}
var coinbase common.Address
for in, blockNo := range batchBlockNos {
block, err := blockReader.ReadBlockByNumber(blockNo)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to get batch blocks for batch %d: %w", i, err)
return common.Hash{}, fmt.Errorf("failed to get block %d: %w", blockNo, err)
}
batchBlocks := []*eritypes.Block{}
var batchTxs []eritypes.Transaction
var coinbase common.Address
for in, blockNo := range batchBlockNos {
block, err := blockReader.ReadBlockByNumber(blockNo)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to get block %d: %w", blockNo, err)
}
if in == 0 {
coinbase = block.Coinbase()
}
batchBlocks = append(batchBlocks, block)
batchTxs = append(batchTxs, block.Transactions()...)
if in == 0 {
coinbase = block.Coinbase()
}
batchBlocks = append(batchBlocks, block)
}

lastBlockNoInPreviousBatch := uint64(0)
firstBlockInBatch := batchBlocks[0]
if firstBlockInBatch.NumberU64() != 0 {
lastBlockNoInPreviousBatch = firstBlockInBatch.NumberU64() - 1
}
lastBlockNoInPreviousBatch := uint64(0)
firstBlockInBatch := batchBlocks[0]
if firstBlockInBatch.NumberU64() != 0 {
lastBlockNoInPreviousBatch = firstBlockInBatch.NumberU64() - 1
}

lastBlockInPreviousBatch, err := blockReader.ReadBlockByNumber(lastBlockNoInPreviousBatch)
if err != nil {
return common.Hash{}, err
}
lastBlockInPreviousBatch, err := blockReader.ReadBlockByNumber(lastBlockNoInPreviousBatch)
if err != nil {
return common.Hash{}, err
}

batchL2Data, err := utils.GenerateBatchDataFromDb(tx, reader, batchBlocks, lastBlockInPreviousBatch, currentForkId)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to generate batch data for batch %d: %w", i, err)
}
batchL2Data, err := utils.GenerateBatchDataFromDb(tx, reader, batchBlocks, lastBlockInPreviousBatch, currentForkId)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to generate batch data for batch %d: %w", i, err)
}

ger, err := reader.GetBlockGlobalExitRoot(batchBlockNos[len(batchBlockNos)-1])
if err != nil {
return common.Hash{}, fmt.Errorf("failed to get global exit root for batch %d: %w", i, err)
}
highestBlock := batchBlocks[len(batchBlocks)-1]

l1InfoTreeUpdate, err := reader.GetL1InfoTreeUpdateByGer(ger)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to get l1 info root for batch %d: %w", i, err)
}
l1InfoRoot := infoTreeIndexes[0]
if l1InfoTreeUpdate != nil {
l1InfoRoot = infoTreeIndexes[l1InfoTreeUpdate.Index]
}
limitTs := batchBlocks[len(batchBlocks)-1].Time()
inputs := utils.AccHashInputs{
OldAccInputHash: &prevAccInputHash,
Sequencer: coinbase,
BatchData: batchL2Data,
L1InfoRoot: &l1InfoRoot,
LimitTimestamp: limitTs,
ForcedBlockHash: &common.Hash{},
}
accInputHash, err = utils.CalculateAccInputHashByForkId(inputs)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to calculate accInputHash for batch %d: %w", i, err)
}
prevAccInputHash = accInputHash
sr := state.NewPlainState(tx, highestBlock.NumberU64(), systemcontracts.SystemContractCodeLookup["hermez"])
if err != nil {
return common.Hash{}, fmt.Errorf("failed to get psr: %w", err)
}
l1InfoRootBytes, err := sr.ReadAccountStorage(state.ADDRESS_SCALABLE_L2, 1, &state.BLOCK_INFO_ROOT_STORAGE_POS)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to read l1 info root: %w", err)
}
sr.Close()
l1InfoRoot := common.BytesToHash(l1InfoRootBytes)

limitTs := highestBlock.Time()

fmt.Println("[l1InfoRoot]", l1InfoRoot.Hex())
fmt.Println("[limitTs]", limitTs)

inputs := utils.AccHashInputs{
OldAccInputHash: prevAccInputHash,
Sequencer: coinbase,
BatchData: batchL2Data,
L1InfoRoot: l1InfoRoot,
LimitTimestamp: limitTs,
ForcedBlockHash: common.Hash{},
}
accInputHash, err = utils.CalculateAccInputHashByForkId(inputs)
if err != nil {
return common.Hash{}, fmt.Errorf("failed to calculate accInputHash for batch %d: %w", i, err)
}
prevAccInputHash = accInputHash
}
return accInputHash, nil
}
17 changes: 9 additions & 8 deletions zk/acc_input_hash/acc_input_hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import (
"fmt"
"math/big"
"testing"

"github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/mdbx"
eritypes "github.com/ledgerwatch/erigon/core/types"
"github.com/ledgerwatch/erigon/zk/hermez_db"
"github.com/ledgerwatch/erigon/zk/types"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/mdbx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -92,12 +93,12 @@ func (m *MockAccInputHashReader) GetL1InfoTreeIndexToRoots() (map[uint64]common.
return m.L1InfoTreeIndexToRoots, nil
}

func GetDbTx() (tx kv.RwTx, cleanup func()) {
dbi, err := mdbx.NewTemporaryMdbx(context.TODO(), "")
func GetDbTx(ctx context.Context) (tx kv.RwTx, cleanup func()) {
dbi, err := mdbx.NewTemporaryMdbx(ctx, "")
if err != nil {
panic(err)
}
tx, err = dbi.BeginRw(context.TODO())
tx, err = dbi.BeginRw(ctx)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -155,7 +156,7 @@ func TestCalculateAccInputHash(t *testing.T) {
"Valid Fork7 Missing Batch Calculate Hash": {
forkID: 7,
batchNum: 5,
expectedHash: common.HexToHash("0x34166ed584a98ccfad3c615899d1ea4975431bddcbe05e9a30f47fef00079739"),
expectedHash: common.HexToHash("0xb370e69e462a8a00469cb0ce188399a9754880dfd8ebd98717e24cbe1103efa6"),
expectError: false,
setup: func(t *testing.T) (*MockAccInputHashReader, *MockBlockReader) {
reader := &MockAccInputHashReader{
Expand Down Expand Up @@ -204,7 +205,7 @@ func TestCalculateAccInputHash(t *testing.T) {
"Valid Fork7, No Previous Batch": {
forkID: 7,
batchNum: 2,
expectedHash: common.HexToHash("0x81436dcddced6a80e936704a5c7fc6002ee19260169d213e4ff2d4f51bab0484"),
expectedHash: common.HexToHash("0x0cd77f88e7eeeef006fa44caaf24baab7a1b46321e26a9fa28f943a293a8811e"),
expectError: false,
setup: func(t *testing.T) (*MockAccInputHashReader, *MockBlockReader) {
reader := &MockAccInputHashReader{
Expand Down Expand Up @@ -274,7 +275,7 @@ func TestCalculateAccInputHash(t *testing.T) {

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
tx, cleanup := GetDbTx()
tx, cleanup := GetDbTx(ctx)
reader, mockBlockReader := tc.setup(t)
var calculator AccInputHashCalculator
var err error
Expand Down
22 changes: 21 additions & 1 deletion zk/hermez_db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ const BATCH_WITNESSES = "hermez_batch_witnesses" // batch
const BATCH_COUNTERS = "hermez_batch_counters" // block number -> counters
const L1_BATCH_DATA = "l1_batch_data" // batch number -> l1 batch data from transaction call data
const REUSED_L1_INFO_TREE_INDEX = "reused_l1_info_tree_index" // block number => const 1
const LATEST_USED_GER = "latest_used_ger" // batch number -> GER latest used GER
const LATEST_USED_GER = "latest_used_ger" // block number -> GER latest used GER
const BATCH_BLOCKS = "batch_blocks" // batch number -> block numbers (concatenated together)
const SMT_DEPTHS = "smt_depths" // block number -> smt depth
const L1_INFO_LEAVES = "l1_info_leaves" // l1 info tree index -> l1 info tree leaf
Expand Down Expand Up @@ -1578,6 +1578,26 @@ func (db *HermezDbReader) GetLatestUsedGer() (uint64, common.Hash, error) {
return batchNo, ger, nil
}

func (db *HermezDbReader) GetLatestUsedGerByBlockNo(blockNo uint64) (common.Hash, error) {
c, err := db.tx.Cursor(LATEST_USED_GER)
if err != nil {
return common.Hash{}, err
}
defer c.Close()

for k, v, err := c.Seek(Uint64ToBytes(blockNo)); k != nil; k, v, err = c.Prev() {
if err != nil {
return common.Hash{}, err
}

if len(v) > 0 {
return common.BytesToHash(v), nil
}
}

return common.Hash{}, nil
}

func (db *HermezDb) DeleteLatestUsedGers(fromBlockNum, toBlockNum uint64) error {
return db.deleteFromBucketWithUintKeysRange(LATEST_USED_GER, fromBlockNum, toBlockNum)
}
Expand Down
26 changes: 9 additions & 17 deletions zk/stages/stage_l1_syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,6 @@ Loop:
if batchLogType == logSequence && cfg.zkCfg.L1RollupId > 1 {
continue
}
// TODO: make the l1 call to get the accinputhash - if we have one for the batch, we get it.
// check the table for the batch - if it's there don't call l1
// cal l1 for it
// if its not 0, write it
// warn on failure - but it's not critical
// move on
// unwind scenario
// flag to turn this off
if err := hermezDb.WriteSequence(info.L1BlockNo, info.BatchNo, info.L1TxHash, info.StateRoot, info.L1InfoRoot); err != nil {
return fmt.Errorf("WriteSequence: %w", err)
}
Expand Down Expand Up @@ -205,7 +197,7 @@ Loop:
}

// do this separately to allow upgrading nodes to back-fill the table
err = getAccInputHashes(ctx, hermezDb, cfg.syncer, &cfg.zkCfg.AddressRollup, cfg.zkCfg.L1RollupId, highestVerification.BatchNo)
err = getAccInputHashes(ctx, logPrefix, hermezDb, cfg.syncer, &cfg.zkCfg.AddressRollup, cfg.zkCfg.L1RollupId, highestVerification.BatchNo)
if err != nil {
return fmt.Errorf("getAccInputHashes: %w", err)
}
Expand Down Expand Up @@ -433,9 +425,9 @@ func blockComparison(tx kv.RwTx, hermezDb *hermez_db.HermezDb, blockNo uint64, l

// call the l1 to get accInputHashes working backwards from the highest known batch, to the highest stored batch
// could be all the way to 0 for a new or upgrading node
func getAccInputHashes(ctx context.Context, hermezDb *hermez_db.HermezDb, syncer IL1Syncer, rollupAddr *common.Address, rollupId uint64, highestSeenBatchNo uint64) error {
func getAccInputHashes(ctx context.Context, logPrefix string, hermezDb *hermez_db.HermezDb, syncer IL1Syncer, rollupAddr *common.Address, rollupId uint64, highestSeenBatchNo uint64) error {
if highestSeenBatchNo == 0 {
log.Info("No (new) batches seen on L1, skipping acc input hash retrieval")
log.Info(fmt.Sprintf("[%s] No (new) batches seen on L1, skipping accinputhash retreival", logPrefix))
return nil
}

Expand Down Expand Up @@ -496,15 +488,15 @@ func getAccInputHashes(ctx context.Context, hermezDb *hermez_db.HermezDb, syncer

accInputHash, _, err := syncer.CallGetRollupSequencedBatches(ctx, rollupAddr, rollupId, batchNo)
if err != nil {
log.Error("CallGetRollupSequencedBatches failed", "batch", batchNo, "err", err)
log.Error(fmt.Sprintf("[%s] CallGetRollupSequencedBatches failed", logPrefix), "batch", batchNo, "err", err)
select {
case resultsCh <- Result{BatchNo: batchNo, Error: err}:
case <-ctx.Done():
}
return
}

log.Debug("Got accinputhash from L1", "batch", batchNo, "hash", accInputHash)
log.Debug(fmt.Sprintf("[%s] Got accinputhash from L1", logPrefix), "batch", batchNo, "hash", accInputHash)

select {
case resultsCh <- Result{BatchNo: batchNo, AccInputHash: accInputHash}:
Expand All @@ -523,20 +515,20 @@ func getAccInputHashes(ctx context.Context, hermezDb *hermez_db.HermezDb, syncer
case res, ok := <-resultsCh:
if !ok {
duration := time.Since(startTime)
log.Info("Completed fetching accinputhashes", "total_batches", totalSequences, "processed_batches", processedSequences, "duration", duration)
log.Info(fmt.Sprintf("[%s] Completed fetching accinputhashes", logPrefix), "total_batches", totalSequences, "processed_batches", processedSequences, "duration", duration)
return nil
}
if res.Error != nil {
log.Warn("Error fetching accinputhash", "batch", res.BatchNo, "err", res.Error)
log.Warn(fmt.Sprintf("[%s] Error fetching accinputhash", logPrefix), "batch", res.BatchNo, "err", res.Error)
}
// Write to Db
if err := hermezDb.WriteBatchAccInputHash(res.BatchNo, res.AccInputHash); err != nil {
log.Error("WriteBatchAccInputHash failed", "batch", res.BatchNo, "err", err)
log.Error(fmt.Sprintf("[%s] WriteBatchAccInputHash failed", logPrefix), "batch", res.BatchNo, "err", err)
return err
}
processedSequences++
case <-ticker.C:
log.Info("Progress update", "total_batches", totalSequences, "processed_batches", processedSequences)
log.Info(fmt.Sprintf("[%s] Progress update", logPrefix), "total_batches", totalSequences, "processed_batches", processedSequences)
case <-ctx.Done():
return ctx.Err()
}
Expand Down
Loading

0 comments on commit b3b3f8d

Please sign in to comment.