diff --git a/peers/app_request_network.go b/peers/app_request_network.go index f25a8518..df4e20e0 100644 --- a/peers/app_request_network.go +++ b/peers/app_request_network.go @@ -212,12 +212,12 @@ type ConnectedCanonicalValidators struct { ConnectedWeight uint64 TotalValidatorWeight uint64 ValidatorSet []*warp.Validator - nodeValidatorIndexMap map[ids.NodeID]int + NodeValidatorIndexMap map[ids.NodeID]int } // Returns the Warp Validator and its index in the canonical Validator ordering for a given nodeID func (c *ConnectedCanonicalValidators) GetValidator(nodeID ids.NodeID) (*warp.Validator, int) { - return c.ValidatorSet[c.nodeValidatorIndexMap[nodeID]], c.nodeValidatorIndexMap[nodeID] + return c.ValidatorSet[c.NodeValidatorIndexMap[nodeID]], c.NodeValidatorIndexMap[nodeID] } // ConnectToCanonicalValidators connects to the canonical validators of the given subnet and returns the connected @@ -258,7 +258,7 @@ func (n *appRequestNetwork) ConnectToCanonicalValidators(subnetID ids.ID) (*Conn ConnectedWeight: connectedWeight, TotalValidatorWeight: totalValidatorWeight, ValidatorSet: validatorSet, - nodeValidatorIndexMap: nodeValidatorIndexMap, + NodeValidatorIndexMap: nodeValidatorIndexMap, }, nil } diff --git a/signature-aggregator/aggregator/aggregator.go b/signature-aggregator/aggregator/aggregator.go index a48ac728..756630b1 100644 --- a/signature-aggregator/aggregator/aggregator.go +++ b/signature-aggregator/aggregator/aggregator.go @@ -5,6 +5,7 @@ package aggregator import ( "bytes" + "encoding/hex" "errors" "fmt" "math/big" @@ -554,6 +555,7 @@ func (s *SignatureAggregator) isValidSignatureResponse( if !bls.Verify(pubKey, sig, unsignedMessage.Bytes()) { s.logger.Debug( "Failed verification for signature", + zap.String("pubKey", hex.EncodeToString(bls.PublicKeyToUncompressedBytes(pubKey))), ) return blsSignatureBuf{}, false } diff --git a/signature-aggregator/aggregator/aggregator_test.go b/signature-aggregator/aggregator/aggregator_test.go index 13a6952c..ff03646f 100644 --- a/signature-aggregator/aggregator/aggregator_test.go +++ b/signature-aggregator/aggregator/aggregator_test.go @@ -1,23 +1,37 @@ package aggregator import ( + "context" + "encoding/hex" + "fmt" + "os" "testing" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/message" + "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/subnets" + "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/vms/platformvm/warp" "github.com/ava-labs/awm-relayer/peers" "github.com/ava-labs/awm-relayer/peers/mocks" "github.com/ava-labs/awm-relayer/signature-aggregator/metrics" + evmMsg "github.com/ava-labs/subnet-evm/plugin/evm/message" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) -var sigAggMetrics *metrics.SignatureAggregatorMetrics -var messageCreator message.Creator +var ( + sigAggMetrics *metrics.SignatureAggregatorMetrics + messageCreator message.Creator +) func instantiateAggregator(t *testing.T) ( *SignatureAggregator, @@ -35,23 +49,77 @@ func instantiateAggregator(t *testing.T) ( constants.DefaultNetworkCompressionType, constants.DefaultNetworkMaximumInboundTimeout, ) - require.Equal(t, err, nil) + require.Equal(t, nil, err) } aggregator, err := NewSignatureAggregator( mockNetwork, - logging.NoLog{}, + logging.NewLogger( + "aggregator_test", + logging.NewWrappedCore( + logging.Debug, + os.Stdout, + zapcore.NewConsoleEncoder( + zap.NewProductionEncoderConfig(), + ), + ), + ), 1024, sigAggMetrics, messageCreator, ) - require.Equal(t, err, nil) + require.Equal(t, nil, err) return aggregator, mockNetwork } +func makeConnectedValidators(validatorCount int) (*peers.ConnectedCanonicalValidators, []*bls.SecretKey) { + var validatorSet []*warp.Validator + var validatorSecretKeys []*bls.SecretKey + + nodeValidatorIndexMap := make(map[ids.NodeID]int) + + for i := 0; i < validatorCount; i++ { + secretKey, err := bls.NewSecretKey() + if err != nil { + panic(err) + } + validatorSecretKeys = append(validatorSecretKeys, secretKey) + + pubKey := bls.PublicFromSecretKey(secretKey) + + nodeID, err := ids.ToNodeID(utils.RandomBytes(20)) + if err != nil { + panic(err) + } + nodeValidatorIndexMap[nodeID] = i + + fmt.Printf( + "validator with pubKey %s has nodeID %s\n", + hex.EncodeToString(bls.PublicKeyToUncompressedBytes(pubKey)), + nodeID.String(), + ) + + validatorSet = append(validatorSet, + &warp.Validator{ + PublicKey: pubKey, + PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pubKey), + Weight: 1, + NodeIDs: []ids.NodeID{nodeID}, + }, + ) + } + + return &peers.ConnectedCanonicalValidators{ + ConnectedWeight: uint64(validatorCount), + TotalValidatorWeight: uint64(validatorCount), + ValidatorSet: validatorSet, + NodeValidatorIndexMap: nodeValidatorIndexMap, + }, validatorSecretKeys +} + func TestCreateSignedMessageFailsWithNoValidators(t *testing.T) { aggregator, mockNetwork := instantiateAggregator(t) msg, err := warp.NewUnsignedMessage(0, ids.Empty, []byte{}) - require.Equal(t, err, nil) + require.Equal(t, nil, err) mockNetwork.EXPECT().GetSubnetID(ids.Empty).Return(ids.Empty, nil) mockNetwork.EXPECT().ConnectToCanonicalValidators(ids.Empty).Return( &peers.ConnectedCanonicalValidators{ @@ -68,7 +136,7 @@ func TestCreateSignedMessageFailsWithNoValidators(t *testing.T) { func TestCreateSignedMessageFailsWithoutSufficientConnectedStake(t *testing.T) { aggregator, mockNetwork := instantiateAggregator(t) msg, err := warp.NewUnsignedMessage(0, ids.Empty, []byte{}) - require.Equal(t, err, nil) + require.Equal(t, nil, err) mockNetwork.EXPECT().GetSubnetID(ids.Empty).Return(ids.Empty, nil) mockNetwork.EXPECT().ConnectToCanonicalValidators(ids.Empty).Return( &peers.ConnectedCanonicalValidators{ @@ -85,3 +153,248 @@ func TestCreateSignedMessageFailsWithoutSufficientConnectedStake(t *testing.T) { "failed to connect to a threshold of stake", ) } + +func makeAppRequests( + chainID ids.ID, + requestID uint32, + connectedValidators *peers.ConnectedCanonicalValidators, +) []ids.RequestID { + var appRequests []ids.RequestID + for _, validator := range connectedValidators.ValidatorSet { + for _, nodeID := range validator.NodeIDs { + appRequests = append( + appRequests, + ids.RequestID{ + NodeID: nodeID, + SourceChainID: chainID, + DestinationChainID: chainID, + RequestID: requestID, + Op: byte( + message.AppResponseOp, + ), + }, + ) + } + } + return appRequests +} + +func TestCreateSignedMessageRetriesAndFailsWithoutP2PResponses(t *testing.T) { + aggregator, mockNetwork := instantiateAggregator(t) + + var ( + connectedValidators, _ = makeConnectedValidators(2) + requestID = aggregator.currentRequestID.Load() + 1 + ) + + chainID, err := ids.ToID(utils.RandomBytes(32)) + if err != nil { + panic(err) + } + + msg, err := warp.NewUnsignedMessage(0, chainID, []byte{}) + require.Equal(t, nil, err) + + subnetID, err := ids.ToID(utils.RandomBytes(32)) + require.Equal(t, nil, err) + mockNetwork.EXPECT().GetSubnetID(chainID).Return( + subnetID, + nil, + ) + + mockNetwork.EXPECT().ConnectToCanonicalValidators(subnetID).Return( + connectedValidators, + nil, + ) + + appRequests := makeAppRequests(chainID, requestID, connectedValidators) + for _, appRequest := range appRequests { + mockNetwork.EXPECT().RegisterAppRequest(appRequest).Times( + maxRelayerQueryAttempts, + ) + } + + mockNetwork.EXPECT().RegisterRequestID( + requestID, + len(appRequests), + ).Return( + make(chan message.InboundMessage, len(appRequests)), + ).Times(maxRelayerQueryAttempts) + + var nodeIDs set.Set[ids.NodeID] + for _, appRequest := range appRequests { + nodeIDs.Add(appRequest.NodeID) + } + mockNetwork.EXPECT().Send( + gomock.Any(), + nodeIDs, + subnetID, + subnets.NoOpAllower, + ).Times(maxRelayerQueryAttempts) + + _, err = aggregator.CreateSignedMessage(msg, subnetID, 80) + require.ErrorContains( + t, + err, + "failed to collect a threshold of signatures", + ) +} + +func TestCreateSignedMessageSucceeds(t *testing.T) { + var msg *warp.UnsignedMessage // to be signed + chainID, err := ids.ToID(utils.RandomBytes(32)) + if err != nil { + panic(err) + } + networkID := uint32(0) + msg, err = warp.NewUnsignedMessage( + networkID, + chainID, + utils.RandomBytes(1234), + ) + require.Equal(t, nil, err) + + // the signers: + var connectedValidators, validatorSecretKeys = makeConnectedValidators(5) + + // prime the aggregator: + + aggregator, mockNetwork := instantiateAggregator(t) + + subnetID, err := ids.ToID(utils.RandomBytes(32)) + require.Equal(t, nil, err) + mockNetwork.EXPECT().GetSubnetID(chainID).Return( + subnetID, + nil, + ) + + mockNetwork.EXPECT().ConnectToCanonicalValidators(subnetID).Return( + connectedValidators, + nil, + ) + + // prime the signers' responses: + + var requestID = aggregator.currentRequestID.Load() + 1 + + appRequests := makeAppRequests(chainID, requestID, connectedValidators) + for _, appRequest := range appRequests { + mockNetwork.EXPECT().RegisterAppRequest(appRequest).Times(1) + } + + var nodeIDs set.Set[ids.NodeID] + responseChan := make(chan message.InboundMessage, len(appRequests)) + for _, appRequest := range appRequests { + nodeIDs.Add(appRequest.NodeID) + validatorSecretKey := validatorSecretKeys[connectedValidators.NodeValidatorIndexMap[appRequest.NodeID]] + responseBytes, err := evmMsg.Codec.Marshal( + 0, + &evmMsg.SignatureResponse{ + Signature: [bls.SignatureLen]byte( + bls.SignatureToBytes( + bls.Sign( + validatorSecretKey, + msg.Bytes(), + ), + ), + ), + }, + ) + require.Equal(t, nil, err) + responseChan <- message.InboundAppResponse( + chainID, + requestID, + responseBytes, + appRequest.NodeID, + ) + } + mockNetwork.EXPECT().RegisterRequestID( + requestID, + len(appRequests), + ).Return(responseChan).Times(1) + + mockNetwork.EXPECT().Send( + gomock.Any(), + nodeIDs, + subnetID, + subnets.NoOpAllower, + ).Times(1).Return(nodeIDs) + + // aggregate the signatures: + var quorumPercentage uint64 = 80 + signedMessage, err := aggregator.CreateSignedMessage( + msg, + subnetID, + quorumPercentage, + ) + require.Equal(t, nil, err) + + // verify the aggregated signature: + pChainState := newPChainStateStub( + chainID, + subnetID, + 1, + connectedValidators, + ) + require.Equal( + t, + nil, + signedMessage.Signature.Verify( + context.Background(), + msg, + networkID, + pChainState, + pChainState.currentHeight, + quorumPercentage, + 100, + ), + ) +} + +type pChainStateStub struct { + subnetIDByChainID map[ids.ID]ids.ID + connectedCanonicalValidators *peers.ConnectedCanonicalValidators + currentHeight uint64 +} + +func newPChainStateStub( + chainID, subnetID ids.ID, + currentHeight uint64, + connectedValidators *peers.ConnectedCanonicalValidators, +) *pChainStateStub { + subnetIDByChainID := make(map[ids.ID]ids.ID) + subnetIDByChainID[chainID] = subnetID + return &pChainStateStub{ + subnetIDByChainID: subnetIDByChainID, + connectedCanonicalValidators: connectedValidators, + currentHeight: currentHeight, + } +} + +func (p pChainStateStub) GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, error) { + return p.subnetIDByChainID[chainID], nil +} + +func (p pChainStateStub) GetMinimumHeight(context.Context) (uint64, error) { return 0, nil } + +func (p pChainStateStub) GetCurrentHeight(context.Context) (uint64, error) { + return p.currentHeight, nil +} + +func (p pChainStateStub) GetValidatorSet( + ctx context.Context, + height uint64, + subnetID ids.ID, +) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + output := make(map[ids.NodeID]*validators.GetValidatorOutput) + for _, validator := range p.connectedCanonicalValidators.ValidatorSet { + for _, nodeID := range validator.NodeIDs { + output[nodeID] = &validators.GetValidatorOutput{ + NodeID: nodeID, + PublicKey: validator.PublicKey, + Weight: validator.Weight, + } + } + } + return output, nil +} diff --git a/signature-aggregator/aggregator/cache/cache.go b/signature-aggregator/aggregator/cache/cache.go index 475a6fdc..89d30941 100644 --- a/signature-aggregator/aggregator/cache/cache.go +++ b/signature-aggregator/aggregator/cache/cache.go @@ -1,6 +1,7 @@ package cache import ( + "encoding/hex" "math" "github.com/ava-labs/avalanchego/ids" @@ -41,7 +42,16 @@ func (c *Cache) Get(msgID ids.ID) (map[PublicKeyBytes]SignatureBytes, bool) { cachedValue, isCached := c.signatures.Get(msgID) if isCached { - c.logger.Debug("cache hit", zap.Stringer("msgID", msgID)) + var encodedKeys []string + for key := range cachedValue { + encodedKeys = append(encodedKeys, hex.EncodeToString(key[:])) + } + c.logger.Debug( + "cache hit", + zap.Stringer("msgID", msgID), + zap.Int("signatureCount", len(cachedValue)), + zap.Strings("public keys", encodedKeys), + ) return cachedValue, true } else { c.logger.Debug("cache miss", zap.Stringer("msgID", msgID))