Skip to content

Commit

Permalink
WIP - almost complete without tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pnowosie committed Oct 8, 2024
1 parent d383531 commit 1ef09de
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 36 deletions.
36 changes: 36 additions & 0 deletions core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ type StateReader interface {
ContractNonce(addr *felt.Felt) (*felt.Felt, error)
ContractStorage(addr, key *felt.Felt) (*felt.Felt, error)
Class(classHash *felt.Felt) (*DeclaredClass, error)
ClassTrie() (*trie.Trie, func() error, error)
StorageTrie() (*trie.Trie, func() error, error)
StorageTrieForAddr(addr *felt.Felt) (*trie.Trie, error)
StateAndClassRoot() (*felt.Felt, *felt.Felt, error)
}

type State struct {
Expand Down Expand Up @@ -733,3 +737,35 @@ func (s *State) buildReverseDiff(blockNumber uint64, diff *StateDiff) (*StateDif

return &reversed, nil
}

func (s *State) StateAndClassRoot() (*felt.Felt, *felt.Felt, error) {
var storageRoot, classesRoot *felt.Felt

sStorage, closer, err := s.storage()
if err != nil {
return nil, nil, err
}

if storageRoot, err = sStorage.Root(); err != nil {
return nil, nil, err
}

if err = closer(); err != nil {
return nil, nil, err
}

classes, closer, err := s.classesTrie()
if err != nil {
return nil, nil, err
}

if classesRoot, err = classes.Root(); err != nil {
return nil, nil, err
}

if err = closer(); err != nil {
return nil, nil, err
}

return storageRoot, classesRoot, nil
}
4 changes: 2 additions & 2 deletions core/trie/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Node struct {
}

// Hash calculates the hash of a [Node]
func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt {
func (n *Node) Hash(path *Key, hashFunc HashFunc) *felt.Felt {
if path.Len() == 0 {
// we have to deference the Value, since the Node can released back
// to the NodePool and be reused anytime
Expand All @@ -33,7 +33,7 @@ func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt {
}

// Hash calculates the hash of a [Node]
func (n *Node) HashFromParent(parnetKey, nodeKey *Key, hashFunc hashFunc) *felt.Felt {
func (n *Node) HashFromParent(parnetKey, nodeKey *Key, hashFunc HashFunc) *felt.Felt {
path := path(nodeKey, parnetKey)
return n.Hash(&path, hashFunc)
}
Expand Down
29 changes: 17 additions & 12 deletions core/trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var (
)

type ProofNode interface {
Hash(hash hashFunc) *felt.Felt
Hash(hash HashFunc) *felt.Felt
Len() uint8
PrettyPrint()
}
Expand All @@ -23,7 +23,7 @@ type Binary struct {
RightHash *felt.Felt
}

func (b *Binary) Hash(hash hashFunc) *felt.Felt {
func (b *Binary) Hash(hash HashFunc) *felt.Felt {
return hash(b.LeftHash, b.RightHash)
}

Expand All @@ -42,7 +42,7 @@ type Edge struct {
Path *Key // path from parent to child
}

func (e *Edge) Hash(hash hashFunc) *felt.Felt {
func (e *Edge) Hash(hash HashFunc) *felt.Felt {
length := make([]byte, len(e.Path.bitset))
length[len(e.Path.bitset)-1] = e.Path.len
pathFelt := e.Path.Felt()
Expand All @@ -54,6 +54,11 @@ func (e *Edge) Len() uint8 {
return e.Path.Len()
}

func (e *Edge) PathInt() uint64 {
f := e.Path.Felt()
return f.Uint64()
}

func (e *Edge) PrettyPrint() {
fmt.Printf(" Edge:\n")
fmt.Printf(" Child: %v\n", e.Child)
Expand Down Expand Up @@ -199,7 +204,7 @@ func traverseNodes(currNode ProofNode, path *[]ProofNode, nodeHashes map[felt.Fe
// merges paths in the specified order [commonNodes..., leftNodes..., rightNodes...]
// ordering of the merged path is not important
// since SplitProofPath can discover the left and right paths using the merged path and the rootHash
func MergeProofPaths(leftPath, rightPath []ProofNode, hash hashFunc) ([]ProofNode, *felt.Felt, error) {
func MergeProofPaths(leftPath, rightPath []ProofNode, hash HashFunc) ([]ProofNode, *felt.Felt, error) {
merged := []ProofNode{}
minLen := min(len(leftPath), len(rightPath))

Expand Down Expand Up @@ -236,7 +241,7 @@ func MergeProofPaths(leftPath, rightPath []ProofNode, hash hashFunc) ([]ProofNod
// SplitProofPath splits the merged proof path into two paths (left and right), which were merged before
// it first validates that the merged path is not circular, the split happens at most once and rootHash exists
// then calls traverseNodes to split the path to left and right paths
func SplitProofPath(mergedPath []ProofNode, rootHash *felt.Felt, hash hashFunc) ([]ProofNode, []ProofNode, error) {
func SplitProofPath(mergedPath []ProofNode, rootHash *felt.Felt, hash HashFunc) ([]ProofNode, []ProofNode, error) {
commonPath := []ProofNode{}
leftPath := []ProofNode{}
rightPath := []ProofNode{}
Expand Down Expand Up @@ -316,7 +321,7 @@ func GetProof(key *Key, tri *Trie) ([]ProofNode, error) {

// verifyProof checks if `leafPath` leads from `root` to `leafHash` along the `proofNodes`
// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006
func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash hashFunc) bool {
func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash HashFunc) bool {
expectedHash := root
remainingPath := NewKey(key.len, key.bitset[:])
for i, proofNode := range proofs {
Expand Down Expand Up @@ -363,7 +368,7 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode
// and therefore it's hash won't match the expected root.
// ref: https://github.com/ethereum/go-ethereum/blob/v1.14.3/trie/proof.go#L484
func VerifyRangeProof(root *felt.Felt, keys, values []*felt.Felt, proofKeys [2]*Key, proofValues [2]*felt.Felt,
proofs [2][]ProofNode, hash hashFunc,
proofs [2][]ProofNode, hash HashFunc,
) (bool, error) {
// Step 0: checks
if len(keys) != len(values) {
Expand Down Expand Up @@ -440,7 +445,7 @@ func ensureMonotonicIncreasing(proofKeys [2]*Key, keys []*felt.Felt) error {
}

// compressNode determines if the node needs compressed, and if so, the len needed to arrive at the next key
func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8, error) {
func compressNode(idx int, proofNodes []ProofNode, hashF HashFunc) (int, uint8, error) {
parent := proofNodes[idx]

if idx == len(proofNodes)-1 {
Expand Down Expand Up @@ -474,7 +479,7 @@ func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8,
}

func assignChild(i, compressedParent int, parentNode *Node,
nilKey, leafKey, parentKey *Key, proofNodes []ProofNode, hashF hashFunc,
nilKey, leafKey, parentKey *Key, proofNodes []ProofNode, hashF HashFunc,
) (*Key, error) {
childInd := i + compressedParent + 1
childKey, err := getChildKey(childInd, parentKey, leafKey, nilKey, proofNodes, hashF)
Expand All @@ -494,7 +499,7 @@ func assignChild(i, compressedParent int, parentNode *Node,
// ProofToPath returns a set of storage nodes from the root to the end of the proof path.
// The storage nodes will have the hashes of the children, but only the key of the child
// along the path outlined by the proof.
func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]StorageNode, error) {
func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF HashFunc) ([]StorageNode, error) {
pathNodes := []StorageNode{}

// Child keys that can't be derived are set to nilKey, so that we can store the node
Expand Down Expand Up @@ -552,7 +557,7 @@ func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]Storag
return pathNodes, nil
}

func skipNode(pNode ProofNode, pathNodes []StorageNode, hashF hashFunc) bool {
func skipNode(pNode ProofNode, pathNodes []StorageNode, hashF HashFunc) bool {
lastNode := pathNodes[len(pathNodes)-1].node
noLeftMatch, noRightMatch := false, false
if lastNode.LeftHash != nil && !pNode.Hash(hashF).Equal(lastNode.LeftHash) {
Expand Down Expand Up @@ -607,7 +612,7 @@ func getParentKey(idx int, compressedParentOffset uint8, leafKey *Key,
return crntKey, err
}

func getChildKey(childIdx int, crntKey, leafKey, nilKey *Key, proofNodes []ProofNode, hashF hashFunc) (*Key, error) {
func getChildKey(childIdx int, crntKey, leafKey, nilKey *Key, proofNodes []ProofNode, hashF HashFunc) (*Key, error) {
if childIdx > len(proofNodes)-1 {
return nilKey, nil
}
Expand Down
10 changes: 7 additions & 3 deletions core/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/NethermindEth/juno/db"
)

type hashFunc func(*felt.Felt, *felt.Felt) *felt.Felt
type HashFunc func(*felt.Felt, *felt.Felt) *felt.Felt

// Trie is a dense Merkle Patricia Trie (i.e., all internal nodes have two children).
//
Expand All @@ -37,7 +37,7 @@ type Trie struct {
rootKey *Key
maxKey *felt.Felt
storage *Storage
hash hashFunc
hash HashFunc

dirtyNodes []*Key
rootKeyIsDirty bool
Expand All @@ -53,7 +53,7 @@ func NewTriePoseidon(storage *Storage, height uint8) (*Trie, error) {
return newTrie(storage, height, crypto.Poseidon)
}

func newTrie(storage *Storage, height uint8, hash hashFunc) (*Trie, error) {
func newTrie(storage *Storage, height uint8, hash HashFunc) (*Trie, error) {
if height > felt.Bits {
return nil, fmt.Errorf("max trie height is %d, got: %d", felt.Bits, height)
}
Expand Down Expand Up @@ -668,6 +668,10 @@ func (t *Trie) RootKey() *Key {
return t.rootKey
}

func (t *Trie) HashFunc() HashFunc {
return t.hash
}

func (t *Trie) Dump() {
t.dump(0, nil)
}
Expand Down
Loading

0 comments on commit 1ef09de

Please sign in to comment.