Skip to content

Commit

Permalink
timer: move Timer inteface to dbft package
Browse files Browse the repository at this point in the history
Store Timer interface along with other dBFT interfaces and provide
default timer implementation in `timer` package. Create `dbft.HV`
interface and configuration for its customisation. Provide default
implementation of `dbft.HV` in `timer` package.

A part of #84.

Signed-off-by: Anna Shaleva <[email protected]>
  • Loading branch information
AnnaShaleva committed Mar 8, 2024
1 parent 32e1df2 commit f1519e9
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 88 deletions.
20 changes: 16 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"time"

"github.com/nspcc-dev/dbft/timer"
"go.uber.org/zap"
)

Expand All @@ -14,7 +13,9 @@ type Config[H Hash] struct {
// Logger
Logger *zap.Logger
// Timer
Timer timer.Timer
Timer Timer
// NewHeightView is a constructor for [dbft.HV] object.
NewHeightView func(height uint32, view byte) HV
// SecondsPerBlock is the number of seconds that
// need to pass before another block will be accepted.
SecondsPerBlock time.Duration
Expand Down Expand Up @@ -85,7 +86,6 @@ func defaultConfig[H Hash]() *Config[H] {
// fields which are set to nil must be provided from client
return &Config[H]{
Logger: zap.NewNop(),
Timer: timer.New(),
SecondsPerBlock: defaultSecondsPerBlock,
TimestampIncrement: defaultTimestampIncrement,
GetKeyPair: nil,
Expand All @@ -110,6 +110,10 @@ func defaultConfig[H Hash]() *Config[H] {
func checkConfig[H Hash](cfg *Config[H]) error {
if cfg.GetKeyPair == nil {
return errors.New("private key is nil")
} else if cfg.Timer == nil {
return errors.New("Timer is nil")
} else if cfg.NewHeightView == nil {
return errors.New("NewHeightView is nil")
} else if cfg.CurrentHeight == nil {
return errors.New("CurrentHeight is nil")
} else if cfg.CurrentBlockHash == nil {
Expand Down Expand Up @@ -176,12 +180,20 @@ func WithLogger[H Hash](log *zap.Logger) func(config *Config[H]) {
}

// WithTimer sets Timer.
func WithTimer[H Hash](t timer.Timer) func(config *Config[H]) {
func WithTimer[H Hash](t Timer) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.Timer = t
}
}

// WithNewHeightView sets NewHeightView constructor.
func WithNewHeightView[H Hash](f func(height uint32, view byte) HV) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.NewHeightView = f
}

}

// WithSecondsPerBlock sets SecondsPerBlock.
func WithSecondsPerBlock[H Hash](d time.Duration) func(config *Config[H]) {
return func(cfg *Config[H]) {
Expand Down
14 changes: 5 additions & 9 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"crypto/rand"
"encoding/binary"
"time"

"github.com/nspcc-dev/dbft/timer"
)

// Context is a main dBFT structure which
Expand Down Expand Up @@ -67,7 +65,7 @@ type Context[H Hash] struct {
LastChangeViewPayloads []ConsensusPayload[H]
// LastSeenMessage array stores the height of the last seen message, for each validator.
// if this node never heard from validator i, LastSeenMessage[i] will be -1.
LastSeenMessage []*timer.HV
LastSeenMessage []*HV

lastBlockTimestamp uint64 // ns-precision timestamp from the last header (used for the next block timestamp calculations).
lastBlockTime time.Time // Wall clock time of when the last block was first seen (used for timer adjustments).
Expand Down Expand Up @@ -120,7 +118,7 @@ func (c *Context[H]) CountCommitted() (count int) {
// for this view and that hasn't sent the Commit message at the previous views.
func (c *Context[H]) CountFailed() (count int) {
for i, hv := range c.LastSeenMessage {
if c.CommitPayloads[i] == nil && (hv == nil || hv.Height < c.BlockIndex || hv.View < c.ViewNumber) {
if c.CommitPayloads[i] == nil && (hv == nil || (*hv).Height() < c.BlockIndex || (*hv).View() < c.ViewNumber) {
count++
}
}
Expand Down Expand Up @@ -200,7 +198,7 @@ func (c *Context[H]) reset(view byte, ts uint64) {
c.LastChangeViewPayloads = make([]ConsensusPayload[H], n)

if c.LastSeenMessage == nil {
c.LastSeenMessage = make([]*timer.HV, n)
c.LastSeenMessage = make([]*HV, n)
}
c.blockProcessed = false
} else {
Expand Down Expand Up @@ -233,10 +231,8 @@ func (c *Context[H]) reset(view byte, ts uint64) {
c.ViewNumber = view

if c.MyIndex >= 0 {
c.LastSeenMessage[c.MyIndex] = &timer.HV{
Height: c.BlockIndex,
View: c.ViewNumber,
}
hv := c.Config.NewHeightView(c.BlockIndex, c.ViewNumber)
c.LastSeenMessage[c.MyIndex] = &hv
}
}

Expand Down
23 changes: 10 additions & 13 deletions dbft.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"sync"
"time"

"github.com/nspcc-dev/dbft/timer"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -169,22 +168,22 @@ func (d *DBFT[H]) OnTransaction(tx Transaction[H]) {
}

// OnTimeout advances state machine as if timeout was fired.
func (d *DBFT[H]) OnTimeout(hv timer.HV) {
func (d *DBFT[H]) OnTimeout(hv HV) {
if d.Context.WatchOnly() || d.BlockSent() {
return
}

if hv.Height != d.BlockIndex || hv.View != d.ViewNumber {
if hv.Height() != d.BlockIndex || hv.View() != d.ViewNumber {
d.Logger.Debug("timeout: ignore old timer",
zap.Uint32("height", hv.Height),
zap.Uint("view", uint(hv.View)))
zap.Uint32("height", hv.Height()),
zap.Uint("view", uint(hv.View())))

return
}

d.Logger.Debug("timeout",
zap.Uint32("height", hv.Height),
zap.Uint("view", uint(hv.View)))
zap.Uint32("height", hv.Height()),
zap.Uint("view", uint(hv.View())))

if d.IsPrimary() && !d.RequestSentOrReceived() {
d.sendPrepareRequest()
Expand Down Expand Up @@ -237,11 +236,9 @@ func (d *DBFT[H]) OnReceive(msg ConsensusPayload[H]) {
}

hv := d.LastSeenMessage[msg.ValidatorIndex()]
if hv == nil || hv.Height < msg.Height() || hv.View < msg.ViewNumber() {
d.LastSeenMessage[msg.ValidatorIndex()] = &timer.HV{
Height: msg.Height(),
View: msg.ViewNumber(),
}
if hv == nil || (*hv).Height() < msg.Height() || (*hv).View() < msg.ViewNumber() {
hv := d.Config.NewHeightView(msg.Height(), msg.ViewNumber())
d.LastSeenMessage[msg.ValidatorIndex()] = &hv
}

if d.BlockSent() && msg.Type() != RecoveryRequestType {
Expand Down Expand Up @@ -612,7 +609,7 @@ func (d *DBFT[H]) changeTimer(delay time.Duration) {
zap.Uint32("h", d.BlockIndex),
zap.Int("v", int(d.ViewNumber)),
zap.Duration("delay", delay))
d.Timer.Reset(timer.HV{Height: d.BlockIndex, View: d.ViewNumber}, delay)
d.Timer.Reset(d.Config.NewHeightView(d.BlockIndex, d.ViewNumber), delay)
}

func (d *DBFT[H]) extendTimer(count time.Duration) {
Expand Down
44 changes: 29 additions & 15 deletions dbft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) {
require.EqualValues(t, 2, p.ValidatorIndex())

t.Run("primary send ChangeView on timeout", func(t *testing.T) {
service.OnTimeout(timer.HV{Height: s.currHeight + 1})
service.OnTimeout(timer.HV{H: s.currHeight + 1})

// if there are many faulty must send RecoveryRequest
cv := s.tryRecv()
Expand All @@ -73,9 +73,10 @@ func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) {

// if all nodes are up must send ChangeView
for i := range service.LastSeenMessage {
service.LastSeenMessage[i] = &timer.HV{Height: s.currHeight + 1}
var hv dbft.HV = timer.HV{H: s.currHeight + 1}
service.LastSeenMessage[i] = &hv
}
service.OnTimeout(timer.HV{Height: s.currHeight + 1})
service.OnTimeout(timer.HV{H: s.currHeight + 1})

cv = s.tryRecv()
require.NotNil(t, cv)
Expand Down Expand Up @@ -166,7 +167,8 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) {
service.Start(0)

for i := range service.LastSeenMessage {
service.LastSeenMessage[i] = &timer.HV{Height: s.currHeight + 1}
var hv dbft.HV = timer.HV{H: s.currHeight + 1}
service.LastSeenMessage[i] = &hv
}

p := s.getPrepareRequest(5, txs[0].Hash())
Expand Down Expand Up @@ -303,10 +305,10 @@ func TestDBFT_OnReceiveCommit(t *testing.T) {
require.NoError(t, service.Header().Verify(pub, cm.GetCommit().Signature()))

t.Run("send recovery message on timeout", func(t *testing.T) {
service.OnTimeout(timer.HV{Height: 1})
service.OnTimeout(timer.HV{H: 1})
require.Nil(t, s.tryRecv())

service.OnTimeout(timer.HV{Height: s.currHeight + 1})
service.OnTimeout(timer.HV{H: s.currHeight + 1})

r := s.tryRecv()
require.NotNil(t, r)
Expand Down Expand Up @@ -394,13 +396,13 @@ func TestDBFT_OnReceiveChangeView(t *testing.T) {
service.OnReceive(resp)
require.Nil(t, s.tryRecv())

service.OnTimeout(timer.HV{Height: s.currHeight + 1})
service.OnTimeout(timer.HV{H: s.currHeight + 1})
cv := s.tryRecv()
require.NotNil(t, cv)
require.Equal(t, dbft.ChangeViewType, cv.Type())

t.Run("primary sends prepare request after timeout", func(t *testing.T) {
service.OnTimeout(timer.HV{Height: s.currHeight + 1, View: 1})
service.OnTimeout(timer.HV{H: s.currHeight + 1, V: 1})
pr := s.tryRecv()
require.NotNil(t, pr)
require.Equal(t, dbft.PrepareRequestType, pr.Type())
Expand All @@ -418,6 +420,16 @@ func TestDBFT_Invalid(t *testing.T) {
require.NotNil(t, pub)

opts := []func(*dbft.Config[crypto.Uint256]){dbft.WithKeyPair[crypto.Uint256](priv, pub)}
t.Run("without Timer", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
})

opts = append(opts, dbft.WithTimer[crypto.Uint256](timer.New()))
t.Run("without NewHeightView", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
})

opts = append(opts, dbft.WithNewHeightView[crypto.Uint256](timer.NewHV))
t.Run("without CurrentHeight", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
})
Expand Down Expand Up @@ -570,7 +582,7 @@ func TestDBFT_FourGoodNodesDeadlock(t *testing.T) {
// (possible on timeout) and sends the ChangeView message.
s3.OnReceive(resp0V0)
s3.OnReceive(resp2V0)
s3.OnTimeout(timer.HV{Height: r3.currHeight + 1, View: 0})
s3.OnTimeout(timer.HV{H: r3.currHeight + 1, V: 0})
cv3V0 := r3.tryRecv()
require.NotNil(t, cv3V0)
require.Equal(t, dbft.ChangeViewType, cv3V0.Type())
Expand All @@ -580,7 +592,7 @@ func TestDBFT_FourGoodNodesDeadlock(t *testing.T) {
// current view) and sends the ChangeView message.
s1.OnReceive(resp0V0)
s1.OnReceive(cv3V0)
s1.OnTimeout(timer.HV{Height: r1.currHeight + 1, View: 0})
s1.OnTimeout(timer.HV{H: r1.currHeight + 1, V: 0})
cv1V0 := r1.tryRecv()
require.NotNil(t, cv1V0)
require.Equal(t, dbft.ChangeViewType, cv1V0.Type())
Expand All @@ -589,7 +601,7 @@ func TestDBFT_FourGoodNodesDeadlock(t *testing.T) {
// (possible on timeout after receiving at least M non-commit messages from the
// current view) and sends the ChangeView message.
s0.OnReceive(cv3V0)
s0.OnTimeout(timer.HV{Height: r0.currHeight + 1, View: 0})
s0.OnTimeout(timer.HV{H: r0.currHeight + 1, V: 0})
cv0V0 := r0.tryRecv()
require.NotNil(t, cv0V0)
require.Equal(t, dbft.ChangeViewType, cv0V0.Type())
Expand All @@ -605,7 +617,7 @@ func TestDBFT_FourGoodNodesDeadlock(t *testing.T) {
require.Equal(t, uint8(1), s0.ViewNumber)

// Step 10. The primary (at view 1) replica 0 sends the PrepareRequest message.
s0.OnTimeout(timer.HV{Height: r0.currHeight + 1, View: 1})
s0.OnTimeout(timer.HV{H: r0.currHeight + 1, V: 1})
reqV1 := r0.tryRecv()
require.NotNil(t, reqV1)
require.Equal(t, dbft.PrepareRequestType, reqV1.Type())
Expand All @@ -628,7 +640,7 @@ func TestDBFT_FourGoodNodesDeadlock(t *testing.T) {
// Intermediate step A. It is added to make step 14 possible. The backup (at
// view 1) replica 3 doesn't receive anything for a long time and sends
// RecoveryRequest.
s3.OnTimeout(timer.HV{Height: r3.currHeight + 1, View: 1})
s3.OnTimeout(timer.HV{H: r3.currHeight + 1, V: 1})
rcvr3V1 := r3.tryRecv()
require.NotNil(t, rcvr3V1)
require.Equal(t, dbft.RecoveryRequestType, rcvr3V1.Type())
Expand Down Expand Up @@ -663,15 +675,15 @@ func TestDBFT_FourGoodNodesDeadlock(t *testing.T) {
// of "lost" nodes. That's why we'he added Intermediate steps A and B.
//
// After that replica 1 is allowed to send the CV message.
s1.OnTimeout(timer.HV{Height: r1.currHeight + 1, View: 1})
s1.OnTimeout(timer.HV{H: r1.currHeight + 1, V: 1})
cv1V1 := r1.tryRecv()
require.NotNil(t, cv1V1)
require.Equal(t, dbft.ChangeViewType, cv1V1.Type())

// Step 13. The primary (at view 1) replica 0 decides to change its view
// (possible on timeout) and sends the ChangeView message.
s0.OnReceive(resp1V1)
s0.OnTimeout(timer.HV{Height: r0.currHeight + 1, View: 1})
s0.OnTimeout(timer.HV{H: r0.currHeight + 1, V: 1})
cv0V1 := r0.tryRecv()
require.NotNil(t, cv0V1)
require.Equal(t, dbft.ChangeViewType, cv0V1.Type())
Expand Down Expand Up @@ -806,6 +818,8 @@ func (s testState) copyWithIndex(myIndex int) *testState {

func (s *testState) getOptions() []func(*dbft.Config[crypto.Uint256]) {
opts := []func(*dbft.Config[crypto.Uint256]){
dbft.WithTimer[crypto.Uint256](timer.New()),
dbft.WithNewHeightView[crypto.Uint256](timer.NewHV),
dbft.WithCurrentHeight[crypto.Uint256](func() uint32 { return s.currHeight }),
dbft.WithCurrentBlockHash[crypto.Uint256](func() crypto.Uint256 { return s.currHash }),
dbft.WithGetValidators[crypto.Uint256](func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey { return s.pubs }),
Expand Down
3 changes: 3 additions & 0 deletions internal/consensus/consensus.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/nspcc-dev/dbft"
"github.com/nspcc-dev/dbft/internal/crypto"
"github.com/nspcc-dev/dbft/timer"
"go.uber.org/zap"
)

Expand All @@ -18,6 +19,8 @@ func New(logger *zap.Logger, key dbft.PrivateKey, pub dbft.PublicKey,
getValidators func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey,
verifyPayload func(consensusPayload dbft.ConsensusPayload[crypto.Uint256]) error) *dbft.DBFT[crypto.Uint256] {
return dbft.New[crypto.Uint256](
dbft.WithTimer[crypto.Uint256](timer.New()),
dbft.WithNewHeightView[crypto.Uint256](timer.NewHV),
dbft.WithLogger[crypto.Uint256](logger),
dbft.WithSecondsPerBlock[crypto.Uint256](time.Second*5),
dbft.WithKeyPair[crypto.Uint256](key, pub),
Expand Down
30 changes: 30 additions & 0 deletions timer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package dbft

import (
"time"
)

// Timer is an interface which implements all time-related
// functions. It can be mocked for testing.
type Timer interface {
// Now returns current time.
Now() time.Time
// Reset resets timer to the specified block height and view.
Reset(hv HV, d time.Duration)
// Sleep stops execution for duration d.
Sleep(d time.Duration)
// Extend extends current timer with duration d.
Extend(d time.Duration)
// Stop stops timer.
Stop()
// HV returns current height and view set for the timer.
HV() HV
// C returns channel for timer events.
C() <-chan time.Time
}

// HV is an abstraction for pair of a Height and a View.
type HV interface {
Height() uint32
View() byte
}
Loading

0 comments on commit f1519e9

Please sign in to comment.