Skip to content

Commit

Permalink
ocr3 - changes related to multiramps (#1034)
Browse files Browse the repository at this point in the history
Update commit plugin logic to support multiramps.
With multiramps we are not able to compute the msg hash on the source
chain but only on the destination chain.

Changes:

0. Updated the spec. Check the spec diff for a summary of the changes.
1. Msg ID is computed on source and only used for tracing purposes.
2. Msg Hash is computed for the destination chain by the MsgHasher
destination chain specific implementation.
4. Merkle tree uses msg hash for leaves instead of msg id.
  • Loading branch information
dimkouv authored Jun 25, 2024
1 parent 795690d commit 2fbf8fd
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 191 deletions.
9 changes: 5 additions & 4 deletions core/services/ocr3/plugins/ccip/commit/plugin_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package commit
import (
"context"
"reflect"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -301,8 +302,8 @@ func setupAllNodesReadAllChains(ctx context.Context, t *testing.T, lggr logger.L
chainB,
cciptypes.NewSeqNumRange(21, cciptypes.SeqNum(21+cfg.NewMsgScanBatchSize)),
).Return([]cciptypes.CCIPMsg{
{CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ID: cciptypes.Bytes32{1}, SourceChain: chainB, SeqNum: 21}},
{CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ID: cciptypes.Bytes32{2}, SourceChain: chainB, SeqNum: 22}},
{CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{MsgHash: cciptypes.Bytes32{1}, ID: "1", SourceChain: chainB, SeqNum: 21}},
{CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{MsgHash: cciptypes.Bytes32{2}, ID: "2", SourceChain: chainB, SeqNum: 22}},
}, nil)

n.ccipReader.On("GasPrices", ctx, []cciptypes.ChainSelector{chainA, chainB}).
Expand Down Expand Up @@ -401,8 +402,8 @@ func setupNodesDoNotAgreeOnMsgs(ctx context.Context, t *testing.T, lggr logger.L
cciptypes.SeqNum(21+cfg.NewMsgScanBatchSize),
),
).Return([]cciptypes.CCIPMsg{
{CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ID: cciptypes.Bytes32{1, byte(i)}, SourceChain: chainB, SeqNum: 21 + cciptypes.SeqNum(i*10)}},
{CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ID: cciptypes.Bytes32{2, byte(i)}, SourceChain: chainB, SeqNum: 22 + cciptypes.SeqNum(i*20)}},
{CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{MsgHash: cciptypes.Bytes32{1}, ID: "1" + strconv.Itoa(i), SourceChain: chainB, SeqNum: 21 + cciptypes.SeqNum(i*10)}},
{CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{MsgHash: cciptypes.Bytes32{2}, ID: "2" + strconv.Itoa(i), SourceChain: chainB, SeqNum: 22 + cciptypes.SeqNum(i*20)}},
}, nil)

n.ccipReader.On("GasPrices", ctx, []cciptypes.ChainSelector{chainA, chainB}).
Expand Down
93 changes: 45 additions & 48 deletions core/services/ocr3/plugins/ccip/commit/plugin_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
)

