diff --git a/config.go b/config.go index 6339d00e..c015cbbc 100644 --- a/config.go +++ b/config.go @@ -89,6 +89,10 @@ type Config[H Hash] struct { // Note that PreBlock-dependent PreCommit verification should be performed inside PreBlock.Verify // callback. VerifyPreCommit func(p ConsensusPayload[H]) error + // VerifyCommit performs external Commit verification and returns nil if it's successful. + // Note that Block-dependent Commit verification should be performed inside Block.Verify + // callback. + VerifyCommit func(p ConsensusPayload[H]) error } const defaultSecondsPerBlock = time.Second * 15 @@ -151,6 +155,8 @@ func checkConfig[H Hash](cfg *Config[H]) error { return errors.New("NewRecoveryRequest is nil") } else if cfg.NewRecoveryMessage == nil { return errors.New("NewRecoveryMessage is nil") + } else if cfg.VerifyCommit == nil { + return errors.New("VerifyCommit is nil") } else if cfg.AntiMEVExtensionEnablingHeight >= 0 { if cfg.NewPreBlockFromContext == nil { return errors.New("NewPreBlockFromContext is nil") @@ -400,3 +406,10 @@ func WithVerifyPreCommit[H Hash](f func(preCommit ConsensusPayload[H]) error) fu cfg.VerifyPreCommit = f } } + +// WithVerifyCommit sets VerifyCommit. +func WithVerifyCommit[H Hash](f func(commit ConsensusPayload[H]) error) func(config *Config[H]) { + return func(cfg *Config[H]) { + cfg.VerifyCommit = f + } +} diff --git a/dbft.go b/dbft.go index 5efaf31b..09975cea 100644 --- a/dbft.go +++ b/dbft.go @@ -586,6 +586,12 @@ func (d *DBFT[H]) onCommit(msg ConsensusPayload[H]) { } d.CommitPayloads[msg.ValidatorIndex()] = msg if d.ViewNumber == msg.ViewNumber() { + if err := d.VerifyCommit(msg); err != nil { + d.CommitPayloads[msg.ValidatorIndex()] = nil + d.Logger.Warn("invalid Commit", zap.Uint16("from", msg.ValidatorIndex()), zap.String("error", err.Error())) + return + } + d.Logger.Info("received Commit", zap.Uint("validator", uint(msg.ValidatorIndex()))) d.extendTimer(4) header := d.MakeHeader() diff --git a/dbft_test.go b/dbft_test.go index 2fcdea1a..c31305fc 100644 --- a/dbft_test.go +++ b/dbft_test.go @@ -578,6 +578,14 @@ func TestDBFT_Invalid(t *testing.T) { opts = append(opts, dbft.WithNewRecoveryMessage[crypto.Uint256](func() dbft.RecoveryMessage[crypto.Uint256] { return nil })) + t.Run("without VerifyCommit", func(t *testing.T) { + _, err := dbft.New(opts...) + require.Error(t, err) + }) + + opts = append(opts, dbft.WithVerifyCommit[crypto.Uint256](func(dbft.ConsensusPayload[crypto.Uint256]) error { + return nil + })) t.Run("with all defaults", func(t *testing.T) { d, err := dbft.New(opts...) require.NoError(t, err) @@ -1150,6 +1158,7 @@ func (s *testState) getOptions() []func(*dbft.Config[crypto.Uint256]) { dbft.WithNewRecoveryMessage[crypto.Uint256](func() dbft.RecoveryMessage[crypto.Uint256] { return consensus.NewRecoveryMessage(nil) }), + dbft.WithVerifyCommit[crypto.Uint256](func(p dbft.ConsensusPayload[crypto.Uint256]) error { return nil }), } verify := s.verify diff --git a/internal/consensus/consensus.go b/internal/consensus/consensus.go index e0ebdd6d..37653efc 100644 --- a/internal/consensus/consensus.go +++ b/internal/consensus/consensus.go @@ -40,6 +40,7 @@ func New(logger *zap.Logger, key dbft.PrivateKey, pub dbft.PublicKey, dbft.WithGetValidators[crypto.Uint256](getValidators), dbft.WithVerifyPrepareRequest[crypto.Uint256](verifyPayload), dbft.WithVerifyPrepareResponse[crypto.Uint256](verifyPayload), + dbft.WithVerifyCommit[crypto.Uint256](verifyPayload), dbft.WithNewBlockFromContext[crypto.Uint256](newBlockFromContext), dbft.WithNewConsensusPayload[crypto.Uint256](defaultNewConsensusPayload),