diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..6fbcf2a --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,150 @@ +linters-settings: + gci: + local-prefixes: github.com/txaty/go-merkletree + dupl: + threshold: 100 + errorlint: + errorf: true + errcheck: + check-type-assertions: true + check-blank: true + exhaustive: + check-generated: false + default-signifies-exhaustive: false + funlen: + lines: 65 + statements: 40 + gocognit: + min-complexity: 15 + gocyclo: + min-complexity: 10 + goconst: + min-len: 2 + min-occurrences: 2 + gocritic: + enabled-tags: + - diagnostic + - experimental + - opinionated + - performance + - style + disabled-checks: + - dupImport + - unnecessaryBlock + gofumpt: + extra-rules: true + gomnd: + settings: + mnd: + checks: + - argument + - case + - condition + - operation + - return + govet: + check-shadowing: true + misspell: + locale: US + nestif: + min-complexity: 4 + nolintlint: + require-explanation: true + require-specific: true + +linters: + disable-all: true + enable: + - asasalint + - asciicheck + - bidichk + - bodyclose + - containedctx + - contextcheck + - cyclop + - decorder + - depguard + - dogsled + - dupl + - durationcheck + - errcheck + - errorlint + - exhaustive + - exportloopref + - forbidigo + - funlen + - gci + - gochecknoglobals + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godot + - godox + - goerr113 + - gofmt + - gofumpt + - goimports + - gomnd + - gomodguard + - goprintffuncname + - gosec + - gosimple + - govet + - ineffassign + - makezero + - misspell + - nakedret + - nestif + - nlreturn + - noctx + - nolintlint + - paralleltest + - predeclared + - revive + - rowserrcheck + - sloglint + - sqlclosecheck + - staticcheck + - stylecheck + - tparallel + - thelper + - typecheck + - unconvert + - unparam + - unused + - wastedassign + - wsl + - whitespace + - goheader + - prealloc + - wrapcheck + - zerologlint + +disable: + - testpackage + +issues: + exclude-rules: + - path: _test\.go + linters: + - cyclop + - depguard + - dupl + - forbidigo + - funlen + - gocognit + - gocritic + - gocyclo + - goerr113 + - gosec + - nestif + - nlreturn + - paralleltest + - unparam + - wsl + +run: + skip-dirs: + - docs diff --git a/default_hash.go b/default_hash.go index 86fe058..d4e6f13 100644 --- a/default_hash.go +++ b/default_hash.go @@ -25,7 +25,9 @@ package merkletree import "crypto/sha256" // sha256Digest is the reusable digest for DefaultHashFunc. -// It is used to avoid creating a new hash digest for every call to DefaultHashFunc. +// It is used to avoid creating a new hash digest for every call to DefaultHashFunc and reduce memory allocations. +// +//nolint:gochecknoglobals // Ignoring this linting error as this has to be a global variable. var sha256Digest = sha256.New() // DefaultHashFunc is the default hash function used when no user-specified hash function is provided. @@ -33,6 +35,7 @@ var sha256Digest = sha256.New() func DefaultHashFunc(data []byte) ([]byte, error) { defer sha256Digest.Reset() sha256Digest.Write(data) + return sha256Digest.Sum(make([]byte, 0, sha256Digest.Size())), nil } @@ -42,5 +45,6 @@ func DefaultHashFunc(data []byte) ([]byte, error) { func DefaultHashFuncParallel(data []byte) ([]byte, error) { digest := sha256.New() digest.Write(data) + return digest.Sum(make([]byte, 0, digest.Size())), nil } diff --git a/leaf.go b/leaf.go index a858e76..4345a06 100644 --- a/leaf.go +++ b/leaf.go @@ -22,7 +22,11 @@ package merkletree -import "golang.org/x/sync/errgroup" +import ( + "fmt" + + "golang.org/x/sync/errgroup" +) // computeLeafNodes compute the leaf nodes from the data blocks. func (m *MerkleTree) computeLeafNodes(blocks []DataBlock) ([][]byte, error) { @@ -32,11 +36,13 @@ func (m *MerkleTree) computeLeafNodes(blocks []DataBlock) ([][]byte, error) { disableLeafHashing = m.DisableLeafHashing err error ) + for i := 0; i < m.NumLeaves; i++ { if leaves[i], err = dataBlockToLeaf(blocks[i], hashFunc, disableLeafHashing); err != nil { return nil, err } } + return leaves, nil } @@ -50,9 +56,12 @@ func (m *MerkleTree) computeLeafNodesParallel(blocks []DataBlock) ([][]byte, err disableLeafHashing = m.DisableLeafHashing eg = new(errgroup.Group) ) + numRoutines = min(numRoutines, lenLeaves) + for startIdx := 0; startIdx < numRoutines; startIdx++ { startIdx := startIdx + eg.Go(func() error { var err error for i := startIdx; i < lenLeaves; i += numRoutines { @@ -60,12 +69,15 @@ func (m *MerkleTree) computeLeafNodesParallel(blocks []DataBlock) ([][]byte, err return err } } + return nil }) } + if err := eg.Wait(); err != nil { - return nil, err + return nil, fmt.Errorf("computeLeafNodesParallel: %w", err) } + return leaves, nil } @@ -74,13 +86,16 @@ func (m *MerkleTree) computeLeafNodesParallel(blocks []DataBlock) ([][]byte, err func dataBlockToLeaf(block DataBlock, hashFunc TypeHashFunc, disableLeafHashing bool) ([]byte, error) { blockBytes, err := block.Serialize() if err != nil { - return nil, err + return nil, fmt.Errorf("dataBlockToLeaf: %w", err) } + if disableLeafHashing { // copy the value so that the original byte slice is not modified leaf := make([]byte, len(blockBytes)) copy(leaf, blockBytes) + return leaf, nil } + return hashFunc(blockBytes) } diff --git a/merkle_tree.go b/merkle_tree.go index 05be68d..8eb611a 100644 --- a/merkle_tree.go +++ b/merkle_tree.go @@ -118,16 +118,6 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) { Depth: bits.Len(uint(len(blocks) - 1)), } - // Initialize the hash function. - if m.HashFunc == nil { - if m.RunInParallel { - // Use a concurrent safe hash function for parallel execution. - m.HashFunc = DefaultHashFuncParallel - } else { - m.HashFunc = DefaultHashFunc - } - } - // Hash concatenation function initialization. if m.concatHashFunc == nil { if m.SortSiblingPairs { @@ -137,69 +127,106 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) { } } - // Configure parallelization settings. + // Perform actions based on the configured mode. + // Set the mode to ModeProofGen by default if not specified. + if m.Mode == 0 { + m.Mode = ModeProofGen + } + if m.RunInParallel { - // Set NumRoutines to the number of CPU cores if not specified or invalid. - if m.NumRoutines <= 0 { - m.NumRoutines = runtime.NumCPU() - } - if m.Leaves, err = m.computeLeafNodesParallel(blocks); err != nil { - return nil, err - } - } else { - // Generate leaves without parallelization. - if m.Leaves, err = m.computeLeafNodes(blocks); err != nil { + if err := m.newParallel(blocks); err != nil { return nil, err } + + return m, nil } - // Perform actions based on the configured mode. - // Set the mode to ModeProofGen by default if not specified. - if m.Mode == 0 { - m.Mode = ModeProofGen + if err := m.new(blocks); err != nil { + return nil, err + } + + return m, nil +} + +func (m *MerkleTree) new(blocks []DataBlock) error { + // Initialize the hash function. + if m.HashFunc == nil { + m.HashFunc = DefaultHashFunc + } + + // Generate leaves. + var err error + m.Leaves, err = m.computeLeafNodes(blocks) + + if err != nil { + return err } - // Generate proofs in ModeProofGen. if m.Mode == ModeProofGen { - if m.RunInParallel { - err = m.proofGenParallel() - return - } - err = m.proofGen() - return + return m.proofGen() } + // Initialize the leafMap for ModeTreeBuild and ModeProofGenAndTreeBuild. m.leafMap = make(map[string]int) - // Build the tree in ModeTreeBuild. if m.Mode == ModeTreeBuild { - if m.RunInParallel { - err = m.treeBuildParallel() - return - } - err = m.treeBuild() - return + return m.treeBuild() } // Build the tree and generate proofs in ModeProofGenAndTreeBuild. if m.Mode == ModeProofGenAndTreeBuild { - if m.RunInParallel { - err = m.proofGenAndTreeBuildParallel() - return - } - err = m.proofGenAndTreeBuild() - return + return m.proofGenAndTreeBuild() + } + + // Return an error if the configuration mode is invalid. + return ErrInvalidConfigMode +} + +func (m *MerkleTree) newParallel(blocks []DataBlock) error { + // Initialize the hash function. + if m.HashFunc == nil { + m.HashFunc = DefaultHashFuncParallel + } + + // Set NumRoutines to the number of CPU cores if not specified or invalid. + if m.NumRoutines <= 0 { + m.NumRoutines = runtime.NumCPU() + } + + // Generate leaves. + var err error + m.Leaves, err = m.computeLeafNodesParallel(blocks) + + if err != nil { + return err + } + + if m.Mode == ModeProofGen { + return m.proofGenParallel() + } + + // Initialize the leafMap for ModeTreeBuild and ModeProofGenAndTreeBuild. + m.leafMap = make(map[string]int) + + if m.Mode == ModeTreeBuild { + return m.treeBuildParallel() + } + + // Build the tree and generate proofs in ModeProofGenAndTreeBuild. + if m.Mode == ModeProofGenAndTreeBuild { + return m.proofGenAndTreeBuildParallel() } // Return an error if the configuration mode is invalid. - return nil, ErrInvalidConfigMode + return ErrInvalidConfigMode } // concatHash concatenates two byte slices, b1 and b2. -func concatHash(b1 []byte, b2 []byte) []byte { +func concatHash(b1, b2 []byte) []byte { result := make([]byte, len(b1)+len(b2)) copy(result, b1) copy(result[len(b1):], b2) + return result } @@ -207,9 +234,10 @@ func concatHash(b1 []byte, b2 []byte) []byte { // The function ensures that the smaller byte slice (in terms of lexicographic order) // is placed before the larger one. This is used for compatibility with OpenZeppelin's // Merkle Proof verification implementation. -func concatSortHash(b1 []byte, b2 []byte) []byte { +func concatSortHash(b1, b2 []byte) []byte { if bytes.Compare(b1, b2) < 0 { return concatHash(b1, b2) } + return concatHash(b2, b1) } diff --git a/proof.go b/proof.go index d80958d..fc7cc11 100644 --- a/proof.go +++ b/proof.go @@ -47,6 +47,7 @@ func (m *MerkleTree) Proof(dataBlock DataBlock) (*Proof, error) { m.leafMapMu.Lock() idx, ok := m.leafMap[string(leaf)] m.leafMapMu.Unlock() + if !ok { return nil, ErrProofInvalidDataBlock } @@ -56,6 +57,7 @@ func (m *MerkleTree) Proof(dataBlock DataBlock) (*Proof, error) { path uint32 siblings = make([][]byte, m.Depth) ) + for i := 0; i < m.Depth; i++ { if idx&1 == 1 { siblings[i] = m.nodes[i][idx-1] @@ -63,8 +65,10 @@ func (m *MerkleTree) Proof(dataBlock DataBlock) (*Proof, error) { path += 1 << i siblings[i] = m.nodes[i][idx+1] } + idx >>= 1 } + return &Proof{ Path: path, Siblings: siblings, diff --git a/proof_gen.go b/proof_gen.go index a33d423..77372bf 100644 --- a/proof_gen.go +++ b/proof_gen.go @@ -23,6 +23,7 @@ package merkletree import ( + "fmt" "sync" "golang.org/x/sync/errgroup" @@ -33,36 +34,46 @@ import ( func (m *MerkleTree) proofGen() (err error) { m.initProofs() buffer, bufferSize := initBuffer(m.Leaves) + for step := 0; step < m.Depth; step++ { bufferSize = fixOddNumOfNodes(buffer, bufferSize, step) m.updateProofs(buffer, bufferSize, step) + for idx := 0; idx < bufferSize; idx += 2 { leftIdx := idx << step rightIdx := min(leftIdx+(1<>= 1 } + m.Root = buffer[0] + return } // proofGenParallel generates proofs concurrently for the MerkleTree. -func (m *MerkleTree) proofGenParallel() (err error) { +func (m *MerkleTree) proofGenParallel() error { m.initProofs() buffer, bufferSize := initBuffer(m.Leaves) numRoutines := m.NumRoutines + for step := 0; step < m.Depth; step++ { // Limit the number of workers to the previous level length. numRoutines = min(numRoutines, bufferSize) bufferSize = fixOddNumOfNodes(buffer, bufferSize, step) m.updateProofsParallel(buffer, bufferSize, step) + eg := new(errgroup.Group) - for startIdx := 0; startIdx < numRoutines; startIdx++ { - startIdx := startIdx << 1 + + for workerIdx := 0; workerIdx < numRoutines; workerIdx++ { + startIdx := workerIdx << 1 + eg.Go(func() error { var err error for i := startIdx; i < bufferSize; i += numRoutines << 1 { @@ -73,16 +84,21 @@ func (m *MerkleTree) proofGenParallel() (err error) { return err } } + return nil }) } - if err = eg.Wait(); err != nil { - return + + if err := eg.Wait(); err != nil { + return fmt.Errorf("proofGenParallel: %w", err) } + bufferSize >>= 1 } + m.Root = buffer[0] - return + + return nil } // initProofs initializes the MerkleTree's Proofs with the appropriate size and depth. @@ -97,18 +113,18 @@ func (m *MerkleTree) initProofs() { // initBuffer initializes the buffer with the leaves and returns the buffer size. // If the number of leaves is odd, the buffer size is increased by 1. -func initBuffer(leaves [][]byte) ([][]byte, int) { - var ( - numLeaves = len(leaves) - buffer [][]byte - ) +func initBuffer(leaves [][]byte) (buffer [][]byte, numLeaves int) { + numLeaves = len(leaves) + // If the number of leaves is odd, make initial buffer size even by adding 1. if numLeaves&1 == 1 { buffer = make([][]byte, numLeaves+1) } else { buffer = make([][]byte, numLeaves) } + copy(buffer, leaves) + return buffer, numLeaves } @@ -119,11 +135,13 @@ func fixOddNumOfNodes(buffer [][]byte, bufferSize, step int) int { if bufferSize&1 == 0 { return bufferSize } + // Determine the node to append. appendNodeIndex := (bufferSize - 1) << step // The appended node will be put at the end of the buffer. buffer[len(buffer)-1] = buffer[appendNodeIndex] bufferSize++ + return bufferSize } @@ -141,11 +159,14 @@ func (m *MerkleTree) updateProofsParallel(buffer [][]byte, bufferLength, step in batch = 1 << step wg sync.WaitGroup ) + numRoutines := min(m.NumRoutines, bufferLength) wg.Add(numRoutines) + for startIdx := 0; startIdx < numRoutines; startIdx++ { go func(startIdx int) { defer wg.Done() + for i := startIdx; i < bufferLength; i += numRoutines << 1 { updateProofInTwoBatches(m.Proofs, buffer, i, batch, step) } @@ -159,13 +180,16 @@ func updateProofInTwoBatches(proofs []*Proof, buffer [][]byte, idx, batch, step start := idx * batch end := min(start+batch, len(proofs)) siblingNodeIdx := min((idx+1)<>1) + for j := 0; j < numNodes; j += 2 { if m.nodes[i+1][j>>1], err = m.HashFunc( m.concatHashFunc(m.nodes[i][j], m.nodes[i][j+1]), @@ -41,28 +47,34 @@ func (m *MerkleTree) treeBuild() (err error) { } } } + if m.Root, err = m.HashFunc(m.concatHashFunc( m.nodes[m.Depth-1][0], m.nodes[m.Depth-1][1], )); err != nil { return } + <-finishMap + return } // treeBuildParallel builds the Merkle Tree and stores all the nodes in parallel. -func (m *MerkleTree) treeBuildParallel() (err error) { +func (m *MerkleTree) treeBuildParallel() error { finishMap := make(chan struct{}) go m.workerBuildLeafMap(finishMap) m.initNodes() + for i := 0; i < m.Depth-1; i++ { m.nodes[i] = appendNodeIfOdd(m.nodes[i]) numNodes := len(m.nodes[i]) m.nodes[i+1] = make([][]byte, numNodes>>1) numRoutines := min(m.NumRoutines, numNodes) eg := new(errgroup.Group) + for startIdx := 0; startIdx < numRoutines; startIdx++ { startIdx := startIdx + eg.Go(func() error { for j := startIdx << 1; j < numNodes; j += numRoutines << 1 { newHash, err := m.HashFunc(m.concatHashFunc( @@ -73,25 +85,32 @@ func (m *MerkleTree) treeBuildParallel() (err error) { } m.nodes[i+1][j>>1] = newHash } + return nil }) } - if err = eg.Wait(); err != nil { - return + + if err := eg.Wait(); err != nil { + return fmt.Errorf("treeBuildParallel: %w", err) } } + + var err error if m.Root, err = m.HashFunc(m.concatHashFunc( m.nodes[m.Depth-1][0], m.nodes[m.Depth-1][1], )); err != nil { - return + return err } + <-finishMap - return + + return nil } func (m *MerkleTree) workerBuildLeafMap(finishChan chan struct{}) { m.leafMapMu.Lock() defer m.leafMapMu.Unlock() + for i := 0; i < m.NumLeaves; i++ { m.leafMap[string(m.Leaves[i])] = i } @@ -108,7 +127,10 @@ func appendNodeIfOdd(buffer [][]byte) [][]byte { if len(buffer)&1 == 0 { return buffer } + appendNode := buffer[len(buffer)-1] + buffer = append(buffer, appendNode) + return buffer } diff --git a/verify.go b/verify.go index c1f902e..974987b 100644 --- a/verify.go +++ b/verify.go @@ -37,12 +37,15 @@ func Verify(dataBlock DataBlock, proof *Proof, root []byte, config *Config) (boo if dataBlock == nil { return false, ErrDataBlockIsNil } + if proof == nil { return false, ErrProofIsNil } + if config == nil { config = new(Config) } + if config.HashFunc == nil { config.HashFunc = DefaultHashFunc } @@ -63,17 +66,22 @@ func Verify(dataBlock DataBlock, proof *Proof, root []byte, config *Config) (boo // Copy the slice so that the original leaf won't be modified. result := make([]byte, len(leaf)) copy(result, leaf) + path := proof.Path + for _, sib := range proof.Siblings { if path&1 == 1 { result, err = config.HashFunc(concatFunc(result, sib)) } else { result, err = config.HashFunc(concatFunc(sib, result)) } + if err != nil { return false, err } + path >>= 1 } + return bytes.Equal(result, root), nil }