// observeLatestCommittedSeqNums finds the maximum committed sequence numbers for each source chain.
// If we cannot observe the dest we return an empty slice and no error..
// If we cannot observe the dest we return an empty slice and no error.
func observeLatestCommittedSeqNums(
ctx context.Context,
lggr logger.Logger,
Expand Down Expand Up @@ -85,16 +85,12 @@ func observeNewMsgs(
lggr.Debugw("no new messages discovered", "chain", seqNumChain.ChainSel)
}

for _, msg := range newMsgs {
msgHash, err := msgHasher.Hash(ctx, msg)
for i := range newMsgs {
h, err := msgHasher.Hash(ctx, newMsgs[i])
if err != nil {
return fmt.Errorf("hash message: %w", err)
}

if msgHash != msg.ID {
lggr.Warnw("invalid message discovered", "msg", msg, "err", err)
continue
}
newMsgs[i].MsgHash = h // populate msgHash field
}

newMsgsPerChain[chainIdx] = newMsgs
Expand Down Expand Up @@ -244,52 +240,52 @@ func newMsgsConsensusForChain(
lggr.Debugw("observed messages consensus",
"chain", chainSel, "fChain", fChain, "observedMsgs", len(observedMsgs))

// First come to consensus about the (sequence number, id) pairs.
// For each sequence number consider correct the ID with the most votes.
msgSeqNumToIDCounts := make(map[cciptypes.SeqNum]map[string]int) // seqNum -> msgID -> count
// First come to consensus about the (sequence number, msg hash) pairs.
// For each sequence number consider the Hash with the most votes.
msgSeqNumToHashCounts := make(map[cciptypes.SeqNum]map[string]int) // seqNum -> msgHash -> count
for _, msg := range observedMsgs {
if _, exists := msgSeqNumToIDCounts[msg.SeqNum]; !exists {
msgSeqNumToIDCounts[msg.SeqNum] = make(map[string]int)
if _, exists := msgSeqNumToHashCounts[msg.SeqNum]; !exists {
msgSeqNumToHashCounts[msg.SeqNum] = make(map[string]int)
}
msgSeqNumToIDCounts[msg.SeqNum][msg.ID.String()]++
msgSeqNumToHashCounts[msg.SeqNum][msg.MsgHash.String()]++
}
lggr.Debugw("observed message counts", "chain", chainSel, "msgSeqNumToIdCounts", msgSeqNumToIDCounts)
lggr.Debugw("observed message counts", "chain", chainSel, "msgSeqNumToHashCounts", msgSeqNumToHashCounts)

msgObservationsCount := make(map[cciptypes.SeqNum]int)
msgSeqNumToID := make(map[cciptypes.SeqNum]cciptypes.Bytes32)
for seqNum, idCounts := range msgSeqNumToIDCounts {
if len(idCounts) == 0 {
lggr.Errorw("critical error id counts should never be empty", "seqNum", seqNum)
msgSeqNumToHash := make(map[cciptypes.SeqNum]cciptypes.Bytes32)
for seqNum, hashCounts := range msgSeqNumToHashCounts {
if len(hashCounts) == 0 {
lggr.Fatalw("hash counts should never be empty", "seqNum", seqNum)
continue
}

// Find the ID with the most votes for each sequence number.
idsSlice := make([]string, 0, len(idCounts))
for id := range idCounts {
idsSlice = append(idsSlice, id)
// Find the MsgHash with the most votes for each sequence number.
hashesSlice := make([]string, 0, len(hashCounts))
for h := range hashCounts {
hashesSlice = append(hashesSlice, h)
}
// determinism in case we have the same count for different ids
sort.Slice(idsSlice, func(i, j int) bool { return idsSlice[i] < idsSlice[j] })
// determinism in case we have the same count for different hashes
sort.Slice(hashesSlice, func(i, j int) bool { return hashesSlice[i] < hashesSlice[j] })

maxCnt := idCounts[idsSlice[0]]
mostVotedID := idsSlice[0]
for _, id := range idsSlice[1:] {
cnt := idCounts[id]
maxCnt := hashCounts[hashesSlice[0]]
mostVotedHash := hashesSlice[0]
for _, h := range hashesSlice[1:] {
cnt := hashCounts[h]
if cnt > maxCnt {
maxCnt = cnt
mostVotedID = id
mostVotedHash = h
}
}

msgObservationsCount[seqNum] = maxCnt
idBytes, err := cciptypes.NewBytes32FromString(mostVotedID)
hashBytes, err := cciptypes.NewBytes32FromString(mostVotedHash)
if err != nil {
return observedMsgsConsensus{}, fmt.Errorf("critical issue converting id '%s' to bytes32: %w",
mostVotedID, err)
return observedMsgsConsensus{}, fmt.Errorf("critical issue converting hash '%s' to bytes32: %w",
mostVotedHash, err)
}
msgSeqNumToID[seqNum] = idBytes
msgSeqNumToHash[seqNum] = hashBytes
}
lggr.Debugw("observed message consensus", "chain", chainSel, "msgSeqNumToId", msgSeqNumToID)
lggr.Debugw("observed message consensus", "chain", chainSel, "msgSeqNumToHash", msgSeqNumToHash)

// Filter out msgs not observed by at least 2f_chain+1 followers.
msgSeqNumsQuorum := mapset.NewSet[cciptypes.SeqNum]()
Expand All @@ -313,22 +309,13 @@ func newMsgsConsensusForChain(
seqNumConsensusRange.SetEnd(seqNum)
}

msgsBySeqNum := make(map[cciptypes.SeqNum]cciptypes.CCIPMsgBaseDetails)
for _, msg := range observedMsgs {
consensusMsgID, ok := msgSeqNumToID[msg.SeqNum]
if !ok || consensusMsgID != msg.ID {
continue
}
msgsBySeqNum[msg.SeqNum] = msg
}

treeLeaves := make([][32]byte, 0)
for seqNum := seqNumConsensusRange.Start(); seqNum <= seqNumConsensusRange.End(); seqNum++ {
msg, ok := msgsBySeqNum[seqNum]
msgHash, ok := msgSeqNumToHash[seqNum]
if !ok {
return observedMsgsConsensus{}, fmt.Errorf("msg not found in map for seq num %d", seqNum)
return observedMsgsConsensus{}, fmt.Errorf("msg hash not found for seq num %d", seqNum)
}
treeLeaves = append(treeLeaves, msg.ID)
treeLeaves = append(treeLeaves, msgHash)
}

lggr.Debugw("constructing merkle tree", "chain", chainSel, "treeLeaves", len(treeLeaves))
Expand Down Expand Up @@ -491,18 +478,28 @@ func validateObservedSequenceNumbers(msgs []cciptypes.CCIPMsgBaseDetails, maxSeq
}

seqNums := make(map[cciptypes.ChainSelector]mapset.Set[cciptypes.SeqNum], len(msgs))
hashes := mapset.NewSet[string]()
for _, msg := range msgs {
// The same sequence number must not appear more than once for the same chain and must be valid.
if msg.MsgHash.IsEmpty() {
return fmt.Errorf("observed msg hash must not be empty")
}

if _, exists := seqNums[msg.SourceChain]; !exists {
seqNums[msg.SourceChain] = mapset.NewSet[cciptypes.SeqNum]()
}

// The same sequence number must not appear more than once for the same chain and must be valid.
if seqNums[msg.SourceChain].Contains(msg.SeqNum) {
return fmt.Errorf("duplicate sequence number %d for chain %d", msg.SeqNum, msg.SourceChain)
}
seqNums[msg.SourceChain].Add(msg.SeqNum)

// The observed msg hash cannot appear twice for different msgs.
if hashes.Contains(msg.MsgHash.String()) {
return fmt.Errorf("duplicate msg hash %s", msg.MsgHash.String())
}
hashes.Add(msg.MsgHash.String())

// The observed msg sequence number cannot be less than or equal to the max observed sequence number.
maxSeqNum, exists := maxSeqNumsMap[msg.SourceChain]
if !exists {
Expand Down
Loading

0 comments on commit 2fbf8fd

Please sign in to comment.